Commit 728c05d1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TensorFlower Gardener
Browse files

Checkpoint side-effecting assignments inside staged TensorFlow conditional branch.

Conditionals that branch on Tensors are staged as tf.cond operations. To construct the computation graph, both branches of the conditional are executed when tf.cond is called. As a result a Python side-effecting assignment (e.g. o.x = b) in one branch of staged code may incorrectly influence the results of the other branch (see below). To restore expected semantics, we checkpoint composite symbols (e.g. o.x) recording the value before executing each branch, then restore those values after executing it.

This ensures the staged conditional returns the correct values.

For example:

  class Foo(object):
    def __init__(self):
      self.b = 0

  @tf.function()
  def bar(x, condition):

    if condition:
      x.b = 1
      print(x.b)  # Output: 1
    else:
      print(x.b)  # Output: 1 (incorrect; the correct output should be 0)

    return x.b

  foo = Foo()
  bar(foo, tf.constant(True)) # returns 1

  foo = Foo()
  bar(foo, tf.constant(False)) # returns 0

Before this fix both calls to bar would have returned 1 (as the initial value was overwritten when tracing the true branch even when it would never execute in eager mode).

PiperOrigin-RevId: 238003482
parent 8cb1aa18
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment