@@ -15,24 +15,23 @@ def pixel_freq_bands(
1515 num_bands : int ,
1616 max_freq : float = 224. ,
1717 linear_bands : bool = True ,
18- dtype : torch .dtype = torch .float32 ,
1918 device : Optional [torch .device ] = None ,
2019):
2120 if linear_bands :
22- bands = torch .linspace (1.0 , max_freq / 2 , num_bands , dtype = dtype , device = device )
21+ bands = torch .linspace (1.0 , max_freq / 2 , num_bands , dtype = torch . float32 , device = device )
2322 else :
24- bands = 2 ** torch .linspace (0 , math .log (max_freq , 2 ) - 1 , num_bands , dtype = dtype , device = device )
23+ bands = 2 ** torch .linspace (0 , math .log (max_freq , 2 ) - 1 , num_bands , dtype = torch . float32 , device = device )
2524 return bands * torch .pi
2625
2726
2827def freq_bands (
2928 num_bands : int ,
3029 temperature : float = 10000. ,
3130 step : int = 2 ,
32- dtype : torch .dtype = torch .float32 ,
3331 device : Optional [torch .device ] = None ,
3432) -> torch .Tensor :
35- bands = 1. / (temperature ** (torch .arange (0 , num_bands , step , dtype = dtype , device = device ) / num_bands ))
33+ exp = torch .arange (0 , num_bands , step , dtype = torch .int64 , device = device ).to (torch .float32 ) / num_bands
34+ bands = 1. / (temperature ** exp )
3635 return bands
3736
3837
@@ -61,18 +60,20 @@ def build_sincos2d_pos_embed(
6160 """
6261 assert dim % 4 == 0 , 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
6362 pos_dim = dim // 4
64- bands = freq_bands (pos_dim , temperature = temperature , step = 1 , dtype = dtype , device = device )
63+ bands = freq_bands (pos_dim , temperature = temperature , step = 1 , device = device )
6564
6665 if reverse_coord :
6766 feat_shape = feat_shape [::- 1 ] # stack W, H instead of H, W
6867 grid = torch .stack (torch .meshgrid (
69- [torch .arange (s , device = device , dtype = dtype ) for s in feat_shape ])).flatten (1 ).transpose (0 , 1 )
68+ [torch .arange (s , device = device , dtype = torch .int64 ).to (torch .float32 )
69+ for s in feat_shape ])
70+ ).flatten (1 ).transpose (0 , 1 )
7071 pos2 = grid .unsqueeze (- 1 ) * bands .unsqueeze (0 )
7172 # FIXME add support for unflattened spatial dim?
7273
7374 stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
7475 pos_emb = torch .stack ([torch .sin (pos2 ), torch .cos (pos2 )], dim = stack_dim ).flatten (1 )
75- return pos_emb
76+ return pos_emb . to ( dtype = dtype )
7677
7778
7879def build_fourier_pos_embed (
@@ -112,15 +113,13 @@ def build_fourier_pos_embed(
112113 num_bands ,
113114 float (max_res ),
114115 linear_bands = linear_bands ,
115- dtype = dtype ,
116116 device = device ,
117117 )
118118 else :
119119 bands = freq_bands (
120120 num_bands ,
121121 temperature = temperature ,
122122 step = 1 ,
123- dtype = dtype ,
124123 device = device ,
125124 )
126125 else :
@@ -130,9 +129,9 @@ def build_fourier_pos_embed(
130129 dtype = bands .dtype
131130
132131 if in_pixels :
133- t = [torch .linspace (- 1. , 1. , steps = s , device = device , dtype = dtype ) for s in feat_shape ]
132+ t = [torch .linspace (- 1. , 1. , steps = s , device = device , dtype = torch . float32 ) for s in feat_shape ]
134133 else :
135- t = [torch .arange (s , device = device , dtype = dtype ) for s in feat_shape ]
134+ t = [torch .arange (s , device = device , dtype = torch . int64 ). to ( torch . float32 ) for s in feat_shape ]
136135
137136 if ref_feat_shape is not None :
138137 # eva's scheme for resizing rope embeddings (ref shape = pretrain)
@@ -142,7 +141,7 @@ def build_fourier_pos_embed(
142141 grid = grid .unsqueeze (- 1 )
143142 pos = grid * bands
144143
145- pos_sin , pos_cos = pos .sin (), pos .cos ()
144+ pos_sin , pos_cos = pos .sin (). to ( dtype = dtype ) , pos .cos (). to ( dtype )
146145 out = [grid , pos_sin , pos_cos ] if include_grid else [pos_sin , pos_cos ]
147146 return out
148147
0 commit comments