HLSL: HS return is arrayed to match SPIR-V semantics

HLSL HS outputs a per ctrl point value, and the DS reads an array
of that type.  (It also has a per patch frequency).  The per-ctrl-pt
frequency is arrayed on just one side, as opposed to SPIR-V which
is arrayed on both.  To match semantics, the compiler creates an
array behind the scenes and indexes it by invocation ID, assigning
the HS return value to it.
This commit is contained in:
steve-lunarg 2017-03-22 18:39:25 -06:00
parent 7afe1344ca
commit e752f463c5
6 changed files with 477 additions and 382 deletions

View file

@ -1598,48 +1598,10 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
return paramNodes;
}
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
// Handle all [attrib] attribute for the shader entry point
void HlslParseContext::handleEntryPointAttributes(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// entry point logic...
// Handle entry-point function attributes
const TIntermAggregate* numThreads = attributes[EatNumThreads];
if (numThreads != nullptr) {
@ -1691,8 +1653,12 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
error(loc, "unsupported domain type", domainStr.c_str(), "");
}
if (! intermediate.setInputPrimitive(domain)) {
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
if (language == EShLangTessEvaluation) {
if (! intermediate.setInputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
} else {
if (! intermediate.setOutputPrimitive(domain))
error(loc, "cannot change previously set domain", TQualifier::getGeometryString(domain), "");
}
}
}
@ -1770,6 +1736,52 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
}
}
}
}
//
// Do all special handling for the entry point, including wrapping
// the shader's entry point with the official entry point that will call it.
//
// The following:
//
// retType shaderEntryPoint(args...) // shader declared entry point
// { body }
//
// Becomes
//
// out retType ret;
// in iargs<that are input>...;
// out oargs<that are output> ...;
//
// void shaderEntryPoint() // synthesized, but official, entry point
// {
// args<that are input> = iargs...;
// ret = @shaderEntryPoint(args...);
// oargs = args<that are output>...;
// }
//
// The symbol table will still map the original entry point name to the
// the modified function and it's new name:
//
// symbol table: shaderEntryPoint -> @shaderEntryPoint
//
// Returns nullptr if no entry-point tree was built, otherwise, returns
// a subtree that creates the entry point.
//
TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunction& userFunction, const TAttributeMap& attributes)
{
// if we aren't in the entry point, fix the IO as such and exit
if (userFunction.getName().compare(intermediate.getEntryPointName().c_str()) != 0) {
remapNonEntryPointIO(userFunction);
return nullptr;
}
entryPointFunction = &userFunction; // needed in finish()
// Handle entry point attributes
handleEntryPointAttributes(loc, userFunction, attributes);
// entry point logic...
// Move parameters and return value to shader in/out
TVariable* entryPointOutput; // gets created in remapEntryPointIO
@ -1838,10 +1850,37 @@ TIntermNode* HlslParseContext::transformEntryPoint(const TSourceLoc& loc, TFunct
currentCaller = userFunction.getMangledName();
// Return value
if (entryPointOutput)
intermediate.growAggregate(synthBody, handleAssign(loc, EOpAssign,
intermediate.addSymbol(*entryPointOutput), callReturn));
else
if (entryPointOutput) {
TIntermTyped* returnAssign;
if (language == EShLangTessControl) {
TIntermSymbol* invocationIdSym = findLinkageSymbol(EbvInvocationId);
// If there is no user declared invocation ID, we must make one.
if (invocationIdSym == nullptr) {
TType invocationIdType(EbtUint, EvqIn, 1);
TString* invocationIdName = NewPoolTString("InvocationId");
invocationIdType.getQualifier().builtIn = EbvInvocationId;
TVariable* variable = makeInternalVariable(*invocationIdName, invocationIdType);
globalQualifierFix(loc, variable->getWritableType().getQualifier());
trackLinkage(*variable);
invocationIdSym = intermediate.addSymbol(*variable);
}
TIntermTyped* element = intermediate.addIndex(EOpIndexIndirect, intermediate.addSymbol(*entryPointOutput),
invocationIdSym, loc);
element->setType(callReturn->getType());
returnAssign = handleAssign(loc, EOpAssign, element, callReturn);
} else {
returnAssign = handleAssign(loc, EOpAssign, intermediate.addSymbol(*entryPointOutput), callReturn);
}
intermediate.growAggregate(synthBody, returnAssign);
} else
intermediate.growAggregate(synthBody, callReturn);
// Output copies
@ -1914,10 +1953,29 @@ void HlslParseContext::remapEntryPointIO(TFunction& function, TVariable*& return
};
// return value is actually a shader-scoped output (out)
if (function.getType().getBasicType() == EbtVoid)
if (function.getType().getBasicType() == EbtVoid) {
returnValue = nullptr;
else
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
} else {
if (language == EShLangTessControl) {
// tessellation evaluation in HLSL writes a per-ctrl-pt value, but it needs to be an
// array in SPIR-V semantics. We'll write to it indexed by invocation ID.
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
TType outputType;
outputType.shallowCopy(function.getType());
// vertices has necessarily already been set when handling entry point attributes.
TArraySizes arraySizes;
arraySizes.addInnerSize(intermediate.getVertices());
outputType.newArraySizes(arraySizes);
clearUniformInputOutput(function.getWritableType().getQualifier());
returnValue = makeIoVariable("@entryPointOutput", outputType, EvqVaryingOut);
} else {
returnValue = makeIoVariable("@entryPointOutput", function.getWritableType(), EvqVaryingOut);
}
}
// parameters are actually shader-scoped inputs and outputs (in or out)
for (int i = 0; i < function.getParamCount(); i++) {
@ -7410,6 +7468,17 @@ void HlslParseContext::clearUniformInputOutput(TQualifier& qualifier)
correctUniform(qualifier);
}
// Return a symbol for the linkage variable of the given TBuiltInVariable type
TIntermSymbol* HlslParseContext::findLinkageSymbol(TBuiltInVariable biType) const
{
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
}
// Add patch constant function invocation
void HlslParseContext::addPatchConstantInvocation()
{
@ -7481,15 +7550,6 @@ void HlslParseContext::addPatchConstantInvocation()
}
};
// Return a symbol for the linkage variable of the given TBuiltInVariable type
const auto findLinkageSymbol = [this](TBuiltInVariable biType) -> TIntermSymbol* {
const auto it = builtInLinkageSymbols.find(biType);
if (it == builtInLinkageSymbols.end()) // if it wasn't declared by the user, return nullptr
return nullptr;
return intermediate.addSymbol(*it->second->getAsVariable());
};
const auto isPerCtrlPt = [this](const TType& type) {
// TODO: this is not sufficient to reject all such cases in malformed shaders.
return type.isArray() && !type.isRuntimeSizedArray();