Fix issues of the interaction between cooperative_matrix and spirv_intrinsics

coopmat<> type definition allows type parameters. To make it accept
types defined by spirv_type directive, we add spirvType info to the type
parameters. This change is to support this case. And a test is added to
show the missing usage.
This commit is contained in:
Rex Xu 2024-03-21 23:09:00 +08:00 committed by GitHub
parent 10ee92feb0
commit 022aea431c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 1136 additions and 1022 deletions

View file

@ -1435,13 +1435,25 @@ class TTypeParameters {
public:
POOL_ALLOCATOR_NEW_DELETE(GetThreadPoolAllocator())
TTypeParameters() : basicType(EbtVoid), arraySizes(nullptr) {}
TTypeParameters() : basicType(EbtVoid), arraySizes(nullptr), spirvType(nullptr) {}
TBasicType basicType;
TArraySizes *arraySizes;
TSpirvType *spirvType;
bool operator==(const TTypeParameters& rhs) const { return basicType == rhs.basicType && *arraySizes == *rhs.arraySizes; }
bool operator!=(const TTypeParameters& rhs) const { return basicType != rhs.basicType || *arraySizes != *rhs.arraySizes; }
bool operator==(const TTypeParameters& rhs) const
{
bool same = basicType == rhs.basicType && *arraySizes == *rhs.arraySizes;
if (same && basicType == EbtSpirvType) {
assert(spirvType && rhs.spirvType);
return *spirvType == *rhs.spirvType;
}
return same;
}
bool operator!=(const TTypeParameters& rhs) const
{
return !(*this == rhs);
}
};
//
@ -1618,6 +1630,10 @@ public:
}
if (p.isCoopmatKHR() && p.typeParameters && p.typeParameters->arraySizes->getNumDims() > 0) {
basicType = p.typeParameters->basicType;
if (isSpirvType()) {
assert(p.typeParameters->spirvType);
spirvType = p.typeParameters->spirvType;
}
if (p.typeParameters->arraySizes->getNumDims() == 4) {
const int dimSize = p.typeParameters->arraySizes->getDimSize(3);
@ -2719,7 +2735,8 @@ public:
if (isCoopMatKHR() && right.isCoopMatKHR()) {
return ((getBasicType() == right.getBasicType()) || (getBasicType() == EbtCoopmat) ||
(right.getBasicType() == EbtCoopmat)) &&
typeParameters == nullptr && right.typeParameters != nullptr;
((typeParameters == nullptr && right.typeParameters != nullptr) ||
(typeParameters != nullptr && right.typeParameters == nullptr));
}
return false;
}
@ -2825,6 +2842,7 @@ protected:
typeParameters = new TTypeParameters;
typeParameters->arraySizes = new TArraySizes;
*typeParameters->arraySizes = *copyOf.typeParameters->arraySizes;
*typeParameters->spirvType = *copyOf.typeParameters->spirvType;
typeParameters->basicType = copyOf.basicType;
}