Commit c0e3d0a7 authored by Tong Shen's avatar Tong Shen Committed by TensorFlower Gardener
Browse files

Avoid quadratic node explosion in FunctionalizeCond.

Before this change, if we have the following tf.cond:

  def computation(p, a, b, c, d):
    def f():
      return (a, a + b, a + b + c, a + b + c + d)
    return tf.cond(p, f, f)

We will create one "If" node for each Merge node, because those Merge nodes have different sets of ancestor Switch nodes. Merge node for first return value will have "Switch for input a", Merge node for second return value will have "Switch for input a & Switch for input b", and so on. As a result, the total number of If nodes is the number of nodes, and total number of nodes in If's branch functions is quadratic.

In this CL, we add predicate of the Switch nodes in ancestor id. For the previous example, now we only generate 1 If node.

PiperOrigin-RevId: 231844072
parent 25062cbb
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment