Implement optional link-time cross stage optimization

This commit is contained in:
Daniel Story 2024-09-12 12:18:09 -07:00 committed by arcady-lunarg
parent 9d00d6d6ca
commit 05559a2963
6 changed files with 124 additions and 2 deletions

View file

@ -49,6 +49,7 @@
#include "localintermediate.h"
#include "../Include/InfoSink.h"
#include "SymbolTable.h"
#include "LiveTraverser.h"
namespace glslang {
@ -187,6 +188,107 @@ void TIntermediate::checkStageIO(TInfoSink& infoSink, TIntermediate& unit) {
}
}
void TIntermediate::optimizeStageIO(TInfoSink&, TIntermediate& unit)
{
// don't do any input/output demotion on compute, raytracing, or task/mesh stages
// TODO: support task/mesh
if (getStage() > EShLangFragment || unit.getStage() > EShLangFragment) {
return;
}
class TIOTraverser : public TLiveTraverser {
public:
TIOTraverser(TIntermediate& i, bool all, TIntermSequence& sequence, TStorageQualifier storage)
: TLiveTraverser(i, all, true, false, false), sequence(sequence), storage(storage)
{
}
virtual void visitSymbol(TIntermSymbol* symbol)
{
if (symbol->getQualifier().storage == storage) {
sequence.push_back(symbol);
}
}
private:
TIntermSequence& sequence;
TStorageQualifier storage;
};
// live symbols only
TIntermSequence unitLiveInputs;
TIOTraverser unitTraverser(unit, false, unitLiveInputs, EvqVaryingIn);
unitTraverser.pushFunction(unit.getEntryPointMangledName().c_str());
while (! unitTraverser.destinations.empty()) {
TIntermNode* destination = unitTraverser.destinations.back();
unitTraverser.destinations.pop_back();
destination->traverse(&unitTraverser);
}
TIntermSequence allOutputs;
TIntermSequence unitAllInputs;
TIOTraverser allTraverser(*this, true, allOutputs, EvqVaryingOut);
getTreeRoot()->traverse(&allTraverser);
TIOTraverser unitAllTraverser(unit, true, unitAllInputs, EvqVaryingIn);
unit.getTreeRoot()->traverse(&unitAllTraverser);
// find outputs not consumed by the next stage
std::for_each(allOutputs.begin(), allOutputs.end(), [&unitLiveInputs, &unitAllInputs](TIntermNode* output) {
// don't do anything to builtins
if (output->getAsSymbolNode()->getAccessName().compare(0, 3, "gl_") == 0)
return;
// don't demote block outputs (for now)
if (output->getAsSymbolNode()->getBasicType() == EbtBlock)
return;
// check if the (loose) output has a matching loose input
auto isMatchingInput = [output](TIntermNode* input) {
return output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName();
};
// check if the (loose) output has a matching block member input
auto isMatchingInputBlockMember = [output](TIntermNode* input) {
// ignore loose inputs
if (input->getAsSymbolNode()->getBasicType() != EbtBlock)
return false;
// don't demote loose outputs with matching input block members
auto isMatchingBlockMember = [output](TTypeLoc type) {
return type.type->getFieldName() == output->getAsSymbolNode()->getName();
};
const TTypeList* members = input->getAsSymbolNode()->getType().getStruct();
return std::any_of(members->begin(), members->end(), isMatchingBlockMember);
};
// determine if the input/output pair should be demoted
// do the faster (and more likely) loose-loose check first
if (std::none_of(unitLiveInputs.begin(), unitLiveInputs.end(), isMatchingInput) &&
std::none_of(unitAllInputs.begin(), unitAllInputs.end(), isMatchingInputBlockMember)) {
// demote any input matching the output
auto demoteMatchingInputs = [output](TIntermNode* input) {
if (output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName()) {
// demote input to a plain variable
TIntermSymbol* symbol = input->getAsSymbolNode();
symbol->getQualifier().storage = EvqGlobal;
symbol->getQualifier().clearInterstage();
symbol->getQualifier().clearLayout();
}
};
// demote all matching outputs to a plain variable
TIntermSymbol* symbol = output->getAsSymbolNode();
symbol->getQualifier().storage = EvqGlobal;
symbol->getQualifier().clearInterstage();
symbol->getQualifier().clearLayout();
std::for_each(unitAllInputs.begin(), unitAllInputs.end(), demoteMatchingInputs);
}
});
}
void TIntermediate::mergeCallGraphs(TInfoSink& infoSink, TIntermediate& unit)
{
if (unit.getNumEntryPoints() > 0) {