Implement support for GL_KHR_cooperative_matrix extension

This commit is contained in:
Boris Zanin 2023-03-16 13:01:01 +01:00 committed by arcady-lunarg
parent 91a97b4c69
commit 808c7ed17c
40 changed files with 8227 additions and 5733 deletions

View file

@ -66,6 +66,7 @@ enum TBasicType {
EbtReference,
EbtRayQuery,
EbtHitObjectNV,
EbtCoopmat,
#ifndef GLSLANG_WEB
// SPIR-V type defined by spirv_type
EbtSpirvType,

View file

@ -1541,6 +1541,19 @@ struct TShaderQualifiers {
}
};
class TTypeParameters {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TTypeParameters() : basicType(EbtVoid), arraySizes(nullptr) {}
TBasicType basicType;
TArraySizes *arraySizes;
bool operator==(const TTypeParameters& rhs) const { return basicType == rhs.basicType && *arraySizes == *rhs.arraySizes; }
bool operator!=(const TTypeParameters& rhs) const { return basicType != rhs.basicType || *arraySizes != *rhs.arraySizes; }
};
//
// TPublicType is just temporarily used while parsing and not quite the same
// information kept per node in TType. Due to the bison stack, it can't have
@ -1555,14 +1568,15 @@ public:
TSampler sampler;
TQualifier qualifier;
TShaderQualifiers shaderQualifiers;
int vectorSize : 4;
int matrixCols : 4;
int matrixRows : 4;
bool coopmat : 1;
int vectorSize : 4;
int matrixCols : 4;
int matrixRows : 4;
bool coopmatNV : 1;
bool coopmatKHR : 1;
TArraySizes* arraySizes;
const TType* userDef;
TSourceLoc loc;
TArraySizes* typeParameters;
TTypeParameters* typeParameters;
#ifndef GLSLANG_WEB
// SPIR-V type defined by spirv_type directive
TSpirvType* spirvType;
@ -1570,8 +1584,12 @@ public:
#ifdef GLSLANG_WEB
bool isCoopmat() const { return false; }
bool isCoopmatNV() const { return false; }
bool isCoopmatKHR() const { return false; }
#else
bool isCoopmat() const { return coopmat; }
bool isCoopmat() const { return coopmatNV || coopmatKHR; }
bool isCoopmatNV() const { return coopmatNV; }
bool isCoopmatKHR() const { return coopmatKHR; }
#endif
void initType(const TSourceLoc& l)
@ -1584,7 +1602,8 @@ public:
userDef = nullptr;
loc = l;
typeParameters = nullptr;
coopmat = false;
coopmatNV = false;
coopmatKHR = false;
#ifndef GLSLANG_WEB
spirvType = nullptr;
#endif
@ -1645,7 +1664,7 @@ public:
// for "empty" type (no args) or simple scalar/vector/matrix
explicit TType(TBasicType t = EbtVoid, TStorageQualifier q = EvqTemporary, int vs = 1, int mc = 0, int mr = 0,
bool isVector = false) :
basicType(t), vectorSize(vs), matrixCols(mc), matrixRows(mr), vector1(isVector && vs == 1), coopmat(false),
basicType(t), vectorSize(vs), matrixCols(mc), matrixRows(mr), vector1(isVector && vs == 1), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
arraySizes(nullptr), structure(nullptr), fieldName(nullptr), typeName(nullptr), typeParameters(nullptr)
#ifndef GLSLANG_WEB
, spirvType(nullptr)
@ -1659,7 +1678,7 @@ public:
// for explicit precision qualifier
TType(TBasicType t, TStorageQualifier q, TPrecisionQualifier p, int vs = 1, int mc = 0, int mr = 0,
bool isVector = false) :
basicType(t), vectorSize(vs), matrixCols(mc), matrixRows(mr), vector1(isVector && vs == 1), coopmat(false),
basicType(t), vectorSize(vs), matrixCols(mc), matrixRows(mr), vector1(isVector && vs == 1), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
arraySizes(nullptr), structure(nullptr), fieldName(nullptr), typeName(nullptr), typeParameters(nullptr)
#ifndef GLSLANG_WEB
, spirvType(nullptr)
@ -1675,7 +1694,7 @@ public:
// for turning a TPublicType into a TType, using a shallow copy
explicit TType(const TPublicType& p) :
basicType(p.basicType),
vectorSize(p.vectorSize), matrixCols(p.matrixCols), matrixRows(p.matrixRows), vector1(false), coopmat(p.coopmat),
vectorSize(p.vectorSize), matrixCols(p.matrixCols), matrixRows(p.matrixRows), vector1(false), coopmatNV(p.coopmatNV), coopmatKHR(p.coopmatKHR), coopmatKHRuse(-1),
arraySizes(p.arraySizes), structure(nullptr), fieldName(nullptr), typeName(nullptr), typeParameters(p.typeParameters)
#ifndef GLSLANG_WEB
, spirvType(p.spirvType)
@ -1695,23 +1714,37 @@ public:
}
typeName = NewPoolTString(p.userDef->getTypeName().c_str());
}
if (p.isCoopmat() && p.typeParameters && p.typeParameters->getNumDims() > 0) {
int numBits = p.typeParameters->getDimSize(0);
if (p.isCoopmatNV() && p.typeParameters && p.typeParameters->arraySizes->getNumDims() > 0) {
int numBits = p.typeParameters->arraySizes->getDimSize(0);
if (p.basicType == EbtFloat && numBits == 16) {
basicType = EbtFloat16;
qualifier.precision = EpqNone;
} else if (p.basicType == EbtUint && numBits == 8) {
basicType = EbtUint8;
qualifier.precision = EpqNone;
} else if (p.basicType == EbtUint && numBits == 16) {
basicType = EbtUint16;
qualifier.precision = EpqNone;
} else if (p.basicType == EbtInt && numBits == 8) {
basicType = EbtInt8;
qualifier.precision = EpqNone;
} else if (p.basicType == EbtInt && numBits == 16) {
basicType = EbtInt16;
qualifier.precision = EpqNone;
}
}
if (p.isCoopmatKHR() && p.typeParameters && p.typeParameters->arraySizes->getNumDims() > 0) {
basicType = p.typeParameters->basicType;
if (p.typeParameters->arraySizes->getNumDims() == 4) {
coopmatKHRuse = p.typeParameters->arraySizes->getDimSize(3);
p.typeParameters->arraySizes->removeLastSize();
}
}
}
// for construction of sampler types
TType(const TSampler& sampler, TStorageQualifier q = EvqUniform, TArraySizes* as = nullptr) :
basicType(EbtSampler), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmat(false),
basicType(EbtSampler), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
arraySizes(as), structure(nullptr), fieldName(nullptr), typeName(nullptr),
sampler(sampler), typeParameters(nullptr)
#ifndef GLSLANG_WEB
@ -1758,14 +1791,16 @@ public:
vectorSize = 1;
vector1 = false;
} else if (isCoopMat()) {
coopmat = false;
coopmatNV = false;
coopmatKHR = false;
coopmatKHRuse = -1;
typeParameters = nullptr;
}
}
}
// for making structures, ...
TType(TTypeList* userDef, const TString& n) :
basicType(EbtStruct), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmat(false),
basicType(EbtStruct), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
arraySizes(nullptr), structure(userDef), fieldName(nullptr), typeParameters(nullptr)
#ifndef GLSLANG_WEB
, spirvType(nullptr)
@ -1777,7 +1812,7 @@ public:
}
// For interface blocks
TType(TTypeList* userDef, const TString& n, const TQualifier& q) :
basicType(EbtBlock), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmat(false),
basicType(EbtBlock), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
qualifier(q), arraySizes(nullptr), structure(userDef), fieldName(nullptr), typeParameters(nullptr)
#ifndef GLSLANG_WEB
, spirvType(nullptr)
@ -1788,7 +1823,7 @@ public:
}
// for block reference (first parameter must be EbtReference)
explicit TType(TBasicType t, const TType &p, const TString& n) :
basicType(t), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmat(false),
basicType(t), vectorSize(1), matrixCols(0), matrixRows(0), vector1(false), coopmatNV(false), coopmatKHR(false), coopmatKHRuse(-1),
arraySizes(nullptr), structure(nullptr), fieldName(nullptr), typeName(nullptr), typeParameters(nullptr)
#ifndef GLSLANG_WEB
, spirvType(nullptr)
@ -1827,7 +1862,9 @@ public:
#ifndef GLSLANG_WEB
spirvType = copyOf.spirvType;
#endif
coopmat = copyOf.isCoopMat();
coopmatNV = copyOf.isCoopMatNV();
coopmatKHR = copyOf.isCoopMatKHR();
coopmatKHRuse = copyOf.coopmatKHRuse;
}
// Make complete copy of the whole type graph rooted at 'copyOf'.
@ -1912,8 +1949,8 @@ public:
virtual const TArraySizes* getArraySizes() const { return arraySizes; }
virtual TArraySizes* getArraySizes() { return arraySizes; }
virtual TType* getReferentType() const { return referentType; }
virtual const TArraySizes* getTypeParameters() const { return typeParameters; }
virtual TArraySizes* getTypeParameters() { return typeParameters; }
virtual const TTypeParameters* getTypeParameters() const { return typeParameters; }
virtual TTypeParameters* getTypeParameters() { return typeParameters; }
virtual bool isScalar() const { return ! isVector() && ! isMatrix() && ! isStruct() && ! isArray(); }
virtual bool isScalarOrVec1() const { return isScalar() || vector1; }
@ -1968,14 +2005,19 @@ public:
#ifdef GLSLANG_WEB
bool isAtomic() const { return false; }
bool isCoopMat() const { return false; }
bool isCoopMatNV() const { return false; }
bool isCoopMatKHR() const { return false; }
bool isReference() const { return false; }
bool isSpirvType() const { return false; }
#else
bool isAtomic() const { return basicType == EbtAtomicUint; }
bool isCoopMat() const { return coopmat; }
bool isCoopMat() const { return coopmatNV || coopmatKHR; }
bool isCoopMatNV() const { return coopmatNV; }
bool isCoopMatKHR() const { return coopmatKHR; }
bool isReference() const { return getBasicType() == EbtReference; }
bool isSpirvType() const { return getBasicType() == EbtSpirvType; }
#endif
int getCoopMatKHRuse() const { return coopmatKHRuse; }
// return true if this type contains any subtype which satisfies the given predicate.
template <typename P>
@ -2092,7 +2134,7 @@ public:
}
bool containsCoopMat() const
{
return contains([](const TType* t) { return t->coopmat; } );
return contains([](const TType* t) { return t->coopmatNV || t->coopmatKHR; } );
}
bool containsReference() const
{
@ -2174,43 +2216,12 @@ public:
}
}
void updateTypeParameters(const TType& type)
{
// For when we may already be sharing existing array descriptors,
// keeping the pointers the same, just updating the contents.
assert(typeParameters != nullptr);
assert(type.typeParameters != nullptr);
*typeParameters = *type.typeParameters;
}
void copyTypeParameters(const TArraySizes& s)
void copyTypeParameters(const TTypeParameters& s)
{
// For setting a fresh new set of type parameters, not yet worrying about sharing.
typeParameters = new TArraySizes;
typeParameters = new TTypeParameters;
*typeParameters = s;
}
void transferTypeParameters(TArraySizes* s)
{
// For setting an already allocated set of sizes that this type can use
// (no copy made).
typeParameters = s;
}
void clearTypeParameters()
{
typeParameters = nullptr;
}
// Add inner array sizes, to any existing sizes, via copy; the
// sizes passed in can still be reused for other purposes.
void copyTypeParametersInnerSizes(const TArraySizes* s)
{
if (s != nullptr) {
if (typeParameters == nullptr)
copyTypeParameters(*s);
else
typeParameters->addInnerSizes(*s);
}
}
const char* getBasicString() const
{
@ -2243,6 +2254,7 @@ public:
case EbtReference: return "reference";
case EbtString: return "string";
case EbtSpirvType: return "spirv_type";
case EbtCoopmat: return "coopmat";
#endif
default: return "unknown type";
}
@ -2553,12 +2565,21 @@ public:
}
}
if (isParameterized()) {
if (isCoopMatKHR()) {
appendStr(" ");
appendStr("coopmat");
}
appendStr("<");
for (int i = 0; i < (int)typeParameters->getNumDims(); ++i) {
appendInt(typeParameters->getDimSize(i));
if (i != (int)typeParameters->getNumDims() - 1)
for (int i = 0; i < (int)typeParameters->arraySizes->getNumDims(); ++i) {
appendInt(typeParameters->arraySizes->getDimSize(i));
if (i != (int)typeParameters->arraySizes->getNumDims() - 1)
appendStr(", ");
}
if (coopmatKHRuse != -1) {
appendStr(", ");
appendInt(coopmatKHRuse);
}
appendStr(">");
}
if (getPrecision && qualifier.precision != EpqNone) {
@ -2835,7 +2856,8 @@ public:
matrixCols == right.matrixCols &&
matrixRows == right.matrixRows &&
vector1 == right.vector1 &&
isCoopMat() == right.isCoopMat() &&
isCoopMatNV() == right.isCoopMatNV() &&
isCoopMatKHR() == right.isCoopMatKHR() &&
sameStructType(right, lpidx, rpidx) &&
sameReferenceType(right);
}
@ -2844,29 +2866,70 @@ public:
// an OK function parameter
bool coopMatParameterOK(const TType& right) const
{
return isCoopMat() && right.isCoopMat() && (getBasicType() == right.getBasicType()) &&
typeParameters == nullptr && right.typeParameters != nullptr;
if (isCoopMatNV()) {
return right.isCoopMatNV() && (getBasicType() == right.getBasicType()) && typeParameters == nullptr &&
right.typeParameters != nullptr;
}
if (isCoopMatKHR() && right.isCoopMatKHR()) {
return ((getBasicType() == right.getBasicType()) || (getBasicType() == EbtCoopmat) ||
(right.getBasicType() == EbtCoopmat)) &&
typeParameters == nullptr && right.typeParameters != nullptr;
}
return false;
}
bool sameCoopMatBaseType(const TType &right) const {
bool rv = coopmat && right.coopmat;
if (getBasicType() == EbtFloat || getBasicType() == EbtFloat16)
rv = right.getBasicType() == EbtFloat || right.getBasicType() == EbtFloat16;
else if (getBasicType() == EbtUint || getBasicType() == EbtUint8)
rv = right.getBasicType() == EbtUint || right.getBasicType() == EbtUint8;
else if (getBasicType() == EbtInt || getBasicType() == EbtInt8)
rv = right.getBasicType() == EbtInt || right.getBasicType() == EbtInt8;
else
rv = false;
bool rv = false;
if (isCoopMatNV()) {
rv = isCoopMatNV() && right.isCoopMatNV();
if (getBasicType() == EbtFloat || getBasicType() == EbtFloat16)
rv = right.getBasicType() == EbtFloat || right.getBasicType() == EbtFloat16;
else if (getBasicType() == EbtUint || getBasicType() == EbtUint8 || getBasicType() == EbtUint16)
rv = right.getBasicType() == EbtUint || right.getBasicType() == EbtUint8 || right.getBasicType() == EbtUint16;
else if (getBasicType() == EbtInt || getBasicType() == EbtInt8 || getBasicType() == EbtInt16)
rv = right.getBasicType() == EbtInt || right.getBasicType() == EbtInt8 || right.getBasicType() == EbtInt16;
else
rv = false;
} else if (isCoopMatKHR() && right.isCoopMatKHR()) {
if (getBasicType() == EbtFloat || getBasicType() == EbtFloat16)
rv = right.getBasicType() == EbtFloat || right.getBasicType() == EbtFloat16 || right.getBasicType() == EbtCoopmat;
else if (getBasicType() == EbtUint || getBasicType() == EbtUint8 || getBasicType() == EbtUint16)
rv = right.getBasicType() == EbtUint || right.getBasicType() == EbtUint8 || right.getBasicType() == EbtUint16 || right.getBasicType() == EbtCoopmat;
else if (getBasicType() == EbtInt || getBasicType() == EbtInt8 || getBasicType() == EbtInt16)
rv = right.getBasicType() == EbtInt || right.getBasicType() == EbtInt8 || right.getBasicType() == EbtInt16 || right.getBasicType() == EbtCoopmat;
else
rv = false;
}
return rv;
}
bool sameCoopMatUse(const TType &right) const {
return coopmatKHRuse == right.coopmatKHRuse;
}
bool sameCoopMatShapeAndUse(const TType &right) const
{
if (!isCoopMat() || !right.isCoopMat() || isCoopMatKHR() != right.isCoopMatKHR())
return false;
if (coopmatKHRuse != right.coopmatKHRuse)
return false;
// Skip bit width type parameter (first array size) for coopmatNV
int firstArrayDimToCompare = isCoopMatNV() ? 1 : 0;
for (int i = firstArrayDimToCompare; i < typeParameters->arraySizes->getNumDims(); ++i) {
if (typeParameters->arraySizes->getDimSize(i) != right.typeParameters->arraySizes->getDimSize(i))
return false;
}
return true;
}
// See if two types match in all ways (just the actual type, not qualification)
bool operator==(const TType& right) const
{
#ifndef GLSLANG_WEB
return sameElementType(right) && sameArrayness(right) && sameTypeParameters(right) && sameSpirvType(right);
return sameElementType(right) && sameArrayness(right) && sameTypeParameters(right) && sameCoopMatUse(right) && sameSpirvType(right);
#else
return sameElementType(right) && sameArrayness(right) && sameTypeParameters(right);
#endif
@ -2923,8 +2986,10 @@ protected:
}
if (copyOf.typeParameters) {
typeParameters = new TArraySizes;
*typeParameters = *copyOf.typeParameters;
typeParameters = new TTypeParameters;
typeParameters->arraySizes = new TArraySizes;
*typeParameters->arraySizes = *copyOf.typeParameters->arraySizes;
typeParameters->basicType = copyOf.basicType;
}
if (copyOf.isStruct() && copyOf.structure) {
@ -2962,7 +3027,9 @@ protected:
// functionality is added.
// HLSL does have a 1-component vectors, so this will be true to disambiguate
// from a scalar.
bool coopmat : 1;
bool coopmatNV : 1;
bool coopmatKHR : 1;
int coopmatKHRuse : 4; // Accepts one of three values: 0, 1, 2 (gl_MatrixUseA, gl_MatrixUseB, gl_MatrixUseAccumulator)
TQualifier qualifier;
TArraySizes* arraySizes; // nullptr unless an array; can be shared across types
@ -2975,7 +3042,7 @@ protected:
TString *fieldName; // for structure field names
TString *typeName; // for structure type name
TSampler sampler;
TArraySizes* typeParameters;// nullptr unless a parameterized type; can be shared across types
TTypeParameters *typeParameters;// nullptr unless a parameterized type; can be shared across types
#ifndef GLSLANG_WEB
TSpirvType* spirvType; // SPIR-V type defined by spirv_type directive
#endif

View file

@ -147,6 +147,15 @@ struct TSmallArrayVector {
sizes->erase(sizes->begin());
}
void pop_back()
{
assert(sizes != nullptr && sizes->size() > 0);
if (sizes->size() == 1)
dealloc();
else
sizes->resize(sizes->size() - 1);
}
// 'this' should currently not be holding anything, and copyNonFront
// will make it hold a copy of all but the first element of rhs.
// (This would be useful for making a type that is dereferenced by
@ -306,6 +315,7 @@ struct TArraySizes {
bool isDefaultImplicitlySized() const { return implicitlySized && implicitArraySize == 0; }
void setImplicitlySized(bool isImplicitSizing) { implicitlySized = isImplicitSizing; }
void dereference() { sizes.pop_front(); }
void removeLastSize() { sizes.pop_back(); }
void copyDereferenced(const TArraySizes& rhs)
{
assert(sizes.size() == 0);

View file

@ -629,6 +629,9 @@ enum TOperator {
EOpCooperativeMatrixLoad,
EOpCooperativeMatrixStore,
EOpCooperativeMatrixMulAdd,
EOpCooperativeMatrixLoadNV,
EOpCooperativeMatrixStoreNV,
EOpCooperativeMatrixMulAddNV,
EOpBeginInvocationInterlock, // Fragment only
EOpEndInvocationInterlock, // Fragment only
@ -766,7 +769,8 @@ enum TOperator {
EOpConstructTextureSampler,
EOpConstructNonuniform, // expected to be transformed away, not present in final AST
EOpConstructReference,
EOpConstructCooperativeMatrix,
EOpConstructCooperativeMatrixNV,
EOpConstructCooperativeMatrixKHR,
EOpConstructAccStruct,
EOpConstructGuardEnd,