Implement support for GL_KHR_cooperative_matrix extension
This commit is contained in:
parent
91a97b4c69
commit
808c7ed17c
40 changed files with 8227 additions and 5733 deletions
|
|
@ -55,5 +55,6 @@ static const char* const E_SPV_KHR_subgroup_uniform_control_flow = "SPV_KHR_subg
|
|||
static const char* const E_SPV_KHR_fragment_shader_barycentric = "SPV_KHR_fragment_shader_barycentric";
|
||||
static const char* const E_SPV_AMD_shader_early_and_late_fragment_tests = "SPV_AMD_shader_early_and_late_fragment_tests";
|
||||
static const char* const E_SPV_KHR_ray_tracing_position_fetch = "SPV_KHR_ray_tracing_position_fetch";
|
||||
static const char* const E_SPV_KHR_cooperative_matrix = "SPV_KHR_cooperative_matrix";
|
||||
|
||||
#endif // #ifndef GLSLextKHR_H
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ protected:
|
|||
glslang::TLayoutPacking, const glslang::TQualifier&);
|
||||
void decorateStructType(const glslang::TType&, const glslang::TTypeList* glslangStruct, glslang::TLayoutPacking,
|
||||
const glslang::TQualifier&, spv::Id, const std::vector<spv::Id>& spvMembers);
|
||||
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim);
|
||||
spv::Id makeArraySizeId(const glslang::TArraySizes&, int dim, bool allowZero = false);
|
||||
spv::Id accessChainLoad(const glslang::TType& type);
|
||||
void accessChainStore(const glslang::TType& type, spv::Id rvalue);
|
||||
void multiTypeStore(const glslang::TType&, spv::Id rValue);
|
||||
|
|
@ -212,7 +212,7 @@ protected:
|
|||
glslang::TBasicType typeProxy);
|
||||
spv::Id createConversion(glslang::TOperator op, OpDecorations&, spv::Id destTypeId, spv::Id operand,
|
||||
glslang::TBasicType typeProxy);
|
||||
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize);
|
||||
spv::Id createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType);
|
||||
spv::Id makeSmearedConstant(spv::Id constant, int vectorSize);
|
||||
spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId,
|
||||
std::vector<spv::Id>& operands, glslang::TBasicType typeProxy,
|
||||
|
|
@ -2560,12 +2560,15 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
|
|||
|
||||
spv::Id length;
|
||||
if (node->getOperand()->getType().isCoopMat()) {
|
||||
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
|
||||
|
||||
spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
|
||||
length = builder.createCooperativeMatrixLength(typeId);
|
||||
if (node->getOperand()->getType().isCoopMatKHR()) {
|
||||
length = builder.createCooperativeMatrixLengthKHR(typeId);
|
||||
} else {
|
||||
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
|
||||
length = builder.createCooperativeMatrixLengthNV(typeId);
|
||||
}
|
||||
} else {
|
||||
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
|
||||
block->traverse(this);
|
||||
|
|
@ -3099,7 +3102,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
case glslang::EOpConstructStruct:
|
||||
case glslang::EOpConstructTextureSampler:
|
||||
case glslang::EOpConstructReference:
|
||||
case glslang::EOpConstructCooperativeMatrix:
|
||||
case glslang::EOpConstructCooperativeMatrixNV:
|
||||
case glslang::EOpConstructCooperativeMatrixKHR:
|
||||
{
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
std::vector<spv::Id> arguments;
|
||||
|
|
@ -3116,7 +3120,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
} else
|
||||
constructed = builder.createOp(spv::OpSampledImage, resultType(), arguments);
|
||||
} else if (node->getOp() == glslang::EOpConstructStruct ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrix ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrixNV ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrixKHR ||
|
||||
node->getType().isArray()) {
|
||||
std::vector<spv::Id> constituents;
|
||||
for (int c = 0; c < (int)arguments.size(); ++c)
|
||||
|
|
@ -3291,6 +3296,8 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
break;
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
case glslang::EOpCooperativeMatrixLoadNV:
|
||||
case glslang::EOpCooperativeMatrixStoreNV:
|
||||
noReturnValue = true;
|
||||
break;
|
||||
case glslang::EOpBeginInvocationInterlock:
|
||||
|
|
@ -3502,10 +3509,12 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
case glslang::EOpCooperativeMatrixLoadNV:
|
||||
if (arg == 0 || arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
case glslang::EOpCooperativeMatrixStoreNV:
|
||||
if (arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
|
|
@ -3534,7 +3543,9 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
|
||||
#ifndef GLSLANG_WEB
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
|
||||
|
||||
if (arg == 1) {
|
||||
// fold "element" parameter into the access chain
|
||||
|
|
@ -3555,9 +3566,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
unsigned int alignment = builder.getAccessChain().alignment;
|
||||
|
||||
int memoryAccess = TranslateMemoryAccess(coherentFlags);
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad)
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerAvailableKHRMask;
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerVisibleKHRMask;
|
||||
if (builder.getStorageClass(builder.getAccessChain().base) ==
|
||||
spv::StorageClassPhysicalStorageBufferEXT) {
|
||||
|
|
@ -3655,31 +3668,48 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
#ifndef GLSLANG_WEB
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixLoadNV) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
} else {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
}
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
// get the pointee type
|
||||
spv::Id typeId = builder.getContainedTypeId(builder.getTypeId(operands[0]));
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
// do the op
|
||||
spv::Id result = builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
|
||||
spv::Id result = node->getOp() == glslang::EOpCooperativeMatrixLoad
|
||||
? builder.createOp(spv::OpCooperativeMatrixLoadKHR, typeId, idImmOps)
|
||||
: builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
|
||||
// store the result to the pointer (out param 'm')
|
||||
builder.createStore(result, operands[0]);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStoreNV) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[0])); // object
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // matrixLayout
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
} else {
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
}
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreKHR, idImmOps);
|
||||
else
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpRayQueryGetIntersectionTriangleVertexPositionsEXT) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
|
@ -3694,6 +3724,32 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
// store the result to the pointer (out param 'm')
|
||||
builder.createStore(result, operands[2]);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixMulAdd) {
|
||||
uint32_t matrixOperands = 0;
|
||||
|
||||
// If the optional operand is present, initialize matrixOperands to that value.
|
||||
if (glslangOperands.size() == 4 && glslangOperands[3]->getAsConstantUnion()) {
|
||||
matrixOperands = glslangOperands[3]->getAsConstantUnion()->getConstArray()[0].getIConst();
|
||||
}
|
||||
|
||||
// Determine Cooperative Matrix Operands bits from the signedness of the types.
|
||||
if (isTypeSignedInt(glslangOperands[0]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixASignedComponentsMask;
|
||||
if (isTypeSignedInt(glslangOperands[1]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixBSignedComponentsMask;
|
||||
if (isTypeSignedInt(glslangOperands[2]->getAsTyped()->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixCSignedComponentsMask;
|
||||
if (isTypeSignedInt(node->getBasicType()))
|
||||
matrixOperands |= spv::CooperativeMatrixOperandsMatrixResultSignedComponentsMask;
|
||||
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[0]));
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1]));
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2]));
|
||||
if (matrixOperands != 0)
|
||||
idImmOps.push_back(spv::IdImmediate(false, matrixOperands));
|
||||
|
||||
result = builder.createOp(spv::OpCooperativeMatrixMulAddKHR, resultType(), idImmOps);
|
||||
} else
|
||||
#endif
|
||||
if (atomic) {
|
||||
|
|
@ -4586,9 +4642,10 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
|
|||
spvType = builder.makeVectorType(spvType, type.getVectorSize());
|
||||
}
|
||||
|
||||
if (type.isCoopMat()) {
|
||||
if (type.isCoopMatNV()) {
|
||||
builder.addCapability(spv::CapabilityCooperativeMatrixNV);
|
||||
builder.addExtension(spv::E_SPV_NV_cooperative_matrix);
|
||||
|
||||
if (type.getBasicType() == glslang::EbtFloat16)
|
||||
builder.addCapability(spv::CapabilityFloat16);
|
||||
if (type.getBasicType() == glslang::EbtUint8 ||
|
||||
|
|
@ -4596,11 +4653,29 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
|
|||
builder.addCapability(spv::CapabilityInt8);
|
||||
}
|
||||
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters(), 1);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters(), 2);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters(), 3);
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 3);
|
||||
|
||||
spvType = builder.makeCooperativeMatrixType(spvType, scope, rows, cols);
|
||||
spvType = builder.makeCooperativeMatrixTypeNV(spvType, scope, rows, cols);
|
||||
}
|
||||
|
||||
if (type.isCoopMatKHR()) {
|
||||
builder.addCapability(spv::CapabilityCooperativeMatrixKHR);
|
||||
builder.addExtension(spv::E_SPV_KHR_cooperative_matrix);
|
||||
|
||||
if (type.getBasicType() == glslang::EbtFloat16)
|
||||
builder.addCapability(spv::CapabilityFloat16);
|
||||
if (type.getBasicType() == glslang::EbtUint8 || type.getBasicType() == glslang::EbtInt8) {
|
||||
builder.addCapability(spv::CapabilityInt8);
|
||||
}
|
||||
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters()->arraySizes, 0);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters()->arraySizes, 1);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters()->arraySizes, 2);
|
||||
spv::Id use = builder.makeUintConstant(type.getCoopMatKHRuse());
|
||||
|
||||
spvType = builder.makeCooperativeMatrixTypeKHR(spvType, scope, rows, cols, use);
|
||||
}
|
||||
|
||||
if (type.isArray()) {
|
||||
|
|
@ -4951,7 +5026,7 @@ void TGlslangToSpvTraverser::decorateStructType(const glslang::TType& type,
|
|||
// This is not quite trivial, because of specialization constants.
|
||||
// Sometimes, a raw constant is turned into an Id, and sometimes
|
||||
// a specialization constant expression is.
|
||||
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim)
|
||||
spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arraySizes, int dim, bool allowZero)
|
||||
{
|
||||
// First, see if this is sized with a node, meaning a specialization constant:
|
||||
glslang::TIntermTyped* specNode = arraySizes.getDimNode(dim);
|
||||
|
|
@ -4965,7 +5040,10 @@ spv::Id TGlslangToSpvTraverser::makeArraySizeId(const glslang::TArraySizes& arra
|
|||
|
||||
// Otherwise, need a compile-time (front end) size, get it:
|
||||
int size = arraySizes.getDimSize(dim);
|
||||
assert(size > 0);
|
||||
|
||||
if (!allowZero)
|
||||
assert(size > 0);
|
||||
|
||||
return builder.makeUintConstant(size);
|
||||
}
|
||||
|
||||
|
|
@ -7287,7 +7365,9 @@ spv::Id TGlslangToSpvTraverser::createUnaryMatrixOperation(spv::Op op, OpDecorat
|
|||
// For converting integers where both the bitwidth and the signedness could
|
||||
// change, but only do the width change here. The caller is still responsible
|
||||
// for the signedness conversion.
|
||||
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize)
|
||||
// destType is the final type that will be converted to, but this function
|
||||
// may only be doing part of that conversion.
|
||||
spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op, spv::Id operand, int vectorSize, spv::Id destType)
|
||||
{
|
||||
// Get the result type width, based on the type to convert to.
|
||||
int width = 32;
|
||||
|
|
@ -7358,6 +7438,11 @@ spv::Id TGlslangToSpvTraverser::createIntWidthConversion(glslang::TOperator op,
|
|||
|
||||
if (vectorSize > 0)
|
||||
type = builder.makeVectorType(type, vectorSize);
|
||||
else if (builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixKHR ||
|
||||
builder.getOpCode(destType) == spv::OpTypeCooperativeMatrixNV) {
|
||||
|
||||
type = builder.makeCooperativeMatrixTypeWithSameShape(type, destType);
|
||||
}
|
||||
|
||||
return builder.createUnaryOp(convOp, type, operand);
|
||||
}
|
||||
|
|
@ -7630,7 +7715,7 @@ spv::Id TGlslangToSpvTraverser::createConversion(glslang::TOperator op, OpDecora
|
|||
case glslang::EOpConvUint64ToInt16:
|
||||
case glslang::EOpConvUint64ToInt:
|
||||
// OpSConvert/OpUConvert + OpBitCast
|
||||
operand = createIntWidthConversion(op, operand, vectorSize);
|
||||
operand = createIntWidthConversion(op, operand, vectorSize, destType);
|
||||
|
||||
if (builder.isInSpecConstCodeGenMode()) {
|
||||
// Build zero scalar or vector for OpIAdd.
|
||||
|
|
@ -8963,7 +9048,7 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
|
|||
case glslang::EOpSetMeshOutputsEXT:
|
||||
builder.createNoResultOp(spv::OpSetMeshOutputsEXT, operands);
|
||||
return 0;
|
||||
case glslang::EOpCooperativeMatrixMulAdd:
|
||||
case glslang::EOpCooperativeMatrixMulAddNV:
|
||||
opCode = spv::OpCooperativeMatrixMulAddNV;
|
||||
break;
|
||||
case glslang::EOpHitObjectTraceRayNV:
|
||||
|
|
|
|||
|
|
@ -680,6 +680,7 @@ namespace spv {
|
|||
case spv::OperandKernelEnqueueFlags:
|
||||
case spv::OperandKernelProfilingInfo:
|
||||
case spv::OperandCapability:
|
||||
case spv::OperandCooperativeMatrixOperands:
|
||||
++word;
|
||||
break;
|
||||
|
||||
|
|
|
|||
|
|
@ -481,15 +481,41 @@ Id Builder::makeMatrixType(Id component, int cols, int rows)
|
|||
return type->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols)
|
||||
Id Builder::makeCooperativeMatrixTypeKHR(Id component, Id scope, Id rows, Id cols, Id use)
|
||||
{
|
||||
// try to find it
|
||||
Instruction* type;
|
||||
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeMatrixKHR].size(); ++t) {
|
||||
type = groupedTypes[OpTypeCooperativeMatrixKHR][t];
|
||||
if (type->getIdOperand(0) == component &&
|
||||
type->getIdOperand(1) == scope &&
|
||||
type->getIdOperand(2) == rows &&
|
||||
type->getIdOperand(3) == cols &&
|
||||
type->getIdOperand(4) == use)
|
||||
return type->getResultId();
|
||||
}
|
||||
|
||||
// not found, make it
|
||||
type = new Instruction(getUniqueId(), NoType, OpTypeCooperativeMatrixKHR);
|
||||
type->addIdOperand(component);
|
||||
type->addIdOperand(scope);
|
||||
type->addIdOperand(rows);
|
||||
type->addIdOperand(cols);
|
||||
type->addIdOperand(use);
|
||||
groupedTypes[OpTypeCooperativeMatrixKHR].push_back(type);
|
||||
constantsTypesGlobals.push_back(std::unique_ptr<Instruction>(type));
|
||||
module.mapInstruction(type);
|
||||
|
||||
return type->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::makeCooperativeMatrixTypeNV(Id component, Id scope, Id rows, Id cols)
|
||||
{
|
||||
// try to find it
|
||||
Instruction* type;
|
||||
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeMatrixNV].size(); ++t) {
|
||||
type = groupedTypes[OpTypeCooperativeMatrixNV][t];
|
||||
if (type->getIdOperand(0) == component &&
|
||||
type->getIdOperand(1) == scope &&
|
||||
type->getIdOperand(2) == rows &&
|
||||
if (type->getIdOperand(0) == component && type->getIdOperand(1) == scope && type->getIdOperand(2) == rows &&
|
||||
type->getIdOperand(3) == cols)
|
||||
return type->getResultId();
|
||||
}
|
||||
|
|
@ -507,6 +533,17 @@ Id Builder::makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols)
|
|||
return type->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::makeCooperativeMatrixTypeWithSameShape(Id component, Id otherType)
|
||||
{
|
||||
Instruction* instr = module.getInstruction(otherType);
|
||||
if (instr->getOpCode() == OpTypeCooperativeMatrixNV) {
|
||||
return makeCooperativeMatrixTypeNV(component, instr->getIdOperand(1), instr->getIdOperand(2), instr->getIdOperand(3));
|
||||
} else {
|
||||
assert(instr->getOpCode() == OpTypeCooperativeMatrixKHR);
|
||||
return makeCooperativeMatrixTypeKHR(component, instr->getIdOperand(1), instr->getIdOperand(2), instr->getIdOperand(3), instr->getIdOperand(4));
|
||||
}
|
||||
}
|
||||
|
||||
Id Builder::makeGenericType(spv::Op opcode, std::vector<spv::IdImmediate>& operands)
|
||||
{
|
||||
// try to find it
|
||||
|
|
@ -1254,6 +1291,7 @@ int Builder::getNumTypeConstituents(Id typeId) const
|
|||
}
|
||||
case OpTypeStruct:
|
||||
return instr->getNumOperands();
|
||||
case OpTypeCooperativeMatrixKHR:
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
// has only one constituent when used with OpCompositeConstruct.
|
||||
return 1;
|
||||
|
|
@ -1303,6 +1341,7 @@ Id Builder::getContainedTypeId(Id typeId, int member) const
|
|||
case OpTypeMatrix:
|
||||
case OpTypeArray:
|
||||
case OpTypeRuntimeArray:
|
||||
case OpTypeCooperativeMatrixKHR:
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
return instr->getIdOperand(0);
|
||||
case OpTypePointer:
|
||||
|
|
@ -1769,6 +1808,7 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector<Id>& members, boo
|
|||
case OpTypeVector:
|
||||
case OpTypeArray:
|
||||
case OpTypeMatrix:
|
||||
case OpTypeCooperativeMatrixKHR:
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
if (! specConstant) {
|
||||
Id existing = findCompositeConstant(typeClass, typeId, members);
|
||||
|
|
@ -2405,7 +2445,24 @@ Id Builder::createArrayLength(Id base, unsigned int member)
|
|||
return length->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::createCooperativeMatrixLength(Id type)
|
||||
Id Builder::createCooperativeMatrixLengthKHR(Id type)
|
||||
{
|
||||
spv::Id intType = makeUintType(32);
|
||||
|
||||
// Generate code for spec constants if in spec constant operation
|
||||
// generation mode.
|
||||
if (generatingOpCodeForSpecConst) {
|
||||
return createSpecConstantOp(OpCooperativeMatrixLengthKHR, intType, std::vector<Id>(1, type), std::vector<Id>());
|
||||
}
|
||||
|
||||
Instruction* length = new Instruction(getUniqueId(), intType, OpCooperativeMatrixLengthKHR);
|
||||
length->addIdOperand(type);
|
||||
buildPoint->addInstruction(std::unique_ptr<Instruction>(length));
|
||||
|
||||
return length->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::createCooperativeMatrixLengthNV(Id type)
|
||||
{
|
||||
spv::Id intType = makeUintType(32);
|
||||
|
||||
|
|
|
|||
|
|
@ -203,7 +203,9 @@ public:
|
|||
Id makeImageType(Id sampledType, Dim, bool depth, bool arrayed, bool ms, unsigned sampled, ImageFormat format);
|
||||
Id makeSamplerType();
|
||||
Id makeSampledImageType(Id imageType);
|
||||
Id makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols);
|
||||
Id makeCooperativeMatrixTypeKHR(Id component, Id scope, Id rows, Id cols, Id use);
|
||||
Id makeCooperativeMatrixTypeNV(Id component, Id scope, Id rows, Id cols);
|
||||
Id makeCooperativeMatrixTypeWithSameShape(Id component, Id otherType);
|
||||
Id makeGenericType(spv::Op opcode, std::vector<spv::IdImmediate>& operands);
|
||||
|
||||
// SPIR-V NonSemantic Shader DebugInfo Instructions
|
||||
|
|
@ -286,7 +288,10 @@ public:
|
|||
#ifdef GLSLANG_WEB
|
||||
bool isCooperativeMatrixType(Id typeId)const { return false; }
|
||||
#else
|
||||
bool isCooperativeMatrixType(Id typeId)const { return getTypeClass(typeId) == OpTypeCooperativeMatrixNV; }
|
||||
bool isCooperativeMatrixType(Id typeId)const
|
||||
{
|
||||
return getTypeClass(typeId) == OpTypeCooperativeMatrixKHR || getTypeClass(typeId) == OpTypeCooperativeMatrixNV;
|
||||
}
|
||||
#endif
|
||||
bool isAggregateType(Id typeId) const
|
||||
{ return isArrayType(typeId) || isStructType(typeId) || isCooperativeMatrixType(typeId); }
|
||||
|
|
@ -464,8 +469,10 @@ public:
|
|||
// Create an OpArrayLength instruction
|
||||
Id createArrayLength(Id base, unsigned int member);
|
||||
|
||||
// Create an OpCooperativeMatrixLengthKHR instruction
|
||||
Id createCooperativeMatrixLengthKHR(Id type);
|
||||
// Create an OpCooperativeMatrixLengthNV instruction
|
||||
Id createCooperativeMatrixLength(Id type);
|
||||
Id createCooperativeMatrixLengthNV(Id type);
|
||||
|
||||
// Create an OpCompositeExtract instruction
|
||||
Id createCompositeExtract(Id composite, Id typeId, unsigned index);
|
||||
|
|
|
|||
|
|
@ -790,6 +790,21 @@ const char* MemoryAccessString(int mem)
|
|||
}
|
||||
}
|
||||
|
||||
const int CooperativeMatrixOperandsCeiling = 6;
|
||||
|
||||
const char* CooperativeMatrixOperandsString(int op)
|
||||
{
|
||||
switch (op) {
|
||||
case CooperativeMatrixOperandsMatrixASignedComponentsShift: return "ASignedComponents";
|
||||
case CooperativeMatrixOperandsMatrixBSignedComponentsShift: return "BSignedComponents";
|
||||
case CooperativeMatrixOperandsMatrixCSignedComponentsShift: return "CSignedComponents";
|
||||
case CooperativeMatrixOperandsMatrixResultSignedComponentsShift: return "ResultSignedComponents";
|
||||
case CooperativeMatrixOperandsSaturatingAccumulationShift: return "SaturatingAccumulation";
|
||||
|
||||
default: return "Bad";
|
||||
}
|
||||
}
|
||||
|
||||
const char* ScopeString(int mem)
|
||||
{
|
||||
switch (mem) {
|
||||
|
|
@ -993,6 +1008,7 @@ const char* CapabilityString(int info)
|
|||
case CapabilityVariablePointers: return "VariablePointers";
|
||||
|
||||
case CapabilityCooperativeMatrixNV: return "CooperativeMatrixNV";
|
||||
case CapabilityCooperativeMatrixKHR: return "CooperativeMatrixKHR";
|
||||
case CapabilityShaderSMBuiltinsNV: return "ShaderSMBuiltinsNV";
|
||||
|
||||
case CapabilityFragmentShaderSampleInterlockEXT: return "CapabilityFragmentShaderSampleInterlockEXT";
|
||||
|
|
@ -1473,6 +1489,11 @@ const char* OpcodeString(int op)
|
|||
case OpCooperativeMatrixStoreNV: return "OpCooperativeMatrixStoreNV";
|
||||
case OpCooperativeMatrixMulAddNV: return "OpCooperativeMatrixMulAddNV";
|
||||
case OpCooperativeMatrixLengthNV: return "OpCooperativeMatrixLengthNV";
|
||||
case OpTypeCooperativeMatrixKHR: return "OpTypeCooperativeMatrixKHR";
|
||||
case OpCooperativeMatrixLoadKHR: return "OpCooperativeMatrixLoadKHR";
|
||||
case OpCooperativeMatrixStoreKHR: return "OpCooperativeMatrixStoreKHR";
|
||||
case OpCooperativeMatrixMulAddKHR: return "OpCooperativeMatrixMulAddKHR";
|
||||
case OpCooperativeMatrixLengthKHR: return "OpCooperativeMatrixLengthKHR";
|
||||
case OpDemoteToHelperInvocationEXT: return "OpDemoteToHelperInvocationEXT";
|
||||
case OpIsHelperInvocationEXT: return "OpIsHelperInvocationEXT";
|
||||
|
||||
|
|
@ -1536,6 +1557,7 @@ EnumParameters LoopControlParams[FunctionControlCeiling];
|
|||
EnumParameters SelectionControlParams[SelectControlCeiling];
|
||||
EnumParameters FunctionControlParams[FunctionControlCeiling];
|
||||
EnumParameters MemoryAccessParams[MemoryAccessCeiling];
|
||||
EnumParameters CooperativeMatrixOperandsParams[CooperativeMatrixOperandsCeiling];
|
||||
|
||||
// Set up all the parameterizing descriptions of the opcodes, operands, etc.
|
||||
void Parameterize()
|
||||
|
|
@ -1630,6 +1652,8 @@ void Parameterize()
|
|||
InstructionDesc[OpModuleProcessed].setResultAndType(false, false);
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].setResultAndType(true, false);
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].setResultAndType(false, false);
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].setResultAndType(true, false);
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].setResultAndType(false, false);
|
||||
InstructionDesc[OpBeginInvocationInterlockEXT].setResultAndType(false, false);
|
||||
InstructionDesc[OpEndInvocationInterlockEXT].setResultAndType(false, false);
|
||||
|
||||
|
|
@ -1701,6 +1725,7 @@ void Parameterize()
|
|||
OperandClassParams[OperandKernelEnqueueFlags].set(0, KernelEnqueueFlagsString, nullptr);
|
||||
OperandClassParams[OperandKernelProfilingInfo].set(0, KernelProfilingInfoString, nullptr, true);
|
||||
OperandClassParams[OperandCapability].set(0, CapabilityString, nullptr);
|
||||
OperandClassParams[OperandCooperativeMatrixOperands].set(CooperativeMatrixOperandsCeiling, CooperativeMatrixOperandsString, CooperativeMatrixOperandsParams, true);
|
||||
OperandClassParams[OperandOpcode].set(OpCodeMask + 1, OpcodeString, nullptr);
|
||||
|
||||
// set name of operator, an initial set of <id> style operands, and the description
|
||||
|
|
@ -3093,6 +3118,34 @@ void Parameterize()
|
|||
|
||||
InstructionDesc[OpCooperativeMatrixLengthNV].operands.push(OperandId, "'Type'");
|
||||
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].operands.push(OperandId, "'Component Type'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].operands.push(OperandId, "'Scope'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].operands.push(OperandId, "'Rows'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].operands.push(OperandId, "'Columns'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixKHR].operands.push(OperandId, "'Use'");
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandId, "'Pointer'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandId, "'Memory Layout'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandId, "'Stride'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandMemoryAccess, "'Memory Access'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandLiteralNumber, "", true);
|
||||
InstructionDesc[OpCooperativeMatrixLoadKHR].operands.push(OperandId, "", true);
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandId, "'Pointer'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandId, "'Object'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandId, "'Memory Layout'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandId, "'Stride'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandMemoryAccess, "'Memory Access'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandLiteralNumber, "", true);
|
||||
InstructionDesc[OpCooperativeMatrixStoreKHR].operands.push(OperandId, "", true);
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixMulAddKHR].operands.push(OperandId, "'A'");
|
||||
InstructionDesc[OpCooperativeMatrixMulAddKHR].operands.push(OperandId, "'B'");
|
||||
InstructionDesc[OpCooperativeMatrixMulAddKHR].operands.push(OperandId, "'C'");
|
||||
InstructionDesc[OpCooperativeMatrixMulAddKHR].operands.push(OperandCooperativeMatrixOperands, "'Cooperative Matrix Operands'", true);
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixLengthKHR].operands.push(OperandId, "'Type'");
|
||||
|
||||
InstructionDesc[OpDemoteToHelperInvocationEXT].setResultAndType(false, false);
|
||||
|
||||
InstructionDesc[OpReadClockKHR].operands.push(OperandScope, "'Scope'");
|
||||
|
|
|
|||
|
|
@ -156,6 +156,7 @@ enum OperandClass {
|
|||
OperandKernelEnqueueFlags,
|
||||
OperandKernelProfilingInfo,
|
||||
OperandCapability,
|
||||
OperandCooperativeMatrixOperands,
|
||||
|
||||
OperandOpcode,
|
||||
|
||||
|
|
|
|||
|
|
@ -1144,6 +1144,7 @@ enum Capability {
|
|||
CapabilityDotProduct = 6019,
|
||||
CapabilityDotProductKHR = 6019,
|
||||
CapabilityRayCullMaskKHR = 6020,
|
||||
CapabilityCooperativeMatrixKHR = 6022,
|
||||
CapabilityBitInstructions = 6025,
|
||||
CapabilityGroupNonUniformRotateKHR = 6026,
|
||||
CapabilityAtomicFloat32AddEXT = 6033,
|
||||
|
|
@ -1261,6 +1262,37 @@ enum PackedVectorFormat {
|
|||
PackedVectorFormatMax = 0x7fffffff,
|
||||
};
|
||||
|
||||
enum CooperativeMatrixOperandsShift {
|
||||
CooperativeMatrixOperandsMatrixASignedComponentsShift = 0,
|
||||
CooperativeMatrixOperandsMatrixBSignedComponentsShift = 1,
|
||||
CooperativeMatrixOperandsMatrixCSignedComponentsShift = 2,
|
||||
CooperativeMatrixOperandsMatrixResultSignedComponentsShift = 3,
|
||||
CooperativeMatrixOperandsSaturatingAccumulationShift = 4,
|
||||
CooperativeMatrixOperandsMax = 0x7fffffff,
|
||||
};
|
||||
|
||||
enum CooperativeMatrixOperandsMask {
|
||||
CooperativeMatrixOperandsMaskNone = 0,
|
||||
CooperativeMatrixOperandsMatrixASignedComponentsMask = 0x00000001,
|
||||
CooperativeMatrixOperandsMatrixBSignedComponentsMask = 0x00000002,
|
||||
CooperativeMatrixOperandsMatrixCSignedComponentsMask = 0x00000004,
|
||||
CooperativeMatrixOperandsMatrixResultSignedComponentsMask = 0x00000008,
|
||||
CooperativeMatrixOperandsSaturatingAccumulationMask = 0x00000010,
|
||||
};
|
||||
|
||||
enum CooperativeMatrixLayout {
|
||||
CooperativeMatrixLayoutCooperativeMatrixRowMajorKHR = 0,
|
||||
CooperativeMatrixLayoutCooperativeMatrixColumnMajorKHR = 1,
|
||||
CooperativeMatrixLayoutMax = 0x7fffffff,
|
||||
};
|
||||
|
||||
enum CooperativeMatrixUse {
|
||||
CooperativeMatrixUseMatrixAKHR = 0,
|
||||
CooperativeMatrixUseMatrixBKHR = 1,
|
||||
CooperativeMatrixUseMatrixAccumulatorKHR = 2,
|
||||
CooperativeMatrixUseMax = 0x7fffffff,
|
||||
};
|
||||
|
||||
enum Op {
|
||||
OpNop = 0,
|
||||
OpUndef = 1,
|
||||
|
|
@ -1634,6 +1666,11 @@ enum Op {
|
|||
OpUDotAccSatKHR = 4454,
|
||||
OpSUDotAccSat = 4455,
|
||||
OpSUDotAccSatKHR = 4455,
|
||||
OpTypeCooperativeMatrixKHR = 4456,
|
||||
OpCooperativeMatrixLoadKHR = 4457,
|
||||
OpCooperativeMatrixStoreKHR = 4458,
|
||||
OpCooperativeMatrixMulAddKHR = 4459,
|
||||
OpCooperativeMatrixLengthKHR = 4460,
|
||||
OpTypeRayQueryKHR = 4472,
|
||||
OpRayQueryInitializeKHR = 4473,
|
||||
OpRayQueryTerminateKHR = 4474,
|
||||
|
|
@ -2346,6 +2383,11 @@ inline void HasResultAndType(Op opcode, bool *hasResult, bool *hasResultType) {
|
|||
case OpSDotAccSat: *hasResult = true; *hasResultType = true; break;
|
||||
case OpUDotAccSat: *hasResult = true; *hasResultType = true; break;
|
||||
case OpSUDotAccSat: *hasResult = true; *hasResultType = true; break;
|
||||
case OpTypeCooperativeMatrixKHR: *hasResult = true; *hasResultType = false; break;
|
||||
case OpCooperativeMatrixLoadKHR: *hasResult = true; *hasResultType = true; break;
|
||||
case OpCooperativeMatrixStoreKHR: *hasResult = false; *hasResultType = false; break;
|
||||
case OpCooperativeMatrixMulAddKHR: *hasResult = true; *hasResultType = true; break;
|
||||
case OpCooperativeMatrixLengthKHR: *hasResult = true; *hasResultType = true; break;
|
||||
case OpTypeRayQueryKHR: *hasResult = true; *hasResultType = false; break;
|
||||
case OpRayQueryInitializeKHR: *hasResult = false; *hasResultType = false; break;
|
||||
case OpRayQueryTerminateKHR: *hasResult = false; *hasResultType = false; break;
|
||||
|
|
@ -2722,6 +2764,10 @@ inline FragmentShadingRateMask operator|(FragmentShadingRateMask a, FragmentShad
|
|||
inline FragmentShadingRateMask operator&(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) & unsigned(b)); }
|
||||
inline FragmentShadingRateMask operator^(FragmentShadingRateMask a, FragmentShadingRateMask b) { return FragmentShadingRateMask(unsigned(a) ^ unsigned(b)); }
|
||||
inline FragmentShadingRateMask operator~(FragmentShadingRateMask a) { return FragmentShadingRateMask(~unsigned(a)); }
|
||||
inline CooperativeMatrixOperandsMask operator|(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) | unsigned(b)); }
|
||||
inline CooperativeMatrixOperandsMask operator&(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) & unsigned(b)); }
|
||||
inline CooperativeMatrixOperandsMask operator^(CooperativeMatrixOperandsMask a, CooperativeMatrixOperandsMask b) { return CooperativeMatrixOperandsMask(unsigned(a) ^ unsigned(b)); }
|
||||
inline CooperativeMatrixOperandsMask operator~(CooperativeMatrixOperandsMask a) { return CooperativeMatrixOperandsMask(~unsigned(a)); }
|
||||
|
||||
} // end namespace spv
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue