HLSL: implement numthreads for compute shaders

This PR adds handling of the numthreads attribute for compute shaders, as well as a general
infrastructure for returning attribute values from acceptAttributes, which may be needed in other
cases, e.g, unroll(x), or merely to know if some attribute without params was given.

A map of enum values from TAttributeType to TIntermAggregate nodes is built and returned.  It
can be queried with operator[] on the map.  In the future there may be a need to also handle
strings (e.g, for patchconstantfunc), and those can be easily added into the class if needed.

New test is in hlsl.numthreads.comp.
This commit is contained in:
steve-lunarg 2016-10-20 13:07:10 -06:00
parent e19e68d431
commit 1868b14435
10 changed files with 351 additions and 23 deletions

View file

@ -37,6 +37,7 @@
#include "hlslParseHelper.h"
#include "hlslScanContext.h"
#include "hlslGrammar.h"
#include "hlslAttributes.h"
#include "../glslang/MachineIndependent/Scan.h"
#include "../glslang/MachineIndependent/preprocessor/PpContext.h"
@ -1045,7 +1046,8 @@ TFunction& HlslParseContext::handleFunctionDeclarator(const TSourceLoc& loc, TFu
// Handle seeing the function prototype in front of a function definition in the grammar.
// The body is handled after this function returns.
//
TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& loc, TFunction& function)
TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& loc, TFunction& function,
const TAttributeMap& attributes)
{
currentCaller = function.getMangledName();
TSymbol* symbol = symbolTable.find(function.getMangledName());
@ -1134,6 +1136,15 @@ TIntermAggregate* HlslParseContext::handleFunctionDefinition(const TSourceLoc& l
controlFlowNestingLevel = 0;
postMainReturn = false;
// Handle function attributes
const TIntermAggregate* numThreadliterals = attributes[EatNumthreads];
if (numThreadliterals != nullptr && inEntryPoint) {
const TIntermSequence& sequence = numThreadliterals->getSequence();
for (int lid = 0; lid < int(sequence.size()); ++lid)
intermediate.setLocalSize(lid, sequence[lid]->getAsConstantUnion()->getConstArray()[0].getIConst());
}
return paramNodes;
}