Implement optional link-time cross stage optimization
This commit is contained in:
parent
9d00d6d6ca
commit
05559a2963
6 changed files with 124 additions and 2 deletions
|
|
@ -49,6 +49,7 @@
|
|||
#include "localintermediate.h"
|
||||
#include "../Include/InfoSink.h"
|
||||
#include "SymbolTable.h"
|
||||
#include "LiveTraverser.h"
|
||||
|
||||
namespace glslang {
|
||||
|
||||
|
|
@ -187,6 +188,107 @@ void TIntermediate::checkStageIO(TInfoSink& infoSink, TIntermediate& unit) {
|
|||
}
|
||||
}
|
||||
|
||||
void TIntermediate::optimizeStageIO(TInfoSink&, TIntermediate& unit)
|
||||
{
|
||||
// don't do any input/output demotion on compute, raytracing, or task/mesh stages
|
||||
// TODO: support task/mesh
|
||||
if (getStage() > EShLangFragment || unit.getStage() > EShLangFragment) {
|
||||
return;
|
||||
}
|
||||
|
||||
class TIOTraverser : public TLiveTraverser {
|
||||
public:
|
||||
TIOTraverser(TIntermediate& i, bool all, TIntermSequence& sequence, TStorageQualifier storage)
|
||||
: TLiveTraverser(i, all, true, false, false), sequence(sequence), storage(storage)
|
||||
{
|
||||
}
|
||||
|
||||
virtual void visitSymbol(TIntermSymbol* symbol)
|
||||
{
|
||||
if (symbol->getQualifier().storage == storage) {
|
||||
sequence.push_back(symbol);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
TIntermSequence& sequence;
|
||||
TStorageQualifier storage;
|
||||
};
|
||||
|
||||
// live symbols only
|
||||
TIntermSequence unitLiveInputs;
|
||||
|
||||
TIOTraverser unitTraverser(unit, false, unitLiveInputs, EvqVaryingIn);
|
||||
unitTraverser.pushFunction(unit.getEntryPointMangledName().c_str());
|
||||
while (! unitTraverser.destinations.empty()) {
|
||||
TIntermNode* destination = unitTraverser.destinations.back();
|
||||
unitTraverser.destinations.pop_back();
|
||||
destination->traverse(&unitTraverser);
|
||||
}
|
||||
|
||||
TIntermSequence allOutputs;
|
||||
TIntermSequence unitAllInputs;
|
||||
|
||||
TIOTraverser allTraverser(*this, true, allOutputs, EvqVaryingOut);
|
||||
getTreeRoot()->traverse(&allTraverser);
|
||||
|
||||
TIOTraverser unitAllTraverser(unit, true, unitAllInputs, EvqVaryingIn);
|
||||
unit.getTreeRoot()->traverse(&unitAllTraverser);
|
||||
|
||||
// find outputs not consumed by the next stage
|
||||
std::for_each(allOutputs.begin(), allOutputs.end(), [&unitLiveInputs, &unitAllInputs](TIntermNode* output) {
|
||||
// don't do anything to builtins
|
||||
if (output->getAsSymbolNode()->getAccessName().compare(0, 3, "gl_") == 0)
|
||||
return;
|
||||
|
||||
// don't demote block outputs (for now)
|
||||
if (output->getAsSymbolNode()->getBasicType() == EbtBlock)
|
||||
return;
|
||||
|
||||
// check if the (loose) output has a matching loose input
|
||||
auto isMatchingInput = [output](TIntermNode* input) {
|
||||
return output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName();
|
||||
};
|
||||
|
||||
// check if the (loose) output has a matching block member input
|
||||
auto isMatchingInputBlockMember = [output](TIntermNode* input) {
|
||||
// ignore loose inputs
|
||||
if (input->getAsSymbolNode()->getBasicType() != EbtBlock)
|
||||
return false;
|
||||
|
||||
// don't demote loose outputs with matching input block members
|
||||
auto isMatchingBlockMember = [output](TTypeLoc type) {
|
||||
return type.type->getFieldName() == output->getAsSymbolNode()->getName();
|
||||
};
|
||||
const TTypeList* members = input->getAsSymbolNode()->getType().getStruct();
|
||||
return std::any_of(members->begin(), members->end(), isMatchingBlockMember);
|
||||
};
|
||||
|
||||
// determine if the input/output pair should be demoted
|
||||
// do the faster (and more likely) loose-loose check first
|
||||
if (std::none_of(unitLiveInputs.begin(), unitLiveInputs.end(), isMatchingInput) &&
|
||||
std::none_of(unitAllInputs.begin(), unitAllInputs.end(), isMatchingInputBlockMember)) {
|
||||
// demote any input matching the output
|
||||
auto demoteMatchingInputs = [output](TIntermNode* input) {
|
||||
if (output->getAsSymbolNode()->getAccessName() == input->getAsSymbolNode()->getAccessName()) {
|
||||
// demote input to a plain variable
|
||||
TIntermSymbol* symbol = input->getAsSymbolNode();
|
||||
symbol->getQualifier().storage = EvqGlobal;
|
||||
symbol->getQualifier().clearInterstage();
|
||||
symbol->getQualifier().clearLayout();
|
||||
}
|
||||
};
|
||||
|
||||
// demote all matching outputs to a plain variable
|
||||
TIntermSymbol* symbol = output->getAsSymbolNode();
|
||||
symbol->getQualifier().storage = EvqGlobal;
|
||||
symbol->getQualifier().clearInterstage();
|
||||
symbol->getQualifier().clearLayout();
|
||||
std::for_each(unitAllInputs.begin(), unitAllInputs.end(), demoteMatchingInputs);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void TIntermediate::mergeCallGraphs(TInfoSink& infoSink, TIntermediate& unit)
|
||||
{
|
||||
if (unit.getNumEntryPoints() > 0) {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue