Fix an floating point inaccuracy issue in precision_recall_at_equal_thresholds due
to accumulating the tp/fp/tn/fn values in float32, which can become highly inaccurate as the number of values increases. In the common case, the method sums the value 1.0f to the tp/fp/tn/fn bucket for every value in the predictions tensor. If the tensor is large (say, it represents an image and we have one tp/fp/tn/fn value per pixel), then we are essentially adding many 1.0f's together, across the entire batch and also across all the batches. By doing it in float32 the value starts becoming inaccurate at around 16M, which is very small. In practice, we see a deviation of 100x when the total reaches about 3e10 (the previous code reports a number about 1e8 when the actual value should be 3e10). We avoid all these issues by always accumulating in float64. Also fix a bug that the method cannot be called with predictions dtype being anything other than float32. Preivously it would crash due to the eps code near the end. Added tests for using float64 and float16. PiperOrigin-RevId: 199216173
Loading
Please sign in to comment