@@ -48,18 +48,19 @@ def __init__(self, hidden_dim=32, dim=768, temperature=10000):
4848
4949 def forward (self , B : int , H : int , W : int ):
5050 device = self .token_projection .weight .device
51- y_embed = torch .arange (1 , H + 1 , dtype = torch .float32 , device = device ).unsqueeze (1 ).repeat (1 , 1 , W )
52- x_embed = torch .arange (1 , W + 1 , dtype = torch .float32 , device = device ).repeat (1 , H , 1 )
51+ dtype = self .token_projection .weight .dtype
52+ y_embed = torch .arange (1 , H + 1 , device = device ).to (torch .float32 ).unsqueeze (1 ).repeat (1 , 1 , W )
53+ x_embed = torch .arange (1 , W + 1 , device = device ).to (torch .float32 ).repeat (1 , H , 1 )
5354 y_embed = y_embed / (y_embed [:, - 1 :, :] + self .eps ) * self .scale
5455 x_embed = x_embed / (x_embed [:, :, - 1 :] + self .eps ) * self .scale
55- dim_t = torch .arange (self .hidden_dim , dtype = torch .float32 , device = device )
56+ dim_t = torch .arange (self .hidden_dim , device = device ). to ( torch .float32 )
5657 dim_t = self .temperature ** (2 * torch .div (dim_t , 2 , rounding_mode = 'floor' ) / self .hidden_dim )
5758 pos_x = x_embed [:, :, :, None ] / dim_t
5859 pos_y = y_embed [:, :, :, None ] / dim_t
5960 pos_x = torch .stack ([pos_x [:, :, :, 0 ::2 ].sin (), pos_x [:, :, :, 1 ::2 ].cos ()], dim = 4 ).flatten (3 )
6061 pos_y = torch .stack ([pos_y [:, :, :, 0 ::2 ].sin (), pos_y [:, :, :, 1 ::2 ].cos ()], dim = 4 ).flatten (3 )
6162 pos = torch .cat ((pos_y , pos_x ), dim = 3 ).permute (0 , 3 , 1 , 2 )
62- pos = self .token_projection (pos )
63+ pos = self .token_projection (pos . to ( dtype ) )
6364 return pos .repeat (B , 1 , 1 , 1 ) # (B, C, H, W)
6465
6566
@@ -890,6 +891,7 @@ def xcit_large_24_p8_384(pretrained=False, **kwargs) -> Xcit:
890891 'xcit_small_12_p16_224_dist' : 'xcit_small_12_p16_224.fb_dist_in1k' ,
891892 'xcit_small_12_p16_384_dist' : 'xcit_small_12_p16_384.fb_dist_in1k' ,
892893 'xcit_small_24_p16_224_dist' : 'xcit_small_24_p16_224.fb_dist_in1k' ,
894+ 'xcit_small_24_p16_384_dist' : 'xcit_small_24_p16_384.fb_dist_in1k' ,
893895 'xcit_medium_24_p16_224_dist' : 'xcit_medium_24_p16_224.fb_dist_in1k' ,
894896 'xcit_medium_24_p16_384_dist' : 'xcit_medium_24_p16_384.fb_dist_in1k' ,
895897 'xcit_large_24_p16_224_dist' : 'xcit_large_24_p16_224.fb_dist_in1k' ,
0 commit comments