Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ repos:
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 25.1.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 26.1.0
hooks:
- id: black
args: [--line-length=128, --extend-exclude=.ipynb, --verbose]
Expand All @@ -20,12 +20,12 @@ repos:
additional_dependencies: [pycodestyle>=2.11.0]
args: [--max-line-length=128, '--exclude=./.*,build,dist', '--ignore=E501,W503,E203,F841,E231,W604', --count, --statistics, --show-source]
- repo: https://github.com/pycqa/isort
rev: 6.0.1
rev: 7.0.0
hooks:
- id: isort
args: [--profile=black, --line-length=128]
- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
rev: 0.9.0
hooks:
- id: nbstripout
- repo: https://github.com/nbQA-dev/nbQA
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/train_crnn_cinc2020/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,17 +245,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
head_labels = all_labels[:head_num, ...]
head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels]
for n in range(head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]}
binary prediction: {head_bin_preds[n].tolist()}
labels: {head_labels[n].astype(int).tolist()}
predicted classes: {head_preds_classes[n].tolist()}
label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

(
Expand Down
16 changes: 4 additions & 12 deletions benchmarks/train_crnn_cinc2021/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,10 @@ def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = Fa
)
train_file.write_text(json.dumps(train_set, ensure_ascii=False))
test_file.write_text(json.dumps(test_set, ensure_ascii=False))
print(
textwrap.dedent(
f"""
print(textwrap.dedent(f"""
train set saved to \n\042{str(train_file)}\042
test set saved to \n\042{str(test_file)}\042
"""
)
)
"""))
else:
train_set = json.loads(train_file.read_text())
test_set = json.loads(test_file.read_text())
Expand Down Expand Up @@ -362,16 +358,12 @@ def _check_train_test_split_validity(self, train_set: List[str], test_set: List[
test_classes = set(list_sum([self.reader.get_labels(rec, fmt="a") for rec in test_set]))
test_classes.intersection_update(all_classes)
is_valid = len(all_classes) == len(train_classes) == len(test_classes)
print(
textwrap.dedent(
f"""
print(textwrap.dedent(f"""
all_classes: {all_classes}
train_classes: {train_classes}
test_classes: {test_classes}
is_valid: {is_valid}
"""
)
)
"""))
return is_valid

def persistence(self) -> None:
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/train_crnn_cinc2021/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,17 +233,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
head_labels = all_labels[:head_num, ...]
head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels]
for n in range(head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]}
binary prediction: {head_bin_preds[n].tolist()}
labels: {head_labels[n].astype(int).tolist()}
predicted classes: {head_preds_classes[n].tolist()}
label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

(
Expand Down
6 changes: 2 additions & 4 deletions benchmarks/train_crnn_cinc2023/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,17 +346,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
log_head_num = min(log_head_num, len(head_scalar_preds))
for n in range(log_head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
cpc scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]}
cpc binary prediction: {head_bin_preds[n].tolist()}
cpc labels: {head_labels[n].astype(int).tolist()}
cpc predicted classes: {head_preds_classes[n].tolist()}
cpc label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

eval_res = compute_challenge_metrics(
Expand Down
42 changes: 14 additions & 28 deletions benchmarks/train_hybrid_cpsc2020/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ def train(

# max_itr = n_epochs * n_train

msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
Starting training:
------------------
Epochs: {n_epochs}
Expand All @@ -174,8 +173,7 @@ def train(
Device: {device.type}
Optimizer: {config.train_optimizer}
-----------------------------------------
"""
)
""")
# print(msg) # in case no logger
if logger:
logger.info(msg)
Expand Down Expand Up @@ -351,20 +349,17 @@ def train(
scheduler.step()

if debug:
eval_train_msg = textwrap.dedent(
f"""
eval_train_msg = textwrap.dedent(f"""
train/auroc: {eval_train_res[0]}
train/auprc: {eval_train_res[1]}
train/accuracy: {eval_train_res[2]}
train/f_measure: {eval_train_res[3]}
train/f_beta_measure: {eval_train_res[4]}
train/g_beta_measure: {eval_train_res[5]}
"""
)
""")
else:
eval_train_msg = ""
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
Train epoch_{epoch + 1}:
--------------------
train/epoch_loss: {epoch_loss}{eval_train_msg}
Expand All @@ -375,8 +370,7 @@ def train(
test/f_beta_measure: {eval_res[4]}
test/g_beta_measure: {eval_res[5]}
---------------------------------
"""
)
""")
elif config.model_name == "seq_lab":
eval_res = evaluate_seq_lab(model, val_loader, config, device, debug)
model.train()
Expand All @@ -403,8 +397,7 @@ def train(
scheduler.step()

if debug:
eval_train_msg = textwrap.dedent(
f"""
eval_train_msg = textwrap.dedent(f"""
train/total_loss: {eval_train_res.total_loss}
train/spb_loss: {eval_train_res.spb_loss}
train/pvc_loss: {eval_train_res.pvc_loss}
Expand All @@ -414,12 +407,10 @@ def train(
train/pvc_fp: {eval_train_res.pvc_fp}
train/spb_fn: {eval_train_res.spb_fn}
train/pvc_fn: {eval_train_res.pvc_fn}
"""
)
""")
else:
eval_train_msg = ""
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
Train epoch_{epoch + 1}:
--------------------
train/epoch_loss: {epoch_loss}{eval_train_msg}
Expand All @@ -433,8 +424,7 @@ def train(
test/spb_fn: {eval_res.spb_fn}
test/pvc_fn: {eval_res.pvc_fn}
---------------------------------
"""
)
""")

# print(msg) # in case no logger
if logger:
Expand All @@ -460,12 +450,10 @@ def train(
print(msg)
break

msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
best challenge metric = {best_challenge_metric},
obtained at epoch {best_epoch}
"""
)
""")
if logger:
logger.info(msg)
else:
Expand Down Expand Up @@ -615,17 +603,15 @@ def evaluate_crnn(
head_labels = all_labels[:head_num, ...]
head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels]
for n in range(head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]}
binary prediction: {head_bin_preds[n].tolist()}
labels: {head_labels[n].astype(int).tolist()}
predicted classes: {head_preds_classes[n].tolist()}
label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
if logger:
logger.info(msg)
else:
Expand Down
8 changes: 2 additions & 6 deletions benchmarks/train_hybrid_cpsc2021/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,14 +1183,10 @@ def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = Fa

train_file.write_text(json.dumps(train_set, ensure_ascii=False))
test_file.write_text(json.dumps(test_set, ensure_ascii=False))
print(
nildent(
f"""
print(nildent(f"""
train set saved to \n\042{str(train_file)}\042
test set saved to \n\042{str(test_file)}\042
"""
)
)
"""))
else:
train_set = json.loads(train_file.read_text())
test_set = json.loads(test_file.read_text())
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/train_mtl_cinc2022/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens


class SpectralInput(_SpectralInput):
__doc__ = (
_SpectralInput.__doc__
+ """
__doc__ = _SpectralInput.__doc__ + """

Concatenation of 3 different types of spectrograms:
- Spectrogram
Expand Down Expand Up @@ -444,7 +442,6 @@ class SpectralInput(_SpectralInput):
(32, 3, 224, 2308)

"""
)

__name__ = "SpectralInput"

Expand Down
12 changes: 4 additions & 8 deletions benchmarks/train_mtl_cinc2022/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,17 +353,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
log_head_num = min(log_head_num, len(head_scalar_preds))
for n in range(log_head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
murmur scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]}
murmur binary prediction: {head_bin_preds[n].tolist()}
murmur labels: {head_labels[n].astype(int).tolist()}
murmur predicted classes: {head_preds_classes[n].tolist()}
murmur label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)
if "outcome" in input_tensors:
head_scalar_preds = all_outputs[0].outcome_output.prob[:log_head_num]
Expand All @@ -382,17 +380,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
log_head_num = min(log_head_num, len(head_scalar_preds))
for n in range(log_head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
outcome scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]}
outcome binary prediction: {head_bin_preds[n].tolist()}
outcome labels: {head_labels[n].astype(int).tolist()}
outcome predicted classes: {head_preds_classes[n].tolist()}
outcome label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

eval_res = compute_challenge_metrics(
Expand Down
6 changes: 2 additions & 4 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,11 @@ def setup(app):
]

# https://sphinxcontrib-bibtex.readthedocs.io/en/latest/usage.html#latex-backend-fails-with-citations-in-figure-captions
latex_elements = {
"preamble": r"""
latex_elements = {"preamble": r"""
% make phantomsection empty inside figures
\usepackage{etoolbox}
\AtBeginEnvironment{figure}{\renewcommand{\phantomsection}{}}
"""
}
"""}

man_pages = [(master_doc, project, f"{project} Documentation", [author], 1)]

Expand Down
6 changes: 2 additions & 4 deletions test/test_pipelines/test_crnn_cinc2021_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,17 +621,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
head_labels = all_labels[:head_num, ...]
head_labels_classes = [np.array(classes)[np.where(row)] for row in head_labels]
for n in range(head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
scalar prediction: {[round(n, 3) for n in head_scalar_preds[n].tolist()]}
binary prediction: {head_bin_preds[n].tolist()}
labels: {head_labels[n].astype(int).tolist()}
predicted classes: {head_preds_classes[n].tolist()}
label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

(
Expand Down
12 changes: 4 additions & 8 deletions test/test_pipelines/test_mtl_cinc2022_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3329,17 +3329,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
log_head_num = min(log_head_num, len(head_scalar_preds))
for n in range(log_head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
murmur scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]}
murmur binary prediction: {head_bin_preds[n].tolist()}
murmur labels: {head_labels[n].astype(int).tolist()}
murmur predicted classes: {head_preds_classes[n].tolist()}
murmur label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)
if "outcome" in input_tensors:
head_scalar_preds = all_outputs[0].outcome_output.prob[:log_head_num]
Expand All @@ -3358,17 +3356,15 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
]
log_head_num = min(log_head_num, len(head_scalar_preds))
for n in range(log_head_num):
msg = textwrap.dedent(
f"""
msg = textwrap.dedent(f"""
----------------------------------------------
outcome scalar prediction: {[round(item, 3) for item in head_scalar_preds[n].tolist()]}
outcome binary prediction: {head_bin_preds[n].tolist()}
outcome labels: {head_labels[n].astype(int).tolist()}
outcome predicted classes: {head_preds_classes[n].tolist()}
outcome label classes: {head_labels_classes[n].tolist()}
----------------------------------------------
"""
)
""")
self.log_manager.log_message(msg)

eval_res = compute_challenge_metrics(
Expand Down
Loading
Loading