[TF:XLA] Fix the AR/CRS combiner to avoid crashing when two cross-module...
[TF:XLA] Fix the AR/CRS combiner to avoid crashing when two cross-module AllReduces lead to the same cross-replica AllReduce.
Also, add tests RewriteMultipleAdds and RewriteArSubtractCrs to document existing behavior of the pass.
This graph:
A
|
AR B C
\ / |
+ AR
\ /
+
|
CRS
gets rewritten to:
B const C
\ / |
A Div AR const
\ / | /
+ Div
\ /
+
|
CRS
Ideally, we would remove both cross-module AllReduces. It's not not straightforward to do that, so I'm leaving it for a separate CL.
PiperOrigin-RevId: 230472380
Loading
Please sign in to comment