Skip to content

Commit 3234daf

Browse files
committed
Add missing deprecation mapping for a densenet and xcit model. Fix #2086. Tweak xcit pos embed use of arange for better low prec safety.
1 parent 809a9e1 commit 3234daf

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

timm/models/densenet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
1616
from ._builder import build_model_with_cfg
1717
from ._manipulate import MATCH_PREV_GROUP
18-
from ._registry import register_model, generate_default_cfgs
18+
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
1919

2020
__all__ = ['DenseNet']
2121

@@ -415,3 +415,7 @@ def densenet264d(pretrained=False, **kwargs) -> DenseNet:
415415
model = _create_densenet('densenet264d', pretrained=pretrained, **dict(model_args, **kwargs))
416416
return model
417417

418+
419+
register_model_deprecations(__name__, {
420+
'tv_densenet121': 'densenet121.tv_in1k',
421+
})

timm/models/xcit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,19 @@ def __init__(self, hidden_dim=32, dim=768, temperature=10000):
4848

4949
def forward(self, B: int, H: int, W: int):
5050
device = self.token_projection.weight.device
51-
y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W)
52-
x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1)
51+
dtype = self.token_projection.weight.dtype
52+
y_embed = torch.arange(1, H + 1, device=device).to(torch.float32).unsqueeze(1).repeat(1, 1, W)
53+
x_embed = torch.arange(1, W + 1, device=device).to(torch.float32).repeat(1, H, 1)
5354
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
5455
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
55-
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device)
56+
dim_t = torch.arange(self.hidden_dim, device=device).to(torch.float32)
5657
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
5758
pos_x = x_embed[:, :, :, None] / dim_t
5859
pos_y = y_embed[:, :, :, None] / dim_t
5960
pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
6061
pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
6162
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
62-
pos = self.token_projection(pos)
63+
pos = self.token_projection(pos.to(dtype))
6364
return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
6465

6566

@@ -890,6 +891,7 @@ def xcit_large_24_p8_384(pretrained=False, **kwargs) -> Xcit:
890891
'xcit_small_12_p16_224_dist': 'xcit_small_12_p16_224.fb_dist_in1k',
891892
'xcit_small_12_p16_384_dist': 'xcit_small_12_p16_384.fb_dist_in1k',
892893
'xcit_small_24_p16_224_dist': 'xcit_small_24_p16_224.fb_dist_in1k',
894+
'xcit_small_24_p16_384_dist': 'xcit_small_24_p16_384.fb_dist_in1k',
893895
'xcit_medium_24_p16_224_dist': 'xcit_medium_24_p16_224.fb_dist_in1k',
894896
'xcit_medium_24_p16_384_dist': 'xcit_medium_24_p16_384.fb_dist_in1k',
895897
'xcit_large_24_p16_224_dist': 'xcit_large_24_p16_224.fb_dist_in1k',

0 commit comments

Comments
 (0)