Avoid quadratic node explosion in FunctionalizeCond.
Before this change, if we have the following tf.cond:
def computation(p, a, b, c, d):
def f():
return (a, a + b, a + b + c, a + b + c + d)
return tf.cond(p, f, f)
We will create one "If" node for each Merge node, because those Merge nodes have different sets of ancestor Switch nodes. Merge node for first return value will have "Switch for input a", Merge node for second return value will have "Switch for input a & Switch for input b", and so on. As a result, the total number of If nodes is the number of nodes, and total number of nodes in If's branch functions is quadratic.
In this CL, we add predicate of the Switch nodes in ancestor id. For the previous example, now we only generate 1 If node.
PiperOrigin-RevId: 231844072
Loading
Please sign in to comment