Skip to content

Commit ef180a3

Browse files
authored
Allow output_attentions to be set to False (#223)
1 parent 60417c4 commit ef180a3

File tree

3 files changed

+105
-68
lines changed

3 files changed

+105
-68
lines changed

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 78 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import re
66
from math import exp, prod
7-
from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast
7+
from typing import Collection, Iterable, List, Optional, Sequence, Tuple, Union, cast
88

99
import torch # pyright: ignore[reportMissingImports]
1010
from sacremoses import MosesPunctNormalizer
@@ -24,6 +24,7 @@
2424
from transformers.tokenization_utils import BatchEncoding, TruncationStrategy
2525

2626
from ...annotations.range import Range
27+
from ...corpora.aligned_word_pair import AlignedWordPair
2728
from ...utils.typeshed import StrPath
2829
from ..translation_engine import TranslationEngine
2930
from ..translation_result import TranslationResult
@@ -163,10 +164,11 @@ def _try_translate_n_batch(
163164
builder = TranslationResultBuilder(input_tokens)
164165
for token, score in zip(output["translation_tokens"], output["token_scores"]):
165166
builder.append_token(token, TranslationSources.NMT, exp(score))
166-
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
167-
wa_matrix = WordAlignmentMatrix.from_word_pairs(
168-
len(input_tokens), output_length, set(zip(src_indices, range(output_length)))
169-
)
167+
word_pairs: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None
168+
if output.get("token_attentions") is not None:
169+
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
170+
word_pairs = set(zip(src_indices, range(output_length)))
171+
wa_matrix = WordAlignmentMatrix.from_word_pairs(len(input_tokens), output_length, word_pairs)
170172
builder.mark_phrase(Range.create(0, len(input_tokens)), wa_matrix)
171173
segment_results.append(builder.to_result(output["translation_text"]))
172174
all_results.append(segment_results)
@@ -242,12 +244,12 @@ def _forward(self, model_inputs, **generate_kwargs):
242244
config = self.model.config
243245
generate_kwargs["min_length"] = generate_kwargs.get("min_length", config.min_length)
244246
generate_kwargs["max_length"] = generate_kwargs.get("max_length", config.max_length)
247+
generate_kwargs["output_attentions"] = generate_kwargs.get("output_attentions", True)
245248
self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"])
246249
output = self.model.generate(
247250
**model_inputs,
248251
**generate_kwargs,
249252
output_scores=True,
250-
output_attentions=True,
251253
return_dict_in_generate=True,
252254
)
253255

@@ -285,36 +287,39 @@ def _forward(self, model_inputs, **generate_kwargs):
285287
if self.model.config.decoder_start_token_id is not None:
286288
scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2)
287289

288-
assert attentions is not None
289-
num_heads = attentions[0][0].shape[1]
290-
indices = torch.stack(
291-
(
292-
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
293-
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
294-
),
295-
dim=3,
296-
)
297-
num_layers = len(attentions[0])
298-
layer = (2 * num_layers) // 3
299-
attentions = (
300-
torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0)
301-
.squeeze()
302-
.reshape(len(attentions), in_b, num_beams, num_heads, -1)
303-
.transpose(0, 1)
304-
)
305-
attentions = torch.mean(attentions, dim=3)
306-
attentions = torch_gather_nd(attentions, indices, 1)
307-
if self.model.config.decoder_start_token_id is not None:
308-
attentions = torch.cat(
290+
if generate_kwargs["output_attentions"] is True:
291+
assert attentions is not None
292+
num_heads = attentions[0][0].shape[1]
293+
indices = torch.stack(
309294
(
310-
torch.zeros(
311-
(attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]),
312-
device=attentions.device,
295+
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(
296+
in_b, n_sequences, -1
313297
),
314-
attentions,
298+
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
315299
),
316-
dim=2,
300+
dim=3,
317301
)
302+
num_layers = len(attentions[0])
303+
layer = (2 * num_layers) // 3
304+
attentions = (
305+
torch.stack([cast(Tuple[torch.FloatTensor, ...], a)[layer][:, :, -1, :] for a in attentions], dim=0)
306+
.squeeze()
307+
.reshape(len(attentions), in_b, num_beams, num_heads, -1)
308+
.transpose(0, 1)
309+
)
310+
attentions = torch.mean(attentions, dim=3)
311+
attentions = torch_gather_nd(attentions, indices, 1)
312+
if self.model.config.decoder_start_token_id is not None:
313+
attentions = torch.cat(
314+
(
315+
torch.zeros(
316+
(attentions.shape[0], attentions.shape[1], 1, attentions.shape[3]),
317+
device=attentions.device,
318+
),
319+
attentions,
320+
),
321+
dim=2,
322+
)
318323

319324
output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:])
320325
return {
@@ -339,37 +344,55 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
339344
input_tokens = model_outputs["input_tokens"][0]
340345

341346
records = []
342-
output_ids: torch.Tensor
343-
scores: torch.Tensor
344-
attentions: torch.Tensor
345-
for output_ids, scores, attentions in zip(
346-
model_outputs["output_ids"][0],
347-
model_outputs["scores"][0],
348-
model_outputs["attentions"][0],
349-
):
347+
348+
has_attentions = model_outputs.get("attentions") is not None and model_outputs["attentions"][0] is not None
349+
if has_attentions:
350+
zipped = zip(
351+
model_outputs["output_ids"][0],
352+
model_outputs["scores"][0],
353+
model_outputs["attentions"][0],
354+
)
355+
else:
356+
zipped = zip(
357+
model_outputs["output_ids"][0],
358+
model_outputs["scores"][0],
359+
)
360+
361+
for item in zipped:
362+
if has_attentions:
363+
output_ids, scores, attentions = cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], item)
364+
else:
365+
output_ids, scores = cast(Tuple[torch.Tensor, torch.Tensor], item)
366+
attentions = None
367+
350368
output_tokens: List[str] = []
351369
output_indices: List[int] = []
352370
for i, output_id in enumerate(output_ids):
353371
id = cast(int, output_id.item())
354372
if id not in all_special_ids:
355373
output_tokens.append(self.tokenizer.convert_ids_to_tokens(id))
356374
output_indices.append(i)
375+
357376
scores = scores[output_indices]
358-
attentions = attentions[output_indices]
359-
attentions = attentions[:, input_indices]
360-
records.append(
361-
{
362-
"input_tokens": input_tokens,
363-
"translation_tokens": output_tokens,
364-
"token_scores": scores,
365-
"token_attentions": attentions,
366-
"translation_text": self.tokenizer.decode(
367-
output_ids,
368-
skip_special_tokens=True,
369-
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
370-
),
371-
}
372-
)
377+
378+
record = {
379+
"input_tokens": input_tokens,
380+
"translation_tokens": output_tokens,
381+
"token_scores": scores,
382+
"translation_text": self.tokenizer.decode(
383+
output_ids,
384+
skip_special_tokens=True,
385+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
386+
),
387+
}
388+
389+
if attentions is not None:
390+
attentions = attentions[output_indices]
391+
attentions = attentions[:, input_indices]
392+
record["token_attentions"] = attentions
393+
394+
records.append(record)
395+
373396
return records
374397

375398

machine/translation/word_alignment_matrix.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ def from_word_pairs(
2323
cls,
2424
row_count: int,
2525
column_count: int,
26-
set_values: Collection[Union[AlignedWordPair, Tuple[int, int]]] = set(),
26+
set_values: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None,
2727
) -> WordAlignmentMatrix:
28+
if set_values is None:
29+
set_values = set()
2830
matrix = np.full((row_count, column_count), False)
2931
for i, j in set_values:
3032
matrix[i, j] = True

tests/translation/huggingface/test_hugging_face_nmt_engine.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,39 +5,51 @@
55

66
skip("skipping Hugging Face tests on MacOS", allow_module_level=True)
77

8-
from pytest import approx, raises
8+
from pytest import approx, mark, raises
99

1010
from machine.translation.huggingface import HuggingFaceNmtEngine
1111

1212

13-
def test_translate_n_batch_beam() -> None:
14-
with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", num_beams=2, max_length=10) as engine:
13+
@mark.parametrize("output_attentions", [True, False])
14+
def test_translate_n_batch_beam(output_attentions: bool) -> None:
15+
with HuggingFaceNmtEngine(
16+
"stas/tiny-m2m_100",
17+
src_lang="en",
18+
tgt_lang="es",
19+
num_beams=2,
20+
max_length=10,
21+
output_attentions=output_attentions,
22+
) as engine:
1523
results = engine.translate_n_batch(
1624
n=2,
1725
segments=["This is a test string", "Hello, world!"],
1826
)
1927
assert results[0][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
2028
assert results[0][0].confidences[0] == approx(1.08e-05, 0.01)
21-
assert str(results[0][0].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
29+
assert str(results[0][0].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
2230
assert results[0][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
2331
assert results[0][1].confidences[0] == approx(1.08e-05, 0.01)
24-
assert str(results[0][1].alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
32+
assert str(results[0][1].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
2533
assert results[1][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
2634
assert results[1][0].confidences[0] == approx(1.08e-05, 0.01)
27-
assert str(results[1][0].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6"
35+
assert str(results[1][0].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
2836
assert results[1][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
2937
assert results[1][1].confidences[0] == approx(1.08e-05, 0.01)
30-
assert str(results[1][1].alignment) == "0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6"
38+
assert str(results[1][1].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
3139

3240

33-
def test_translate_greedy() -> None:
34-
with HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10) as engine:
41+
@mark.parametrize("output_attentions", [True, False])
42+
def test_translate_greedy(output_attentions: bool) -> None:
43+
with HuggingFaceNmtEngine(
44+
"stas/tiny-m2m_100", src_lang="en", tgt_lang="es", max_length=10, output_attentions=output_attentions
45+
) as engine:
3546
result = engine.translate("This is a test string")
3647
assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir"
3748
assert result.confidences[0] == approx(1.08e-05, 0.01)
38-
assert str(result.alignment) == "2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7"
49+
assert str(result.alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
3950

4051

41-
def test_construct_invalid_lang() -> None:
52+
@mark.parametrize("output_attentions", [True, False])
53+
def test_construct_invalid_lang(output_attentions: bool) -> None:
4254
with raises(ValueError):
43-
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es")
55+
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es", output_attentions=output_attentions)

0 commit comments

Comments
 (0)