Resolve comments

1. Sink adding noContraction decoration to createBinaryOperation() and
createUnaryOperation().

2. Fix comments.

3. Remove the #define of my delimiter, use global constant char.
This commit is contained in:
qining 2016-05-06 17:25:16 -04:00
parent 015150e4b3
commit 25262b3fd9
2 changed files with 162 additions and 162 deletions

View file

@ -33,8 +33,6 @@
//ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE //ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
//POSSIBILITY OF SUCH DAMAGE. //POSSIBILITY OF SUCH DAMAGE.
//
// Author: John Kessenich, LunarG
// //
// Visit the nodes in the glslang intermediate tree representation to // Visit the nodes in the glslang intermediate tree representation to
// translate them to SPIR-V. // translate them to SPIR-V.
@ -135,10 +133,10 @@ protected:
spv::Id createImageTextureFunctionCall(glslang::TIntermOperator* node); spv::Id createImageTextureFunctionCall(glslang::TIntermOperator* node);
spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*); spv::Id handleUserFunctionCall(const glslang::TIntermAggregate*);
spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true); spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true);
spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right); spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right);
spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy);
spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand); spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Id destTypeId, spv::Id operand);
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize); spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy);
@ -621,8 +619,7 @@ bool HasNonLayoutQualifiers(const glslang::TQualifier& qualifier)
// - struct members can inherit from a struct declaration // - struct members can inherit from a struct declaration
// - effect decorations on the struct members (note smooth does not, and expecting something like volatile to effect the whole object) // - effect decorations on the struct members (note smooth does not, and expecting something like volatile to effect the whole object)
// - are not part of the offset/st430/etc or row/column-major layout // - are not part of the offset/st430/etc or row/column-major layout
return qualifier.invariant || qualifier.nopersp || qualifier.flat || qualifier.centroid || qualifier.patch || qualifier.sample || qualifier.hasLocation() || return qualifier.invariant || qualifier.nopersp || qualifier.flat || qualifier.centroid || qualifier.patch || qualifier.sample || qualifier.hasLocation();
qualifier.noContraction;
} }
// //
@ -884,12 +881,10 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T
// do the operation // do the operation
rValue = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()), rValue = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()),
TranslateNoContractionDecoration(node->getType().getQualifier()),
convertGlslangToSpvType(node->getType()), leftRValue, rValue, convertGlslangToSpvType(node->getType()), leftRValue, rValue,
node->getType().getBasicType()); node->getType().getBasicType());
// Decorate this instruction, if this node has 'noContraction' qualifier.
addDecoration(rValue, TranslateNoContractionDecoration(node->getType().getQualifier()));
// these all need their counterparts in createBinaryOperation() // these all need their counterparts in createBinaryOperation()
assert(rValue != spv::NoResult); assert(rValue != spv::NoResult);
} }
@ -1005,6 +1000,7 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T
// get result // get result
spv::Id result = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()), spv::Id result = createBinaryOperation(node->getOp(), TranslatePrecisionDecoration(node->getType()),
TranslateNoContractionDecoration(node->getType().getQualifier()),
convertGlslangToSpvType(node->getType()), left, right, convertGlslangToSpvType(node->getType()), left, right,
node->getLeft()->getType().getBasicType()); node->getLeft()->getType().getBasicType());
@ -1013,8 +1009,6 @@ bool TGlslangToSpvTraverser::visitBinary(glslang::TVisit /* visit */, glslang::T
logger->missingFunctionality("unknown glslang binary operation"); logger->missingFunctionality("unknown glslang binary operation");
return true; // pick up a child as the place-holder result return true; // pick up a child as the place-holder result
} else { } else {
// Decorate this instruction, if this node has 'noContraction' qualifier.
addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier()));
builder.setAccessChainRValue(result); builder.setAccessChainRValue(result);
return false; return false;
} }
@ -1073,6 +1067,7 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
operand = accessChainLoad(node->getOperand()->getType()); operand = accessChainLoad(node->getOperand()->getType());
spv::Decoration precision = TranslatePrecisionDecoration(node->getType()); spv::Decoration precision = TranslatePrecisionDecoration(node->getType());
spv::Decoration noContraction = TranslateNoContractionDecoration(node->getType().getQualifier());
// it could be a conversion // it could be a conversion
if (! result) if (! result)
@ -1080,11 +1075,9 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
// if not, then possibly an operation // if not, then possibly an operation
if (! result) if (! result)
result = createUnaryOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operand, node->getOperand()->getBasicType()); result = createUnaryOperation(node->getOp(), precision, noContraction, convertGlslangToSpvType(node->getType()), operand, node->getOperand()->getBasicType());
if (result) { if (result) {
// Decorate this instruction, if this node has 'noContraction' qualifier.
addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier()));
builder.clearAccessChain(); builder.clearAccessChain();
builder.setAccessChainRValue(result); builder.setAccessChainRValue(result);
@ -1114,11 +1107,10 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
op = glslang::EOpSub; op = glslang::EOpSub;
spv::Id result = createBinaryOperation(op, TranslatePrecisionDecoration(node->getType()), spv::Id result = createBinaryOperation(op, TranslatePrecisionDecoration(node->getType()),
TranslateNoContractionDecoration(node->getType().getQualifier()),
convertGlslangToSpvType(node->getType()), operand, one, convertGlslangToSpvType(node->getType()), operand, one,
node->getType().getBasicType()); node->getType().getBasicType());
assert(result != spv::NoResult); assert(result != spv::NoResult);
// Decorate this instruction, if this node has 'noContraction' qualifier.
addDecoration(result, TranslateNoContractionDecoration(node->getType().getQualifier()));
// The result of operation is always stored, but conditionally the // The result of operation is always stored, but conditionally the
// consumed result. The consumed result is always an r-value. // consumed result. The consumed result is always an r-value.
@ -1414,7 +1406,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
right->traverse(this); right->traverse(this);
spv::Id rightId = accessChainLoad(right->getType()); spv::Id rightId = accessChainLoad(right->getType());
result = createBinaryOperation(binOp, precision, result = createBinaryOperation(binOp, precision, TranslateNoContractionDecoration(node->getType().getQualifier()),
convertGlslangToSpvType(node->getType()), leftId, rightId, convertGlslangToSpvType(node->getType()), leftId, rightId,
left->getType().getBasicType(), reduceComparison); left->getType().getBasicType(), reduceComparison);
@ -1488,7 +1480,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
result = createNoArgOperation(node->getOp()); result = createNoArgOperation(node->getOp());
break; break;
case 1: case 1:
result = createUnaryOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operands.front(), glslangOperands[0]->getAsTyped()->getBasicType()); result = createUnaryOperation(
node->getOp(), precision,
TranslateNoContractionDecoration(node->getType().getQualifier()),
convertGlslangToSpvType(node->getType()), operands.front(),
glslangOperands[0]->getAsTyped()->getBasicType());
break; break;
default: default:
result = createMiscOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operands, node->getBasicType()); result = createMiscOperation(node->getOp(), precision, convertGlslangToSpvType(node->getType()), operands, node->getBasicType());
@ -2680,6 +2676,7 @@ spv::Id TGlslangToSpvTraverser::handleUserFunctionCall(const glslang::TIntermAgg
// Translate AST operation to SPV operation, already having SPV-based operands/types. // Translate AST operation to SPV operation, already having SPV-based operands/types.
spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv::Decoration precision,
spv::Decoration noContraction,
spv::Id typeId, spv::Id left, spv::Id right, spv::Id typeId, spv::Id left, spv::Id right,
glslang::TBasicType typeProxy, bool reduceComparison) glslang::TBasicType typeProxy, bool reduceComparison)
{ {
@ -2816,13 +2813,15 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
if (binOp != spv::OpNop) { if (binOp != spv::OpNop) {
assert(comparison == false); assert(comparison == false);
if (builder.isMatrix(left) || builder.isMatrix(right)) if (builder.isMatrix(left) || builder.isMatrix(right))
return createBinaryMatrixOperation(binOp, precision, typeId, left, right); return createBinaryMatrixOperation(binOp, precision, noContraction, typeId, left, right);
// No matrix involved; make both operands be the same number of components, if needed // No matrix involved; make both operands be the same number of components, if needed
if (needMatchingVectors) if (needMatchingVectors)
builder.promoteScalar(precision, left, right); builder.promoteScalar(precision, left, right);
return builder.setPrecision(builder.createBinOp(binOp, typeId, left, right), precision); spv::Id result = builder.createBinOp(binOp, typeId, left, right);
addDecoration(result, noContraction);
return builder.setPrecision(result, precision);
} }
if (! comparison) if (! comparison)
@ -2891,8 +2890,11 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
break; break;
} }
if (binOp != spv::OpNop) if (binOp != spv::OpNop) {
return builder.setPrecision(builder.createBinOp(binOp, typeId, left, right), precision); spv::Id result = builder.createBinOp(binOp, typeId, left, right);
addDecoration(result, noContraction);
return builder.setPrecision(result, precision);
}
return 0; return 0;
} }
@ -2911,7 +2913,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
// matrix op scalar op in {+, -, /} // matrix op scalar op in {+, -, /}
// scalar op matrix op in {+, -, /} // scalar op matrix op in {+, -, /}
// //
spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id left, spv::Id right) spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right)
{ {
bool firstClass = true; bool firstClass = true;
@ -2947,8 +2949,11 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec
break; break;
} }
if (firstClass) if (firstClass) {
return builder.setPrecision(builder.createBinOp(op, typeId, left, right), precision); spv::Id result = builder.createBinOp(op, typeId, left, right);
addDecoration(result, noContraction);
return builder.setPrecision(result, precision);
}
// Handle component-wise +, -, *, and / for all combinations of type. // Handle component-wise +, -, *, and / for all combinations of type.
// The result type of all of them is the same type as the (a) matrix operand. // The result type of all of them is the same type as the (a) matrix operand.
@ -2983,8 +2988,9 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec
indexes.push_back(c); indexes.push_back(c);
spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec; spv::Id leftVec = leftMat ? builder.createCompositeExtract( left, vecType, indexes) : smearVec;
spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec; spv::Id rightVec = rightMat ? builder.createCompositeExtract(right, vecType, indexes) : smearVec;
results.push_back(builder.createBinOp(op, vecType, leftVec, rightVec)); spv::Id result = builder.createBinOp(op, vecType, leftVec, rightVec);
builder.setPrecision(results.back(), precision); addDecoration(result, noContraction);
results.push_back(builder.setPrecision(result, precision));
} }
// put the pieces together // put the pieces together
@ -2996,7 +3002,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, spv::Dec
} }
} }
spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy) spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy)
{ {
spv::Op unaryOp = spv::OpNop; spv::Op unaryOp = spv::OpNop;
int libCall = -1; int libCall = -1;
@ -3008,7 +3014,7 @@ spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv:
if (isFloat) { if (isFloat) {
unaryOp = spv::OpFNegate; unaryOp = spv::OpFNegate;
if (builder.isMatrixType(typeId)) if (builder.isMatrixType(typeId))
return createUnaryMatrixOperation(unaryOp, precision, typeId, operand, typeProxy); return createUnaryMatrixOperation(unaryOp, precision, noContraction, typeId, operand, typeProxy);
} else } else
unaryOp = spv::OpSNegate; unaryOp = spv::OpSNegate;
break; break;
@ -3290,11 +3296,12 @@ spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, spv:
id = builder.createUnaryOp(unaryOp, typeId, operand); id = builder.createUnaryOp(unaryOp, typeId, operand);
} }
addDecoration(id, noContraction);
return builder.setPrecision(id, precision); return builder.setPrecision(id, precision);
} }
// Create a unary operation on a matrix // Create a unary operation on a matrix
spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Id typeId, spv::Id operand, glslang::TBasicType /* typeProxy */) spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand, glslang::TBasicType /* typeProxy */)
{ {
// Handle unary operations vector by vector. // Handle unary operations vector by vector.
// The result type is the same type as the original type. // The result type is the same type as the original type.
@ -3315,8 +3322,9 @@ spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, spv::Deco
std::vector<unsigned int> indexes; std::vector<unsigned int> indexes;
indexes.push_back(c); indexes.push_back(c);
spv::Id vec = builder.createCompositeExtract(operand, vecType, indexes); spv::Id vec = builder.createCompositeExtract(operand, vecType, indexes);
results.push_back(builder.createUnaryOp(op, vecType, vec)); spv::Id vec_result = builder.createUnaryOp(op, vecType, vec);
builder.setPrecision(results.back(), precision); addDecoration(vec_result, noContraction);
results.push_back(builder.setPrecision(vec_result, precision));
} }
// put the pieces together // put the pieces together

View file

@ -47,18 +47,27 @@
#include "localintermediate.h" #include "localintermediate.h"
namespace { namespace {
// Use string to hold the accesschain information, as in most cases we the // Use string to hold the accesschain information, as in most cases the
// accesschain is short and may contain only one element, which is the symbol ID. // accesschain is short and may contain only one element, which is the symbol
// ID.
// Example: struct {float a; float b;} s;
// Object s.a will be represented with: <symbol ID of s>/0
// Object s.b will be represented with: <symbol ID of s>/1
// Object s will be representend with: <symbol ID of s>
// For members of vector, matrix and arrays, they will be represented with the
// same symbol ID of their container symbol objects. This is because their
// precise'ness is always the same as their container symbol objects.
using ObjectAccessChain = std::string; using ObjectAccessChain = std::string;
#ifndef StructAccessChainDelimiter
#define StructAccessChainDelimiter '/' // The delimiter used in the ObjectAccessChain string to separate symbol ID and
#endif // different level of struct indices.
const char OBJECT_ACCESSCHAIN_DELIMITER = '/';
// Mapping from Symbol IDs of symbol nodes, to their defining operation // Mapping from Symbol IDs of symbol nodes, to their defining operation
// nodes. // nodes.
using NodeMapping = std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator *>; using NodeMapping = std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator*>;
// Mapping from object nodes to their accesschain info string. // Mapping from object nodes to their accesschain info string.
using AccessChainMapping = std::unordered_map<glslang::TIntermTyped *, ObjectAccessChain>; using AccessChainMapping = std::unordered_map<glslang::TIntermTyped*, ObjectAccessChain>;
// Set of object IDs. // Set of object IDs.
using ObjectAccesschainSet = std::unordered_set<ObjectAccessChain>; using ObjectAccesschainSet = std::unordered_set<ObjectAccessChain>;
@ -67,7 +76,7 @@ using ReturnBranchNodeSet = std::unordered_set<glslang::TIntermBranch*>;
// A helper function to tell whether a node is 'noContraction'. Returns true if // A helper function to tell whether a node is 'noContraction'. Returns true if
// the node has 'noContraction' qualifier, otherwise false. // the node has 'noContraction' qualifier, otherwise false.
bool isPreciseObjectNode(glslang::TIntermTyped *node) bool isPreciseObjectNode(glslang::TIntermTyped* node)
{ {
return node->getType().getQualifier().noContraction; return node->getType().getQualifier().noContraction;
} }
@ -118,7 +127,7 @@ bool isAssignOperation(glslang::TOperator op)
// A helper function to get the unsigned int from a given constant union node. // A helper function to get the unsigned int from a given constant union node.
// Note the node should only holds a uint scalar. // Note the node should only holds a uint scalar.
unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped *node) unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped* node)
{ {
assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar()); assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar());
unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst(); unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst();
@ -126,9 +135,10 @@ unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped *node)
} }
// A helper function to generate symbol_label. // A helper function to generate symbol_label.
ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol *node) ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol* node)
{ {
ObjectAccessChain symbol_id = std::to_string(node->getId()) + "(" + node->getName().c_str() + ")"; ObjectAccessChain symbol_id =
std::to_string(node->getId()) + "(" + node->getName().c_str() + ")";
return symbol_id; return symbol_id;
} }
@ -180,43 +190,41 @@ bool isArithmeticOperation(glslang::TOperator op)
// A helper class to help managing populating_initial_no_contraction_ flag. // A helper class to help managing populating_initial_no_contraction_ flag.
template <typename T> class StateSettingGuard { template <typename T> class StateSettingGuard {
public: public:
StateSettingGuard(T *state_ptr, T new_state_value) StateSettingGuard(T* state_ptr, T new_state_value)
: state_ptr_(state_ptr), previous_state_(*state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr)
{ {
*state_ptr = new_state_value; *state_ptr = new_state_value;
} }
StateSettingGuard(T *state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {} StateSettingGuard(T* state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {}
void setState(T new_state_value) void setState(T new_state_value) { *state_ptr_ = new_state_value; }
{
*state_ptr_ = new_state_value;
}
~StateSettingGuard() { *state_ptr_ = previous_state_; } ~StateSettingGuard() { *state_ptr_ = previous_state_; }
private: private:
T *state_ptr_; T* state_ptr_;
T previous_state_; T previous_state_;
}; };
// A helper function to get the front element from a given ObjectAccessChain // A helper function to get the front element from a given ObjectAccessChain
ObjectAccessChain getFrontElement(const ObjectAccessChain &chain) ObjectAccessChain getFrontElement(const ObjectAccessChain& chain)
{ {
size_t pos_delimiter = chain.find(StructAccessChainDelimiter); size_t pos_delimiter = chain.find(OBJECT_ACCESSCHAIN_DELIMITER);
return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter); return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter);
} }
// A helper function to get the accesschain starting from the second element. // A helper function to get the accesschain starting from the second element.
ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain) ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain)
{ {
size_t pos_delimiter = chain.find(StructAccessChainDelimiter); size_t pos_delimiter = chain.find(OBJECT_ACCESSCHAIN_DELIMITER);
return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1); return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1);
} }
// A helper function to get the accesschain after removing a given prefix. // A helper function to get the accesschain after removing a given prefix.
ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain &chain, const ObjectAccessChain &prefix) ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain& chain,
const ObjectAccessChain& prefix)
{ {
size_t pos = chain.find(prefix); size_t pos = chain.find(prefix);
if (pos != 0) return chain; if (pos != 0) return chain;
return chain.substr(prefix.length() + sizeof(StructAccessChainDelimiter)); return chain.substr(prefix.length() + sizeof(OBJECT_ACCESSCHAIN_DELIMITER));
} }
// //
@ -226,34 +234,33 @@ ObjectAccessChain getSubAccessChainAfterPrefix(const ObjectAccessChain &chain, c
// //
class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser { class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser {
public: public:
TSymbolDefinitionCollectingTraverser( TSymbolDefinitionCollectingTraverser(NodeMapping* symbol_definition_mapping,
NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping, AccessChainMapping* accesschain_mapping,
ObjectAccesschainSet *precise_objects, ObjectAccesschainSet* precise_objects,
ReturnBranchNodeSet *precise_return_nodes); ReturnBranchNodeSet* precise_return_nodes);
// bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override; bool visitUnary(glslang::TVisit, glslang::TIntermUnary*) override;
bool visitUnary(glslang::TVisit, glslang::TIntermUnary *) override; bool visitBinary(glslang::TVisit, glslang::TIntermBinary*) override;
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *) override; void visitSymbol(glslang::TIntermSymbol*) override;
void visitSymbol(glslang::TIntermSymbol *) override; bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate*) override;
bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override; bool visitBranch(glslang::TVisit, glslang::TIntermBranch*) override;
bool visitBranch(glslang::TVisit, glslang::TIntermBranch *) override;
protected: protected:
// The mapping from symbol node IDs to their defining nodes. This should be // The mapping from symbol node IDs to their defining nodes. This should be
// populated along traversing the AST. // populated along traversing the AST.
NodeMapping &symbol_definition_mapping_; NodeMapping& symbol_definition_mapping_;
// The set of symbol node IDs for precise symbol nodes, the ones marked as // The set of symbol node IDs for precise symbol nodes, the ones marked as
// 'noContraction'. // 'noContraction'.
ObjectAccesschainSet &precise_objects_; ObjectAccesschainSet& precise_objects_;
// The set of precise return nodes. // The set of precise return nodes.
ReturnBranchNodeSet &precise_return_nodes_; ReturnBranchNodeSet& precise_return_nodes_;
// A temporary cache of the symbol node whose defining node is to be found // A temporary cache of the symbol node whose defining node is to be found
// currently along traversing the AST. // currently along traversing the AST.
ObjectAccessChain object_to_be_defined_; ObjectAccessChain object_to_be_defined_;
// A map from object node to its accesschain. This traverser stores // A map from object node to its accesschain. This traverser stores
// the built accesschains into this map for each object node it has // the built accesschains into this map for each object node it has
// visited. // visited.
AccessChainMapping &accesschain_mapping_; AccessChainMapping& accesschain_mapping_;
// The pointer to the Function Definition node, so we can get the // The pointer to the Function Definition node, so we can get the
// precise'ness of the return expression from it when we traverse the // precise'ness of the return expression from it when we traverse the
// return branch node. // return branch node.
@ -261,9 +268,9 @@ protected:
}; };
TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser( TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser(
NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping, NodeMapping* symbol_definition_mapping, AccessChainMapping* accesschain_mapping,
ObjectAccesschainSet *precise_objects, ObjectAccesschainSet* precise_objects,
std::unordered_set<glslang::TIntermBranch *> *precise_return_nodes) std::unordered_set<glslang::TIntermBranch*>* precise_return_nodes)
: TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping), : TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping),
precise_objects_(*precise_objects), object_to_be_defined_(), precise_objects_(*precise_objects), object_to_be_defined_(),
accesschain_mapping_(*accesschain_mapping), current_function_definition_node_(nullptr), accesschain_mapping_(*accesschain_mapping), current_function_definition_node_(nullptr),
@ -273,7 +280,7 @@ TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser(
// current node symbol ID, and record a mapping from this node to the current // current node symbol ID, and record a mapping from this node to the current
// object_to_be_defined_, which is the just obtained symbol // object_to_be_defined_, which is the just obtained symbol
// ID. // ID.
void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol *node) void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol* node)
{ {
object_to_be_defined_ = generateSymbolLabel(node); object_to_be_defined_ = generateSymbolLabel(node);
accesschain_mapping_[node] = object_to_be_defined_; accesschain_mapping_[node] = object_to_be_defined_;
@ -281,12 +288,12 @@ void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol *n
// Visits an aggregate node, traverses all of its children. // Visits an aggregate node, traverses all of its children.
bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit, bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
glslang::TIntermAggregate *node) glslang::TIntermAggregate* node)
{ {
// This aggreagate node might be a function definition node, in which case we need to // This aggreagate node might be a function definition node, in which case we need to
// cache this node, so we can get the precise'ness information of the return value // cache this node, so we can get the precise'ness information of the return value
// of this function later. // of this function later.
StateSettingGuard<glslang::TIntermAggregate *> current_function_definition_node_setting_guard( StateSettingGuard<glslang::TIntermAggregate*> current_function_definition_node_setting_guard(
&current_function_definition_node_); &current_function_definition_node_);
if (node->getOp() == glslang::EOpFunction) { if (node->getOp() == glslang::EOpFunction) {
// This is function definition node, we need to cache this node so that we can // This is function definition node, we need to cache this node so that we can
@ -294,7 +301,7 @@ bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
current_function_definition_node_setting_guard.setState(node); current_function_definition_node_setting_guard.setState(node);
} }
// Traverse the items in the sequence. // Traverse the items in the sequence.
glslang::TIntermSequence &seq = node->getSequence(); glslang::TIntermSequence& seq = node->getSequence();
for (int i = 0; i < (int)seq.size(); ++i) { for (int i = 0; i < (int)seq.size(); ++i) {
object_to_be_defined_.clear(); object_to_be_defined_.clear();
seq[i]->traverse(this); seq[i]->traverse(this);
@ -303,7 +310,7 @@ bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
} }
bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit, bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit,
glslang::TIntermBranch *node) glslang::TIntermBranch* node)
{ {
if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() && if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() &&
current_function_definition_node_ && current_function_definition_node_ &&
@ -319,7 +326,7 @@ bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit,
// Visits an unary node. This might be an implicit assignment like i++, i--. etc. // Visits an unary node. This might be an implicit assignment like i++, i--. etc.
bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */, bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */,
glslang::TIntermUnary *node) glslang::TIntermUnary* node)
{ {
object_to_be_defined_.clear(); object_to_be_defined_.clear();
node->getOperand()->traverse(this); node->getOperand()->traverse(this);
@ -351,7 +358,7 @@ bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit *
// Visits a binary node and updates the mapping from symbol IDs to the definition // Visits a binary node and updates the mapping from symbol IDs to the definition
// nodes. Also collects the accesschains for the initial precise objects. // nodes. Also collects the accesschains for the initial precise objects.
bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */, bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */,
glslang::TIntermBinary *node) glslang::TIntermBinary* node)
{ {
// Traverses the left node to build the accesschain info for the object. // Traverses the left node to build the accesschain info for the object.
object_to_be_defined_.clear(); object_to_be_defined_.clear();
@ -408,7 +415,7 @@ bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit
// object. We need to record the accesschain information of the current // object. We need to record the accesschain information of the current
// node into its object id. // node into its object id.
unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight()); unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight());
object_to_be_defined_.push_back(StructAccessChainDelimiter); object_to_be_defined_.push_back(OBJECT_ACCESSCHAIN_DELIMITER);
object_to_be_defined_.append(std::to_string(struct_dereference_index)); object_to_be_defined_.append(std::to_string(struct_dereference_index));
accesschain_mapping_[node] = object_to_be_defined_; accesschain_mapping_[node] = object_to_be_defined_;
@ -428,17 +435,18 @@ bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit
// 2) a mapping from object nodes in the AST to the accesschains of these objects. // 2) a mapping from object nodes in the AST to the accesschains of these objects.
// 3) a set of accesschains of precise objects. // 3) a set of accesschains of precise objects.
std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet> std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet>
getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate &intermediate) getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate& intermediate)
{ {
auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(), ReturnBranchNodeSet()); auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(),
ReturnBranchNodeSet());
TIntermNode *root = intermediate.getTreeRoot(); TIntermNode* root = intermediate.getTreeRoot();
if (root == 0) return result_tuple; if (root == 0) return result_tuple;
NodeMapping &symbol_definition_mapping = std::get<0>(result_tuple); NodeMapping& symbol_definition_mapping = std::get<0>(result_tuple);
AccessChainMapping &accesschain_mapping = std::get<1>(result_tuple); AccessChainMapping& accesschain_mapping = std::get<1>(result_tuple);
ObjectAccesschainSet &precise_objects = std::get<2>(result_tuple); ObjectAccesschainSet& precise_objects = std::get<2>(result_tuple);
ReturnBranchNodeSet &precise_return_nodes = std::get<3>(result_tuple); ReturnBranchNodeSet& precise_return_nodes = std::get<3>(result_tuple);
// Traverses the AST and populate the results. // Traverses the AST and populate the results.
TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping, TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping,
@ -474,7 +482,7 @@ class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser
}; };
public: public:
TNoContractionAssigneeCheckingTraverser(const AccessChainMapping &accesschain_mapping) TNoContractionAssigneeCheckingTraverser(const AccessChainMapping& accesschain_mapping)
: TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping), : TIntermTraverser(true, false, false), accesschain_mapping_(accesschain_mapping),
precise_object_(nullptr) {} precise_object_(nullptr) {}
@ -494,7 +502,7 @@ public:
// precise object. // precise object.
std::tuple<bool, ObjectAccessChain> std::tuple<bool, ObjectAccessChain>
getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node, getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node,
const ObjectAccessChain &precise_object) const ObjectAccessChain& precise_object)
{ {
assert(isAssignOperation(node->getOp())); assert(isAssignOperation(node->getOp()));
precise_object_ = &precise_object; precise_object_ = &precise_object;
@ -570,23 +578,23 @@ public:
} }
protected: protected:
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override; bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override;
void visitSymbol(glslang::TIntermSymbol *node) override; void visitSymbol(glslang::TIntermSymbol* node) override;
// A map from object nodes to their accesschain string (used as object ID). // A map from object nodes to their accesschain string (used as object ID).
const AccessChainMapping &accesschain_mapping_; const AccessChainMapping& accesschain_mapping_;
// A given precise object, represented in it accesschain string. This // A given precise object, represented in it accesschain string. This
// precise object is used to be compared with the assignee node to tell if // precise object is used to be compared with the assignee node to tell if
// the assignee node is 'precise', contains 'precise' object or not // the assignee node is 'precise', contains 'precise' object or not
// 'precise'. // 'precise'.
const ObjectAccessChain *precise_object_; const ObjectAccessChain* precise_object_;
}; };
// Visit a binary node. If the node is an object node, it must be a dereference // Visit a binary node. If the node is an object node, it must be a dereference
// node. In such cases, if the left node is 'precise', this node should also be // node. In such cases, if the left node is 'precise', this node should also be
// 'precise'. // 'precise'.
bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit, bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
glslang::TIntermBinary *node) glslang::TIntermBinary* node)
{ {
// Traverses the left so that we transfer the 'precise' from nesting object // Traverses the left so that we transfer the 'precise' from nesting object
// to its nested object. // to its nested object.
@ -602,7 +610,7 @@ bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
// this node should be marked as 'precise'. // this node should be marked as 'precise'.
if (isPreciseObjectNode(node->getLeft())) { if (isPreciseObjectNode(node->getLeft())) {
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} else if (accesschain_mapping_.at(node) == *precise_object_){ } else if (accesschain_mapping_.at(node) == *precise_object_) {
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} }
} }
@ -611,7 +619,7 @@ bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
// Visit a symbol node, if the symbol node ID (its accesschain string) matches // Visit a symbol node, if the symbol node ID (its accesschain string) matches
// with the given precise object, this node should be 'precise'. // with the given precise object, this node should be 'precise'.
void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol *node) void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol* node)
{ {
// A symbol node should always be an object node, and should have been added // A symbol node should always be an object node, and should have been added
// to the map from object nodes to their accesschain strings. // to the map from object nodes to their accesschain strings.
@ -632,27 +640,27 @@ void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol
// //
class TNoContractionPropagator : public glslang::TIntermTraverser { class TNoContractionPropagator : public glslang::TIntermTraverser {
public: public:
TNoContractionPropagator(ObjectAccesschainSet *precise_objects, TNoContractionPropagator(ObjectAccesschainSet* precise_objects,
const AccessChainMapping &accesschain_mapping) const AccessChainMapping& accesschain_mapping)
: TIntermTraverser(true, false, false), remained_accesschain_(), : TIntermTraverser(true, false, false), remained_accesschain_(),
precise_objects_(*precise_objects), precise_objects_(*precise_objects), accesschain_mapping_(accesschain_mapping),
accesschain_mapping_(accesschain_mapping), added_precise_object_ids_() {} added_precise_object_ids_() {}
// Propagates 'precise' in the right nodes of a given assignment node with // Propagates 'precise' in the right nodes of a given assignment node with
// accesschain record from the assignee node to a 'precise' object it // accesschain record from the assignee node to a 'precise' object it
// contains. // contains.
void void
propagateNoContractionInOneExpression(glslang::TIntermTyped *defining_node, propagateNoContractionInOneExpression(glslang::TIntermTyped* defining_node,
const ObjectAccessChain &assignee_remained_accesschain) const ObjectAccessChain& assignee_remained_accesschain)
{ {
remained_accesschain_ = assignee_remained_accesschain; remained_accesschain_ = assignee_remained_accesschain;
if (glslang::TIntermBinary *BN = defining_node->getAsBinaryNode()) { if (glslang::TIntermBinary* BN = defining_node->getAsBinaryNode()) {
assert(isAssignOperation(BN->getOp())); assert(isAssignOperation(BN->getOp()));
BN->getRight()->traverse(this); BN->getRight()->traverse(this);
if (isArithmeticOperation(BN->getOp())) { if (isArithmeticOperation(BN->getOp())) {
BN->getWritableType().getQualifier().noContraction = true; BN->getWritableType().getQualifier().noContraction = true;
} }
} else if (glslang::TIntermUnary *UN = defining_node->getAsUnaryNode()) { } else if (glslang::TIntermUnary* UN = defining_node->getAsUnaryNode()) {
assert(isAssignOperation(UN->getOp())); assert(isAssignOperation(UN->getOp()));
UN->getOperand()->traverse(this); UN->getOperand()->traverse(this);
if (isArithmeticOperation(UN->getOp())) { if (isArithmeticOperation(UN->getOp())) {
@ -662,8 +670,7 @@ public:
} }
// Propagates 'precise' in a given precise return node. // Propagates 'precise' in a given precise return node.
void void propagateNoContractionInReturnNode(glslang::TIntermBranch* return_node)
propagateNoContractionInReturnNode(glslang::TIntermBranch *return_node)
{ {
remained_accesschain_ = ""; remained_accesschain_ = "";
assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression()); assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression());
@ -675,7 +682,7 @@ protected:
// case we need to find the 'precise' or 'precise' containing object node // case we need to find the 'precise' or 'precise' containing object node
// with the accesschain record. In other cases, just need to traverse all // with the accesschain record. In other cases, just need to traverse all
// the children nodes. // the children nodes.
bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *node) override bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate* node) override
{ {
if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) { if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) {
// This is a struct initializer node, and the remained // This is a struct initializer node, and the remained
@ -689,7 +696,7 @@ protected:
getFrontElement(remained_accesschain_); getFrontElement(remained_accesschain_);
unsigned precise_accesschain_index = std::stoul(precise_accesschain_index_str); unsigned precise_accesschain_index = std::stoul(precise_accesschain_index_str);
// Gets the node pointed by the accesschain index extracted before. // Gets the node pointed by the accesschain index extracted before.
glslang::TIntermTyped *potential_precise_node = glslang::TIntermTyped* potential_precise_node =
node->getSequence()[precise_accesschain_index]->getAsTyped(); node->getSequence()[precise_accesschain_index]->getAsTyped();
assert(potential_precise_node); assert(potential_precise_node);
// Pop the front accesschain index from the path, and visit the nested node. // Pop the front accesschain index from the path, and visit the nested node.
@ -700,16 +707,9 @@ protected:
&remained_accesschain_, next_level_accesschain); &remained_accesschain_, next_level_accesschain);
potential_precise_node->traverse(this); potential_precise_node->traverse(this);
} }
return false;
} else {
// If this is not a struct constructor, just visit each nested node.
glslang::TIntermSequence &seq = node->getSequence();
for (int i = 0; i < (int)seq.size(); ++i) {
seq[i]->traverse(this);
}
} }
return true;
return false;
} }
// Visit a binary node. A binary node can be an object node, e.g. a dereference node. // Visit a binary node. A binary node can be an object node, e.g. a dereference node.
@ -718,7 +718,7 @@ protected:
// an object node. If the binary node does not represent an object node, it should // an object node. If the binary node does not represent an object node, it should
// go on to traverse its children nodes and if it is an arithmetic operation node, this // go on to traverse its children nodes and if it is an arithmetic operation node, this
// operation should be marked as 'noContraction'. // operation should be marked as 'noContraction'.
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override bool visitBinary(glslang::TVisit, glslang::TIntermBinary* node) override
{ {
if (isDereferenceOperation(node->getOp())) { if (isDereferenceOperation(node->getOp())) {
// This binary node is an object node. Need to update the precise // This binary node is an object node. Need to update the precise
@ -728,8 +728,7 @@ protected:
if (remained_accesschain_.empty()) { if (remained_accesschain_.empty()) {
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} else { } else {
new_precise_accesschain += new_precise_accesschain += OBJECT_ACCESSCHAIN_DELIMITER + remained_accesschain_;
StructAccessChainDelimiter + remained_accesschain_;
} }
// Cache the accesschain as added precise object, so we won't add the // Cache the accesschain as added precise object, so we won't add the
// same object to the worklist again. // same object to the worklist again.
@ -746,21 +745,18 @@ protected:
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} }
// As this node is not an object node, need to traverse the children nodes. // As this node is not an object node, need to traverse the children nodes.
node->getLeft()->traverse(this); return true;
node->getRight()->traverse(this);
return false;
} }
// Visits an unary node. An unary node can not be an object node. If the operation // Visits an unary node. An unary node can not be an object node. If the operation
// is an arithmetic operation, need to mark this node as 'noContraction'. // is an arithmetic operation, need to mark this node as 'noContraction'.
bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary *node) override bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary* node) override
{ {
// If this is an arithmetic operation, marks this with 'noContraction' // If this is an arithmetic operation, marks this with 'noContraction'
if (isArithmeticOperation(node->getOp())) { if (isArithmeticOperation(node->getOp())) {
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} }
node->getOperand()->traverse(this); return true;
return false;
} }
// Visits a symbol node. A symbol node is always an object node. So we // Visits a symbol node. A symbol node is always an object node. So we
@ -768,7 +764,7 @@ protected:
// nodes to accesschains. As an object node, a symbol node can be either // nodes to accesschains. As an object node, a symbol node can be either
// 'precise' or containing 'precise' objects according to unused // 'precise' or containing 'precise' objects according to unused
// accesschain information we have when we visit this node. // accesschain information we have when we visit this node.
void visitSymbol(glslang::TIntermSymbol *node) override void visitSymbol(glslang::TIntermSymbol* node) override
{ {
// Symbol nodes are object nodes and should always have an // Symbol nodes are object nodes and should always have an
// accesschain collected before matches with it. // accesschain collected before matches with it.
@ -781,7 +777,7 @@ protected:
if (remained_accesschain_.empty()) { if (remained_accesschain_.empty()) {
node->getWritableType().getQualifier().noContraction = true; node->getWritableType().getQualifier().noContraction = true;
} else { } else {
new_precise_accesschain += StructAccessChainDelimiter + remained_accesschain_; new_precise_accesschain += OBJECT_ACCESSCHAIN_DELIMITER + remained_accesschain_;
} }
// Add the new 'precise' accesschain to the worklist and make sure we // Add the new 'precise' accesschain to the worklist and make sure we
// don't visit it again. // don't visit it again.
@ -792,7 +788,7 @@ protected:
} }
// A set of precise objects, represented as accesschains. // A set of precise objects, represented as accesschains.
ObjectAccesschainSet &precise_objects_; ObjectAccesschainSet& precise_objects_;
// Visited symbol nodes, should not revisit these nodes. // Visited symbol nodes, should not revisit these nodes.
ObjectAccesschainSet added_precise_object_ids_; ObjectAccesschainSet added_precise_object_ids_;
// The left node of an assignment operation might be an parent of 'precise' objects. // The left node of an assignment operation might be an parent of 'precise' objects.
@ -802,15 +798,13 @@ protected:
// tell us how to find the corresponding 'precise' node in the right. // tell us how to find the corresponding 'precise' node in the right.
ObjectAccessChain remained_accesschain_; ObjectAccessChain remained_accesschain_;
// A map from node pointers to their accesschains. // A map from node pointers to their accesschains.
const AccessChainMapping &accesschain_mapping_; const AccessChainMapping& accesschain_mapping_;
}; };
#undef StructAccessChainDelimiter
} }
namespace glslang { namespace glslang {
void PropagateNoContraction(const glslang::TIntermediate &intermediate) void PropagateNoContraction(const glslang::TIntermediate& intermediate)
{ {
// First, traverses the AST, records symbols with their defining operations // First, traverses the AST, records symbols with their defining operations
// and collects the initial set of precise symbols (symbol nodes that marked // and collects the initial set of precise symbols (symbol nodes that marked
@ -821,18 +815,17 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate)
// The mapping of symbol node IDs to their defining nodes. This enables us // The mapping of symbol node IDs to their defining nodes. This enables us
// to get the defining node directly from a given symbol ID without // to get the defining node directly from a given symbol ID without
// traversing the tree again. // traversing the tree again.
NodeMapping &symbol_definition_mapping = std::get<0>(mappings_and_precise_objects); NodeMapping& symbol_definition_mapping = std::get<0>(mappings_and_precise_objects);
// The mapping of object nodes to their accesschains recorded. // The mapping of object nodes to their accesschains recorded.
AccessChainMapping &accesschain_mapping = std::get<1>(mappings_and_precise_objects); AccessChainMapping& accesschain_mapping = std::get<1>(mappings_and_precise_objects);
// The initial set of 'precise' objects which are represented as the // The initial set of 'precise' objects which are represented as the
// accesschain toward them. // accesschain toward them.
ObjectAccesschainSet &precise_object_accesschains = ObjectAccesschainSet& precise_object_accesschains = std::get<2>(mappings_and_precise_objects);
std::get<2>(mappings_and_precise_objects);
// The set of 'precise' return nodes. // The set of 'precise' return nodes.
ReturnBranchNodeSet &precise_return_nodes = std::get<3>(mappings_and_precise_objects); ReturnBranchNodeSet& precise_return_nodes = std::get<3>(mappings_and_precise_objects);
// Second, uses the initial set of precise objects as a worklist, pops an // Second, uses the initial set of precise objects as a worklist, pops an
// accesschain, extract the symbol ID from it. Then: // accesschain, extract the symbol ID from it. Then:
@ -845,10 +838,9 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate)
// 'precise' accesschain worklist with new found object nodes. // 'precise' accesschain worklist with new found object nodes.
// Repeat above steps until the worklist is empty. // Repeat above steps until the worklist is empty.
TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping); TNoContractionAssigneeCheckingTraverser checker(accesschain_mapping);
TNoContractionPropagator propagator(&precise_object_accesschains, TNoContractionPropagator propagator(&precise_object_accesschains, accesschain_mapping);
accesschain_mapping);
// We have to initial precise worklist to handle: // We have two initial precise worklists to handle:
// 1) precise return nodes // 1) precise return nodes
// 2) precise object accesschains // 2) precise object accesschains
// We should process the precise return nodes first and the involved // We should process the precise return nodes first and the involved
@ -877,12 +869,12 @@ void PropagateNoContraction(const glslang::TIntermediate &intermediate)
// objects, and mark arithmetic operations as 'noContraction'. // objects, and mark arithmetic operations as 'noContraction'.
for (NodeMapping::iterator defining_node_iter = range.first; for (NodeMapping::iterator defining_node_iter = range.first;
defining_node_iter != range.second; defining_node_iter++) { defining_node_iter != range.second; defining_node_iter++) {
TIntermOperator *defining_node = defining_node_iter->second; TIntermOperator* defining_node = defining_node_iter->second;
// Check the assignee node. // Check the assignee node.
auto checker_result = checker.getPrecisenessAndRemainedAccessChain( auto checker_result = checker.getPrecisenessAndRemainedAccessChain(
defining_node, precise_object_accesschain); defining_node, precise_object_accesschain);
bool &contain_precise = std::get<0>(checker_result); bool& contain_precise = std::get<0>(checker_result);
ObjectAccessChain &remained_accesschain = std::get<1>(checker_result); ObjectAccessChain& remained_accesschain = std::get<1>(checker_result);
// If the assignee node is 'precise' or contains 'precise', propagate the // If the assignee node is 'precise' or contains 'precise', propagate the
// 'precise' to the right. Otherwise just skip this assignment node. // 'precise' to the right. Otherwise just skip this assignment node.
if (contain_precise) { if (contain_precise) {