Skip to content

Commit d438621

Browse files
committed
Improve type handling for arange & rel pos embeds, keep calculations in float32 until application (may change to apply in float32 in future). Prevent arange type hijacking by DeepSpeed Zero
1 parent 3234daf commit d438621

File tree

4 files changed

+20
-21
lines changed

4 files changed

+20
-21
lines changed

timm/layers/pos_embed_rel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ def gen_relative_log_coords(
311311
):
312312
assert mode in ('swin', 'cr')
313313
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
314-
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
315-
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
314+
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
315+
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
316316
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
317317
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
318318
if mode == 'swin':

timm/layers/pos_embed_sincos.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,23 @@ def pixel_freq_bands(
1515
num_bands: int,
1616
max_freq: float = 224.,
1717
linear_bands: bool = True,
18-
dtype: torch.dtype = torch.float32,
1918
device: Optional[torch.device] = None,
2019
):
2120
if linear_bands:
22-
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
21+
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
2322
else:
24-
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
23+
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
2524
return bands * torch.pi
2625

2726

2827
def freq_bands(
2928
num_bands: int,
3029
temperature: float = 10000.,
3130
step: int = 2,
32-
dtype: torch.dtype = torch.float32,
3331
device: Optional[torch.device] = None,
3432
) -> torch.Tensor:
35-
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
33+
exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
34+
bands = 1. / (temperature ** exp)
3635
return bands
3736

3837

@@ -61,18 +60,20 @@ def build_sincos2d_pos_embed(
6160
"""
6261
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
6362
pos_dim = dim // 4
64-
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
63+
bands = freq_bands(pos_dim, temperature=temperature, step=1, device=device)
6564

6665
if reverse_coord:
6766
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
6867
grid = torch.stack(torch.meshgrid(
69-
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
68+
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
69+
for s in feat_shape])
70+
).flatten(1).transpose(0, 1)
7071
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
7172
# FIXME add support for unflattened spatial dim?
7273

7374
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
7475
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
75-
return pos_emb
76+
return pos_emb.to(dtype=dtype)
7677

7778

7879
def build_fourier_pos_embed(
@@ -112,15 +113,13 @@ def build_fourier_pos_embed(
112113
num_bands,
113114
float(max_res),
114115
linear_bands=linear_bands,
115-
dtype=dtype,
116116
device=device,
117117
)
118118
else:
119119
bands = freq_bands(
120120
num_bands,
121121
temperature=temperature,
122122
step=1,
123-
dtype=dtype,
124123
device=device,
125124
)
126125
else:
@@ -130,9 +129,9 @@ def build_fourier_pos_embed(
130129
dtype = bands.dtype
131130

132131
if in_pixels:
133-
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
132+
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
134133
else:
135-
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
134+
t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
136135

137136
if ref_feat_shape is not None:
138137
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
@@ -142,7 +141,7 @@ def build_fourier_pos_embed(
142141
grid = grid.unsqueeze(-1)
143142
pos = grid * bands
144143

145-
pos_sin, pos_cos = pos.sin(), pos.cos()
144+
pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
146145
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
147146
return out
148147

timm/models/edgenext.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ def forward(self, shape: Tuple[int, int, int]):
4141
device = self.token_projection.weight.device
4242
dtype = self.token_projection.weight.dtype
4343
inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
44-
y_embed = inv_mask.cumsum(1, dtype=dtype)
45-
x_embed = inv_mask.cumsum(2, dtype=dtype)
44+
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
45+
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
4646
eps = 1e-6
4747
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
4848
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
4949

50-
dim_t = torch.arange(self.hidden_dim, dtype=dtype, device=device)
50+
dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32)
5151
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
5252

5353
pos_x = x_embed[:, :, :, None] / dim_t
@@ -59,7 +59,7 @@ def forward(self, shape: Tuple[int, int, int]):
5959
(pos_y[:, :, :, 0::2].sin(),
6060
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
6161
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
62-
pos = self.token_projection(pos)
62+
pos = self.token_projection(pos.to(dtype))
6363

6464
return pos
6565

timm/models/swin_transformer_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ def __init__(
105105
)
106106

107107
# get relative_coords_table
108-
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
109-
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
108+
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0]).to(torch.float32)
109+
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1]).to(torch.float32)
110110
relative_coords_table = torch.stack(torch.meshgrid([
111111
relative_coords_h,
112112
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2

0 commit comments

Comments
 (0)