Handle fetching uninitialized tensors from TensorLists.
This is needed to support use-cases of the form: c = constant([1., 2.]) l = tensor_list_from_tensor(c, []) t = tensor_list_get_item(l, 0) dt = gradients(t, c) In the above example gradient of the TensorListFromTensor op, a TensorListStack op, would attempt to stack a list with DT_INVALID type tensors. Since we only read index 0, the gradient list has only index 0 set. This change replaces the DT_INVALID tensors with zeros of the appropriate shape and type. This change adds an `element_shape` input to TensorListGetItem, TensorListPopBack, TensorListStack and TensorListGather ops. This is the shape used to build a zeros tensor for a DT_INVALID tensor in the list. Returning zeros for DT_VARIANT (nested lists) type tensors is not supported yet. I will add support for TensorListConcat in a follow-up change. PiperOrigin-RevId: 230362520
Loading
Please sign in to comment