From 8700e9e6d1c5317cb015d29b09e454ca0748b378 Mon Sep 17 00:00:00 2001
From: John Kessenich <cepheus@frii.com>
Date: Fri, 30 Aug 2013 00:45:57 +0000
Subject: [PATCH] Add more constant folding cases for min, max, step,
 smoothstep, mix, clamp, atan, and pow.

git-svn-id: https://cvs.khronos.org/svn/repos/ogl/trunk/ecosystem/public/sdk/tools/glslang@22903 e7fa87d3-cd2b-0410-9028-fcbf551c1848
---
 Test/baseResults/constFold.frag.out     | 199 +++++++++++++++---------
 Test/constFold.frag                     |  21 ++-
 glslang/MachineIndependent/Constant.cpp | 147 ++++++++++++-----
 3 files changed, 256 insertions(+), 111 deletions(-)

diff --git a/Test/baseResults/constFold.frag.out b/Test/baseResults/constFold.frag.out
index e8f180ecb..343077f61 100644
--- a/Test/baseResults/constFold.frag.out
+++ b/Test/baseResults/constFold.frag.out
@@ -1,81 +1,134 @@
 0:? Sequence
-0:26  Function Definition: main( (void)
-0:26    Function Parameters: 
-0:28    Sequence
-0:28      Sequence
-0:28        move second child to first child (4-component vector of float)
-0:28          'dx' (4-component vector of float)
-0:28          dPdx (4-component vector of float)
-0:28            'inv' (smooth in 4-component vector of float)
-0:35      move second child to first child (4-component vector of float)
-0:35        'FragColor' (out 4-component vector of float)
-0:35        2.000000
-0:35        6.000000
-0:35        3.000000
-0:35        171.887339
-0:40      move second child to first child (4-component vector of float)
-0:40        'FragColor' (out 4-component vector of float)
-0:40        3.000000
-0:40        2.000000
-0:40        0.001593
-0:40        -0.999999
-0:41      move second child to first child (2-component vector of float)
-0:41        'out2' (out 2-component vector of float)
-0:41        5.600000
-0:41        5.800000
+0:28  Function Definition: main( (void)
+0:28    Function Parameters: 
+0:30    Sequence
+0:30      Sequence
+0:30        move second child to first child (4-component vector of float)
+0:30          'dx' (4-component vector of float)
+0:30          dPdx (4-component vector of float)
+0:30            'inv' (smooth in 4-component vector of float)
+0:37      move second child to first child (4-component vector of float)
+0:37        'FragColor' (out 4-component vector of float)
+0:37        2.000000
+0:37        6.000000
+0:37        3.000000
+0:37        171.887339
 0:42      move second child to first child (4-component vector of float)
-0:42        'out3' (out 4-component vector of float)
-0:42        20.085537
-0:42        2.302585
-0:42        16.000000
-0:42        8.000000
-0:43      move second child to first child (4-component vector of float)
-0:43        'out4' (out 4-component vector of float)
-0:43        10.000000
-0:43        0.100000
-0:43        4.700000
-0:43        10.900000
-0:44      move second child to first child (4-component vector of int)
-0:44        'out5' (out 4-component vector of int)
-0:44        8 (const int)
-0:44        17 (const int)
-0:44        -1 (const int)
-0:44        1 (const int)
-0:45      move second child to first child (3-component vector of float)
-0:45        'out6' (out 3-component vector of float)
-0:45        -1.000000
-0:45        1.000000
-0:45        0.000000
-0:46      move second child to first child (4-component vector of float)
-0:46        'out7' (out 4-component vector of float)
-0:46        4.000000
-0:46        -4.000000
-0:46        5.000000
-0:46        -5.000000
-0:47      move second child to first child (4-component vector of float)
-0:47        'out8' (out 4-component vector of float)
-0:47        4.000000
-0:47        5.000000
-0:47        4.000000
-0:47        -6.000000
+0:42        'FragColor' (out 4-component vector of float)
+0:42        3.000000
+0:42        2.000000
+0:42        0.001593
+0:42        -0.999999
+0:43      move second child to first child (2-component vector of float)
+0:43        'out2' (out 2-component vector of float)
+0:43        5.600000
+0:43        5.800000
+0:44      move second child to first child (4-component vector of float)
+0:44        'out3' (out 4-component vector of float)
+0:44        20.085537
+0:44        2.302585
+0:44        16.000000
+0:44        8.000000
+0:45      move second child to first child (4-component vector of float)
+0:45        'out4' (out 4-component vector of float)
+0:45        10.000000
+0:45        0.100000
+0:45        4.700000
+0:45        10.900000
+0:46      move second child to first child (4-component vector of int)
+0:46        'out5' (out 4-component vector of int)
+0:46        8 (const int)
+0:46        17 (const int)
+0:46        -1 (const int)
+0:46        1 (const int)
+0:47      move second child to first child (3-component vector of float)
+0:47        'out6' (out 3-component vector of float)
+0:47        -1.000000
+0:47        1.000000
+0:47        0.000000
 0:48      move second child to first child (4-component vector of float)
-0:48        'out9' (out 4-component vector of float)
-0:48        8.000000
+0:48        'out7' (out 4-component vector of float)
+0:48        4.000000
 0:48        -4.000000
-0:48        0.345000
-0:48        0.400000
+0:48        5.000000
+0:48        -5.000000
 0:49      move second child to first child (4-component vector of float)
-0:49        'out10' (out 4-component vector of float)
-0:49        1.000000
-0:49        1.000000
-0:49        0.000000
-0:49        0.000000
+0:49        'out8' (out 4-component vector of float)
+0:49        4.000000
+0:49        5.000000
+0:49        4.000000
+0:49        -6.000000
 0:50      move second child to first child (4-component vector of float)
-0:50        'out11' (out 4-component vector of float)
-0:50        0.000000
-0:50        0.000000
-0:50        1.000000
-0:50        0.000000
+0:50        'out9' (out 4-component vector of float)
+0:50        8.000000
+0:50        -4.000000
+0:50        0.345000
+0:50        0.400000
+0:51      move second child to first child (4-component vector of float)
+0:51        'out10' (out 4-component vector of float)
+0:51        1.000000
+0:51        1.000000
+0:51        0.000000
+0:51        0.000000
+0:52      move second child to first child (4-component vector of float)
+0:52        'out11' (out 4-component vector of float)
+0:52        0.000000
+0:52        0.000000
+0:52        1.000000
+0:52        0.000000
+0:53      move second child to first child (4-component vector of float)
+0:53        'out11' (out 4-component vector of float)
+0:53        1.029639
+0:53        0.799690
+0:53        0.674741
+0:53        1.570696
+0:54      move second child to first child (4-component vector of float)
+0:54        'out11' (out 4-component vector of float)
+0:54        0.000000
+0:54        0.523599
+0:54        1.570796
+0:54        1.047198
+0:58      move second child to first child (4-component vector of float)
+0:58        'out11' (out 4-component vector of float)
+0:58        1.373401
+0:58        0.000000
+0:58        0.896055
+0:58        -0.380506
+0:62      move second child to first child (2-component vector of int)
+0:62        'out12' (out 2-component vector of int)
+0:62        15 (const int)
+0:62        16 (const int)
+0:63      move second child to first child (2-component vector of int)
+0:63        'out12' (out 2-component vector of int)
+0:63        17 (const int)
+0:63        17 (const int)
+0:64      move second child to first child (2-component vector of float)
+0:64        'out2' (out 2-component vector of float)
+0:64        871.421253
+0:64        4913.000000
+0:65      move second child to first child (3-component vector of uint)
+0:65        'out13' (out 3-component vector of uint)
+0:65        10 (const uint)
+0:65        20 (const uint)
+0:65        30 (const uint)
+0:66      move second child to first child (2-component vector of float)
+0:66        'out2' (out 2-component vector of float)
+0:66        3.000000
+0:66        6.000000
+0:67      move second child to first child (2-component vector of float)
+0:67        'out2' (out 2-component vector of float)
+0:67        3.500000
+0:67        4.500000
+0:68      move second child to first child (2-component vector of float)
+0:68        'out2' (out 2-component vector of float)
+0:68        0.000000
+0:68        1.000000
+0:69      move second child to first child (4-component vector of float)
+0:69        'out11' (out 4-component vector of float)
+0:69        0.000000
+0:69        0.028000
+0:69        0.500000
+0:69        1.000000
 0:?   Linker Objects
 0:?     'inv' (smooth in 4-component vector of float)
 0:?     'FragColor' (out 4-component vector of float)
@@ -89,4 +142,6 @@
 0:?     'out9' (out 4-component vector of float)
 0:?     'out10' (out 4-component vector of float)
 0:?     'out11' (out 4-component vector of float)
+0:?     'out12' (out 2-component vector of int)
+0:?     'out13' (out 3-component vector of uint)
 
diff --git a/Test/constFold.frag b/Test/constFold.frag
index 61e5f3b14..0769390c6 100644
--- a/Test/constFold.frag
+++ b/Test/constFold.frag
@@ -21,7 +21,9 @@ out vec4 out7;
 out vec4 out8;
 out vec4 out9;
 out vec4 out10;
-out vec4 out11;
+out vec4 out11; 
+out ivec2 out12;
+out uvec3 out13;
 
 void main()
 {
@@ -48,4 +50,21 @@ void main()
     out9 = vec4(roundEven(7.5), roundEven(-4.5), fract(2.345), fract(-2.6)); // 8, -4, .345, 0.4
     out10 = vec4(isinf(4.0/0.0), isinf(-3.0/0.0), isinf(0.0/0.0), isinf(-93048593405938405938405.0));  // true, true, false, false -> 1.0, 1.0, 0.0, 0.0
     out11 = vec4(isnan(4.0/0.0), isnan(-3.0/0.0), isnan(0.0/0.0), isnan(-93048593405938405938405.0));  // false, false, true, false -> 0.0, 1.0, 0.0, 0.0
+    out11 = vec4(tan(0.8), atan(1.029), atan(8.0, 10.0), atan(10000.0));                               // 1.029, 0.8, 0.6747, 1.57
+    out11 = vec4(asin(0.0), asin(0.5), acos(0.0), acos(0.5));                                          // 0.0, .523599, 1.57, 1.047
+
+    const vec4 v1 = vec4(1.0, 0.0, 0.5, -0.2);
+    const vec4 v2 = vec4(0.2, 0.3, 0.4, 0.5);
+    out11 = atan(v1, v2);                      // 1.373401, 0.0, 0.896055, -0.380506
+
+    const ivec2 v3 = ivec2(15.0, 17.0);
+    const ivec2 v4 = ivec2(17.0, 15.0);
+    out12 = min(v3, 16);                      // 15, 16
+    out12 = max(v3, v4);                      // 17, 17
+    out2 = pow(vec2(v3), vec2(2.5, 3.0));     // 871.4, 4913
+    out13 = clamp(uvec3(1, 20, 50), 10u, 30u);  // 10, 20, 30
+    out2 = mix(vec2(3.0, 4.0), vec2(5.0, 6.0), bvec2(false, true));  // 3.0, 6.0
+    out2 = mix(vec2(3.0, 4.0), vec2(5.0, 6.0), 0.25);  // 3.5, 4.5
+    out2 = step(0.5, vec2(0.2, 0.6));                  // 0.0, 1.0
+    out11 = smoothstep(50.0, 60.0, vec4(40.0, 51.0, 55.0, 70.0)); // 0.0, 0.028, 0.5, 1.0
 }
diff --git a/glslang/MachineIndependent/Constant.cpp b/glslang/MachineIndependent/Constant.cpp
index dcccf164d..174aa55a9 100644
--- a/glslang/MachineIndependent/Constant.cpp
+++ b/glslang/MachineIndependent/Constant.cpp
@@ -373,7 +373,7 @@ TIntermTyped* TIntermConstantUnion::fold(TOperator op, TIntermTyped* constantNod
         break;
 
     default:
-        infoSink.info.message(EPrefixInternalError, "Invalid operator for constant folding", getLoc());
+        infoSink.info.message(EPrefixInternalError, "Invalid binary operator for constant folding", getLoc());
 
         return 0;
     }
@@ -633,21 +633,26 @@ TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
 
     // First, see if this is an operation to constant fold, kick out if not,
     // see what size the result is if so.
+
+    bool componentwise = false;  // will also say componentwise if a scalar argument gets repeated to make per-component results
     int objectSize;
     switch (aggrNode->getOp()) {
+    case EOpAtan:
+    case EOpPow:
     case EOpMin:
     case EOpMax:
+    case EOpMix:
+    case EOpClamp:
+        componentwise = true;
+        objectSize = children[0]->getAsConstantUnion()->getType().getObjectSize();
+        break;
+    case EOpCross:
     case EOpReflect:
     case EOpRefract:
     case EOpFaceForward:
-    case EOpAtan:
-    case EOpPow:
-    case EOpClamp:
-    case EOpMix:
-    case EOpDistance:
-    case EOpCross:
         objectSize = children[0]->getAsConstantUnion()->getType().getObjectSize();
         break;
+    case EOpDistance:
     case EOpDot:
         objectSize = 1;
         break;
@@ -656,10 +661,12 @@ TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
                      children[1]->getAsTyped()->getType().getVectorSize();
         break;
     case EOpStep:
+        componentwise = true;
         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
                               children[1]->getAsTyped()->getType().getVectorSize());
         break;
     case EOpSmoothStep:
+        componentwise = true;
         objectSize = std::max(children[0]->getAsTyped()->getType().getVectorSize(),
                               children[2]->getAsTyped()->getType().getVectorSize());
         break;
@@ -669,44 +676,108 @@ TIntermTyped* TIntermediate::fold(TIntermAggregate* aggrNode)
     TConstUnion* newConstArray = new TConstUnion[objectSize];
 
     TVector<TConstUnion*> childConstUnions;
-    for (unsigned int i = 0; i < children.size(); ++i)
-        childConstUnions.push_back(children[i]->getAsConstantUnion()->getUnionArrayPointer());
+    for (unsigned int arg = 0; arg < children.size(); ++arg)
+        childConstUnions.push_back(children[arg]->getAsConstantUnion()->getUnionArrayPointer());
 
     // Second, do the actual folding
 
-    // TODO: Functionality: constant folding: separate component-wise from non-component-wise
-    switch (aggrNode->getOp()) {
-    case EOpMin:
-    case EOpMax:
-        for (int i = 0; i < objectSize; i++) {
-            if (aggrNode->getOp() == EOpMax)
-                newConstArray[i].setDConst(std::max(childConstUnions[0]->getDConst(), childConstUnions[1]->getDConst()));
-            else
-                newConstArray[i].setDConst(std::min(childConstUnions[0]->getDConst(), childConstUnions[1]->getDConst()));
+    bool isFloatingPoint = children[0]->getAsTyped()->getBasicType() == EbtFloat ||
+                           children[0]->getAsTyped()->getBasicType() == EbtDouble;
+    bool isSigned = children[0]->getAsTyped()->getBasicType() == EbtInt;
+    if (componentwise) {
+        for (int comp = 0; comp < objectSize; comp++) {
+
+            // some arguments are scalars instead of matching vectors; simulate a smear
+            int arg0comp = std::min(comp, children[0]->getAsTyped()->getType().getVectorSize() - 1);
+            int arg1comp;
+            if (children.size() > 1)
+                arg1comp = std::min(comp, children[1]->getAsTyped()->getType().getVectorSize() - 1);
+            int arg2comp;
+            if (children.size() > 2)
+                arg2comp = std::min(comp, children[2]->getAsTyped()->getType().getVectorSize() - 1);
+
+            switch (aggrNode->getOp()) {
+            case EOpAtan:
+                newConstArray[comp].setDConst(atan2(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
+                break;
+            case EOpPow:
+                newConstArray[comp].setDConst(pow(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
+                break;
+            case EOpMin:
+                if (isFloatingPoint)
+                    newConstArray[comp].setDConst(std::min(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
+                else if (isSigned)
+                    newConstArray[comp].setIConst(std::min(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
+                else
+                    newConstArray[comp].setUConst(std::min(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
+                break;
+            case EOpMax:
+                if (isFloatingPoint)
+                    newConstArray[comp].setDConst(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()));
+                else if (isSigned)
+                    newConstArray[comp].setIConst(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()));
+                else
+                    newConstArray[comp].setUConst(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()));
+                break;
+            case EOpClamp:
+                if (isFloatingPoint)
+                    newConstArray[comp].setDConst(std::min(std::max(childConstUnions[0][arg0comp].getDConst(), childConstUnions[1][arg1comp].getDConst()), 
+                                                                                                               childConstUnions[2][arg2comp].getDConst()));
+                else if (isSigned)
+                    newConstArray[comp].setIConst(std::min(std::max(childConstUnions[0][arg0comp].getIConst(), childConstUnions[1][arg1comp].getIConst()), 
+                                                                                                               childConstUnions[2][arg2comp].getIConst()));
+                else
+                    newConstArray[comp].setUConst(std::min(std::max(childConstUnions[0][arg0comp].getUConst(), childConstUnions[1][arg1comp].getUConst()), 
+                                                                                                               childConstUnions[2][arg2comp].getUConst()));
+                break;
+            case EOpMix:
+                if (children[2]->getAsTyped()->getBasicType() == EbtBool)
+                    newConstArray[comp].setDConst(childConstUnions[2][arg2comp].getBConst() ? childConstUnions[1][arg1comp].getDConst() :
+                                                                                              childConstUnions[0][arg0comp].getDConst());
+                else
+                    newConstArray[comp].setDConst(childConstUnions[0][arg0comp].getDConst() * (1.0 - childConstUnions[2][arg2comp].getDConst()) +
+                                                  childConstUnions[1][arg1comp].getDConst() *        childConstUnions[2][arg2comp].getDConst());
+                break;
+            case EOpStep:
+                newConstArray[comp].setDConst(childConstUnions[1][arg1comp].getDConst() < childConstUnions[0][arg0comp].getDConst() ? 0.0 : 1.0);
+                break;
+            case EOpSmoothStep:
+            {
+                double t = (childConstUnions[2][arg2comp].getDConst() - childConstUnions[0][arg0comp].getDConst()) / 
+                           (childConstUnions[1][arg1comp].getDConst() - childConstUnions[0][arg0comp].getDConst());
+                if (t < 0.0)
+                    t = 0.0;
+                if (t > 1.0)
+                    t = 1.0;
+                newConstArray[comp].setDConst(t * t * (3.0 - 2.0 * t));
+                break;
+            }
+            default:
+                infoSink.info.message(EPrefixInternalError, "componentwise constant folding operation not implemented", aggrNode->getLoc());
+                return aggrNode;
+            }
         }
-        break;
+    } else {
+        // Non-componentwise...
 
-    // TODO: Functionality: constant folding: the rest of the ops have to be fleshed out
+        switch (aggrNode->getOp()) {
 
-    case EOpAtan:
-    case EOpPow:
-    case EOpModf:
-    case EOpClamp:
-    case EOpMix:
-    case EOpStep:
-    case EOpSmoothStep:
-    case EOpDistance:
-    case EOpDot:
-    case EOpCross:
-    case EOpFaceForward:
-    case EOpReflect:
-    case EOpRefract:
-    case EOpOuterProduct:
-        infoSink.info.message(EPrefixInternalError, "constant folding operation not implemented", aggrNode->getLoc());
-        return aggrNode;
+        // TODO: Functionality: constant folding: the rest of the ops have to be fleshed out
 
-    default:
-        return aggrNode;
+        case EOpModf:
+        case EOpDistance:
+        case EOpDot:
+        case EOpCross:
+        case EOpFaceForward:
+        case EOpReflect:
+        case EOpRefract:
+        case EOpOuterProduct:
+            infoSink.info.message(EPrefixInternalError, "constant folding operation not implemented", aggrNode->getLoc());
+            return aggrNode;
+
+        default:
+            return aggrNode;
+        }
     }
 
     TIntermConstantUnion *newNode = new TIntermConstantUnion(newConstArray, aggrNode->getType());
-- 
GitLab