44import logging
55import re
66from math import exp , prod
7- from typing import Iterable , List , Optional , Sequence , Tuple , Union , cast
7+ from typing import Collection , Iterable , List , Optional , Sequence , Tuple , Union , cast
88
99import torch # pyright: ignore[reportMissingImports]
1010from sacremoses import MosesPunctNormalizer
2424from transformers .tokenization_utils import BatchEncoding , TruncationStrategy
2525
2626from ...annotations .range import Range
27+ from ...corpora .aligned_word_pair import AlignedWordPair
2728from ...utils .typeshed import StrPath
2829from ..translation_engine import TranslationEngine
2930from ..translation_result import TranslationResult
@@ -163,10 +164,11 @@ def _try_translate_n_batch(
163164 builder = TranslationResultBuilder (input_tokens )
164165 for token , score in zip (output ["translation_tokens" ], output ["token_scores" ]):
165166 builder .append_token (token , TranslationSources .NMT , exp (score ))
166- src_indices = torch .argmax (output ["token_attentions" ], dim = 1 ).tolist ()
167- wa_matrix = WordAlignmentMatrix .from_word_pairs (
168- len (input_tokens ), output_length , set (zip (src_indices , range (output_length )))
169- )
167+ word_pairs : Optional [Collection [Union [AlignedWordPair , Tuple [int , int ]]]] = None
168+ if output .get ("token_attentions" ) is not None :
169+ src_indices = torch .argmax (output ["token_attentions" ], dim = 1 ).tolist ()
170+ word_pairs = set (zip (src_indices , range (output_length )))
171+ wa_matrix = WordAlignmentMatrix .from_word_pairs (len (input_tokens ), output_length , word_pairs )
170172 builder .mark_phrase (Range .create (0 , len (input_tokens )), wa_matrix )
171173 segment_results .append (builder .to_result (output ["translation_text" ]))
172174 all_results .append (segment_results )
@@ -242,12 +244,12 @@ def _forward(self, model_inputs, **generate_kwargs):
242244 config = self .model .config
243245 generate_kwargs ["min_length" ] = generate_kwargs .get ("min_length" , config .min_length )
244246 generate_kwargs ["max_length" ] = generate_kwargs .get ("max_length" , config .max_length )
247+ generate_kwargs ["output_attentions" ] = generate_kwargs .get ("output_attentions" , True )
245248 self .check_inputs (input_length , generate_kwargs ["min_length" ], generate_kwargs ["max_length" ])
246249 output = self .model .generate (
247250 ** model_inputs ,
248251 ** generate_kwargs ,
249252 output_scores = True ,
250- output_attentions = True ,
251253 return_dict_in_generate = True ,
252254 )
253255
@@ -285,36 +287,39 @@ def _forward(self, model_inputs, **generate_kwargs):
285287 if self .model .config .decoder_start_token_id is not None :
286288 scores = torch .cat ((torch .zeros (scores .shape [0 ], scores .shape [1 ], 1 , device = scores .device ), scores ), dim = 2 )
287289
288- assert attentions is not None
289- num_heads = attentions [0 ][0 ].shape [1 ]
290- indices = torch .stack (
291- (
292- torch .arange (output_ids .shape [1 ] - start_index , device = output_ids .device ).expand (in_b , n_sequences , - 1 ),
293- torch .reshape (beam_indices [:, start_index :] % num_beams , (in_b , n_sequences , - 1 )),
294- ),
295- dim = 3 ,
296- )
297- num_layers = len (attentions [0 ])
298- layer = (2 * num_layers ) // 3
299- attentions = (
300- torch .stack ([cast (Tuple [torch .FloatTensor , ...], a )[layer ][:, :, - 1 , :] for a in attentions ], dim = 0 )
301- .squeeze ()
302- .reshape (len (attentions ), in_b , num_beams , num_heads , - 1 )
303- .transpose (0 , 1 )
304- )
305- attentions = torch .mean (attentions , dim = 3 )
306- attentions = torch_gather_nd (attentions , indices , 1 )
307- if self .model .config .decoder_start_token_id is not None :
308- attentions = torch .cat (
290+ if generate_kwargs ["output_attentions" ] is True :
291+ assert attentions is not None
292+ num_heads = attentions [0 ][0 ].shape [1 ]
293+ indices = torch .stack (
309294 (
310- torch .zeros (
311- (attentions .shape [0 ], attentions .shape [1 ], 1 , attentions .shape [3 ]),
312- device = attentions .device ,
295+ torch .arange (output_ids .shape [1 ] - start_index , device = output_ids .device ).expand (
296+ in_b , n_sequences , - 1
313297 ),
314- attentions ,
298+ torch . reshape ( beam_indices [:, start_index :] % num_beams , ( in_b , n_sequences , - 1 )) ,
315299 ),
316- dim = 2 ,
300+ dim = 3 ,
317301 )
302+ num_layers = len (attentions [0 ])
303+ layer = (2 * num_layers ) // 3
304+ attentions = (
305+ torch .stack ([cast (Tuple [torch .FloatTensor , ...], a )[layer ][:, :, - 1 , :] for a in attentions ], dim = 0 )
306+ .squeeze ()
307+ .reshape (len (attentions ), in_b , num_beams , num_heads , - 1 )
308+ .transpose (0 , 1 )
309+ )
310+ attentions = torch .mean (attentions , dim = 3 )
311+ attentions = torch_gather_nd (attentions , indices , 1 )
312+ if self .model .config .decoder_start_token_id is not None :
313+ attentions = torch .cat (
314+ (
315+ torch .zeros (
316+ (attentions .shape [0 ], attentions .shape [1 ], 1 , attentions .shape [3 ]),
317+ device = attentions .device ,
318+ ),
319+ attentions ,
320+ ),
321+ dim = 2 ,
322+ )
318323
319324 output_ids = output_ids .reshape (in_b , n_sequences , * output_ids .shape [1 :])
320325 return {
@@ -339,37 +344,55 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
339344 input_tokens = model_outputs ["input_tokens" ][0 ]
340345
341346 records = []
342- output_ids : torch .Tensor
343- scores : torch .Tensor
344- attentions : torch .Tensor
345- for output_ids , scores , attentions in zip (
346- model_outputs ["output_ids" ][0 ],
347- model_outputs ["scores" ][0 ],
348- model_outputs ["attentions" ][0 ],
349- ):
347+
348+ has_attentions = model_outputs .get ("attentions" ) is not None and model_outputs ["attentions" ][0 ] is not None
349+ if has_attentions :
350+ zipped = zip (
351+ model_outputs ["output_ids" ][0 ],
352+ model_outputs ["scores" ][0 ],
353+ model_outputs ["attentions" ][0 ],
354+ )
355+ else :
356+ zipped = zip (
357+ model_outputs ["output_ids" ][0 ],
358+ model_outputs ["scores" ][0 ],
359+ )
360+
361+ for item in zipped :
362+ if has_attentions :
363+ output_ids , scores , attentions = cast (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ], item )
364+ else :
365+ output_ids , scores = cast (Tuple [torch .Tensor , torch .Tensor ], item )
366+ attentions = None
367+
350368 output_tokens : List [str ] = []
351369 output_indices : List [int ] = []
352370 for i , output_id in enumerate (output_ids ):
353371 id = cast (int , output_id .item ())
354372 if id not in all_special_ids :
355373 output_tokens .append (self .tokenizer .convert_ids_to_tokens (id ))
356374 output_indices .append (i )
375+
357376 scores = scores [output_indices ]
358- attentions = attentions [output_indices ]
359- attentions = attentions [:, input_indices ]
360- records .append (
361- {
362- "input_tokens" : input_tokens ,
363- "translation_tokens" : output_tokens ,
364- "token_scores" : scores ,
365- "token_attentions" : attentions ,
366- "translation_text" : self .tokenizer .decode (
367- output_ids ,
368- skip_special_tokens = True ,
369- clean_up_tokenization_spaces = clean_up_tokenization_spaces ,
370- ),
371- }
372- )
377+
378+ record = {
379+ "input_tokens" : input_tokens ,
380+ "translation_tokens" : output_tokens ,
381+ "token_scores" : scores ,
382+ "translation_text" : self .tokenizer .decode (
383+ output_ids ,
384+ skip_special_tokens = True ,
385+ clean_up_tokenization_spaces = clean_up_tokenization_spaces ,
386+ ),
387+ }
388+
389+ if attentions is not None :
390+ attentions = attentions [output_indices ]
391+ attentions = attentions [:, input_indices ]
392+ record ["token_attentions" ] = attentions
393+
394+ records .append (record )
395+
373396 return records
374397
375398
0 commit comments