Implement support for GL_KHR_cooperative_matrix extension
This commit is contained in:
parent
91a97b4c69
commit
808c7ed17c
40 changed files with 8227 additions and 5733 deletions
|
|
@ -4397,6 +4397,94 @@ void TBuiltIns::initialize(int version, EProfile profile, const SpvVersion& spvV
|
|||
"icoopmatNV coopMatMulAddNV(icoopmatNV A, icoopmatNV B, icoopmatNV C);\n"
|
||||
"ucoopmatNV coopMatMulAddNV(ucoopmatNV A, ucoopmatNV B, ucoopmatNV C);\n"
|
||||
);
|
||||
|
||||
std::string cooperativeMatrixFuncs =
|
||||
"void coopMatLoad(out coopmat m, volatile coherent int8_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent int16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent int32_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent int64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent uint8_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent uint16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent uint32_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent uint64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent float16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent float[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent float64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent i64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent u64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatLoad(out coopmat m, volatile coherent f64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"void coopMatStore(coopmat m, volatile coherent int8_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent int16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent int32_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent int64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent uint8_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent uint16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent uint32_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent uint64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent float16_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent float[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent float64_t[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"void coopMatStore(coopmat m, volatile coherent i8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u8vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f16vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f32vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f64vec2[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"void coopMatStore(coopmat m, volatile coherent i8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent i64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u8vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent u64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f16vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f32vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
"void coopMatStore(coopmat m, volatile coherent f64vec4[] buf, uint element, uint stride, int matrixLayout);\n"
|
||||
|
||||
"coopmat coopMatMulAdd(coopmat A, coopmat B, coopmat C);\n"
|
||||
"coopmat coopMatMulAdd(coopmat A, coopmat B, coopmat C, int matrixOperands);\n";
|
||||
|
||||
commonBuiltins.append(cooperativeMatrixFuncs.c_str());
|
||||
|
||||
commonBuiltins.append(
|
||||
"const int gl_MatrixUseA = 0;\n"
|
||||
"const int gl_MatrixUseB = 1;\n"
|
||||
"const int gl_MatrixUseAccumulator = 2;\n"
|
||||
"const int gl_MatrixOperandsSaturatingAccumulation = 0x10;\n"
|
||||
"const int gl_CooperativeMatrixLayoutRowMajor = 0;\n"
|
||||
"const int gl_CooperativeMatrixLayoutColumnMajor = 1;\n"
|
||||
"\n"
|
||||
);
|
||||
}
|
||||
|
||||
//============================================================================
|
||||
|
|
@ -8897,6 +8985,12 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion
|
|||
symbolTable.setFunctionExtensions("coopMatMulAddNV", 2, coopExt);
|
||||
}
|
||||
|
||||
{
|
||||
symbolTable.setFunctionExtensions("coopMatLoad", 1, &E_GL_KHR_cooperative_matrix);
|
||||
symbolTable.setFunctionExtensions("coopMatStore", 1, &E_GL_KHR_cooperative_matrix);
|
||||
symbolTable.setFunctionExtensions("coopMatMulAdd", 1, &E_GL_KHR_cooperative_matrix);
|
||||
}
|
||||
|
||||
if ((profile != EEsProfile && version >= 450) || (profile == EEsProfile && version >= 320)) {
|
||||
symbolTable.setFunctionExtensions("dFdx", 1, &E_GL_NV_compute_shader_derivatives);
|
||||
symbolTable.setFunctionExtensions("dFdy", 1, &E_GL_NV_compute_shader_derivatives);
|
||||
|
|
@ -10005,9 +10099,13 @@ void TBuiltIns::identifyBuiltIns(int version, EProfile profile, const SpvVersion
|
|||
symbolTable.relateToOperator("dFdyCoarse", EOpDPdyCoarse);
|
||||
symbolTable.relateToOperator("fwidthCoarse",EOpFwidthCoarse);
|
||||
}
|
||||
symbolTable.relateToOperator("coopMatLoadNV", EOpCooperativeMatrixLoad);
|
||||
symbolTable.relateToOperator("coopMatStoreNV", EOpCooperativeMatrixStore);
|
||||
symbolTable.relateToOperator("coopMatMulAddNV", EOpCooperativeMatrixMulAdd);
|
||||
symbolTable.relateToOperator("coopMatLoadNV", EOpCooperativeMatrixLoadNV);
|
||||
symbolTable.relateToOperator("coopMatStoreNV", EOpCooperativeMatrixStoreNV);
|
||||
symbolTable.relateToOperator("coopMatMulAddNV", EOpCooperativeMatrixMulAddNV);
|
||||
|
||||
symbolTable.relateToOperator("coopMatLoad", EOpCooperativeMatrixLoad);
|
||||
symbolTable.relateToOperator("coopMatStore", EOpCooperativeMatrixStore);
|
||||
symbolTable.relateToOperator("coopMatMulAdd", EOpCooperativeMatrixMulAdd);
|
||||
break;
|
||||
|
||||
case EShLangRayGen:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue