HLSL: implement numthreads for compute shaders
This PR adds handling of the numthreads attribute for compute shaders, as well as a general infrastructure for returning attribute values from acceptAttributes, which may be needed in other cases, e.g, unroll(x), or merely to know if some attribute without params was given. A map of enum values from TAttributeType to TIntermAggregate nodes is built and returned. It can be queried with operator[] on the map. In the future there may be a need to also handle strings (e.g, for patchconstantfunc), and those can be easily added into the class if needed. New test is in hlsl.numthreads.comp.
This commit is contained in:
parent
e19e68d431
commit
1868b14435
10 changed files with 351 additions and 23 deletions
|
|
@ -37,6 +37,7 @@
|
|||
#include "hlslParseHelper.h"
|
||||
#include "hlslScanContext.h"
|
||||
#include "hlslGrammar.h"
|
||||
#include "hlslAttributes.h"
|
||||
|
||||
#include "../glslang/MachineIndependent/Scan.h"
|
||||
#include "../glslang/MachineIndependent/preprocessor/PpContext.h"
|
||||
|
|
@ -1045,7 +1046,8 @@ TFunction& HlslParseContext::handleFunctionDeclarator(const TSourceLoc& loc, TFu
|
|||
// Handle seeing the function prototype in front of a function definition in the grammar.
|
||||
// The body is handled after this function returns.
|
||||
//
|
||||
TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& loc, TFunction& function)
|
||||
TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& loc, TFunction& function,
|
||||
const TAttributeMap& attributes)
|
||||
{
|
||||
currentCaller = function.getMangledName();
|
||||
TSymbol* symbol = symbolTable.find(function.getMangledName());
|
||||
|
|
@ -1134,6 +1136,15 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
|
|||
controlFlowNestingLevel = 0;
|
||||
postMainReturn = false;
|
||||
|
||||
// Handle function attributes
|
||||
const TIntermAggregate* numThreadliterals = attributes[EatNumthreads];
|
||||
if (numThreadliterals != nullptr && inEntryPoint) {
|
||||
const TIntermSequence& sequence = numThreadliterals->getSequence();
|
||||
|
||||
for (int lid = 0; lid < int(sequence.size()); ++lid)
|
||||
intermediate.setLocalSize(lid, sequence[lid]->getAsConstantUnion()->getConstArray()[0].getIConst());
|
||||
}
|
||||
|
||||
return paramNodes;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue