Change infrastructure to support constant folding across built-in functions, as required by 1.2 semantics. Partially fleshed out with min/max and some trig functions. Still have to complete all operations.

git-svn-id: https://cvs.khronos.org/svn/repos/ogl/trunk/ecosystem/public/sdk/tools/glslang@20806 e7fa87d3-cd2b-0410-9028-fcbf551c1848
This commit is contained in:
John Kessenich 2013-03-07 19:22:07 +00:00
parent 3f3e0ad3ad
commit 53fb465729
14 changed files with 737 additions and 481 deletions

View file

@ -43,8 +43,6 @@
#include "RemoveTree.h"
#include <float.h>
bool CompareStructure(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray);
////////////////////////////////////////////////////////////////////////////
//
// First set of functions are to help build the intermediate representation.
@ -221,7 +219,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
if (child->getType().getBasicType() == EbtStruct || child->getType().isArray())
return 0;
}
//
// Do we need to promote the operand?
//
@ -270,7 +268,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
return 0;
if (childTempConstant) {
TIntermTyped* newChild = childTempConstant->fold(op, 0, infoSink);
TIntermTyped* newChild = childTempConstant->fold(op, node->getType(), infoSink);
if (newChild)
return newChild;
@ -289,7 +287,7 @@ TIntermTyped* TIntermediate::addUnaryMath(TOperator op, TIntermNode* childNode,
// Returns an aggregate node, which could be the one passed in if
// it was already an aggregate.
//
TIntermAggregate* TIntermediate::setAggregateOperator(TIntermNode* node, TOperator op, TSourceLoc line)
TIntermTyped* TIntermediate::setAggregateOperator(TIntermNode* node, TOperator op, const TType& type, TSourceLoc line)
{
TIntermAggregate* aggNode;
@ -317,7 +315,9 @@ TIntermAggregate* TIntermediate::setAggregateOperator(TIntermNode* node, TOperat
if (line != 0)
aggNode->setLine(line);
return aggNode;
aggNode->setType(type);
return fold(aggNode);
}
//
@ -431,7 +431,7 @@ TIntermTyped* TIntermediate::addConversion(TOperator op, const TType& type, TInt
if (node->getAsConstantUnion()) {
return (promoteConstantUnion(promoteTo, node->getAsConstantUnion()));
return promoteConstantUnion(promoteTo, node->getAsConstantUnion());
} else {
//
// Add a new newNode for the conversion.
@ -822,6 +822,7 @@ bool TIntermOperator::isConstructor() const
{
return op > EOpConstructGuardStart && op < EOpConstructGuardEnd;
}
//
// Make sure the type of a unary operator is appropriate for its
// combination of operation and operand type.
@ -833,10 +834,13 @@ bool TIntermUnary::promote(TInfoSink&)
switch (op) {
case EOpLogicalNot:
if (operand->getBasicType() != EbtBool)
return false;
break;
case EOpBitwiseNot:
if (operand->getBasicType() != EbtInt)
if (operand->getBasicType() != EbtInt &&
operand->getBasicType() != EbtUint)
return false;
break;
case EOpNegative:
@ -844,22 +848,53 @@ bool TIntermUnary::promote(TInfoSink&)
case EOpPostDecrement:
case EOpPreIncrement:
case EOpPreDecrement:
if (operand->getBasicType() == EbtBool)
if (operand->getBasicType() != EbtInt &&
operand->getBasicType() != EbtUint &&
operand->getBasicType() != EbtFloat)
return false;
break;
// operators for built-ins are already type checked against their prototype
//
// Operators for built-ins are already type checked against their prototype.
// Special case the non-float ones, just so we don't give an error.
//
case EOpAny:
case EOpAll:
setType(TType(EbtBool));
return true;
case EOpVectorLogicalNot:
break;
case EOpLength:
setType(TType(EbtFloat, EvqTemporary, operand->getQualifier().precision));
return true;
case EOpTranspose:
setType(TType(operand->getType().getBasicType(), EvqTemporary, operand->getQualifier().precision, 0,
operand->getType().getMatrixRows(),
operand->getType().getMatrixCols()));
return true;
case EOpDeterminant:
setType(TType(operand->getType().getBasicType(), EvqTemporary, operand->getQualifier().precision));
return true;
default:
// TODO: functionality: uint/int versions of built-ins
// make sure all paths set the type
if (operand->getBasicType() != EbtFloat)
return false;
}
setType(operand->getType());
getTypePointer()->getQualifier().storage = EvqTemporary;
return true;
}
@ -1125,30 +1160,6 @@ bool TIntermBinary::promote(TInfoSink& infoSink)
return true;
}
bool CompareStruct(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
{
TTypeList* fields = leftNodeType.getStruct();
size_t structSize = fields->size();
int index = 0;
for (size_t j = 0; j < structSize; j++) {
int size = (*fields)[j].type->getObjectSize();
for (int i = 0; i < size; i++) {
if ((*fields)[j].type->getBasicType() == EbtStruct) {
if (!CompareStructure(*(*fields)[j].type, &rightUnionArray[index], &leftUnionArray[index]))
return false;
} else {
if (leftUnionArray[index] != rightUnionArray[index])
return false;
index++;
}
}
}
return true;
}
void TIntermTyped::propagatePrecision(TPrecisionQualifier newPrecision)
{
if (getQualifier().precision != EpqNone || (getBasicType() != EbtInt && getBasicType() != EbtFloat))
@ -1196,350 +1207,6 @@ void TIntermTyped::propagatePrecision(TPrecisionQualifier newPrecision)
// indexing?
}
bool CompareStructure(const TType& leftNodeType, constUnion* rightUnionArray, constUnion* leftUnionArray)
{
if (leftNodeType.isArray()) {
TType typeWithoutArrayness = leftNodeType;
typeWithoutArrayness.dereference();
int arraySize = leftNodeType.getArraySize();
for (int i = 0; i < arraySize; ++i) {
int offset = typeWithoutArrayness.getObjectSize() * i;
if (!CompareStruct(typeWithoutArrayness, &rightUnionArray[offset], &leftUnionArray[offset]))
return false;
}
} else
return CompareStruct(leftNodeType, rightUnionArray, leftUnionArray);
return true;
}
//
// The fold functions see if an operation on a constant can be done in place,
// without generating run-time code.
//
// Returns the node to keep using, which may or may not be the node passed in.
//
TIntermTyped* TIntermConstantUnion::fold(TOperator op, TIntermTyped* constantNode, TInfoSink& infoSink)
{
constUnion *unionArray = getUnionArrayPointer();
int objectSize = getType().getObjectSize();
if (constantNode) { // binary operations
TIntermConstantUnion *node = constantNode->getAsConstantUnion();
constUnion *rightUnionArray = node->getUnionArrayPointer();
TType returnType = getType();
if (getType().getBasicType() != node->getBasicType()) {
infoSink.info.message(EPrefixInternalError, "Constant folding basic types don't match", getLine());
return 0;
}
if (constantNode->getType().getObjectSize() == 1 && objectSize > 1) {
// for a case like float f = vec4(2,3,4,5) + 1.2;
rightUnionArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; ++i)
rightUnionArray[i] = *node->getUnionArrayPointer();
} else if (constantNode->getType().getObjectSize() > 1 && objectSize == 1) {
// for a case like float f = 1.2 + vec4(2,3,4,5);
rightUnionArray = node->getUnionArrayPointer();
unionArray = new constUnion[constantNode->getType().getObjectSize()];
for (int i = 0; i < constantNode->getType().getObjectSize(); ++i)
unionArray[i] = *getUnionArrayPointer();
returnType = node->getType();
objectSize = constantNode->getType().getObjectSize();
}
constUnion* tempConstArray = 0;
TIntermConstantUnion *tempNode;
int index = 0;
bool boolNodeFlag = false;
switch(op) {
case EOpAdd:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] + rightUnionArray[i];
break;
case EOpSub:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] - rightUnionArray[i];
break;
case EOpMul:
case EOpVectorTimesScalar:
case EOpMatrixTimesScalar:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] * rightUnionArray[i];
break;
case EOpMatrixTimesMatrix:
tempConstArray = new constUnion[getMatrixRows() * node->getMatrixCols()];
for (int row = 0; row < getMatrixRows(); row++) {
for (int column = 0; column < node->getMatrixCols(); column++) {
float sum = 0.0f;
for (int i = 0; i < node->getMatrixRows(); i++)
sum += unionArray[i * getMatrixRows() + row].getFConst() * rightUnionArray[column * node->getMatrixRows() + i].getFConst();
tempConstArray[column * getMatrixRows() + row].setFConst(sum);
}
}
returnType = TType(getType().getBasicType(), EvqConst, 0, getMatrixRows(), node->getMatrixCols());
break;
case EOpOuterProduct:
// TODO: functionality >= 120
break;
case EOpDeterminant:
// TODO: functionality >= 150
break;
case EOpMatrixInverse:
// TODO: functionality >= 150
break;
case EOpTranspose:
// TODO: functionality >= 120
break;
case EOpDiv:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++) {
switch (getType().getBasicType()) {
case EbtFloat:
if (rightUnionArray[i] == 0.0f) {
infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
tempConstArray[i].setFConst(FLT_MAX);
} else
tempConstArray[i].setFConst(unionArray[i].getFConst() / rightUnionArray[i].getFConst());
break;
case EbtInt:
if (rightUnionArray[i] == 0) {
infoSink.info.message(EPrefixWarning, "Divide by zero error during constant folding", getLine());
tempConstArray[i].setIConst(0xEFFFFFFF);
} else
tempConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
break;
default:
infoSink.info.message(EPrefixInternalError, "Constant folding cannot be done for \"/\"", getLine());
return 0;
}
}
break;
case EOpMatrixTimesVector:
tempConstArray = new constUnion[getMatrixRows()];
for (int i = 0; i < getMatrixRows(); i++) {
float sum = 0.0f;
for (int j = 0; j < node->getVectorSize(); j++) {
sum += unionArray[j*getMatrixRows() + i].getFConst() * rightUnionArray[j].getFConst();
}
tempConstArray[i].setFConst(sum);
}
tempNode = new TIntermConstantUnion(tempConstArray, TType(getBasicType(), EvqConst, getMatrixRows()));
tempNode->setLine(getLine());
return tempNode;
case EOpVectorTimesMatrix:
tempConstArray = new constUnion[node->getMatrixCols()];
for (int i = 0; i < node->getMatrixCols(); i++) {
float sum = 0.0f;
for (int j = 0; j < getVectorSize(); j++)
sum += unionArray[j].getFConst() * rightUnionArray[i*node->getMatrixRows() + j].getFConst();
tempConstArray[i].setFConst(sum);
}
tempNode = new TIntermConstantUnion(tempConstArray, TType(getBasicType(), EvqConst, node->getMatrixCols()));
tempNode->setLine(getLine());
return tempNode;
case EOpMod:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] % rightUnionArray[i];
break;
case EOpRightShift:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] >> rightUnionArray[i];
break;
case EOpLeftShift:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] << rightUnionArray[i];
break;
case EOpAnd:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] & rightUnionArray[i];
break;
case EOpInclusiveOr:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] | rightUnionArray[i];
break;
case EOpExclusiveOr:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] ^ rightUnionArray[i];
break;
case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] && rightUnionArray[i];
break;
case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++)
tempConstArray[i] = unionArray[i] || rightUnionArray[i];
break;
case EOpLogicalXor:
tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++) {
switch (getType().getBasicType()) {
case EbtBool: tempConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
default: assert(false && "Default missing");
}
}
break;
case EOpLessThan:
assert(objectSize == 1);
tempConstArray = new constUnion[1];
tempConstArray->setBConst(*unionArray < *rightUnionArray);
returnType = TType(EbtBool, EvqConst);
break;
case EOpGreaterThan:
assert(objectSize == 1);
tempConstArray = new constUnion[1];
tempConstArray->setBConst(*unionArray > *rightUnionArray);
returnType = TType(EbtBool, EvqConst);
break;
case EOpLessThanEqual:
{
assert(objectSize == 1);
constUnion constant;
constant.setBConst(*unionArray > *rightUnionArray);
tempConstArray = new constUnion[1];
tempConstArray->setBConst(!constant.getBConst());
returnType = TType(EbtBool, EvqConst);
break;
}
case EOpGreaterThanEqual:
{
assert(objectSize == 1);
constUnion constant;
constant.setBConst(*unionArray < *rightUnionArray);
tempConstArray = new constUnion[1];
tempConstArray->setBConst(!constant.getBConst());
returnType = TType(EbtBool, EvqConst);
break;
}
case EOpEqual:
if (getType().getBasicType() == EbtStruct) {
if (!CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
boolNodeFlag = true;
} else {
for (int i = 0; i < objectSize; i++) {
if (unionArray[i] != rightUnionArray[i]) {
boolNodeFlag = true;
break; // break out of for loop
}
}
}
tempConstArray = new constUnion[1];
if (!boolNodeFlag) {
tempConstArray->setBConst(true);
}
else {
tempConstArray->setBConst(false);
}
tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EvqConst));
tempNode->setLine(getLine());
return tempNode;
case EOpNotEqual:
if (getType().getBasicType() == EbtStruct) {
if (CompareStructure(node->getType(), node->getUnionArrayPointer(), unionArray))
boolNodeFlag = true;
} else {
for (int i = 0; i < objectSize; i++) {
if (unionArray[i] == rightUnionArray[i]) {
boolNodeFlag = true;
break; // break out of for loop
}
}
}
tempConstArray = new constUnion[1];
if (!boolNodeFlag) {
tempConstArray->setBConst(true);
}
else {
tempConstArray->setBConst(false);
}
tempNode = new TIntermConstantUnion(tempConstArray, TType(EbtBool, EvqConst));
tempNode->setLine(getLine());
return tempNode;
default:
infoSink.info.message(EPrefixInternalError, "Invalid operator for constant folding", getLine());
return 0;
}
tempNode = new TIntermConstantUnion(tempConstArray, returnType);
tempNode->setLine(getLine());
return tempNode;
} else {
//
// Do unary operations
//
TIntermConstantUnion *newNode = 0;
constUnion* tempConstArray = new constUnion[objectSize];
for (int i = 0; i < objectSize; i++) {
switch(op) {
case EOpNegative:
switch (getType().getBasicType()) {
case EbtFloat: tempConstArray[i].setFConst(-unionArray[i].getFConst()); break;
case EbtInt: tempConstArray[i].setIConst(-unionArray[i].getIConst()); break;
default:
infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
return 0;
}
break;
case EOpLogicalNot: // this code is written for possible future use, will not get executed currently
switch (getType().getBasicType()) {
case EbtBool: tempConstArray[i].setBConst(!unionArray[i].getBConst()); break;
default:
infoSink.info.message(EPrefixInternalError, "Unary operation not folded into constant", getLine());
return 0;
}
break;
default:
return 0;
}
}
newNode = new TIntermConstantUnion(tempConstArray, getType());
newNode->setLine(getLine());
return newNode;
}
return this;
}
TIntermTyped* TIntermediate::promoteConstantUnion(TBasicType promoteTo, TIntermConstantUnion* node)
{
if (node->getType().isArray())