@@ -307,6 +307,72 @@ def _jacobian(self, x: Tensor, val: Tensor) -> Tensor:
307307 return jac
308308
309309
310+ class MaxPool1d (AbstractJacobian , nn .MaxPool1d ):
311+ def forward (self , input : Tensor ):
312+ val , idx = F .max_pool1d (
313+ input , self .kernel_size , self .stride ,
314+ self .padding , self .dilation , self .ceil_mode ,
315+ return_indices = True
316+ )
317+ self .idx = idx
318+ return val
319+
320+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
321+ b , c1 , l1 = x .shape
322+ c2 , l2 = val .shape [1 :]
323+
324+ jac_in_orig_shape = jac_in .shape
325+ jac_in = jac_in .reshape (- 1 , l1 , * jac_in_orig_shape [3 :])
326+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), l2 ).long ()
327+ idx = self .idx .reshape (- 1 )
328+ jac_in = jac_in [arange_repeated , idx , :, :].reshape (* val .shape , * jac_in_orig_shape [3 :])
329+ return jac_in
330+
331+
332+ class MaxPool2d (AbstractJacobian , nn .MaxPool2d ):
333+ def forward (self , input : Tensor ):
334+ val , idx = F .max_pool2d (
335+ input , self .kernel_size , self .stride ,
336+ self .padding , self .dilation , self .ceil_mode ,
337+ return_indices = True
338+ )
339+ self .idx = idx
340+ return val
341+
342+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
343+ b , c1 , h1 , w1 = x .shape
344+ c2 , h2 , w2 = val .shape [1 :]
345+
346+ jac_in_orig_shape = jac_in .shape
347+ jac_in = jac_in .reshape (- 1 , h1 * w1 , * jac_in_orig_shape [4 :])
348+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * w2 ).long ()
349+ idx = self .idx .reshape (- 1 )
350+ jac_in = jac_in [arange_repeated , idx , :, :, :].reshape (* val .shape , * jac_in_orig_shape [4 :])
351+ return jac_in
352+
353+
354+ class MaxPool3d (AbstractJacobian , nn .MaxPool3d ):
355+ def forward (self , input : Tensor ):
356+ val , idx = F .max_pool3d (
357+ input , self .kernel_size , self .stride ,
358+ self .padding , self .dilation , self .ceil_mode ,
359+ return_indices = True
360+ )
361+ self .idx = idx
362+ return val
363+
364+ def _jacobian_mult (self , x : Tensor , val : Tensor , jac_in : Tensor ) -> Tensor :
365+ b , c1 , d1 , h1 , w1 = x .shape
366+ c2 , d2 , h2 , w2 = val .shape [1 :]
367+
368+ jac_in_orig_shape = jac_in .shape
369+ jac_in = jac_in .reshape (- 1 , d1 * h1 * w1 , * jac_in_orig_shape [5 :])
370+ arange_repeated = torch .repeat_interleave (torch .arange (b * c1 ), h2 * d2 * w2 ).long ()
371+ idx = self .idx .reshape (- 1 )
372+ jac_in = jac_in [arange_repeated , idx , :, :].reshape (* val .shape , * jac_in_orig_shape [5 :])
373+ return jac_in
374+
375+
310376class Sigmoid (AbstractActivationJacobian , nn .Sigmoid ):
311377 def _jacobian (self , x : Tensor , val : Tensor ) -> Tensor :
312378 jac = val * (1.0 - val )
0 commit comments