|
1 | 1 | import argparse |
2 | | -import json |
3 | 2 | import logging |
4 | | -import os |
5 | | -from typing import Callable, Optional, cast |
| 3 | +from typing import Callable, Optional |
6 | 4 |
|
7 | 5 | from clearml import Task |
8 | 6 |
|
9 | 7 | from ..utils.canceled_error import CanceledError |
10 | 8 | from ..utils.progress_status import ProgressStatus |
11 | | -from .build_clearml_helper import report_clearml_progress |
| 9 | +from .build_clearml_helper import report_clearml_progress, update_settings |
12 | 10 | from .config import SETTINGS |
13 | 11 | from .nmt_engine_build_job import NmtEngineBuildJob |
14 | 12 | from .nmt_model_factory import NmtModelFactory |
@@ -47,26 +45,11 @@ def clearml_progress(status: ProgressStatus) -> None: |
47 | 45 |
|
48 | 46 | try: |
49 | 47 | logger.info("NMT Engine Build Job started") |
50 | | - |
51 | | - SETTINGS.update(args) |
52 | | - model_type = cast(str, SETTINGS.model_type).lower() |
53 | | - if "build_options" in SETTINGS: |
54 | | - try: |
55 | | - build_options = json.loads(cast(str, SETTINGS.build_options)) |
56 | | - except ValueError as e: |
57 | | - raise ValueError("Build options could not be parsed: Invalid JSON") from e |
58 | | - except TypeError as e: |
59 | | - raise TypeError(f"Build options could not be parsed: {e}") from e |
60 | | - SETTINGS.update({model_type: build_options}) |
61 | | - if "align_pretranslations" in build_options: |
62 | | - SETTINGS.update({"align_pretranslations": build_options["align_pretranslations"]}) |
63 | | - SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir)) |
64 | | - |
65 | | - logger.info(f"Config: {SETTINGS.as_dict()}") |
| 48 | + update_settings(SETTINGS, args, task, logger) |
66 | 49 |
|
67 | 50 | translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) |
68 | 51 | nmt_model_factory: NmtModelFactory |
69 | | - if model_type == "huggingface": |
| 52 | + if SETTINGS.model_type == "huggingface": |
70 | 53 | from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory |
71 | 54 |
|
72 | 55 | nmt_model_factory = HuggingFaceNmtModelFactory(SETTINGS) |
|
0 commit comments