-
Notifications
You must be signed in to change notification settings - Fork 2
Description
I found a value that I believe to be rounded incorrectly when using a custom FormatInfo that matches NVIDIA TensorFloat32 (A 19 bit floating point format). This happened when using the TowardsZero rounding mode.
EDIT: This issue happens for built-in formats as well. (See comment)
Minimal Example:
import gfloat
format_info_tf32 = gfloat.FormatInfo(
name="TensorFloat32",
k=19, # total of 19 bits
precision=11, # 10 bits in mantissa field + implied 1
emax=127, # 254 - 127 = 127
has_nz=True,
has_infs=True,
num_high_nans=2**10 - 1,
has_subnormals=True,
is_signed=True,
is_twos_complement=False,
)
a = 6.6461399789245764e+35
b = 6.642894793387995e+35
rounded_a = gfloat.round_float(format_info_tf32, a, gfloat.RoundMode.TowardZero)
rounded_b = gfloat.round_float(format_info_tf32, b, gfloat.RoundMode.TowardZero)
print("a: :", a)
print("b: :", b)
print("rounded_a :", rounded_a)
print("rounded_b :", rounded_b)
print("b == rounded_b :", b == rounded_b) # True only if `b` is in the space of TensorFloat32
print("(rounded_a < b < a) :", rounded_a < b and b < a) # Rounding skipped `b`Output:
a: : 6.6461399789245764e+35
b: : 6.642894793387995e+35
rounded_a : 6.639649607851411e+35
rounded_b : 6.642894793387995e+35
b == rounded_b : True
(rounded_a < b < a) : True
a is the value that gets rounded incorrectly. b is a value that can be encoded by TensorFloat32.
a is larger than b, but rounding a down yields a value that is less than b.
It appears to have skipped a value when rounding down.
I dug a little and believe that the following is the root cause (line 63 of round.py):
np.log2(vpos) appears to be rounding its result up (at least on my platform). It ends up outputting an integer (but still a float type) despite the input not being a power of 2.
Calculation:
import numpy as np
vpos = 6.6461399789245764e+35 # = a
print("vpos :", vpos)
print("log2(vpos) :", np.log2(vpos))
print("floor(...) :", np.floor(np.log2(vpos)))
print("inf(floor(...)) :", int(np.floor(np.log2(vpos))))
print("int(vpos) :", int(vpos))
print("2**119 :", 2**119)
print("int(vpos) < 2**119 :", int(vpos) < 2**119)Output:
vpos : 6.6461399789245764e+35
log2(vpos) : 119.0
floor(...) : 119.0
inf(floor(...)) : 119
int(vpos) : 664613997892457641303998350787346432
2**119 : 664613997892457936451903530140172288
int(vpos) < 2**119 : True
Since it is flooring log2(vpos), I believe the code relies on vpos being greater than 2**(floor(log2(vpos))).
However, since log2(vpos) rounds up to an integer, floor does not lower it to the correct value.
This causes what would be the lsb of the mantissa to be rounded off.