Commit b5214cab authored by Tian Jin's avatar Tian Jin Committed by Martin Wicke
Browse files

Add GPU implementation for tf.segment_sum. (#11630)

* Add GPU implementation for tf.segment_sum.
* Refactor segment sum to compute asynchronously.
* Add GPU tests and change test datatype to accommodate GPU kernel input requirement.
* Add benchmarks.
* Benchmark results against baseline unsorted impl.

The columns are: datatypes, outer dimension, output outer dimension, inner dimension, execution time difference against baseline impl as a percentage of the time taken by the baseline impl (positive values mean faster execution than baseline).

fp32    512     256     512     -11.8632
fp64    512     256     512     -10.1854
fp32    512     256     2048    -6.2147
fp64    512     256     2048    -0.106
fp32    512     256     8192    -0.0867
fp64    512     256     8192    3.6285
fp32    512     256     1120    -10.9163
fp64    512     256     1120    -1.3509
fp32    512     256     1215    -2.0428
fp64    512     256     1215    -7.884
fp32    512     256     1856    -7.3159
fp64    512     256     1856    -0.1011
fp32    512     256     1302    -3.7802
fp64    512     256     1302    4.6675
fp32    512     256     1329    -13.2275
fp64    512     256     1329    4.3505
fp32    512     256     1531    -7.5993
fp64    512     256     1531    -3.5371
fp32    512     256     1313    -5.0677
fp64    512     256     1313    -1.368
fp32    512     256     1672    -0.0907
fp64    512     256     1672    4.5809
fp32    512     256     1851    -10.2862
fp64    512     256     1851    2.3119
fp32    512     256     1584    -10.3406
fp64    512     256     1584    1.4481
fp32    512     64      512     -6.3544
fp64    512     64      512     -18.4343
fp32    512     64      2048    -9.6639
fp64    512     64      2048    0.0714
fp32    512     64      8192    1.2097
fp64    512     64      8192    10.4839
fp32    512     64      1120    -20.1102
fp64    512     64      1120    -5.0784
fp32    512     64      1215    -6.4061
fp64    512     64      1215    -14.1781
fp32    512     64      1856    2.4221
fp64    512     64      1856    -8.4205
fp32    512     64      1302    -10.2403
fp64    512     64      1302    -7.0577
fp32    512     64      1329    -10.515
fp64    512     64      1329    -4.4899
fp32    512     64      1531    -18.4045
fp64    512     64      1531    9.5982
fp32    512     64      1313    -17.0858
fp64    512     64      1313    -2.02
fp32    512     64      1672    -6.1997
fp64    512     64      1672    1.1616
fp32    512     64      1851    -2.3241
fp64    512     64      1851    -1.0585
fp32    512     64      1584    -7.2199
fp64    512     64      1584    -1.0865
fp32    512     16      512     -14.8754
fp64    512     16      512     -10.987
fp32    512     16      2048    -18.3725
fp64    512     16      2048    1.5949
fp32    512     16      8192    2.163
fp64    512     16      8192    12.5431
fp32    512     16      1120    2.2558
fp64    512     16      1120    -2.7135
fp32    512     16      1215    -12.7228
fp64    512     16      1215    11.0343
fp32    512     16      1856    -6.986
fp64    512     16      1856    7.0687
fp32    512     16      1302    -9.3881
fp64    512     16      1302    -8.2974
fp32    512     16      1329    -11.5103
fp64    512     16      1329    18.9707
fp32    512     16      1531    -14.3721
fp64    512     16      1531    9.6774
fp32    512     16      1313    -16.5546
fp64    512     16      1313    11.7528
fp32    512     16      1672    -10.3689
fp64    512     16      1672    15.1197
fp32    512     16      1851    -11.6021
fp64    512     16      1851    8.2983
fp32    512     16      1584    -13.6702
fp64    512     16      1584    9.4635
fp32    2048    1024    512     -5.6482
fp64    2048    1024    512     1.45
fp32    2048    1024    2048    -0.066
fp64    2048    1024    2048    3.6549
fp32    2048    1024    8192    4.3953
fp64    2048    1024    8192    5.0636
fp32    2048    1024    1120    2.3119
fp64    2048    1024    1120    3.4102
fp32    2048    1024    1215    1.6251
fp64    2048    1024    1215    2.4538
fp32    2048    1024    1856    1.4219
fp64    2048    1024    1856    5.2966
fp32    2048    1024    1302    0.4938
fp64    2048    1024    1302    3.6871
fp32    2048    1024    1329    2.7753
fp64    2048    1024    1329    4.2955
fp32    2048    1024    1531    1.8766
fp64    2048    1024    1531    4.4579
fp32    2048    1024    1313    0.6639
fp64    2048    1024    1313    4.5556
fp32    2048    1024    1672    1.1072
fp64    2048    1024    1672    3.8653
fp32    2048    1024    1851    1.1566
fp64    2048    1024    1851    3.6434
fp32    2048    1024    1584    0.7806
fp64    2048    1024    1584    4.3265
fp32    2048    256     512     -10.7236
fp64    2048    256     512     1.011
fp32    2048    256     2048    2.2321
fp64    2048    256     2048    12.2771
fp32    2048    256     8192    8.0287
fp64    2048    256     8192    15.4497
fp32    2048    256     1120    -8.1388
fp64    2048    256     1120    5.8003
fp32    2048    256     1215    1.709
fp64    2048    256     1215    12.4369
fp32    2048    256     1856    5.1844
fp64    2048    256     1856    14.2236
fp32    2048    256     1302    2.8457
fp64    2048    256     1302    10.5728
fp32    2048    256     1329    -2.547
fp64    2048    256     1329    12.1123
fp32    2048    256     1531    2.4946
fp64    2048    256     1531    12.2398
fp32    2048    256     1313    6.1621
fp64    2048    256     1313    9.857
fp32    2048    256     1672    2.176
fp64    2048    256     1672    9.8899
fp32    2048    256     1851    4.6307
fp64    2048    256     1851    15.0223
fp32    2048    256     1584    3.5238
fp64    2048    256     1584    10.3181
fp32    2048    64      512     -11.5325
fp64    2048    64      512     8.5141
fp32    2048    64      2048    0.6066
fp64    2048    64      2048    25.8166
fp32    2048    64      8192    15.5994
fp64    2048    64      8192    29.453
fp32    2048    64      1120    1.5933
fp64    2048    64      1120    17.1686
fp32    2048    64      1215    -11.8064
fp64    2048    64      1215    21.7897
fp32    2048    64      1856    3.3061
fp64    2048    64      1856    17.6379
fp32    2048    64      1302    -1.201
fp64    2048    64      1302    26.775
fp32    2048    64      1329    -1.377
fp64    2048    64      1329    23.6142
fp32    2048    64      1531    0.9212
fp64    2048    64      1531    16.7177
fp32    2048    64      1313    2.8448
fp64    2048    64      1313    26.824
fp32    2048    64      1672    1.5334
fp64    2048    64      1672    23.7874
fp32    2048    64      1851    0.1934
fp64    2048    64      1851    25.1446
fp32    2048    64      1584    -2.8748
fp64    2048    64      1584    22.3902
fp32    8192    4096    512     0.0512
fp64    8192    4096    512     2.8049
fp32    8192    4096    2048    3.6683
fp64    8192    4096    2048    5.7372
fp32    8192    4096    8192    6.2501
fp64    8192    4096    8192    5.6644
fp32    8192    4096    1120    3.4347
fp64    8192    4096    1120    5.9099
fp32    8192    4096    1215    4.0591
fp64    8192    4096    1215    6.2049
fp32    8192    4096    1856    4.5046
fp64    8192    4096    1856    5.9
fp32    8192    4096    1302    3.8744
fp64    8192    4096    1302    5.74
fp32    8192    4096    1329    3.9169
fp64    8192    4096    1329    6.302
fp32    8192    4096    1531    5.0479
fp64    8192    4096    1531    6.048
fp32    8192    4096    1313    3.5261
fp64    8192    4096    1313    6.0544
fp32    8192    4096    1672    4.6081
fp64    8192    4096    1672    5.2568
fp32    8192    4096    1851    4.2022
fp64    8192    4096    1851    6.0934
fp32    8192    4096    1584    3.3852
fp64    8192    4096    1584    5.6772
fp32    8192    1024    512     3.7405
fp64    8192    1024    512     16.4627
fp32    8192    1024    2048    8.3918
fp64    8192    1024    2048    18.5254
fp32    8192    1024    8192    13.7773
fp64    8192    1024    8192    17.4314
fp32    8192    1024    1120    6.2023
fp64    8192    1024    1120    16.689
fp32    8192    1024    1215    9.5441
fp64    8192    1024    1215    19.7246
fp32    8192    1024    1856    9.864
fp64    8192    1024    1856    18.2895
fp32    8192    1024    1302    7.3145
fp64    8192    1024    1302    19.8528
fp32    8192    1024    1329    9.6131
fp64    8192    1024    1329    19.5526
fp32    8192    1024    1531    8.9847
fp64    8192    1024    1531    20.3696
fp32    8192    1024    1313    7.2819
fp64    8192    1024    1313    20.5361
fp32    8192    1024    1672    11.8095
fp64    8192    1024    1672    18.3047
fp32    8192    1024    1851    12.1042
fp64    8192    1024    1851    21.8124
fp32    8192    1024    1584    9.6549
fp64    8192    1024    1584    18.1818
fp32    8192    256     512     8.2649
fp64    8192    256     512     20.9372
fp32    8192    256     2048    15.6297
fp64    8192    256     2048    35.6407
fp32    8192    256     8192    21.7055
fp64    8192    256     8192    37.225
fp32    8192    256     1120    8.322
fp64    8192    256     1120    33.6497
fp32    8192    256     1215    12.9148
fp64    8192    256     1215    40.0554
fp32    8192    256     1856    12.2226
fp64    8192    256     1856    36.2642
fp32    8192    256     1302    12.2956
fp64    8192    256     1302    40.4711
fp32    8192    256     1329    10.2045
fp64    8192    256     1329    38.4891
fp32    8192    256     1531    14.9187
fp64    8192    256     1531    40.7874
fp32    8192    256     1313    9.5106
fp64    8192    256     1313    42.1367
fp32    8192    256     1672    15.2577
fp64    8192    256     1672    36.7527
fp32    8192    256     1851    15.668
fp64    8192    256     1851    40.2035
fp32    8192    256     1584    14.126
fp64    8192    256     1584    32.7602
parent 674db817
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment