glslang-zig/glslang/MachineIndependent/propagateNoContraction.cpp
qining 9220dbb078 Precise and noContraction propagation
Reimplement the whole workflow to make that: precise'ness of struct
    members won't spread to other non-precise members of the same struct
    instance.

    Approach:
    1. Build the map from symbols to their defining nodes. And for each
    object node (StructIndex, DirectIndex, Symbol nodes, etc), generates an
    accesschain path. Different AST nodes that indicating a same object
    should have the same accesschain path.

    2. Along the building phase in step 1, collect the initial set of
    'precise' (AST qualifier: 'noContraction') objects' accesschain paths.

    3. Start with the initial set of 'precise' accesschain paths, use it as
    a worklist, do as the following steps until the worklist is empty:

        1) Pop an accesschain path from worklist.
        2) Get the symbol part from the accesschain path.
        3) Find the defining nodes of that symbol.
        4) For each defining node, check whether it is defining a 'precise'
        object, or its assignee has nested 'precise' object. Get the
        incremental path from assignee to its nested 'precise' object (if
        any).
        5) Traverse the right side of the defining node, obtain the
        accesschain paths of the corresponding involved 'precise' objects.
        Update the worklist with those new objects' accesschain paths.
        Label involved operations with 'noContraction'.

    In each step, whenever we find the parent object of an nested object is
    'precise' (has 'noContraction' qualifier), we let the nested object
    inherit the 'precise'ness from its parent object.
2016-05-09 10:46:40 -04:00

886 lines
40 KiB
C++

//
// Copyright (C) 2015-2016 Google, Inc.
//
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following
// disclaimer in the documentation and/or other materials provided
// with the distribution.
//
// Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
// FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
// COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
// BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
// LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Visit the nodes in the glslang intermediate tree representation to
// propagate 'noContraction' qualifier.
//
#include "propagateNoContraction.h"
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include "localintermediate.h"
namespace {
// Use string to hold the accesschain information, as in most cases we the
// accesschain is short and may contain only one element, which is the symbol ID.
using ObjectAccessChain = std::string;
#ifndef StructAccessChainDelimiter
#define StructAccessChainDelimiter '/'
#endif
// Mapping from Symbol IDs of symbol nodes, to their defining operation
// nodes.
using NodeMapping = std::unordered_multimap<ObjectAccessChain, glslang::TIntermOperator *>;
// Mapping from object nodes to their accesschain info string.
using AccessChainMapping = std::unordered_map<glslang::TIntermTyped *, ObjectAccessChain>;
// Set of object IDs.
using ObjectAccesschainSet = std::unordered_set<ObjectAccessChain>;
// Set of return branch nodes.
using ReturnBranchNodeSet = std::unordered_set<glslang::TIntermBranch*>;
// A helper function to tell whether a node is 'noContraction'. Returns true if
// the node has 'noContraction' qualifier, otherwise false.
bool isPreciseObjectNode(glslang::TIntermTyped *node)
{
return node->getType().getQualifier().noContraction;
}
// Returns true if the opcode is a dereferencing one.
bool isDereferenceOperation(glslang::TOperator op)
{
switch (op) {
case glslang::EOpIndexDirect:
case glslang::EOpIndexDirectStruct:
case glslang::EOpIndexIndirect:
case glslang::EOpVectorSwizzle:
return true;
default:
return false;
}
}
// Returns true if the opcode leads to an assignment operation.
bool isAssignOperation(glslang::TOperator op)
{
switch (op) {
case glslang::EOpAssign:
case glslang::EOpAddAssign:
case glslang::EOpSubAssign:
case glslang::EOpMulAssign:
case glslang::EOpVectorTimesMatrixAssign:
case glslang::EOpVectorTimesScalarAssign:
case glslang::EOpMatrixTimesScalarAssign:
case glslang::EOpMatrixTimesMatrixAssign:
case glslang::EOpDivAssign:
case glslang::EOpModAssign:
case glslang::EOpAndAssign:
case glslang::EOpLeftShiftAssign:
case glslang::EOpRightShiftAssign:
case glslang::EOpInclusiveOrAssign:
case glslang::EOpExclusiveOrAssign:
case glslang::EOpPostIncrement:
case glslang::EOpPostDecrement:
case glslang::EOpPreIncrement:
case glslang::EOpPreDecrement:
return true;
default:
return false;
}
}
// A helper function to get the unsigned int from a given constant union node.
// Note the node should only holds a uint scalar.
unsigned getStructIndexFromConstantUnion(glslang::TIntermTyped *node)
{
assert(node->getAsConstantUnion() && node->getAsConstantUnion()->isScalar());
unsigned struct_dereference_index = node->getAsConstantUnion()->getConstArray()[0].getUConst();
return struct_dereference_index;
}
// A helper function to generate symbol_label.
ObjectAccessChain generateSymbolLabel(glslang::TIntermSymbol *node)
{
ObjectAccessChain symbol_id = std::to_string(node->getId()) + "(" + node->getName().c_str() + ")";
return symbol_id;
}
// Return true if the operation is an arithmetic operation and valid for
// 'NoContraction' decoration.
bool isArithmeticOperation(glslang::TOperator op)
{
switch (op) {
case glslang::EOpAddAssign:
case glslang::EOpSubAssign:
case glslang::EOpMulAssign:
case glslang::EOpVectorTimesMatrixAssign:
case glslang::EOpVectorTimesScalarAssign:
case glslang::EOpMatrixTimesScalarAssign:
case glslang::EOpMatrixTimesMatrixAssign:
case glslang::EOpDivAssign:
case glslang::EOpModAssign:
case glslang::EOpNegative:
case glslang::EOpAdd:
case glslang::EOpSub:
case glslang::EOpMul:
case glslang::EOpDiv:
case glslang::EOpMod:
case glslang::EOpVectorTimesScalar:
case glslang::EOpVectorTimesMatrix:
case glslang::EOpMatrixTimesVector:
case glslang::EOpMatrixTimesScalar:
case glslang::EOpDot:
case glslang::EOpAddCarry:
case glslang::EOpSubBorrow:
case glslang::EOpUMulExtended:
case glslang::EOpIMulExtended:
case glslang::EOpPostIncrement:
case glslang::EOpPostDecrement:
case glslang::EOpPreIncrement:
case glslang::EOpPreDecrement:
return true;
default:
return false;
}
}
// A helper class to help managing populating_initial_no_contraction_ flag.
template <typename T> class StateSettingGuard {
public:
StateSettingGuard(T *state_ptr, T new_state_value)
: state_ptr_(state_ptr), previous_state_(*state_ptr)
{
*state_ptr = new_state_value;
}
StateSettingGuard(T *state_ptr) : state_ptr_(state_ptr), previous_state_(*state_ptr) {}
void setState(T new_state_value)
{
*state_ptr_ = new_state_value;
}
~StateSettingGuard() { *state_ptr_ = previous_state_; }
private:
T *state_ptr_;
T previous_state_;
};
// A helper function to get the front element from a given ObjectAccessChain
ObjectAccessChain getFrontElement(const ObjectAccessChain &chain)
{
size_t pos_delimiter = chain.find(StructAccessChainDelimiter);
return pos_delimiter == std::string::npos ? chain : chain.substr(0, pos_delimiter);
}
// A helper function to get the accesschain starting from the second element.
ObjectAccessChain subAccessChainFromSecondElement(const ObjectAccessChain& chain)
{
size_t pos_delimiter = chain.find(StructAccessChainDelimiter);
return pos_delimiter == std::string::npos ? "" : chain.substr(pos_delimiter + 1);
}
//
// A traverser which traverses the whole AST and populates:
// 1) A mapping from symbol nodes' IDs to their defining operation nodes.
// 2) A set of accesschains of the initial precise object nodes.
//
class TSymbolDefinitionCollectingTraverser : public glslang::TIntermTraverser {
public:
TSymbolDefinitionCollectingTraverser(
NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping,
ObjectAccesschainSet *precise_objects,
ReturnBranchNodeSet *precise_return_nodes);
// bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override;
bool visitUnary(glslang::TVisit, glslang::TIntermUnary *) override;
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *) override;
void visitSymbol(glslang::TIntermSymbol *) override;
bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *) override;
bool visitBranch(glslang::TVisit, glslang::TIntermBranch *) override;
protected:
// The mapping from symbol node IDs to their defining nodes. This should be
// populated along traversing the AST.
NodeMapping &symbol_definition_mapping_;
// The set of symbol node IDs for precise symbol nodes, the ones marked as
// 'noContraction'.
ObjectAccesschainSet &precise_objects_;
// The set of precise return nodes.
ReturnBranchNodeSet &precise_return_nodes_;
// A temporary cache of the symbol node whose defining node is to be found
// currently along traversing the AST.
ObjectAccessChain object_to_be_defined_;
// A map from object node to its accesschain. This traverser stores
// the built accesschains into this map for each object node it has
// visited.
AccessChainMapping &accesschain_mapping_;
// The pointer to the Function Definition node, so we can get the
// precise'ness of the return expression from it when we traverse the
// return branch node.
glslang::TIntermAggregate* current_function_definition_node_;
};
TSymbolDefinitionCollectingTraverser::TSymbolDefinitionCollectingTraverser(
NodeMapping *symbol_definition_mapping, AccessChainMapping *accesschain_mapping,
ObjectAccesschainSet *precise_objects,
std::unordered_set<glslang::TIntermBranch *> *precise_return_nodes)
: TIntermTraverser(true, false, false), symbol_definition_mapping_(*symbol_definition_mapping),
precise_objects_(*precise_objects), object_to_be_defined_(),
accesschain_mapping_(*accesschain_mapping), current_function_definition_node_(nullptr),
precise_return_nodes_(*precise_return_nodes) {}
// Visits a symbol node, set the object_to_be_defined_ to the
// current node symbol ID, and record a mapping from this node to the current
// object_to_be_defined_, which is the just obtained symbol
// ID.
void TSymbolDefinitionCollectingTraverser::visitSymbol(glslang::TIntermSymbol *node)
{
object_to_be_defined_ = generateSymbolLabel(node);
accesschain_mapping_[node] = object_to_be_defined_;
}
// Visits an aggregate node, traverses all of its children.
bool TSymbolDefinitionCollectingTraverser::visitAggregate(glslang::TVisit,
glslang::TIntermAggregate *node)
{
// This aggreagate node might be a function definition node, in which case we need to
// cache this node, so we can get the precise'ness information of the return value
// of this function later.
StateSettingGuard<glslang::TIntermAggregate *> current_function_definition_node_setting_guard(
&current_function_definition_node_);
if (node->getOp() == glslang::EOpFunction) {
// This is function definition node, we need to cache this node so that we can
// get the precise'ness of the return value later.
current_function_definition_node_setting_guard.setState(node);
}
// Traverse the items in the sequence.
glslang::TIntermSequence &seq = node->getSequence();
for (int i = 0; i < (int)seq.size(); ++i) {
object_to_be_defined_.clear();
seq[i]->traverse(this);
}
return false;
}
bool TSymbolDefinitionCollectingTraverser::visitBranch(glslang::TVisit,
glslang::TIntermBranch *node)
{
if (node->getFlowOp() == glslang::EOpReturn && node->getExpression() &&
current_function_definition_node_ &&
current_function_definition_node_->getType().getQualifier().noContraction) {
// This node is a return node with expression, and its function has
// precise return value. We need to find the involved objects in its
// expression and add them to the set of initial precise objects.
precise_return_nodes_.insert(node);
node->getExpression()->traverse(this);
}
return false;
}
// Visits an unary node. This might be an implicit assignment like i++, i--. etc.
bool TSymbolDefinitionCollectingTraverser::visitUnary(glslang::TVisit /* visit */,
glslang::TIntermUnary *node)
{
object_to_be_defined_.clear();
node->getOperand()->traverse(this);
if (isAssignOperation(node->getOp())) {
// We should always be able to get an accesschain of the operand node.
// But we have some tests in which it is intented to have invalid operand
// nodes, so just return for now.
if (object_to_be_defined_.empty()) return false;
// If the operand node object is 'precise', we collect its accesschain
// for the initial set of 'precise' objects.
if (isPreciseObjectNode(node->getOperand())) {
// The operand node is an 'precise' object node, add its
// accesschain to the set of 'precise' objects. This is to collect
// the initial set of 'precise' objects.
precise_objects_.insert(object_to_be_defined_);
}
// Gets the symbol ID from the object's accesschain.
ObjectAccessChain id_symbol = getFrontElement(object_to_be_defined_);
// Add a mapping from the symbol ID to this assignment operation node.
symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
}
// Unary node is not a dereference node, so we clear the accesschain which
// is under construction.
object_to_be_defined_.clear();
return false;
}
// Visits a binary node and updates the mapping from symbol IDs to the definition
// nodes. Also collects the accesschains for the initial precise objects.
bool TSymbolDefinitionCollectingTraverser::visitBinary(glslang::TVisit /* visit */,
glslang::TIntermBinary *node)
{
// Traverses the left node to build the accesschain info for the object.
object_to_be_defined_.clear();
node->getLeft()->traverse(this);
if (isAssignOperation(node->getOp())) {
// We should always be able to get an accesschain for the left node.
// But we have some tests in which it is intented to have invalid left
// nodes, so just return false in such cases for now.
if (object_to_be_defined_.empty()) return false;
// If the left node object is 'precise', it is an initial precise object
// specified in the shader source. Adds it to the initial worklist to
// process later.
if (isPreciseObjectNode(node->getLeft())) {
// The left node is an 'precise' object node, add its accesschain to
// the set of 'precise' objects. This is to collect the initial set
// of 'precise' objects.
precise_objects_.insert(object_to_be_defined_);
}
// Gets the symbol ID from the object accesschain, which should be the
// first element recorded in the accesschain.
ObjectAccessChain id_symbol = getFrontElement(object_to_be_defined_);
// Adds a mapping from the symbol ID to this assignment operation node.
symbol_definition_mapping_.insert(std::make_pair(id_symbol, node));
// Traverses the right node, there may be other 'assignment'
// operatrions in the right.
object_to_be_defined_.clear();
node->getRight()->traverse(this);
return false;
} else if (isDereferenceOperation(node->getOp())) {
// If the left node is 'precise' object node, this node should also
// be 'precise' object node, and all the members of this node too. There
// is no need to append accesschain information into the object id.
if (isPreciseObjectNode(node->getLeft())) {
node->getWritableType().getQualifier().noContraction = true;
accesschain_mapping_[node] = object_to_be_defined_;
return false;
}
// If the opcode is not EOpIndexDirectStruct, the left node is not be a
// struct type object, hence there is no need to append dereference
// indices. For other composite type objects, the precise'ness of
// members should always matches with the 'precise'ness of the
// composite type object.
if (node->getOp() != glslang::EOpIndexDirectStruct) {
accesschain_mapping_[node] = object_to_be_defined_;
return false;
}
// The left node (parent node) is not 'precise' and it is a struct type
// object. We need to record the accesschain information of the current
// node into its object id.
unsigned struct_dereference_index = getStructIndexFromConstantUnion(node->getRight());
object_to_be_defined_.push_back(StructAccessChainDelimiter);
object_to_be_defined_.append(std::to_string(struct_dereference_index));
accesschain_mapping_[node] = object_to_be_defined_;
// For dereference node, there is no need to traverse the right child
// node as the right node should always be an integer type object.
return false;
} else {
// For other binary nodes, still traverse the right node.
object_to_be_defined_.clear();
node->getRight()->traverse(this);
return false;
}
}
// Traverses the AST and returns a tuple of three members:
// 1) a mapping from symbol IDs to the definition nodes (aka. assignment nodes) of these symbols.
// 2) a mapping from object nodes in the AST to the accesschains of these objects.
// 3) a set of accesschains of precise objects.
std::tuple<NodeMapping, AccessChainMapping, ObjectAccesschainSet, ReturnBranchNodeSet>
getSymbolToDefinitionMappingAndPreciseSymbolIDs(const glslang::TIntermediate &intermediate)
{
auto result_tuple = std::make_tuple(NodeMapping(), AccessChainMapping(), ObjectAccesschainSet(), ReturnBranchNodeSet());
TIntermNode *root = intermediate.getTreeRoot();
if (root == 0) return result_tuple;
NodeMapping &symbol_definition_mapping = std::get<0>(result_tuple);
AccessChainMapping &accesschain_mapping = std::get<1>(result_tuple);
ObjectAccesschainSet &precise_objects = std::get<2>(result_tuple);
ReturnBranchNodeSet &precise_return_nodes = std::get<3>(result_tuple);
// Traverses the AST and populate the results.
TSymbolDefinitionCollectingTraverser collector(&symbol_definition_mapping, &accesschain_mapping,
&precise_objects, &precise_return_nodes);
root->traverse(&collector);
return result_tuple;
}
//
// A traverser that determine whether the left node (or operand node for unary
// node) of an assignment node is 'precise', containing 'precise' or not,
// according to the accesschain a given precise object which share the same
// symbol as the left node.
//
// Post-orderly traverses the left node subtree of an binary assignment node and:
//
// 1) Propagates the 'precise' from the left object nodes to this object node.
//
// 2) Builds object accesschain along the traversal, and also compares with
// the accesschain of the given 'precise' object along with the traversal to
// tell if the node to be defined is 'precise' or not.
//
class TNoContractionAssigneeCheckingTraverser : public glslang::TIntermTraverser {
enum DecisionStatus {
// The object node to be assigned to may contain 'precise' objects and also not 'precise' objects.
Mixed = 0,
// The object node to be assigned to is either a 'precise' object or a struct objects whose members are all 'precise'.
Precise = 1,
// The object node to be assigned to is not a 'precise' object.
NotPreicse = 2,
};
public:
TNoContractionAssigneeCheckingTraverser()
: TIntermTraverser(true, false, false), accesschain_to_precise_object_(), decision_(Mixed) {}
// Checks the precise'ness of a given assignment node with a precise object
// represented as accesschain. The precise object shares the same symbol
// with the assignee of the given assignment node. Return a tuple of two:
//
// 1) The precise'ness of the assignee node of this assignment node. True
// if the assignee contains 'precise' objects or is 'precise', false if
// the assignee is not 'precise' according to the accesschain of the given
// precise object.
//
// 2) The incremental accesschain from the assignee node to its nested
// 'precise' object, according to the accesschain of the given precise
// object. This incremental accesschain can be empty, which means the
// assignee is 'precise'. Otherwise it shows the path to the nested
// precise object.
std::tuple<bool, ObjectAccessChain>
getPrecisenessAndRemainedAccessChain(glslang::TIntermOperator* node,
const ObjectAccessChain &precise_object)
{
assert(isAssignOperation(node->getOp()));
accesschain_to_precise_object_ = precise_object;
decision_ = Mixed;
node->traverse(this);
return make_tuple(decision_ != NotPreicse, accesschain_to_precise_object_);
}
protected:
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override;
bool visitUnary(glslang::TVisit, glslang::TIntermUnary *node) override;
void visitSymbol(glslang::TIntermSymbol *node) override;
// The accesschain toward the given precise object. It will be iniailized
// with the accesschain of a given precise object, then trimmed along the
// traversal of the assignee subtree. The remained accesschain at the end
// of traversal shows the path from the assignee node to its nested
// 'precise' object. If the assignee node is 'precise' object object, this
// should be empty.
ObjectAccessChain accesschain_to_precise_object_;
// A state to tell the precise'ness of the assignee node according to the
// accesschain of the given precise object:
//
// 'Mixed': contains both 'precise' and 'non-precise' object
// (accesschain_to_precise_object_ is not empty),
//
// 'Precise': is precise object (accesschain_to_precise_object is empty),
//
// 'NotPrecise': is not precise object (mismatch in the struct dereference
// indices).
DecisionStatus decision_;
};
// Visit a binary node. As this traverser's job is to check the precise'ness of
// the assignee node in an assignment operation, it only needs to traverse the
// object nodes along the left branches. For struct type object nodes, it needs
// to obtain the struct dereference index from the right node to build the
// accesschain for this node.
bool TNoContractionAssigneeCheckingTraverser::visitBinary(glslang::TVisit,
glslang::TIntermBinary *node)
{
node->getLeft()->traverse(this);
// For dereference operation nodes, we may need to check if the accesschain
// of the given precise object matches with the struct dereference indices
// of the assignee subtree.
if (isDereferenceOperation(node->getOp())) {
if (isPreciseObjectNode(node->getLeft())) {
// The left node is 'precise', which means the object node in the
// left contains the object represented in this node. If the left node
// is 'precise', this object node should also be 'precise' and no need
// to check the accesschain and struct deference indices anymore.
node->getWritableType().getQualifier().noContraction = true;
decision_ = Precise;
return false;
}
if (node->getOp() == glslang::EOpIndexDirectStruct && decision_ == Mixed) {
std::string struct_index =
std::to_string(getStructIndexFromConstantUnion(node->getRight()));
ObjectAccessChain precise_struct_index = getFrontElement(accesschain_to_precise_object_);
if (precise_struct_index == struct_index) {
// The struct dereference index matches with the record in the
// accesschain to the precise object. Pop the front access
// chain index from the precise object access chain.
accesschain_to_precise_object_ =
subAccessChainFromSecondElement(accesschain_to_precise_object_);
// If the given access chain to precise object is empty now,
// it means we've found the corresponding precise object in
// the assignee subtree.
if (accesschain_to_precise_object_.empty()) {
node->getWritableType().getQualifier().noContraction = true;
decision_ = Precise;
}
} else {
// The access chain index does not match with the record in the precise object id.
// This object should not be labelled as 'precise' here.
decision_ = NotPreicse;
}
}
}
return false;
}
// Visits an unary node, traverses its operand. If the node is an assignment node,
// determines the precise'ness of the assignee directly based on the assignee node's
// precise'ness.
bool TNoContractionAssigneeCheckingTraverser::visitUnary(glslang::TVisit,
glslang::TIntermUnary *node)
{
node->getOperand()->traverse(this);
if (isAssignOperation(node->getOp())) {
if (isPreciseObjectNode(node->getOperand())) {
decision_ = Precise;
// As the assignee node is 'precise', all (if any) the
// member objects the that node should also be 'precise'. This means
// we won't need to propagate extra access chain info.
accesschain_to_precise_object_.clear();
} else {
decision_ = NotPreicse;
}
}
return false;
}
// Visits a symbol node. The symbol ID of this node should match with the symbol ID, which is
// the front element, in the accesschain of the given 'precise' object.
void TNoContractionAssigneeCheckingTraverser::visitSymbol(glslang::TIntermSymbol *node)
{
ObjectAccessChain symbol_id = generateSymbolLabel(node);
// The root symbol of the given access chain should be the same with the one represented by the symbol node here.
assert(symbol_id == getFrontElement(accesschain_to_precise_object_));
// Pop the symbol node part from the front end of the accesschain string.
accesschain_to_precise_object_ =
subAccessChainFromSecondElement(accesschain_to_precise_object_);
if (accesschain_to_precise_object_.empty()) {
node->getWritableType().getQualifier().noContraction = true;
decision_ = Precise;
}
// If this symbol node is 'precise', all its members should be 'precise' so the assignee of the processing
// assignment operations is 'precise'.
if (isPreciseObjectNode(node)) {
decision_ = Precise;
}
}
//
// A traverser that only traverses the right side of binary assignment nodes
// and the operand node of unary assignment nodes.
//
// 1) Marks arithmetic operations 'NoContraction'.
//
// 2) Find the object which should be marked as 'precise' in the right and
// update the 'precise' object worklist.
//
class TNoContractionPropagator : public glslang::TIntermTraverser {
public:
TNoContractionPropagator(ObjectAccesschainSet *precise_objects,
const AccessChainMapping &accesschain_mapping)
: TIntermTraverser(true, false, false), remained_accesschain_(),
precise_objects_(*precise_objects),
accesschain_mapping_(accesschain_mapping), added_precise_object_ids_() {}
// Propagates 'precise' in the right nodes of a given assignment node with
// accesschain record from the assignee node to a 'precise' object it
// contains.
void
propagateNoContractionInOneExpression(glslang::TIntermTyped *defining_node,
const ObjectAccessChain &assignee_remained_accesschain)
{
remained_accesschain_ = assignee_remained_accesschain;
if (glslang::TIntermBinary *BN = defining_node->getAsBinaryNode()) {
assert(isAssignOperation(BN->getOp()));
BN->getRight()->traverse(this);
if (isArithmeticOperation(BN->getOp())) {
BN->getWritableType().getQualifier().noContraction = true;
}
} else if (glslang::TIntermUnary *UN = defining_node->getAsUnaryNode()) {
assert(isAssignOperation(UN->getOp()));
UN->getOperand()->traverse(this);
if (isArithmeticOperation(UN->getOp())) {
UN->getWritableType().getQualifier().noContraction = true;
}
}
}
// Propagates 'precise' in a given precise return node.
void
propagateNoContractionInReturnNode(glslang::TIntermBranch *return_node)
{
remained_accesschain_ = "";
assert(return_node->getFlowOp() == glslang::EOpReturn && return_node->getExpression());
return_node->getExpression()->traverse(this);
}
protected:
// Visit an aggregate node. The node can be a initializer list, in which
// case we need to find the 'precise' or 'precise' containing object node
// with the accesschain record. In other cases, just need to traverse all
// the children nodes.
bool visitAggregate(glslang::TVisit, glslang::TIntermAggregate *node) override
{
if (!remained_accesschain_.empty() && node->getOp() == glslang::EOpConstructStruct) {
// This is a struct initializer node, and the remained
// accesschain is not empty, we need to refer to the
// assignee_remained_access_chain_ to find the nested
// 'precise' object. And we don't need to visit other nodes in this
// aggreagate node.
// Gets the struct dereference index that leads to 'precise' object.
ObjectAccessChain precise_accesschain_index_str =
getFrontElement(remained_accesschain_);
unsigned precise_accesschain_index = std::stoul(precise_accesschain_index_str);
// Gets the node pointed by the accesschain index extracted before.
glslang::TIntermTyped *potential_precise_node =
node->getSequence()[precise_accesschain_index]->getAsTyped();
assert(potential_precise_node);
// Pop the front accesschain index from the path, and visit the nested node.
{
ObjectAccessChain next_level_accesschain =
subAccessChainFromSecondElement(remained_accesschain_);
StateSettingGuard<ObjectAccessChain> setup_remained_accesschain_for_next_level(
&remained_accesschain_, next_level_accesschain);
potential_precise_node->traverse(this);
}
} else {
// If this is not a struct constructor, just visit each nested node.
glslang::TIntermSequence &seq = node->getSequence();
for (int i = 0; i < (int)seq.size(); ++i) {
seq[i]->traverse(this);
}
}
return false;
}
// Visit a binary node. A binary node can be an object node, e.g. a dereference node.
// As only the top object nodes in the right side of an assignment needs to be visited
// and added to 'precise' worklist, this traverser won't visit the children nodes of
// an object node. If the binary node does not represent an object node, it should
// go on to traverse its children nodes and if it is an arithmetic operation node, this
// operation should be marked as 'noContraction'.
bool visitBinary(glslang::TVisit, glslang::TIntermBinary *node) override
{
if (isDereferenceOperation(node->getOp())) {
// This binary node is an object node. Need to update the precise
// object set with the accesschain of this node + remained
// accesschain .
ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
if (remained_accesschain_.empty()) {
node->getWritableType().getQualifier().noContraction = true;
} else {
new_precise_accesschain +=
StructAccessChainDelimiter + remained_accesschain_;
}
// Cache the accesschain as added precise object, so we won't add the
// same object to the worklist again.
if (!added_precise_object_ids_.count(new_precise_accesschain)) {
precise_objects_.insert(new_precise_accesschain);
added_precise_object_ids_.insert(new_precise_accesschain);
}
// Only the upper-most object nodes should be visited, so do not
// visit children of this object node.
return false;
}
// If this is an arithmetic operation, marks this node as 'noContraction'.
if (isArithmeticOperation(node->getOp())) {
node->getWritableType().getQualifier().noContraction = true;
}
// As this node is not an object node, need to traverse the children nodes.
node->getLeft()->traverse(this);
node->getRight()->traverse(this);
return false;
}
// Visits an unary node. An unary node can not be an object node. If the operation
// is an arithmetic operation, need to mark this node as 'noContraction'.
bool visitUnary(glslang::TVisit /* visit */, glslang::TIntermUnary *node) override
{
// If this is an arithmetic operation, marks this with 'noContraction'
if (isArithmeticOperation(node->getOp())) {
node->getWritableType().getQualifier().noContraction = true;
}
node->getOperand()->traverse(this);
return false;
}
// Visits a symbol node. A symbol node is always an object node. So we
// should always be able to find its in our colected mapping from object
// nodes to accesschains. As an object node, a symbol node can be either
// 'precise' or containing 'precise' objects according to unused
// accesschain information we have when we visit this node.
void visitSymbol(glslang::TIntermSymbol *node) override
{
// Symbol nodes are object nodes and should always have an
// accesschain collected before matches with it.
assert(accesschain_mapping_.count(node));
ObjectAccessChain new_precise_accesschain = accesschain_mapping_.at(node);
// If the unused accesschain is empty, this symbol node should be
// marked as 'precise'. Otherwise, the unused accesschain should be
// appended to the symbol ID to build a new accesschain which points to
// the nested 'precise' object in this symbol object.
if (remained_accesschain_.empty()) {
node->getWritableType().getQualifier().noContraction = true;
} else {
new_precise_accesschain += StructAccessChainDelimiter + remained_accesschain_;
}
// Add the new 'precise' accesschain to the worklist and make sure we
// don't visit it again.
if (!added_precise_object_ids_.count(new_precise_accesschain)) {
precise_objects_.insert(new_precise_accesschain);
added_precise_object_ids_.insert(new_precise_accesschain);
}
}
// A set of precise objects, represented as accesschains.
ObjectAccesschainSet &precise_objects_;
// Visited symbol nodes, should not revisit these nodes.
ObjectAccesschainSet added_precise_object_ids_;
// The left node of an assignment operation might be an parent of 'precise' objects.
// This means the left node might not be an 'precise' object node, but it may contains
// 'precise' qualifier which should be propagated to the corresponding child node in
// the right. So we need the path from the left node to its nested 'precise' node to
// tell us how to find the corresponding 'precise' node in the right.
ObjectAccessChain remained_accesschain_;
// A map from node pointers to their accesschains.
const AccessChainMapping &accesschain_mapping_;
};
#undef StructAccessChainDelimiter
}
namespace glslang {
void PropagateNoContraction(const glslang::TIntermediate &intermediate)
{
// First, traverses the AST, records symbols with their defining operations
// and collects the initial set of precise symbols (symbol nodes that marked
// as 'noContraction').
auto mappings_and_precise_objects =
getSymbolToDefinitionMappingAndPreciseSymbolIDs(intermediate);
// The mapping of symbol node IDs to their defining nodes. This enables us
// to get the defining node directly from a given symbol ID without
// traversing the tree again.
NodeMapping &symbol_definition_mapping = std::get<0>(mappings_and_precise_objects);
// The mapping of object nodes to their accesschains recorded.
AccessChainMapping &accesschain_mapping = std::get<1>(mappings_and_precise_objects);
// The initial set of 'precise' objects which are represented as the
// accesschain toward them.
ObjectAccesschainSet &precise_object_accesschains =
std::get<2>(mappings_and_precise_objects);
// The set of 'precise' return nodes.
ReturnBranchNodeSet &precise_return_nodes = std::get<3>(mappings_and_precise_objects);
// Second, uses the initial set of precise objects as a worklist, pops an
// accesschain, extract the symbol ID from it. Then:
// 1) Check the assignee object, see if it is 'precise' object node or
// contains 'precise' object. Obtain the incremental accesschain from the
// assignee node to its nested 'precise' node (if any).
// 2) If the assignee object node is 'precise' or it contains 'precise'
// objects, traverses the right side of the assignment operation
// expression to mark arithmetic operations as 'noContration' and update
// 'precise' accesschain worklist with new found object nodes.
// Repeat above steps until the worklist is empty.
TNoContractionAssigneeCheckingTraverser checker;
TNoContractionPropagator propagator(&precise_object_accesschains,
accesschain_mapping);
// We have to initial precise worklist to handle:
// 1) precise return nodes
// 2) precise object accesschains
// We should process the precise return nodes first and the involved
// objects in the return expression should be added to the precise object
// accesschain set.
while (!precise_return_nodes.empty()) {
glslang::TIntermBranch* precise_return_node = *precise_return_nodes.begin();
propagator.propagateNoContractionInReturnNode(precise_return_node);
precise_return_nodes.erase(precise_return_node);
}
while (!precise_object_accesschains.empty()) {
// Get the accesschain of a precise object from the worklist.
ObjectAccessChain precise_object_accesschain = *precise_object_accesschains.begin();
// Get the symbol id from the accesschain.
ObjectAccessChain symbol_id = getFrontElement(precise_object_accesschain);
// Get all the defining nodes of that symbol ID.
std::pair<NodeMapping::iterator, NodeMapping::iterator> range =
symbol_definition_mapping.equal_range(symbol_id);
// Visit all the assignment nodes of that symbol ID and
// 1) Check if the assignee node is 'precise' or contains 'precise'
// objects.
// 2) Propagate the 'precise' to the top layer object ndoes
// in the right side of the assignment operation, update the 'precise'
// worklist with new accesschains representing the new 'precise'
// objects, and mark arithmetic operations as 'noContraction'.
for (NodeMapping::iterator defining_node_iter = range.first;
defining_node_iter != range.second; defining_node_iter++) {
TIntermOperator *defining_node = defining_node_iter->second;
// Check the assignee node.
auto checker_result = checker.getPrecisenessAndRemainedAccessChain(
defining_node, precise_object_accesschain);
bool &contain_precise = std::get<0>(checker_result);
ObjectAccessChain &remained_accesschain = std::get<1>(checker_result);
// If the assignee node is 'precise' or contains 'precise', propagate the
// 'precise' to the right. Otherwise just skip this assignment node.
if (contain_precise) {
propagator.propagateNoContractionInOneExpression(defining_node,
remained_accesschain);
}
}
// Remove the last processed 'precise' object from the worklist.
precise_object_accesschains.erase(precise_object_accesschain);
}
}
};