Set device on Identity ops representing inlined function inputs/outputs.
This fixes a strange multi-GPU performance issue where the identities from inlined cond_v2 branch functions (part of the control flow lowering pass) were being placed alone on a GPU, creating unnecessary device traffic. Ideally the placer wouldn't be this stupid or grappler would remove the useless identities, but this fix is easier for now. This also adds a simple benchmark testing nested function calls, which currently are inlined. This benchmark doesn't expose the original problem (all of the ops, including the identities added by inlining, are placed on a single device regardless), but I've included it in case it catches any future regressions. I also increased the number of iterations, as I found this made the results more stable. PiperOrigin-RevId: 228619628
Loading
Please sign in to comment