[XLA] Use generic fusion's multi-output fusion support.
The main benefit here is a 6% speedup on MLPERF ncf. BEGIN_PUBLIC [XLA] Fix multioutput fusion support in normal instruction fusion. END_PUBLIC ** End-to-end macrobenchmarks ** NOTE: conv_draw runs on a single core on the benchmark without cross replica sum. A replicated version of conv_draw would not create the fusion nodes found here due to the cross replica sum. As a result, I do not expect a performance hit in typical production use cases. Speedups (before / after): INFEED: 0.83x (before 1033.5 ms, after 1250.9 ms): openai_v0_rnn_optimized_df WINDOW SELECTION: 0.89x (before 8517.4 ms, after 9537.7 ms): conv_draw_batch_128_df WINDOW SELECTION: 0.91x (before 15847.4 ms, after 17435.4 ms): conv_draw_batch_128_jf MORE REMAT: 0.93x (before 5367.3 ms, after 5787.3 ms): deeplab_training_img513_depth101_batch_8_jf WINDOW SELECTION: 0.93x (before 5078.1 ms, after 5453.6 ms): conv_draw_batch_16_df WINDOW SELCTION: 0.98x (before 7886.3 ms, after 8052.7 ms): conv_draw_batch_16_jf 0.98x (before 36499.0 us, after 37218.0 us): magenta_jf 0.99x (before 10003.2 ms, after 10124.5 ms): faster_rcnn_image640_batch_8_df 0.99x (before 7618.2 ms, after 7687.7 ms): alphachess_rec_df 0.99x (before 74307.0 us, after 74846.0 us): translate_inference_jf 1.01x (before 29668.0 us, after 29515.0 us): translate_inference_hybrid_df 1.01x (before 795.7 ms, after 791.5 ms): deleuze_learn_jf 1.01x (before 438.1 ms, after 435.6 ms): wavernn_train_jf 1.01x (before 137.9 ms, after 137.0 ms): resnet_v2_50_batch_16_df 1.01x (before 13349.9 ms, after 13263.0 ms): nmt_32k_large_df 1.01x (before 20215.9 ms, after 20083.0 ms): nmt_32k_large_jf 1.01x (before 2602.9 ms, after 2585.3 ms): openai_v0_rnn_natural_df 1.01x (before 539.1 ms, after 535.2 ms): deleuze_learn_df 1.01x (before 3831.9 ms, after 3801.5 ms): nmt_transformer_big_adafactor_with_model_split_jf 1.01x (before 6510.3 ms, after 6456.9 ms): nmt_model_parallel_jf 1.01x (before 1538.1 ms, after 1525.2 ms): babelfish_transformer_big_adafactor_with_model_split_jf 1.01x (before 2384.5 ms, after 2362.3 ms): nmt_transformer_big_adafactor_with_model_split_df 1.01x (before 6567.9 ms, after 6503.3 ms): nmt_single_tpu_jf 1.01x (before 285.7 ms, after 282.8 ms): wavenet_train_ar_jf 1.01x (before 54764.0 us, after 54203.0 us): resnet_imagenet_jf 1.01x (before 193.0 ms, after 190.9 ms): wavenet_train_ar_df 1.01x (before 46145.0 us, after 45632.0 us): translate_inference_hybrid_jf 1.01x (before 4516.0 ms, after 4462.8 ms): nmt_model_parallel_df 1.01x (before 2071.6 ms, after 2045.9 ms): babelfish_jf 1.01x (before 61792.0 us, after 60928.0 us): inception_v3_batch_8_train_jf 1.01x (before 4566.1 ms, after 4499.8 ms): nmt_single_tpu_df 1.02x (before 6090.6 ms, after 5999.2 ms): alphastar_df 1.02x (before 233.6 ms, after 230.1 ms): nmt_model_char_v1_df 1.02x (before 39605.0 us, after 38971.0 us): resnet_imagenet_df 1.02x (before 29482.0 us, after 28982.0 us): inception_v2_batch_8_train_jf 1.02x (before 571.6 ms, after 561.8 ms): nmt_model_char_v0_df 1.02x (before 1173.0 us, after 1153.0 us): smartcompose_single_step_jf 1.02x (before 11952.1 ms, after 11735.3 ms): faster_rcnn_image640_batch_8_jf 1.02x (before 22885.0 us, after 22449.0 us): inception_v2_batch_8_train_df 1.02x (before 31350.0 us, after 30720.0 us): magenta_df 1.02x (before 10627.0 ms, after 10397.1 ms): alphastar_jf 1.02x (before 44625.0 us, after 43601.0 us): inception_v3_batch_8_train_df 1.02x (before 44470.0 us, after 43410.0 us): translate_inference_df 1.02x (before 699.0 us, after 682.0 us): smartcompose_single_step_df 1.03x (before 102.0 ms, after 99217.0 us): magenta_dynamic_jf 1.03x (before 93749.0 us, after 91198.0 us): magenta_dynamic_df 1.85x (before 5873.3 ms, after 3177.0 ms): object_detection_ssd_mobilenet_v1_300x300_batch_128_df Geomean before: 680.7 ms Geomean after: 677.9 ms Geomean speedup: 1.00x (geomean before / geomean after) ** HBM usage results ** Size reduction (before / after): 0.92x (before 5.92 GiB, after 6.42 GiB): unet_training_jf 0.94x (before 1.90 GiB, after 2.02 GiB): babelfish_char2feats02_jf 0.94x (before 1.99 GiB, after 2.11 GiB): babelfish_char2feats01_jf 0.97x (before 1.45 GiB, after 1.50 GiB): nmt_model_parallel_df 0.97x (before 1.45 GiB, after 1.50 GiB): nmt_model_parallel_jf 0.98x (before 1.59 GiB, after 1.62 GiB): conv_draw_batch_128_jf 0.98x (before 1.62 GiB, after 1.65 GiB): conv_draw_batch_128_df 0.99x (before 3.29 GiB, after 3.34 GiB): nmt_32k_large_df 0.99x (before 3.29 GiB, after 3.34 GiB): nmt_32k_large_jf 0.99x (before 3.34 GiB, after 3.39 GiB): deleuze_learn_df 0.99x (before 3.35 GiB, after 3.39 GiB): deleuze_learn_jf 0.99x (before 1.18 GiB, after 1.19 GiB): resnet_imagenet_df 0.99x (before 1.18 GiB, after 1.19 GiB): resnet_imagenet_jf 1.01x (before 8.02 GiB, after 7.98 GiB): retinanet_training_img1024_depth50_batch_8_bf16_df 1.01x (before 1.15 GiB, after 1.14 GiB): inception_v3_batch_8_train_jf 1.01x (before 1.15 GiB, after 1.14 GiB): inception_v3_batch_8_train_df 1.01x (before 8.63 GiB, after 8.55 GiB): retinanet_training_img768_depth50_batch_8_df 1.01x (before 6.09 GiB, after 6.02 GiB): retinanet_training_img640_batch_8_jf 1.01x (before 77.1 MiB, after 76.2 MiB): magenta_jf 1.01x (before 6.63 GiB, after 6.55 GiB): retinanet_training_img768_depth50_batch_8_jf 1.01x (before 6.09 GiB, after 6.01 GiB): retinanet_training_img640_batch_8_df 1.01x (before 76.8 MiB, after 75.7 MiB): magenta_df 1.02x (before 8.62 GiB, after 8.45 GiB): inception_v3_batch_128_train_df 1.02x (before 241 MiB, after 237 MiB): alphachess_rec_df 1.02x (before 6.47 GiB, after 6.34 GiB): inception_v3_batch_128_train_jf 1.03x (before 1.08 GiB, after 1.05 GiB): nmt_model_char_v0_jf 1.03x (before 1.08 GiB, after 1.05 GiB): nmt_model_char_v0_df 1.03x (before 5.19 GiB, after 5.04 GiB): alphastar_jf 1.04x (before 5.47 GiB, after 5.25 GiB): alphastar_df 1.04x (before 698 MiB, after 668 MiB): conv_draw_batch_16_df 1.05x (before 696 MiB, after 665 MiB): conv_draw_batch_16_jf 1.05x (before 6.39 GiB, after 6.07 GiB): deeplab_training_img513_depth101_batch_8_jf 1.07x (before 243 MiB, after 227 MiB): alphachess_rec_jf Geomean before: 1.03 GiB Geomean after: 1.03 GiB Geomean size reduction: 1.00x (geomean before / geomean after) *** Reason for rollback *** Doesn't break speech models anymore *** Original change description *** Automated rollback of changelist 204006812 PiperOrigin-RevId: 219905428
Loading
Please sign in to comment