SPV: Implement composite comparisons (reductions across hierchical compare).
This commit is contained in:
parent
59420fd356
commit
2211835b4d
6 changed files with 470 additions and 71 deletions
|
|
@ -2274,7 +2274,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, spv
|
|||
if (reduceComparison && (builder.isVector(left) || builder.isMatrix(left) || builder.isAggregate(left))) {
|
||||
assert(op == glslang::EOpEqual || op == glslang::EOpNotEqual);
|
||||
|
||||
return builder.createCompare(precision, left, right, op == glslang::EOpEqual);
|
||||
return builder.createCompositeCompare(precision, left, right, op == glslang::EOpEqual);
|
||||
}
|
||||
|
||||
switch (op) {
|
||||
|
|
|
|||
|
|
@ -435,7 +435,7 @@ Op Builder::getMostBasicTypeClass(Id typeId) const
|
|||
}
|
||||
}
|
||||
|
||||
int Builder::getNumTypeComponents(Id typeId) const
|
||||
int Builder::getNumTypeConstituents(Id typeId) const
|
||||
{
|
||||
Instruction* instr = module.getInstruction(typeId);
|
||||
|
||||
|
|
@ -447,7 +447,10 @@ int Builder::getNumTypeComponents(Id typeId) const
|
|||
return 1;
|
||||
case OpTypeVector:
|
||||
case OpTypeMatrix:
|
||||
case OpTypeArray:
|
||||
return instr->getImmediateOperand(1);
|
||||
case OpTypeStruct:
|
||||
return instr->getNumOperands();
|
||||
default:
|
||||
assert(0);
|
||||
return 1;
|
||||
|
|
@ -1411,88 +1414,78 @@ Id Builder::createTextureQueryCall(Op opCode, const TextureParameters& parameter
|
|||
return query->getResultId();
|
||||
}
|
||||
|
||||
// Comments in header
|
||||
Id Builder::createCompare(Decoration precision, Id value1, Id value2, bool equal)
|
||||
// External comments in header.
|
||||
// Operates recursively to visit the composite's hierarchy.
|
||||
Id Builder::createCompositeCompare(Decoration precision, Id value1, Id value2, bool equal)
|
||||
{
|
||||
Id boolType = makeBoolType();
|
||||
Id valueType = getTypeId(value1);
|
||||
|
||||
assert(valueType == getTypeId(value2));
|
||||
assert(! isScalar(value1));
|
||||
|
||||
// Vectors
|
||||
Id resultId;
|
||||
|
||||
if (isVectorType(valueType)) {
|
||||
Id boolVectorType = makeVectorType(boolType, getNumTypeComponents(valueType));
|
||||
Id boolVector;
|
||||
int numConstituents = getNumTypeConstituents(valueType);
|
||||
|
||||
// Scalars and Vectors
|
||||
|
||||
if (isScalarType(valueType) || isVectorType(valueType)) {
|
||||
// These just need a single comparison, just have
|
||||
// to figure out what it is.
|
||||
Op op;
|
||||
if (getMostBasicTypeClass(valueType) == OpTypeFloat)
|
||||
switch (getMostBasicTypeClass(valueType)) {
|
||||
case OpTypeFloat:
|
||||
op = equal ? OpFOrdEqual : OpFOrdNotEqual;
|
||||
else
|
||||
break;
|
||||
case OpTypeInt:
|
||||
op = equal ? OpIEqual : OpINotEqual;
|
||||
break;
|
||||
case OpTypeBool:
|
||||
op = equal ? OpLogicalEqual : OpLogicalNotEqual;
|
||||
precision = NoPrecision;
|
||||
break;
|
||||
}
|
||||
|
||||
boolVector = createBinOp(op, boolVectorType, value1, value2);
|
||||
setPrecision(boolVector, precision);
|
||||
if (isScalarType(valueType)) {
|
||||
// scalar
|
||||
resultId = createBinOp(op, boolType, value1, value2);
|
||||
setPrecision(resultId, precision);
|
||||
} else {
|
||||
// vector
|
||||
resultId = createBinOp(op, makeVectorType(boolType, numConstituents), value1, value2);
|
||||
setPrecision(resultId, precision);
|
||||
// reduce vector compares...
|
||||
resultId = createUnaryOp(equal ? OpAll : OpAny, boolType, resultId);
|
||||
}
|
||||
|
||||
// Reduce vector compares with any() and all().
|
||||
|
||||
op = equal ? OpAll : OpAny;
|
||||
|
||||
return createUnaryOp(op, boolType, boolVector);
|
||||
return resultId;
|
||||
}
|
||||
|
||||
spv::MissingFunctionality("Composite comparison of non-vectors");
|
||||
// Only structs, arrays, and matrices should be left.
|
||||
// They share in common the reduction operation across their constituents.
|
||||
assert(isAggregateType(valueType) || isMatrixType(valueType));
|
||||
|
||||
return NoResult;
|
||||
// Compare each pair of constituents
|
||||
for (int constituent = 0; constituent < numConstituents; ++constituent) {
|
||||
std::vector<unsigned> indexes(1, constituent);
|
||||
Id constituentType = getContainedTypeId(valueType, constituent);
|
||||
Id constituent1 = createCompositeExtract(value1, constituentType, indexes);
|
||||
Id constituent2 = createCompositeExtract(value2, constituentType, indexes);
|
||||
|
||||
// Recursively handle aggregates, which include matrices, arrays, and structures
|
||||
// and accumulate the results.
|
||||
Id subResultId = createCompositeCompare(precision, constituent1, constituent2, equal);
|
||||
|
||||
// Matrices
|
||||
if (constituent == 0)
|
||||
resultId = subResultId;
|
||||
else
|
||||
resultId = createBinOp(equal ? OpLogicalAnd : OpLogicalOr, boolType, resultId, subResultId);
|
||||
}
|
||||
|
||||
// Arrays
|
||||
|
||||
//int numElements;
|
||||
//const llvm::ArrayType* arrayType = llvm::dyn_cast<llvm::ArrayType>(value1->getType());
|
||||
//if (arrayType)
|
||||
// numElements = (int)arrayType->getNumElements();
|
||||
//else {
|
||||
// // better be structure
|
||||
// const llvm::StructType* structType = llvm::dyn_cast<llvm::StructType>(value1->getType());
|
||||
// assert(structType);
|
||||
// numElements = structType->getNumElements();
|
||||
//}
|
||||
|
||||
//assert(numElements > 0);
|
||||
|
||||
//for (int element = 0; element < numElements; ++element) {
|
||||
// // Get intermediate comparison values
|
||||
// llvm::Value* element1 = builder.CreateExtractValue(value1, element, "element1");
|
||||
// setInstructionPrecision(element1, precision);
|
||||
// llvm::Value* element2 = builder.CreateExtractValue(value2, element, "element2");
|
||||
// setInstructionPrecision(element2, precision);
|
||||
|
||||
// llvm::Value* subResult = createCompare(precision, element1, element2, equal, "comp");
|
||||
|
||||
// // Accumulate intermediate comparison
|
||||
// if (element == 0)
|
||||
// result = subResult;
|
||||
// else {
|
||||
// if (equal)
|
||||
// result = builder.CreateAnd(result, subResult);
|
||||
// else
|
||||
// result = builder.CreateOr(result, subResult);
|
||||
// setInstructionPrecision(result, precision);
|
||||
// }
|
||||
//}
|
||||
|
||||
//return result;
|
||||
return resultId;
|
||||
}
|
||||
|
||||
// OpCompositeConstruct
|
||||
Id Builder::createCompositeConstruct(Id typeId, std::vector<Id>& constituents)
|
||||
{
|
||||
assert(isAggregateType(typeId) || (getNumTypeComponents(typeId) > 1 && getNumTypeComponents(typeId) == (int)constituents.size()));
|
||||
assert(isAggregateType(typeId) || (getNumTypeConstituents(typeId) > 1 && getNumTypeConstituents(typeId) == (int)constituents.size()));
|
||||
|
||||
Instruction* op = new Instruction(getUniqueId(), typeId, OpCompositeConstruct);
|
||||
for (int c = 0; c < (int)constituents.size(); ++c)
|
||||
|
|
|
|||
|
|
@ -116,7 +116,8 @@ public:
|
|||
Op getTypeClass(Id typeId) const { return getOpCode(typeId); }
|
||||
Op getMostBasicTypeClass(Id typeId) const;
|
||||
int getNumComponents(Id resultId) const { return getNumTypeComponents(getTypeId(resultId)); }
|
||||
int getNumTypeComponents(Id typeId) const;
|
||||
int getNumTypeConstituents(Id typeId) const;
|
||||
int getNumTypeComponents(Id typeId) const { return getNumTypeConstituents(typeId); }
|
||||
Id getScalarTypeId(Id typeId) const;
|
||||
Id getContainedTypeId(Id typeId) const;
|
||||
Id getContainedTypeId(Id typeId, int) const;
|
||||
|
|
@ -150,7 +151,7 @@ public:
|
|||
int getTypeNumColumns(Id typeId) const
|
||||
{
|
||||
assert(isMatrixType(typeId));
|
||||
return getNumTypeComponents(typeId);
|
||||
return getNumTypeConstituents(typeId);
|
||||
}
|
||||
int getNumColumns(Id resultId) const { return getTypeNumColumns(getTypeId(resultId)); }
|
||||
int getTypeNumRows(Id typeId) const
|
||||
|
|
@ -265,11 +266,13 @@ public:
|
|||
// (No true lvalue or stores are used.)
|
||||
Id createLvalueSwizzle(Id typeId, Id target, Id source, std::vector<unsigned>& channels);
|
||||
|
||||
// If the value passed in is an instruction and the precision is not EMpNone,
|
||||
// If the value passed in is an instruction and the precision is not NoPrecision,
|
||||
// it gets tagged with the requested precision.
|
||||
void setPrecision(Id /* value */, Decoration /* precision */)
|
||||
void setPrecision(Id /* value */, Decoration precision)
|
||||
{
|
||||
// TODO
|
||||
if (precision != NoPrecision) {
|
||||
;// TODO
|
||||
}
|
||||
}
|
||||
|
||||
// Can smear a scalar to a vector for the following forms:
|
||||
|
|
@ -322,7 +325,7 @@ public:
|
|||
Id createBitFieldInsertCall(Decoration precision, Id, Id, Id, Id);
|
||||
|
||||
// Reduction comparision for composites: For equal and not-equal resulting in a scalar.
|
||||
Id createCompare(Decoration precision, Id, Id, bool /* true if for equal, fales if for not-equal */);
|
||||
Id createCompositeCompare(Decoration precision, Id, Id, bool /* true if for equal, false if for not-equal */);
|
||||
|
||||
// OpCompositeConstruct
|
||||
Id createCompositeConstruct(Id typeId, std::vector<Id>& constituents);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue