Checkpoint side-effecting assignments inside staged TensorFlow conditional branch.
Conditionals that branch on Tensors are staged as tf.cond operations. To construct the computation graph, both branches of the conditional are executed when tf.cond is called. As a result a Python side-effecting assignment (e.g. o.x = b) in one branch of staged code may incorrectly influence the results of the other branch (see below). To restore expected semantics, we checkpoint composite symbols (e.g. o.x) recording the value before executing each branch, then restore those values after executing it.
This ensures the staged conditional returns the correct values.
For example:
class Foo(object):
def __init__(self):
self.b = 0
@tf.function()
def bar(x, condition):
if condition:
x.b = 1
print(x.b) # Output: 1
else:
print(x.b) # Output: 1 (incorrect; the correct output should be 0)
return x.b
foo = Foo()
bar(foo, tf.constant(True)) # returns 1
foo = Foo()
bar(foo, tf.constant(False)) # returns 0
Before this fix both calls to bar would have returned 1 (as the initial value was overwritten when tracing the true branch even when it would never execute in eager mode).
PiperOrigin-RevId: 238003482
Loading
Please sign in to comment