GL_NV_integer_cooperative_matrix support
This commit is contained in:
parent
a3bc04b278
commit
387657e4cf
15 changed files with 3580 additions and 2797 deletions
|
|
@ -6192,6 +6192,8 @@ const TFunction* TParseContext::findFunction400(const TSourceLoc& loc, const TFu
|
|||
}
|
||||
if (from.isArray() || to.isArray() || ! from.sameElementShape(to))
|
||||
return false;
|
||||
if (from.isCoopMat() && to.isCoopMat())
|
||||
return from.sameCoopMatBaseType(to);
|
||||
return intermediate.canImplicitlyPromote(from.getBasicType(), to.getBasicType());
|
||||
};
|
||||
|
||||
|
|
@ -6266,6 +6268,8 @@ const TFunction* TParseContext::findFunctionExplicitTypes(const TSourceLoc& loc,
|
|||
}
|
||||
if (from.isArray() || to.isArray() || ! from.sameElementShape(to))
|
||||
return false;
|
||||
if (from.isCoopMat() && to.isCoopMat())
|
||||
return from.sameCoopMatBaseType(to);
|
||||
return intermediate.canImplicitlyPromote(from.getBasicType(), to.getBasicType());
|
||||
};
|
||||
|
||||
|
|
@ -6365,12 +6369,20 @@ TIntermNode* TParseContext::declareVariable(const TSourceLoc& loc, TString& iden
|
|||
if (!publicType.typeParameters || publicType.typeParameters->getNumDims() != 4) {
|
||||
error(loc, "expected four type parameters", identifier.c_str(), "");
|
||||
}
|
||||
if (publicType.typeParameters &&
|
||||
publicType.typeParameters->getDimSize(0) != 16 &&
|
||||
publicType.typeParameters->getDimSize(0) != 32 &&
|
||||
publicType.typeParameters->getDimSize(0) != 64) {
|
||||
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
|
||||
if (publicType.typeParameters) {
|
||||
if (isTypeFloat(publicType.basicType) &&
|
||||
publicType.typeParameters->getDimSize(0) != 16 &&
|
||||
publicType.typeParameters->getDimSize(0) != 32 &&
|
||||
publicType.typeParameters->getDimSize(0) != 64) {
|
||||
error(loc, "expected 16, 32, or 64 bits for first type parameter", identifier.c_str(), "");
|
||||
}
|
||||
if (isTypeInt(publicType.basicType) &&
|
||||
publicType.typeParameters->getDimSize(0) != 8 &&
|
||||
publicType.typeParameters->getDimSize(0) != 32) {
|
||||
error(loc, "expected 8 or 32 bits for first type parameter", identifier.c_str(), "");
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
if (publicType.typeParameters && publicType.typeParameters->getNumDims() != 0) {
|
||||
error(loc, "unexpected type parameters", identifier.c_str(), "");
|
||||
|
|
@ -7065,19 +7077,90 @@ TIntermTyped* TParseContext::constructBuiltIn(const TType& type, TOperator op, T
|
|||
}
|
||||
node = intermediate.setAggregateOperator(node, EOpConstructCooperativeMatrix, type, node->getLoc());
|
||||
} else {
|
||||
TOperator op;
|
||||
switch (type.getBasicType()) {
|
||||
default:
|
||||
assert(0);
|
||||
break;
|
||||
case EbtInt:
|
||||
{
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToInt; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToInt; break;
|
||||
case EbtUint8: op = EOpConvUint8ToInt; break;
|
||||
case EbtInt8: op = EOpConvInt8ToInt; break;
|
||||
case EbtUint: op = EOpConvUintToInt; break;
|
||||
default: assert(0);
|
||||
}
|
||||
|
||||
}
|
||||
break;
|
||||
case EbtUint:
|
||||
{
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToUint; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToUint; break;
|
||||
case EbtUint8: op = EOpConvUint8ToUint; break;
|
||||
case EbtInt8: op = EOpConvInt8ToUint; break;
|
||||
case EbtInt: op = EOpConvIntToUint; break;
|
||||
case EbtUint: op = EOpConvUintToInt8; break;
|
||||
default: assert(0);
|
||||
}
|
||||
|
||||
}
|
||||
break;
|
||||
|
||||
case EbtInt8:
|
||||
{
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToInt8; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToInt8; break;
|
||||
case EbtUint8: op = EOpConvUint8ToInt8; break;
|
||||
case EbtInt: op = EOpConvIntToInt8; break;
|
||||
case EbtUint: op = EOpConvUintToInt8; break;
|
||||
default: assert(0);
|
||||
}
|
||||
|
||||
}
|
||||
break;
|
||||
case EbtUint8: {
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToUint8; break;
|
||||
case EbtFloat16: op = EOpConvFloat16ToUint8; break;
|
||||
case EbtInt8: op = EOpConvInt8ToUint8; break;
|
||||
case EbtInt: op = EOpConvIntToUint8; break;
|
||||
case EbtUint: op = EOpConvUintToUint8; break;
|
||||
default: assert(0);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case EbtFloat:
|
||||
assert(node->getType().getBasicType() == EbtFloat16);
|
||||
node = intermediate.addUnaryNode(EOpConvFloat16ToFloat, node, node->getLoc(), type);
|
||||
{
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat16: op = EOpConvFloat16ToFloat; break;
|
||||
case EbtInt8: op = EOpConvInt8ToFloat; break;
|
||||
case EbtUint8: op = EOpConvUint8ToFloat; break;
|
||||
case EbtInt: op = EOpConvIntToFloat; break;
|
||||
case EbtUint: op = EOpConvUintToFloat; break;
|
||||
default: assert(0);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case EbtFloat16:
|
||||
assert(node->getType().getBasicType() == EbtFloat);
|
||||
node = intermediate.addUnaryNode(EOpConvFloatToFloat16, node, node->getLoc(), type);
|
||||
{
|
||||
switch (node->getType().getBasicType()) {
|
||||
case EbtFloat: op = EOpConvFloatToFloat16; break;
|
||||
case EbtInt8: op = EOpConvInt8ToFloat16; break;
|
||||
case EbtUint8: op = EOpConvUint8ToFloat16; break;
|
||||
case EbtInt: op = EOpConvIntToFloat16; break;
|
||||
case EbtUint: op = EOpConvUintToFloat16; break;
|
||||
default: assert(0);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
node = intermediate.addUnaryNode(op, node, node->getLoc(), type);
|
||||
// If it's a (non-specialization) constant, it must be folded.
|
||||
if (node->getAsUnaryNode()->getOperand()->getAsConstantUnion())
|
||||
return node->getAsUnaryNode()->getOperand()->getAsConstantUnion()->fold(op, node->getType());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue