Implement support for GL_KHR_cooperative_matrix extension

This commit is contained in:
Boris Zanin 2023-03-16 13:01:01 +01:00 committed by arcady-lunarg
parent 91a97b4c69
commit 808c7ed17c
40 changed files with 8227 additions and 5733 deletions

View file

@ -4397,6 +4397,94 @@ void TBuiltIns::initialize(int version, EProfile profile, const SpvVersion& spvV
"icoopmatNV coopMatMulAddNV(icoopmatNV A, icoopmatNV B, icoopmatNV C);\n"
"ucoopmatNV coopMatMulAddNV(ucoopmatNV A, ucoopmatNV B, ucoopmatNV C);\n"
);
std::string cooperativeMatrixFuncs =
"void coopMatLoad(out coopmat m, volatile coherent int8_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent int16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent int32_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent int64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent uint8_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent uint16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent uint32_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent uint64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent float16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent float[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent float64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent i64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent u64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatLoad(out coopmat m, volatile coherent f64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent int8_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent int16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent int32_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent int64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent uint8_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent uint16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent uint32_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent uint64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent float16_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent float[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent float64_t[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent i64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent u64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"void coopMatStore(coopmat m, volatile coherent f64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
"coopmat coopMatMulAdd(coopmat A, coopmat B, coopmat C);\n"
"coopmat coopMatMulAdd(coopmat A, coopmat B, coopmat C, int matrixOperands);\n";
commonBuiltins.append(cooperativeMatrixFuncs.c_str());
commonBuiltins.append(
"const int gl_MatrixUseA = 0;\n"
"const int gl_MatrixUseB = 1;\n"
"const int gl_MatrixUseAccumulator = 2;\n"
"const int gl_MatrixOperandsSaturatingAccumulation = 0x10;\n"
"const int gl_CooperativeMatrixLayoutRowMajor = 0;\n"
"const int gl_CooperativeMatrixLayoutColumnMajor = 1;\n"
"\n"
);
}
//============================================================================
@ -8897,6 +8985,12 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion
symbolTable.setFunctionExtensions("coopMatMulAddNV", 2, coopExt);
}
{
symbolTable.setFunctionExtensions("coopMatLoad", 1, &E_GL_KHR_cooperative_matrix);
symbolTable.setFunctionExtensions("coopMatStore", 1, &E_GL_KHR_cooperative_matrix);
symbolTable.setFunctionExtensions("coopMatMulAdd", 1, &E_GL_KHR_cooperative_matrix);
}
if ((profile != EEsProfile && version >= 450) || (profile == EEsProfile && version >= 320)) {
symbolTable.setFunctionExtensions("dFdx", 1, &E_GL_NV_compute_shader_derivatives);
symbolTable.setFunctionExtensions("dFdy", 1, &E_GL_NV_compute_shader_derivatives);
@ -10005,9 +10099,13 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion
symbolTable.relateToOperator("dFdyCoarse", EOpDPdyCoarse);
symbolTable.relateToOperator("fwidthCoarse",EOpFwidthCoarse);
}
symbolTable.relateToOperator("coopMatLoadNV", EOpCooperativeMatrixLoad);
symbolTable.relateToOperator("coopMatStoreNV", EOpCooperativeMatrixStore);
symbolTable.relateToOperator("coopMatMulAddNV", EOpCooperativeMatrixMulAdd);
symbolTable.relateToOperator("coopMatLoadNV", EOpCooperativeMatrixLoadNV);
symbolTable.relateToOperator("coopMatStoreNV", EOpCooperativeMatrixStoreNV);
symbolTable.relateToOperator("coopMatMulAddNV", EOpCooperativeMatrixMulAddNV);
symbolTable.relateToOperator("coopMatLoad", EOpCooperativeMatrixLoad);
symbolTable.relateToOperator("coopMatStore", EOpCooperativeMatrixStore);
symbolTable.relateToOperator("coopMatMulAdd", EOpCooperativeMatrixMulAdd);
break;
case EShLangRayGen: