Skip to content

Commit 0fb9518

Browse files
authored
Fix moses punct (#140)
* Fix moses punctuation * max_number actually not there - but it's in the generation_config!
1 parent 730ea67 commit 0fb9518

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def __init__(
5555
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)
5656
if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)):
5757
self._mpn = MosesPunctNormalizer()
58-
self._mpn.substitutions = [
59-
(str(re.compile(r)), sub)
58+
self._mpn.substitutions = [ # type: ignore
59+
(re.compile(r), sub)
6060
for r, sub in self._mpn.substitutions
6161
if isinstance(r, str) and isinstance(sub, str)
6262
]
@@ -236,8 +236,12 @@ def _forward(self, model_inputs, **generate_kwargs):
236236

237237
input_tokens = model_inputs["input_tokens"]
238238
del model_inputs["input_tokens"]
239-
generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length)
240-
generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length)
239+
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
240+
config = self.model.generation_config
241+
else:
242+
config = self.model.config
243+
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
244+
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
241245
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
242246
output = self.model.generate(
243247
**model_inputs,

machine/translation/huggingface/hugging_face_nmt_model_trainer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ def __init__(
100100
self._add_unk_src_tokens = add_unk_src_tokens
101101
self._add_unk_tgt_tokens = add_unk_tgt_tokens
102102
self._mpn = MosesPunctNormalizer()
103-
self._mpn.substitutions = [
104-
(str(re.compile(r)), sub)
105-
for r, sub in self._mpn.substitutions
106-
if isinstance(r, str) and isinstance(sub, str)
103+
self._mpn.substitutions = [ # type: ignore
104+
(re.compile(r), sub) for r, sub in self._mpn.substitutions if isinstance(r, str) and isinstance(sub, str)
107105
]
108106
self._stats = TrainStats()
109107

0 commit comments

Comments
 (0)