diff --git a/the_well/benchmark/models/fno/__init__.py b/the_well/benchmark/models/fno/__init__.py index 1773ac2a..7437f206 100755 --- a/the_well/benchmark/models/fno/__init__.py +++ b/the_well/benchmark/models/fno/__init__.py @@ -47,7 +47,7 @@ def forward(self, x: torch.Tensor, output_shape=None, **kwargs): x = self.domain_padding.pad(x) for layer_idx in range(self.n_layers): - self.optional_checkpointing( + x = self.optional_checkpointing( self.fno_blocks, x, layer_idx, output_shape=output_shape[layer_idx] ) diff --git a/the_well/benchmark/models/tfno/__init__.py b/the_well/benchmark/models/tfno/__init__.py index 21d08518..31eb36a8 100755 --- a/the_well/benchmark/models/tfno/__init__.py +++ b/the_well/benchmark/models/tfno/__init__.py @@ -45,7 +45,7 @@ def forward(self, x: torch.Tensor, output_shape=None, **kwargs): x = self.domain_padding.pad(x) for layer_idx in range(self.n_layers): - self.optional_checkpointing( + x = self.optional_checkpointing( self.fno_blocks, x, layer_idx, output_shape=output_shape[layer_idx] )