Commit 76aa6cf9 authored by Dimitris Vardoulakis's avatar Dimitris Vardoulakis Committed by TensorFlower Gardener
Browse files

[TF:XLA] Fix the AR/CRS combiner to avoid crashing when two cross-module...

[TF:XLA] Fix the AR/CRS combiner to avoid crashing when two cross-module AllReduces lead to the same cross-replica AllReduce.
Also, add tests RewriteMultipleAdds and RewriteArSubtractCrs to document existing behavior of the pass.

This graph:

A
|
AR    B   C
 \   /    |
   +     AR
    \   /
      +
      |
     CRS

gets rewritten to:

    B  const  C
     \   /    |
 A    Div    AR  const
  \   /      |  /
    +       Div
     \     /
        +
        |
       CRS

Ideally, we would remove both cross-module AllReduces. It's not not straightforward to do that, so I'm leaving it for a separate CL.

PiperOrigin-RevId: 230472380
parent bba3f6a4
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment