Skip to content

Commit 0354cc7

Browse files
authored
Fixes for generating alignments alongside pretranslations (#179)
1 parent 0ff0f8d commit 0354cc7

File tree

5 files changed

+17
-15
lines changed

5 files changed

+17
-15
lines changed

machine/jobs/build_nmt_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def clearml_progress(status: ProgressStatus) -> None:
5555
except TypeError as e:
5656
raise TypeError(f"Build options could not be parsed: {e}") from e
5757
SETTINGS.update({model_type: build_options})
58+
if "align_pretranslations" in build_options:
59+
SETTINGS.update({"align_pretranslations": build_options["align_pretranslations"]})
5860
SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir))
5961

6062
logger.info(f"Config: {SETTINGS.as_dict()}")

machine/jobs/nmt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ def _align(
157157
check_canceled()
158158

159159
for i in range(len(pretranslations)):
160-
pretranslations[i]["source_tokens"] = list(src_tokenized[i])
161-
pretranslations[i]["translation_tokens"] = list(trg_tokenized[i])
160+
pretranslations[i]["sourceTokens"] = list(src_tokenized[i])
161+
pretranslations[i]["translationTokens"] = list(trg_tokenized[i])
162162
pretranslations[i]["alignment"] = alignments[i]
163163

164164
return pretranslations

machine/jobs/translation_file_service.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ class PretranslationInfo(TypedDict):
1616
textId: str # noqa: N815
1717
refs: List[str]
1818
translation: str
19-
source_tokens: List[str]
20-
translation_tokens: List[str]
19+
sourceTokens: List[str] # noqa: N815
20+
translationTokens: List[str] # noqa: N815
2121
alignment: str
2222

2323

@@ -65,9 +65,9 @@ def generator() -> Generator[PretranslationInfo, None, None]:
6565
textId=pi["textId"],
6666
refs=list(pi["refs"]),
6767
translation=pi["translation"],
68-
source_tokens=list(pi["source_tokens"]),
69-
translation_tokens=list(pi["translation_tokens"]),
70-
alignment=pi["alignment"],
68+
sourceTokens=list(),
69+
translationTokens=list(),
70+
alignment="",
7171
)
7272

7373
return ContextManagedGenerator(generator())

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_run(decoy: Decoy) -> None:
3838
assert len(pretranslations) == 1
3939
assert pretranslations[0]["translation"] == "Please, I have booked a room."
4040
if is_eflomal_available():
41-
assert pretranslations[0]["source_tokens"] == [
41+
assert pretranslations[0]["sourceTokens"] == [
4242
"Por",
4343
"favor",
4444
",",
@@ -48,11 +48,11 @@ def test_run(decoy: Decoy) -> None:
4848
"habitación",
4949
".",
5050
]
51-
assert pretranslations[0]["translation_tokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
51+
assert pretranslations[0]["translationTokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
5252
assert len(pretranslations[0]["alignment"]) > 0
5353
else:
54-
assert pretranslations[0]["source_tokens"] == []
55-
assert pretranslations[0]["translation_tokens"] == []
54+
assert pretranslations[0]["sourceTokens"] == []
55+
assert pretranslations[0]["translationTokens"] == []
5656
assert len(pretranslations[0]["alignment"]) == 0
5757
decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
5858

@@ -131,8 +131,8 @@ def __init__(self, decoy: Decoy) -> None:
131131
textId="text1",
132132
refs=["ref1"],
133133
translation="Por favor, tengo reservada una habitación.",
134-
source_tokens=[],
135-
translation_tokens=[],
134+
sourceTokens=[],
135+
translationTokens=[],
136136
alignment="",
137137
)
138138
]

tests/jobs/test_smt_engine_build_job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def __init__(self, decoy: Decoy) -> None:
137137
textId="text1",
138138
refs=["ref1"],
139139
translation="Por favor, tengo reservada una habitación.",
140-
source_tokens=[],
141-
translation_tokens=[],
140+
sourceTokens=[],
141+
translationTokens=[],
142142
alignment="",
143143
)
144144
]

0 commit comments

Comments
 (0)