GL_NV_integer_cooperative_matrix support

This commit is contained in:
Jeff Bolz 2019-08-22 20:28:00 -05:00
parent a3bc04b278
commit 387657e4cf
15 changed files with 3580 additions and 2797 deletions

View file

@ -6192,6 +6192,8 @@ const TFunction* TParseContext::findFunction400(const TSourceLoc& loc, const TFu
}
if (from.isArray() || to.isArray() || ! from.sameElementShape(to))
return false;
if (from.isCoopMat() && to.isCoopMat())
return from.sameCoopMatBaseType(to);
return intermediate.canImplicitlyPromote(from.getBasicType(), to.getBasicType());
};
@ -6266,6 +6268,8 @@ const TFunction* TParseContext::findFunctionExplicitTypes(const TSourceLoc& loc,
}
if (from.isArray() || to.isArray() || ! from.sameElementShape(to))
return false;
if (from.isCoopMat() && to.isCoopMat())
return from.sameCoopMatBaseType(to);
return intermediate.canImplicitlyPromote(from.getBasicType(), to.getBasicType());
};
@ -6365,12 +6369,20 @@ TIntermNode* TParseContext::declareVariable(const TSourceLoc& loc, TString& iden
if (!publicType.typeParameters || publicType.typeParameters->getNumDims() != 4) {
error(loc, "expected four type parameters", identifier.c_str(), "");
}
if (publicType.typeParameters &&
publicType.typeParameters->getDimSize(0) != 16 &&
publicType.typeParameters->getDimSize(0) != 32 &&
publicType.typeParameters->getDimSize(0) != 64) {
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
if (publicType.typeParameters) {
if (isTypeFloat(publicType.basicType) &&
publicType.typeParameters->getDimSize(0) != 16 &&
publicType.typeParameters->getDimSize(0) != 32 &&
publicType.typeParameters->getDimSize(0) != 64) {
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
}
if (isTypeInt(publicType.basicType) &&
publicType.typeParameters->getDimSize(0) != 8 &&
publicType.typeParameters->getDimSize(0) != 32) {
error(loc, "expected 8 or 32 bits for first type parameter", identifier.c_str(), "");
}
}
} else {
if (publicType.typeParameters && publicType.typeParameters->getNumDims() != 0) {
error(loc, "unexpected type parameters", identifier.c_str(), "");
@ -7065,19 +7077,90 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
}
node = intermediate.setAggregateOperator(node, EOpConstructCooperativeMatrix, type, node->getLoc());
} else {
TOperator op;
switch (type.getBasicType()) {
default:
assert(0);
break;
case EbtInt:
{
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToInt; break;
case EbtFloat16: op = EOpConvFloat16ToInt; break;
case EbtUint8: op = EOpConvUint8ToInt; break;
case EbtInt8: op = EOpConvInt8ToInt; break;
case EbtUint: op = EOpConvUintToInt; break;
default: assert(0);
}
}
break;
case EbtUint:
{
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToUint; break;
case EbtFloat16: op = EOpConvFloat16ToUint; break;
case EbtUint8: op = EOpConvUint8ToUint; break;
case EbtInt8: op = EOpConvInt8ToUint; break;
case EbtInt: op = EOpConvIntToUint; break;
case EbtUint: op = EOpConvUintToInt8; break;
default: assert(0);
}
}
break;
case EbtInt8:
{
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToInt8; break;
case EbtFloat16: op = EOpConvFloat16ToInt8; break;
case EbtUint8: op = EOpConvUint8ToInt8; break;
case EbtInt: op = EOpConvIntToInt8; break;
case EbtUint: op = EOpConvUintToInt8; break;
default: assert(0);
}
}
break;
case EbtUint8: {
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToUint8; break;
case EbtFloat16: op = EOpConvFloat16ToUint8; break;
case EbtInt8: op = EOpConvInt8ToUint8; break;
case EbtInt: op = EOpConvIntToUint8; break;
case EbtUint: op = EOpConvUintToUint8; break;
default: assert(0);
}
}
break;
case EbtFloat:
assert(node->getType().getBasicType() == EbtFloat16);
node = intermediate.addUnaryNode(EOpConvFloat16ToFloat, node, node->getLoc(), type);
{
switch (node->getType().getBasicType()) {
case EbtFloat16: op = EOpConvFloat16ToFloat; break;
case EbtInt8: op = EOpConvInt8ToFloat; break;
case EbtUint8: op = EOpConvUint8ToFloat; break;
case EbtInt: op = EOpConvIntToFloat; break;
case EbtUint: op = EOpConvUintToFloat; break;
default: assert(0);
}
}
break;
case EbtFloat16:
assert(node->getType().getBasicType() == EbtFloat);
node = intermediate.addUnaryNode(EOpConvFloatToFloat16, node, node->getLoc(), type);
{
switch (node->getType().getBasicType()) {
case EbtFloat: op = EOpConvFloatToFloat16; break;
case EbtInt8: op = EOpConvInt8ToFloat16; break;
case EbtUint8: op = EOpConvUint8ToFloat16; break;
case EbtInt: op = EOpConvIntToFloat16; break;
case EbtUint: op = EOpConvUintToFloat16; break;
default: assert(0);
}
}
break;
}
node = intermediate.addUnaryNode(op, node, node->getLoc(), type);
// If it's a (non-specialization) constant, it must be folded.
if (node->getAsUnaryNode()->getOperand()->getAsConstantUnion())
return node->getAsUnaryNode()->getOperand()->getAsConstantUnion()->fold(op, node->getType());