Skip to content

Commit 730ea67

Browse files
authored
fix deprecated tokenizer methods, add test cases (#139)
* fix deprecated tokenizer methods, add test cases * make add_lang_code_to_tokenizer a public function again * fix test import
1 parent 1d1dc25 commit 730ea67

File tree

4 files changed

+150
-36
lines changed

4 files changed

+150
-36
lines changed

machine/translation/huggingface/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88

99
from .hugging_face_nmt_engine import HuggingFaceNmtEngine
1010
from .hugging_face_nmt_model import HuggingFaceNmtModel
11-
from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer
11+
from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer, add_lang_code_to_tokenizer
1212

13-
__all__ = ["HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"]
13+
__all__ = ["add_lang_code_to_tokenizer", "HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"]

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
import logging
55
import re
66
from math import exp, prod
7-
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union, cast
7+
from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast
88

99
import torch # pyright: ignore[reportMissingImports]
1010
from sacremoses import MosesPunctNormalizer
1111
from transformers import (
1212
AutoConfig,
1313
AutoModelForSeq2SeqLM,
1414
AutoTokenizer,
15+
M2M100Tokenizer,
1516
NllbTokenizer,
1617
NllbTokenizerFast,
1718
PreTrainedModel,
@@ -73,17 +74,23 @@ def __init__(
7374
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
7475
else:
7576
additional_special_tokens = self._tokenizer.additional_special_tokens
77+
if isinstance(self._tokenizer, M2M100Tokenizer):
78+
src_lang_token = self._tokenizer.lang_code_to_token.get(src_lang) if src_lang is not None else None
79+
tgt_lang_token = self._tokenizer.lang_code_to_token.get(tgt_lang) if tgt_lang is not None else None
80+
else:
81+
src_lang_token = src_lang
82+
tgt_lang_token = tgt_lang
7683
if (
7784
src_lang is not None
78-
and src_lang not in cast(Any, self._tokenizer).lang_code_to_id
79-
and src_lang not in additional_special_tokens
85+
and src_lang_token not in self._tokenizer.added_tokens_encoder
86+
and src_lang_token not in additional_special_tokens
8087
):
8188
raise ValueError(f"The specified model does not support the language code '{src_lang}'")
8289

8390
if (
8491
tgt_lang is not None
85-
and tgt_lang not in cast(Any, self._tokenizer).lang_code_to_id
86-
and tgt_lang not in additional_special_tokens
92+
and tgt_lang_token not in self._tokenizer.added_tokens_encoder
93+
and tgt_lang_token not in additional_special_tokens
8794
):
8895
raise ValueError(f"The specified model does not support the language code '{tgt_lang}'")
8996

machine/translation/huggingface/hugging_face_nmt_model_trainer.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
NllbTokenizer,
2525
NllbTokenizerFast,
2626
PreTrainedModel,
27+
PreTrainedTokenizer,
2728
PreTrainedTokenizerFast,
2829
Seq2SeqTrainer,
2930
Seq2SeqTrainingArguments,
@@ -218,24 +219,6 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
218219
if missing_tokens:
219220
tokenizer = add_tokens(tokenizer, missing_tokens)
220221

221-
def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
222-
if lang_code in tokenizer.lang_code_to_id:
223-
return
224-
tokenizer.add_special_tokens(
225-
{"additional_special_tokens": tokenizer.additional_special_tokens + [lang_code]}
226-
)
227-
lang_id = tokenizer.convert_tokens_to_ids(lang_code)
228-
tokenizer.lang_code_to_id[lang_code] = lang_id
229-
230-
if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
231-
tokenizer.id_to_lang_code[lang_id] = lang_code
232-
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
233-
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
234-
elif isinstance(tokenizer, M2M100Tokenizer):
235-
tokenizer.lang_code_to_token[lang_code] = lang_code
236-
tokenizer.lang_token_to_id[lang_code] = lang_id
237-
tokenizer.id_to_lang_token[lang_id] = lang_code
238-
239222
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
240223
logger.info("Add new language codes as tokens")
241224
if self._src_lang is not None:
@@ -413,3 +396,29 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
413396
if self._max_steps is None
414397
else ProgressStatus.from_step(state.global_step, self._max_steps)
415398
)
399+
400+
401+
def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
402+
if isinstance(tokenizer, M2M100Tokenizer):
403+
lang_token = "__" + lang_code + "__"
404+
else:
405+
lang_token = lang_code
406+
407+
if lang_token in tokenizer.added_tokens_encoder:
408+
return
409+
410+
tokenizer.add_special_tokens(
411+
{"additional_special_tokens": tokenizer.additional_special_tokens + [lang_token]} # type: ignore
412+
)
413+
lang_id = cast(int, tokenizer.convert_tokens_to_ids(lang_token))
414+
415+
if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
416+
tokenizer.lang_code_to_id[lang_code] = lang_id
417+
tokenizer.id_to_lang_code[lang_id] = lang_code
418+
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
419+
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
420+
elif isinstance(tokenizer, M2M100Tokenizer):
421+
tokenizer.lang_code_to_id[lang_code] = lang_id
422+
tokenizer.lang_code_to_token[lang_code] = lang_token
423+
tokenizer.lang_token_to_id[lang_token] = lang_id
424+
tokenizer.id_to_lang_token[lang_id] = lang_token

tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,22 @@
66
skip("skipping Hugging Face tests on MacOS", allow_module_level=True)
77

88
from tempfile import TemporaryDirectory
9-
10-
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainingArguments
9+
from typing import cast
10+
11+
from transformers import (
12+
M2M100Tokenizer,
13+
MBart50Tokenizer,
14+
MBart50TokenizerFast,
15+
MBartTokenizer,
16+
MBartTokenizerFast,
17+
NllbTokenizer,
18+
NllbTokenizerFast,
19+
PreTrainedTokenizerFast,
20+
Seq2SeqTrainingArguments,
21+
)
1122

1223
from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow
13-
from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer
24+
from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer, add_lang_code_to_tokenizer
1425

1526

1627
def test_train_non_empty_corpus() -> None:
@@ -142,10 +153,8 @@ def test_update_tokenizer_missing_char() -> None:
142153
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
143154
)
144155
finetuned_result_nochar_composite = finetuned_engine_nochar.tokenizer.encode("Ḏ is a composite character")
145-
normalized_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str(
146-
"‌ "
147-
)
148-
normalized_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
156+
norm_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
157+
norm_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
149158

150159
with HuggingFaceNmtModelTrainer(
151160
"hf-internal-testing/tiny-random-nllb",
@@ -167,11 +176,11 @@ def test_update_tokenizer_missing_char() -> None:
167176
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
168177
)
169178
finetuned_result_char_composite = finetuned_engine_char.tokenizer.encode("Ḏ is a composite character")
170-
normalized_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
171-
normalized_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
179+
norm_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
180+
norm_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
172181

173-
assert normalized_result_nochar1 != normalized_result_char1
174-
assert normalized_result_nochar2 != normalized_result_char2
182+
assert norm_result_nochar1 != norm_result_char1
183+
assert norm_result_nochar2 != norm_result_char2
175184

176185
assert finetuned_result_nochar != finetuned_result_char
177186
assert finetuned_result_nochar_composite != finetuned_result_char_composite
@@ -467,5 +476,94 @@ def test_update_tokenizer_no_missing_char() -> None:
467476
assert finetuned_result_nochar == finetuned_result_char
468477

469478

479+
def test_nllb_tokenizer_add_lang_code() -> None:
480+
with TemporaryDirectory() as temp_dir:
481+
tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M"))
482+
assert "new_lang" not in tokenizer.added_tokens_encoder
483+
add_lang_code_to_tokenizer(tokenizer, "new_lang")
484+
assert "new_lang" in tokenizer.added_tokens_encoder
485+
tokenizer.save_pretrained(temp_dir)
486+
new_tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained(temp_dir))
487+
assert "new_lang" in new_tokenizer.added_tokens_encoder
488+
return
489+
490+
491+
def test_nllb_tokenizer_fast_add_lang_code() -> None:
492+
with TemporaryDirectory() as temp_dir:
493+
tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M"))
494+
assert "new_lang" not in tokenizer.added_tokens_encoder
495+
add_lang_code_to_tokenizer(tokenizer, "new_lang")
496+
assert "new_lang" in tokenizer.added_tokens_encoder
497+
tokenizer.save_pretrained(temp_dir)
498+
new_tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained(temp_dir))
499+
assert "new_lang" in new_tokenizer.added_tokens_encoder
500+
return
501+
502+
503+
def test_mbart_tokenizer_add_lang_code() -> None:
504+
with TemporaryDirectory() as temp_dir:
505+
tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained("hf-internal-testing/tiny-random-nllb"))
506+
assert "nl_NS" not in tokenizer.added_tokens_encoder
507+
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
508+
assert "nl_NS" in tokenizer.added_tokens_encoder
509+
tokenizer.save_pretrained(temp_dir)
510+
new_tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained(temp_dir))
511+
assert "nl_NS" in new_tokenizer.added_tokens_encoder
512+
return
513+
514+
515+
def test_mbart_tokenizer_fast_add_lang_code() -> None:
516+
with TemporaryDirectory() as temp_dir:
517+
tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-nllb"))
518+
assert "nl_NS" not in tokenizer.added_tokens_encoder
519+
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
520+
assert "nl_NS" in tokenizer.added_tokens_encoder
521+
tokenizer.save_pretrained(temp_dir)
522+
new_tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained(temp_dir))
523+
assert "nl_NS" in new_tokenizer.added_tokens_encoder
524+
return
525+
526+
527+
def test_mbart_50_tokenizer_add_lang_code() -> None:
528+
with TemporaryDirectory() as temp_dir:
529+
tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained("hf-internal-testing/tiny-random-mbart50"))
530+
assert "nl_NS" not in tokenizer.added_tokens_encoder
531+
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
532+
assert "nl_NS" in tokenizer.added_tokens_encoder
533+
tokenizer.save_pretrained(temp_dir)
534+
new_tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained(temp_dir))
535+
assert "nl_NS" in new_tokenizer.added_tokens_encoder
536+
return
537+
538+
539+
def test_mbart_50_tokenizer_fast_add_lang_code() -> None:
540+
with TemporaryDirectory() as temp_dir:
541+
tokenizer = cast(
542+
MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-mbart50")
543+
)
544+
assert "nl_NS" not in tokenizer.added_tokens_encoder
545+
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
546+
assert "nl_NS" in tokenizer.added_tokens_encoder
547+
tokenizer.save_pretrained(temp_dir)
548+
new_tokenizer = cast(MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained(temp_dir))
549+
assert "nl_NS" in new_tokenizer.added_tokens_encoder
550+
return
551+
552+
553+
def test_m2m_100_tokenizer_add_lang_code() -> None:
554+
with TemporaryDirectory() as temp_dir:
555+
tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained("stas/tiny-m2m_100"))
556+
assert "nc" not in tokenizer.lang_code_to_id
557+
assert "__nc__" not in tokenizer.added_tokens_encoder
558+
add_lang_code_to_tokenizer(tokenizer, "nc")
559+
assert "nc" in tokenizer.lang_code_to_id
560+
assert "__nc__" in tokenizer.added_tokens_encoder
561+
tokenizer.save_pretrained(temp_dir)
562+
new_tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained(temp_dir))
563+
assert "nc" in tokenizer.lang_code_to_id
564+
assert "__nc__" in new_tokenizer.added_tokens_encoder
565+
return
566+
567+
470568
def _row(row_ref: int, text: str) -> TextRow:
471569
return TextRow("text1", row_ref, segment=[text])

0 commit comments

Comments
 (0)