Constant folding: Correct result type of non-square matrix folding.

This also made the function easier to read by identifying
left and right operands more clearly.
This commit is contained in:
John Kessenich 2015-12-14 18:21:19 -07:00
parent ea0cb2eb11
commit 61c47a951b
5 changed files with 106 additions and 47 deletions

View file

@ -81,11 +81,12 @@ namespace glslang {
//
//
// Do folding between a pair of nodes
// Do folding between a pair of nodes.
// 'this' is the left-hand operand and 'rightConstantNode' is the right-hand operand.
//
// Returns a new node representing the result.
//
TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* constantNode) const
TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* rightConstantNode) const
{
// For most cases, the return type matches the argument type, so set that
// up and just code to exceptions below.
@ -96,37 +97,37 @@ TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* const
// A pair of nodes is to be folded together
//
const TIntermConstantUnion *node = constantNode->getAsConstantUnion();
TConstUnionArray unionArray = getConstArray();
TConstUnionArray rightUnionArray = node->getConstArray();
const TIntermConstantUnion *rightNode = rightConstantNode->getAsConstantUnion();
TConstUnionArray leftUnionArray = getConstArray();
TConstUnionArray rightUnionArray = rightNode->getConstArray();
// Figure out the size of the result
int newComps;
int constComps;
switch(op) {
case EOpMatrixTimesMatrix:
newComps = getMatrixRows() * node->getMatrixCols();
newComps = rightNode->getMatrixCols() * getMatrixRows();
break;
case EOpMatrixTimesVector:
newComps = getMatrixRows();
break;
case EOpVectorTimesMatrix:
newComps = node->getMatrixCols();
newComps = rightNode->getMatrixCols();
break;
default:
newComps = getType().computeNumComponents();
constComps = constantNode->getType().computeNumComponents();
constComps = rightConstantNode->getType().computeNumComponents();
if (constComps == 1 && newComps > 1) {
// for a case like vec4 f = vec4(2,3,4,5) + 1.2;
TConstUnionArray smearedArray(newComps, node->getConstArray()[0]);
TConstUnionArray smearedArray(newComps, rightNode->getConstArray()[0]);
rightUnionArray = smearedArray;
} else if (constComps > 1 && newComps == 1) {
// for a case like vec4 f = 1.2 + vec4(2,3,4,5);
newComps = constComps;
rightUnionArray = node->getConstArray();
rightUnionArray = rightNode->getConstArray();
TConstUnionArray smearedArray(newComps, getConstArray()[0]);
unionArray = smearedArray;
returnType.shallowCopy(node->getType());
leftUnionArray = smearedArray;
returnType.shallowCopy(rightNode->getType());
}
break;
}
@ -137,52 +138,52 @@ TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* const
switch(op) {
case EOpAdd:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] + rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] + rightUnionArray[i];
break;
case EOpSub:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] - rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] - rightUnionArray[i];
break;
case EOpMul:
case EOpVectorTimesScalar:
case EOpMatrixTimesScalar:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] * rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] * rightUnionArray[i];
break;
case EOpMatrixTimesMatrix:
for (int row = 0; row < getMatrixRows(); row++) {
for (int column = 0; column < node->getMatrixCols(); column++) {
for (int column = 0; column < rightNode->getMatrixCols(); column++) {
double sum = 0.0f;
for (int i = 0; i < node->getMatrixRows(); i++)
sum += unionArray[i * getMatrixRows() + row].getDConst() * rightUnionArray[column * node->getMatrixRows() + i].getDConst();
for (int i = 0; i < rightNode->getMatrixRows(); i++)
sum += leftUnionArray[i * getMatrixRows() + row].getDConst() * rightUnionArray[column * rightNode->getMatrixRows() + i].getDConst();
newConstArray[column * getMatrixRows() + row].setDConst(sum);
}
}
returnType.shallowCopy(TType(getType().getBasicType(), EvqConst, 0, getMatrixRows(), node->getMatrixCols()));
returnType.shallowCopy(TType(getType().getBasicType(), EvqConst, 0, rightNode->getMatrixCols(), getMatrixRows()));
break;
case EOpDiv:
for (int i = 0; i < newComps; i++) {
switch (getType().getBasicType()) {
case EbtDouble:
case EbtFloat:
newConstArray[i].setDConst(unionArray[i].getDConst() / rightUnionArray[i].getDConst());
newConstArray[i].setDConst(leftUnionArray[i].getDConst() / rightUnionArray[i].getDConst());
break;
case EbtInt:
if (rightUnionArray[i] == 0)
newConstArray[i].setIConst(0x7FFFFFFF);
else if (rightUnionArray[i].getIConst() == -1 && unionArray[i].getIConst() == 0x80000000)
else if (rightUnionArray[i].getIConst() == -1 && leftUnionArray[i].getIConst() == 0x80000000)
newConstArray[i].setIConst(0x80000000);
else
newConstArray[i].setIConst(unionArray[i].getIConst() / rightUnionArray[i].getIConst());
newConstArray[i].setIConst(leftUnionArray[i].getIConst() / rightUnionArray[i].getIConst());
break;
case EbtUint:
if (rightUnionArray[i] == 0) {
newConstArray[i].setUConst(0xFFFFFFFF);
} else
newConstArray[i].setUConst(unionArray[i].getUConst() / rightUnionArray[i].getUConst());
newConstArray[i].setUConst(leftUnionArray[i].getUConst() / rightUnionArray[i].getUConst());
break;
default:
return 0;
@ -193,8 +194,8 @@ TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* const
case EOpMatrixTimesVector:
for (int i = 0; i < getMatrixRows(); i++) {
double sum = 0.0f;
for (int j = 0; j < node->getVectorSize(); j++) {
sum += unionArray[j*getMatrixRows() + i].getDConst() * rightUnionArray[j].getDConst();
for (int j = 0; j < rightNode->getVectorSize(); j++) {
sum += leftUnionArray[j*getMatrixRows() + i].getDConst() * rightUnionArray[j].getDConst();
}
newConstArray[i].setDConst(sum);
}
@ -203,89 +204,89 @@ TIntermTyped* TIntermConstantUnion::fold(TOperator op, const TIntermTyped* const
break;
case EOpVectorTimesMatrix:
for (int i = 0; i < node->getMatrixCols(); i++) {
for (int i = 0; i < rightNode->getMatrixCols(); i++) {
double sum = 0.0f;
for (int j = 0; j < getVectorSize(); j++)
sum += unionArray[j].getDConst() * rightUnionArray[i*node->getMatrixRows() + j].getDConst();
sum += leftUnionArray[j].getDConst() * rightUnionArray[i*rightNode->getMatrixRows() + j].getDConst();
newConstArray[i].setDConst(sum);
}
returnType.shallowCopy(TType(getBasicType(), EvqConst, node->getMatrixCols()));
returnType.shallowCopy(TType(getBasicType(), EvqConst, rightNode->getMatrixCols()));
break;
case EOpMod:
for (int i = 0; i < newComps; i++) {
if (rightUnionArray[i] == 0)
newConstArray[i] = unionArray[i];
newConstArray[i] = leftUnionArray[i];
else
newConstArray[i] = unionArray[i] % rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] % rightUnionArray[i];
}
break;
case EOpRightShift:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] >> rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] >> rightUnionArray[i];
break;
case EOpLeftShift:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] << rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] << rightUnionArray[i];
break;
case EOpAnd:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] & rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] & rightUnionArray[i];
break;
case EOpInclusiveOr:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] | rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] | rightUnionArray[i];
break;
case EOpExclusiveOr:
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] ^ rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] ^ rightUnionArray[i];
break;
case EOpLogicalAnd: // this code is written for possible future use, will not get executed currently
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] && rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] && rightUnionArray[i];
break;
case EOpLogicalOr: // this code is written for possible future use, will not get executed currently
for (int i = 0; i < newComps; i++)
newConstArray[i] = unionArray[i] || rightUnionArray[i];
newConstArray[i] = leftUnionArray[i] || rightUnionArray[i];
break;
case EOpLogicalXor:
for (int i = 0; i < newComps; i++) {
switch (getType().getBasicType()) {
case EbtBool: newConstArray[i].setBConst((unionArray[i] == rightUnionArray[i]) ? false : true); break;
case EbtBool: newConstArray[i].setBConst((leftUnionArray[i] == rightUnionArray[i]) ? false : true); break;
default: assert(false && "Default missing");
}
}
break;
case EOpLessThan:
newConstArray[0].setBConst(unionArray[0] < rightUnionArray[0]);
newConstArray[0].setBConst(leftUnionArray[0] < rightUnionArray[0]);
returnType.shallowCopy(constBool);
break;
case EOpGreaterThan:
newConstArray[0].setBConst(unionArray[0] > rightUnionArray[0]);
newConstArray[0].setBConst(leftUnionArray[0] > rightUnionArray[0]);
returnType.shallowCopy(constBool);
break;
case EOpLessThanEqual:
newConstArray[0].setBConst(! (unionArray[0] > rightUnionArray[0]));
newConstArray[0].setBConst(! (leftUnionArray[0] > rightUnionArray[0]));
returnType.shallowCopy(constBool);
break;
case EOpGreaterThanEqual:
newConstArray[0].setBConst(! (unionArray[0] < rightUnionArray[0]));
newConstArray[0].setBConst(! (leftUnionArray[0] < rightUnionArray[0]));
returnType.shallowCopy(constBool);
break;
case EOpEqual:
newConstArray[0].setBConst(node->getConstArray() == unionArray);
newConstArray[0].setBConst(rightNode->getConstArray() == leftUnionArray);
returnType.shallowCopy(constBool);
break;
case EOpNotEqual:
newConstArray[0].setBConst(node->getConstArray() != unionArray);
newConstArray[0].setBConst(rightNode->getConstArray() != leftUnionArray);
returnType.shallowCopy(constBool);
break;