Make DS InputIterator get_next() API behave correctly in the last partial...
Make DS InputIterator get_next() API behave correctly in the last partial batch case with multiple devices/workers. Now there are 3 cases for get_next(): 1. Each replica gets a full batch, the behavior is the same as before. 2. Some replicas get full batches, some get partial batches, and some get no data. get_next() will return a list with tensors from all replicas which include partial batch data and tensors with batch dimension 0 representing no data. 3. If there is no data in any replicas, an OutOfRange error will be triggered. PiperOrigin-RevId: 238491718
Loading
Please sign in to comment