Skip to content

Commit 6862c98

Browse files
authored
fix backward in hgnet
1 parent 6cd28bc commit 6862c98

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

timm/models/hgnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self,
2323
scale_value=1.0,
2424
bias_value=0.0):
2525
super().__init__()
26-
self.scale = nn.Parameter(torch.tensor([scale_value]))
27-
self.bias = nn.Parameter(torch.tensor([bias_value]))
26+
self.scale = nn.Parameter(torch.tensor([scale_value]), requires_grad=True)
27+
self.bias = nn.Parameter(torch.tensor([bias_value]), requires_grad=True)
2828

2929
def forward(self, x):
3030
return self.scale * x + self.bias
@@ -262,7 +262,7 @@ def forward(self, x):
262262
x = torch.cat(output, dim=1)
263263
x = self.aggregation(x)
264264
if self.residual:
265-
x += identity
265+
x = x + identity
266266
return x
267267

268268

0 commit comments

Comments
 (0)