diff --git a/SPIRV/GlslangToSpv.cpp b/SPIRV/GlslangToSpv.cpp index 5f631b4db1f374956297b958dd9a20cb94024586..146666527643f66f262ab27ee750c95701f40ed2 100755 --- a/SPIRV/GlslangToSpv.cpp +++ b/SPIRV/GlslangToSpv.cpp @@ -149,11 +149,14 @@ protected: spv::Id createBinaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right, glslang::TBasicType typeProxy, bool reduceComparison = true); spv::Id createBinaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id left, spv::Id right); spv::Id createUnaryOperation(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); - spv::Id createUnaryMatrixOperation(spv::Op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); + spv::Id createUnaryMatrixOperation(spv::Op op, spv::Decoration precision, spv::Decoration noContraction, spv::Id typeId, spv::Id operand,glslang::TBasicType typeProxy); spv::Id createConversion(glslang::TOperator op, spv::Decoration precision, spv::Decoration noContraction, spv::Id destTypeId, spv::Id operand, glslang::TBasicType typeProxy); spv::Id makeSmearedConstant(spv::Id constant, int vectorSize); spv::Id createAtomicOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); - spv::Id createInvocationsOperation(glslang::TOperator, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy); + spv::Id createInvocationsOperation(glslang::TOperator op, spv::Id typeId, spv::Id operand, glslang::TBasicType typeProxy); +#ifdef AMD_EXTENSIONS + spv::Id CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand); +#endif spv::Id createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy); spv::Id createNoArgOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId); spv::Id getSymbolId(const glslang::TIntermSymbol* node); @@ -3940,7 +3943,10 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op spvOp = spv::OpGroupIAdd; } - return builder.createOp(spvOp, typeId, operands); + if (builder.isVectorType(typeId)) + return CreateInvocationsVectorOperation(spvOp, typeId, operand); + else + return builder.createOp(spvOp, typeId, operands); } case glslang::EOpMinInvocationsNonUniform: case glslang::EOpMaxInvocationsNonUniform: @@ -3974,7 +3980,10 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op spvOp = spv::OpGroupIAddNonUniformAMD; } - return builder.createOp(spvOp, typeId, operands); + if (builder.isVectorType(typeId)) + return CreateInvocationsVectorOperation(spvOp, typeId, operand); + else + return builder.createOp(spvOp, typeId, operands); } #endif default: @@ -3983,6 +3992,48 @@ spv::Id TGlslangToSpvTraverser::createInvocationsOperation(glslang::TOperator op } } +#ifdef AMD_EXTENSIONS +// Create group invocation operations on a vector +spv::Id TGlslangToSpvTraverser::CreateInvocationsVectorOperation(spv::Op op, spv::Id typeId, spv::Id operand) +{ + assert(op == spv::OpGroupFMin || op == spv::OpGroupUMin || op == spv::OpGroupSMin || + op == spv::OpGroupFMax || op == spv::OpGroupUMax || op == spv::OpGroupSMax || + op == spv::OpGroupFAdd || op == spv::OpGroupIAdd || + op == spv::OpGroupFMinNonUniformAMD || op == spv::OpGroupUMinNonUniformAMD || op == spv::OpGroupSMinNonUniformAMD || + op == spv::OpGroupFMaxNonUniformAMD || op == spv::OpGroupUMaxNonUniformAMD || op == spv::OpGroupSMaxNonUniformAMD || + op == spv::OpGroupFAddNonUniformAMD || op == spv::OpGroupIAddNonUniformAMD); + + // Handle group invocation operations scalar by scalar. + // The result type is the same type as the original type. + // The algorithm is to: + // - break the vector into scalars + // - apply the operation to each scalar + // - make a vector out the scalar results + + // get the types sorted out + int numComponents = builder.getNumComponents(operand); + spv::Id scalarType = builder.getScalarTypeId(builder.getTypeId(operand)); + std::vector<spv::Id> results; + + // do each scalar op + for (int comp = 0; comp < numComponents; ++comp) { + std::vector<unsigned int> indexes; + indexes.push_back(comp); + spv::Id scalar = builder.createCompositeExtract(operand, scalarType, indexes); + + std::vector<spv::Id> operands; + operands.push_back(builder.makeUintConstant(spv::ScopeSubgroup)); + operands.push_back(spv::GroupOperationReduce); + operands.push_back(scalar); + + results.push_back(builder.createOp(op, scalarType, operands)); + } + + // put the pieces together + return builder.createCompositeConstruct(typeId, results); +} +#endif + spv::Id TGlslangToSpvTraverser::createMiscOperation(glslang::TOperator op, spv::Decoration precision, spv::Id typeId, std::vector<spv::Id>& operands, glslang::TBasicType typeProxy) { bool isUnsigned = typeProxy == glslang::EbtUint || typeProxy == glslang::EbtUint64;