Add GPU implementation for tf.segment_sum. (#11630)
* Add GPU implementation for tf.segment_sum. * Refactor segment sum to compute asynchronously. * Add GPU tests and change test datatype to accommodate GPU kernel input requirement. * Add benchmarks. * Benchmark results against baseline unsorted impl. The columns are: datatypes, outer dimension, output outer dimension, inner dimension, execution time difference against baseline impl as a percentage of the time taken by the baseline impl (positive values mean faster execution than baseline). fp32 512 256 512 -11.8632 fp64 512 256 512 -10.1854 fp32 512 256 2048 -6.2147 fp64 512 256 2048 -0.106 fp32 512 256 8192 -0.0867 fp64 512 256 8192 3.6285 fp32 512 256 1120 -10.9163 fp64 512 256 1120 -1.3509 fp32 512 256 1215 -2.0428 fp64 512 256 1215 -7.884 fp32 512 256 1856 -7.3159 fp64 512 256 1856 -0.1011 fp32 512 256 1302 -3.7802 fp64 512 256 1302 4.6675 fp32 512 256 1329 -13.2275 fp64 512 256 1329 4.3505 fp32 512 256 1531 -7.5993 fp64 512 256 1531 -3.5371 fp32 512 256 1313 -5.0677 fp64 512 256 1313 -1.368 fp32 512 256 1672 -0.0907 fp64 512 256 1672 4.5809 fp32 512 256 1851 -10.2862 fp64 512 256 1851 2.3119 fp32 512 256 1584 -10.3406 fp64 512 256 1584 1.4481 fp32 512 64 512 -6.3544 fp64 512 64 512 -18.4343 fp32 512 64 2048 -9.6639 fp64 512 64 2048 0.0714 fp32 512 64 8192 1.2097 fp64 512 64 8192 10.4839 fp32 512 64 1120 -20.1102 fp64 512 64 1120 -5.0784 fp32 512 64 1215 -6.4061 fp64 512 64 1215 -14.1781 fp32 512 64 1856 2.4221 fp64 512 64 1856 -8.4205 fp32 512 64 1302 -10.2403 fp64 512 64 1302 -7.0577 fp32 512 64 1329 -10.515 fp64 512 64 1329 -4.4899 fp32 512 64 1531 -18.4045 fp64 512 64 1531 9.5982 fp32 512 64 1313 -17.0858 fp64 512 64 1313 -2.02 fp32 512 64 1672 -6.1997 fp64 512 64 1672 1.1616 fp32 512 64 1851 -2.3241 fp64 512 64 1851 -1.0585 fp32 512 64 1584 -7.2199 fp64 512 64 1584 -1.0865 fp32 512 16 512 -14.8754 fp64 512 16 512 -10.987 fp32 512 16 2048 -18.3725 fp64 512 16 2048 1.5949 fp32 512 16 8192 2.163 fp64 512 16 8192 12.5431 fp32 512 16 1120 2.2558 fp64 512 16 1120 -2.7135 fp32 512 16 1215 -12.7228 fp64 512 16 1215 11.0343 fp32 512 16 1856 -6.986 fp64 512 16 1856 7.0687 fp32 512 16 1302 -9.3881 fp64 512 16 1302 -8.2974 fp32 512 16 1329 -11.5103 fp64 512 16 1329 18.9707 fp32 512 16 1531 -14.3721 fp64 512 16 1531 9.6774 fp32 512 16 1313 -16.5546 fp64 512 16 1313 11.7528 fp32 512 16 1672 -10.3689 fp64 512 16 1672 15.1197 fp32 512 16 1851 -11.6021 fp64 512 16 1851 8.2983 fp32 512 16 1584 -13.6702 fp64 512 16 1584 9.4635 fp32 2048 1024 512 -5.6482 fp64 2048 1024 512 1.45 fp32 2048 1024 2048 -0.066 fp64 2048 1024 2048 3.6549 fp32 2048 1024 8192 4.3953 fp64 2048 1024 8192 5.0636 fp32 2048 1024 1120 2.3119 fp64 2048 1024 1120 3.4102 fp32 2048 1024 1215 1.6251 fp64 2048 1024 1215 2.4538 fp32 2048 1024 1856 1.4219 fp64 2048 1024 1856 5.2966 fp32 2048 1024 1302 0.4938 fp64 2048 1024 1302 3.6871 fp32 2048 1024 1329 2.7753 fp64 2048 1024 1329 4.2955 fp32 2048 1024 1531 1.8766 fp64 2048 1024 1531 4.4579 fp32 2048 1024 1313 0.6639 fp64 2048 1024 1313 4.5556 fp32 2048 1024 1672 1.1072 fp64 2048 1024 1672 3.8653 fp32 2048 1024 1851 1.1566 fp64 2048 1024 1851 3.6434 fp32 2048 1024 1584 0.7806 fp64 2048 1024 1584 4.3265 fp32 2048 256 512 -10.7236 fp64 2048 256 512 1.011 fp32 2048 256 2048 2.2321 fp64 2048 256 2048 12.2771 fp32 2048 256 8192 8.0287 fp64 2048 256 8192 15.4497 fp32 2048 256 1120 -8.1388 fp64 2048 256 1120 5.8003 fp32 2048 256 1215 1.709 fp64 2048 256 1215 12.4369 fp32 2048 256 1856 5.1844 fp64 2048 256 1856 14.2236 fp32 2048 256 1302 2.8457 fp64 2048 256 1302 10.5728 fp32 2048 256 1329 -2.547 fp64 2048 256 1329 12.1123 fp32 2048 256 1531 2.4946 fp64 2048 256 1531 12.2398 fp32 2048 256 1313 6.1621 fp64 2048 256 1313 9.857 fp32 2048 256 1672 2.176 fp64 2048 256 1672 9.8899 fp32 2048 256 1851 4.6307 fp64 2048 256 1851 15.0223 fp32 2048 256 1584 3.5238 fp64 2048 256 1584 10.3181 fp32 2048 64 512 -11.5325 fp64 2048 64 512 8.5141 fp32 2048 64 2048 0.6066 fp64 2048 64 2048 25.8166 fp32 2048 64 8192 15.5994 fp64 2048 64 8192 29.453 fp32 2048 64 1120 1.5933 fp64 2048 64 1120 17.1686 fp32 2048 64 1215 -11.8064 fp64 2048 64 1215 21.7897 fp32 2048 64 1856 3.3061 fp64 2048 64 1856 17.6379 fp32 2048 64 1302 -1.201 fp64 2048 64 1302 26.775 fp32 2048 64 1329 -1.377 fp64 2048 64 1329 23.6142 fp32 2048 64 1531 0.9212 fp64 2048 64 1531 16.7177 fp32 2048 64 1313 2.8448 fp64 2048 64 1313 26.824 fp32 2048 64 1672 1.5334 fp64 2048 64 1672 23.7874 fp32 2048 64 1851 0.1934 fp64 2048 64 1851 25.1446 fp32 2048 64 1584 -2.8748 fp64 2048 64 1584 22.3902 fp32 8192 4096 512 0.0512 fp64 8192 4096 512 2.8049 fp32 8192 4096 2048 3.6683 fp64 8192 4096 2048 5.7372 fp32 8192 4096 8192 6.2501 fp64 8192 4096 8192 5.6644 fp32 8192 4096 1120 3.4347 fp64 8192 4096 1120 5.9099 fp32 8192 4096 1215 4.0591 fp64 8192 4096 1215 6.2049 fp32 8192 4096 1856 4.5046 fp64 8192 4096 1856 5.9 fp32 8192 4096 1302 3.8744 fp64 8192 4096 1302 5.74 fp32 8192 4096 1329 3.9169 fp64 8192 4096 1329 6.302 fp32 8192 4096 1531 5.0479 fp64 8192 4096 1531 6.048 fp32 8192 4096 1313 3.5261 fp64 8192 4096 1313 6.0544 fp32 8192 4096 1672 4.6081 fp64 8192 4096 1672 5.2568 fp32 8192 4096 1851 4.2022 fp64 8192 4096 1851 6.0934 fp32 8192 4096 1584 3.3852 fp64 8192 4096 1584 5.6772 fp32 8192 1024 512 3.7405 fp64 8192 1024 512 16.4627 fp32 8192 1024 2048 8.3918 fp64 8192 1024 2048 18.5254 fp32 8192 1024 8192 13.7773 fp64 8192 1024 8192 17.4314 fp32 8192 1024 1120 6.2023 fp64 8192 1024 1120 16.689 fp32 8192 1024 1215 9.5441 fp64 8192 1024 1215 19.7246 fp32 8192 1024 1856 9.864 fp64 8192 1024 1856 18.2895 fp32 8192 1024 1302 7.3145 fp64 8192 1024 1302 19.8528 fp32 8192 1024 1329 9.6131 fp64 8192 1024 1329 19.5526 fp32 8192 1024 1531 8.9847 fp64 8192 1024 1531 20.3696 fp32 8192 1024 1313 7.2819 fp64 8192 1024 1313 20.5361 fp32 8192 1024 1672 11.8095 fp64 8192 1024 1672 18.3047 fp32 8192 1024 1851 12.1042 fp64 8192 1024 1851 21.8124 fp32 8192 1024 1584 9.6549 fp64 8192 1024 1584 18.1818 fp32 8192 256 512 8.2649 fp64 8192 256 512 20.9372 fp32 8192 256 2048 15.6297 fp64 8192 256 2048 35.6407 fp32 8192 256 8192 21.7055 fp64 8192 256 8192 37.225 fp32 8192 256 1120 8.322 fp64 8192 256 1120 33.6497 fp32 8192 256 1215 12.9148 fp64 8192 256 1215 40.0554 fp32 8192 256 1856 12.2226 fp64 8192 256 1856 36.2642 fp32 8192 256 1302 12.2956 fp64 8192 256 1302 40.4711 fp32 8192 256 1329 10.2045 fp64 8192 256 1329 38.4891 fp32 8192 256 1531 14.9187 fp64 8192 256 1531 40.7874 fp32 8192 256 1313 9.5106 fp64 8192 256 1313 42.1367 fp32 8192 256 1672 15.2577 fp64 8192 256 1672 36.7527 fp32 8192 256 1851 15.668 fp64 8192 256 1851 40.2035 fp32 8192 256 1584 14.126 fp64 8192 256 1584 32.7602
Loading
Please sign in to comment