Imagine a time series of financial calculations. You compute, say, state income tax on realized gains on each line of a large number of rows in a database.
The formula might look like this:
Net Profit = Net Sale - Net Purchase
Tax Rate * (Net Profit) = Tax Amount
Now you give this to someone in a feed, but you round to the nearest penny, because… well, it’s meaningless to talk about 4/10 of a penny.
Net Profit = $1046.62
Tax Amount = $209.12
Except… you rounded down. And the accountant reconciling this feed will notice, after a sufficient number of differences, that the final value might be significantly different than he expected.
You can reduce this error, however, with “random rounding”. That is, if the number is 0.004, that gives you a 60% chance of rounding down to 0.00, and a 40% chance of rounding up to 0.01. The result of this is superior aggregation, and although it might not matter much, if you’re doing millions of small-value calculations, and they are actullay impacting customers, errors can add up (and not always in your favor).
More information on stochastic rounding from the Royal Society.
A simple library in python do to this is appended, with both random and deterministic versions.
Careful use of randomization comes into play when doing integer-only arithmetic. For example, the result of 1 + .01 is always 1, if all you have are integers available, as is often the case in optimized int8 ai-libraries, blockchain, and cryptographic arithmentic. The accumulated loss of repeatedly rounding this way can be infinite, and can lead to extreme hallucination, for example, in some memory-optimized image models.
But if you use stochastic rounding, the result of 1 + .01 is 1 : 90% of the time and is 2 : 10% of the time. Which means that over the long term, you get an average error of zero. Whew, no more crazy hands!
Python cround.py:
import random
def custom_round(value, n=2):
factor = 10 ** n
beyond_nth_decimal = (value * factor) % 1
if random.random() < beyond_nth_decimal:
return round(value - beyond_nth_decimal/factor + 1/factor, n)
else:
return round(value - beyond_nth_decimal/factor, n)
def deterministic_hash(value):
# we just want something to rotate among rounded values
int_repr = int.from_bytes(float(value).hex().encode(), 'big')
hash_value = int_repr * 2654435761 % 2**32
return (hash_value / 2**32)
def custom_round_fixed(value, n=2):
factor = 10 ** n
beyond_nth_decimal = (value * factor) % 1
pseudo_random = deterministic_hash(value)
if pseudo_random < beyond_nth_decimal:
return round(value - beyond_nth_decimal / factor + 1 / factor, n)
else:
return round(value - beyond_nth_decimal / factor, n)