Commit 7d195d0d authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TensorFlower Gardener
Browse files

Fix an floating point inaccuracy issue in precision_recall_at_equal_thresholds due

to accumulating the tp/fp/tn/fn values in float32, which can become highly inaccurate
as the number of values increases.

In the common case, the method sums the value 1.0f to the tp/fp/tn/fn bucket for every
value in the predictions tensor.  If the tensor is large (say, it represents an image
and we have one tp/fp/tn/fn value per pixel), then we are essentially adding many 1.0f's
together, across the entire batch and also across all the batches.  By doing it in
float32 the value starts becoming inaccurate at around 16M, which is very small.

In practice, we see a deviation of 100x when the total reaches about 3e10 (the previous
code reports a number about 1e8 when the actual value should be 3e10).

We avoid all these issues by always accumulating in float64.

Also fix a bug that the method cannot be called with predictions dtype being anything
other than float32.  Preivously it would crash due to the eps code near the end.
Added tests for using float64 and float16.

PiperOrigin-RevId: 199216173
parent 14d4d163
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment