diff --git a/Test/baseResults/hlsl.flatten.return.frag.out b/Test/baseResults/hlsl.flatten.return.frag.out new file mode 100644 index 0000000000000000000000000000000000000000..39fbf0ef87f08bddf1d62266bbc4ed3650dbc366 --- /dev/null +++ b/Test/baseResults/hlsl.flatten.return.frag.out @@ -0,0 +1,187 @@ +hlsl.flatten.return.frag +Shader version: 450 +gl_FragCoord origin is upper left +0:? Sequence +0:11 Function Definition: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:11 Function Parameters: +0:? Sequence +0:12 Branch: Return with expression +0:? Constant: +0:? 1.000000 +0:? 1.000000 +0:? 1.000000 +0:? 1.000000 +0:? 2.000000 +0:? 3.000000 +0:? 4.000000 +0:16 Function Definition: main( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:16 Function Parameters: +0:? Sequence +0:17 Sequence +0:17 Sequence +0:17 move second child to first child (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Function Call: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 move second child to first child (temp 4-component vector of float) +0:? 'color' (layout(location=0 ) out 4-component vector of float) +0:17 color: direct index for structure (temp 4-component vector of float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 0 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member1' (layout(location=1 ) out float) +0:17 other_struct_member1: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 1 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member2' (layout(location=2 ) out float) +0:17 other_struct_member2: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 2 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member3' (layout(location=3 ) out float) +0:17 other_struct_member3: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 3 (const int) +0:17 Branch: Return +0:? Linker Objects +0:? 'color' (layout(location=0 ) out 4-component vector of float) +0:? 'other_struct_member1' (layout(location=1 ) out float) +0:? 'other_struct_member2' (layout(location=2 ) out float) +0:? 'other_struct_member3' (layout(location=3 ) out float) + + +Linked fragment stage: + + +Shader version: 450 +gl_FragCoord origin is upper left +0:? Sequence +0:11 Function Definition: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:11 Function Parameters: +0:? Sequence +0:12 Branch: Return with expression +0:? Constant: +0:? 1.000000 +0:? 1.000000 +0:? 1.000000 +0:? 1.000000 +0:? 2.000000 +0:? 3.000000 +0:? 4.000000 +0:16 Function Definition: main( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:16 Function Parameters: +0:? Sequence +0:17 Sequence +0:17 Sequence +0:17 move second child to first child (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Function Call: Func1( (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 move second child to first child (temp 4-component vector of float) +0:? 'color' (layout(location=0 ) out 4-component vector of float) +0:17 color: direct index for structure (temp 4-component vector of float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 0 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member1' (layout(location=1 ) out float) +0:17 other_struct_member1: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 1 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member2' (layout(location=2 ) out float) +0:17 other_struct_member2: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 2 (const int) +0:17 move second child to first child (temp float) +0:? 'other_struct_member3' (layout(location=3 ) out float) +0:17 other_struct_member3: direct index for structure (temp float) +0:17 'flattenTemp' (temp structure{temp 4-component vector of float color, temp float other_struct_member1, temp float other_struct_member2, temp float other_struct_member3}) +0:17 Constant: +0:17 3 (const int) +0:17 Branch: Return +0:? Linker Objects +0:? 'color' (layout(location=0 ) out 4-component vector of float) +0:? 'other_struct_member1' (layout(location=1 ) out float) +0:? 'other_struct_member2' (layout(location=2 ) out float) +0:? 'other_struct_member3' (layout(location=3 ) out float) + +// Module Version 10000 +// Generated by (magic number): 80001 +// Id's are bound by 45 + + Capability Shader + 1: ExtInstImport "GLSL.std.450" + MemoryModel Logical GLSL450 + EntryPoint Fragment 4 "main" 24 31 36 40 + ExecutionMode 4 OriginUpperLeft + Name 4 "main" + Name 8 "PS_OUTPUT" + MemberName 8(PS_OUTPUT) 0 "color" + MemberName 8(PS_OUTPUT) 1 "other_struct_member1" + MemberName 8(PS_OUTPUT) 2 "other_struct_member2" + MemberName 8(PS_OUTPUT) 3 "other_struct_member3" + Name 10 "Func1(" + Name 21 "flattenTemp" + Name 24 "color" + Name 31 "other_struct_member1" + Name 36 "other_struct_member2" + Name 40 "other_struct_member3" + Decorate 24(color) Location 0 + Decorate 31(other_struct_member1) Location 1 + Decorate 36(other_struct_member2) Location 2 + Decorate 40(other_struct_member3) Location 3 + 2: TypeVoid + 3: TypeFunction 2 + 6: TypeFloat 32 + 7: TypeVector 6(float) 4 + 8(PS_OUTPUT): TypeStruct 7(fvec4) 6(float) 6(float) 6(float) + 9: TypeFunction 8(PS_OUTPUT) + 12: 6(float) Constant 1065353216 + 13: 7(fvec4) ConstantComposite 12 12 12 12 + 14: 6(float) Constant 1073741824 + 15: 6(float) Constant 1077936128 + 16: 6(float) Constant 1082130432 + 17:8(PS_OUTPUT) ConstantComposite 13 14 15 16 + 20: TypePointer Function 8(PS_OUTPUT) + 23: TypePointer Output 7(fvec4) + 24(color): 23(ptr) Variable Output + 25: TypeInt 32 1 + 26: 25(int) Constant 0 + 27: TypePointer Function 7(fvec4) + 30: TypePointer Output 6(float) +31(other_struct_member1): 30(ptr) Variable Output + 32: 25(int) Constant 1 + 33: TypePointer Function 6(float) +36(other_struct_member2): 30(ptr) Variable Output + 37: 25(int) Constant 2 +40(other_struct_member3): 30(ptr) Variable Output + 41: 25(int) Constant 3 + 4(main): 2 Function None 3 + 5: Label + 21(flattenTemp): 20(ptr) Variable Function + 22:8(PS_OUTPUT) FunctionCall 10(Func1() + Store 21(flattenTemp) 22 + 28: 27(ptr) AccessChain 21(flattenTemp) 26 + 29: 7(fvec4) Load 28 + Store 24(color) 29 + 34: 33(ptr) AccessChain 21(flattenTemp) 32 + 35: 6(float) Load 34 + Store 31(other_struct_member1) 35 + 38: 33(ptr) AccessChain 21(flattenTemp) 37 + 39: 6(float) Load 38 + Store 36(other_struct_member2) 39 + 42: 33(ptr) AccessChain 21(flattenTemp) 41 + 43: 6(float) Load 42 + Store 40(other_struct_member3) 43 + Return + FunctionEnd + 10(Func1():8(PS_OUTPUT) Function None 9 + 11: Label + ReturnValue 17 + FunctionEnd diff --git a/Test/hlsl.flatten.return.frag b/Test/hlsl.flatten.return.frag new file mode 100644 index 0000000000000000000000000000000000000000..c633e679882ee1e181ca28400267d20769ac1ffe --- /dev/null +++ b/Test/hlsl.flatten.return.frag @@ -0,0 +1,18 @@ + +struct PS_OUTPUT +{ + float4 color : SV_Target0; + float other_struct_member1; + float other_struct_member2; + float other_struct_member3; +}; + +PS_OUTPUT Func1() +{ + return PS_OUTPUT(float4(1), 2, 3, 4); +} + +PS_OUTPUT main() +{ + return Func1(); +} diff --git a/glslang/MachineIndependent/Intermediate.cpp b/glslang/MachineIndependent/Intermediate.cpp index cababc35905ae8b462c892f722897bfcd31a0dd0..97556202ffd6badb081a644bd10b92598ed6fef9 100644 --- a/glslang/MachineIndependent/Intermediate.cpp +++ b/glslang/MachineIndependent/Intermediate.cpp @@ -73,6 +73,16 @@ TIntermSymbol* TIntermediate::addSymbol(int id, const TString& name, const TType return node; } +TIntermSymbol* TIntermediate::addSymbol(const TIntermSymbol& intermSymbol) +{ + return addSymbol(intermSymbol.getId(), + intermSymbol.getName(), + intermSymbol.getType(), + intermSymbol.getConstArray(), + intermSymbol.getConstSubtree(), + intermSymbol.getLoc()); +} + TIntermSymbol* TIntermediate::addSymbol(const TVariable& variable) { glslang::TSourceLoc loc; // just a null location diff --git a/glslang/MachineIndependent/localintermediate.h b/glslang/MachineIndependent/localintermediate.h index 14b8a00a2c947a742ed71633098b6ab4c46654ba..acfafb1e860a96473f264a57664735f5998bdc71 100644 --- a/glslang/MachineIndependent/localintermediate.h +++ b/glslang/MachineIndependent/localintermediate.h @@ -201,6 +201,7 @@ public: TIntermSymbol* addSymbol(const TVariable&); TIntermSymbol* addSymbol(const TVariable&, const TSourceLoc&); TIntermSymbol* addSymbol(const TType&, const TSourceLoc&); + TIntermSymbol* addSymbol(const TIntermSymbol&); TIntermTyped* addConversion(TOperator, const TType&, TIntermTyped*) const; TIntermTyped* addShapeConversion(TOperator, const TType&, TIntermTyped*); TIntermTyped* addBinaryMath(TOperator, TIntermTyped* left, TIntermTyped* right, TSourceLoc); diff --git a/gtests/Hlsl.FromFile.cpp b/gtests/Hlsl.FromFile.cpp index 7467eb79f29dc8b379be45ec366c603fc03ab748..51be19ddbdaa9a22892feb92bde550e5a0a168a5 100644 --- a/gtests/Hlsl.FromFile.cpp +++ b/gtests/Hlsl.FromFile.cpp @@ -99,6 +99,7 @@ INSTANTIATE_TEST_CASE_P( {"hlsl.entry-out.frag", "PixelShaderFunction"}, {"hlsl.float1.frag", "PixelShaderFunction"}, {"hlsl.float4.frag", "PixelShaderFunction"}, + {"hlsl.flatten.return.frag", "main"}, {"hlsl.forLoop.frag", "PixelShaderFunction"}, {"hlsl.gather.array.dx10.frag", "main"}, {"hlsl.gather.basic.dx10.frag", "main"}, diff --git a/hlsl/hlslParseHelper.cpp b/hlsl/hlslParseHelper.cpp index 082a49604c43f9818604cfb1f407bc0e61220897..08ffd58068dfbed3a3169ea71ffead19b10cfc7d 100755 --- a/hlsl/hlslParseHelper.cpp +++ b/hlsl/hlslParseHelper.cpp @@ -952,10 +952,53 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op const TVector<TVariable*>* leftVariables = nullptr; const TVector<TVariable*>* rightVariables = nullptr; + // A temporary to store the right node's value, so we don't keep indirecting into it + // if it's not a simple symbol. + TVariable* rhsTempVar = nullptr; + + // If the RHS is a simple symbol node, we'll copy it for each member. + TIntermSymbol* cloneSymNode = nullptr; + + // Array structs are not yet handled in flattening. (Compilation error upstream, so + // this should never fire). + assert(!(left->getType().isStruct() && left->getType().isArray())); + + int memberCount = 0; + + // Track how many items there are to copy. + if (left->getType().isStruct()) + memberCount = left->getType().getStruct()->size(); + if (left->getType().isArray()) + memberCount = left->getType().getCumulativeArraySize(); + if (flattenLeft) leftVariables = &flattenMap.find(left->getAsSymbolNode()->getId())->second; - if (flattenRight) + + if (flattenRight) { rightVariables = &flattenMap.find(right->getAsSymbolNode()->getId())->second; + } else { + // The RHS is not flattened. There are several cases: + // 1. 1 item to copy: Use the RHS directly. + // 2. >1 item, simple symbol RHS: we'll create a new TIntermSymbol node for each, but no assign to temp. + // 3. >1 item, complex RHS: assign it to a new temp variable, and create a TIntermSymbol for each member. + + if (memberCount <= 1) { + // case 1: we'll use the symbol directly below. Nothing to do. + } else { + if (right->getAsSymbolNode() != nullptr) { + // case 2: we'll copy the symbol per iteration below. + cloneSymNode = right->getAsSymbolNode(); + } else { + // case 3: assign to a temp, and indirect into that. + rhsTempVar = makeInternalVariable("flattenTemp", right->getType()); + rhsTempVar->getWritableType().getQualifier().makeTemporary(); + TIntermTyped* noFlattenRHS = intermediate.addSymbol(*rhsTempVar, loc); + + // Add this to the aggregate being built. + assignList = intermediate.growAggregate(assignList, intermediate.addAssign(op, noFlattenRHS, right, loc), loc); + } + } + } const auto getMember = [&](bool flatten, TIntermTyped* node, const TVector<TVariable*>& memberVariables, int member, @@ -971,6 +1014,14 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op return subTree; }; + // Return the proper RHS node: a new symbol from a TVariable, copy + // of an TIntermSymbol node, or sometimes the right node directly. + const auto getRHS = [&]() { + return rhsTempVar ? intermediate.addSymbol(*rhsTempVar, loc) : + cloneSymNode ? intermediate.addSymbol(*cloneSymNode) : + right; + }; + // Handle struct assignment if (left->getType().isStruct()) { // If we get here, we are assigning to or from a whole struct that must be @@ -978,7 +1029,7 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op const auto& members = *left->getType().getStruct(); for (int member = 0; member < (int)members.size(); ++member) { - TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, member, + TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, member, EOpIndexDirectStruct, *members[member].type); TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, member, EOpIndexDirectStruct, *members[member].type); @@ -992,10 +1043,10 @@ TIntermTyped* HlslParseContext::handleAssign(const TSourceLoc& loc, TOperator op // flattened, so have to do member-by-member assignment: const TType dereferencedType(left->getType(), 0); - const int size = left->getType().getCumulativeArraySize(); - for (int element=0; element < size; ++element) { - TIntermTyped* subRight = getMember(flattenRight, right, *rightVariables, element, + for (int element=0; element < memberCount; ++element) { + // Add a new AST symbol node if we have a temp variable holding a complex RHS. + TIntermTyped* subRight = getMember(flattenRight, getRHS(), *rightVariables, element, EOpIndexDirect, dereferencedType); TIntermTyped* subLeft = getMember(flattenLeft, left, *leftVariables, element, EOpIndexDirect, dereferencedType); @@ -1235,9 +1286,9 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType // Return value from size query TVariable* tempArg = makeInternalVariable("sizeQueryTemp", sizeQuery->getType()); tempArg->getWritableType().getQualifier().makeTemporary(); - TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc); - - TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign, sizeQueryReturn, sizeQuery, loc); + TIntermTyped* sizeQueryAssign = intermediate.addAssign(EOpAssign, + intermediate.addSymbol(*tempArg, loc), + sizeQuery, loc); // Compound statement for assigning outputs TIntermAggregate* compoundStatement = intermediate.makeAggregate(sizeQueryAssign, loc); @@ -1246,6 +1297,7 @@ void HlslParseContext::decomposeSampleMethods(const TSourceLoc& loc, TIntermType for (int compNum = 0; compNum < numDims; ++compNum) { TIntermTyped* indexedOut = nullptr; + TIntermSymbol* sizeQueryReturn = intermediate.addSymbol(*tempArg, loc); if (numDims > 1) { TIntermTyped* component = intermediate.addConstantUnion(compNum, loc, true);