Implement support for GL_KHR_cooperative_matrix extension
This commit is contained in:
parent
91a97b4c69
commit
808c7ed17c
40 changed files with 8227 additions and 5733 deletions
|
|
@ -1501,7 +1501,8 @@ TIntermTyped* TParseContext::handleFunctionCall(const TSourceLoc& loc, TFunction
|
|||
|
||||
if (result->getAsTyped()->getType().isCoopMat() &&
|
||||
!result->getAsTyped()->getType().isParameterized()) {
|
||||
assert(fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAdd);
|
||||
assert(fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAdd ||
|
||||
fnCandidate->getBuiltInOp() == EOpCooperativeMatrixMulAddNV);
|
||||
|
||||
result->setType(result->getAsAggregate()->getSequence()[2]->getAsTyped()->getType());
|
||||
}
|
||||
|
|
@ -3642,6 +3643,12 @@ bool TParseContext::constructorError(const TSourceLoc& loc, TIntermNode* node, T
|
|||
}
|
||||
|
||||
TIntermTyped* typed = node->getAsTyped();
|
||||
if (type.isCoopMat() && typed->getType().isCoopMat() &&
|
||||
!type.sameCoopMatShapeAndUse(typed->getType())) {
|
||||
error(loc, "Cooperative matrix type parameters mismatch", constructorString.c_str(), "");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (typed == nullptr) {
|
||||
error(loc, "constructor argument does not have a type", constructorString.c_str(), "");
|
||||
return true;
|
||||
|
|
@ -4302,7 +4309,7 @@ TPrecisionQualifier TParseContext::getDefaultPrecision(TPublicType& publicType)
|
|||
return defaultPrecision[publicType.basicType];
|
||||
}
|
||||
|
||||
void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType baseType, TQualifier& qualifier)
|
||||
void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType baseType, TQualifier& qualifier, bool isCoopMat)
|
||||
{
|
||||
// Built-in symbols are allowed some ambiguous precisions, to be pinned down
|
||||
// later by context.
|
||||
|
|
@ -4314,6 +4321,9 @@ void TParseContext::precisionQualifierCheck(const TSourceLoc& loc, TBasicType ba
|
|||
error(loc, "atomic counters can only be highp", "atomic_uint", "");
|
||||
#endif
|
||||
|
||||
if (isCoopMat)
|
||||
return;
|
||||
|
||||
if (baseType == EbtFloat || baseType == EbtUint || baseType == EbtInt || baseType == EbtSampler || baseType == EbtAtomicUint) {
|
||||
if (qualifier.precision == EpqNone) {
|
||||
if (relaxedErrors())
|
||||
|
|
@ -4358,7 +4368,8 @@ bool TParseContext::containsFieldWithBasicType(const TType& type, TBasicType bas
|
|||
//
|
||||
// Do size checking for an array type's size.
|
||||
//
|
||||
void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TArraySize& sizePair, const char *sizeType)
|
||||
void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TArraySize& sizePair,
|
||||
const char* sizeType, const bool allowZero)
|
||||
{
|
||||
bool isConst = false;
|
||||
sizePair.node = nullptr;
|
||||
|
|
@ -4378,9 +4389,8 @@ void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TA
|
|||
TIntermSymbol* symbol = expr->getAsSymbolNode();
|
||||
if (symbol && symbol->getConstArray().size() > 0)
|
||||
size = symbol->getConstArray()[0].getIConst();
|
||||
} else if (expr->getAsUnaryNode() &&
|
||||
expr->getAsUnaryNode()->getOp() == glslang::EOpArrayLength &&
|
||||
expr->getAsUnaryNode()->getOperand()->getType().isCoopMat()) {
|
||||
} else if (expr->getAsUnaryNode() && expr->getAsUnaryNode()->getOp() == glslang::EOpArrayLength &&
|
||||
expr->getAsUnaryNode()->getOperand()->getType().isCoopMatNV()) {
|
||||
isConst = true;
|
||||
size = 1;
|
||||
sizePair.node = expr->getAsUnaryNode();
|
||||
|
|
@ -4394,9 +4404,16 @@ void TParseContext::arraySizeCheck(const TSourceLoc& loc, TIntermTyped* expr, TA
|
|||
return;
|
||||
}
|
||||
|
||||
if (size <= 0) {
|
||||
error(loc, sizeType, "", "must be a positive integer");
|
||||
return;
|
||||
if (allowZero) {
|
||||
if (size < 0) {
|
||||
error(loc, sizeType, "", "must be a non-negative integer");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (size <= 0) {
|
||||
error(loc, sizeType, "", "must be a positive integer");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -7371,6 +7388,43 @@ void TParseContext::declareTypeDefaults(const TSourceLoc& loc, const TPublicType
|
|||
#endif
|
||||
}
|
||||
|
||||
void TParseContext::coopMatTypeParametersCheck(const TSourceLoc& loc, const TPublicType& publicType)
|
||||
{
|
||||
#ifndef GLSLANG_WEB
|
||||
if (parsingBuiltins)
|
||||
return;
|
||||
if (publicType.isCoopmatKHR()) {
|
||||
if (publicType.typeParameters == nullptr) {
|
||||
error(loc, "coopmat missing type parameters", "", "");
|
||||
return;
|
||||
}
|
||||
switch (publicType.typeParameters->basicType) {
|
||||
case EbtFloat:
|
||||
case EbtFloat16:
|
||||
case EbtInt:
|
||||
case EbtInt8:
|
||||
case EbtInt16:
|
||||
case EbtUint:
|
||||
case EbtUint8:
|
||||
case EbtUint16:
|
||||
break;
|
||||
default:
|
||||
error(loc, "coopmat invalid basic type", TType::getBasicString(publicType.typeParameters->basicType), "");
|
||||
break;
|
||||
}
|
||||
if (publicType.typeParameters->arraySizes->getNumDims() != 4) {
|
||||
error(loc, "coopmat incorrect number of type parameters", "", "");
|
||||
return;
|
||||
}
|
||||
int use = publicType.typeParameters->arraySizes->getDimSize(3);
|
||||
if (use < 0 || use > 2) {
|
||||
error(loc, "coopmat invalid matrix Use", "", "");
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
bool TParseContext::vkRelaxedRemapUniformVariable(const TSourceLoc& loc, TString& identifier, const TPublicType&,
|
||||
TArraySizes*, TIntermTyped* initializer, TType& type)
|
||||
{
|
||||
|
|
@ -7486,29 +7540,43 @@ TIntermNode* TParseContext::declareVariable(const TSourceLoc& loc, TString& iden
|
|||
|
||||
}
|
||||
|
||||
if (type.isCoopMat()) {
|
||||
if (type.isCoopMatKHR()) {
|
||||
intermediate.setUseVulkanMemoryModel();
|
||||
intermediate.setUseStorageBuffer();
|
||||
|
||||
if (!publicType.typeParameters || publicType.typeParameters->getNumDims() != 4) {
|
||||
if (!publicType.typeParameters || !publicType.typeParameters->arraySizes ||
|
||||
publicType.typeParameters->arraySizes->getNumDims() != 3) {
|
||||
error(loc, "unexpected number type parameters", identifier.c_str(), "");
|
||||
}
|
||||
if (publicType.typeParameters) {
|
||||
if (!isTypeFloat(publicType.typeParameters->basicType) && !isTypeInt(publicType.typeParameters->basicType)) {
|
||||
error(loc, "expected 8, 16, 32, or 64 bit signed or unsigned integer or 16, 32, or 64 bit float type", identifier.c_str(), "");
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (type.isCoopMatNV()) {
|
||||
intermediate.setUseVulkanMemoryModel();
|
||||
intermediate.setUseStorageBuffer();
|
||||
|
||||
if (!publicType.typeParameters || publicType.typeParameters->arraySizes->getNumDims() != 4) {
|
||||
error(loc, "expected four type parameters", identifier.c_str(), "");
|
||||
}
|
||||
if (publicType.typeParameters) {
|
||||
if (isTypeFloat(publicType.basicType) &&
|
||||
publicType.typeParameters->getDimSize(0) != 16 &&
|
||||
publicType.typeParameters->getDimSize(0) != 32 &&
|
||||
publicType.typeParameters->getDimSize(0) != 64) {
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 16 &&
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 32 &&
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 64) {
|
||||
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
|
||||
}
|
||||
if (isTypeInt(publicType.basicType) &&
|
||||
publicType.typeParameters->getDimSize(0) != 8 &&
|
||||
publicType.typeParameters->getDimSize(0) != 32) {
|
||||
error(loc, "expected 8 or 32 bits for first type parameter", identifier.c_str(), "");
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 8 &&
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 16 &&
|
||||
publicType.typeParameters->arraySizes->getDimSize(0) != 32) {
|
||||
error(loc, "expected 8, 16, or 32 bits for first type parameter", identifier.c_str(), "");
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (publicType.typeParameters && publicType.typeParameters->getNumDims() != 0) {
|
||||
if (publicType.typeParameters && publicType.typeParameters->arraySizes->getNumDims() != 0) {
|
||||
error(loc, "unexpected type parameters", identifier.c_str(), "");
|
||||
}
|
||||
}
|
||||
|
|
@ -8336,14 +8404,18 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
case EOpConstructCooperativeMatrix:
|
||||
case EOpConstructCooperativeMatrixNV:
|
||||
case EOpConstructCooperativeMatrixKHR:
|
||||
if (node->getType() == type) {
|
||||
return node;
|
||||
}
|
||||
if (!node->getType().isCoopMat()) {
|
||||
if (type.getBasicType() != node->getType().getBasicType()) {
|
||||
node = intermediate.addConversion(type.getBasicType(), node);
|
||||
if (node == nullptr)
|
||||
return nullptr;
|
||||
}
|
||||
node = intermediate.setAggregateOperator(node, EOpConstructCooperativeMatrix, type, node->getLoc());
|
||||
node = intermediate.setAggregateOperator(node, op, type, node->getLoc());
|
||||
} else {
|
||||
TOperator op = EOpNull;
|
||||
switch (type.getBasicType()) {
|
||||
|
|
@ -8356,6 +8428,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat16: op = EOpConvFloat16ToInt; break;
|
||||
case EbtUint8: op = EOpConvUint8ToInt; break;
|
||||
case EbtInt8: op = EOpConvInt8ToInt; break;
|
||||
case EbtUint16: op = EOpConvUint16ToInt; break;
|
||||
case EbtInt16: op = EOpConvInt16ToInt; break;
|
||||
case EbtUint: op = EOpConvUintToInt; break;
|
||||
default: assert(0);
|
||||
}
|
||||
|
|
@ -8366,8 +8440,33 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat16: op = EOpConvFloat16ToUint; break;
|
||||
case EbtUint8: op = EOpConvUint8ToUint; break;
|
||||
case EbtInt8: op = EOpConvInt8ToUint; break;
|
||||
case EbtUint16: op = EOpConvUint16ToUint; break;
|
||||
case EbtInt16: op = EOpConvInt16ToUint; break;
|
||||
case EbtInt: op = EOpConvIntToUint; break;
|
||||
case EbtUint: op = EOpConvUintToInt8; break;
|
||||
default: assert(0);
|
||||
}
|
||||
break;
|
||||
case EbtInt16:
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToInt16; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToInt16; break;
|
||||
case EbtUint8: op = EOpConvUint8ToInt16; break;
|
||||
case EbtInt8: op = EOpConvInt8ToInt16; break;
|
||||
case EbtUint16: op = EOpConvUint16ToInt16; break;
|
||||
case EbtInt: op = EOpConvIntToInt16; break;
|
||||
case EbtUint: op = EOpConvUintToInt16; break;
|
||||
default: assert(0);
|
||||
}
|
||||
break;
|
||||
case EbtUint16:
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToUint16; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToUint16; break;
|
||||
case EbtUint8: op = EOpConvUint8ToUint16; break;
|
||||
case EbtInt8: op = EOpConvInt8ToUint16; break;
|
||||
case EbtInt16: op = EOpConvInt16ToUint16; break;
|
||||
case EbtInt: op = EOpConvIntToUint16; break;
|
||||
case EbtUint: op = EOpConvUintToUint16; break;
|
||||
default: assert(0);
|
||||
}
|
||||
break;
|
||||
|
|
@ -8376,6 +8475,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat: op = EOpConvFloatToInt8; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToInt8; break;
|
||||
case EbtUint8: op = EOpConvUint8ToInt8; break;
|
||||
case EbtInt16: op = EOpConvInt16ToInt8; break;
|
||||
case EbtUint16: op = EOpConvUint16ToInt8; break;
|
||||
case EbtInt: op = EOpConvIntToInt8; break;
|
||||
case EbtUint: op = EOpConvUintToInt8; break;
|
||||
default: assert(0);
|
||||
|
|
@ -8386,6 +8487,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat: op = EOpConvFloatToUint8; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToUint8; break;
|
||||
case EbtInt8: op = EOpConvInt8ToUint8; break;
|
||||
case EbtInt16: op = EOpConvInt16ToUint8; break;
|
||||
case EbtUint16: op = EOpConvUint16ToUint8; break;
|
||||
case EbtInt: op = EOpConvIntToUint8; break;
|
||||
case EbtUint: op = EOpConvUintToUint8; break;
|
||||
default: assert(0);
|
||||
|
|
@ -8396,6 +8499,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat16: op = EOpConvFloat16ToFloat; break;
|
||||
case EbtInt8: op = EOpConvInt8ToFloat; break;
|
||||
case EbtUint8: op = EOpConvUint8ToFloat; break;
|
||||
case EbtInt16: op = EOpConvInt16ToFloat; break;
|
||||
case EbtUint16: op = EOpConvUint16ToFloat; break;
|
||||
case EbtInt: op = EOpConvIntToFloat; break;
|
||||
case EbtUint: op = EOpConvUintToFloat; break;
|
||||
default: assert(0);
|
||||
|
|
@ -8406,6 +8511,8 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
case EbtFloat: op = EOpConvFloatToFloat16; break;
|
||||
case EbtInt8: op = EOpConvInt8ToFloat16; break;
|
||||
case EbtUint8: op = EOpConvUint8ToFloat16; break;
|
||||
case EbtInt16: op = EOpConvInt16ToFloat16; break;
|
||||
case EbtUint16: op = EOpConvUint16ToFloat16; break;
|
||||
case EbtInt: op = EOpConvIntToFloat16; break;
|
||||
case EbtUint: op = EOpConvUintToFloat16; break;
|
||||
default: assert(0);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue