Implement optional link-time cross stage optimization

This commit is contained in:
Daniel Story 2024-09-12 12:18:09 -07:00 committed by arcady-lunarg
parent 9d00d6d6ca
commit 05559a2963
6 changed files with 124 additions and 2 deletions

View file

@ -109,6 +109,7 @@ enum TOptions : uint64_t {
EOptionDumpBareVersion = (1ull << 31),
EOptionCompileOnly = (1ull << 32),
EOptionDisplayErrorColumn = (1ull << 33),
EOptionLinkTimeOptimization = (1ull << 34),
};
bool targetHlslFunctionality1 = false;
bool SpvToolsDisassembler = false;
@ -899,6 +900,8 @@ void ProcessArguments(std::vector<std::unique_ptr<glslang::TWorkItem>>& workItem
Options |= EOptionCompileOnly;
} else if (lowerword == "error-column") {
Options |= EOptionDisplayErrorColumn;
} else if (lowerword == "lto") {
Options |= EOptionLinkTimeOptimization;
} else if (lowerword == "help") {
usage();
break;
@ -1083,6 +1086,10 @@ void ProcessArguments(std::vector<std::unique_ptr<glslang::TWorkItem>>& workItem
if ((Options & EOptionDumpReflection) && !(Options & EOptionLinkProgram))
Error("reflection requires -l for linking");
// link time optimization makes no sense unless linking
if ((Options & EOptionLinkTimeOptimization) && !(Options & EOptionLinkProgram))
Error("link time optimization requires -l for linking");
// -o or -x makes no sense if there is no target binary
if (binaryFileName && (Options & EOptionSpv) == 0)
Error("no binary generation requested (e.g., -V)");
@ -1167,6 +1174,8 @@ void SetMessageOptions(EShMessages& messages)
messages = (EShMessages)(messages | EShMsgAbsolutePath);
if (Options & EOptionDisplayErrorColumn)
messages = (EShMessages)(messages | EShMsgDisplayErrorColumn);
if (Options & EOptionLinkTimeOptimization)
messages = (EShMessages)(messages | EShMsgLinkTimeOptimization);
}
//
@ -2135,7 +2144,8 @@ void usage()
" initialized with the shader binary code\n"
" --no-link Only compile shader; do not link (GLSL-only)\n"
" NOTE: this option will set the export linkage\n"
" attribute on all functions\n");
" attribute on all functions\n"
" --lto perform link time optimization\n");
exit(EFailUsage);
}

View file

@ -176,6 +176,7 @@ typedef enum {
GLSLANG_MSG_ENHANCED = (1 << 15),
GLSLANG_MSG_ABSOLUTE_PATH = (1 << 16),
GLSLANG_MSG_DISPLAY_ERROR_COLUMN = (1 << 17),
GLSLANG_MSG_LINK_TIME_OPTIMIZATION_BIT = (1 << 18),
LAST_ELEMENT_MARKER(GLSLANG_MSG_COUNT),
} glslang_messages_t;

View file

@ -2055,7 +2055,7 @@ bool TProgram::linkStage(EShLanguage stage, EShMessages messages)
//
// Return true if no errors.
//
bool TProgram::crossStageCheck(EShMessages) {
bool TProgram::crossStageCheck(EShMessages messages) {
// make temporary intermediates to hold the linkage symbols for each linking interface
// while we do the checks
@ -2110,6 +2110,13 @@ bool TProgram::crossStageCheck(EShMessages) {
error |= (activeStages[i - 1]->getNumErrors() != 0);
}
// if requested, optimize cross stage IO
if (messages & EShMsgLinkTimeOptimization) {
for (unsigned int i = 1; i < activeStages.size(); ++i) {
activeStages[i - 1]->optimizeStageIO(*infoSink, *activeStages[i]);
}
}
return !error;
}

View file

@ -49,6 +49,7 @@
#include "localintermediate.h"
#include "../Include/InfoSink.h"
#include "SymbolTable.h"
#include "LiveTraverser.h"
namespace glslang {
@ -187,6 +188,107 @@ void TIntermediate::checkStageIO(TInfoSink& infoSink, TIntermediate& unit) {
}
}
void TIntermediate::optimizeStageIO(TInfoSink&, TIntermediate& unit)
{
// don't do any input/output demotion on compute, raytracing, or task/mesh stages
// TODO: support task/mesh
if (getStage() > EShLangFragment || unit.getStage() > EShLangFragment) {
return;
}
class TIOTraverser : public TLiveTraverser {
public:
TIOTraverser(TIntermediate& i, bool all, TIntermSequence& sequence, TStorageQualifier storage)
: TLiveTraverser(i, all, true, false, false), sequence(sequence), storage(storage)
{
}
virtual void visitSymbol(TIntermSymbol* symbol)
{
if (symbol->getQualifier().storage == storage) {
sequence.push_back(symbol);
}
}
private:
TIntermSequence& sequence;
TStorageQualifier storage;
};
// live symbols only
TIntermSequence unitLiveInputs;
TIOTraverser unitTraverser(unit, false, unitLiveInputs, EvqVaryingIn);
unitTraverser.pushFunction(unit.getEntryPointMangledName().c_str());
while (! unitTraverser.destinations.empty()) {
TIntermNode* destination = unitTraverser.destinations.back();
unitTraverser.destinations.pop_back();
destination->traverse(&unitTraverser);
}
TIntermSequence allOutputs;
TIntermSequence unitAllInputs;
TIOTraverser allTraverser(*this, true, allOutputs, EvqVaryingOut);
getTreeRoot()->traverse(&allTraverser);
TIOTraverser unitAllTraverser(unit, true, unitAllInputs, EvqVaryingIn);
unit.getTreeRoot()->traverse(&unitAllTraverser);
// find outputs not consumed by the next stage
std::for_each(allOutputs.begin(), allOutputs.end(), [&unitLiveInputs, &unitAllInputs](TIntermNode* output) {
// don't do anything to builtins
if (output->getAsSymbolNode()->getAccessName().compare(0, 3, "gl_") == 0)
return;
// don't demote block outputs (for now)
if (output->getAsSymbolNode()->getBasicType() == EbtBlock)
return;
// check if the (loose) output has a matching loose input
auto isMatchingInput = [output](TIntermNode* input) {
return output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName();
};
// check if the (loose) output has a matching block member input
auto isMatchingInputBlockMember = [output](TIntermNode* input) {
// ignore loose inputs
if (input->getAsSymbolNode()->getBasicType() != EbtBlock)
return false;
// don't demote loose outputs with matching input block members
auto isMatchingBlockMember = [output](TTypeLoc type) {
return type.type->getFieldName() == output->getAsSymbolNode()->getName();
};
const TTypeList* members = input->getAsSymbolNode()->getType().getStruct();
return std::any_of(members->begin(), members->end(), isMatchingBlockMember);
};
// determine if the input/output pair should be demoted
// do the faster (and more likely) loose-loose check first
if (std::none_of(unitLiveInputs.begin(), unitLiveInputs.end(), isMatchingInput) &&
std::none_of(unitAllInputs.begin(), unitAllInputs.end(), isMatchingInputBlockMember)) {
// demote any input matching the output
auto demoteMatchingInputs = [output](TIntermNode* input) {
if (output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName()) {
// demote input to a plain variable
TIntermSymbol* symbol = input->getAsSymbolNode();
symbol->getQualifier().storage = EvqGlobal;
symbol->getQualifier().clearInterstage();
symbol->getQualifier().clearLayout();
}
};
// demote all matching outputs to a plain variable
TIntermSymbol* symbol = output->getAsSymbolNode();
symbol->getQualifier().storage = EvqGlobal;
symbol->getQualifier().clearInterstage();
symbol->getQualifier().clearLayout();
std::for_each(unitAllInputs.begin(), unitAllInputs.end(), demoteMatchingInputs);
}
});
}
void TIntermediate::mergeCallGraphs(TInfoSink& infoSink, TIntermediate& unit)
{
if (unit.getNumEntryPoints() > 0) {

View file

@ -1051,6 +1051,7 @@ public:
void mergeGlobalUniformBlocks(TInfoSink& infoSink, TIntermediate& unit, bool mergeExistingOnly);
void mergeUniformObjects(TInfoSink& infoSink, TIntermediate& unit);
void checkStageIO(TInfoSink&, TIntermediate&);
void optimizeStageIO(TInfoSink&, TIntermediate&);
bool buildConvertOp(TBasicType dst, TBasicType src, TOperator& convertOp) const;
TIntermTyped* createConversion(TBasicType convertTo, TIntermTyped* node) const;

View file

@ -271,6 +271,7 @@ enum EShMessages : unsigned {
EShMsgEnhanced = (1 << 15), // enhanced message readability
EShMsgAbsolutePath = (1 << 16), // Output Absolute path for messages
EShMsgDisplayErrorColumn = (1 << 17), // Display error message column aswell as line
EShMsgLinkTimeOptimization = (1 << 18), // perform cross-stage optimizations during linking
LAST_ELEMENT_MARKER(EShMsgCount),
};