@@ -285,20 +285,21 @@ def connecting_geodesic(self, p1, p2, curve=None):
285285 raise NameError ('shape mismatch' )
286286
287287 if curve is None :
288- curve = CubicSpline (p1 . detach () , p2 . detach () )
288+ curve = CubicSpline (p1 , p2 )
289289 else :
290290 curve .begin = p1
291291 curve .end = p2
292292
293293 for b in range (B ):
294- idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
295- idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
296- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
297- weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
298- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
299- raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
300- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
301- t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
294+ with torch .no_grad ():
295+ idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
296+ idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
297+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
298+ weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
299+ mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
300+ raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
301+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
302+ t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
302303
303304 curve [b ].fit (t , coordinates )
304305
0 commit comments