diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6313617e..28b63c37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,8 @@ repos: hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 25.1.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 26.1.0 hooks: - id: black args: [--line-length=128, --extend-exclude=.ipynb, --verbose] @@ -20,12 +20,12 @@ repos: additional_dependencies: [pycodestyle>=2.11.0] args: [--max-line-length=128, '--exclude=./.*,build,dist', '--ignore=E501,W503,E203,F841,E231,W604', --count, --statistics, --show-source] - repo: https://github.com/pycqa/isort - rev: 6.0.1 + rev: 7.0.0 hooks: - id: isort args: [--profile=black, --line-length=128] - repo: https://github.com/kynan/nbstripout - rev: 0.8.1 + rev: 0.9.0 hooks: - id: nbstripout - repo: https://github.com/nbQA-dev/nbQA diff --git a/benchmarks/train_crnn_cinc2020/trainer.py b/benchmarks/train_crnn_cinc2020/trainer.py index 3cb80060..f9804d0e 100644 --- a/benchmarks/train_crnn_cinc2020/trainer.py +++ b/benchmarks/train_crnn_cinc2020/trainer.py @@ -245,8 +245,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: head_labels = all_labels[:head_num, ...] head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels] for n in range(head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]} binary prediction: {head_bin_preds[n].tolist()} @@ -254,8 +253,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: predicted classes: {head_preds_classes[n].tolist()} label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) ( diff --git a/benchmarks/train_crnn_cinc2021/dataset.py b/benchmarks/train_crnn_cinc2021/dataset.py index 30e19bbb..ad0fc25e 100644 --- a/benchmarks/train_crnn_cinc2021/dataset.py +++ b/benchmarks/train_crnn_cinc2021/dataset.py @@ -314,14 +314,10 @@ def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = Fa ) train_file.write_text(json.dumps(train_set, ensure_ascii=False)) test_file.write_text(json.dumps(test_set, ensure_ascii=False)) - print( - textwrap.dedent( - f""" + print(textwrap.dedent(f""" train set saved to \n\042{str(train_file)}\042 test set saved to \n\042{str(test_file)}\042 - """ - ) - ) + """)) else: train_set = json.loads(train_file.read_text()) test_set = json.loads(test_file.read_text()) @@ -362,16 +358,12 @@ def _check_train_test_split_validity(self, train_set: List[str], test_set: List[ test_classes = set(list_sum([self.reader.get_labels(rec, fmt="a") for rec in test_set])) test_classes.intersection_update(all_classes) is_valid = len(all_classes) == len(train_classes) == len(test_classes) - print( - textwrap.dedent( - f""" + print(textwrap.dedent(f""" all_classes: {all_classes} train_classes: {train_classes} test_classes: {test_classes} is_valid: {is_valid} - """ - ) - ) + """)) return is_valid def persistence(self) -> None: diff --git a/benchmarks/train_crnn_cinc2021/trainer.py b/benchmarks/train_crnn_cinc2021/trainer.py index ba548ae7..bb348d2e 100644 --- a/benchmarks/train_crnn_cinc2021/trainer.py +++ b/benchmarks/train_crnn_cinc2021/trainer.py @@ -233,8 +233,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: head_labels = all_labels[:head_num, ...] head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels] for n in range(head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]} binary prediction: {head_bin_preds[n].tolist()} @@ -242,8 +241,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: predicted classes: {head_preds_classes[n].tolist()} label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) ( diff --git a/benchmarks/train_crnn_cinc2023/trainer.py b/benchmarks/train_crnn_cinc2023/trainer.py index a44095f0..749dd20b 100644 --- a/benchmarks/train_crnn_cinc2023/trainer.py +++ b/benchmarks/train_crnn_cinc2023/trainer.py @@ -346,8 +346,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: ] log_head_num = min(log_head_num, len(head_scalar_preds)) for n in range(log_head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- cpc scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]} cpc binary prediction: {head_bin_preds[n].tolist()} @@ -355,8 +354,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: cpc predicted classes: {head_preds_classes[n].tolist()} cpc label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) eval_res = compute_challenge_metrics( diff --git a/benchmarks/train_hybrid_cpsc2020/trainer.py b/benchmarks/train_hybrid_cpsc2020/trainer.py index 9fe5cbe7..7b935306 100644 --- a/benchmarks/train_hybrid_cpsc2020/trainer.py +++ b/benchmarks/train_hybrid_cpsc2020/trainer.py @@ -162,8 +162,7 @@ def train( # max_itr = n_epochs * n_train - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" Starting training: ------------------ Epochs: {n_epochs} @@ -174,8 +173,7 @@ def train( Device: {device.type} Optimizer: {config.train_optimizer} ----------------------------------------- - """ - ) + """) # print(msg) # in case no logger if logger: logger.info(msg) @@ -351,20 +349,17 @@ def train( scheduler.step() if debug: - eval_train_msg = textwrap.dedent( - f""" + eval_train_msg = textwrap.dedent(f""" train/auroc: {eval_train_res[0]} train/auprc: {eval_train_res[1]} train/accuracy: {eval_train_res[2]} train/f_measure: {eval_train_res[3]} train/f_beta_measure: {eval_train_res[4]} train/g_beta_measure: {eval_train_res[5]} - """ - ) + """) else: eval_train_msg = "" - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" Train epoch_{epoch + 1}: -------------------- train/epoch_loss: {epoch_loss}{eval_train_msg} @@ -375,8 +370,7 @@ def train( test/f_beta_measure: {eval_res[4]} test/g_beta_measure: {eval_res[5]} --------------------------------- - """ - ) + """) elif config.model_name == "seq_lab": eval_res = evaluate_seq_lab(model, val_loader, config, device, debug) model.train() @@ -403,8 +397,7 @@ def train( scheduler.step() if debug: - eval_train_msg = textwrap.dedent( - f""" + eval_train_msg = textwrap.dedent(f""" train/total_loss: {eval_train_res.total_loss} train/spb_loss: {eval_train_res.spb_loss} train/pvc_loss: {eval_train_res.pvc_loss} @@ -414,12 +407,10 @@ def train( train/pvc_fp: {eval_train_res.pvc_fp} train/spb_fn: {eval_train_res.spb_fn} train/pvc_fn: {eval_train_res.pvc_fn} - """ - ) + """) else: eval_train_msg = "" - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" Train epoch_{epoch + 1}: -------------------- train/epoch_loss: {epoch_loss}{eval_train_msg} @@ -433,8 +424,7 @@ def train( test/spb_fn: {eval_res.spb_fn} test/pvc_fn: {eval_res.pvc_fn} --------------------------------- - """ - ) + """) # print(msg) # in case no logger if logger: @@ -460,12 +450,10 @@ def train( print(msg) break - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" best challenge metric = {best_challenge_metric}, obtained at epoch {best_epoch} - """ - ) + """) if logger: logger.info(msg) else: @@ -615,8 +603,7 @@ def evaluate_crnn( head_labels = all_labels[:head_num, ...] head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels] for n in range(head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]} binary prediction: {head_bin_preds[n].tolist()} @@ -624,8 +611,7 @@ def evaluate_crnn( predicted classes: {head_preds_classes[n].tolist()} label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) if logger: logger.info(msg) else: diff --git a/benchmarks/train_hybrid_cpsc2021/dataset.py b/benchmarks/train_hybrid_cpsc2021/dataset.py index 3d93f857..a9a3d9fc 100644 --- a/benchmarks/train_hybrid_cpsc2021/dataset.py +++ b/benchmarks/train_hybrid_cpsc2021/dataset.py @@ -1183,14 +1183,10 @@ def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = Fa train_file.write_text(json.dumps(train_set, ensure_ascii=False)) test_file.write_text(json.dumps(test_set, ensure_ascii=False)) - print( - nildent( - f""" + print(nildent(f""" train set saved to \n\042{str(train_file)}\042 test set saved to \n\042{str(test_file)}\042 - """ - ) - ) + """)) else: train_set = json.loads(train_file.read_text()) test_set = json.loads(test_file.read_text()) diff --git a/benchmarks/train_mtl_cinc2022/inputs.py b/benchmarks/train_mtl_cinc2022/inputs.py index a878c417..1c482bb9 100644 --- a/benchmarks/train_mtl_cinc2022/inputs.py +++ b/benchmarks/train_mtl_cinc2022/inputs.py @@ -404,9 +404,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens class SpectralInput(_SpectralInput): - __doc__ = ( - _SpectralInput.__doc__ - + """ + __doc__ = _SpectralInput.__doc__ + """ Concatenation of 3 different types of spectrograms: - Spectrogram @@ -444,7 +442,6 @@ class SpectralInput(_SpectralInput): (32, 3, 224, 2308) """ - ) __name__ = "SpectralInput" diff --git a/benchmarks/train_mtl_cinc2022/trainer.py b/benchmarks/train_mtl_cinc2022/trainer.py index 9517c3c9..722e91f2 100644 --- a/benchmarks/train_mtl_cinc2022/trainer.py +++ b/benchmarks/train_mtl_cinc2022/trainer.py @@ -353,8 +353,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: ] log_head_num = min(log_head_num, len(head_scalar_preds)) for n in range(log_head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- murmur scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]} murmur binary prediction: {head_bin_preds[n].tolist()} @@ -362,8 +361,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: murmur predicted classes: {head_preds_classes[n].tolist()} murmur label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) if "outcome" in input_tensors: head_scalar_preds = all_outputs[0].outcome_output.prob[:log_head_num] @@ -382,8 +380,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: ] log_head_num = min(log_head_num, len(head_scalar_preds)) for n in range(log_head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- outcome scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]} outcome binary prediction: {head_bin_preds[n].tolist()} @@ -391,8 +388,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: outcome predicted classes: {head_preds_classes[n].tolist()} outcome label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) eval_res = compute_challenge_metrics( diff --git a/docs/source/conf.py b/docs/source/conf.py index 89afcbb3..3fdc4eda 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -216,13 +216,11 @@ def setup(app): ] # https://sphinxcontrib-bibtex.readthedocs.io/en/latest/usage.html#latex-backend-fails-with-citations-in-figure-captions -latex_elements = { - "preamble": r""" +latex_elements = {"preamble": r""" % make phantomsection empty inside figures \usepackage{etoolbox} \AtBeginEnvironment{figure}{\renewcommand{\phantomsection}{}} - """ -} + """} man_pages = [(master_doc, project, f"{project} Documentation", [author], 1)] diff --git a/test/test_pipelines/test_crnn_cinc2021_pipeline.py b/test/test_pipelines/test_crnn_cinc2021_pipeline.py index 07171053..c2a40077 100644 --- a/test/test_pipelines/test_crnn_cinc2021_pipeline.py +++ b/test/test_pipelines/test_crnn_cinc2021_pipeline.py @@ -621,8 +621,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: head_labels = all_labels[:head_num, ...] head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels] for n in range(head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]} binary prediction: {head_bin_preds[n].tolist()} @@ -630,8 +629,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: predicted classes: {head_preds_classes[n].tolist()} label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) ( diff --git a/test/test_pipelines/test_mtl_cinc2022_pipeline.py b/test/test_pipelines/test_mtl_cinc2022_pipeline.py index 21d46270..847a3dd7 100644 --- a/test/test_pipelines/test_mtl_cinc2022_pipeline.py +++ b/test/test_pipelines/test_mtl_cinc2022_pipeline.py @@ -3329,8 +3329,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: ] log_head_num = min(log_head_num, len(head_scalar_preds)) for n in range(log_head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- murmur scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]} murmur binary prediction: {head_bin_preds[n].tolist()} @@ -3338,8 +3337,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: murmur predicted classes: {head_preds_classes[n].tolist()} murmur label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) if "outcome" in input_tensors: head_scalar_preds = all_outputs[0].outcome_output.prob[:log_head_num] @@ -3358,8 +3356,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: ] log_head_num = min(log_head_num, len(head_scalar_preds)) for n in range(log_head_num): - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" ---------------------------------------------- outcome scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]} outcome binary prediction: {head_bin_preds[n].tolist()} @@ -3367,8 +3364,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: outcome predicted classes: {head_preds_classes[n].tolist()} outcome label classes: {head_labels_classes[n].tolist()} ---------------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) eval_res = compute_challenge_metrics( diff --git a/test/test_utils/test_misc.py b/test/test_utils/test_misc.py index f1b9267a..8ee72548 100644 --- a/test/test_utils/test_misc.py +++ b/test/test_utils/test_misc.py @@ -328,14 +328,12 @@ def test_dicts_equal(): def test_ReprMixin(): some_class = SomeClass(1, 2, 3) - string = textwrap.dedent( - """ + string = textwrap.dedent(""" SomeClass( aaa = 1, bb = 2 ) - """ - ).strip("\n") + """).strip("\n") assert str(some_class) == repr(some_class) assert str(some_class) == string @@ -447,9 +445,7 @@ def test_remove_parameters_returns_from_docstring(): parameters=["returns_indicator", "parameters_indicator"], returns="str", ) - assert ( - new_docstring - == """Remove parameters and/or returns from docstring, + assert new_docstring == """Remove parameters and/or returns from docstring, which is of the format of `numpydoc`. Parameters @@ -470,7 +466,6 @@ def test_remove_parameters_returns_from_docstring(): or add a line of `None` to the section. """ - ) def test_timeout(): diff --git a/test/test_utils/test_utils_signal_t.py b/test/test_utils/test_utils_signal_t.py index 424698b2..cfdc6f2d 100644 --- a/test/test_utils/test_utils_signal_t.py +++ b/test/test_utils/test_utils_signal_t.py @@ -7,7 +7,7 @@ def test_normalize(): - (b, l, s) = 2, 12, 20 + b, l, s = 2, 12, 20 nm_sig = normalize(torch.randn(b, l, s), method="min-max", inplace=True) for shape in [ (b,), diff --git a/torch_ecg/components/inputs.py b/torch_ecg/components/inputs.py index fc348c03..5094ea22 100644 --- a/torch_ecg/components/inputs.py +++ b/torch_ecg/components/inputs.py @@ -539,9 +539,7 @@ def extra_repr_keys(self) -> List[str]: class SpectrogramInput(_SpectralInput): - __doc__ = ( - _SpectralInput.__doc__ - + """ + __doc__ = _SpectralInput.__doc__ + """ Examples -------- @@ -567,7 +565,6 @@ class SpectrogramInput(_SpectralInput): True """ - ) __name__ = "SpectrogramInput" def _post_init(self) -> None: diff --git a/torch_ecg/components/metrics.py b/torch_ecg/components/metrics.py index 2c7fcae1..d15e5bc2 100644 --- a/torch_ecg/components/metrics.py +++ b/torch_ecg/components/metrics.py @@ -492,8 +492,7 @@ def set_macro(self, macro: bool) -> None: if macro: self.__prefix = "macro_" - @add_docstring( - f""" + @add_docstring(f""" Compute metrics for the task of ECG wave delineation (sensitivity, precision, f1_score, mean error and standard deviation of the mean errors) for multiple evaluations. @@ -527,8 +526,7 @@ def set_macro(self, macro: bool) -> None: self : WaveDelineationMetrics The metrics object itself with the computed metrics. - """ - ) + """) def compute( self, labels: Union[np.ndarray, Tensor], diff --git a/torch_ecg/components/outputs.py b/torch_ecg/components/outputs.py index 3c286c83..f86ce4d8 100644 --- a/torch_ecg/components/outputs.py +++ b/torch_ecg/components/outputs.py @@ -352,8 +352,7 @@ def required_fields(self) -> Set[str]: ] ) - @add_docstring( - f"""Compute metrics from the output + @add_docstring(f"""Compute metrics from the output Parameters ---------- @@ -375,8 +374,7 @@ def required_fields(self) -> Set[str]: metrics : WaveDelineationMetrics Metrics computed from the output - """ - ) + """) def compute_metrics( self, fs: int, diff --git a/torch_ecg/components/trainer.py b/torch_ecg/components/trainer.py index 127d23e6..6ae8131f 100644 --- a/torch_ecg/components/trainer.py +++ b/torch_ecg/components/trainer.py @@ -164,8 +164,7 @@ def train(self) -> OrderedDict: level=logging.WARNING, ) - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" Starting training: ------------------ Epochs: {self.n_epochs} @@ -177,8 +176,7 @@ def train(self) -> OrderedDict: Optimizer: {self.train_config.optimizer} Dataset classes: {self.train_config.classes} ----------------------------------------- - """ - ) + """) self.log_manager.log_message(msg) start_epoch = self.epoch @@ -239,12 +237,10 @@ def train(self) -> OrderedDict: self.log_manager.log_message(msg) break - msg = textwrap.dedent( - f""" + msg = textwrap.dedent(f""" best metric = {self.best_metric}, obtained at epoch {self.best_epoch} - """ - ) + """) self.log_manager.log_message(msg) # save checkpoint diff --git a/torch_ecg/databases/aux_data/aha.py b/torch_ecg/databases/aux_data/aha.py index 72fef94d..d74164c3 100644 --- a/torch_ecg/databases/aux_data/aha.py +++ b/torch_ecg/databases/aux_data/aha.py @@ -28,8 +28,7 @@ df_primary_statements = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" CategoryCode,Category,Code,Description A,Overall interpretation,1,Normal ECG ,,2,Otherwise normal ECG @@ -148,8 +147,7 @@ ,,188,"Failure to inhibit, ventricular" ,,189,"Failure to pace, atrial" ,,190,"Failure to pace, ventricular" -""" - ), +"""), dtype=str, ) @@ -158,8 +156,7 @@ df_secondary_statements = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" Group,Code,Description Suggests,200,Acute pericarditis ,201,Acute pulmonary embolism @@ -189,8 +186,7 @@ ,229,Pulmonary disease ,230,Dextrocardia ,231,Dextroposition -""" - ), +"""), dtype=str, ) @@ -199,8 +195,7 @@ df_modifiers = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" Category,Code,Description, General,301,Borderline, ,303,Increased, @@ -249,8 +244,7 @@ ,366,Low amplitude, ,367,Inversion, ,369,Postpacing (anamnestic), -""" - ), +"""), dtype=str, ) @@ -259,8 +253,7 @@ df_comparison_statements = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" Code,Statement,Criteria 400,No significant change,"Intervals (PR, QRS, QTc) remain normal or within 10% of a previously abnormal value" ,,No new or deleted diagnoses with the exception of normal variant diagnoses @@ -275,8 +268,7 @@ ,,Change in QTc >60 ms 405,Change in clinical status,"New or deleted diagnosis from Axis and Voltage, Chamber Hypertrophy, or Enlargement primary statement categories or “Suggests…” secondary statement category" 406,Change in interpretation without significant change in waveform,"Used when a primary or secondary statement is added or removed despite no real change in the tracing; ie, an interpretive disagreement exists between the readers of the first and second ECGs" -""" - ), +"""), dtype=str, ) @@ -285,22 +277,19 @@ df_convenience_statements = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" Code,Statement 500,Nonspecific ST-T abnormality 501,ST elevation 502,ST depression 503,LVH with ST-T changes -""" - ), +"""), dtype=str, ) df_secondary_primary_statement_pairing_rules = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" Secondary Code,May Accompany These Primary Codes 200,145-147 201,"21, 105, 109, 120, 131, 141, 145-147" @@ -330,15 +319,13 @@ 229,"109, 120, 122-123, 125, 128, 131, 141, 143" 230,"128, 131" 231,128 -""" - ), +"""), dtype=str, ) df_general_modifier_primary_statement_pairing_rules = pd.read_csv( - io.StringIO( - """ + io.StringIO(""" General Modifier Code,May (May Not) Accompany These Primary Codes or May Be Between Codes in These Categories or Groups of Categories,May/May Not,Location 301,"1-20, 24-76, 81, 83-106, 108, 122-124",May not,b 302,"1-3, 12-16, 80-82, 111-130, 145-152",May not,"b, i" @@ -359,8 +346,7 @@ 318,"C, D, E, F, G, N, H, I, J, K, L, M",May,i 319,"C, D, E, F, G, N, 100, J, K, L, M",May,i 321,"40, 55, 56, 145-147",May,b -""" - ), +"""), dtype=str, ) diff --git a/torch_ecg/databases/aux_data/cinc2020_aux_data.py b/torch_ecg/databases/aux_data/cinc2020_aux_data.py index 2786873a..a058fad5 100644 --- a/torch_ecg/databases/aux_data/cinc2020_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2020_aux_data.py @@ -69,8 +69,7 @@ dx_mapping_scored = pd.read_csv( - StringIO( - """Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total,Notes + StringIO("""Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total,Notes 1st degree av block,270492004,IAVB,722,106,0,0,797,769,2394, atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,3475, atrial flutter,164890007,AFL,0,54,0,1,73,186,314, @@ -97,16 +96,14 @@ supraventricular premature beats,63593006,SVPB,0,53,4,0,157,1,215,We score 284470004 and 63593006 as the same diagnosis. t wave abnormal,164934002,TAb,0,22,0,0,2345,2306,4673, t wave inversion,59931005,TInv,0,5,1,0,294,812,1112, -ventricular premature beats,17338001,VPB,0,8,0,0,0,357,365,We score 427172004 and 17338001 as the same diagnosis.""" - ) +ventricular premature beats,17338001,VPB,0,8,0,0,0,357,365,We score 427172004 and 17338001 as the same diagnosis.""") ) dx_mapping_scored = dx_mapping_scored.fillna("") dx_mapping_scored["SNOMED CT Code"] = dx_mapping_scored["SNOMED CT Code"].apply(str) dx_mapping_unscored = pd.read_csv( - StringIO( - """Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total + StringIO("""Dx,SNOMED CT Code,Abbreviation,CPSC,CPSC-Extra,StPetersburg,PTB,PTB-XL,Georgia,Total 2nd degree av block,195042002,IIAVB,0,21,0,0,14,23,58 abnormal QRS,164951009,abQRS,0,0,0,0,3389,0,3389 accelerated junctional rhythm,426664006,AJR,0,0,0,0,0,19,19 @@ -190,8 +187,7 @@ ventricular tachycardia,164895002,VTach,0,1,1,10,0,0,12 ventricular trigeminy,251180001,VTrig,0,4,4,0,20,1,29 wandering atrial pacemaker,195101003,WAP,0,0,0,0,0,7,7 -wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,88""" - ) +wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,88""") ) dx_mapping_unscored["SNOMED CT Code"] = dx_mapping_unscored["SNOMED CT Code"].apply(str) diff --git a/torch_ecg/databases/aux_data/cinc2021_aux_data.py b/torch_ecg/databases/aux_data/cinc2021_aux_data.py index 859aa61d..6a1adc60 100644 --- a/torch_ecg/databases/aux_data/cinc2021_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2021_aux_data.py @@ -109,8 +109,7 @@ def expand_equiv_classes(df: pd.DataFrame, sep: str = "|") -> pd.DataFrame: dx_mapping_scored = pd.read_csv( - StringIO( - """Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total,Notes + StringIO("""Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total,Notes atrial fibrillation,164889003,AF,1221,153,2,15,1514,570,1780,0,5255, atrial flutter,164890007,AFL,0,54,0,1,73,186,445,7615,8374, bundle branch block,6374002,BBB,0,0,1,20,0,116,0,385,522, @@ -140,8 +139,7 @@ def expand_equiv_classes(df: pd.DataFrame, sep: str = "|") -> pd.DataFrame: supraventricular premature beats,63593006,SVPB,0,53,4,0,157,1,0,9,224,We score 284470004 and 63593006 as the same diagnosis. t wave abnormal,164934002,TAb,0,22,0,0,2345,2306,1876,5167,11716, t wave inversion,59931005,TInv,0,5,1,0,294,812,157,2720,3989, -ventricular premature beats,17338001,VPB,0,8,0,0,0,357,294,0,659,We score 427172004 and 17338001 as the same diagnosis.""" - ) +ventricular premature beats,17338001,VPB,0,8,0,0,0,357,294,0,659,We score 427172004 and 17338001 as the same diagnosis.""") ) dx_mapping_scored = dx_mapping_scored.fillna("") dx_mapping_scored["SNOMEDCTCode"] = dx_mapping_scored["SNOMEDCTCode"].apply(str) @@ -154,8 +152,7 @@ def expand_equiv_classes(df: pd.DataFrame, sep: str = "|") -> pd.DataFrame: dx_mapping_unscored = pd.read_csv( - StringIO( - """Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total + StringIO("""Dx,SNOMEDCTCode,Abbreviation,CPSC,CPSC_Extra,StPetersburg,PTB,PTB_XL,Georgia,Chapman_Shaoxing,Ningbo,Total accelerated atrial escape rhythm,233892002,AAR,0,0,0,0,0,0,0,16,16 abnormal QRS,164951009,abQRS,0,0,0,0,3389,0,0,0,3389 atrial escape beat,251187003,AED,0,0,0,0,0,0,0,17,17 @@ -258,8 +255,7 @@ def expand_equiv_classes(df: pd.DataFrame, sep: str = "|") -> pd.DataFrame: ventricular tachycardia,164895002,VTach,0,1,1,10,0,0,0,0,12 ventricular trigeminy,251180001,VTrig,0,4,4,0,20,1,8,0,37 wandering atrial pacemaker,195101003,WAP,0,0,0,0,0,7,2,0,9 -wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,4,68,160""" - ) +wolff parkinson white pattern,74390002,WPW,0,0,4,2,80,2,4,68,160""") ) dx_mapping_unscored["SNOMEDCTCode"] = dx_mapping_unscored["SNOMEDCTCode"].apply(str) dx_mapping_unscored["CUSPHNFH"] = dx_mapping_unscored["Chapman_Shaoxing"].values + dx_mapping_unscored["Ningbo"].values diff --git a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py index 9f3a867b..e40e1531 100644 --- a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py +++ b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py @@ -391,14 +391,10 @@ def _train_test_split( is_valid = True train_file.write_text(json.dumps(train_set, ensure_ascii=False)) test_file.write_text(json.dumps(test_set, ensure_ascii=False)) - print( - textwrap.dedent( - f""" + print(textwrap.dedent(f""" train set saved to \n\042{str(train_file)}\042 test set saved to \n\042{str(test_file)}\042 - """ - ) - ) + """)) else: train_set = json.loads(train_file.read_text()) test_set = json.loads(test_file.read_text()) @@ -440,16 +436,12 @@ def _check_train_test_split_validity(self, train_set: List[str], test_set: List[ test_classes = set(list_sum([self.reader.get_labels(rec, fmt="a") for rec in test_set])) test_classes.intersection_update(all_classes) is_valid = len(all_classes) == len(train_classes) == len(test_classes) - print( - textwrap.dedent( - f""" + print(textwrap.dedent(f""" all_classes: {all_classes} train_classes: {train_classes} test_classes: {test_classes} is_valid: {is_valid} - """ - ) - ) + """)) return is_valid def persistence(self) -> None: diff --git a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py index 136c67fc..489dcef5 100644 --- a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py +++ b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py @@ -1334,14 +1334,10 @@ def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = Fa train_file.write_text(json.dumps(train_set, ensure_ascii=False)) test_file.write_text(json.dumps(test_set, ensure_ascii=False)) - print( - nildent( - f""" + print(nildent(f""" train set saved to \n\042{str(train_file)}\042 test set saved to \n\042{str(test_file)}\042 - """ - ) - ) + """)) else: train_set = json.loads(train_file.read_text()) test_set = json.loads(test_file.read_text()) diff --git a/torch_ecg/utils/ecg_arrhythmia_knowledge.py b/torch_ecg/utils/ecg_arrhythmia_knowledge.py index c093c610..26e473ae 100644 --- a/torch_ecg/utils/ecg_arrhythmia_knowledge.py +++ b/torch_ecg/utils/ecg_arrhythmia_knowledge.py @@ -38,6 +38,7 @@ NOTE that wikipedia is NOT listed in the References """ + from io import StringIO import pandas as pd @@ -624,9 +625,7 @@ ) -_dx_mapping = pd.read_csv( - StringIO( - """Dx,SNOMEDCTCode,Abbreviation +_dx_mapping = pd.read_csv(StringIO("""Dx,SNOMEDCTCode,Abbreviation atrial fibrillation,164889003,AF atrial flutter,164890007,AFL bundle branch block,6374002,BBB @@ -759,9 +758,7 @@ ventricular tachycardia,164895002,VTach ventricular trigeminy,251180001,VTrig wandering atrial pacemaker,195101003,WAP -wolff parkinson white pattern,74390002,WPW""" - ) -) +wolff parkinson white pattern,74390002,WPW""")) for ea_str in __all__: