Implement GL_NV_cooperative_matrix
This commit is contained in:
parent
ec484527b6
commit
4605e2ed2b
37 changed files with 5630 additions and 4211 deletions
|
|
@ -72,4 +72,7 @@ const char* const E_SPV_NV_ray_tracing = "SPV_NV_ray_tracing";
|
|||
//SPV_NV_shading_rate
|
||||
const char* const E_SPV_NV_shading_rate = "SPV_NV_shading_rate";
|
||||
|
||||
//SPV_NV_cooperative_matrix
|
||||
const char* const E_SPV_NV_cooperative_matrix = "SPV_NV_cooperative_matrix";
|
||||
|
||||
#endif // #ifndef GLSLextNV_H
|
||||
|
|
|
|||
149
SPIRV/GlslangToSpv.cpp
Executable file → Normal file
149
SPIRV/GlslangToSpv.cpp
Executable file → Normal file
|
|
@ -1330,6 +1330,10 @@ TGlslangToSpvTraverser::TGlslangToSpvTraverser(unsigned int spvVersion, const gl
|
|||
}
|
||||
builder.setMemoryModel(addressingModel, memoryModel);
|
||||
|
||||
if (glslangIntermediate->usingVariablePointers()) {
|
||||
builder.addCapability(spv::CapabilityVariablePointers);
|
||||
}
|
||||
|
||||
shaderEntry = builder.makeEntryPoint(glslangIntermediate->getEntryPointName().c_str());
|
||||
entryPoint = builder.addEntryPoint(executionModel, shaderEntry, glslangIntermediate->getEntryPointName().c_str());
|
||||
|
||||
|
|
@ -1870,16 +1874,31 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
|
|||
// So, this has to be block.lastMember.length().
|
||||
// SPV wants "block" and member number as the operands, go get them.
|
||||
|
||||
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
|
||||
block->traverse(this);
|
||||
unsigned int member = node->getOperand()->getAsBinaryNode()->getRight()->getAsConstantUnion()->getConstArray()[0].getUConst();
|
||||
spv::Id length = builder.createArrayLength(builder.accessChainGetLValue(), member);
|
||||
spv::Id length;
|
||||
if (node->getOperand()->getType().isCoopMat()) {
|
||||
spec_constant_op_mode_setter.turnOnSpecConstantOpMode();
|
||||
|
||||
spv::Id typeId = convertGlslangToSpvType(node->getOperand()->getType());
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
|
||||
length = builder.createCooperativeMatrixLength(typeId);
|
||||
} else {
|
||||
glslang::TIntermTyped* block = node->getOperand()->getAsBinaryNode()->getLeft();
|
||||
block->traverse(this);
|
||||
unsigned int member = node->getOperand()->getAsBinaryNode()->getRight()->getAsConstantUnion()->getConstArray()[0].getUConst();
|
||||
length = builder.createArrayLength(builder.accessChainGetLValue(), member);
|
||||
}
|
||||
|
||||
// GLSL semantics say the result of .length() is an int, while SPIR-V says
|
||||
// signedness must be 0. So, convert from SPIR-V unsigned back to GLSL's
|
||||
// AST expectation of a signed result.
|
||||
if (glslangIntermediate->getSource() == glslang::EShSourceGlsl)
|
||||
length = builder.createUnaryOp(spv::OpBitcast, builder.makeIntType(32), length);
|
||||
if (glslangIntermediate->getSource() == glslang::EShSourceGlsl) {
|
||||
if (builder.isInSpecConstCodeGenMode()) {
|
||||
length = builder.createBinOp(spv::OpIAdd, builder.makeIntType(32), length, builder.makeIntConstant(0));
|
||||
} else {
|
||||
length = builder.createUnaryOp(spv::OpBitcast, builder.makeIntType(32), length);
|
||||
}
|
||||
}
|
||||
|
||||
builder.clearAccessChain();
|
||||
builder.setAccessChainRValue(length);
|
||||
|
|
@ -2222,6 +2241,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
case glslang::EOpConstructStruct:
|
||||
case glslang::EOpConstructTextureSampler:
|
||||
case glslang::EOpConstructReference:
|
||||
case glslang::EOpConstructCooperativeMatrix:
|
||||
{
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
std::vector<spv::Id> arguments;
|
||||
|
|
@ -2229,7 +2249,9 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
spv::Id constructed;
|
||||
if (node->getOp() == glslang::EOpConstructTextureSampler)
|
||||
constructed = builder.createOp(spv::OpSampledImage, resultType(), arguments);
|
||||
else if (node->getOp() == glslang::EOpConstructStruct || node->getType().isArray()) {
|
||||
else if (node->getOp() == glslang::EOpConstructStruct ||
|
||||
node->getOp() == glslang::EOpConstructCooperativeMatrix ||
|
||||
node->getType().isArray()) {
|
||||
std::vector<spv::Id> constituents;
|
||||
for (int c = 0; c < (int)arguments.size(); ++c)
|
||||
constituents.push_back(arguments[c]);
|
||||
|
|
@ -2347,6 +2369,10 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
noReturnValue = true;
|
||||
break;
|
||||
#endif
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
noReturnValue = true;
|
||||
break;
|
||||
|
||||
default:
|
||||
break;
|
||||
|
|
@ -2389,6 +2415,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
//
|
||||
glslang::TIntermSequence& glslangOperands = node->getSequence();
|
||||
std::vector<spv::Id> operands;
|
||||
std::vector<spv::IdImmediate> memoryAccessOperands;
|
||||
for (int arg = 0; arg < (int)glslangOperands.size(); ++arg) {
|
||||
// special case l-value operands; there are just a few
|
||||
bool lvalue = false;
|
||||
|
|
@ -2445,6 +2472,14 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
if (arg >= 2)
|
||||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixLoad:
|
||||
if (arg == 0 || arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
case glslang::EOpCooperativeMatrixStore:
|
||||
if (arg == 1)
|
||||
lvalue = true;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
@ -2453,6 +2488,50 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
glslangOperands[0]->getAsBinaryNode()->getLeft()->traverse(this);
|
||||
else
|
||||
glslangOperands[arg]->traverse(this);
|
||||
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad ||
|
||||
node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
|
||||
if (arg == 1) {
|
||||
// fold "element" parameter into the access chain
|
||||
spv::Builder::AccessChain save = builder.getAccessChain();
|
||||
builder.clearAccessChain();
|
||||
glslangOperands[2]->traverse(this);
|
||||
|
||||
spv::Id elementId = accessChainLoad(glslangOperands[2]->getAsTyped()->getType());
|
||||
|
||||
builder.setAccessChain(save);
|
||||
|
||||
// Point to the first element of the array.
|
||||
builder.accessChainPush(elementId, TranslateCoherent(glslangOperands[arg]->getAsTyped()->getType()),
|
||||
getBufferReferenceAlignment(glslangOperands[arg]->getAsTyped()->getType()));
|
||||
|
||||
spv::Builder::AccessChain::CoherentFlags coherentFlags = builder.getAccessChain().coherentFlags;
|
||||
unsigned int alignment = builder.getAccessChain().alignment;
|
||||
|
||||
int memoryAccess = TranslateMemoryAccess(coherentFlags);
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerAvailableKHRMask;
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixStore)
|
||||
memoryAccess &= ~spv::MemoryAccessMakePointerVisibleKHRMask;
|
||||
if (builder.getStorageClass(builder.getAccessChain().base) == spv::StorageClassPhysicalStorageBufferEXT) {
|
||||
memoryAccess = (spv::MemoryAccessMask)(memoryAccess | spv::MemoryAccessAlignedMask);
|
||||
}
|
||||
|
||||
memoryAccessOperands.push_back(spv::IdImmediate(false, memoryAccess));
|
||||
|
||||
if (memoryAccess & spv::MemoryAccessAlignedMask) {
|
||||
memoryAccessOperands.push_back(spv::IdImmediate(false, alignment));
|
||||
}
|
||||
|
||||
if (memoryAccess & (spv::MemoryAccessMakePointerAvailableKHRMask | spv::MemoryAccessMakePointerVisibleKHRMask)) {
|
||||
memoryAccessOperands.push_back(spv::IdImmediate(true, builder.makeUintConstant(TranslateMemoryScope(coherentFlags))));
|
||||
}
|
||||
} else if (arg == 2) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (lvalue)
|
||||
operands.push_back(builder.accessChainGetLValue());
|
||||
else {
|
||||
|
|
@ -2462,7 +2541,33 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
|
|||
}
|
||||
|
||||
builder.setLine(node->getLoc().line, node->getLoc().getFilename());
|
||||
if (atomic) {
|
||||
if (node->getOp() == glslang::EOpCooperativeMatrixLoad) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
// get the pointee type
|
||||
spv::Id typeId = builder.getContainedTypeId(builder.getTypeId(operands[0]));
|
||||
assert(builder.isCooperativeMatrixType(typeId));
|
||||
// do the op
|
||||
spv::Id result = builder.createOp(spv::OpCooperativeMatrixLoadNV, typeId, idImmOps);
|
||||
// store the result to the pointer (out param 'm')
|
||||
builder.createStore(result, operands[0]);
|
||||
result = 0;
|
||||
} else if (node->getOp() == glslang::EOpCooperativeMatrixStore) {
|
||||
std::vector<spv::IdImmediate> idImmOps;
|
||||
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[1])); // buf
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[0])); // object
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[2])); // stride
|
||||
idImmOps.push_back(spv::IdImmediate(true, operands[3])); // colMajor
|
||||
idImmOps.insert(idImmOps.end(), memoryAccessOperands.begin(), memoryAccessOperands.end());
|
||||
|
||||
builder.createNoResultOp(spv::OpCooperativeMatrixStoreNV, idImmOps);
|
||||
result = 0;
|
||||
} else if (atomic) {
|
||||
// Handle all atomics
|
||||
result = createAtomicOperation(node->getOp(), precision, resultType(), operands, node->getBasicType());
|
||||
} else {
|
||||
|
|
@ -3090,6 +3195,19 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
|
|||
spvType = builder.makeVectorType(spvType, type.getVectorSize());
|
||||
}
|
||||
|
||||
if (type.isCoopMat()) {
|
||||
builder.addCapability(spv::CapabilityCooperativeMatrixNV);
|
||||
builder.addExtension(spv::E_SPV_NV_cooperative_matrix);
|
||||
if (type.getBasicType() == glslang::EbtFloat16)
|
||||
builder.addCapability(spv::CapabilityFloat16);
|
||||
|
||||
spv::Id scope = makeArraySizeId(*type.getTypeParameters(), 1);
|
||||
spv::Id rows = makeArraySizeId(*type.getTypeParameters(), 2);
|
||||
spv::Id cols = makeArraySizeId(*type.getTypeParameters(), 3);
|
||||
|
||||
spvType = builder.makeCooperativeMatrixType(spvType, scope, rows, cols);
|
||||
}
|
||||
|
||||
if (type.isArray()) {
|
||||
int stride = 0; // keep this 0 unless doing an explicit layout; 0 will mean no decoration, no stride
|
||||
|
||||
|
|
@ -4847,7 +4965,8 @@ spv::Id TGlslangToSpvTraverser::createBinaryOperation(glslang::TOperator op, OpD
|
|||
// handle mapped binary operations (should be non-comparison)
|
||||
if (binOp != spv::OpNop) {
|
||||
assert(comparison == false);
|
||||
if (builder.isMatrix(left) || builder.isMatrix(right))
|
||||
if (builder.isMatrix(left) || builder.isMatrix(right) ||
|
||||
builder.isCooperativeMatrix(left) || builder.isCooperativeMatrix(right))
|
||||
return createBinaryMatrixOperation(binOp, decorations, typeId, left, right);
|
||||
|
||||
// No matrix involved; make both operands be the same number of components, if needed
|
||||
|
|
@ -4968,7 +5087,7 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, OpDecora
|
|||
firstClass = false;
|
||||
break;
|
||||
case spv::OpMatrixTimesScalar:
|
||||
if (builder.isMatrix(right))
|
||||
if (builder.isMatrix(right) || builder.isCooperativeMatrix(right))
|
||||
std::swap(left, right);
|
||||
assert(builder.isScalar(right));
|
||||
break;
|
||||
|
|
@ -4989,6 +5108,9 @@ spv::Id TGlslangToSpvTraverser::createBinaryMatrixOperation(spv::Op op, OpDecora
|
|||
break;
|
||||
}
|
||||
|
||||
if (builder.isCooperativeMatrix(left) || builder.isCooperativeMatrix(right))
|
||||
firstClass = true;
|
||||
|
||||
if (firstClass) {
|
||||
spv::Id result = builder.createBinOp(op, typeId, left, right);
|
||||
builder.addDecoration(result, decorations.noContraction);
|
||||
|
|
@ -7030,6 +7152,10 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
|
|||
builder.createNoResultOp(spv::OpWritePackedPrimitiveIndices4x8NV, operands);
|
||||
return 0;
|
||||
#endif
|
||||
case glslang::EOpCooperativeMatrixMulAdd:
|
||||
opCode = spv::OpCooperativeMatrixMulAddNV;
|
||||
break;
|
||||
|
||||
default:
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -7486,6 +7612,9 @@ spv::Id TGlslangToSpvTraverser::createSpvConstantFromConstUnionArray(const glsla
|
|||
glslang::TType vectorType(glslangType, 0);
|
||||
for (int col = 0; col < glslangType.getMatrixCols(); ++col)
|
||||
spvConsts.push_back(createSpvConstantFromConstUnionArray(vectorType, consts, nextConst, false));
|
||||
} else if (glslangType.isCoopMat()) {
|
||||
glslang::TType componentType(glslangType.getBasicType());
|
||||
spvConsts.push_back(createSpvConstantFromConstUnionArray(componentType, consts, nextConst, false));
|
||||
} else if (glslangType.isStruct()) {
|
||||
glslang::TVector<glslang::TTypeLoc>::const_iterator iter;
|
||||
for (iter = glslangType.getStruct()->begin(); iter != glslangType.getStruct()->end(); ++iter)
|
||||
|
|
|
|||
60
SPIRV/SpvBuilder.cpp
Executable file → Normal file
60
SPIRV/SpvBuilder.cpp
Executable file → Normal file
|
|
@ -388,6 +388,33 @@ Id Builder::makeMatrixType(Id component, int cols, int rows)
|
|||
return type->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols)
|
||||
{
|
||||
// try to find it
|
||||
Instruction* type;
|
||||
for (int t = 0; t < (int)groupedTypes[OpTypeCooperativeMatrixNV].size(); ++t) {
|
||||
type = groupedTypes[OpTypeCooperativeMatrixNV][t];
|
||||
if (type->getIdOperand(0) == component &&
|
||||
type->getIdOperand(1) == scope &&
|
||||
type->getIdOperand(2) == rows &&
|
||||
type->getIdOperand(3) == cols)
|
||||
return type->getResultId();
|
||||
}
|
||||
|
||||
// not found, make it
|
||||
type = new Instruction(getUniqueId(), NoType, OpTypeCooperativeMatrixNV);
|
||||
type->addIdOperand(component);
|
||||
type->addIdOperand(scope);
|
||||
type->addIdOperand(rows);
|
||||
type->addIdOperand(cols);
|
||||
groupedTypes[OpTypeCooperativeMatrixNV].push_back(type);
|
||||
constantsTypesGlobals.push_back(std::unique_ptr<Instruction>(type));
|
||||
module.mapInstruction(type);
|
||||
|
||||
return type->getResultId();
|
||||
}
|
||||
|
||||
|
||||
// TODO: performance: track arrays per stride
|
||||
// If a stride is supplied (non-zero) make an array.
|
||||
// If no stride (0), reuse previous array types.
|
||||
|
|
@ -623,6 +650,9 @@ int Builder::getNumTypeConstituents(Id typeId) const
|
|||
}
|
||||
case OpTypeStruct:
|
||||
return instr->getNumOperands();
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
// has only one constituent when used with OpCompositeConstruct.
|
||||
return 1;
|
||||
default:
|
||||
assert(0);
|
||||
return 1;
|
||||
|
|
@ -669,6 +699,7 @@ Id Builder::getContainedTypeId(Id typeId, int member) const
|
|||
case OpTypeMatrix:
|
||||
case OpTypeArray:
|
||||
case OpTypeRuntimeArray:
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
return instr->getIdOperand(0);
|
||||
case OpTypePointer:
|
||||
return instr->getIdOperand(1);
|
||||
|
|
@ -981,15 +1012,14 @@ Id Builder::makeFpConstant(Id type, double d, bool specConstant)
|
|||
return NoResult;
|
||||
}
|
||||
|
||||
Id Builder::findCompositeConstant(Op typeClass, const std::vector<Id>& comps)
|
||||
Id Builder::findCompositeConstant(Op typeClass, Id typeId, const std::vector<Id>& comps)
|
||||
{
|
||||
Instruction* constant = 0;
|
||||
bool found = false;
|
||||
for (int i = 0; i < (int)groupedConstants[typeClass].size(); ++i) {
|
||||
constant = groupedConstants[typeClass][i];
|
||||
|
||||
// same shape?
|
||||
if (constant->getNumOperands() != (int)comps.size())
|
||||
if (constant->getTypeId() != typeId)
|
||||
continue;
|
||||
|
||||
// same contents?
|
||||
|
|
@ -1044,8 +1074,9 @@ Id Builder::makeCompositeConstant(Id typeId, const std::vector<Id>& members, boo
|
|||
case OpTypeVector:
|
||||
case OpTypeArray:
|
||||
case OpTypeMatrix:
|
||||
case OpTypeCooperativeMatrixNV:
|
||||
if (! specConstant) {
|
||||
Id existing = findCompositeConstant(typeClass, members);
|
||||
Id existing = findCompositeConstant(typeClass, typeId, members);
|
||||
if (existing)
|
||||
return existing;
|
||||
}
|
||||
|
|
@ -1408,6 +1439,23 @@ Id Builder::createArrayLength(Id base, unsigned int member)
|
|||
return length->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::createCooperativeMatrixLength(Id type)
|
||||
{
|
||||
spv::Id intType = makeUintType(32);
|
||||
|
||||
// Generate code for spec constants if in spec constant operation
|
||||
// generation mode.
|
||||
if (generatingOpCodeForSpecConst) {
|
||||
return createSpecConstantOp(OpCooperativeMatrixLengthNV, intType, std::vector<Id>(1, type), std::vector<Id>());
|
||||
}
|
||||
|
||||
Instruction* length = new Instruction(getUniqueId(), intType, OpCooperativeMatrixLengthNV);
|
||||
length->addIdOperand(type);
|
||||
buildPoint->addInstruction(std::unique_ptr<Instruction>(length));
|
||||
|
||||
return length->getResultId();
|
||||
}
|
||||
|
||||
Id Builder::createCompositeExtract(Id composite, Id typeId, unsigned index)
|
||||
{
|
||||
// Generate code for spec constants if in spec constant operation
|
||||
|
|
@ -2598,9 +2646,9 @@ Id Builder::accessChainLoad(Decoration precision, Decoration nonUniform, Id resu
|
|||
}
|
||||
}
|
||||
|
||||
if (constant)
|
||||
if (constant) {
|
||||
id = createCompositeExtract(accessChain.base, swizzleBase, indexes);
|
||||
else {
|
||||
} else {
|
||||
// make a new function variable for this r-value
|
||||
Id lValue = createVariable(StorageClassFunction, getTypeId(accessChain.base), "indexable");
|
||||
|
||||
|
|
|
|||
10
SPIRV/SpvBuilder.h
Executable file → Normal file
10
SPIRV/SpvBuilder.h
Executable file → Normal file
|
|
@ -155,6 +155,7 @@ public:
|
|||
Id makeImageType(Id sampledType, Dim, bool depth, bool arrayed, bool ms, unsigned sampled, ImageFormat format);
|
||||
Id makeSamplerType();
|
||||
Id makeSampledImageType(Id imageType);
|
||||
Id makeCooperativeMatrixType(Id component, Id scope, Id rows, Id cols);
|
||||
|
||||
// accelerationStructureNV type
|
||||
Id makeAccelerationStructureNVType();
|
||||
|
|
@ -178,6 +179,7 @@ public:
|
|||
bool isScalar(Id resultId) const { return isScalarType(getTypeId(resultId)); }
|
||||
bool isVector(Id resultId) const { return isVectorType(getTypeId(resultId)); }
|
||||
bool isMatrix(Id resultId) const { return isMatrixType(getTypeId(resultId)); }
|
||||
bool isCooperativeMatrix(Id resultId)const { return isCooperativeMatrixType(getTypeId(resultId)); }
|
||||
bool isAggregate(Id resultId) const { return isAggregateType(getTypeId(resultId)); }
|
||||
bool isSampledImage(Id resultId) const { return isSampledImageType(getTypeId(resultId)); }
|
||||
|
||||
|
|
@ -191,7 +193,8 @@ public:
|
|||
bool isMatrixType(Id typeId) const { return getTypeClass(typeId) == OpTypeMatrix; }
|
||||
bool isStructType(Id typeId) const { return getTypeClass(typeId) == OpTypeStruct; }
|
||||
bool isArrayType(Id typeId) const { return getTypeClass(typeId) == OpTypeArray; }
|
||||
bool isAggregateType(Id typeId) const { return isArrayType(typeId) || isStructType(typeId); }
|
||||
bool isCooperativeMatrixType(Id typeId)const { return getTypeClass(typeId) == OpTypeCooperativeMatrixNV; }
|
||||
bool isAggregateType(Id typeId) const { return isArrayType(typeId) || isStructType(typeId) || isCooperativeMatrixType(typeId); }
|
||||
bool isImageType(Id typeId) const { return getTypeClass(typeId) == OpTypeImage; }
|
||||
bool isSamplerType(Id typeId) const { return getTypeClass(typeId) == OpTypeSampler; }
|
||||
bool isSampledImageType(Id typeId) const { return getTypeClass(typeId) == OpTypeSampledImage; }
|
||||
|
|
@ -314,6 +317,9 @@ public:
|
|||
// Create an OpArrayLength instruction
|
||||
Id createArrayLength(Id base, unsigned int member);
|
||||
|
||||
// Create an OpCooperativeMatrixLengthNV instruction
|
||||
Id createCooperativeMatrixLength(Id type);
|
||||
|
||||
// Create an OpCompositeExtract instruction
|
||||
Id createCompositeExtract(Id composite, Id typeId, unsigned index);
|
||||
Id createCompositeExtract(Id composite, Id typeId, const std::vector<unsigned>& indexes);
|
||||
|
|
@ -670,7 +676,7 @@ public:
|
|||
Id makeInt64Constant(Id typeId, unsigned long long value, bool specConstant);
|
||||
Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned value);
|
||||
Id findScalarConstant(Op typeClass, Op opcode, Id typeId, unsigned v1, unsigned v2);
|
||||
Id findCompositeConstant(Op typeClass, const std::vector<Id>& comps);
|
||||
Id findCompositeConstant(Op typeClass, Id typeId, const std::vector<Id>& comps);
|
||||
Id findStructConstant(Id typeId, const std::vector<Id>& comps);
|
||||
Id collapseAccessChain();
|
||||
void remapDynamicSwizzle();
|
||||
|
|
|
|||
|
|
@ -930,6 +930,10 @@ const char* CapabilityString(int info)
|
|||
|
||||
case CapabilityPhysicalStorageBufferAddressesEXT: return "CapabilityPhysicalStorageBufferAddressesEXT";
|
||||
|
||||
case CapabilityVariablePointers: return "CapabilityVariablePointers";
|
||||
|
||||
case CapabilityCooperativeMatrixNV: return "CapabilityCooperativeMatrixNV";
|
||||
|
||||
default: return "Bad";
|
||||
}
|
||||
}
|
||||
|
|
@ -1333,6 +1337,12 @@ const char* OpcodeString(int op)
|
|||
case OpWritePackedPrimitiveIndices4x8NV: return "OpWritePackedPrimitiveIndices4x8NV";
|
||||
#endif
|
||||
|
||||
case OpTypeCooperativeMatrixNV: return "OpTypeCooperativeMatrixNV";
|
||||
case OpCooperativeMatrixLoadNV: return "OpCooperativeMatrixLoadNV";
|
||||
case OpCooperativeMatrixStoreNV: return "OpCooperativeMatrixStoreNV";
|
||||
case OpCooperativeMatrixMulAddNV: return "OpCooperativeMatrixMulAddNV";
|
||||
case OpCooperativeMatrixLengthNV: return "OpCooperativeMatrixLengthNV";
|
||||
|
||||
default:
|
||||
return "Bad";
|
||||
}
|
||||
|
|
@ -1444,6 +1454,8 @@ void Parameterize()
|
|||
InstructionDesc[OpGroupWaitEvents].setResultAndType(false, false);
|
||||
InstructionDesc[OpAtomicFlagClear].setResultAndType(false, false);
|
||||
InstructionDesc[OpModuleProcessed].setResultAndType(false, false);
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].setResultAndType(true, false);
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].setResultAndType(false, false);
|
||||
|
||||
// Specific additional context-dependent operands
|
||||
|
||||
|
|
@ -2714,6 +2726,32 @@ void Parameterize()
|
|||
InstructionDesc[OpWritePackedPrimitiveIndices4x8NV].operands.push(OperandId, "'Index Offset'");
|
||||
InstructionDesc[OpWritePackedPrimitiveIndices4x8NV].operands.push(OperandId, "'Packed Indices'");
|
||||
#endif
|
||||
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].operands.push(OperandId, "'Component Type'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].operands.push(OperandId, "'Scope'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].operands.push(OperandId, "'Rows'");
|
||||
InstructionDesc[OpTypeCooperativeMatrixNV].operands.push(OperandId, "'Columns'");
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandId, "'Pointer'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandId, "'Stride'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandId, "'Column Major'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandMemoryAccess, "'Memory Access'");
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandLiteralNumber, "", true);
|
||||
InstructionDesc[OpCooperativeMatrixLoadNV].operands.push(OperandId, "", true);
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandId, "'Pointer'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandId, "'Object'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandId, "'Stride'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandId, "'Column Major'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandMemoryAccess, "'Memory Access'");
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandLiteralNumber, "", true);
|
||||
InstructionDesc[OpCooperativeMatrixStoreNV].operands.push(OperandId, "", true);
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixMulAddNV].operands.push(OperandId, "'A'");
|
||||
InstructionDesc[OpCooperativeMatrixMulAddNV].operands.push(OperandId, "'B'");
|
||||
InstructionDesc[OpCooperativeMatrixMulAddNV].operands.push(OperandId, "'C'");
|
||||
|
||||
InstructionDesc[OpCooperativeMatrixLengthNV].operands.push(OperandId, "'Type'");
|
||||
}
|
||||
|
||||
}; // end spv namespace
|
||||
|
|
|
|||
|
|
@ -811,6 +811,7 @@ enum Capability {
|
|||
CapabilityVulkanMemoryModelDeviceScopeKHR = 5346,
|
||||
CapabilityPhysicalStorageBufferAddressesEXT = 5347,
|
||||
CapabilityComputeDerivativeGroupLinearNV = 5350,
|
||||
CapabilityCooperativeMatrixNV = 5357,
|
||||
CapabilitySubgroupShuffleINTEL = 5568,
|
||||
CapabilitySubgroupBufferBlockIOINTEL = 5569,
|
||||
CapabilitySubgroupImageBlockIOINTEL = 5570,
|
||||
|
|
@ -1183,6 +1184,11 @@ enum Op {
|
|||
OpTraceNV = 5337,
|
||||
OpTypeAccelerationStructureNV = 5341,
|
||||
OpExecuteCallableNV = 5344,
|
||||
OpTypeCooperativeMatrixNV = 5358,
|
||||
OpCooperativeMatrixLoadNV = 5359,
|
||||
OpCooperativeMatrixStoreNV = 5360,
|
||||
OpCooperativeMatrixMulAddNV = 5361,
|
||||
OpCooperativeMatrixLengthNV = 5362,
|
||||
OpSubgroupShuffleINTEL = 5571,
|
||||
OpSubgroupShuffleDownINTEL = 5572,
|
||||
OpSubgroupShuffleUpINTEL = 5573,
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ const MemorySemanticsMask MemorySemanticsAllMemory =
|
|||
struct IdImmediate {
|
||||
bool isId; // true if word is an Id, false if word is an immediate
|
||||
unsigned word;
|
||||
IdImmediate(bool i, unsigned w) : isId(i), word(w) {}
|
||||
};
|
||||
|
||||
//
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue