Add support for GL_NV_shader_invocation_reorder. (#3054)

This commit is contained in:
alelenv 2022-12-09 12:19:08 -08:00 committed by GitHub
parent 19d816c8c9
commit 4fc43cd3c1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
30 changed files with 7478 additions and 5636 deletions

View file

@ -278,12 +278,10 @@ protected:
// requiring local translation to and from SPIR-V type on every access.
// Maps <builtin-variable-id -> AST-required-type-id>
std::unordered_map<spv::Id, spv::Id> forceType;
// Used later for generating OpTraceKHR/OpExecuteCallableKHR
std::unordered_map<unsigned int, glslang::TIntermSymbol *> locationToSymbol[2];
// Used by Task shader while generating opearnds for OpEmitMeshTasksEXT
spv::Id taskPayloadID;
// Used later for generating OpTraceKHR/OpExecuteCallableKHR/OpHitObjectRecordHit*/OpHitObjectGetShaderBindingTableData
std::unordered_map<unsigned int, glslang::TIntermSymbol *> locationToSymbol[4];
};
//
@ -392,6 +390,7 @@ spv::Decoration TranslateBlockDecoration(const glslang::TType& type, bool useSto
case glslang::EvqHitAttr: return spv::DecorationBlock;
case glslang::EvqCallableData: return spv::DecorationBlock;
case glslang::EvqCallableDataIn: return spv::DecorationBlock;
case glslang::EvqHitObjectAttrNV: return spv::DecorationBlock;
#endif
default:
assert(0);
@ -469,6 +468,7 @@ spv::Decoration TranslateLayoutDecoration(const glslang::TType& type, glslang::T
case glslang::EvqHitAttr:
case glslang::EvqCallableData:
case glslang::EvqCallableDataIn:
case glslang::EvqHitObjectAttrNV:
return spv::DecorationMax;
#endif
default:
@ -1299,7 +1299,7 @@ spv::LoopControlMask TGlslangToSpvTraverser::TranslateLoopControl(const glslang:
// Translate glslang type to SPIR-V storage class.
spv::StorageClass TGlslangToSpvTraverser::TranslateStorageClass(const glslang::TType& type)
{
if (type.getBasicType() == glslang::EbtRayQuery)
if (type.getBasicType() == glslang::EbtRayQuery || type.getBasicType() == glslang::EbtHitObjectNV)
return spv::StorageClassPrivate;
#ifndef GLSLANG_WEB
if (type.getQualifier().isSpirvByReference()) {
@ -1356,6 +1356,7 @@ spv::StorageClass TGlslangToSpvTraverser::TranslateStorageClass(const glslang::T
case glslang::EvqCallableData: return spv::StorageClassCallableDataKHR;
case glslang::EvqCallableDataIn: return spv::StorageClassIncomingCallableDataKHR;
case glslang::EvqtaskPayloadSharedEXT : return spv::StorageClassTaskPayloadWorkgroupEXT;
case glslang::EvqHitObjectAttrNV: return spv::StorageClassHitObjectAttributeNV;
case glslang::EvqSpirvStorageClass: return static_cast<spv::StorageClass>(type.getQualifier().spirvStorageClass);
#endif
default:
@ -2582,6 +2583,35 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
spv::Builder::AccessChain::CoherentFlags lvalueCoherentFlags;
const auto hitObjectOpsWithLvalue = [](glslang::TOperator op) {
switch(op) {
case glslang::EOpReorderThreadNV:
case glslang::EOpHitObjectGetCurrentTimeNV:
case glslang::EOpHitObjectGetHitKindNV:
case glslang::EOpHitObjectGetPrimitiveIndexNV:
case glslang::EOpHitObjectGetGeometryIndexNV:
case glslang::EOpHitObjectGetInstanceIdNV:
case glslang::EOpHitObjectGetInstanceCustomIndexNV:
case glslang::EOpHitObjectGetObjectRayDirectionNV:
case glslang::EOpHitObjectGetObjectRayOriginNV:
case glslang::EOpHitObjectGetWorldRayDirectionNV:
case glslang::EOpHitObjectGetWorldRayOriginNV:
case glslang::EOpHitObjectGetWorldToObjectNV:
case glslang::EOpHitObjectGetObjectToWorldNV:
case glslang::EOpHitObjectGetRayTMaxNV:
case glslang::EOpHitObjectGetRayTMinNV:
case glslang::EOpHitObjectIsEmptyNV:
case glslang::EOpHitObjectIsHitNV:
case glslang::EOpHitObjectIsMissNV:
case glslang::EOpHitObjectRecordEmptyNV:
case glslang::EOpHitObjectGetShaderBindingTableRecordIndexNV:
case glslang::EOpHitObjectGetShaderRecordBufferHandleNV:
return true;
default:
return false;
}
};
#ifndef GLSLANG_WEB
if (node->getOp() == glslang::EOpAtomicCounterIncrement ||
node->getOp() == glslang::EOpAtomicCounterDecrement ||
@ -2596,7 +2626,8 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
node->getOp() == glslang::EOpRayQueryGetIntersectionCandidateAABBOpaque ||
node->getOp() == glslang::EOpRayQueryTerminate ||
node->getOp() == glslang::EOpRayQueryConfirmIntersection ||
(node->getOp() == glslang::EOpSpirvInst && operandNode->getAsTyped()->getQualifier().isSpirvByReference())) {
(node->getOp() == glslang::EOpSpirvInst && operandNode->getAsTyped()->getQualifier().isSpirvByReference()) ||
hitObjectOpsWithLvalue(node->getOp())) {
operand = builder.accessChainGetLValue(); // Special case l-value operands
lvalueCoherentFlags = builder.getAccessChain().coherentFlags;
lvalueCoherentFlags |= TranslateCoherent(operandNode->getAsTyped()->getType());
@ -2731,6 +2762,12 @@ bool TGlslangToSpvTraverser::visitUnary(glslang::TVisit /* visit */, glslang::TI
case glslang::EOpRayQueryConfirmIntersection:
builder.createNoResultOp(spv::OpRayQueryConfirmIntersectionKHR, operand);
return false;
case glslang::EOpReorderThreadNV:
builder.createNoResultOp(spv::OpReorderThreadWithHitObjectNV, operand);
return false;
case glslang::EOpHitObjectRecordEmptyNV:
builder.createNoResultOp(spv::OpHitObjectRecordEmptyNV, operand);
return false;
#endif
default:
@ -3222,6 +3259,43 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
builder.addExtension(spv::E_SPV_EXT_fragment_shader_interlock);
noReturnValue = true;
break;
case glslang::EOpHitObjectTraceRayNV:
case glslang::EOpHitObjectTraceRayMotionNV:
case glslang::EOpHitObjectGetAttributesNV:
case glslang::EOpHitObjectExecuteShaderNV:
case glslang::EOpHitObjectRecordEmptyNV:
case glslang::EOpHitObjectRecordMissNV:
case glslang::EOpHitObjectRecordMissMotionNV:
case glslang::EOpHitObjectRecordHitNV:
case glslang::EOpHitObjectRecordHitMotionNV:
case glslang::EOpHitObjectRecordHitWithIndexNV:
case glslang::EOpHitObjectRecordHitWithIndexMotionNV:
case glslang::EOpReorderThreadNV:
noReturnValue = true;
//Fallthrough
case glslang::EOpHitObjectIsEmptyNV:
case glslang::EOpHitObjectIsMissNV:
case glslang::EOpHitObjectIsHitNV:
case glslang::EOpHitObjectGetRayTMinNV:
case glslang::EOpHitObjectGetRayTMaxNV:
case glslang::EOpHitObjectGetObjectRayOriginNV:
case glslang::EOpHitObjectGetObjectRayDirectionNV:
case glslang::EOpHitObjectGetWorldRayOriginNV:
case glslang::EOpHitObjectGetWorldRayDirectionNV:
case glslang::EOpHitObjectGetObjectToWorldNV:
case glslang::EOpHitObjectGetWorldToObjectNV:
case glslang::EOpHitObjectGetInstanceCustomIndexNV:
case glslang::EOpHitObjectGetInstanceIdNV:
case glslang::EOpHitObjectGetGeometryIndexNV:
case glslang::EOpHitObjectGetPrimitiveIndexNV:
case glslang::EOpHitObjectGetHitKindNV:
case glslang::EOpHitObjectGetCurrentTimeNV:
case glslang::EOpHitObjectGetShaderBindingTableRecordIndexNV:
case glslang::EOpHitObjectGetShaderRecordBufferHandleNV:
builder.addExtension(spv::E_SPV_NV_shader_invocation_reorder);
builder.addCapability(spv::CapabilityShaderInvocationReorderNV);
break;
#endif
case glslang::EOpDebugPrintf:
@ -3279,6 +3353,22 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
lvalue = true;
break;
case glslang::EOpHitObjectRecordHitNV:
case glslang::EOpHitObjectRecordHitMotionNV:
case glslang::EOpHitObjectRecordHitWithIndexNV:
case glslang::EOpHitObjectRecordHitWithIndexMotionNV:
case glslang::EOpHitObjectTraceRayNV:
case glslang::EOpHitObjectTraceRayMotionNV:
case glslang::EOpHitObjectExecuteShaderNV:
case glslang::EOpHitObjectRecordMissNV:
case glslang::EOpHitObjectRecordMissMotionNV:
case glslang::EOpHitObjectGetAttributesNV:
if (arg == 0)
lvalue = true;
break;
case glslang::EOpRayQueryInitialize:
case glslang::EOpRayQueryTerminate:
case glslang::EOpRayQueryConfirmIntersection:
@ -3379,6 +3469,11 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
if (glslangOperands[arg]->getAsTyped()->getQualifier().isSpirvByReference())
lvalue = true;
break;
case glslang::EOpReorderThreadNV:
//Three variants of reorderThreadNV, two of them use hitObjectNV
if (arg == 0 && glslangOperands.size() != 2)
lvalue = true;
break;
#endif
default:
break;
@ -3435,7 +3530,7 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
} else if (arg == 2) {
continue;
}
}
}
#endif
// for l-values, pass the address, for r-values, pass the value
@ -3477,11 +3572,23 @@ bool TGlslangToSpvTraverser::visitAggregate(glslang::TVisit visit, glslang::TInt
operands.push_back(builder.makeIntConstant(cond ? 1 : 0));
} else if ((arg == 10 && glslangOp == glslang::EOpTraceKHR) ||
(arg == 11 && glslangOp == glslang::EOpTraceRayMotionNV) ||
(arg == 1 && glslangOp == glslang::EOpExecuteCallableKHR)) {
const int opdNum = glslangOp == glslang::EOpTraceKHR ? 10 : (glslangOp == glslang::EOpTraceRayMotionNV ? 11 : 1);
(arg == 1 && glslangOp == glslang::EOpExecuteCallableKHR) ||
(arg == 1 && glslangOp == glslang::EOpHitObjectExecuteShaderNV) ||
(arg == 11 && glslangOp == glslang::EOpHitObjectTraceRayNV) ||
(arg == 12 && glslangOp == glslang::EOpHitObjectTraceRayMotionNV)) {
const int set = glslangOp == glslang::EOpExecuteCallableKHR ? 1 : 0;
const int location = glslangOperands[opdNum]->getAsConstantUnion()->getConstArray()[0].getUConst();
const int location = glslangOperands[arg]->getAsConstantUnion()->getConstArray()[0].getUConst();
auto itNode = locationToSymbol[set].find(location);
visitSymbol(itNode->second);
spv::Id symId = getSymbolId(itNode->second);
operands.push_back(symId);
} else if ((arg == 12 && glslangOp == glslang::EOpHitObjectRecordHitNV) ||
(arg == 13 && glslangOp == glslang::EOpHitObjectRecordHitMotionNV) ||
(arg == 11 && glslangOp == glslang::EOpHitObjectRecordHitWithIndexNV) ||
(arg == 12 && glslangOp == glslang::EOpHitObjectRecordHitWithIndexMotionNV) ||
(arg == 1 && glslangOp == glslang::EOpHitObjectGetAttributesNV)) {
const int location = glslangOperands[arg]->getAsConstantUnion()->getConstArray()[0].getUConst();
const int set = 2;
auto itNode = locationToSymbol[set].find(location);
visitSymbol(itNode->second);
spv::Id symId = getSymbolId(itNode->second);
@ -4307,6 +4414,13 @@ spv::Id TGlslangToSpvTraverser::convertGlslangToSpvType(const glslang::TType& ty
case glslang::EbtString:
// no type used for OpString
return 0;
case glslang::EbtHitObjectNV: {
builder.addExtension(spv::E_SPV_NV_shader_invocation_reorder);
builder.addCapability(spv::CapabilityShaderInvocationReorderNV);
spvType = builder.makeHitObjectNVType();
}
break;
#ifndef GLSLANG_WEB
case glslang::EbtSpirvType: {
// GL_EXT_spirv_intrinsics
@ -4726,6 +4840,9 @@ void TGlslangToSpvTraverser::decorateStructType(const glslang::TType& type,
// Decorate the structure
builder.addDecoration(spvType, TranslateLayoutDecoration(type, qualifier.layoutMatrix));
builder.addDecoration(spvType, TranslateBlockDecoration(type, glslangIntermediate->usingStorageBuffer()));
if (qualifier.hasHitObjectShaderRecordNV())
builder.addDecoration(spvType, spv::DecorationHitObjectShaderRecordBufferNV);
}
// Turn the expression forming the array size into an id.
@ -5253,6 +5370,10 @@ void TGlslangToSpvTraverser::collectRayTracingLinkerObjects()
set = 1;
break;
case glslang::EvqHitObjectAttrNV:
set = 2;
break;
default:
set = -1;
}
@ -6897,6 +7018,83 @@ spv::Id TGlslangToSpvTraverser::createUnaryOperation(glslang::TOperator op, OpDe
case glslang::EOpConvUvec2ToAccStruct:
unaryOp = spv::OpConvertUToAccelerationStructureKHR;
break;
case glslang::EOpHitObjectIsEmptyNV:
unaryOp = spv::OpHitObjectIsEmptyNV;
break;
case glslang::EOpHitObjectIsMissNV:
unaryOp = spv::OpHitObjectIsMissNV;
break;
case glslang::EOpHitObjectIsHitNV:
unaryOp = spv::OpHitObjectIsHitNV;
break;
case glslang::EOpHitObjectGetObjectRayOriginNV:
unaryOp = spv::OpHitObjectGetObjectRayOriginNV;
break;
case glslang::EOpHitObjectGetObjectRayDirectionNV:
unaryOp = spv::OpHitObjectGetObjectRayDirectionNV;
break;
case glslang::EOpHitObjectGetWorldRayOriginNV:
unaryOp = spv::OpHitObjectGetWorldRayOriginNV;
break;
case glslang::EOpHitObjectGetWorldRayDirectionNV:
unaryOp = spv::OpHitObjectGetWorldRayDirectionNV;
break;
case glslang::EOpHitObjectGetObjectToWorldNV:
unaryOp = spv::OpHitObjectGetObjectToWorldNV;
break;
case glslang::EOpHitObjectGetWorldToObjectNV:
unaryOp = spv::OpHitObjectGetWorldToObjectNV;
break;
case glslang::EOpHitObjectGetRayTMinNV:
unaryOp = spv::OpHitObjectGetRayTMinNV;
break;
case glslang::EOpHitObjectGetRayTMaxNV:
unaryOp = spv::OpHitObjectGetRayTMaxNV;
break;
case glslang::EOpHitObjectGetPrimitiveIndexNV:
unaryOp = spv::OpHitObjectGetPrimitiveIndexNV;
break;
case glslang::EOpHitObjectGetInstanceIdNV:
unaryOp = spv::OpHitObjectGetInstanceIdNV;
break;
case glslang::EOpHitObjectGetInstanceCustomIndexNV:
unaryOp = spv::OpHitObjectGetInstanceCustomIndexNV;
break;
case glslang::EOpHitObjectGetGeometryIndexNV:
unaryOp = spv::OpHitObjectGetGeometryIndexNV;
break;
case glslang::EOpHitObjectGetHitKindNV:
unaryOp = spv::OpHitObjectGetHitKindNV;
break;
case glslang::EOpHitObjectGetCurrentTimeNV:
unaryOp = spv::OpHitObjectGetCurrentTimeNV;
break;
case glslang::EOpHitObjectGetShaderBindingTableRecordIndexNV:
unaryOp = spv::OpHitObjectGetShaderBindingTableRecordIndexNV;
break;
case glslang::EOpHitObjectGetShaderRecordBufferHandleNV:
unaryOp = spv::OpHitObjectGetShaderRecordBufferHandleNV;
break;
#endif
case glslang::EOpCopyObject:
@ -8638,6 +8836,122 @@ spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::
case glslang::EOpCooperativeMatrixMulAdd:
opCode = spv::OpCooperativeMatrixMulAddNV;
break;
case glslang::EOpHitObjectTraceRayNV:
builder.createNoResultOp(spv::OpHitObjectTraceRayNV, operands);
return 0;
case glslang::EOpHitObjectTraceRayMotionNV:
builder.createNoResultOp(spv::OpHitObjectTraceRayMotionNV, operands);
return 0;
case glslang::EOpHitObjectRecordHitNV:
builder.createNoResultOp(spv::OpHitObjectRecordHitNV, operands);
return 0;
case glslang::EOpHitObjectRecordHitMotionNV:
builder.createNoResultOp(spv::OpHitObjectRecordHitMotionNV, operands);
return 0;
case glslang::EOpHitObjectRecordHitWithIndexNV:
builder.createNoResultOp(spv::OpHitObjectRecordHitWithIndexNV, operands);
return 0;
case glslang::EOpHitObjectRecordHitWithIndexMotionNV:
builder.createNoResultOp(spv::OpHitObjectRecordHitWithIndexMotionNV, operands);
return 0;
case glslang::EOpHitObjectRecordMissNV:
builder.createNoResultOp(spv::OpHitObjectRecordMissNV, operands);
return 0;
case glslang::EOpHitObjectRecordMissMotionNV:
builder.createNoResultOp(spv::OpHitObjectRecordMissMotionNV, operands);
return 0;
case glslang::EOpHitObjectExecuteShaderNV:
builder.createNoResultOp(spv::OpHitObjectExecuteShaderNV, operands);
return 0;
case glslang::EOpHitObjectIsEmptyNV:
typeId = builder.makeBoolType();
opCode = spv::OpHitObjectIsEmptyNV;
break;
case glslang::EOpHitObjectIsMissNV:
typeId = builder.makeBoolType();
opCode = spv::OpHitObjectIsMissNV;
break;
case glslang::EOpHitObjectIsHitNV:
typeId = builder.makeBoolType();
opCode = spv::OpHitObjectIsHitNV;
break;
case glslang::EOpHitObjectGetRayTMinNV:
typeId = builder.makeFloatType(32);
opCode = spv::OpHitObjectGetRayTMinNV;
break;
case glslang::EOpHitObjectGetRayTMaxNV:
typeId = builder.makeFloatType(32);
opCode = spv::OpHitObjectGetRayTMaxNV;
break;
case glslang::EOpHitObjectGetObjectRayOriginNV:
typeId = builder.makeVectorType(builder.makeFloatType(32), 3);
opCode = spv::OpHitObjectGetObjectRayOriginNV;
break;
case glslang::EOpHitObjectGetObjectRayDirectionNV:
typeId = builder.makeVectorType(builder.makeFloatType(32), 3);
opCode = spv::OpHitObjectGetObjectRayDirectionNV;
break;
case glslang::EOpHitObjectGetWorldRayOriginNV:
typeId = builder.makeVectorType(builder.makeFloatType(32), 3);
opCode = spv::OpHitObjectGetWorldRayOriginNV;
break;
case glslang::EOpHitObjectGetWorldRayDirectionNV:
typeId = builder.makeVectorType(builder.makeFloatType(32), 3);
opCode = spv::OpHitObjectGetWorldRayDirectionNV;
break;
case glslang::EOpHitObjectGetWorldToObjectNV:
typeId = builder.makeMatrixType(builder.makeFloatType(32), 4, 3);
opCode = spv::OpHitObjectGetWorldToObjectNV;
break;
case glslang::EOpHitObjectGetObjectToWorldNV:
typeId = builder.makeMatrixType(builder.makeFloatType(32), 4, 3);
opCode = spv::OpHitObjectGetObjectToWorldNV;
break;
case glslang::EOpHitObjectGetInstanceCustomIndexNV:
typeId = builder.makeIntegerType(32, 1);
opCode = spv::OpHitObjectGetInstanceCustomIndexNV;
break;
case glslang::EOpHitObjectGetInstanceIdNV:
typeId = builder.makeIntegerType(32, 1);
opCode = spv::OpHitObjectGetInstanceIdNV;
break;
case glslang::EOpHitObjectGetGeometryIndexNV:
typeId = builder.makeIntegerType(32, 1);
opCode = spv::OpHitObjectGetGeometryIndexNV;
break;
case glslang::EOpHitObjectGetPrimitiveIndexNV:
typeId = builder.makeIntegerType(32, 1);
opCode = spv::OpHitObjectGetPrimitiveIndexNV;
break;
case glslang::EOpHitObjectGetHitKindNV:
typeId = builder.makeIntegerType(32, 0);
opCode = spv::OpHitObjectGetHitKindNV;
break;
case glslang::EOpHitObjectGetCurrentTimeNV:
typeId = builder.makeFloatType(32);
opCode = spv::OpHitObjectGetCurrentTimeNV;
break;
case glslang::EOpHitObjectGetShaderBindingTableRecordIndexNV:
typeId = builder.makeIntegerType(32, 0);
opCode = spv::OpHitObjectGetShaderBindingTableRecordIndexNV;
return 0;
case glslang::EOpHitObjectGetAttributesNV:
builder.createNoResultOp(spv::OpHitObjectGetAttributesNV, operands);
return 0;
case glslang::EOpHitObjectGetShaderRecordBufferHandleNV:
typeId = builder.makeVectorType(builder.makeUintType(32), 2);
opCode = spv::OpHitObjectGetShaderRecordBufferHandleNV;
break;
case glslang::EOpReorderThreadNV: {
if (operands.size() == 2) {
builder.createNoResultOp(spv::OpReorderThreadWithHintNV, operands);
} else {
builder.createNoResultOp(spv::OpReorderThreadWithHitObjectNV, operands);
}
return 0;
}
break;
#endif // GLSLANG_WEB
default:
return 0;
@ -8947,13 +9261,17 @@ spv::Id TGlslangToSpvTraverser::getSymbolId(const glslang::TIntermSymbol* symbol
}
if (symbol->getQualifier().hasLocation()) {
if (!(glslangIntermediate->isRayTracingStage() && glslangIntermediate->IsRequestedExtension(glslang::E_GL_EXT_ray_tracing)
if (!(glslangIntermediate->isRayTracingStage() &&
(glslangIntermediate->IsRequestedExtension(glslang::E_GL_EXT_ray_tracing) ||
glslangIntermediate->IsRequestedExtension(glslang::E_GL_NV_shader_invocation_reorder))
&& (builder.getStorageClass(id) == spv::StorageClassRayPayloadKHR ||
builder.getStorageClass(id) == spv::StorageClassIncomingRayPayloadKHR ||
builder.getStorageClass(id) == spv::StorageClassCallableDataKHR ||
builder.getStorageClass(id) == spv::StorageClassIncomingCallableDataKHR))) {
// Location values are used to link TraceRayKHR and ExecuteCallableKHR to corresponding variables
// but are not valid in SPIRV since they are supported only for Input/Output Storage classes.
builder.getStorageClass(id) == spv::StorageClassIncomingCallableDataKHR ||
builder.getStorageClass(id) == spv::StorageClassHitObjectAttributeNV))) {
// Location values are used to link TraceRayKHR/ExecuteCallableKHR/HitObjectGetAttributesNV
// to corresponding variables but are not valid in SPIRV since they are supported only
// for Input/Output Storage classes.
builder.addDecoration(id, spv::DecorationLocation, symbol->getQualifier().layoutLocation);
}
}