@@ -276,21 +276,30 @@ def connecting_geodesic(self, p1, p2, curve=None):
276276 curve input.
277277 """
278278 device = p1 .device
279- idx1 = self ._grid_point (p1 )
280- idx2 = self ._grid_point (p2 )
281- path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
282- weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
283- mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
284- raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
285- coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
286- t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
279+ if p1 .ndim == 1 :
280+ p1 = p1 .unsqueeze (0 ) # 1xD
281+ if p2 .ndim == 1 :
282+ p2 = p2 .unsqueeze (0 ) # 1xD
283+ B = p1 .shape [0 ]
284+ if p1 .shape != p2 .shape :
285+ raise NameError ('shape mismatch' )
287286
288287 if curve is None :
289288 curve = CubicSpline (p1 , p2 )
290289 else :
291290 curve .begin = p1
292291 curve .end = p2
293292
294- curve .fit (t , coordinates )
293+ for b in range (B ):
294+ idx1 = self ._grid_point (p1 [b ].unsqueeze (0 ))
295+ idx2 = self ._grid_point (p2 [b ].unsqueeze (0 ))
296+ path = nx .shortest_path (self .G , source = idx1 , target = idx2 , weight = 'weight' ) # list with N elements
297+ weights = [self .G .edges [path [k ], path [k + 1 ]]['weight' ] for k in range (len (path ) - 1 )]
298+ mesh = torch .meshgrid (* self .grid , indexing = 'ij' )
299+ raw_coordinates = [m .flatten ()[path [1 :- 1 ]].view (- 1 , 1 ) for m in mesh ]
300+ coordinates = torch .cat (raw_coordinates , dim = 1 ) # Nx(dim)
301+ t = torch .tensor (weights [:- 1 ], device = device ).cumsum (dim = 0 ) / sum (weights )
302+
303+ curve [b ].fit (t , coordinates )
295304
296305 return curve , True
0 commit comments