@@ -55,8 +55,8 @@ def __init__(
5555 self ._tokenizer = AutoTokenizer .from_pretrained (self ._model .name_or_path , use_fast = True )
5656 if isinstance (self ._tokenizer , (NllbTokenizer , NllbTokenizerFast )):
5757 self ._mpn = MosesPunctNormalizer ()
58- self ._mpn .substitutions = [
59- (str ( re .compile (r ) ), sub )
58+ self ._mpn .substitutions = [ # type: ignore
59+ (re .compile (r ), sub )
6060 for r , sub in self ._mpn .substitutions
6161 if isinstance (r , str ) and isinstance (sub , str )
6262 ]
@@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs):
236236
237237 input_tokens = model_inputs ["input_tokens" ]
238238 del model_inputs ["input_tokens" ]
239- generate_kwargs ["min_length" ] = generate_kwargs .get ("min_length" , self .model .config .min_length )
240- generate_kwargs ["max_length" ] = generate_kwargs .get ("max_length" , self .model .config .max_length )
239+ if hasattr (self .model , "generation_config" ) and self .model .generation_config is not None :
240+ config = self .model .generation_config
241+ else :
242+ config = self .model .config
243+ generate_kwargs ["min_length" ] = generate_kwargs .get ("min_length" , config .min_length )
244+ generate_kwargs ["max_length" ] = generate_kwargs .get ("max_length" , config .max_length )
241245 self .check_inputs (input_length , generate_kwargs ["min_length" ], generate_kwargs ["max_length" ])
242246 output = self .model .generate (
243247 ** model_inputs ,
0 commit comments