From 05559a2963bde12f47882b63ea3820d0b6ff88e5 Mon Sep 17 00:00:00 2001 From: Daniel Story Date: Thu, 12 Sep 2024 12:18:09 -0700 Subject: [PATCH] Implement optional link-time cross stage optimization --- StandAlone/StandAlone.cpp | 12 ++- glslang/Include/glslang_c_shader_types.h | 1 + glslang/MachineIndependent/ShaderLang.cpp | 9 +- glslang/MachineIndependent/linkValidate.cpp | 102 ++++++++++++++++++ .../MachineIndependent/localintermediate.h | 1 + glslang/Public/ShaderLang.h | 1 + 6 files changed, 124 insertions(+), 2 deletions(-) diff --git a/StandAlone/StandAlone.cpp b/StandAlone/StandAlone.cpp index 3288b887..e7153b7a 100644 --- a/StandAlone/StandAlone.cpp +++ b/StandAlone/StandAlone.cpp @@ -109,6 +109,7 @@ enum TOptions : uint64_t { EOptionDumpBareVersion = (1ull << 31), EOptionCompileOnly = (1ull << 32), EOptionDisplayErrorColumn = (1ull << 33), + EOptionLinkTimeOptimization = (1ull << 34), }; bool targetHlslFunctionality1 = false; bool SpvToolsDisassembler = false; @@ -899,6 +900,8 @@ void ProcessArguments(std::vector>& workItem Options |= EOptionCompileOnly; } else if (lowerword == "error-column") { Options |= EOptionDisplayErrorColumn; + } else if (lowerword == "lto") { + Options |= EOptionLinkTimeOptimization; } else if (lowerword == "help") { usage(); break; @@ -1083,6 +1086,10 @@ void ProcessArguments(std::vector>& workItem if ((Options & EOptionDumpReflection) && !(Options & EOptionLinkProgram)) Error("reflection requires -l for linking"); + // link time optimization makes no sense unless linking + if ((Options & EOptionLinkTimeOptimization) && !(Options & EOptionLinkProgram)) + Error("link time optimization requires -l for linking"); + // -o or -x makes no sense if there is no target binary if (binaryFileName && (Options & EOptionSpv) == 0) Error("no binary generation requested (e.g., -V)"); @@ -1167,6 +1174,8 @@ void SetMessageOptions(EShMessages& messages) messages = (EShMessages)(messages | EShMsgAbsolutePath); if (Options & EOptionDisplayErrorColumn) messages = (EShMessages)(messages | EShMsgDisplayErrorColumn); + if (Options & EOptionLinkTimeOptimization) + messages = (EShMessages)(messages | EShMsgLinkTimeOptimization); } // @@ -2135,7 +2144,8 @@ void usage() " initialized with the shader binary code\n" " --no-link Only compile shader; do not link (GLSL-only)\n" " NOTE: this option will set the export linkage\n" - " attribute on all functions\n"); + " attribute on all functions\n" + " --lto perform link time optimization\n"); exit(EFailUsage); } diff --git a/glslang/Include/glslang_c_shader_types.h b/glslang/Include/glslang_c_shader_types.h index 7bb0ccda..768e2e84 100644 --- a/glslang/Include/glslang_c_shader_types.h +++ b/glslang/Include/glslang_c_shader_types.h @@ -176,6 +176,7 @@ typedef enum { GLSLANG_MSG_ENHANCED = (1 << 15), GLSLANG_MSG_ABSOLUTE_PATH = (1 << 16), GLSLANG_MSG_DISPLAY_ERROR_COLUMN = (1 << 17), + GLSLANG_MSG_LINK_TIME_OPTIMIZATION_BIT = (1 << 18), LAST_ELEMENT_MARKER(GLSLANG_MSG_COUNT), } glslang_messages_t; diff --git a/glslang/MachineIndependent/ShaderLang.cpp b/glslang/MachineIndependent/ShaderLang.cpp index 367388f9..040b21da 100644 --- a/glslang/MachineIndependent/ShaderLang.cpp +++ b/glslang/MachineIndependent/ShaderLang.cpp @@ -2055,7 +2055,7 @@ bool TProgram::linkStage(EShLanguage stage, EShMessages messages) // // Return true if no errors. // -bool TProgram::crossStageCheck(EShMessages) { +bool TProgram::crossStageCheck(EShMessages messages) { // make temporary intermediates to hold the linkage symbols for each linking interface // while we do the checks @@ -2110,6 +2110,13 @@ bool TProgram::crossStageCheck(EShMessages) { error |= (activeStages[i - 1]->getNumErrors() != 0); } + // if requested, optimize cross stage IO + if (messages & EShMsgLinkTimeOptimization) { + for (unsigned int i = 1; i < activeStages.size(); ++i) { + activeStages[i - 1]->optimizeStageIO(*infoSink, *activeStages[i]); + } + } + return !error; } diff --git a/glslang/MachineIndependent/linkValidate.cpp b/glslang/MachineIndependent/linkValidate.cpp index 182a6775..b0c27e84 100644 --- a/glslang/MachineIndependent/linkValidate.cpp +++ b/glslang/MachineIndependent/linkValidate.cpp @@ -49,6 +49,7 @@ #include "localintermediate.h" #include "../Include/InfoSink.h" #include "SymbolTable.h" +#include "LiveTraverser.h" namespace glslang { @@ -187,6 +188,107 @@ void TIntermediate::checkStageIO(TInfoSink& infoSink, TIntermediate& unit) { } } +void TIntermediate::optimizeStageIO(TInfoSink&, TIntermediate& unit) +{ + // don't do any input/output demotion on compute, raytracing, or task/mesh stages + // TODO: support task/mesh + if (getStage() > EShLangFragment || unit.getStage() > EShLangFragment) { + return; + } + + class TIOTraverser : public TLiveTraverser { + public: + TIOTraverser(TIntermediate& i, bool all, TIntermSequence& sequence, TStorageQualifier storage) + : TLiveTraverser(i, all, true, false, false), sequence(sequence), storage(storage) + { + } + + virtual void visitSymbol(TIntermSymbol* symbol) + { + if (symbol->getQualifier().storage == storage) { + sequence.push_back(symbol); + } + } + + private: + TIntermSequence& sequence; + TStorageQualifier storage; + }; + + // live symbols only + TIntermSequence unitLiveInputs; + + TIOTraverser unitTraverser(unit, false, unitLiveInputs, EvqVaryingIn); + unitTraverser.pushFunction(unit.getEntryPointMangledName().c_str()); + while (! unitTraverser.destinations.empty()) { + TIntermNode* destination = unitTraverser.destinations.back(); + unitTraverser.destinations.pop_back(); + destination->traverse(&unitTraverser); + } + + TIntermSequence allOutputs; + TIntermSequence unitAllInputs; + + TIOTraverser allTraverser(*this, true, allOutputs, EvqVaryingOut); + getTreeRoot()->traverse(&allTraverser); + + TIOTraverser unitAllTraverser(unit, true, unitAllInputs, EvqVaryingIn); + unit.getTreeRoot()->traverse(&unitAllTraverser); + + // find outputs not consumed by the next stage + std::for_each(allOutputs.begin(), allOutputs.end(), [&unitLiveInputs, &unitAllInputs](TIntermNode* output) { + // don't do anything to builtins + if (output->getAsSymbolNode()->getAccessName().compare(0, 3, "gl_") == 0) + return; + + // don't demote block outputs (for now) + if (output->getAsSymbolNode()->getBasicType() == EbtBlock) + return; + + // check if the (loose) output has a matching loose input + auto isMatchingInput = [output](TIntermNode* input) { + return output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName(); + }; + + // check if the (loose) output has a matching block member input + auto isMatchingInputBlockMember = [output](TIntermNode* input) { + // ignore loose inputs + if (input->getAsSymbolNode()->getBasicType() != EbtBlock) + return false; + + // don't demote loose outputs with matching input block members + auto isMatchingBlockMember = [output](TTypeLoc type) { + return type.type->getFieldName() == output->getAsSymbolNode()->getName(); + }; + const TTypeList* members = input->getAsSymbolNode()->getType().getStruct(); + return std::any_of(members->begin(), members->end(), isMatchingBlockMember); + }; + + // determine if the input/output pair should be demoted + // do the faster (and more likely) loose-loose check first + if (std::none_of(unitLiveInputs.begin(), unitLiveInputs.end(), isMatchingInput) && + std::none_of(unitAllInputs.begin(), unitAllInputs.end(), isMatchingInputBlockMember)) { + // demote any input matching the output + auto demoteMatchingInputs = [output](TIntermNode* input) { + if (output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName()) { + // demote input to a plain variable + TIntermSymbol* symbol = input->getAsSymbolNode(); + symbol->getQualifier().storage = EvqGlobal; + symbol->getQualifier().clearInterstage(); + symbol->getQualifier().clearLayout(); + } + }; + + // demote all matching outputs to a plain variable + TIntermSymbol* symbol = output->getAsSymbolNode(); + symbol->getQualifier().storage = EvqGlobal; + symbol->getQualifier().clearInterstage(); + symbol->getQualifier().clearLayout(); + std::for_each(unitAllInputs.begin(), unitAllInputs.end(), demoteMatchingInputs); + } + }); +} + void TIntermediate::mergeCallGraphs(TInfoSink& infoSink, TIntermediate& unit) { if (unit.getNumEntryPoints() > 0) { diff --git a/glslang/MachineIndependent/localintermediate.h b/glslang/MachineIndependent/localintermediate.h index 80638a6b..a2fb9514 100644 --- a/glslang/MachineIndependent/localintermediate.h +++ b/glslang/MachineIndependent/localintermediate.h @@ -1051,6 +1051,7 @@ public: void mergeGlobalUniformBlocks(TInfoSink& infoSink, TIntermediate& unit, bool mergeExistingOnly); void mergeUniformObjects(TInfoSink& infoSink, TIntermediate& unit); void checkStageIO(TInfoSink&, TIntermediate&); + void optimizeStageIO(TInfoSink&, TIntermediate&); bool buildConvertOp(TBasicType dst, TBasicType src, TOperator& convertOp) const; TIntermTyped* createConversion(TBasicType convertTo, TIntermTyped* node) const; diff --git a/glslang/Public/ShaderLang.h b/glslang/Public/ShaderLang.h index b105b5c9..739d7f7b 100644 --- a/glslang/Public/ShaderLang.h +++ b/glslang/Public/ShaderLang.h @@ -271,6 +271,7 @@ enum EShMessages : unsigned { EShMsgEnhanced = (1 << 15), // enhanced message readability EShMsgAbsolutePath = (1 << 16), // Output Absolute path for messages EShMsgDisplayErrorColumn = (1 << 17), // Display error message column aswell as line + EShMsgLinkTimeOptimization = (1 << 18), // perform cross-stage optimizations during linking LAST_ELEMENT_MARKER(EShMsgCount), };