Commit 7df1985f authored by Saurabh Saxena's avatar Saurabh Saxena Committed by TensorFlower Gardener
Browse files

Handle fetching uninitialized tensors from TensorLists.

This is needed to support use-cases of the form:

c = constant([1., 2.])
l = tensor_list_from_tensor(c, [])
t = tensor_list_get_item(l, 0)
dt = gradients(t, c)

In the above example gradient of the TensorListFromTensor op, a TensorListStack op, would attempt to stack a list with DT_INVALID type tensors. Since we only read index 0, the gradient list has only index 0 set.

This change replaces the DT_INVALID tensors with zeros of the appropriate shape and type. This change adds an `element_shape` input to TensorListGetItem, TensorListPopBack, TensorListStack and TensorListGather ops. This is the shape used to build a zeros tensor for a DT_INVALID tensor in the list. Returning zeros for DT_VARIANT (nested lists) type tensors is not supported yet.

I will add support for TensorListConcat in a follow-up change.

PiperOrigin-RevId: 230362520
parent 7d11353d
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment