1- from typing import Callable , Generator , Optional , Union
1+ from typing import Callable , Generator , Iterable , Optional , Union
22
3+ from ..corpora .corpora_utils import batch
34from ..corpora .parallel_text_corpus import ParallelTextCorpus
45from ..corpora .parallel_text_row import ParallelTextRow
56from ..utils .progress_status import ProgressStatus
@@ -48,11 +49,11 @@ def is_source_tokenized(self) -> bool:
4849 def is_target_tokenized (self ) -> bool :
4950 return self ._corpus .is_target_tokenized
5051
51- def _get_rows (self ) -> Generator [ParallelTextRow , None , None ]:
52- with self ._corpus .batch ( self . _batch_size ) as batches :
53- for batch in batches :
54- alignments = self ._aligner .align_batch (batch )
55- for row , alignment in zip (batch , alignments ):
52+ def _get_rows (self , text_ids : Optional [ Iterable [ str ]] = None ) -> Generator [ParallelTextRow , None , None ]:
53+ with self ._corpus .get_rows ( text_ids ) as rows :
54+ for row_batch in batch ( rows , self . _batch_size ) :
55+ alignments = self ._aligner .align_batch (row_batch )
56+ for row , alignment in zip (row_batch , alignments ):
5657 known_alignment = WordAlignmentMatrix .from_parallel_text_row (row )
5758 if known_alignment is not None :
5859 known_alignment .priority_symmetrize_with (alignment )
@@ -78,12 +79,12 @@ def is_source_tokenized(self) -> bool:
7879 def is_target_tokenized (self ) -> bool :
7980 return self ._corpus .is_target_tokenized
8081
81- def _get_rows (self ) -> Generator [ParallelTextRow , None , None ]:
82- with self ._corpus .batch ( self . _batch_size ) as batches :
83- for batch in batches :
82+ def _get_rows (self , text_ids : Optional [ Iterable [ str ]] = None ) -> Generator [ParallelTextRow , None , None ]:
83+ with self ._corpus .get_rows ( text_ids ) as rows :
84+ for row_batch in batch ( rows , self . _batch_size ) :
8485 translations = self ._translation_engine .translate_batch (
85- [r .source_segment if self .is_source_tokenized else r .source_text for r in batch ]
86+ [r .source_segment if self .is_source_tokenized else r .source_text for r in row_batch ]
8687 )
87- for row , translation in zip (batch , translations ):
88+ for row , translation in zip (row_batch , translations ):
8889 row .target_segment = translation .target_tokens
8990 yield row
0 commit comments