Implement support for GL_KHR_cooperative_matrix extension

This commit is contained in:
Boris Zanin 2023-03-16 13:01:01 +01:00 committed by arcady-lunarg
parent 91a97b4c69
commit 808c7ed17c
40 changed files with 8227 additions and 5733 deletions

View file

@ -1501,7 +1501,8 @@ TIntermTyped* TParseContext::handleFunctionCall(const TSourceLoc& loc, TFunction
if (result->getAsTyped()->getType().isCoopMat() &&
!result->getAsTyped()->getType().isParameterized()) {
assert(fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAdd);
assert(fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAdd ||
fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAddNV);
result->setType(result->getAsAggregate()->getSequence()[2]->getAsTyped()->getType());
}
@ -3642,6 +3643,12 @@ bool TParseContext::constructorError(const TSourceLoc& loc, TIntermNode* node, T
}
TIntermTyped* typed = node->getAsTyped();
if (type.isCoopMat() && typed->getType().isCoopMat() &&
!type.sameCoopMatShapeAndUse(typed->getType())) {
error(loc, "Cooperative matrix type parameters mismatch", constructorString.c_str(), "");
return true;
}
if (typed == nullptr) {
error(loc, "constructor argument does not have a type", constructorString.c_str(), "");
return true;
@ -4302,7 +4309,7 @@ TPrecisionQualifier TParseContext::getDefaultPrecision(TPublicType& publicType)
return defaultPrecision[publicType.basicType];
}
void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType baseType, TQualifier& qualifier)
void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType baseType, TQualifier& qualifier, bool isCoopMat)
{
// Built-in symbols are allowed some ambiguous precisions, to be pinned down
// later by context.
@ -4314,6 +4321,9 @@ void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType ba
error(loc, "atomic counters can only be highp", "atomic_uint", "");
#endif
if (isCoopMat)
return;
if (baseType == EbtFloat || baseType == EbtUint || baseType == EbtInt || baseType == EbtSampler || baseType == EbtAtomicUint) {
if (qualifier.precision == EpqNone) {
if (relaxedErrors())
@ -4358,7 +4368,8 @@ bool TParseContext::containsFieldWithBasicType(const TType& type, TBasicType bas
//
// Do size checking for an array type's size.
//
void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TArraySize& sizePair, const char *sizeType)
void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TArraySize& sizePair,
const char* sizeType, const bool allowZero)
{
bool isConst = false;
sizePair.node = nullptr;
@ -4378,9 +4389,8 @@ void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TA
TIntermSymbol* symbol = expr->getAsSymbolNode();
if (symbol && symbol->getConstArray().size() > 0)
size = symbol->getConstArray()[0].getIConst();
} else if (expr->getAsUnaryNode() &&
expr->getAsUnaryNode()->getOp() == glslang::EOpArrayLength &&
expr->getAsUnaryNode()->getOperand()->getType().isCoopMat()) {
} else if (expr->getAsUnaryNode() && expr->getAsUnaryNode()->getOp() == glslang::EOpArrayLength &&
expr->getAsUnaryNode()->getOperand()->getType().isCoopMatNV()) {
isConst = true;
size = 1;
sizePair.node = expr->getAsUnaryNode();
@ -4394,9 +4404,16 @@ void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TA
return;
}
if (size <= 0) {
error(loc, sizeType, "", "must be a positive integer");
return;
if (allowZero) {
if (size < 0) {
error(loc, sizeType, "", "must be a non-negative integer");
return;
}
} else {
if (size <= 0) {
error(loc, sizeType, "", "must be a positive integer");
return;
}
}
}
@ -7371,6 +7388,43 @@ void TParseContext::declareTypeDefaults(const TSourceLoc& loc, const TPublicType
#endif
}
void TParseContext::coopMatTypeParametersCheck(const TSourceLoc& loc, const TPublicType& publicType)
{
#ifndef GLSLANG_WEB
if (parsingBuiltins)
return;
if (publicType.isCoopmatKHR()) {
if (publicType.typeParameters == nullptr) {
error(loc, "coopmat missing type parameters", "", "");
return;
}
switch (publicType.typeParameters->basicType) {
case EbtFloat:
case EbtFloat16:
case EbtInt:
case EbtInt8:
case EbtInt16:
case EbtUint:
case EbtUint8:
case EbtUint16:
break;
default:
error(loc, "coopmat invalid basic type", TType::getBasicString(publicType.typeParameters->basicType), "");
break;
}
if (publicType.typeParameters->arraySizes->getNumDims() != 4) {
error(loc, "coopmat incorrect number of type parameters", "", "");
return;
}
int use = publicType.typeParameters->arraySizes->getDimSize(3);
if (use < 0 || use > 2) {
error(loc, "coopmat invalid matrix Use", "", "");
return;
}
}
#endif
}
bool TParseContext::vkRelaxedRemapUniformVariable(const TSourceLoc& loc, TString& identifier, const TPublicType&,
TArraySizes*, TIntermTyped* initializer, TType& type)
{
@ -7486,29 +7540,43 @@ TIntermNode* TParseContext::declareVariable(const TSourceLoc& loc, TString& iden
}
if (type.isCoopMat()) {
if (type.isCoopMatKHR()) {
intermediate.setUseVulkanMemoryModel();
intermediate.setUseStorageBuffer();
if (!publicType.typeParameters || publicType.typeParameters->getNumDims() != 4) {
if (!publicType.typeParameters || !publicType.typeParameters->arraySizes ||
publicType.typeParameters->arraySizes->getNumDims() != 3) {
error(loc, "unexpected number type parameters", identifier.c_str(), "");
}
if (publicType.typeParameters) {
if (!isTypeFloat(publicType.typeParameters->basicType) && !isTypeInt(publicType.typeParameters->basicType)) {
error(loc, "expected 8, 16, 32, or 64 bit signed or unsigned integer or 16, 32, or 64 bit float type", identifier.c_str(), "");
}
}
}
else if (type.isCoopMatNV()) {
intermediate.setUseVulkanMemoryModel();
intermediate.setUseStorageBuffer();
if (!publicType.typeParameters || publicType.typeParameters->arraySizes->getNumDims() != 4) {
error(loc, "expected four type parameters", identifier.c_str(), "");
}
if (publicType.typeParameters) {
if (isTypeFloat(publicType.basicType) &&
publicType.typeParameters->getDimSize(0) != 16 &&
publicType.typeParameters->getDimSize(0) != 32 &&
publicType.typeParameters->getDimSize(0) != 64) {
publicType.typeParameters->arraySizes->getDimSize(0) != 16 &&
publicType.typeParameters->arraySizes->getDimSize(0) != 32 &&
publicType.typeParameters->arraySizes->getDimSize(0) != 64) {
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
}
if (isTypeInt(publicType.basicType) &&
publicType.typeParameters->getDimSize(0) != 8 &&
publicType.typeParameters->getDimSize(0) != 32) {
error(loc, "expected 8 or 32 bits for first type parameter", identifier.c_str(), "");
publicType.typeParameters->arraySizes->getDimSize(0) != 8 &&
publicType.typeParameters->arraySizes->getDimSize(0) != 16 &&
publicType.typeParameters->arraySizes->getDimSize(0) != 32) {
error(loc, "expected 8, 16, or 32 bits for first type parameter", identifier.c_str(), "");
}
}
} else {
if (publicType.typeParameters && publicType.typeParameters->getNumDims() != 0) {
if (publicType.typeParameters && publicType.typeParameters->arraySizes->getNumDims() != 0) {
error(loc, "unexpected type parameters", identifier.c_str(), "");
}
}
@ -8336,14 +8404,18 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
return nullptr;
}
case EOpConstructCooperativeMatrix:
case EOpConstructCooperativeMatrixNV:
case EOpConstructCooperativeMatrixKHR:
if (node->getType() == type) {
return node;
}
if (!node->getType().isCoopMat()) {
if (type.getBasicType() != node->getType().getBasicType()) {
node = intermediate.addConversion(type.getBasicType(), node);
if (node == nullptr)
return nullptr;
}
node = intermediate.setAggregateOperator(node, EOpConstructCooperativeMatrix, type, node->getLoc());
node = intermediate.setAggregateOperator(node, op, type, node->getLoc());
} else {
TOperator op = EOpNull;
switch (type.getBasicType()) {
@ -8356,6 +8428,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat16: op = EOpConvFloat16ToInt; break;
case EbtUint8: op = EOpConvUint8ToInt; break;
case EbtInt8: op = EOpConvInt8ToInt; break;
case EbtUint16: op = EOpConvUint16ToInt; break;
case EbtInt16: op = EOpConvInt16ToInt; break;
case EbtUint: op = EOpConvUintToInt; break;
default: assert(0);
}
@ -8366,8 +8440,33 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat16: op = EOpConvFloat16ToUint; break;
case EbtUint8: op = EOpConvUint8ToUint; break;
case EbtInt8: op = EOpConvInt8ToUint; break;
case EbtUint16: op = EOpConvUint16ToUint; break;
case EbtInt16: op = EOpConvInt16ToUint; break;
case EbtInt: op = EOpConvIntToUint; break;
case EbtUint: op = EOpConvUintToInt8; break;
default: assert(0);
}
break;
case EbtInt16:
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToInt16; break;
case EbtFloat16: op = EOpConvFloat16ToInt16; break;
case EbtUint8: op = EOpConvUint8ToInt16; break;
case EbtInt8: op = EOpConvInt8ToInt16; break;
case EbtUint16: op = EOpConvUint16ToInt16; break;
case EbtInt: op = EOpConvIntToInt16; break;
case EbtUint: op = EOpConvUintToInt16; break;
default: assert(0);
}
break;
case EbtUint16:
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToUint16; break;
case EbtFloat16: op = EOpConvFloat16ToUint16; break;
case EbtUint8: op = EOpConvUint8ToUint16; break;
case EbtInt8: op = EOpConvInt8ToUint16; break;
case EbtInt16: op = EOpConvInt16ToUint16; break;
case EbtInt: op = EOpConvIntToUint16; break;
case EbtUint: op = EOpConvUintToUint16; break;
default: assert(0);
}
break;
@ -8376,6 +8475,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat: op = EOpConvFloatToInt8; break;
case EbtFloat16: op = EOpConvFloat16ToInt8; break;
case EbtUint8: op = EOpConvUint8ToInt8; break;
case EbtInt16: op = EOpConvInt16ToInt8; break;
case EbtUint16: op = EOpConvUint16ToInt8; break;
case EbtInt: op = EOpConvIntToInt8; break;
case EbtUint: op = EOpConvUintToInt8; break;
default: assert(0);
@ -8386,6 +8487,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat: op = EOpConvFloatToUint8; break;
case EbtFloat16: op = EOpConvFloat16ToUint8; break;
case EbtInt8: op = EOpConvInt8ToUint8; break;
case EbtInt16: op = EOpConvInt16ToUint8; break;
case EbtUint16: op = EOpConvUint16ToUint8; break;
case EbtInt: op = EOpConvIntToUint8; break;
case EbtUint: op = EOpConvUintToUint8; break;
default: assert(0);
@ -8396,6 +8499,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat16: op = EOpConvFloat16ToFloat; break;
case EbtInt8: op = EOpConvInt8ToFloat; break;
case EbtUint8: op = EOpConvUint8ToFloat; break;
case EbtInt16: op = EOpConvInt16ToFloat; break;
case EbtUint16: op = EOpConvUint16ToFloat; break;
case EbtInt: op = EOpConvIntToFloat; break;
case EbtUint: op = EOpConvUintToFloat; break;
default: assert(0);
@ -8406,6 +8511,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
case EbtFloat: op = EOpConvFloatToFloat16; break;
case EbtInt8: op = EOpConvInt8ToFloat16; break;
case EbtUint8: op = EOpConvUint8ToFloat16; break;
case EbtInt16: op = EOpConvInt16ToFloat16; break;
case EbtUint16: op = EOpConvUint16ToFloat16; break;
case EbtInt: op = EOpConvIntToFloat16; break;
case EbtUint: op = EOpConvUintToFloat16; break;
default: assert(0);