diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml
index 5820ddbc..9f3fdd88 100644
--- a/.github/workflows/run-pytest.yml
+++ b/.github/workflows/run-pytest.yml
@@ -14,6 +14,7 @@ on:
env:
PYTHON_PRIMARY_VERSION: '3.10'
+ SHHS_DATA_AVAILABLE: 'false'
jobs:
build:
@@ -66,36 +67,46 @@ jobs:
- name: List installed Python packages
run: |
python -m pip list
- - name: Install nsrr and download a samll part of SHHS to do test
+ - name: Install nsrr and download a small part of SHHS to do test
# ref. https://github.com/DeepPSP/nsrr-automate
- uses: gacts/run-and-post-run@v1
- with:
- # if ~/tmp/nsrr-data/shhs is empty (no files downloaded),
- # fail and terminate the workflow
- run: |
- gem install nsrr --no-document
- nsrr download shhs/polysomnography/edfs/shhs1/ --file="^shhs1\-20010.*\.edf" --token=${{ secrets.NSRR_TOKEN }}
- nsrr download shhs/polysomnography/annotations-events-nsrr/shhs1/ --file="^shhs1\-20010.*\-nsrr\.xml" --token=${{ secrets.NSRR_TOKEN }}
- nsrr download shhs/polysomnography/annotations-events-profusion/shhs1/ --file="^shhs1\-20010.*\-profusion\.xml" --token=${{ secrets.NSRR_TOKEN }}
- nsrr download shhs/polysomnography/annotations-rpoints/shhs1/ --file="^shhs1\-20010.*\-rpoint\.csv" --token=${{ secrets.NSRR_TOKEN }}
- nsrr download shhs/datasets/ --shallow --token=${{ secrets.NSRR_TOKEN }}
- nsrr download shhs/datasets/hrv-analysis/ --token=${{ secrets.NSRR_TOKEN }}
- mkdir -p ~/tmp/nsrr-data/
- mv shhs/ ~/tmp/nsrr-data/
- du -sh ~/tmp/nsrr-data/*
- if [ "$(find ~/tmp/nsrr-data/shhs -type f | wc -l)" -eq 0 ]; \
- then (echo "No files downloaded. Exiting..." && exit 1); \
- else echo "Found $(find ~/tmp/nsrr-data/shhs -type f | wc -l) files"; fi
- post: |
- cd ~/tmp/ && du -sh $(ls -A)
- rm -rf ~/tmp/nsrr-data/
- cd ~/tmp/ && du -sh $(ls -A)
+ # uses: gacts/run-and-post-run@v1
+ continue-on-error: true
+ run: |
+ set -u -o pipefail
+ gem install nsrr --no-document
+ nsrr download shhs/polysomnography/edfs/shhs1/ --file="^shhs1\-20010.*\.edf" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+ nsrr download shhs/polysomnography/annotations-events-nsrr/shhs1/ --file="^shhs1\-20010.*\-nsrr\.xml" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+ nsrr download shhs/polysomnography/annotations-events-profusion/shhs1/ --file="^shhs1\-20010.*\-profusion\.xml" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+ nsrr download shhs/polysomnography/annotations-rpoints/shhs1/ --file="^shhs1\-20010.*\-rpoint\.csv" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+ nsrr download shhs/datasets/ --shallow --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+ nsrr download shhs/datasets/hrv-analysis/ --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true
+
+ mkdir -p ~/tmp/nsrr-data/
+ mv shhs/ ~/tmp/nsrr-data/
+ du -sh ~/tmp/nsrr-data/* || true
+
+ EDF_COUNT=$(find ~/tmp/nsrr-data/shhs -type f -name "*.edf" 2>/dev/null | wc -l | tr -d ' ')
+ echo "Detected SHHS EDF file count: $EDF_COUNT"
+
+ if [ "$EDF_COUNT" -eq 0 ]; then
+ echo "::error title=No SHHS EDF files downloaded::No .edf files were downloaded (token may be invalid or pattern mismatch). SHHS tests will be skipped."
+ echo "No SHHS EDF files downloaded; setting SHHS_DATA_AVAILABLE=false"
+ echo "SHHS_DATA_AVAILABLE=false" >> $GITHUB_ENV
+ exit 1
+ else
+ echo "Found $EDF_COUNT SHHS EDF files; setting SHHS_DATA_AVAILABLE=true"
+ echo "SHHS_DATA_AVAILABLE=true" >> $GITHUB_ENV
+ fi
+ # post: |
+ # cd ~/tmp/ && du -sh $(ls -A)
+ # rm -rf ~/tmp/nsrr-data/
+ # cd ~/tmp/ && du -sh $(ls -A)
- name: Run test with pytest and collect coverage
run: |
- pytest -vv -s \
- --cov=torch_ecg \
- --ignore=test/test_pipelines \
- test
+ echo "SHHS_DATA_AVAILABLE at test step: $SHHS_DATA_AVAILABLE"
+ pytest --cov --junitxml=junit.xml -o junit_family=legacy
+ env:
+ SHHS_DATA_AVAILABLE: ${{ env.SHHS_DATA_AVAILABLE }}
- name: Upload coverage to Codecov
if: matrix.python-version == ${{ env.PYTHON_PRIMARY_VERSION }}
uses: codecov/codecov-action@v4
@@ -103,3 +114,14 @@ jobs:
fail_ci_if_error: true # optional (default = false)
verbose: true # optional (default = false)
token: ${{ secrets.CODECOV_TOKEN }} # required
+ - name: Upload test results to Codecov
+ if: ${{ !cancelled() }}
+ uses: codecov/test-results-action@v1
+ with:
+ token: ${{ secrets.CODECOV_TOKEN }}
+ - name: Cleanup SHHS temp data
+ if: always() && true
+ run: |
+ cd ~/tmp/ && du -sh $(ls -A)
+ rm -rf ~/tmp/nsrr-data/
+ cd ~/tmp/ && du -sh $(ls -A)
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 0046e588..a4d9a63a 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -19,8 +19,13 @@ Changed
- Make the function `remove_spikes_naive` in `torch_ecg.utils.utils_signal`
support 2D and 3D input signals.
+- Use `save_file` and `load_file` from the `safetensors` package for saving
+ and loading files in place of `torch.save` and `torch.load` in the `CkptMixin`
+ class in `torch_ecg.utils.utils_nn`.
- Add retry mechanism to the `http_get` function in
`torch_ecg.utils.download` module.
+- Add length verification in the `http_get` function in
+ `torch_ecg.utils.download` module.
Deprecated
~~~~~~~~~~
@@ -42,6 +47,14 @@ Fixed
- Fix potential errors when deepcopying a `torch_ecg.cfg.CFG` object:
previously, deepcopying such an object like `CFG({"a": {1: 0.1, 2: 0.2}})`
would result in an error.
+- Fix potential bugs in contextmanager `torch_ecg.utils.timeout`: restore the previously
+ installed SIGALRM handler after use, cancel any pending alarm reliably in a finally block,
+ avoid installing a handler when duration <= 0 (preventing unintended global side-effects),
+ and thereby eliminate spurious `TimeoutError` exceptions that could be triggered later by
+ unrelated signal.alarm calls due to the old implementation not reinstating the original handler.
+- Fix bugs in utility function `torch_ecg.utils.make_serializable`: the previous implementation
+ does not drop some types of unserializable items correctly. Two additional parameters
+ `drop_unserializable` and `drop_paths` are added.
Security
~~~~~~~~
diff --git a/benchmarks/train_crnn_cinc2020/model.py b/benchmarks/train_crnn_cinc2020/model.py
index 7df6941e..a94e8d03 100644
--- a/benchmarks/train_crnn_cinc2020/model.py
+++ b/benchmarks/train_crnn_cinc2020/model.py
@@ -96,8 +96,8 @@ def inference(
_input = _input.unsqueeze(0) # add a batch dimension
prob = self.sigmoid(self.forward(_input))
pred = (prob >= bin_pred_thr).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
for row_idx, row in enumerate(pred):
row_max_prob = prob[row_idx, ...].max()
if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None:
diff --git a/benchmarks/train_crnn_cinc2021/model.py b/benchmarks/train_crnn_cinc2021/model.py
index c828805c..f616548c 100644
--- a/benchmarks/train_crnn_cinc2021/model.py
+++ b/benchmarks/train_crnn_cinc2021/model.py
@@ -99,8 +99,8 @@ def inference(
# batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
pred = (prob >= bin_pred_thr).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
for row_idx, row in enumerate(pred):
row_max_prob = prob[row_idx, ...].max()
if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None:
diff --git a/benchmarks/train_hybrid_cpsc2020/model.py b/benchmarks/train_hybrid_cpsc2020/model.py
index f474d2db..5e22b94a 100644
--- a/benchmarks/train_hybrid_cpsc2020/model.py
+++ b/benchmarks/train_hybrid_cpsc2020/model.py
@@ -92,8 +92,8 @@ def inference(
_input = _input.unsqueeze(0) # add a batch dimension
prob = self.sigmoid(self.forward(_input))
pred = (prob >= bin_pred_thr).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
for row_idx, row in enumerate(pred):
row_max_prob = prob[row_idx, ...].max()
if row.sum() == 0:
@@ -190,14 +190,14 @@ def inference(
if self.n_classes == 2:
prob = self.sigmoid(prob) # (batch_size, seq_len, 2)
pred = (prob >= bin_pred_thr).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
# aux used to filter out potential simultaneous predictions of SPB and PVC
aux = (prob == np.max(prob, axis=2, keepdims=True)).astype(int)
pred = aux * pred
elif self.n_classes == 3:
prob = self.softmax(prob) # (batch_size, seq_len, 3)
- prob = prob.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
pred = np.argmax(prob, axis=2)
if rpeak_inds is not None:
diff --git a/benchmarks/train_hybrid_cpsc2021/entry_2021.py b/benchmarks/train_hybrid_cpsc2021/entry_2021.py
index 2854852c..2b70feab 100644
--- a/benchmarks/train_hybrid_cpsc2021/entry_2021.py
+++ b/benchmarks/train_hybrid_cpsc2021/entry_2021.py
@@ -379,12 +379,12 @@ def _detect_rpeaks(model, sig, siglen, overlap_len, config):
for idx in range(batch_size // _BATCH_SIZE):
pred = model.forward(sig[_BATCH_SIZE * idx : _BATCH_SIZE * (idx + 1), ...])
pred = model.sigmoid(pred)
- pred = pred.cpu().detach().numpy().squeeze(-1)
+ pred = pred.detach().cpu().numpy().squeeze(-1)
l_pred.append(pred)
if batch_size % _BATCH_SIZE != 0:
pred = model.forward(sig[batch_size // _BATCH_SIZE * _BATCH_SIZE :, ...])
pred = model.sigmoid(pred)
- pred = pred.cpu().detach().numpy().squeeze(-1)
+ pred = pred.detach().cpu().numpy().squeeze(-1)
l_pred.append(pred)
pred = np.concatenate(l_pred)
@@ -473,12 +473,12 @@ def _main_task(model, sig, siglen, overlap_len, rpeaks, config):
for idx in range(batch_size // _BATCH_SIZE):
pred = model.forward(sig[_BATCH_SIZE * idx : _BATCH_SIZE * (idx + 1), ...])
pred = model.sigmoid(pred)
- pred = pred.cpu().detach().numpy().squeeze(-1)
+ pred = pred.detach().cpu().numpy().squeeze(-1)
l_pred.append(pred)
if batch_size % _BATCH_SIZE != 0:
pred = model.forward(sig[batch_size // _BATCH_SIZE * _BATCH_SIZE :, ...])
pred = model.sigmoid(pred)
- pred = pred.cpu().detach().numpy().squeeze(-1)
+ pred = pred.detach().cpu().numpy().squeeze(-1)
l_pred.append(pred)
pred = np.concatenate(l_pred)
diff --git a/benchmarks/train_hybrid_cpsc2021/model.py b/benchmarks/train_hybrid_cpsc2021/model.py
index 2f50730d..9d02c102 100644
--- a/benchmarks/train_hybrid_cpsc2021/model.py
+++ b/benchmarks/train_hybrid_cpsc2021/model.py
@@ -169,7 +169,7 @@ def _inference_qrs_detection(
_input = _input.unsqueeze(0) # add a batch dimension
# batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _qrs_detection_post_process(
@@ -226,7 +226,7 @@ def _inference_main_task(
_input = _input.unsqueeze(0) # add a batch dimension
batch_size, n_leads, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
af_episodes, af_mask = _main_task_post_process(
prob=prob,
@@ -368,7 +368,7 @@ def _inference_qrs_detection(
_input = _input.unsqueeze(0) # add a batch dimension
# batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _qrs_detection_post_process(
@@ -425,7 +425,7 @@ def _inference_main_task(
_input = _input.unsqueeze(0) # add a batch dimension
batch_size, n_leads, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
af_episodes, af_mask = _main_task_post_process(
prob=prob,
@@ -567,7 +567,7 @@ def _inference_qrs_detection(
_input = _input.unsqueeze(0) # add a batch dimension
# batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _qrs_detection_post_process(
@@ -624,7 +624,7 @@ def _inference_main_task(
_input = _input.unsqueeze(0) # add a batch dimension
batch_size, n_leads, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
af_episodes, af_mask = _main_task_post_process(
prob=prob,
@@ -721,7 +721,7 @@ def inference(
prob = self.forward(_input)
if self.config.clf.name != "crf":
prob = self.sigmoid(prob)
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
af_episodes, af_mask = _main_task_post_process(
prob=prob,
diff --git a/benchmarks/train_mtl_cinc2022/models/crnn.py b/benchmarks/train_mtl_cinc2022/models/crnn.py
index 557e9380..b08d2e60 100644
--- a/benchmarks/train_mtl_cinc2022/models/crnn.py
+++ b/benchmarks/train_mtl_cinc2022/models/crnn.py
@@ -200,31 +200,31 @@ def inference(
prob = self.softmax(forward_output["murmur"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
murmur_output = ClassificationOutput(
classes=self.classes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["murmur"].cpu().detach().numpy(),
+ forward_output=forward_output["murmur"].detach().cpu().numpy(),
)
if forward_output.get("outcome", None) is not None:
prob = self.softmax(forward_output["outcome"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
outcome_output = ClassificationOutput(
classes=self.outcomes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["outcome"].cpu().detach().numpy(),
+ forward_output=forward_output["outcome"].detach().cpu().numpy(),
)
else:
outcome_output = None
@@ -238,13 +238,13 @@ def inference(
else:
prob = self.sigmoid(forward_output["segmentation"])
pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.states,
prob=prob,
pred=pred,
- forward_output=forward_output["segmentation"].cpu().detach().numpy(),
+ forward_output=forward_output["segmentation"].detach().cpu().numpy(),
)
else:
segmentation_output = None
diff --git a/benchmarks/train_mtl_cinc2022/models/model_ml.py b/benchmarks/train_mtl_cinc2022/models/model_ml.py
index b55305cb..cad80801 100644
--- a/benchmarks/train_mtl_cinc2022/models/model_ml.py
+++ b/benchmarks/train_mtl_cinc2022/models/model_ml.py
@@ -170,9 +170,10 @@ def get_model(self, model_name: str, params: Optional[dict] = None) -> BaseEstim
"""
model_cls = self.model_map[model_name]
+ params = params or {}
if model_cls in [GradientBoostingClassifier, SVC]:
params.pop("n_jobs", None)
- return model_cls(**(params or {}))
+ return model_cls(**params)
def save_model(
self,
@@ -198,9 +199,13 @@ def save_model(
path to save the model.
"""
+ if isinstance(model_path, bytes):
+ model_path = model_path.decode()
+ model_path = Path(model_path).expanduser().resolve()
+ model_path.parent.mkdir(parents=True, exist_ok=True)
_config = deepcopy(config)
_config.pop("db_dir", None)
- Path(model_path).write_bytes(
+ model_path.write_bytes(
pickle.dumps(
{
"config": _config,
diff --git a/benchmarks/train_mtl_cinc2022/models/seg.py b/benchmarks/train_mtl_cinc2022/models/seg.py
index d67eb684..d93a181b 100644
--- a/benchmarks/train_mtl_cinc2022/models/seg.py
+++ b/benchmarks/train_mtl_cinc2022/models/seg.py
@@ -138,8 +138,8 @@ def inference(
else:
prob = self.sigmoid(self.forward(_input)["segmentation"])
pred = (prob > bin_pred_threshold).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.classes,
@@ -264,8 +264,8 @@ def inference(
else:
prob = self.sigmoid(self.forward(_input)["segmentation"])
pred = (prob > bin_pred_threshold).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.classes,
diff --git a/benchmarks/train_mtl_cinc2022/models/wav2vec2.py b/benchmarks/train_mtl_cinc2022/models/wav2vec2.py
index 28e5c346..3886064b 100644
--- a/benchmarks/train_mtl_cinc2022/models/wav2vec2.py
+++ b/benchmarks/train_mtl_cinc2022/models/wav2vec2.py
@@ -279,31 +279,31 @@ def inference(
prob = self.softmax(forward_output["murmur"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
murmur_output = ClassificationOutput(
classes=self.classes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["murmur"].cpu().detach().numpy(),
+ forward_output=forward_output["murmur"].detach().cpu().numpy(),
)
if forward_output.get("outcome", None) is not None:
prob = self.softmax(forward_output["outcome"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
outcome_output = ClassificationOutput(
classes=self.outcomes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["outcome"].cpu().detach().numpy(),
+ forward_output=forward_output["outcome"].detach().cpu().numpy(),
)
else:
outcome_output = None
@@ -317,13 +317,13 @@ def inference(
else:
prob = self.sigmoid(forward_output["segmentation"])
pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.states,
prob=prob,
pred=pred,
- forward_output=forward_output["segmentation"].cpu().detach().numpy(),
+ forward_output=forward_output["segmentation"].detach().cpu().numpy(),
)
else:
segmentation_output = None
@@ -599,31 +599,31 @@ def inference(
prob = self.softmax(forward_output["murmur"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
murmur_output = ClassificationOutput(
classes=self.classes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["murmur"].cpu().detach().numpy(),
+ forward_output=forward_output["murmur"].detach().cpu().numpy(),
)
if forward_output.get("outcome", None) is not None:
prob = self.softmax(forward_output["outcome"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
outcome_output = ClassificationOutput(
classes=self.outcomes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["outcome"].cpu().detach().numpy(),
+ forward_output=forward_output["outcome"].detach().cpu().numpy(),
)
else:
outcome_output = None
@@ -637,13 +637,13 @@ def inference(
else:
prob = self.sigmoid(forward_output["segmentation"])
pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.states,
prob=prob,
pred=pred,
- forward_output=forward_output["segmentation"].cpu().detach().numpy(),
+ forward_output=forward_output["segmentation"].detach().cpu().numpy(),
)
else:
segmentation_output = None
diff --git a/benchmarks/train_mtl_cinc2022/team_code.py b/benchmarks/train_mtl_cinc2022/team_code.py
index dd40d991..bb47c482 100644
--- a/benchmarks/train_mtl_cinc2022/team_code.py
+++ b/benchmarks/train_mtl_cinc2022/team_code.py
@@ -424,8 +424,8 @@ def run_challenge_model(
# forward_output = main_model.clf(features) # shape (1, n_classes)
# probabilities = main_model.softmax(forward_output)
# labels = (probabilities == probabilities.max(dim=-1, keepdim=True).values).to(int)
- # probabilities = probabilities.squeeze(dim=0).cpu().detach().numpy()
- # labels = labels.squeeze(dim=0).cpu().detach().numpy()
+ # probabilities = probabilities.squeeze(dim=0).detach().cpu().numpy()
+ # labels = labels.squeeze(dim=0).detach().cpu().numpy()
# get final prediction for murmurs:
# strategy:
diff --git a/benchmarks/train_multi_cpsc2019/model.py b/benchmarks/train_multi_cpsc2019/model.py
index 93586dd7..10396142 100644
--- a/benchmarks/train_multi_cpsc2019/model.py
+++ b/benchmarks/train_multi_cpsc2019/model.py
@@ -112,7 +112,7 @@ def inference(
mode="linear",
align_corners=True,
).permute(0, 2, 1)
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _inference_post_process(
@@ -132,7 +132,7 @@ def inference(
sampling_rate=self.config.fs,
tol=0.05,
)[0]
- for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks)
+ for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks)
]
return RPeaksDetectionOutput(
@@ -224,7 +224,7 @@ def inference(
_input = _input.unsqueeze(0) # add a batch dimension
batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _inference_post_process(
@@ -244,7 +244,7 @@ def inference(
sampling_rate=self.config.fs,
tol=0.05,
)[0]
- for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks)
+ for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks)
]
return RPeaksDetectionOutput(
@@ -336,7 +336,7 @@ def inference(
_input = _input.unsqueeze(0) # add a batch dimension
batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _inference_post_process(
@@ -356,7 +356,7 @@ def inference(
sampling_rate=self.config.fs,
tol=0.05,
)[0]
- for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks)
+ for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks)
]
return RPeaksDetectionOutput(
diff --git a/benchmarks/train_unet_ludb/model.py b/benchmarks/train_unet_ludb/model.py
index 0efb37ac..5e46cfcc 100644
--- a/benchmarks/train_unet_ludb/model.py
+++ b/benchmarks/train_unet_ludb/model.py
@@ -89,7 +89,7 @@ def inference(
prob = self.softmax(prob)
else:
prob = torch.sigmoid(prob)
- prob = prob.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
if "i" in self.classes:
mask = np.argmax(prob, axis=-1)
diff --git a/docs/source/_static/css/codeblock.css b/docs/source/_static/css/codeblock.css
new file mode 100644
index 00000000..dd8fa012
--- /dev/null
+++ b/docs/source/_static/css/codeblock.css
@@ -0,0 +1,96 @@
+div.highlight {
+ position: relative;
+ display: flex;
+ border: 1px solid #eee;
+ border-radius: 8px;
+ overflow: hidden;
+ background: #fff;
+}
+
+div.highlight .custom-linenos {
+ flex-shrink: 0;
+ background: #fafafa;
+ border-right: 1px solid #eee;
+ color: #999;
+ text-align: right;
+ padding: 12px 8px;
+ font-family: monospace;
+ user-select: none;
+ line-height: 1.5;
+ white-space: pre;
+ min-width: 36px;
+}
+
+html[data-theme="dark"] div.highlight .custom-linenos {
+ background: #2a2a2a;
+ border-right-color: #444;
+ color: #aaa;
+}
+
+div.highlight pre {
+ flex: 1;
+ margin: 0;
+ padding: 12px 16px;
+ overflow-x: auto;
+ white-space: pre;
+ background: transparent;
+ line-height: 1.5;
+}
+
+div.highlight pre.code-wrapped {
+ white-space: pre-wrap;
+ word-break: break-word;
+ overflow-x: hidden;
+}
+
+.code-toolbar {
+ position: absolute;
+ top: 8px;
+ right: 8px;
+ display: flex;
+ gap: 6px;
+ z-index: 50;
+}
+
+.code-icon-btn {
+ display: inline-flex;
+ align-items: center;
+ justify-content: center;
+ background-color: rgba(255, 255, 255, 0.9);
+ border: 1px solid #ddd;
+ border-radius: 4px;
+ padding: 4px;
+ cursor: pointer;
+ color: #666;
+ transition: all 0.1s ease;
+ width: 26px;
+ height: 26px;
+}
+.code-icon-btn:hover {
+ background-color: #f5f5f5;
+ border-color: #ccc;
+ color: #333;
+}
+.code-icon-btn.active {
+ background-color: #e8f4e8;
+ border-color: #8cc08c;
+ color: #2d7d2d;
+}
+
+.copy-tooltip {
+ position: fixed;
+ transform: translateX(-50%);
+ background: rgba(0, 0, 0, 0.9);
+ color: #fff;
+ font-size: 12px;
+ padding: 4px 8px;
+ border-radius: 4px;
+ opacity: 0;
+ transition: opacity 0.15s ease, transform 0.15s ease;
+ pointer-events: none;
+ z-index: 9999;
+}
+.copy-tooltip.visible {
+ opacity: 1;
+ transform: translateX(-50%) translateY(-3px);
+}
diff --git a/docs/source/_static/js/codeblock.js b/docs/source/_static/js/codeblock.js
new file mode 100644
index 00000000..1f6649ff
--- /dev/null
+++ b/docs/source/_static/js/codeblock.js
@@ -0,0 +1,125 @@
+document.addEventListener("DOMContentLoaded", () => {
+ document.querySelectorAll("div.highlight").forEach((highlightBlock) => {
+ const codePre = highlightBlock.querySelector("pre");
+ if (!codePre) return;
+
+ const hasSphinxLinenos = codePre.querySelector(".linenos") !== null;
+
+ if (hasSphinxLinenos) {
+ codePre.querySelectorAll("span.linenos").forEach((el) => el.remove());
+ codePre.dataset.linenos = "true";
+ }
+
+ const wrapBtn = document.createElement("button");
+ wrapBtn.className = "code-icon-btn code-wrap-btn";
+ wrapBtn.innerHTML = ``;
+ wrapBtn.title = "Toggle line wrap";
+
+ wrapBtn.addEventListener("click", () => {
+ const isWrapped = codePre.classList.toggle("code-wrapped");
+ wrapBtn.classList.toggle("active", isWrapped);
+
+ const lineDiv = highlightBlock.querySelector(".custom-linenos");
+ if (lineDiv) {
+ lineDiv.style.display = isWrapped ? "none" : "block";
+ }
+ });
+
+ const copyBtn = document.createElement("button");
+ copyBtn.className = "code-icon-btn code-copy-btn";
+ copyBtn.innerHTML = ``;
+ copyBtn.title = "Copy to clipboard";
+
+ copyBtn.addEventListener("click", async () => {
+ let codeText = codePre.textContent.trim();
+ try {
+ if (navigator.clipboard?.writeText) {
+ await navigator.clipboard.writeText(codeText);
+ } else {
+ fallbackCopyText(codeText);
+ }
+ showCopyTooltip(copyBtn, "Copied!");
+ } catch (err) {
+ console.error("Copy failed:", err);
+ showCopyTooltip(copyBtn, "Failed");
+ }
+ });
+
+ const toolbar = document.createElement("div");
+ toolbar.className = "code-toolbar";
+ toolbar.appendChild(wrapBtn);
+ toolbar.appendChild(copyBtn);
+ highlightBlock.style.position = "relative";
+ highlightBlock.appendChild(toolbar);
+
+ if (codePre.dataset.linenos === "true") {
+ const contentForCount = codePre.textContent.trimEnd();
+ const totalLines = contentForCount.split("\n").length;
+
+ const lineNumbers = Array.from({ length: totalLines }, (_, i) => i + 1).join(
+ "\n"
+ );
+ const lineDiv = document.createElement("div");
+ lineDiv.className = "custom-linenos";
+ lineDiv.textContent = lineNumbers;
+ highlightBlock.insertBefore(lineDiv, codePre);
+ }
+ });
+});
+
+function fallbackCopyText(text) {
+ const ta = document.createElement("textarea");
+ ta.value = text;
+ ta.style.position = "fixed";
+ ta.style.top = "-10000px";
+ document.body.appendChild(ta);
+ ta.focus();
+ ta.select();
+ document.execCommand("copy");
+ document.body.removeChild(ta);
+}
+
+function showCopyTooltip(btn, text) {
+ const tooltip = document.createElement("div");
+ tooltip.className = "copy-tooltip";
+ tooltip.textContent = text;
+ document.body.appendChild(tooltip);
+ const rect = btn.getBoundingClientRect();
+ let top = rect.top - 28;
+ if (top < 6) top = rect.bottom + 8;
+ tooltip.style.left = `${rect.left + rect.width / 2}px`;
+ tooltip.style.top = `${top}px`;
+ requestAnimationFrame(() => tooltip.classList.add("visible"));
+ setTimeout(() => {
+ tooltip.classList.remove("visible");
+ setTimeout(() => tooltip.remove(), 200);
+ }, 1200);
+}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 89afcbb3..41bb8064 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -203,6 +203,8 @@ def setup(app):
# )
# app.add_transform(AutoStructify)
app.add_css_file("css/custom.css")
+ app.add_css_file("css/codeblock.css")
+ app.add_js_file("js/codeblock.js")
latex_documents = [
diff --git a/pyproject.toml b/pyproject.toml
index 172f9add..d4f4ad32 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -38,6 +38,7 @@ dependencies = [
"pyEDFlib",
"PyWavelets",
"requests",
+ "safetensors",
"scikit-learn",
"scipy",
"soundfile",
@@ -54,7 +55,7 @@ dependencies = [
[project.optional-dependencies]
dev = [
- "black==24.10.0",
+ "black",
"flake8",
"gdown",
"librosa",
@@ -104,7 +105,7 @@ docs = [
"sphinxcontrib-tikz",
]
test = [
- "black==24.10.0",
+ "black",
"flake8",
"gdown",
"librosa",
@@ -130,3 +131,39 @@ path = "torch_ecg/version.py"
include = [
"/torch_ecg",
]
+
+# configuration for pytest and coverage
+[tool.pytest.ini_options]
+minversion = "7.0"
+addopts = "-vv -s --cov=torch_ecg --ignore=test/test_pipelines"
+testpaths = ["test"]
+
+[tool.coverage.run]
+source = ["torch_ecg"]
+omit = [
+ "torch_ecg/databases/nsrr_databases/shhs.py",
+]
+
+[tool.coverage.report]
+omit = [
+ # "torch_ecg/databases/nsrr_databases/shhs.py",
+ # NSRR databases are ignored temporarily, since data access is denied by the US government
+ "torch_ecg/databases/nsrr_databases/*",
+ # temporarily ignore torch_ecg/components/nas.py since it's not implemented completely
+ "torch_ecg/components/nas.py",
+ # temporarily ignore torch_ecg/models/grad_cam.py since it's not implemented completely
+ "torch_ecg/models/grad_cam.py",
+ # temporarily ignore torch_ecg/ssl since it's not implemented
+ "torch_ecg/ssl/*",
+]
+exclude_also = [
+ "raise NotImplementedError",
+ # Don't complain if non-runnable code isn't run:
+ "if __name__ == .__main__.:",
+ # Don't complain about abstract methods, they aren't run:
+ "@(abc\\.)?abstractmethod",
+ # base class of the NSRR databases are also ignored temporarily
+ "^class NSRRDataBase\\(.*\\):",
+]
+# show_missing = true
+# skip_covered = true
diff --git a/pyrightconfig.json b/pyrightconfig.json
new file mode 100644
index 00000000..5a8db90e
--- /dev/null
+++ b/pyrightconfig.json
@@ -0,0 +1,5 @@
+{
+ "diagnosticSeverityOverrides": {
+ "reportAttributeAccessIssue": "none"
+ }
+}
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 3e04b0ae..cad67fcd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,6 +13,7 @@ tensorboardX
tqdm
deprecated
deprecate-kwargs
+safetensors
pyEDFlib
PyWavelets
torch-optimizer
diff --git a/test/requirements.txt b/test/requirements.txt
index 566f53ba..fbf9f542 100644
--- a/test/requirements.txt
+++ b/test/requirements.txt
@@ -1,4 +1,4 @@
-black==24.10.0
+black
flake8
pytest
pytest-xdist
diff --git a/test/test_components/test_trainer.py b/test/test_components/test_trainer.py
index 9d7b2188..d690cf11 100644
--- a/test/test_components/test_trainer.py
+++ b/test/test_components/test_trainer.py
@@ -95,7 +95,7 @@ def __init__(self, n_leads: int, config: Optional[CFG] = None, **kwargs: Any) ->
super().__init__(ModelCfg.mask_classes, n_leads, model_config)
@torch.no_grad()
- def inference(
+ def inference( # type: ignore
self,
input: Union[Sequence[float], np.ndarray, Tensor],
bin_pred_thr: float = 0.5,
@@ -130,7 +130,7 @@ def inference(
prob = self.softmax(prob)
else:
prob = torch.sigmoid(prob)
- prob = prob.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
if "i" in self.classes:
mask = np.argmax(prob, axis=-1)
@@ -165,37 +165,37 @@ def __init__(
"""
Parameters
----------
- model: Module,
+ model : Module
the model to be trained
- model_config: dict,
+ model_config : dict
the configuration of the model,
used to keep a record in the checkpoints
- train_config: dict,
+ train_config : dict
the configuration of the training,
including configurations for the data loader, for the optimization, etc.
will also be recorded in the checkpoints.
`train_config` should at least contain the following keys:
- "monitor": str,
- "loss": str,
- "n_epochs": int,
- "batch_size": int,
- "learning_rate": float,
- "lr_scheduler": str,
- "lr_step_size": int, optional, depending on the scheduler
- "lr_gamma": float, optional, depending on the scheduler
- "max_lr": float, optional, depending on the scheduler
- "optimizer": str,
- "decay": float, optional, depending on the optimizer
- "momentum": float, optional, depending on the optimizer
- device: torch.device, optional,
+ - "monitor": str,
+ - "loss": str,
+ - "n_epochs": int,
+ - "batch_size": int,
+ - "learning_rate": float,
+ - "lr_scheduler": str,
+ - "lr_step_size": int, optional, depending on the scheduler
+ - "lr_gamma": float, optional, depending on the scheduler
+ - "max_lr": float, optional, depending on the scheduler
+ - "optimizer": str,
+ - "decay": float, optional, depending on the optimizer
+ - "momentum": float, optional, depending on the optimizer
+ device : torch.device, optional
the device to be used for training
- lazy: bool, default True,
+ lazy : bool, default True
whether to initialize the data loader lazily
"""
super().__init__(
model=model,
- dataset_cls=LUDBDataset,
+ dataset_cls=LUDBDataset, # type: ignore
model_config=model_config,
train_config=train_config,
device=device,
@@ -212,21 +212,21 @@ def _setup_dataloaders(
Parameters
----------
- train_dataset: Dataset, optional,
+ train_dataset : Dataset, optional
the training dataset
- val_dataset: Dataset, optional,
+ val_dataset : Dataset, optional
the validation dataset
"""
if train_dataset is None:
- train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False)
+ train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False) # type: ignore
if self.train_config.debug:
val_train_dataset = train_dataset
else:
val_train_dataset = None
if val_dataset is None:
- val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False)
+ val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False) # type: ignore
# https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/4
if torch.cuda.is_available():
@@ -235,7 +235,7 @@ def _setup_dataloaders(
num_workers = 0
self.train_loader = DataLoader(
- dataset=train_dataset,
+ dataset=train_dataset, # type: ignore
batch_size=self.batch_size,
shuffle=True,
num_workers=num_workers,
@@ -246,7 +246,7 @@ def _setup_dataloaders(
if self.train_config.debug:
self.val_train_loader = DataLoader(
- dataset=val_train_dataset,
+ dataset=val_train_dataset, # type: ignore
batch_size=self.batch_size,
shuffle=True,
num_workers=num_workers,
@@ -257,7 +257,7 @@ def _setup_dataloaders(
else:
self.val_train_loader = None
self.val_loader = DataLoader(
- dataset=val_dataset,
+ dataset=val_dataset, # type: ignore
batch_size=self.batch_size,
shuffle=True,
num_workers=num_workers,
@@ -266,7 +266,7 @@ def _setup_dataloaders(
collate_fn=collate_fn,
)
- def run_one_step(self, *data: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ def run_one_step(self, *data: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore
"""
Parameters
----------
@@ -331,9 +331,9 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
# each scoring is a dict consisting of the following metrics:
# sensitivity, precision, f1_score, mean_error, standard_deviation
eval_res_split = compute_ludb_metrics(
- np.repeat(all_labels[:, np.newaxis, :], self.model_config.n_leads, axis=1),
- np.repeat(all_mask_preds[:, np.newaxis, :], self.model_config.n_leads, axis=1),
- self._cm,
+ np.repeat(all_labels[:, np.newaxis, :], self.model_config.n_leads, axis=1), # type: ignore
+ np.repeat(all_mask_preds[:, np.newaxis, :], self.model_config.n_leads, axis=1), # type: ignore
+ self._cm, # type: ignore
self.train_config.fs,
)
@@ -356,7 +356,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
self.model.train()
- return eval_res
+ return eval_res # type: ignore
@property
def _cm(self) -> Dict[str, str]:
@@ -401,9 +401,10 @@ def test_unet_trainer() -> None:
# ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False)
# ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False)
- train_cfg_fl.keep_checkpoint_max = 0
- train_cfg_fl.monitor = None
- train_cfg_fl.n_epochs = 2
+ train_cfg_fl.keep_checkpoint_max = 2
+ train_cfg_fl.monitor = "f1_score"
+ train_cfg_fl.n_epochs = 3
+ train_cfg_fl.flooding_level = 0.1
model_config = deepcopy(ModelCfg)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -423,6 +424,8 @@ def test_unet_trainer() -> None:
lazy=False,
)
+ print(repr(trainer))
+
bmd = trainer.train()
del model, trainer, bmd
@@ -430,7 +433,7 @@ def test_unet_trainer() -> None:
def test_base_trainer():
with pytest.raises(TypeError, match="Can't instantiate abstract class"):
- BaseTrainer(
+ BaseTrainer( # type: ignore
model=ECG_UNET_LUDB(12, ModelCfg),
dataset_cls=LUDBDataset,
train_config=LUDBTrainCfg,
diff --git a/test/test_databases/test_afdb.py b/test/test_databases/test_afdb.py
index 63267682..d0ccabae 100644
--- a/test/test_databases/test_afdb.py
+++ b/test/test_databases/test_afdb.py
@@ -29,7 +29,7 @@
with pytest.warns(RuntimeWarning):
- reader = AFDB(_CWD, verbose=3)
+ reader = AFDB(_CWD, verbose=1)
if len(reader) == 0:
reader.download()
diff --git a/test/test_databases/test_base.py b/test/test_databases/test_base.py
index aa789905..50b4ce1d 100644
--- a/test/test_databases/test_base.py
+++ b/test/test_databases/test_base.py
@@ -2,6 +2,7 @@
from pathlib import Path
+import numpy as np
import pytest
from torch_ecg.databases import AFDB, list_databases
@@ -24,25 +25,25 @@ def test_base_database():
TypeError,
match=f"Can't instantiate abstract class {_DataBase.__name__}",
):
- db = _DataBase()
+ db = _DataBase() # type: ignore[abstract]
with pytest.raises(
TypeError,
match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}",
):
- db = PhysioNetDataBase()
+ db = PhysioNetDataBase() # type: ignore[abstract]
with pytest.raises(
TypeError,
match=f"Can't instantiate abstract class {NSRRDataBase.__name__}",
):
- db = NSRRDataBase()
+ db = NSRRDataBase() # type: ignore[abstract]
with pytest.raises(
TypeError,
match=f"Can't instantiate abstract class {CPSCDataBase.__name__}",
):
- db = CPSCDataBase()
+ db = CPSCDataBase() # type: ignore[abstract]
def test_beat_ann():
@@ -88,6 +89,9 @@ def test_database_meta():
for k in WFDB_Rhythm_Annotations:
assert reader.helper(k) is None # printed: `{k}` stands for `{WFDB_Rhythm_Annotations[k]}`
+ with pytest.raises(NotImplementedError, match="not implemented for"):
+ reader._auto_infer_units(np.ones((10, 2)), sig_type="EEG")
+
def test_database_info():
with pytest.warns(RuntimeWarning, match="`db_dir` is not specified"):
diff --git a/test/test_databases/test_ludb.py b/test/test_databases/test_ludb.py
index 8d894d01..910b1f9f 100644
--- a/test/test_databases/test_ludb.py
+++ b/test/test_databases/test_ludb.py
@@ -60,7 +60,7 @@ def test_load_data(self):
data_1 = reader.load_data(0, leads=[1, 7])
assert data.shape[0] == 12
assert data_1.shape[0] == 2
- assert np.allclose(data[[1, 7], :], data_1)
+ assert np.allclose(data[[1, 7], :], data_1) # type: ignore
def test_load_ann(self):
ann = reader.load_ann(0)
@@ -114,7 +114,11 @@ def test_meta_data(self):
def test_plot(self):
reader.plot(0, leads=["I", 5], ticks_granularity=2)
data = reader.load_data(0, leads="III", data_format="flat")
- reader.plot(0, data=data, leads="III")
+ reader.plot(0, data=data, leads="III") # type: ignore
+
+ def test_get_absolute_path(self):
+ path = reader.get_absolute_path(0, extension="avf")
+ assert path.is_file() and path.suffix == ".avf"
config = deepcopy(LUDBTrainCfg)
diff --git a/test/test_databases/test_shhs.py b/test/test_databases/test_shhs.py
index e4f8b8f1..9f66e376 100644
--- a/test/test_databases/test_shhs.py
+++ b/test/test_databases/test_shhs.py
@@ -4,6 +4,7 @@
subsampling: accomplished
"""
+import os
import time
from numbers import Real
from pathlib import Path
@@ -14,6 +15,11 @@
from torch_ecg.databases import SHHS, DataBaseInfo
+pytestmark = pytest.mark.skipif(
+ os.getenv("SHHS_DATA_AVAILABLE") != "true", reason="SHHS dataset not available (token invalid or download skipped)"
+)
+
+
###############################################################################
# set paths
# 9 files are downloaded in the following directory using `nsrr`
@@ -80,28 +86,28 @@ def test_load_psg_data(self):
assert isinstance(value, tuple)
assert len(value) == 2
assert isinstance(value[0], np.ndarray)
- assert isinstance(value[1], Real) and value[1] > 0
+ assert isinstance(value[1], Real) and value[1] > 0 # type: ignore
available_signals = reader.get_available_signals(0)
- for signal in available_signals:
+ for signal in available_signals: # type: ignore
psg_data = reader.load_psg_data(0, channel=signal, physical=True)
assert isinstance(psg_data, tuple)
assert len(psg_data) == 2
assert isinstance(psg_data[0], np.ndarray)
- assert isinstance(psg_data[1], Real) and psg_data[1] > 0
+ assert isinstance(psg_data[1], Real) and psg_data[1] > 0 # type: ignore
def test_load_data(self):
data, fs = reader.load_data(0)
assert isinstance(data, np.ndarray)
assert data.ndim == 2
- assert isinstance(fs, Real) and fs > 0
+ assert isinstance(fs, Real) and fs > 0 # type: ignore
data_1, fs_1 = reader.load_data(0, fs=500, data_format="flat")
assert isinstance(data_1, np.ndarray)
assert data_1.ndim == 1
assert fs_1 == 500
data_1, fs_1 = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat")
assert fs_1 == fs
- assert data_1.shape[0] == int(10 * fs)
- assert np.allclose(data_1, data[0, int(10 * fs) : int(20 * fs)])
+ assert data_1.shape[0] == int(10 * fs) # type: ignore
+ assert np.allclose(data_1, data[0, int(10 * fs) : int(20 * fs)]) # type: ignore
data_1 = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat", return_fs=False)
assert isinstance(data_1, np.ndarray)
data_2, _ = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat", units="uv")
@@ -244,7 +250,7 @@ def test_load_rpeak_ann(self):
ValueError,
match="`units` should be one of 's', 'ms', case insensitive, or None",
):
- reader.load_rpeak_ann(rec, units="invalid")
+ reader.load_rpeak_ann(rec, units="invalid") # type: ignore
def test_load_rr_ann(self):
rec = reader.rec_with_rpeaks_ann[0]
@@ -271,7 +277,7 @@ def test_load_rr_ann(self):
ValueError,
match="`units` should be one of 's', 'ms', case insensitive, or None",
):
- reader.load_rr_ann(rec, units="invalid")
+ reader.load_rr_ann(rec, units="invalid") # type: ignore
def test_load_nn_ann(self):
rec = reader.rec_with_rpeaks_ann[0]
@@ -331,7 +337,7 @@ def test_load_sleep_ann(self):
assert isinstance(ann["df_events"], pd.DataFrame) and len(ann["df_events"]) == 0
with pytest.raises(ValueError, match="Source `.+` not supported, "):
- reader.load_sleep_ann(rec, source="invalid")
+ reader.load_sleep_ann(rec, source="invalid") # type: ignore
def test_load_apnea_ann(self):
rec = reader.rec_with_event_ann[0]
@@ -352,7 +358,7 @@ def test_load_apnea_ann(self):
assert isinstance(ann, pd.DataFrame) and ann.empty
with pytest.raises(ValueError, match="Source `hrv` contains no apnea annotations"):
- reader.load_apnea_ann(rec, source="hrv")
+ reader.load_apnea_ann(rec, source="hrv") # type: ignore
def test_load_sleep_event_ann(self):
rec = reader.rec_with_event_ann[0]
@@ -379,7 +385,7 @@ def test_load_sleep_event_ann(self):
assert isinstance(ann, pd.DataFrame) and len(ann) == 0
with pytest.raises(ValueError, match="Source `.+` not supported, "):
- reader.load_sleep_event_ann(rec, source="invalid")
+ reader.load_sleep_event_ann(rec, source="invalid") # type: ignore
def test_load_sleep_stage_ann(self):
rec = reader.rec_with_event_ann[0]
@@ -402,7 +408,7 @@ def test_load_sleep_stage_ann(self):
assert isinstance(ann, pd.DataFrame) and len(ann) == 0
with pytest.raises(ValueError, match="Source `.+` not supported, "):
- reader.load_sleep_stage_ann(rec, source="invalid")
+ reader.load_sleep_stage_ann(rec, source="invalid") # type: ignore
def test_locate_abnormal_beats(self):
rec = reader.rec_with_rpeaks_ann[0]
@@ -439,12 +445,12 @@ def test_locate_abnormal_beats(self):
rec = reader.rec_with_rpeaks_ann[0]
with pytest.raises(ValueError, match="No abnormal type of `.+`"):
- reader.locate_abnormal_beats(rec, abnormal_type="AF")
+ reader.locate_abnormal_beats(rec, abnormal_type="AF") # type: ignore
with pytest.raises(
ValueError,
match="`units` should be one of 's', 'ms', case insensitive, or None",
):
- reader.locate_abnormal_beats(rec, units="invalid")
+ reader.locate_abnormal_beats(rec, units="invalid") # type: ignore
def test_locate_artifacts(self):
rec = reader.rec_with_rpeaks_ann[0]
@@ -473,7 +479,7 @@ def test_locate_artifacts(self):
ValueError,
match="`units` should be one of 's', 'ms', case insensitive, or None",
):
- reader.locate_artifacts(rec, units="invalid")
+ reader.locate_artifacts(rec, units="invalid") # type: ignore
def test_get_available_signals(self):
assert reader.get_available_signals(None) is None # no return
@@ -486,14 +492,14 @@ def test_get_available_signals(self):
def test_get_chn_num(self):
available_signals = reader.get_available_signals(0)
- for sig in available_signals:
+ for sig in available_signals: # type: ignore
chn_num = reader.get_chn_num(0, sig)
assert isinstance(chn_num, int)
- assert 0 <= chn_num < len(available_signals)
+ assert 0 <= chn_num < len(available_signals) # type: ignore
def test_match_channel(self):
available_signals = reader.get_available_signals(0)
- for sig in available_signals:
+ for sig in available_signals: # type: ignore
assert sig == reader.match_channel(sig.lower())
assert sig in reader.all_signals
@@ -501,13 +507,13 @@ def test_match_channel(self):
def test_get_fs(self):
available_signals = reader.get_available_signals(0)
- for sig in available_signals:
+ for sig in available_signals: # type: ignore
fs = reader.get_fs(0, sig)
- assert isinstance(fs, Real) and fs > 0
+ assert isinstance(fs, Real) and fs > 0 # type: ignore
rec = reader.rec_with_rpeaks_ann[0]
fs = reader.get_fs(rec, "rpeak")
- assert isinstance(fs, Real) and fs > 0
+ assert isinstance(fs, Real) and fs > 0 # type: ignore
rec = "shhs2-200001" # a record (both signal and ann. files) that does not exist
fs = reader.get_fs(rec)
@@ -629,7 +635,7 @@ def test_plot(self):
match="Unknown plot format `xxx`! `plot_format` can only be one of `span`, `hypnogram`",
):
rec = reader.rec_with_event_ann[0]
- reader.plot_ann(rec, event_source="event", plot_format="xxx")
+ reader.plot_ann(rec, event_source="event", plot_format="xxx") # type: ignore
with pytest.raises(ValueError, match="No input data"):
rec = reader.rec_with_event_ann[0]
diff --git a/test/test_databases/test_sph.py b/test/test_databases/test_sph.py
index 3b908053..0a04120d 100644
--- a/test/test_databases/test_sph.py
+++ b/test/test_databases/test_sph.py
@@ -66,6 +66,7 @@ def test_load_data(self):
data_1, data_1_fs = reader.load_data(rec, return_fs=True)
assert data_1_fs == reader.fs
+ rec = 0
with pytest.raises(AssertionError, match="Invalid data_format: `flat`"):
reader.load_data(rec, data_format="flat")
with pytest.raises(AssertionError, match="Invalid units: `kV`"):
@@ -82,6 +83,7 @@ def test_load_ann(self):
ann_1 = reader.load_ann(rec, ann_format="f", ignore_modifier=False)
assert len(ann) == len(ann_1)
+ rec = 0
with pytest.raises(ValueError, match="Unknown annotation format: `flat`"):
reader.load_ann(rec, ann_format="flat")
with pytest.raises(NotImplementedError, match="Abbreviations are not supported yet"):
@@ -92,23 +94,34 @@ def test_get_subject_info(self):
info = reader.get_subject_info(rec)
assert isinstance(info, dict)
assert info.keys() == {"age", "sex"}
+ info = reader.get_subject_info(0, items=["age"])
+ assert isinstance(info, dict)
+ assert info.keys() == {"age"}
def test_get_subject_id(self):
for rec in reader:
sid = reader.get_subject_id(rec)
assert isinstance(sid, str)
+ sid = reader.get_subject_id(0)
+ assert isinstance(sid, str)
def test_get_age(self):
for rec in reader:
age = reader.get_age(rec)
assert isinstance(age, int)
assert age > 0
+ age = reader.get_age(0)
+ assert isinstance(age, int)
+ assert age > 0
def test_get_sex(self):
for rec in reader:
sex = reader.get_sex(rec)
assert isinstance(sex, str)
assert sex in ["M", "F"]
+ sex = reader.get_sex(0)
+ assert isinstance(sex, str)
+ assert sex in ["M", "F"]
def test_get_siglen(self):
for rec in reader:
@@ -116,6 +129,10 @@ def test_get_siglen(self):
data = reader.load_data(rec)
assert isinstance(siglen, int)
assert siglen == data.shape[1]
+ siglen = reader.get_siglen(0)
+ data = reader.load_data(0)
+ assert isinstance(siglen, int)
+ assert siglen == data.shape[1]
def test_meta_data(self):
assert isinstance(reader.url, dict)
@@ -131,18 +148,18 @@ def test_plot(self):
"t_onsets": [150, 1150],
"t_offsets": [190, 1190],
}
- reader.plot(0, leads="II", ticks_granularity=2, waves=waves)
+ reader.plot(0, leads="II", ticks_granularity=2, waves=waves) # type: ignore
waves = {
"p_peaks": [105, 1105],
"q_peaks": [120, 1120],
"s_peaks": [125, 1125],
"t_peaks": [170, 1170],
}
- reader.plot(0, leads=["II", 7], ticks_granularity=1, waves=waves)
+ reader.plot(0, leads=["II", 7], ticks_granularity=1, waves=waves) # type: ignore
waves = {
"p_peaks": [105, 1105],
"r_peaks": [122, 1122],
"t_peaks": [170, 1170],
}
data = reader.load_data(0)
- reader.plot(0, data=data, ticks_granularity=0, waves=waves)
+ reader.plot(0, data=data, ticks_granularity=0, waves=waves) # type: ignore
diff --git a/test/test_models/test_ecg_crnn.py b/test/test_models/test_ecg_crnn.py
index 75db7109..5805cc1f 100644
--- a/test/test_models/test_ecg_crnn.py
+++ b/test/test_models/test_ecg_crnn.py
@@ -24,18 +24,18 @@ def test_ecg_crnn():
inp = torch.randn(2, n_leads, 2000).to(DEVICE)
grid = itertools.product(
- [cnn_name for cnn_name in ECG_CRNN_CONFIG.cnn.keys() if cnn_name != "name"],
- [rnn_name for rnn_name in ECG_CRNN_CONFIG.rnn.keys() if rnn_name != "name"] + ["none"],
- [attn_name for attn_name in ECG_CRNN_CONFIG.attn.keys() if attn_name != "name"] + ["none"],
+ [cnn_name for cnn_name in ECG_CRNN_CONFIG.cnn.keys() if cnn_name != "name"], # type: ignore
+ [rnn_name for rnn_name in ECG_CRNN_CONFIG.rnn.keys() if rnn_name != "name"] + ["none"], # type: ignore
+ [attn_name for attn_name in ECG_CRNN_CONFIG.attn.keys() if attn_name != "name"] + ["none"], # type: ignore
["none", "max", "avg"], # global pool
)
- total = (len(ECG_CRNN_CONFIG.cnn.keys()) - 1) * len(ECG_CRNN_CONFIG.rnn.keys()) * len(ECG_CRNN_CONFIG.attn.keys()) * 3
+ total = (len(ECG_CRNN_CONFIG.cnn.keys()) - 1) * len(ECG_CRNN_CONFIG.rnn.keys()) * len(ECG_CRNN_CONFIG.attn.keys()) * 3 # type: ignore
for cnn_name, rnn_name, attn_name, global_pool in tqdm(grid, total=total, mininterval=1):
config = deepcopy(ECG_CRNN_CONFIG)
- config.cnn.name = cnn_name
- config.rnn.name = rnn_name
- config.attn.name = attn_name
+ config.cnn.name = cnn_name # type: ignore
+ config.rnn.name = rnn_name # type: ignore
+ config.attn.name = attn_name # type: ignore
config.global_pool = global_pool
model = ECG_CRNN(classes=classes, n_leads=n_leads, config=config).to(DEVICE)
@@ -55,15 +55,15 @@ def test_ecg_crnn():
# load weights from v1
model.cnn.load_state_dict(model_v1.cnn.state_dict())
if model.rnn.__class__.__name__ != "Identity":
- model.rnn.load_state_dict(model_v1.rnn.state_dict())
+ model.rnn.load_state_dict(model_v1.rnn.state_dict()) # type: ignore
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
- doi = model.doi
+ doi = model.doi # type: ignore
assert isinstance(doi, list)
assert all([isinstance(d, str) for d in doi]), doi
- doi = model_v1.doi
+ doi = model_v1.doi # type: ignore
assert isinstance(doi, list)
assert all([isinstance(d, str) for d in doi]), doi
@@ -85,24 +85,24 @@ def test_warns_errors():
model_v1.inference(inp)
config = deepcopy(ECG_CRNN_CONFIG)
- config.cnn.name = "not_implemented"
- config.cnn.not_implemented = {}
+ config.cnn.name = "not_implemented" # type: ignore
+ config.cnn.not_implemented = {} # type: ignore
with pytest.raises(NotImplementedError, match="CNN \042.+\042 not implemented yet"):
ECG_CRNN(classes=classes, n_leads=n_leads, config=config)
with pytest.raises(NotImplementedError, match="CNN \042.+\042 not implemented yet"):
ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config)
config = deepcopy(ECG_CRNN_CONFIG)
- config.rnn.name = "not_implemented"
- config.rnn.not_implemented = {}
+ config.rnn.name = "not_implemented" # type: ignore
+ config.rnn.not_implemented = {} # type: ignore
with pytest.raises(NotImplementedError, match="RNN \042.+\042 not implemented yet"):
ECG_CRNN(classes=classes, n_leads=n_leads, config=config)
with pytest.raises(NotImplementedError, match="RNN \042.+\042 not implemented yet"):
ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config)
config = deepcopy(ECG_CRNN_CONFIG)
- config.attn.name = "not_implemented"
- config.attn.not_implemented = {}
+ config.attn.name = "not_implemented" # type: ignore
+ config.attn.not_implemented = {} # type: ignore
with pytest.raises(NotImplementedError, match="Attention \042.+\042 not implemented yet"):
ECG_CRNN(classes=classes, n_leads=n_leads, config=config)
with pytest.raises(NotImplementedError, match="Attention \042.+\042 not implemented yet"):
@@ -128,10 +128,10 @@ def test_from_v1():
n_leads = 12
classes = ["NSR", "AF", "PVC", "LBBB", "RBBB", "PAB", "VFL"]
model_v1 = ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config)
- model_v1.save(_TMP_DIR / "ecg_crnn_v1.pth", {"classes": classes, "n_leads": n_leads})
+ model_v1.save(_TMP_DIR / "ecg_crnn_v1.pth", {"classes": classes, "n_leads": n_leads}, use_safetensors=False) # type: ignore
model = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth")
del model
- model, _ = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth", return_config=True)
+ model, _ = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth", return_config=True) # type: ignore
(_TMP_DIR / "ecg_crnn_v1.pth").unlink()
del model_v1, model
diff --git a/test/test_models/test_rr_lstm.py b/test/test_models/test_rr_lstm.py
index 6b75c49d..8e33b851 100644
--- a/test/test_models/test_rr_lstm.py
+++ b/test/test_models/test_rr_lstm.py
@@ -22,9 +22,9 @@ def test_rr_lstm():
inp_bf = torch.randn(2, in_channels, 100).to(DEVICE)
config = deepcopy(RR_LSTM_CONFIG)
- config.clf.name = "crf"
+ config.clf.name = "crf" # type: ignore
for attn_name in ["none"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp)
@@ -35,22 +35,22 @@ def test_rr_lstm():
model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1])
model.lstm.load_state_dict(model_v1.lstm.state_dict())
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
config = deepcopy(RR_LSTM_CONFIG)
- config.clf.name = "crf"
+ config.clf.name = "crf" # type: ignore
config.batch_first = True
for attn_name in ["none"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp_bf)
assert out.shape == model.compute_output_shape(seq_len=inp_bf.shape[-1], batch_size=inp_bf.shape[0])
config = deepcopy(RR_LSTM_CONFIG)
- config.clf.name = "linear"
+ config.clf.name = "linear" # type: ignore
for attn_name in ["none", "gc", "nl", "se"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp)
@@ -61,13 +61,13 @@ def test_rr_lstm():
model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1])
model.lstm.load_state_dict(model_v1.lstm.state_dict())
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
config = deepcopy(RR_LSTM_CONFIG)
- config.clf.name = "linear"
+ config.clf.name = "linear" # type: ignore
config.batch_first = True
for attn_name in ["none", "gc", "nl", "se"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp_bf)
@@ -75,7 +75,7 @@ def test_rr_lstm():
config = deepcopy(RR_AF_VANILLA_CONFIG)
for attn_name in ["none", "gc", "nl", "se"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp)
@@ -86,12 +86,12 @@ def test_rr_lstm():
model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1])
model.lstm.load_state_dict(model_v1.lstm.state_dict())
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
config = deepcopy(RR_AF_VANILLA_CONFIG)
config.batch_first = True
for attn_name in ["none", "gc", "nl", "se"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp_bf)
@@ -99,7 +99,7 @@ def test_rr_lstm():
config = deepcopy(RR_AF_CRF_CONFIG)
for attn_name in ["none"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp)
@@ -110,21 +110,21 @@ def test_rr_lstm():
model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1])
model.lstm.load_state_dict(model_v1.lstm.state_dict())
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
config = deepcopy(RR_AF_CRF_CONFIG)
config.batch_first = True
for attn_name in ["none"]:
- config.attn.name = attn_name
+ config.attn.name = attn_name # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp_bf)
assert out.shape == model.compute_output_shape(seq_len=inp_bf.shape[-1], batch_size=inp_bf.shape[0])
config = deepcopy(RR_LSTM_CONFIG)
- config.lstm.retseq = False
- config.clf.name = "linear"
- config.attn.name = "none"
+ config.lstm.retseq = False # type: ignore
+ config.clf.name = "linear" # type: ignore
+ config.attn.name = "none" # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp)
@@ -135,13 +135,13 @@ def test_rr_lstm():
model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1])
model.lstm.load_state_dict(model_v1.lstm.state_dict())
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
config = deepcopy(RR_LSTM_CONFIG)
- config.lstm.retseq = False
- config.clf.name = "linear"
+ config.lstm.retseq = False # type: ignore
+ config.clf.name = "linear" # type: ignore
config.batch_first = True
- config.attn.name = "none"
+ config.attn.name = "none" # type: ignore
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
model = model.eval()
out = model(inp_bf)
@@ -167,9 +167,9 @@ def test_warns_errors():
model_v1 = RR_LSTM_v1(classes=classes).to(DEVICE)
config = deepcopy(RR_LSTM_CONFIG)
- config.lstm.retseq = False
- config.attn.name = "gc"
- config.clf.name = "linear"
+ config.lstm.retseq = False # type: ignore
+ config.attn.name = "gc" # type: ignore
+ config.clf.name = "linear" # type: ignore
with pytest.warns(
RuntimeWarning,
match="Attention is not supported when lstm is not returning sequences",
@@ -182,9 +182,9 @@ def test_warns_errors():
model_v1 = RR_LSTM_v1(classes=classes, config=config).to(DEVICE)
config = deepcopy(RR_LSTM_CONFIG)
- config.lstm.retseq = False
- config.attn.name = "none"
- config.clf.name = "crf"
+ config.lstm.retseq = False # type: ignore
+ config.attn.name = "none" # type: ignore
+ config.clf.name = "crf" # type: ignore
with pytest.warns(
RuntimeWarning,
match="CRF layer is not supported in non-sequence mode, using linear instead",
@@ -215,15 +215,15 @@ def test_warns_errors():
model_v1.inference(inp)
config = deepcopy(RR_LSTM_CONFIG)
- config.attn.name = "not_implemented"
- config.attn.not_implemented = {}
+ config.attn.name = "not_implemented" # type: ignore
+ config.attn.not_implemented = {} # type: ignore
with pytest.raises(NotImplementedError, match="Attn module \042.+\042 not implemented yet"):
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
with pytest.raises(NotImplementedError, match="Attn module \042.+\042 not implemented yet"):
model_v1 = RR_LSTM_v1(classes=classes, config=config).to(DEVICE)
config = deepcopy(RR_LSTM_CONFIG)
- config.clf.name = "linear"
+ config.clf.name = "linear" # type: ignore
config.global_pool = "not_supported"
with pytest.raises(NotImplementedError, match="Pooling type \042.+\042 not supported"):
model = RR_LSTM(classes=classes, config=config).to(DEVICE)
@@ -235,9 +235,9 @@ def test_from_v1():
config = deepcopy(RR_LSTM_CONFIG)
classes = ["NSR", "AF", "PVC", "LBBB", "RBBB", "PAB", "VFL"]
model_v1 = RR_LSTM_v1(classes=classes, config=config)
- model_v1.save(_TMP_DIR / "rr_lstm_v1.pth", {"classes": classes})
- model = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth")
+ model_v1.save(_TMP_DIR / "rr_lstm_v1.pth", {"classes": classes}, use_safetensors=False) # type: ignore
+ model = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth") # type: ignore
del model
- model, _ = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth", return_config=True)
+ model, _ = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth", return_config=True) # type: ignore
(_TMP_DIR / "rr_lstm_v1.pth").unlink()
del model_v1, model
diff --git a/test/test_models/test_seq_lab_net.py b/test/test_models/test_seq_lab_net.py
index 7f8d8cf9..f0204eed 100644
--- a/test/test_models/test_seq_lab_net.py
+++ b/test/test_models/test_seq_lab_net.py
@@ -33,32 +33,32 @@ def test_ecg_seq_lab_net():
for cnn, rnn, attn, recover_length in tqdm(grid, total=total, mininterval=1):
config = adjust_cnn_filter_lengths(ECG_SEQ_LAB_NET_CONFIG, fs)
- config.cnn.name = cnn
- config.rnn.name = rnn
- config.attn.name = attn
- config.recover_length = recover_length
+ config.cnn.name = cnn # type: ignore
+ config.rnn.name = rnn # type: ignore
+ config.attn.name = attn # type: ignore
+ config.recover_length = recover_length # type: ignore
- model = ECG_SEQ_LAB_NET(classes=classes, n_leads=12, config=config).to(DEVICE)
+ model = ECG_SEQ_LAB_NET(classes=classes, n_leads=12, config=config).to(DEVICE) # type: ignore
model = model.eval()
out = model(inp)
assert out.shape == model.compute_output_shape(seq_len=inp.shape[-1], batch_size=inp.shape[0])
if recover_length:
assert out.shape[1] == inp.shape[-1]
- model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=12, config=config).to(DEVICE)
+ model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=12, config=config).to(DEVICE) # type: ignore
model_v1 = model_v1.eval()
out_v1 = model_v1(inp)
model.cnn.load_state_dict(model_v1.cnn.state_dict())
if model.rnn.__class__.__name__ != "Identity":
- model.rnn.load_state_dict(model_v1.rnn.state_dict())
+ model.rnn.load_state_dict(model_v1.rnn.state_dict()) # type: ignore
if model.attn.__class__.__name__ != "Identity":
- model.attn.load_state_dict(model_v1.attn.state_dict())
+ model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore
model.clf.load_state_dict(model_v1.clf.state_dict())
- doi = model.doi
+ doi = model.doi # type: ignore
assert isinstance(doi, list)
assert all([isinstance(d, str) for d in doi]), doi
- doi = model_v1.doi
+ doi = model_v1.doi # type: ignore
assert isinstance(doi, list)
assert all([isinstance(d, str) for d in doi]), doi
@@ -92,9 +92,9 @@ def test_from_v1():
n_leads = 12
classes = ["N"]
model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=n_leads, config=config)
- model_v1.save(_TMP_DIR / "ecg_seq_lab_net_v1.pth", {"classes": classes, "n_leads": n_leads})
- model = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth")
+ model_v1.save(_TMP_DIR / "ecg_seq_lab_net_v1.pth", {"classes": classes, "n_leads": n_leads}, use_safetensors=False) # type: ignore
+ model = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth") # type: ignore
del model
- model, _ = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth", return_config=True)
+ model, _ = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth", return_config=True) # type: ignore
(_TMP_DIR / "ecg_seq_lab_net_v1.pth").unlink()
del model_v1, model
diff --git a/test/test_pipelines/test_crnn_cinc2021_pipeline.py b/test/test_pipelines/test_crnn_cinc2021_pipeline.py
index 07171053..349dfb85 100644
--- a/test/test_pipelines/test_crnn_cinc2021_pipeline.py
+++ b/test/test_pipelines/test_crnn_cinc2021_pipeline.py
@@ -405,8 +405,8 @@ def inference(
# batch_size, channels, seq_len = _input.shape
prob = self.sigmoid(self.forward(_input))
pred = (prob >= bin_pred_thr).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
for row_idx, row in enumerate(pred):
row_max_prob = prob[row_idx, ...].max()
if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None:
diff --git a/test/test_pipelines/test_mtl_cinc2022_pipeline.py b/test/test_pipelines/test_mtl_cinc2022_pipeline.py
index 21d46270..420ae746 100644
--- a/test/test_pipelines/test_mtl_cinc2022_pipeline.py
+++ b/test/test_pipelines/test_mtl_cinc2022_pipeline.py
@@ -2212,31 +2212,31 @@ def inference(
prob = self.softmax(forward_output["murmur"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
murmur_output = ClassificationOutput(
classes=self.classes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["murmur"].cpu().detach().numpy(),
+ forward_output=forward_output["murmur"].detach().cpu().numpy(),
)
if forward_output.get("outcome", None) is not None:
prob = self.softmax(forward_output["outcome"])
pred = torch.argmax(prob, dim=-1)
bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int)
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
- bin_pred = bin_pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
+ bin_pred = bin_pred.detach().cpu().numpy()
outcome_output = ClassificationOutput(
classes=self.outcomes,
prob=prob,
pred=pred,
bin_pred=bin_pred,
- forward_output=forward_output["outcome"].cpu().detach().numpy(),
+ forward_output=forward_output["outcome"].detach().cpu().numpy(),
)
else:
outcome_output = None
@@ -2250,13 +2250,13 @@ def inference(
else:
prob = self.sigmoid(forward_output["segmentation"])
pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int()
- prob = prob.cpu().detach().numpy()
- pred = pred.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
+ pred = pred.detach().cpu().numpy()
segmentation_output = SequenceLabellingOutput(
classes=self.states,
prob=prob,
pred=pred,
- forward_output=forward_output["segmentation"].cpu().detach().numpy(),
+ forward_output=forward_output["segmentation"].detach().cpu().numpy(),
)
else:
segmentation_output = None
diff --git a/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py b/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py
index 0f0145ba..1a4e8287 100644
--- a/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py
+++ b/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py
@@ -149,7 +149,7 @@ def inference(
mode="linear",
align_corners=True,
).permute(0, 2, 1)
- prob = prob.cpu().detach().numpy().squeeze(-1)
+ prob = prob.detach().cpu().numpy().squeeze(-1)
# prob --> qrs mask --> qrs intervals --> rpeaks
rpeaks = _inference_post_process(
@@ -169,7 +169,7 @@ def inference(
sampling_rate=self.config.fs,
tol=0.05,
)[0]
- for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks)
+ for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks)
]
return RPeaksDetectionOutput(
diff --git a/test/test_pipelines/test_unet_ludb_pipeline.py b/test/test_pipelines/test_unet_ludb_pipeline.py
index add76d60..7773893c 100644
--- a/test/test_pipelines/test_unet_ludb_pipeline.py
+++ b/test/test_pipelines/test_unet_ludb_pipeline.py
@@ -128,7 +128,7 @@ def inference(
prob = self.softmax(prob)
else:
prob = torch.sigmoid(prob)
- prob = prob.cpu().detach().numpy()
+ prob = prob.detach().cpu().numpy()
if "i" in self.classes:
mask = np.argmax(prob, axis=-1)
diff --git a/test/test_utils/test_download.py b/test/test_utils/test_download.py
index 56e6ceed..bfda06f6 100644
--- a/test/test_utils/test_download.py
+++ b/test/test_utils/test_download.py
@@ -1,18 +1,14 @@
""" """
import shutil
+import subprocess
+import types
import urllib.parse
from pathlib import Path
import pytest
-from torch_ecg.utils.download import (
- _download_from_aws_s3_using_boto3,
- _download_from_google_drive,
- http_get,
- is_compressed_file,
- url_is_reachable,
-)
+import torch_ecg.utils.download as dl
_TMP_DIR = Path(__file__).resolve().parents[2] / "tmp" / "test_download"
_TMP_DIR.mkdir(parents=True, exist_ok=True)
@@ -22,11 +18,11 @@ def test_http_get():
# normally, direct downloading from dropbox with `dl=0` will not download the file
# http_get internally replaces `dl=0` with `dl=1` to force download
url = "https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0"
- http_get(url, _TMP_DIR / "action-test-zip-extract", extract=True, filename="test.zip")
+ dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract=True, filename="test.zip")
shutil.rmtree(_TMP_DIR / "action-test-zip-extract")
- http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto", filename="test.zip")
+ dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto", filename="test.zip")
shutil.rmtree(_TMP_DIR / "action-test-zip-extract")
- http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto")
+ dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto")
shutil.rmtree(_TMP_DIR / "action-test-zip-extract")
url = (
@@ -41,13 +37,13 @@ def test_http_get():
"Automatic decompression is turned off\\."
),
):
- http_get(url, _TMP_DIR, extract=True, filename="test.txt")
+ dl.http_get(url, _TMP_DIR, extract=True, filename="test.txt")
with pytest.raises(FileExistsError, match="file already exists"):
- http_get(url, _TMP_DIR, extract=True, filename="test.txt")
+ dl.http_get(url, _TMP_DIR, extract=True, filename="test.txt")
(_TMP_DIR / "test.txt").unlink()
- http_get(url, _TMP_DIR, extract="auto", filename="test.txt")
+ dl.http_get(url, _TMP_DIR, extract="auto", filename="test.txt")
(_TMP_DIR / "test.txt").unlink()
- http_get(url, _TMP_DIR, extract="auto")
+ dl.http_get(url, _TMP_DIR, extract="auto")
with pytest.warns(
RuntimeWarning,
@@ -57,7 +53,7 @@ def test_http_get():
"The user is responsible for decompressing the file manually\\."
),
):
- http_get(url, _TMP_DIR, extract=True)
+ dl.http_get(url, _TMP_DIR, extract=True)
Path(_TMP_DIR / Path(url).name).unlink()
# test downloading from Google Drive
@@ -66,45 +62,246 @@ def test_http_get():
url_no_scheme = f"drive.google.com/file/d/{file_id}/view?usp=sharing"
url_xxx_schme = f"xxx://drive.google.com/file/d/{file_id}/view?usp=sharing"
with pytest.raises(AssertionError, match="filename can not be inferred from Google Drive URL"):
- http_get(url_no_scheme, _TMP_DIR)
+ dl.http_get(url_no_scheme, _TMP_DIR)
with pytest.raises(ValueError, match="Unsupported URL scheme"):
- http_get(url_xxx_schme, _TMP_DIR, extract=False, filename="torch-ecg-paper.bib")
- http_get(url, _TMP_DIR, filename="torch-ecg-paper.bib", extract=False)
+ dl.http_get(url_xxx_schme, _TMP_DIR, extract=False, filename="torch-ecg-paper.bib")
+ dl.http_get(url, _TMP_DIR, filename="torch-ecg-paper.bib", extract=False)
(_TMP_DIR / "torch-ecg-paper.bib").unlink()
- _download_from_google_drive(file_id, _TMP_DIR / "torch-ecg-paper.bib")
+ dl._download_from_google_drive(file_id, _TMP_DIR / "torch-ecg-paper.bib")
(_TMP_DIR / "torch-ecg-paper.bib").unlink()
- _download_from_google_drive(url_no_scheme, _TMP_DIR / "torch-ecg-paper.bib")
+ dl._download_from_google_drive(url_no_scheme, _TMP_DIR / "torch-ecg-paper.bib")
(_TMP_DIR / "torch-ecg-paper.bib").unlink()
# test downloading from AWS S3 (by default using AWS CLI)
(_TMP_DIR / "ludb").mkdir(exist_ok=True)
- http_get("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb")
+ dl.http_get("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb")
# test downloading from AWS S3 (by using boto3)
shutil.rmtree(_TMP_DIR / "ludb")
(_TMP_DIR / "ludb").mkdir(exist_ok=True)
- _download_from_aws_s3_using_boto3("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb")
+ dl._download_from_aws_s3_using_boto3("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb")
+
+ with pytest.raises(ValueError, match="Invalid S3 URL"):
+ dl.http_get("s3://xxx", _TMP_DIR / "ludb")
+
+ assert dl._stem(b"https://example.com/path/to/file.tar.gz") == "file"
def test_url_is_reachable():
- assert url_is_reachable("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1")
- assert not url_is_reachable("https://www.some-unknown-domain.com/unknown-path/unknown-file.zip")
+ assert dl.url_is_reachable("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1")
+ assert not dl.url_is_reachable("https://www.some-unknown-domain.com/unknown-path/unknown-file.zip")
def test_is_compressed_file():
# check local files
- assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txt")
- assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract")
- assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test")
- assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.pth.tar")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.gz")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tgz")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.bz2")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tbz2")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.xz")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txz")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.zip")
- assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.7z")
+ assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txt")
+ assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract")
+ assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test")
+ assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.pth.tar")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.gz")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tgz")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.bz2")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tbz2")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.xz")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txz")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.zip")
+ assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.7z")
# check remote files (by URL)
- assert is_compressed_file(urllib.parse.urlparse("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0").path)
+ assert dl.is_compressed_file(urllib.parse.urlparse("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0").path)
+
+
+class FakeResponseOK:
+ """A fake streaming HTTP response that will raise during iter_content."""
+
+ def __init__(self, raise_in_iter=False, raise_text=False):
+ self.status_code = 200
+ self.headers = {"Content-Length": "10"}
+ self._raise_in_iter = raise_in_iter
+ self._raise_text = raise_text
+ self._text_value = "FAKE_CONTENT_ABCDEFG" * 5
+
+ def iter_content(self, chunk_size=1024):
+ if self._raise_in_iter:
+ # simulate mid-download exception
+ raise RuntimeError("iter boom")
+ yield b"abc"
+ yield b"def"
+
+ @property
+ def text(self):
+ if self._raise_text:
+ raise ValueError("text boom")
+ return self._text_value
+
+
+class FakeSession:
+ def __init__(self, response: FakeResponseOK):
+ self._resp = response
+
+ def get(self, *a, **kw):
+ return self._resp
+
+
+@pytest.mark.parametrize("raise_text", [False, True], ids=["text_ok", "text_raises"])
+def test_http_get_iter_exception_triggers_runtime(monkeypatch, tmp_path, raise_text):
+ """Cover the outer except and the inner try/except that reads req.text."""
+ resp = FakeResponseOK(raise_in_iter=True, raise_text=raise_text)
+
+ def fake_retry_session():
+ return FakeSession(resp)
+
+ monkeypatch.setattr(dl, "_requests_retry_session", lambda: types.SimpleNamespace(get=fake_retry_session().get))
+
+ target_dir = tmp_path / "download"
+ with pytest.raises(RuntimeError) as exc:
+ dl.http_get("https://example.com/file.dat", target_dir, extract=False, filename="f.bin")
+
+ msg = str(exc.value)
+ if raise_text:
+ # inner text extraction failed => snippet empty
+ assert "body[:300]=''" in msg or "body[:300]=''"[:10]
+ else:
+ assert "FAKE_CONTENT" in msg
+
+
+def test_http_get_status_403(monkeypatch, tmp_path):
+ """Force status 403 path (your code raises generic Exception then caught by outer except)."""
+
+ class Resp403:
+ status_code = 403
+ headers = {}
+ text = "Forbidden"
+
+ def iter_content(self, chunk_size=1024):
+ yield b""
+
+ def fake_session():
+ return types.SimpleNamespace(get=lambda *a, **k: Resp403())
+
+ monkeypatch.setattr(dl, "_requests_retry_session", lambda: fake_session())
+
+ with pytest.raises(RuntimeError) as exc:
+ dl.http_get("https://example.com/forbidden.bin", tmp_path, extract=False, filename="f.bin")
+ assert "Failed to download" in str(exc.value)
+ assert "status=403" in str(exc.value)
+ assert "Forbidden" in str(exc.value)
+
+
+def test_download_from_aws_s3_using_boto3_empty_bucket(monkeypatch, tmp_path):
+ """Cover object_count == 0 -> ValueError."""
+
+ class FakePaginator:
+ def paginate(self, **kwargs):
+ # return an iterator with no 'Contents'
+ return iter([{"Other": 1}, {"Meta": 2}])
+
+ class FakeBoto3Client:
+ def get_paginator(self, name):
+ assert name == "list_objects_v2"
+ return FakePaginator()
+
+ def close(self):
+ pass
+
+ monkeypatch.setattr(dl, "boto3", types.SimpleNamespace(client=lambda *a, **k: FakeBoto3Client()))
+
+ with pytest.raises(ValueError, match="No objects found"):
+ dl._download_from_aws_s3_using_boto3("s3://fake-bucket/prefix/", tmp_path / "out")
+
+
+def test_download_from_aws_s3_using_awscli_subprocess_fail(monkeypatch, tmp_path):
+ """Force awscli sync subprocess to exit with non-zero code -> CalledProcessError."""
+ # Pretend aws exists
+ monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws")
+ # Control object count
+ monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 3)
+
+ class FakeStdout:
+ def __init__(self, lines):
+ self._lines = lines # list[str]
+ self._idx = 0
+ self.closed = False
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self._idx < len(self._lines):
+ line = self._lines[self._idx]
+ self._idx += 1
+ return line
+ raise StopIteration
+
+ def close(self):
+ self.closed = True
+
+ class FakePopen:
+ def __init__(self, *a, **k):
+ self.stdout = FakeStdout(
+ [
+ "download: s3://bucket/file1 to file1\n",
+ "download: s3://bucket/file2 to file2\n",
+ "some other line\n",
+ ]
+ )
+ self._returncode = None
+ self._waited = False
+ self.args = a
+ self.kwargs = k
+
+ def poll(self):
+ if self.stdout._idx < len(self.stdout._lines):
+ return None
+ if self._returncode is None:
+ self._returncode = 2
+ return self._returncode
+
+ def wait(self):
+ self.stdout._idx = len(self.stdout._lines)
+ if self._returncode is None:
+ self._returncode = 2
+ self._waited = True
+ return self._returncode
+
+ def communicate(self):
+ combined = "".join(self.stdout._lines)
+ return (combined, "boom\n")
+
+ monkeypatch.setattr(dl.subprocess, "Popen", lambda *a, **k: FakePopen(*a, **k))
+
+ with pytest.raises(subprocess.CalledProcessError) as exc:
+ dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path / "out", show_progress=False)
+
+ assert "download: s3://bucket/file1" in exc.value.output
+ assert exc.value.returncode != 0
+
+
+def test_url_is_reachable_exception(monkeypatch):
+ def boom(*a, **k):
+ raise RuntimeError("network down")
+
+ monkeypatch.setattr(dl.requests, "head", boom)
+ assert dl.url_is_reachable("https://whatever") is False
+
+
+def test_download_from_aws_awscli_missing(monkeypatch, tmp_path):
+ monkeypatch.setattr(dl.shutil, "which", lambda name: None)
+
+ with pytest.raises(RuntimeError, match="AWS cli is required to download from S3"):
+ dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path)
+
+
+def test_download_from_aws_awscli_present_fast_path(monkeypatch, tmp_path):
+ monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws")
+ monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 0)
+ dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path, show_progress=False)
+
+
+def test_awscli_non_ci_branch(monkeypatch, tmp_path):
+ monkeypatch.delenv("CI", raising=False)
+
+ monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws")
+ monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 0)
+
+ dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path, show_progress=False)
diff --git a/test/test_utils/test_misc.py b/test/test_utils/test_misc.py
index f1b9267a..6485407a 100644
--- a/test/test_utils/test_misc.py
+++ b/test/test_utils/test_misc.py
@@ -1,6 +1,7 @@
""" """
import datetime
+import json
import textwrap
import time
from itertools import product
@@ -18,6 +19,7 @@
MovingAverage,
ReprMixin,
Timer,
+ _is_pathlike_string,
add_docstring,
add_kwargs,
dict_to_str,
@@ -128,7 +130,7 @@ def test_get_record_list_recursive3():
record_list = get_record_list_recursive3(path, rec_patterns_with_ext, with_suffix=True)
for tranche in list("EFG"):
# assert the records come with file extension
- assert all([p.endswith(".mat") for p in record_list[tranche]]), record_list[tranche]
+ assert all([p.endswith(".mat") for p in record_list[tranche]]), record_list[tranche] # type: ignore
def test_dict_to_str():
@@ -140,8 +142,8 @@ def test_dict_to_str():
def test_str2bool():
assert str2bool(True) is True
assert str2bool(False) is False
- assert str2bool("True") is True
- assert str2bool("False") is False
+ assert str2bool("True ") is True
+ assert str2bool(" False") is False
assert str2bool("true") is True
assert str2bool("false") is False
assert str2bool("1") is True
@@ -150,10 +152,16 @@ def test_str2bool():
assert str2bool("no") is False
assert str2bool("y") is True
assert str2bool("n") is False
+ assert str2bool(None) is False
+ assert str2bool(None, default=True) is True
+ assert str2bool("", strict=False) is False
with pytest.raises(ValueError, match="Boolean value expected"):
str2bool("abc")
with pytest.raises(ValueError, match="Boolean value expected"):
str2bool("2")
+ with pytest.raises(TypeError, match="Expected str|bool|None"):
+ str2bool(1) # type: ignore
+ assert str2bool(1, strict=False) is False # type: ignore
def test_diff_with_step():
@@ -192,7 +200,7 @@ def test_plot_single_lead():
n_samples = 5000
plot_single_lead(
t=np.arange(n_samples) / fs,
- sig=500 * DEFAULTS.RNG.normal(size=(n_samples,)),
+ sig=500 * DEFAULTS.RNG.normal(size=(n_samples,)), # type: ignore
ticks_granularity=2,
)
@@ -225,6 +233,7 @@ def test_read_log_txt():
dst_dir=str(_TMP_DIR),
extract=True,
filename="log.txt",
+ verify_length=False,
)
log_txt_file = str(_TMP_DIR / "log.txt")
log_txt = read_log_txt(log_txt_file)
@@ -260,15 +269,15 @@ def test_dicts_equal():
d2 = {"a": pd.DataFrame([{"hehe": 2, "haha": 2}])[["hehe", "haha"]]}
assert dicts_equal(d1, d2) is False
assert dicts_equal(d2, d1) is False
- d1["a"] = d1["a"]["hehe"]
- d2["a"] = d2["a"]["haha"]
+ d1["a"] = d1["a"]["hehe"] # type: ignore
+ d2["a"] = d2["a"]["haha"] # type: ignore
assert dicts_equal(d1, d2) is False
assert dicts_equal(d2, d1) is False
d1 = {"a": pd.DataFrame([{"hehe": 1, "haha": 2}])[["haha", "hehe"]]}
d2 = {"a": pd.DataFrame([{"hehe": 2, "haha": 2}])[["hehe", "haha"]]}
- d1["a"] = d1["a"]["hehe"]
- d2["a"] = d2["a"]["hehe"]
+ d1["a"] = d1["a"]["hehe"] # type: ignore
+ d2["a"] = d2["a"]["hehe"] # type: ignore
assert dicts_equal(d1, d2) is False
assert dicts_equal(d2, d1) is False
@@ -361,7 +370,7 @@ def test_CitationMixin():
def test_MovingAverage():
ma = MovingAverage(verbose=2)
- data = DEFAULTS.RNG.normal(size=(100,))
+ data = DEFAULTS.RNG.normal(size=(100,)) # type: ignore
new_data = ma(data, method="sma", window=7, center=True)
assert new_data.shape == data.shape
new_data = ma(data, method="ema", weight=0.7)
@@ -435,7 +444,7 @@ def func(a, b):
with pytest.raises(ValueError, match="mode `.+` is not supported"):
- @add_docstring("This is a new docstring.", mode="xxx")
+ @add_docstring("This is a new docstring.", mode="xxx") # type: ignore
def func(a, b):
"""This is a docstring."""
return a + b
@@ -443,7 +452,7 @@ def func(a, b):
def test_remove_parameters_returns_from_docstring():
new_docstring = remove_parameters_returns_from_docstring(
- remove_parameters_returns_from_docstring.__doc__,
+ remove_parameters_returns_from_docstring.__doc__, # type: ignore
parameters=["returns_indicator", "parameters_indicator"],
returns="str",
)
@@ -602,7 +611,17 @@ def test_make_serializable():
x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean())
obj = make_serializable(x)
assert obj == [[1, 2, 3], 5.0]
- assert isinstance(obj[1], float) and isinstance(x[1], np.float64)
+ assert isinstance(obj[1], float) and isinstance(x[1], np.float64) # type: ignore
+
+ obj = make_serializable(DEFAULTS, drop_unserializable=False, drop_paths=False)
+ assert isinstance(obj, dict) and set(["RNG", "DTYPE", "log_dir"]).issubset(set(obj))
+ json.dumps(obj) # should raise no error
+ obj = make_serializable(DEFAULTS, drop_unserializable=False, drop_paths=True)
+ assert isinstance(obj, dict) and set(["RNG", "DTYPE"]).issubset(set(obj)) and "log_dir" not in obj
+ json.dumps(obj) # should raise no error
+ obj = make_serializable(DEFAULTS, drop_unserializable=True, drop_paths=True)
+ assert isinstance(obj, dict) and set(["RNG", "DTYPE", "log_dir"]).intersection(set(obj)) == set()
+ json.dumps(obj) # should raise no error
def test_select_k():
@@ -674,3 +693,20 @@ def test_np_topk():
with pytest.raises(AssertionError, match="dim out of bounds"):
np_topk(arr1d, k=1, dim=1)
+
+
+def test_is_pathlike_string():
+ assert _is_pathlike_string("abc") is False
+ assert _is_pathlike_string("abc.txt") is True
+ assert _is_pathlike_string("/home/abc") is True
+ assert _is_pathlike_string("C:\\abc") is True
+ assert _is_pathlike_string("C:/abc") is True
+ assert _is_pathlike_string("./abc") is True
+ assert _is_pathlike_string("") is False
+ assert _is_pathlike_string("~/abc") is True
+ assert _is_pathlike_string("A:project") is False
+ assert _is_pathlike_string("README") is False
+ assert _is_pathlike_string("my.folder") is True
+ assert _is_pathlike_string(".") is True
+ assert _is_pathlike_string(["abc", "def"]) is False # type: ignore
+ assert _is_pathlike_string(123) is False # type: ignore
diff --git a/test/test_utils/test_utils_data.py b/test/test_utils/test_utils_data.py
index 205800c5..1349b69a 100644
--- a/test/test_utils/test_utils_data.py
+++ b/test/test_utils/test_utils_data.py
@@ -51,6 +51,9 @@ def test_get_mask():
assert intervals == mask_to_intervals(mask[idx], 1)
assert (get_mask(5000, np.arange(250, 5000 - 250, 400), 50, 50) == mask[0]).all()
+ with pytest.raises(ValueError, match="Unknown return_fmt. Expected 'mask' or 'intervals', but got"):
+ get_mask((12, 5000), np.arange(250, 5000 - 250, 400), 50, 50, return_fmt="xxx") # type: ignore
+
def test_mask_to_intervals():
mask = np.zeros(100, dtype=int)
@@ -126,7 +129,7 @@ def test_rdheader():
with pytest.raises(FileNotFoundError, match="file `not_exist_file\\.hea` not found"):
rdheader("not_exist_file")
with pytest.raises(TypeError, match="header_data must be str or sequence of str, but got"):
- rdheader(1)
+ rdheader(1) # type: ignore
def test_ensure_lead_fmt():
diff --git a/test/test_utils/test_utils_nn.py b/test/test_utils/test_utils_nn.py
index ef19e42d..d753f07a 100644
--- a/test/test_utils/test_utils_nn.py
+++ b/test/test_utils/test_utils_nn.py
@@ -107,9 +107,9 @@ def test_compute_output_shape():
num_filters=num_filters,
output_padding=0,
channel_last=channel_last,
- **conv_kw,
+ **conv_kw, # type: ignore
)
- conv_output_tensor = torch.nn.Conv1d(in_channels, num_filters, **conv_kw)(tensor_first)
+ conv_output_tensor = torch.nn.Conv1d(in_channels, num_filters, **conv_kw)(tensor_first) # type: ignore
if channel_last:
conv_output_tensor = conv_output_tensor.permute(0, 2, 1)
@@ -130,15 +130,15 @@ def test_compute_output_shape():
num_filters=num_filters,
output_padding=0,
channel_last=channel_last,
- **deconv_kw,
+ **deconv_kw, # type: ignore
)
- deconv_output_tensor = torch.nn.ConvTranspose1d(in_channels, num_filters, **deconv_kw)(tensor_first)
+ deconv_output_tensor = torch.nn.ConvTranspose1d(in_channels, num_filters, **deconv_kw)(tensor_first) # type: ignore
if channel_last:
deconv_output_tensor = deconv_output_tensor.permute(0, 2, 1)
assert deconv_output_shape == deconv_output_tensor.shape
compute_deconv_output_shape(
- input_shape=tensor.shape, num_filters=num_filters, output_padding=0, channel_last=channel_last, **deconv_kw
+ input_shape=tensor.shape, num_filters=num_filters, output_padding=0, channel_last=channel_last, **deconv_kw # type: ignore
)
# maxpool
@@ -155,9 +155,9 @@ def test_compute_output_shape():
)
for tensor, channel_last in zip([tensor_first, tensor_last], [False, True]):
maxpool_output_shape = compute_output_shape(
- "maxpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **maxpool_kw
+ "maxpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **maxpool_kw # type: ignore
)
- maxpool_output_tensor = torch.nn.MaxPool1d(**maxpool_kw)(tensor_first)
+ maxpool_output_tensor = torch.nn.MaxPool1d(**maxpool_kw)(tensor_first) # type: ignore
if channel_last:
maxpool_output_tensor = maxpool_output_tensor.permute(0, 2, 1)
@@ -176,9 +176,9 @@ def test_compute_output_shape():
)
for tensor, channel_last in zip([tensor_first, tensor_last], [False, True]):
avgpool_output_shape = compute_output_shape(
- "avgpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **avgpool_kw
+ "avgpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **avgpool_kw # type: ignore
)
- avgpool_output_tensor = torch.nn.AvgPool1d(**avgpool_kw)(tensor_first)
+ avgpool_output_tensor = torch.nn.AvgPool1d(**avgpool_kw)(tensor_first) # type: ignore
if channel_last:
avgpool_output_tensor = avgpool_output_tensor.permute(0, 2, 1)
@@ -186,7 +186,7 @@ def test_compute_output_shape():
shape_1 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8])
shape_2 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8], asymmetric_padding=[1, 3])
- assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 1 + 3)
+ assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 1 + 3) # type: ignore
shape_1 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8])
shape_2 = compute_output_shape(
"conv",
@@ -194,7 +194,7 @@ def test_compute_output_shape():
padding=[4, 8],
asymmetric_padding=[[1, 3], [0, 2]],
)
- assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 0 + 2)
+ assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 0 + 2) # type: ignore
shape_1 = compute_output_shape(
"conv",
@@ -217,7 +217,7 @@ def test_compute_output_shape():
with pytest.raises(AssertionError, match="`num_filters` should be `None` or positive integer"):
compute_output_shape("conv", tensor_first.shape, num_filters=-12)
with pytest.raises(AssertionError, match="`kernel_size` should contain only positive integers"):
- compute_output_shape("conv", tensor_first.shape, kernel_size=2.5)
+ compute_output_shape("conv", tensor_first.shape, kernel_size=2.5) # type: ignore
with pytest.raises(AssertionError, match="`kernel_size` should contain only positive integers"):
compute_output_shape("conv", tensor_first.shape, kernel_size=0)
with pytest.raises(AssertionError, match="`stride` should contain only positive integers"):
@@ -249,7 +249,7 @@ def test_compute_output_shape():
with pytest.raises(ValueError, match="input has 1 dimensions, while `padding` has 2 dimensions,"):
compute_output_shape("conv", tensor_first.shape, padding=[1, 1])
with pytest.raises(AssertionError, match="Invalid `asymmetric_padding`"):
- compute_output_shape("conv", tensor_first.shape, asymmetric_padding=2)
+ compute_output_shape("conv", tensor_first.shape, asymmetric_padding=2) # type: ignore
with pytest.raises(AssertionError, match="Invalid `asymmetric_padding`"):
compute_output_shape(
"conv",
@@ -314,32 +314,32 @@ def test_default_collate_fn():
batch_data = [
(
- DEFAULTS.RNG.uniform(size=shape_1),
- DEFAULTS.RNG.uniform(size=shape_2),
- DEFAULTS.RNG.uniform(size=shape_3),
+ DEFAULTS.RNG.uniform(size=shape_1), # type: ignore
+ DEFAULTS.RNG.uniform(size=shape_2), # type: ignore
+ DEFAULTS.RNG.uniform(size=shape_3), # type: ignore
)
for _ in range(batch_size)
]
tensor_1, tensor_2, tensor_3 = default_collate_fn(batch_data)
- assert tensor_1.shape == (batch_size, *shape_1)
- assert tensor_2.shape == (batch_size, *shape_2)
- assert tensor_3.shape == (batch_size, *shape_3)
+ assert tensor_1.shape == (batch_size, *shape_1) # type: ignore
+ assert tensor_2.shape == (batch_size, *shape_2) # type: ignore
+ assert tensor_3.shape == (batch_size, *shape_3) # type: ignore
batch_data = [
dict(
- tensor_1=DEFAULTS.RNG.uniform(size=shape_1),
- tensor_2=DEFAULTS.RNG.uniform(size=shape_2),
- tensor_3=DEFAULTS.RNG.uniform(size=shape_3),
+ tensor_1=DEFAULTS.RNG.uniform(size=shape_1), # type: ignore
+ tensor_2=DEFAULTS.RNG.uniform(size=shape_2), # type: ignore
+ tensor_3=DEFAULTS.RNG.uniform(size=shape_3), # type: ignore
)
for _ in range(batch_size)
]
tensors = default_collate_fn(batch_data)
- assert tensors["tensor_1"].shape == (batch_size, *shape_1)
- assert tensors["tensor_2"].shape == (batch_size, *shape_2)
- assert tensors["tensor_3"].shape == (batch_size, *shape_3)
+ assert tensors["tensor_1"].shape == (batch_size, *shape_1) # type: ignore
+ assert tensors["tensor_2"].shape == (batch_size, *shape_2) # type: ignore
+ assert tensors["tensor_3"].shape == (batch_size, *shape_3) # type: ignore
with pytest.raises(ValueError, match="Invalid batch"):
- default_collate_fn([1])
+ default_collate_fn([1]) # type: ignore
with pytest.raises(ValueError, match="No data"):
default_collate_fn([tuple()])
@@ -516,17 +516,40 @@ def test_mixin_classes():
assert isinstance(model_1d.dtype_, str)
assert isinstance(model_1d.device_, str)
+ # test pth/pt file
save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.pth"
+ # convert save_path to bytes to cover bytes path handling code
+ save_path = str(save_path).encode()
+ model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}, use_safetensors=False)
+ assert save_path.is_file()
+ loaded_model, _ = Model1D.from_checkpoint(save_path)
+ assert repr(model_1d) == repr(loaded_model)
+ save_path = Path(save_path.decode())
+ save_path.unlink()
- model_1d.save(save_path, dict(n_leads=12))
+ with pytest.warns(RuntimeWarning, match="`safetensors` is used by default."):
+ model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}})
+ assert save_path.with_suffix(".safetensors").is_file()
+ save_path.with_suffix(".safetensors").unlink()
+ # test single safetensors file
+ save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.safetensors"
+ model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}})
assert save_path.is_file()
-
loaded_model, _ = Model1D.from_checkpoint(save_path)
assert repr(model_1d) == repr(loaded_model)
-
save_path.unlink()
+ # test directory with safetensors
+ save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.safetensors"
+ model_1d.save(
+ save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}, safetensors_single_file=False
+ )
+ assert save_path.with_suffix("").is_dir()
+ loaded_model, _ = Model1D.from_checkpoint(save_path.with_suffix(""))
+ assert repr(model_1d) == repr(loaded_model)
+ shutil.rmtree(save_path.with_suffix(""))
+
# test remote un-compressed model
save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model"
save_path.mkdir(exist_ok=True, parents=True)
diff --git a/torch_ecg/_preprocessors/bandpass.py b/torch_ecg/_preprocessors/bandpass.py
index a371386a..5e81a796 100644
--- a/torch_ecg/_preprocessors/bandpass.py
+++ b/torch_ecg/_preprocessors/bandpass.py
@@ -1,9 +1,8 @@
"""BandPass filter preprocessor."""
-from numbers import Real
-from typing import Any, List, Literal, Optional, Tuple
+from typing import Any, List, Literal, Optional, Tuple, Union
-import numpy as np
+from numpy.typing import NDArray
from .base import PreProcessor, preprocess_multi_lead_signal
@@ -17,9 +16,9 @@ class BandPass(PreProcessor):
Parameters
----------
- lowcut : numbers.Real, optional
+ lowcut : int or float, optional
Low cutoff frequency
- highcut : numbers.Real, optional
+ highcut : int or float, optional
High cutoff frequency.
filter_type : {"butter", "fir"}, , default "butter"
Type of the bandpass filter.
@@ -43,8 +42,8 @@ class BandPass(PreProcessor):
def __init__(
self,
- lowcut: Optional[Real] = 0.5,
- highcut: Optional[Real] = 45,
+ lowcut: Optional[Union[int, float]] = 0.5,
+ highcut: Optional[Union[int, float]] = 45,
filter_type: Literal["butter", "fir"] = "butter",
filter_order: Optional[int] = None,
**kwargs: Any,
@@ -59,17 +58,17 @@ def __init__(
self.filter_type = filter_type
self.filter_order = filter_order
- def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]:
+ def apply(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
- fs : int
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ fs : int or float
Sampling frequency of the ECG signal.
Returns
@@ -84,8 +83,8 @@ def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]:
filtered_sig = preprocess_multi_lead_signal(
raw_sig=sig,
fs=fs,
- band_fs=[self.lowcut, self.highcut],
- filter_type=self.filter_type,
+ band_fs=[self.lowcut, self.highcut], # type: ignore
+ filter_type=self.filter_type, # type: ignore
filter_order=self.filter_order,
)
return filtered_sig, fs
diff --git a/torch_ecg/_preprocessors/base.py b/torch_ecg/_preprocessors/base.py
index ca74241d..ae2f0999 100644
--- a/torch_ecg/_preprocessors/base.py
+++ b/torch_ecg/_preprocessors/base.py
@@ -2,11 +2,11 @@
from abc import ABC, abstractmethod
from itertools import repeat
-from numbers import Real
-from typing import List, Literal, Optional, Tuple
+from typing import List, Literal, Optional, Tuple, Union
import numpy as np
from biosppy.signals.tools import filter_signal
+from numpy.typing import NDArray
from scipy.ndimage import median_filter
from ..cfg import DEFAULTS
@@ -30,37 +30,37 @@ class PreProcessor(ReprMixin, ABC):
__name__ = "PreProcessor"
@abstractmethod
- def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
+ def apply(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
- fs : numbers.Real
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ fs : int or float
Sampling frequency of the ECG signal.
"""
raise NotImplementedError
- @add_docstring(apply)
- def __call__(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
+ @add_docstring(apply.__doc__) # type: ignore
+ def __call__(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]:
"""alias of :meth:`self.apply`."""
return self.apply(sig, fs)
- def _check_sig(self, sig: np.ndarray) -> None:
+ def _check_sig(self, sig: NDArray) -> None:
"""Check validity of the signal.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
"""
if sig.ndim not in [1, 2, 3]:
@@ -73,14 +73,14 @@ def _check_sig(self, sig: np.ndarray) -> None:
def preprocess_multi_lead_signal(
- raw_sig: np.ndarray,
- fs: Real,
+ raw_sig: NDArray,
+ fs: Union[int, float],
sig_fmt: Literal["channel_first", "lead_first", "channel_last", "lead_last"] = "channel_first",
- bl_win: Optional[List[Real]] = None,
- band_fs: Optional[List[Real]] = None,
+ bl_win: Optional[List[Union[int, float]]] = None,
+ band_fs: Optional[List[Union[int, float]]] = None,
filter_type: Literal["butter", "fir"] = "butter",
filter_order: Optional[int] = None,
-) -> np.ndarray:
+) -> NDArray:
"""Perform preprocessing for multi-lead ECG signal (with units in mV).
preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
@@ -90,19 +90,19 @@ def preprocess_multi_lead_signal(
----------
raw_sig : numpy.ndarray
The raw ECG signal, with units in mV.
- fs : numbers.Real
+ fs : int or float
Sampling frequency of `raw_sig`.
sig_fmt : str, default "channel_first"
Format of the multi-lead ECG signal,
"channel_last" (alias "lead_last"), or
"channel_first" (alias "lead_first").
- bl_win : List[numbers.Real], optional
+ bl_win : List[Union[int, float]], optional
Window (units in second) of baseline removal
using :meth:`~scipy.ndimage.median_filter`,
the first is the shorter one, the second the longer one,
a typical pair is ``[0.2, 0.6]``.
If is None or empty, baseline removal will not be performed.
- band_fs : List[numbers.Real], optional
+ band_fs : List[Union[int, float]], optional
Frequency band of the bandpass filter,
a typical pair is ``[0.5, 45]``.
Be careful when detecting paced rhythm.
@@ -130,9 +130,9 @@ def preprocess_multi_lead_signal(
], f"multi-lead signal format `{sig_fmt}` not supported"
if sig_fmt.lower() in ["channel_last", "lead_last"]:
# might have a batch dimension at the first axis
- filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype)
+ filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype) # type: ignore
else:
- filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype)
+ filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore
# remove baseline
if bl_win:
@@ -186,13 +186,13 @@ def preprocess_multi_lead_signal(
def preprocess_single_lead_signal(
- raw_sig: np.ndarray,
- fs: Real,
- bl_win: Optional[List[Real]] = None,
- band_fs: Optional[List[Real]] = None,
+ raw_sig: NDArray,
+ fs: Union[int, float],
+ bl_win: Optional[List[Union[int, float]]] = None,
+ band_fs: Optional[List[Union[int, float]]] = None,
filter_type: Literal["butter", "fir"] = "butter",
filter_order: Optional[int] = None,
-) -> np.ndarray:
+) -> NDArray:
"""Perform preprocessing for single lead ECG signal (with units in mV).
Preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
@@ -201,15 +201,15 @@ def preprocess_single_lead_signal(
----------
raw_sig : numpy.ndarray
Raw ECG signal, with units in mV.
- fs : numbers.Real
+ fs : int or float
Sampling frequency of `raw_sig`.
- bl_win : list (of 2 numbers.Real), optional
+ bl_win : list (of 2 int or float), optional
Window (units in second) of baseline removal
using :meth:`~scipy.ndimage.median_filter`,
the first is the shorter one, the second the longer one,
a typical pair is ``[0.2, 0.6]``.
If is None or empty, baseline removal will not be performed.
- band_fs : list of numbers.Real, optional
+ band_fs : list of int or float, optional
Frequency band of the bandpass filter,
a typical pair is ``[0.5, 45]``.
Be careful when detecting paced rhythm.
@@ -230,7 +230,7 @@ def preprocess_single_lead_signal(
e.g. :meth:`~torch_ecg.utils.butter_bandpass_filter`.
"""
- filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype)
+ filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore
assert filtered_ecg.ndim == 1, "single-lead signal should be 1d array"
# remove baseline
diff --git a/torch_ecg/_preprocessors/baseline_remove.py b/torch_ecg/_preprocessors/baseline_remove.py
index 1a46ed47..129d2fbb 100644
--- a/torch_ecg/_preprocessors/baseline_remove.py
+++ b/torch_ecg/_preprocessors/baseline_remove.py
@@ -4,10 +4,9 @@
"""
import warnings
-from numbers import Real
-from typing import Any, List, Tuple
+from typing import Any, List, Tuple, Union
-import numpy as np
+from numpy.typing import NDArray
from .base import PreProcessor, preprocess_multi_lead_signal
@@ -46,24 +45,24 @@ def __init__(self, window1: float = 0.2, window2: float = 0.6, **kwargs: Any) ->
self.window1, self.window2 = self.window2, self.window1
warnings.warn("values of `window1` and `window2` are switched", RuntimeWarning)
- def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
+ def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
- fs : numbers.Real
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ fs : float or int
Sampling frequency of the ECG signal.
Returns
-------
filtered_sig : :class:`numpy.ndarray`
The median filtered (hence baseline removed) ECG signal.
- fs : :class:`int`
+ fs : float or int
Sampling frequency of the filtered ECG signal.
"""
diff --git a/torch_ecg/_preprocessors/normalize.py b/torch_ecg/_preprocessors/normalize.py
index dfc8e6d7..bcbca694 100644
--- a/torch_ecg/_preprocessors/normalize.py
+++ b/torch_ecg/_preprocessors/normalize.py
@@ -3,7 +3,7 @@
from numbers import Real
from typing import Any, List, Literal, Tuple, Union
-import numpy as np
+from numpy.typing import NDArray
from ..cfg import DEFAULTS
from ..utils.utils_signal import normalize
@@ -39,11 +39,11 @@ class Normalize(PreProcessor):
----------
method : {"naive", "min-max", "z-score"}, default "z-score"
Normalization method, case insensitive.
- mean : numbers.Real or numpy.ndarray, default 0.0
+ mean : float or int or numpy.ndarray, default 0.0
Mean value of the normalized signal,
or mean values for each lead of the normalized signal.
Useless if `method` is "min-max".
- std : numbers.Real or numpy.ndarray, default 1.0
+ std : float or int or numpy.ndarray, default 1.0
Standard deviation of the normalized signal,
or standard deviations for each lead of the normalized signal.
Useless if `method` is "min-max".
@@ -66,8 +66,8 @@ class Normalize(PreProcessor):
def __init__(
self,
method: Literal["naive", "min-max", "z-score"] = "z-score",
- mean: Union[Real, np.ndarray] = 0.0,
- std: Union[Real, np.ndarray] = 1.0,
+ mean: Union[float, int, NDArray] = 0.0,
+ std: Union[float, int, NDArray] = 1.0,
per_channel: bool = False,
**kwargs: Any,
) -> None:
@@ -89,17 +89,17 @@ def __init__(
std, Real
), "mean and std should be real numbers in the non per-channel setting"
- def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
+ def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
- fs : numbers.Real
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ fs : float or int
Sampling frequency of the ECG signal.
**NOT** used currently.
@@ -114,7 +114,7 @@ def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
self._check_sig(sig)
normalized_sig = normalize(
sig=sig.astype(DEFAULTS.np_dtype),
- method=self.method,
+ method=self.method, # type: ignore
mean=self.mean,
std=self.std,
sig_fmt="channel_first",
@@ -181,9 +181,9 @@ class NaiveNormalize(Normalize):
Parameters
----------
- mean : numbers.Real or numpy.ndarray, default 0.0
+ mean : float or int or numpy.ndarray, default 0.0
Value(s) to be subtracted.
- std : numbers.Real or numpy.ndarray, default 1.0
+ std : float or int or numpy.ndarray, default 1.0
Value(s) to be divided.
per_channel : bool, default False
If True, normalization will be done per channel.
@@ -203,8 +203,8 @@ class NaiveNormalize(Normalize):
def __init__(
self,
- mean: Union[Real, np.ndarray] = 0.0,
- std: Union[Real, np.ndarray] = 1.0,
+ mean: Union[float, int, NDArray] = 0.0,
+ std: Union[float, int, NDArray] = 1.0,
per_channel: bool = False,
**kwargs: Any,
) -> None:
@@ -234,10 +234,10 @@ class ZScoreNormalize(Normalize):
Parameters
----------
- mean : numbers.Real or numpy.ndarray, default 0.0
+ mean : float or int or numpy.ndarray, default 0.0
Mean value of the normalized signal,
or mean values for each lead of the normalized signal.
- std : numbers.Real or numpy.ndarray, default 1.0
+ std : float or int or numpy.ndarray, default 1.0
Standard deviation of the normalized signal,
or standard deviations for each lead of the normalized signal.
per_channel : bool, default False
@@ -258,8 +258,8 @@ class ZScoreNormalize(Normalize):
def __init__(
self,
- mean: Union[Real, np.ndarray] = 0.0,
- std: Union[Real, np.ndarray] = 1.0,
+ mean: Union[float, int, NDArray] = 0.0,
+ std: Union[float, int, NDArray] = 1.0,
per_channel: bool = False,
**kwargs: Any,
) -> None:
diff --git a/torch_ecg/_preprocessors/preproc_manager.py b/torch_ecg/_preprocessors/preproc_manager.py
index 95946643..faf9f22e 100644
--- a/torch_ecg/_preprocessors/preproc_manager.py
+++ b/torch_ecg/_preprocessors/preproc_manager.py
@@ -4,7 +4,7 @@
from random import sample
from typing import List, Optional, Tuple
-import numpy as np
+from numpy.typing import NDArray
from ..utils.misc import ReprMixin
from .bandpass import BandPass
@@ -101,7 +101,7 @@ def _add_resample(self, **config: dict) -> None:
"""
self._preprocessors.append(Resample(**config))
- def __call__(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]:
+ def __call__(self, sig: NDArray, fs: int) -> Tuple[NDArray, int]:
"""The main function of the manager, which applies the preprocessors
Parameters
diff --git a/torch_ecg/_preprocessors/resample.py b/torch_ecg/_preprocessors/resample.py
index e58d43e7..d4f67187 100644
--- a/torch_ecg/_preprocessors/resample.py
+++ b/torch_ecg/_preprocessors/resample.py
@@ -1,10 +1,9 @@
"""Resample the signal into fixed sampling frequency or length."""
-from numbers import Real
-from typing import Any, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple, Union
-import numpy as np
import scipy.signal as SS
+from numpy.typing import NDArray
from ..cfg import DEFAULTS
from .base import PreProcessor
@@ -46,18 +45,17 @@ def __init__(self, fs: Optional[int] = None, siglen: Optional[int] = None, **kwa
self.siglen = siglen
assert sum([bool(self.fs), bool(self.siglen)]) == 1, "one and only one of `fs` and `siglen` should be set"
- def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
+ def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]:
"""Apply the preprocessor to `sig`.
Parameters
----------
sig : numpy.ndarray
The ECG signal, can be
-
- - 1d array, which is a single-lead ECG;
- - 2d array, which is a multi-lead ECG of "lead_first" format;
- - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
- fs : numbers.Real
+ - 1d array, which is a single-lead ECG;
+ - 2d array, which is a multi-lead ECG of "lead_first" format;
+ - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``.
+ fs : float or int
Sampling frequency of the ECG signal.
Returns
@@ -75,7 +73,7 @@ def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]:
else: # self.siglen is not None
rsmp_sig = SS.resample(sig.astype(DEFAULTS.np_dtype), num=self.siglen, axis=-1)
new_fs = int(round(self.siglen / sig.shape[-1] * fs))
- return rsmp_sig, new_fs
+ return rsmp_sig, new_fs # type: ignore
def extra_repr_keys(self) -> List[str]:
return [
diff --git a/torch_ecg/augmenters/baseline_wander.py b/torch_ecg/augmenters/baseline_wander.py
index 0a4bd186..57e48dcb 100644
--- a/torch_ecg/augmenters/baseline_wander.py
+++ b/torch_ecg/augmenters/baseline_wander.py
@@ -2,12 +2,12 @@
import multiprocessing as mp
from itertools import repeat
-from numbers import Real
from random import randint
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
+from numpy.typing import NDArray
from torch import Tensor
from ..cfg import DEFAULTS
@@ -88,9 +88,9 @@ class BaselineWanderAugmenter(Augmenter):
def __init__(
self,
fs: int,
- bw_fs: Optional[np.ndarray] = None,
- ampl_ratio: Optional[np.ndarray] = None,
- gaussian: Optional[np.ndarray] = None,
+ bw_fs: Optional[NDArray] = None,
+ ampl_ratio: Optional[NDArray] = None,
+ gaussian: Optional[NDArray] = None,
prob: float = 0.5,
inplace: bool = True,
**kwargs: Any,
@@ -186,8 +186,8 @@ def forward(
if not self.inplace:
sig = sig.clone()
if self.prob > 0:
- sig.add_(gen_baseline_wander(sig, self.fs, self.bw_fs, self.ampl_ratio, self.gaussian))
- return (sig, label, *extra_tensors)
+ sig.add_(gen_baseline_wander(sig, self.fs, self.bw_fs, self.ampl_ratio, self.gaussian)) # type: ignore
+ return (sig, label, *extra_tensors) # type: ignore
def extra_repr_keys(self) -> List[str]:
return [
@@ -223,7 +223,7 @@ def _get_ampl(sig: Tensor, fs: int) -> Tensor:
return ampl
-def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarray:
+def _gen_gaussian_noise(siglen: int, mean: Union[float, int] = 0, std: Union[float, int] = 0) -> NDArray:
"""Generate 1d Gaussian noise of given
length, mean, and standard deviation.
@@ -231,9 +231,9 @@ def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarra
----------
siglen : int
Length of the noise signal.
- mean : numbers.Real, default 0
+ mean : float or int, default 0
Mean value of the noise.
- std : numbers.Real, default 0
+ std : float or int, default 0
Standard deviation of the noise.
Returns
@@ -248,12 +248,12 @@ def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarra
def _gen_sinusoidal_noise(
siglen: int,
- start_phase: Real,
- end_phase: Real,
- amplitude: Real,
- amplitude_mean: Real = 0,
- amplitude_std: Real = 0,
-) -> np.ndarray:
+ start_phase: Union[float, int],
+ end_phase: Union[float, int],
+ amplitude: Union[float, int],
+ amplitude_mean: Union[float, int] = 0,
+ amplitude_std: Union[float, int] = 0,
+) -> NDArray:
"""Generate 1d sinusoidal noise of given
length, amplitude, start phase, and end phase.
@@ -261,15 +261,15 @@ def _gen_sinusoidal_noise(
----------
siglen : int
Length of the (noise) signal.
- start_phase : numbers.Real
+ start_phase : float or int
Start phase, with units in degrees.
- end_phase : numbers.Real
+ end_phase : float or int
End phase, with units in degrees.
- amplitude : numbers.Real
+ amplitude : float or int
Amplitude of the sinusoidal curve.
- amplitude_mean : numbers.Real
+ amplitude_mean : float or int
Mean amplitude of an extra Gaussian noise.
- amplitude_std : numbers.Real, default 0
+ amplitude_std : float or int, default 0
Standard deviation of an extra Gaussian noise
Returns
@@ -286,11 +286,11 @@ def _gen_sinusoidal_noise(
def _gen_baseline_wander(
siglen: int,
- fs: Real,
- bw_fs: Union[Real, Sequence[Real]],
- amplitude: Union[Real, Sequence[Real]],
- amplitude_gaussian: Sequence[Real] = [0, 0],
-) -> np.ndarray:
+ fs: Union[float, int],
+ bw_fs: Union[float, int, Sequence[Union[float, int]]],
+ amplitude: Union[float, int, Sequence[Union[float, int]]],
+ amplitude_gaussian: Sequence[Union[float, int]] = [0, 0],
+) -> NDArray:
"""Generate 1d baseline wander of given
length, amplitude, and frequency.
@@ -298,14 +298,14 @@ def _gen_baseline_wander(
----------
siglen : int
Length of the (noise) signal.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the original signal.
- bw_fs : numbers.Real, or list of numbers.Real
+ bw_fs : float or int, or list of float or int
Frequency (Frequencies) of the baseline wander.
- amplitude : numbers.Real, or list of numbers.Real
+ amplitude : float or int, or list of float or int
Amplitude of the baseline wander (corr. to each frequency band).
- amplitude_gaussian : Tuple[numbers.Real], default [0,0]
- 2-tuple of :class:`~numbers.Real`.
+ amplitude_gaussian : Tuple[float or int], default [0,0]
+ 2-tuple of :class:`~float or int`.
Mean and std of amplitude of an extra Gaussian noise.
Returns
@@ -319,11 +319,11 @@ def _gen_baseline_wander(
"""
bw = _gen_gaussian_noise(siglen, amplitude_gaussian[0], amplitude_gaussian[1])
- if isinstance(bw_fs, Real):
+ if isinstance(bw_fs, (int, float)):
_bw_fs = [bw_fs]
else:
_bw_fs = bw_fs
- if isinstance(amplitude, Real):
+ if isinstance(amplitude, (int, float)):
_amplitude = list(repeat(amplitude, len(_bw_fs)))
else:
_amplitude = amplitude
@@ -338,11 +338,11 @@ def _gen_baseline_wander(
def gen_baseline_wander(
sig: Tensor,
- fs: Real,
- bw_fs: Union[Real, Sequence[Real]],
- ampl_ratio: np.ndarray,
- gaussian: np.ndarray,
-) -> np.ndarray:
+ fs: Union[float, int],
+ bw_fs: Union[float, int, Sequence[Union[float, int]]],
+ ampl_ratio: NDArray,
+ gaussian: NDArray,
+) -> Tensor:
"""Generate 1d baseline wander of given
length, amplitude, and frequency.
@@ -350,9 +350,9 @@ def gen_baseline_wander(
----------
sig : torch.Tensor
Batched ECGs to be augmented, of shape (batch, lead, siglen).
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the original signal.
- bw_fs : numbers.Real, or list of numbers.Real,
+ bw_fs : float or int, or list of float or int,
Frequency (Frequencies) of the baseline wander.
ampl_ratio : numpy.ndarray, optional
Candidate ratios of noise amplitdes compared to the original ECGs for each `fs`,
@@ -363,7 +363,7 @@ def gen_baseline_wander(
Returns
-------
- bw : numpy.ndarray
+ bw : torch.Tensor
Baseline wander of given length, amplitude, frequency,
of shape ``(batch, lead, siglen)``.
diff --git a/torch_ecg/augmenters/cutmix.py b/torch_ecg/augmenters/cutmix.py
index c0c7edf3..60bd98f2 100644
--- a/torch_ecg/augmenters/cutmix.py
+++ b/torch_ecg/augmenters/cutmix.py
@@ -1,12 +1,12 @@
""" """
from copy import deepcopy
-from numbers import Real
from random import shuffle
from typing import Any, List, Optional, Sequence, Tuple
import numpy as np
import torch
+from numpy.typing import NDArray
from torch import Tensor
from ..cfg import DEFAULTS
@@ -67,8 +67,8 @@ def __init__(
self,
fs: Optional[int] = None,
num_mix: int = 1,
- alpha: Real = 0.5,
- beta: Optional[Real] = None,
+ alpha: float = 0.5,
+ beta: Optional[float] = None,
prob: float = 0.5,
inplace: bool = True,
**kwargs: Any,
@@ -174,7 +174,7 @@ def extra_repr_keys(self) -> List[str]:
] + super().extra_repr_keys()
-def _make_intervals(lam: Tensor, siglen: int) -> np.ndarray:
+def _make_intervals(lam: Tensor, siglen: int) -> NDArray:
"""Make intervals for cutmix.
Parameters
diff --git a/torch_ecg/augmenters/stretch_compress.py b/torch_ecg/augmenters/stretch_compress.py
index d6bb8af7..1337dc0e 100644
--- a/torch_ecg/augmenters/stretch_compress.py
+++ b/torch_ecg/augmenters/stretch_compress.py
@@ -8,6 +8,7 @@
import scipy.signal as SS
import torch
import torch.nn.functional as F
+from numpy.typing import NDArray
from torch import Tensor
from ..cfg import DEFAULTS
@@ -383,10 +384,10 @@ def __init__(
def generate(
self,
seglen: int,
- sig: np.ndarray,
- *labels: Sequence[np.ndarray],
+ sig: NDArray,
+ *labels: Sequence[NDArray],
critical_points: Optional[Sequence[int]] = None,
- ) -> List[Tuple[Union[np.ndarray, int], ...]]:
+ ) -> List[Tuple[Union[NDArray, int], ...]]:
"""Generate stretched or compressed segments from the ECGs.
Parameters
@@ -465,11 +466,11 @@ def generate(
def __generate_segment(
self,
seglen: int,
- sig: np.ndarray,
- *labels: Sequence[np.ndarray],
+ sig: NDArray,
+ *labels: Sequence[NDArray],
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
- ) -> Tuple[Union[np.ndarray, int], ...]:
+ ) -> Tuple[Union[NDArray, int], ...]:
"""Internal function to generate a stretched or compressed segment.
Parameters
@@ -564,10 +565,10 @@ def _sample_ratio(self) -> float:
def __call__(
self,
seglen: int,
- sig: np.ndarray,
- *labels: Sequence[np.ndarray],
+ sig: NDArray,
+ *labels: Sequence[NDArray],
critical_points: Optional[Sequence[int]] = None,
- ) -> List[Tuple[np.ndarray, ...]]:
+ ) -> List[Tuple[NDArray, ...]]:
return self.generate(seglen, sig, *labels, critical_points=critical_points)
def extra_repr_keys(self) -> List[str]:
diff --git a/torch_ecg/cfg.py b/torch_ecg/cfg.py
index 0ffaed0a..70037ff0 100644
--- a/torch_ecg/cfg.py
+++ b/torch_ecg/cfg.py
@@ -84,7 +84,7 @@ def __setattr__(self, name: str, value: Any) -> None:
__setitem__ = __setattr__
- def update(self, new_cfg: Optional[MutableMapping] = None, **kwargs: Any) -> None:
+ def update(self, new_cfg: Optional[MutableMapping] = None, **kwargs: Any) -> None: # type: ignore
"""The new hierarchical update method.
Parameters
@@ -157,9 +157,9 @@ class DTYPE:
"""
STR: str
- NP: np.dtype = None
- TORCH: torch.dtype = None
- INT: int = None # int representation of the dtype, mainly used for `wfdb.rdrecord`
+ NP: np.dtype = None # type: ignore
+ TORCH: torch.dtype = None # type: ignore
+ INT: int = None # int representation of the dtype, mainly used for `wfdb.rdrecord` # type: ignore
def __post_init__(self) -> None:
"""check consistency"""
@@ -168,12 +168,12 @@ def __post_init__(self) -> None:
if self.TORCH is None:
self.TORCH = eval(f"torch.{self.STR}")
if self.INT is None:
- self.INT = int(re.search("\\d+", self.STR).group(0))
+ self.INT = int(re.search("\\d+", self.STR).group(0)) # type: ignore
assert all(
[
self.NP == getattr(np, self.STR),
self.TORCH == getattr(torch, self.STR),
- self.INT == int(re.search("\\d+", self.STR).group(0)),
+ self.INT == int(re.search("\\d+", self.STR).group(0)), # type: ignore
]
), "inconsistent dtype"
diff --git a/torch_ecg/components/inputs.py b/torch_ecg/components/inputs.py
index fc348c03..12f07478 100644
--- a/torch_ecg/components/inputs.py
+++ b/torch_ecg/components/inputs.py
@@ -6,9 +6,9 @@
from copy import deepcopy
from typing import List, Literal, Sequence, Tuple, Union
-import numpy as np
import torch
from einops.layers.torch import Rearrange
+from numpy.typing import NDArray
from torch.nn import functional as F
from ..cfg import CFG, DEFAULTS
@@ -109,7 +109,7 @@ def __init__(self, config: InputConfig) -> None:
self._device = self._config.get("device", DEFAULTS.device)
self._post_init()
- def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Method to transform the waveform to the input tensor.
Parameters
@@ -126,11 +126,11 @@ def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
return self.from_waveform(waveform)
@abstractmethod
- def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
raise NotImplementedError
- def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Transform the waveform to the input tensor.
Parameters
@@ -289,12 +289,12 @@ def _post_init(self) -> None:
"""Make sure the input type is `waveform`."""
assert self.input_type == "waveform", "`input_type` must be `waveform`"
- def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = torch.as_tensor(waveform).to(self.device, self.dtype)
return self._values
- def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor`.
@@ -320,7 +320,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
- def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
""" """
return self.from_waveform(waveform)
@@ -329,16 +329,18 @@ class FFTInput(BaseInput):
"""Inputs from the FFT, via concatenating the amplitudes and the phases.
One can set the following optional parameters for initialization:
- - nfft: int
- the number of FFT bins.
- If nfft is None, the number of FFT bins is computed from the input shape.
- - drop_dc: bool, default True
- Whether to drop the zero frequency bin (the DC component).
- - norm: str, optional
- The normalization of the FFT, can be
- - "forward"
- - "backward"
- - "ortho"
+
+ - nfft: int
+ the number of FFT bins.
+ If nfft is None, the number of FFT bins is computed from the input shape.
+ - drop_dc: bool, default True
+ Whether to drop the zero frequency bin (the DC component).
+ - norm: str, optional
+ The normalization of the FFT, can be
+
+ - "forward"
+ - "backward"
+ - "ortho"
Examples
--------
@@ -400,7 +402,7 @@ def _post_init(self) -> None:
"ortho",
], f"`norm` must be one of [`forward`, `backward`, `ortho`], got {self.norm}"
- def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = torch.fft.rfft(
torch.as_tensor(waveform).to(self.device, self.dtype),
@@ -413,7 +415,7 @@ def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Ten
self._values = torch.cat([torch.abs(self._values), torch.angle(self._values)], dim=-2)
return self._values
- def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of FFTs.
@@ -441,7 +443,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
- def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
return self.from_waveform(waveform)
def extra_repr_keys(self) -> List[str]:
@@ -452,25 +454,26 @@ class _SpectralInput(BaseInput):
"""Inputs from the spectro-temporal domain.
One has to set the following parameters for initialization:
- - n_bins : int
- The number of frequency bins.
- - fs (or sample_rate) : int
- The sample rate of the waveform.
+
+ - n_bins : int
+ The number of frequency bins.
+ - fs (or sample_rate) : int
+ The sample rate of the waveform.
with the following optional parameters with default values:
- - window_size : float, default: 1 / 20
- The size of the window in seconds.
- - overlap_size : float, default: 1 / 40
- The overlap of the windows in seconds.
- - feature_fs : None or float,
- The sample rate of the features.
- If specified, the features will be resampled
- against `fs` to this sample rate.
- - to1d : bool, default False
- Whether to convert the features to 1D.
- NOTE that if `to1d` is True,
- then if the convolutions with ``groups=1`` applied to the `input`
- acts on all the bins, which is "global"
- w.r.t. the `bins` dimension of the corresponding 2d input.
+ - window_size : float, default: 1 / 20
+ The size of the window in seconds.
+ - overlap_size : float, default: 1 / 40
+ The overlap of the windows in seconds.
+ - feature_fs : None or float,
+ The sample rate of the features.
+ If specified, the features will be resampled
+ against `fs` to this sample rate.
+ - to1d : bool, default False
+ Whether to convert the features to 1D.
+ NOTE that if `to1d` is True,
+ then if the convolutions with ``groups=1`` applied to the `input`
+ acts on all the bins, which is "global"
+ w.r.t. the `bins` dimension of the corresponding 2d input.
"""
@@ -589,7 +592,7 @@ def _post_init(self) -> None:
Rearrange("... channel n_bins time -> ... (channel n_bins) time"),
)
- def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
"""Internal method to convert the waveform to the input tensor."""
self._values = self._transform(torch.as_tensor(waveform).to(self.device, self.dtype))
if self.feature_fs is not None:
@@ -605,7 +608,7 @@ def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Ten
self._values = F.interpolate(self._values, scale_factor=scale_factor, recompute_scale_factor=True)
return self._values
- def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
r"""Converts the input :class:`~numpy.ndarray` or
:class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of spectrograms.
@@ -635,5 +638,5 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens
return super().from_waveform(waveform)
@add_docstring(from_waveform.__doc__)
- def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
+ def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor:
return self.from_waveform(waveform)
diff --git a/torch_ecg/components/metrics.py b/torch_ecg/components/metrics.py
index 2c7fcae1..fa3094e5 100644
--- a/torch_ecg/components/metrics.py
+++ b/torch_ecg/components/metrics.py
@@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
import numpy as np
+from numpy.typing import NDArray
from torch import Tensor
from ..utils.misc import ReprMixin, add_docstring
@@ -55,10 +56,10 @@ class ClassificationMetrics(Metrics):
.. code-block:: python
def extra_metrics(
- labels : np.ndarray
- outputs : np.ndarray
+ labels : NDArray
+ outputs : NDArray
num_classes : Optional[int]=None
- weights : Optional[np.ndarray]=None
+ weights : Optional[NDArray]=None
) -> dict
"""
@@ -140,10 +141,10 @@ def set_macro(self, macro: bool) -> None:
)
def compute(
self,
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[np.ndarray] = None,
+ weights: Optional[NDArray] = None,
thr: float = 0.5,
) -> "ClassificationMetrics":
labels, outputs = one_hot_pair(labels, outputs, num_classes)
@@ -164,154 +165,154 @@ def compute(
@add_docstring(compute.__doc__)
def __call__(
self,
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[np.ndarray] = None,
+ weights: Optional[NDArray] = None,
thr: float = 0.5,
) -> "ClassificationMetrics":
return self.compute(labels, outputs, num_classes, weights)
@property
- def sensitivity(self) -> Union[float, np.ndarray]:
+ def sensitivity(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}sens"]
@property
- def recall(self) -> Union[float, np.ndarray]:
+ def recall(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}sens"]
@property
- def hit_rate(self) -> Union[float, np.ndarray]:
+ def hit_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}sens"]
@property
- def true_positive_rate(self) -> Union[float, np.ndarray]:
+ def true_positive_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}sens"]
@property
- def specificity(self) -> Union[float, np.ndarray]:
+ def specificity(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}spec"]
@property
- def selectivity(self) -> Union[float, np.ndarray]:
+ def selectivity(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}spec"]
@property
- def true_negative_rate(self) -> Union[float, np.ndarray]:
+ def true_negative_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}spec"]
@property
- def precision(self) -> Union[float, np.ndarray]:
+ def precision(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}prec"]
@property
- def positive_predictive_value(self) -> Union[float, np.ndarray]:
+ def positive_predictive_value(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}prec"]
@property
- def negative_predictive_value(self) -> Union[float, np.ndarray]:
+ def negative_predictive_value(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}npv"]
@property
- def jaccard_index(self) -> Union[float, np.ndarray]:
+ def jaccard_index(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}jac"]
@property
- def threat_score(self) -> Union[float, np.ndarray]:
+ def threat_score(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}jac"]
@property
- def critical_success_index(self) -> Union[float, np.ndarray]:
+ def critical_success_index(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}jac"]
@property
- def accuracy(self) -> Union[float, np.ndarray]:
+ def accuracy(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}acc"]
@property
- def phi_coefficient(self) -> Union[float, np.ndarray]:
+ def phi_coefficient(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}phi"]
@property
- def matthews_correlation_coefficient(self) -> Union[float, np.ndarray]:
+ def matthews_correlation_coefficient(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}phi"]
@property
- def false_negative_rate(self) -> Union[float, np.ndarray]:
+ def false_negative_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fnr"]
@property
- def miss_rate(self) -> Union[float, np.ndarray]:
+ def miss_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fnr"]
@property
- def false_positive_rate(self) -> Union[float, np.ndarray]:
+ def false_positive_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fpr"]
@property
- def fall_out(self) -> Union[float, np.ndarray]:
+ def fall_out(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fpr"]
@property
- def false_discovery_rate(self) -> Union[float, np.ndarray]:
+ def false_discovery_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fdr"]
@property
- def false_omission_rate(self) -> Union[float, np.ndarray]:
+ def false_omission_rate(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}for"]
@property
- def positive_likelihood_ratio(self) -> Union[float, np.ndarray]:
+ def positive_likelihood_ratio(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}plr"]
@property
- def negative_likelihood_ratio(self) -> Union[float, np.ndarray]:
+ def negative_likelihood_ratio(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}nlr"]
@property
- def prevalence_threshold(self) -> Union[float, np.ndarray]:
+ def prevalence_threshold(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}pt"]
@property
- def balanced_accuracy(self) -> Union[float, np.ndarray]:
+ def balanced_accuracy(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}ba"]
@property
- def f1_measure(self) -> Union[float, np.ndarray]:
+ def f1_measure(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}f1"]
@property
- def fowlkes_mallows_index(self) -> Union[float, np.ndarray]:
+ def fowlkes_mallows_index(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}fm"]
@property
- def bookmaker_informedness(self) -> Union[float, np.ndarray]:
+ def bookmaker_informedness(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}bm"]
@property
- def markedness(self) -> Union[float, np.ndarray]:
+ def markedness(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}mk"]
@property
- def diagnostic_odds_ratio(self) -> Union[float, np.ndarray]:
+ def diagnostic_odds_ratio(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}dor"]
@property
def area_under_the_receiver_operater_characteristic_curve(
self,
- ) -> Union[float, np.ndarray]:
+ ) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}auroc"]
@property
- def auroc(self) -> Union[float, np.ndarray]:
+ def auroc(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}auroc"]
@property
- def area_under_the_precision_recall_curve(self) -> Union[float, np.ndarray]:
+ def area_under_the_precision_recall_curve(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}auprc"]
@property
- def auprc(self) -> Union[float, np.ndarray]:
+ def auprc(self) -> Union[float, NDArray]:
return self._metrics[f"{self.__prefix}auprc"]
@property
@@ -348,8 +349,8 @@ class RPeaksDetectionMetrics(Metrics):
.. code-block:: python
def extra_metrics(
- labels : Sequence[Union[Sequence[int], np.ndarray]],
- outputs : Sequence[Union[Sequence[int], np.ndarray]],
+ labels : Sequence[Union[Sequence[int], NDArray]],
+ outputs : Sequence[Union[Sequence[int], NDArray]],
fs : int
) -> dict
@@ -395,8 +396,8 @@ def __init__(
)
def compute(
self,
- labels: Sequence[Union[Sequence[int], np.ndarray]],
- outputs: Sequence[Union[Sequence[int], np.ndarray]],
+ labels: Sequence[Union[Sequence[int], NDArray]],
+ outputs: Sequence[Union[Sequence[int], NDArray]],
fs: int,
thr: Optional[float] = None,
) -> "RPeaksDetectionMetrics":
@@ -410,8 +411,8 @@ def compute(
@add_docstring(compute.__doc__)
def __call__(
self,
- labels: Sequence[Union[Sequence[int], np.ndarray]],
- outputs: Sequence[Union[Sequence[int], np.ndarray]],
+ labels: Sequence[Union[Sequence[int], NDArray]],
+ outputs: Sequence[Union[Sequence[int], NDArray]],
fs: int,
thr: Optional[float] = None,
) -> "RPeaksDetectionMetrics":
@@ -446,8 +447,8 @@ class WaveDelineationMetrics(Metrics):
.. code-block:: python
def extra_metrics(
- labels: Sequence[Union[Sequence[int], np.ndarray]],
- outputs: Sequence[Union[Sequence[int], np.ndarray]],
+ labels: Sequence[Union[Sequence[int], NDArray]],
+ outputs: Sequence[Union[Sequence[int], NDArray]],
fs: int
) -> dict
@@ -531,8 +532,8 @@ def set_macro(self, macro: bool) -> None:
)
def compute(
self,
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
class_map: Dict[str, int],
fs: int,
mask_format: str = "channel_first",
@@ -592,8 +593,8 @@ def compute(
@add_docstring(compute.__doc__)
def __call__(
self,
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
class_map: Dict[str, int],
fs: int,
mask_format: str = "channel_first",
diff --git a/torch_ecg/components/trainer.py b/torch_ecg/components/trainer.py
index 127d23e6..c3e20b62 100644
--- a/torch_ecg/components/trainer.py
+++ b/torch_ecg/components/trainer.py
@@ -10,7 +10,7 @@
from collections import OrderedDict, deque
from copy import deepcopy
from pathlib import Path
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -53,18 +53,18 @@ class BaseTrainer(ReprMixin, ABC):
Will also be recorded in the checkpoints.
`train_config` should at least contain the following keys:
- - "monitor": str
- - "loss": str
- - "n_epochs": int
- - "batch_size": int
- - "learning_rate": float
- - "lr_scheduler": str
- - "lr_step_size": int, optional, depending on the scheduler
- - "lr_gamma": float, optional, depending on the scheduler
- - "max_lr": float, optional, depending on the scheduler
- - "optimizer": str
- - "decay": float, optional, depending on the optimizer
- - "momentum": float, optional, depending on the optimizer
+ - "monitor": str
+ - "loss": str
+ - "n_epochs": int
+ - "batch_size": int
+ - "learning_rate": float
+ - "lr_scheduler": str
+ - "lr_step_size": int, optional, depending on the scheduler
+ - "lr_gamma": float, optional, depending on the scheduler
+ - "max_lr": float, optional, depending on the scheduler
+ - "optimizer": str
+ - "decay": float, optional, depending on the optimizer
+ - "momentum": float, optional, depending on the optimizer
collate_fn : callable, optional
The collate function for the data loader,
defaults to :meth:`default_collate_fn`.
@@ -93,7 +93,7 @@ def __init__(
dataset_cls: Dataset,
model_config: dict,
train_config: dict,
- collate_fn: Optional[callable] = None,
+ collate_fn: Optional[Callable] = None,
device: Optional[torch.device] = None,
lazy: bool = False,
) -> None:
@@ -108,7 +108,7 @@ def __init__(
self.dataset_cls = dataset_cls
self.model_config = CFG(deepcopy(model_config))
self._train_config = CFG(deepcopy(train_config))
- self._train_config.checkpoints = Path(self._train_config.checkpoints)
+ self._train_config.checkpoints = Path(self._train_config.checkpoints) # type: ignore
self.device = device or next(self._model.parameters()).device
self.dtype = next(self._model.parameters()).dtype
self.model.to(self.device)
@@ -150,12 +150,12 @@ def train(self) -> OrderedDict:
self._setup_criterion()
- if self.train_config.monitor is not None:
+ if self.train_config.monitor is not None: # type: ignore
# if monitor is set but val_loader is None, use train_loader for validation
# and choose the best model based on the metrics on the train set
if self.val_loader is None and self.val_train_loader is None:
self.val_train_loader = self.train_loader
- self.log_manager.log_message(
+ self.log_manager.log_message( # type: ignore
(
"No separate validation set is provided, while monitor is set. "
"The training set will be used for validation, "
@@ -179,7 +179,7 @@ def train(self) -> OrderedDict:
-----------------------------------------
"""
)
- self.log_manager.log_message(msg)
+ self.log_manager.log_message(msg) # type: ignore
start_epoch = self.epoch
for _ in range(start_epoch, self.n_epochs):
@@ -193,15 +193,15 @@ def train(self) -> OrderedDict:
dynamic_ncols=True,
mininterval=1.0,
) as pbar:
- self.log_manager.epoch_start(self.epoch)
+ self.log_manager.epoch_start(self.epoch) # type: ignore
# train one epoch
self.train_one_epoch(pbar)
# evaluate on train set, if debug is True
if self.val_train_loader is not None:
eval_train_res = self.evaluate(self.val_train_loader)
- self.log_manager.log_metrics(
- metrics=eval_train_res,
+ self.log_manager.log_metrics( # type: ignore
+ metrics=eval_train_res, # type: ignore
step=self.global_step,
epoch=self.epoch,
part="train",
@@ -211,8 +211,8 @@ def train(self) -> OrderedDict:
# evaluate on val set
if self.val_loader is not None:
eval_res = self.evaluate(self.val_loader)
- self.log_manager.log_metrics(
- metrics=eval_res,
+ self.log_manager.log_metrics( # type: ignore
+ metrics=eval_res, # type: ignore
step=self.global_step,
epoch=self.epoch,
part="val",
@@ -224,19 +224,19 @@ def train(self) -> OrderedDict:
eval_res = {}
# update best model and best metric if monitor is set
- if self.train_config.monitor is not None:
- if eval_res[self.train_config.monitor] > self.best_metric:
- self.best_metric = eval_res[self.train_config.monitor]
+ if self.train_config.monitor is not None: # type: ignore
+ if eval_res[self.train_config.monitor] > self.best_metric: # type: ignore
+ self.best_metric = eval_res[self.train_config.monitor] # type: ignore
self.best_state_dict = self._model.state_dict()
self.best_eval_res = deepcopy(eval_res)
self.best_epoch = self.epoch
self.pseudo_best_epoch = self.epoch
- elif self.train_config.early_stopping:
- if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta:
+ elif self.train_config.early_stopping: # type: ignore
+ if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta: # type: ignore
self.pseudo_best_epoch = self.epoch
- elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience:
+ elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience: # type: ignore
msg = f"early stopping is triggered at epoch {self.epoch}"
- self.log_manager.log_message(msg)
+ self.log_manager.log_message(msg) # type: ignore
break
msg = textwrap.dedent(
@@ -245,62 +245,65 @@ def train(self) -> OrderedDict:
obtained at epoch {self.best_epoch}
"""
)
- self.log_manager.log_message(msg)
+ self.log_manager.log_message(msg) # type: ignore
# save checkpoint
- save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}"
+ save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}" # type: ignore
else:
save_suffix = f"epochloss_{self.epoch_loss:.5f}"
- save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
- save_path = self.train_config.checkpoints / save_filename
- if self.train_config.keep_checkpoint_max != 0:
+ # save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar"
+ save_folder = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}"
+ save_path = self.train_config.checkpoints / save_folder # type: ignore
+ if self.train_config.keep_checkpoint_max != 0: # type: ignore
self.save_checkpoint(str(save_path))
self.saved_models.append(save_path)
# remove outdated models
- if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0:
+ if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0: # type: ignore
model_to_remove = self.saved_models.popleft()
try:
os.remove(model_to_remove)
except Exception:
- self.log_manager.log_message(f"failed to remove {str(model_to_remove)}")
+ self.log_manager.log_message(f"failed to remove {str(model_to_remove)}") # type: ignore
# update learning rate using lr_scheduler
- if self.train_config.lr_scheduler.lower() == "plateau":
+ if self.train_config.lr_scheduler.lower() == "plateau": # type: ignore
self._update_lr(eval_res)
- self.log_manager.epoch_end(self.epoch)
+ self.log_manager.epoch_end(self.epoch) # type: ignore
self.epoch += 1
# save the best model
if self.best_metric > -np.inf:
- if self.train_config.final_model_name:
- save_filename = self.train_config.final_model_name
+ if self.train_config.final_model_name: # type: ignore
+ save_folder = self.train_config.final_model_name # type: ignore
else:
- save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}"
- save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
- save_path = self.train_config.model_dir / save_filename
+ save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}" # type: ignore
+ # save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar"
+ save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}"
+ save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
- self.log_manager.log_message(f"best model is saved at {save_path}")
- elif self.train_config.monitor is None:
- self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model")
+ self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore
+ elif self.train_config.monitor is None: # type: ignore
+ self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore
self.best_state_dict = self._model.state_dict()
- save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
- save_path = self.train_config.model_dir / save_filename
+ # save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar"
+ save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}"
+ save_path = self.train_config.model_dir / save_folder # type: ignore
# self.save_checkpoint(path=str(save_path))
self._model.save(path=str(save_path), train_config=self.train_config)
else:
raise ValueError("No best model found!")
- self.log_manager.close()
+ self.log_manager.close() # type: ignore
if not self.best_state_dict:
# in case no best model is found,
# e.g. monitor is not set, or keep_checkpoint_max is 0
self.best_state_dict = self._model.state_dict()
- return self.best_state_dict
+ return self.best_state_dict # type: ignore
def train_one_epoch(self, pbar: tqdm) -> None:
"""Train one epoch, and update the progress bar
@@ -311,16 +314,16 @@ def train_one_epoch(self, pbar: tqdm) -> None:
The progress bar for training.
"""
- for epoch_step, data in enumerate(self.train_loader):
+ for epoch_step, data in enumerate(self.train_loader): # type: ignore
self.global_step += 1
# data is assumed to be a tuple of tensors, of the following order:
# signals, labels, *extra_tensors
- data = self.augmenter_manager(*data)
+ data = self.augmenter_manager(*data) # type: ignore
out_tensors = self.run_one_step(*data)
loss = self.criterion(*out_tensors).to(self.dtype)
- if self.train_config.flooding_level > 0:
- flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level
+ if self.train_config.flooding_level > 0: # type: ignore
+ flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level # type: ignore
self.epoch_loss += loss.item()
self.optimizer.zero_grad()
flood.backward()
@@ -331,7 +334,7 @@ def train_one_epoch(self, pbar: tqdm) -> None:
self.optimizer.step()
self._update_lr()
- if self.global_step % self.train_config.log_step == 0:
+ if self.global_step % self.train_config.log_step == 0: # type: ignore
train_step_metrics = {"loss": loss.item()}
if self.scheduler:
train_step_metrics.update({"lr": self.scheduler.get_last_lr()[0]})
@@ -347,9 +350,9 @@ def train_one_epoch(self, pbar: tqdm) -> None:
"loss (batch)": loss.item(),
}
)
- if self.train_config.flooding_level > 0:
- train_step_metrics.update({"flood": flood.item()})
- self.log_manager.log_metrics(
+ if self.train_config.flooding_level > 0: # type: ignore
+ train_step_metrics.update({"flood": flood.item()}) # type: ignore
+ self.log_manager.log_metrics( # type: ignore
metrics=train_step_metrics,
step=self.global_step,
epoch=self.epoch,
@@ -458,22 +461,22 @@ def _update_lr(self, eval_res: Optional[dict] = None) -> None:
The evaluation results (metrics).
"""
- if self.train_config.lr_scheduler.lower() == "none":
+ if self.train_config.lr_scheduler.lower() == "none": # type: ignore
pass
- elif self.train_config.lr_scheduler.lower() == "plateau":
+ elif self.train_config.lr_scheduler.lower() == "plateau": # type: ignore
if eval_res is None:
return
- metrics = eval_res[self.train_config.monitor]
+ metrics = eval_res[self.train_config.monitor] # type: ignore
if isinstance(metrics, torch.Tensor):
metrics = metrics.item()
- self.scheduler.step(metrics)
- elif self.train_config.lr_scheduler.lower() == "step":
- self.scheduler.step()
- elif self.train_config.lr_scheduler.lower() in [
+ self.scheduler.step(metrics) # type: ignore
+ elif self.train_config.lr_scheduler.lower() == "step": # type: ignore
+ self.scheduler.step() # type: ignore
+ elif self.train_config.lr_scheduler.lower() in [ # type: ignore
"one_cycle",
"onecycle",
]:
- self.scheduler.step()
+ self.scheduler.step() # type: ignore
def _setup_from_config(self, train_config: dict) -> None:
"""Setup the trainer from the training configuration.
@@ -492,14 +495,14 @@ def _setup_from_config(self, train_config: dict) -> None:
self._validate_train_config()
# set aliases
- self.n_epochs = self.train_config.n_epochs
- self.batch_size = self.train_config.batch_size
- self.lr = self.train_config.learning_rate
+ self.n_epochs = self.train_config.n_epochs # type: ignore
+ self.batch_size = self.train_config.batch_size # type: ignore
+ self.lr = self.train_config.learning_rate # type: ignore
# setup log manager first
self._setup_log_manager()
msg = f"training configurations are as follows:\n{dict_to_str(self.train_config)}"
- self.log_manager.log_message(msg)
+ self.log_manager.log_message(msg) # type: ignore
# setup directories
self._setup_directories()
@@ -517,7 +520,7 @@ def _setup_from_config(self, train_config: dict) -> None:
def extra_log_suffix(self) -> str:
"""Extra suffix for the log file name."""
model_name = self._model.__name__ if hasattr(self._model, "__name__") else self._model.__class__.__name__
- return f"{model_name}_{self.train_config.optimizer}_LR_{self.lr}_BS_{self.batch_size}"
+ return f"{model_name}_{self.train_config.optimizer}_LR_{self.lr}_BS_{self.batch_size}" # type: ignore
def _setup_log_manager(self) -> None:
"""Setup the log manager."""
@@ -528,27 +531,27 @@ def _setup_log_manager(self) -> None:
def _setup_directories(self) -> None:
"""Setup the directories for saving checkpoints and logs."""
if not self.train_config.get("model_dir", None):
- self._train_config.model_dir = self.train_config.checkpoints
- self._train_config.model_dir = Path(self._train_config.model_dir)
- self.train_config.checkpoints.mkdir(parents=True, exist_ok=True)
- self.train_config.model_dir.mkdir(parents=True, exist_ok=True)
+ self._train_config.model_dir = self.train_config.checkpoints # type: ignore
+ self._train_config.model_dir = Path(self._train_config.model_dir) # type: ignore
+ self.train_config.checkpoints.mkdir(parents=True, exist_ok=True) # type: ignore
+ self.train_config.model_dir.mkdir(parents=True, exist_ok=True) # type: ignore
def _setup_callbacks(self) -> None:
"""Setup the callbacks."""
self._train_config.monitor = self.train_config.get("monitor", None)
- if self.train_config.monitor is None:
+ if self.train_config.monitor is None: # type: ignore
assert (
- self.train_config.lr_scheduler.lower() != "plateau"
+ self.train_config.lr_scheduler.lower() != "plateau" # type: ignore
), "monitor is not specified, lr_scheduler should not be ReduceLROnPlateau"
self._train_config.keep_checkpoint_max = self.train_config.get("keep_checkpoint_max", 1)
- if self._train_config.keep_checkpoint_max < 0:
+ if self._train_config.keep_checkpoint_max < 0: # type: ignore
self._train_config.keep_checkpoint_max = -1
- self.log_manager.log_message(
+ self.log_manager.log_message( # type: ignore
msg="keep_checkpoint_max is set to -1, all checkpoints will be kept",
level=logging.WARNING,
)
- elif self._train_config.keep_checkpoint_max == 0:
- self.log_manager.log_message(
+ elif self._train_config.keep_checkpoint_max == 0: # type: ignore
+ self.log_manager.log_message( # type: ignore
msg="keep_checkpoint_max is set to 0, no checkpoint will be kept",
level=logging.WARNING,
)
@@ -617,7 +620,7 @@ def n_val(self) -> int:
def _setup_optimizer(self) -> None:
"""Setup the optimizer."""
- if self.train_config.optimizer.lower() == "adam":
+ if self.train_config.optimizer.lower() == "adam": # type: ignore
optimizer_kwargs = get_kwargs(optim.Adam)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(dict(lr=self.lr))
@@ -625,20 +628,20 @@ def _setup_optimizer(self) -> None:
params=self.model.parameters(),
**optimizer_kwargs,
)
- elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]:
+ elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]: # type: ignore
optimizer_kwargs = get_kwargs(optim.AdamW)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(
dict(
lr=self.lr,
- amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"),
+ amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"), # type: ignore
)
)
self.optimizer = optim.AdamW(
params=self.model.parameters(),
**optimizer_kwargs,
)
- elif self.train_config.optimizer.lower() == "sgd":
+ elif self.train_config.optimizer.lower() == "sgd": # type: ignore
optimizer_kwargs = get_kwargs(optim.SGD)
optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()})
optimizer_kwargs.update(dict(lr=self.lr))
@@ -648,44 +651,42 @@ def _setup_optimizer(self) -> None:
)
else:
raise NotImplementedError(
- f"optimizer `{self.train_config.optimizer}` not implemented! "
+ f"optimizer `{self.train_config.optimizer}` not implemented! " # type: ignore
"Please use one of the following: `adam`, `adamw`, `adamw_amsgrad`, `sgd`, "
"or override this method to setup your own optimizer."
)
def _setup_scheduler(self) -> None:
"""Setup the learning rate scheduler."""
- if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none":
+ if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none": # type: ignore
self.train_config.lr_scheduler = "none"
self.scheduler = None
- elif self.train_config.lr_scheduler.lower() == "plateau":
+ elif self.train_config.lr_scheduler.lower() == "plateau": # type: ignore
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
"max",
patience=2,
- verbose=False,
)
- elif self.train_config.lr_scheduler.lower() == "step":
+ elif self.train_config.lr_scheduler.lower() == "step": # type: ignore
self.scheduler = optim.lr_scheduler.StepLR(
self.optimizer,
- self.train_config.lr_step_size,
- self.train_config.lr_gamma,
- # verbose=False,
+ self.train_config.lr_step_size, # type: ignore
+ self.train_config.lr_gamma, # type: ignore
)
- elif self.train_config.lr_scheduler.lower() in [
+ elif self.train_config.lr_scheduler.lower() in [ # type: ignore
"one_cycle",
"onecycle",
]:
self.scheduler = optim.lr_scheduler.OneCycleLR(
optimizer=self.optimizer,
- max_lr=self.train_config.max_lr,
+ max_lr=self.train_config.max_lr, # type: ignore
epochs=self.n_epochs,
- steps_per_epoch=len(self.train_loader),
+ steps_per_epoch=len(self.train_loader), # type: ignore
# verbose=False,
)
else: # TODO: add linear and linear with warmup schedulers
raise NotImplementedError(
- f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! "
+ f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! " # type: ignore
"Please use one of the following: `none`, `plateau`, `step`, `one_cycle`, "
"or override this method to setup your own lr scheduler."
)
@@ -696,7 +697,7 @@ def _setup_criterion(self) -> None:
for k, v in loss_kw.items():
if isinstance(v, torch.Tensor):
loss_kw[k] = v.to(device=self.device, dtype=self.dtype)
- self.criterion = setup_criterion(self.train_config.loss, **loss_kw)
+ self.criterion = setup_criterion(self.train_config.loss, **loss_kw) # type: ignore
self.criterion.to(self.device)
def _check_model_config_compatability(self, model_config: dict) -> bool:
@@ -765,16 +766,30 @@ def save_checkpoint(self, path: str) -> None:
Path to save the checkpoint
"""
- torch.save(
- {
- "model_state_dict": self._model.state_dict(),
- "optimizer_state_dict": self.optimizer.state_dict(),
- "model_config": make_safe_globals(self.model_config),
- "train_config": make_safe_globals(self.train_config),
- "epoch": self.epoch,
- },
- path,
- )
+ # if self._model has method `save`, then use it
+ if hasattr(self._model, "save"):
+ self._model.save(
+ path=path,
+ train_config=self.train_config,
+ extra_items={
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ "epoch": self.epoch,
+ },
+ use_safetensors=True,
+ )
+ else:
+ if not str(path).endswith(".pth.tar"):
+ path = Path(path).with_suffix(".pth.tar") # type: ignore
+ torch.save(
+ {
+ "model_state_dict": self._model.state_dict(),
+ "optimizer_state_dict": self.optimizer.state_dict(),
+ "model_config": make_safe_globals(self.model_config),
+ "train_config": make_safe_globals(self.train_config),
+ "epoch": self.epoch,
+ },
+ path,
+ )
def extra_repr_keys(self) -> List[str]:
return [
diff --git a/torch_ecg/databases/aux_data/aha.py b/torch_ecg/databases/aux_data/aha.py
index 72fef94d..070d8a2b 100644
--- a/torch_ecg/databases/aux_data/aha.py
+++ b/torch_ecg/databases/aux_data/aha.py
@@ -11,10 +11,7 @@
import pandas as pd
-try:
- pd.set_option("future.no_silent_downcasting", True)
-except Exception: # pandas._config.config.OptionError: "No such keys(s): 'future.no_silent_downcasting'"
- pass
+pd.set_option("future.no_silent_downcasting", True)
__all__ = [
"df_primary_statements",
diff --git a/torch_ecg/databases/aux_data/cinc2020_aux_data.py b/torch_ecg/databases/aux_data/cinc2020_aux_data.py
index 2786873a..2bfae5e0 100644
--- a/torch_ecg/databases/aux_data/cinc2020_aux_data.py
+++ b/torch_ecg/databases/aux_data/cinc2020_aux_data.py
@@ -8,8 +8,8 @@
from numbers import Real
from typing import Dict, Literal, Optional, Sequence, Union
-import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from ...cfg import CFG
@@ -248,7 +248,7 @@
def load_weights(
classes: Sequence[Union[int, str]] = None, return_fmt: Literal["np", "pd"] = "np"
-) -> Union[np.ndarray, pd.DataFrame]:
+) -> Union[NDArray, pd.DataFrame]:
"""Load the weight matrix of the `classes`.
Parameters
@@ -342,7 +342,7 @@ def get_class_count(
exclude_classes: Optional[Sequence[str]] = None,
scored_only: bool = False,
normalize: bool = True,
- threshold: Optional[Real] = 0,
+ threshold: Union[float, int] = 0,
fmt: str = "a",
) -> Dict[str, int]:
"""Get the number of classes in the `tranches`.
@@ -359,14 +359,15 @@ def get_class_count(
normalize : bool, default True
whether collapse equivalent classes into one or not,
used only when `scored_only` is True.
- threshold : numbers.Real, default 0
+ threshold : float or int, default 0
Minimum ratio (0-1) or absolute number (>1) of a class to be counted.
fmt : str, default "a"
Format of the names of the classes in the returned dict,
can be one of the following (case insensitive):
- - "a", abbreviations
- - "f", full names
- - "s", SNOMED CT Code
+
+ - "a", abbreviations
+ - "f", full names
+ - "s", SNOMED CT Code
Returns
-------
@@ -464,8 +465,8 @@ def get_class_weight(
Returns:
--------
class_weight : dict
- - key: class in the format of `fmt`
- - value: weight of a class in `tranches`
+ key: class in the format of `fmt`,
+ value: weight of a class in `tranches`.
"""
class_count = get_class_count(
diff --git a/torch_ecg/databases/aux_data/cinc2021_aux_data.py b/torch_ecg/databases/aux_data/cinc2021_aux_data.py
index 859aa61d..f291ca9e 100644
--- a/torch_ecg/databases/aux_data/cinc2021_aux_data.py
+++ b/torch_ecg/databases/aux_data/cinc2021_aux_data.py
@@ -8,8 +8,8 @@
from numbers import Real
from typing import Dict, List, Literal, Optional, Sequence, Union
-import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from ...cfg import _DATA_CACHE, CFG
@@ -336,7 +336,7 @@ def load_weights(
classes: Sequence[Union[int, str]] = None,
equivalent_classes: Optional[Union[Dict[str, str], List[List[str]]]] = None,
return_fmt: Literal["np", "pd"] = "np",
-) -> Union[np.ndarray, pd.DataFrame]:
+) -> Union[NDArray, pd.DataFrame]:
"""Load the weight matrix of the `classes`.
Parameters
@@ -430,7 +430,7 @@ def get_class_count(
exclude_classes: Optional[Sequence[str]] = None,
scored_only: bool = False,
normalize: bool = True,
- threshold: Optional[Real] = 0,
+ threshold: Union[float, int] = 0,
fmt: str = "a",
) -> Dict[str, int]:
"""Get the number of classes in the `tranches`.
@@ -447,20 +447,21 @@ def get_class_count(
normalize : bool, default True
Whether collapse equivalent classes into one or not,
used only when `scored_only` is True.
- threshold : numbers.Real
+ threshold : int or float, default 0
Minimum ratio (0-1) or absolute number (>1) of a class to be counted.
fmt : str, default "a"
Format of the names of the classes in the returned dict,
can be one of the following (case insensitive):
- - "a", abbreviations
- - "f", full names
- - "s", SNOMEDCTCode
+
+ - "a", abbreviations
+ - "f", full names
+ - "s", SNOMEDCTCode
Returns
-------
class_count : dict
- - key: class in the format of `fmt`
- - value: count of a class in `tranches`
+ key: class in the format of `fmt`,
+ value: count of a class in `tranches`.
"""
assert threshold >= 0
@@ -543,9 +544,9 @@ def get_class_weight(
fmt : str, default "a"
Format of the names of the classes in the returned dict,
can be one of the following (case insensitive):
- - "a", abbreviations
- - "f", full names
- - "s", SNOMED CT Code
+ - "a", abbreviations
+ - "f", full names
+ - "s", SNOMED CT Code
min_weight : numbers.Real, default 0.5
Minimum value of the weight of all classes,
or equivalently the weight of the largest class.
@@ -553,8 +554,8 @@ def get_class_weight(
Returns:
--------
class_weight : dict
- - key: class in the format of `fmt`
- - value: weight of a class in `tranches`
+ key: class in the format of `fmt`,
+ value: weight of a class in `tranches`.
"""
class_count = get_class_count(
diff --git a/torch_ecg/databases/base.py b/torch_ecg/databases/base.py
index 7b0cf887..34a51afa 100644
--- a/torch_ecg/databases/base.py
+++ b/torch_ecg/databases/base.py
@@ -2,10 +2,10 @@
"""
Base classes for datasets from different sources:
- - PhysioNet
- - NSRR
- - CPSC
- - Other databases
+- PhysioNet
+- NSRR
+- CPSC
+- Other databases
Remarks
-------
@@ -36,6 +36,7 @@
import requests
import scipy.signal as SS
import wfdb
+from numpy.typing import NDArray
from pyedflib import EdfReader
from ..cfg import _DATA_CACHE, CFG, DEFAULTS
@@ -135,11 +136,13 @@ class _DataBase(ReprMixin, ABC):
"""Universal abstract base class for all databases.
Abstract methods that should be implemented by the subclass:
+
- `_ls_rec`: Find all records in the database.
- `load_data`: Load data from a record.
- `load_ann`: Load annotations of a record.
Abstract properties that should be implemented by the subclass:
+
- `database_info`: The :class:`DataBaseInfo` object of the database.
- `url`: URL(s) for downloading the database.
@@ -172,7 +175,7 @@ def __init__(
f"`db_dir` is not specified, " f"using default `{db_dir}` as the storage path",
RuntimeWarning,
)
- self.db_dir = Path(db_dir).expanduser().resolve()
+ self.db_dir = Path(db_dir).expanduser().resolve() # type: ignore
if not self.db_dir.exists():
self.db_dir.mkdir(parents=True, exist_ok=True)
warnings.warn(
@@ -182,7 +185,7 @@ def __init__(
"please use the `download()` method.",
RuntimeWarning,
)
- self.working_dir = Path(working_dir or DEFAULTS.working_dir).expanduser().resolve().absolute() / self.db_name
+ self.working_dir = Path(working_dir or DEFAULTS.working_dir).expanduser().resolve().absolute() / self.db_name # type: ignore
self.working_dir.mkdir(parents=True, exist_ok=True)
self.logger = kwargs.get("logger", None)
@@ -256,7 +259,7 @@ def get_citation(self, format: Optional[str] = None, style: Optional[str] = None
"""
self.database_info.get_citation(lookup=True, format=format, style=style, timeout=10.0, print_result=True)
- def _auto_infer_units(self, sig: np.ndarray, sig_type: str = "ECG") -> str:
+ def _auto_infer_units(self, sig: NDArray, sig_type: str = "ECG") -> str:
"""Automatically infer the units of the signal.
It is assumed that `sig` is not raw signal, but with baseline removed.
@@ -286,7 +289,7 @@ def _auto_infer_units(self, sig: np.ndarray, sig_type: str = "ECG") -> str:
return units
@property
- def all_records(self) -> List[str]:
+ def all_records(self) -> Union[List[str], None]:
if self._all_records is None:
self._ls_rec()
return self._all_records
@@ -312,6 +315,7 @@ def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = Non
path = self._df_records.loc[rec].path
if extension is not None:
path = path.with_suffix(extension if extension.startswith(".") else f".{extension}")
+ path = Path(path).expanduser().resolve() # type: ignore
return path
def _normalize_leads(
@@ -347,24 +351,24 @@ def _normalize_leads(
all_leads = self.all_leads
err_msg = (
f"`leads` should be a subset of {all_leads} or non-negative integers "
- f"less than {len(all_leads)}, but got {leads}"
+ f"less than {len(all_leads)}, but got {leads}" # type: ignore
)
if leads is None or (isinstance(leads, str) and leads.lower() == "all"):
_leads = all_leads
elif isinstance(leads, str):
_leads = [leads]
elif isinstance(leads, int):
- assert len(all_leads) > leads >= 0, err_msg
- _leads = [all_leads[leads]]
+ assert len(all_leads) > leads >= 0, err_msg # type: ignore
+ _leads = [all_leads[leads]] # type: ignore
else:
try:
- _leads = [ld if isinstance(ld, str) else all_leads[ld] for ld in leads]
+ _leads = [ld if isinstance(ld, str) else all_leads[ld] for ld in leads] # type: ignore
except Exception:
raise AssertionError(err_msg)
- assert set(_leads).issubset(all_leads), err_msg
+ assert set(_leads).issubset(all_leads), err_msg # type: ignore
if numeric:
- _leads = [all_leads.index(ld) for ld in _leads]
- return _leads
+ _leads = [all_leads.index(ld) for ld in _leads] # type: ignore
+ return _leads # type: ignore
@classmethod
def get_arrhythmia_knowledge(cls, arrhythmias: Union[str, List[str]]) -> None:
@@ -416,7 +420,9 @@ def __len__(self) -> int:
Number of records in the database.
"""
- return len(self.all_records)
+ if self._all_records is None:
+ return 0
+ return len(self.all_records) # type: ignore
def __getitem__(self, index: int) -> str:
"""Get the record name by index.
@@ -434,7 +440,7 @@ def __getitem__(self, index: int) -> str:
Record name.
"""
- return self.all_records[index]
+ return self.all_records[index] # type: ignore
class PhysioNetDataBase(_DataBase):
@@ -490,13 +496,13 @@ def __init__(
self.df_all_db_info = get_physionet_dbs(local=(self.verbose <= 2))
if self.db_name not in self.df_all_db_info["db_name"].values:
if self.verbose <= 2:
- self.logger.warning(
+ self.logger.warning( # type: ignore
f"Database `{self.db_name}` is not found in the local database list. "
"Please check if the database name is correct, "
"or call method `_update_db_list()` to update the database list."
)
else:
- self.logger.warning(
+ self.logger.warning( # type: ignore
f"Database `{self.db_name}` is not found in the database list on PhysioNet server. "
"Please check if the database name is correct."
)
@@ -552,7 +558,7 @@ def _ls_rec(self, db_name: Optional[str] = None, local: bool = True) -> None:
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
- self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`")
+ self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore
self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
self._all_records = self._df_records.index.tolist()
except Exception:
@@ -572,7 +578,7 @@ def _ls_rec_local(self) -> None:
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
- self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`")
+ self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore
self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve())
self._df_records = self._df_records[self._df_records["path"].apply(lambda x: x.is_file())]
@@ -581,16 +587,16 @@ def _ls_rec_local(self) -> None:
if len(self._df_records) == 0:
print("Please wait patiently to let the reader find " "all records of the database from local storage...")
start = time.time()
- self._df_records["path"] = get_record_list_recursive(self.db_dir, self.data_ext, relative=False)
+ self._df_records["path"] = get_record_list_recursive(self.db_dir, self.data_ext, relative=False) # type: ignore
if self._subsample is not None:
size = min(
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
- self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`")
+ self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore
self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
- self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x))
- self.logger.info(f"Done in {time.time() - start:.3f} seconds!")
+ self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) # type: ignore
+ self.logger.info(f"Done in {time.time() - start:.3f} seconds!") # type: ignore
self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name)
self._df_records.set_index("record", inplace=True)
self._all_records = self._df_records.index.values.tolist()
@@ -611,17 +617,17 @@ def get_subject_id(self, rec: Union[str, int]) -> int:
"""
raise NotImplementedError
- def load_data(
+ def load_data( # type: ignore
self,
rec: Union[str, int],
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
- units: Union[str, type(None)] = "mV",
+ units: Union[str, None] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load physical (converted from digital) ECG data,
which is more understandable for humans;
or load digital signal directly.
@@ -705,9 +711,9 @@ def load_data(
sampto=sampto,
physical=units is not None,
return_res=DEFAULTS.DTYPE.INT,
- channels=[all_leads.index(ld) for ld in _leads],
+ channels=[all_leads.index(ld) for ld in _leads], # type: ignore
) # use `channels` instead of `channel_names` since there're exceptional cases where `channel_names` has duplicates
- wfdb_rec = wfdb.rdrecord(fp, **rdrecord_kwargs)
+ wfdb_rec = wfdb.rdrecord(fp, **rdrecord_kwargs) # type: ignore
# p_signal or d_signal is in the format of "lead_last", and with units in "mV"
if units is None:
@@ -715,7 +721,7 @@ def load_data(
elif units.lower() == "mv":
data = wfdb_rec.p_signal
elif units.lower() in ["μv", "uv", "muv"]:
- data = 1000 * wfdb_rec.p_signal
+ data = 1000 * wfdb_rec.p_signal # type: ignore
if fs is not None:
data_fs = fs
@@ -724,18 +730,18 @@ def load_data(
else:
data_fs = wfdb_rec.fs
if data_fs != wfdb_rec.fs:
- data = SS.resample_poly(data, data_fs, wfdb_rec.fs, axis=0).astype(data.dtype)
+ data = SS.resample_poly(data, data_fs, wfdb_rec.fs, axis=0).astype(data.dtype) # type: ignore
if data_format.lower() in ["channel_first", "lead_first"]:
- data = data.T
+ data = data.T # type: ignore
elif data_format.lower() in ["flat", "plain"]:
- data = data.flatten()
+ data = data.flatten() # type: ignore
if return_fs:
- return data, data_fs
- return data
+ return data, data_fs # type: ignore
+ return data # type: ignore
- def helper(self, items: Union[List[str], str, type(None)] = None) -> None:
+ def helper(self, items: Union[List[str], str, None] = None) -> None:
"""Print corr. meanings of symbols belonging to `items`.
More details can be found
@@ -837,11 +843,11 @@ def get_file_download_url(self, file_name: Union[str, bytes, os.PathLike]) -> st
URL of the file to be downloaded.
"""
- url = posixpath.join(
+ url = posixpath.join( # type: ignore
wfdb.io.download.PN_INDEX_URL,
self.db_name,
self.version,
- file_name,
+ file_name, # type: ignore
)
return url
@@ -879,7 +885,7 @@ def s3_url(self) -> str:
return f"s3://physionet-open/{self.db_name}/{self.version}/"
@property
- def url_(self) -> Union[str, type(None)]:
+ def url_(self) -> Union[str, None]:
"""URL of the compressed database file for downloading."""
if self._url_compressed is not None:
return self._url_compressed
@@ -888,7 +894,7 @@ def url_(self) -> Union[str, type(None)]:
try:
db_desc = self.df_all_db_info[self.df_all_db_info["db_name"] == self.db_name].iloc[0]["db_description"]
except IndexError:
- self.logger.info(f"\042{self.db_name}\042 is not in the database list hosted at PhysioNet!")
+ self.logger.info(f"\042{self.db_name}\042 is not in the database list hosted at PhysioNet!") # type: ignore
return None
db_desc = re.sub(f"[{punct}]+", "", db_desc).lower()
db_desc = re.sub("[\\s:]+", "-", db_desc)
@@ -924,7 +930,7 @@ def download(self, compressed: bool = True, use_s3: bool = True) -> None:
"""
if shutil.which("aws") is None:
use_s3 = False
- self.logger.warning("AWS CLI is not available! Downloading the database from PhysioNet...")
+ self.logger.warning("AWS CLI is not available! Downloading the database from PhysioNet...") # type: ignore
if use_s3:
http_get(self.s3_url, self.db_dir)
elif compressed:
@@ -933,7 +939,7 @@ def download(self, compressed: bool = True, use_s3: bool = True) -> None:
self._ls_rec()
return
else:
- self.logger.info("No compressed database available! Downloading the uncompressed version...")
+ self.logger.info("No compressed database available! Downloading the uncompressed version...") # type: ignore
else:
wfdb.dl_database(
self.db_name,
@@ -1085,7 +1091,7 @@ def show_rec_stats(self, rec: Union[str, int]) -> None:
"""
raise NotImplementedError
- def helper(self, items: Union[List[str], str, type(None)] = None) -> None:
+ def helper(self, items: Union[List[str], str, None] = None) -> None:
"""Print corr. meanings of symbols belonging to `items`.
Parameters
@@ -1181,7 +1187,7 @@ def get_subject_id(self, rec: Union[str, int]) -> int:
"""
raise NotImplementedError
- def helper(self, items: Union[List[str], str, type(None)] = None) -> None:
+ def helper(self, items: Union[List[str], str, None] = None) -> None:
"""Print corr. meanings of symbols belonging to `items`.
Parameters
@@ -1331,9 +1337,9 @@ def format_database_docstring(self, indent: Optional[str] = None) -> str:
docstring = f"{self.status}\n\n{docstring}"
lookup = os.getenv("DB_BIB_LOOKUP", False)
- citation = self.get_citation(lookup=lookup, print_result=False)
- if citation.startswith("@"):
- citation = textwrap.indent(citation, indent)
+ citation = self.get_citation(lookup=lookup, print_result=False) # type: ignore
+ if citation.startswith("@"): # type: ignore
+ citation = textwrap.indent(citation, indent) # type: ignore
citation = textwrap.indent(f"""Citation\n--------\n.. code-block:: bibtex\n\n{citation}""", indent)
docstring = f"{docstring}\n\n{citation}\n"
elif not lookup:
@@ -1359,7 +1365,7 @@ def sleep_stage_intervals_to_mask(
fs: Optional[int] = None,
granularity: int = 30,
class_map: Optional[Dict[str, int]] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Convert sleep stage intervals to sleep stage mask.
Parameters
@@ -1391,7 +1397,7 @@ def sleep_stage_intervals_to_mask(
assert class_map is not None, "`class_map` must be provided"
else:
class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)}
- intervals = {
+ intervals = { # type: ignore
class_map[k]: [[int(round(s / fs / granularity)), int(round(e / fs / granularity))] for s, e in v]
for k, v in intervals.items()
}
@@ -1406,7 +1412,7 @@ def sleep_stage_intervals_to_mask(
def plot_hypnogram(
self,
- mask: np.ndarray,
+ mask: NDArray,
granularity: int = 30,
class_map: Optional[Dict[str, int]] = None,
**kwargs,
@@ -1435,14 +1441,13 @@ def plot_hypnogram(
Axes object.
"""
+ import matplotlib.pyplot as plt
+
if not hasattr(self, "sleep_stage_names"):
pass
else:
class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)}
- if "plt" not in globals():
- import matplotlib.pyplot as plt
-
fig_width = len(mask) * granularity / 3600 / 6 * 20 # stardard width is 20 for 6 hours
fig, ax = plt.subplots(figsize=(fig_width, 4))
@@ -1457,8 +1462,8 @@ def plot_hypnogram(
ax.set_xlabel("Time", fontsize=18)
ax.set_xlim(0, len(mask))
# yticks to the format of sleep stages
- yticks = sorted(class_map.values())
- yticklabels = [k for k, v in sorted(class_map.items(), key=lambda x: x[1])]
+ yticks = sorted(class_map.values()) # type: ignore
+ yticklabels = [k for k, v in sorted(class_map.items(), key=lambda x: x[1])] # type: ignore
ax.set_yticks(yticks)
ax.set_yticklabels(yticklabels, fontsize=14)
ax.set_ylabel("Sleep Stage", fontsize=18)
diff --git a/torch_ecg/databases/cpsc_databases/cpsc2018.py b/torch_ecg/databases/cpsc_databases/cpsc2018.py
index a1d92e96..4ce15724 100644
--- a/torch_ecg/databases/cpsc_databases/cpsc2018.py
+++ b/torch_ecg/databases/cpsc_databases/cpsc2018.py
@@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from scipy.io import loadmat
from ...cfg import DEFAULTS
@@ -264,7 +265,7 @@ def load_data(
data_format="channel_first",
units: str = "mV",
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load the ECG data of a record.
Parameters
@@ -332,10 +333,9 @@ def load_ann(self, rec: Union[str, int], ann_format: str = "n") -> List[str]:
Record name or index of the record in :attr:`all_records`.
ann_format : str, default "n"
Format of labels, one of the following (case insensitive):
-
- - "a", abbreviations
- - "f", full names
- - "n", numeric codes
+ - "a", abbreviations
+ - "f", full names
+ - "n", numeric codes
Returns
-------
diff --git a/torch_ecg/databases/cpsc_databases/cpsc2019.py b/torch_ecg/databases/cpsc_databases/cpsc2019.py
index a49a5b63..fc229b76 100644
--- a/torch_ecg/databases/cpsc_databases/cpsc2019.py
+++ b/torch_ecg/databases/cpsc_databases/cpsc2019.py
@@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
import scipy.signal as SS
+from numpy.typing import NDArray
from scipy.io import loadmat
from ...cfg import DEFAULTS
@@ -222,7 +223,7 @@ def load_data(
units: str = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load the ECG data of the record `rec`.
Parameters
@@ -273,7 +274,7 @@ def load_data(
return data, data_fs
return data
- def load_ann(self, rec: Union[int, str]) -> np.ndarray:
+ def load_ann(self, rec: Union[int, str]) -> NDArray:
"""Load the annotations (indices of R peaks) of the record `rec`.
Parameters
@@ -292,14 +293,14 @@ def load_ann(self, rec: Union[int, str]) -> np.ndarray:
return ann
@add_docstring(load_ann.__doc__)
- def load_rpeaks(self, rec: Union[int, str]) -> np.ndarray:
+ def load_rpeaks(self, rec: Union[int, str]) -> NDArray:
"""
alias of `self.load_ann`
"""
return self.load_ann(rec=rec)
@add_docstring(load_rpeaks.__doc__)
- def load_rpeak_indices(self, rec: Union[int, str]) -> np.ndarray:
+ def load_rpeak_indices(self, rec: Union[int, str]) -> NDArray:
"""
alias of `self.load_rpeaks`
"""
@@ -308,8 +309,8 @@ def load_rpeak_indices(self, rec: Union[int, str]) -> np.ndarray:
def plot(
self,
rec: Union[int, str],
- data: Optional[np.ndarray] = None,
- ann: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[NDArray] = None,
ticks_granularity: int = 0,
**kwargs: Any,
) -> None:
@@ -397,8 +398,8 @@ def webpage(self) -> str:
def compute_metrics(
- rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]],
- rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]],
+ rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]],
+ rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]],
fs: Real,
thr: float = 0.075,
verbose: int = 0,
diff --git a/torch_ecg/databases/cpsc_databases/cpsc2020.py b/torch_ecg/databases/cpsc_databases/cpsc2020.py
index 2139d699..f36f029f 100644
--- a/torch_ecg/databases/cpsc_databases/cpsc2020.py
+++ b/torch_ecg/databases/cpsc_databases/cpsc2020.py
@@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
import scipy.signal as SS
+from numpy.typing import NDArray
from scipy.io import loadmat
from ...cfg import CFG, DEFAULTS
@@ -380,7 +381,7 @@ def load_data(
units: str = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load the ECG data of the record `rec`.
Parameters
@@ -444,7 +445,7 @@ def load_ann(
rec: Union[int, str],
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
- ) -> Dict[str, np.ndarray]:
+ ) -> Dict[str, NDArray]:
"""Load the annotations of the record `rec`.
Parameters
@@ -587,12 +588,12 @@ def locate_premature_beats(
def plot(
self,
rec: Union[int, str],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
ticks_granularity: int = 0,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
) -> None:
"""Plot the ECG signal of a record.
@@ -750,10 +751,10 @@ def webpage(self) -> str:
def compute_metrics(
- sbp_true: List[np.ndarray],
- pvc_true: List[np.ndarray],
- sbp_pred: List[np.ndarray],
- pvc_pred: List[np.ndarray],
+ sbp_true: List[NDArray],
+ pvc_true: List[NDArray],
+ sbp_pred: List[NDArray],
+ pvc_pred: List[NDArray],
verbose: int = 0,
) -> Union[Tuple[int], dict]:
"""Score Function for all (test) records.
@@ -771,10 +772,10 @@ def compute_metrics(
Tuple of (negative) scores for each ectopic beat type (SBP, PVC),
or dict of more scoring details, including
- - total_loss: sum of loss of each ectopic beat type (PVC and SPB)
- - true_positive: number of true positives of each ectopic beat type
- - false_positive: number of false positives of each ectopic beat type
- - false_negative: number of false negatives of each ectopic beat type
+ - total_loss: sum of loss of each ectopic beat type (PVC and SPB)
+ - true_positive: number of true positives of each ectopic beat type
+ - false_positive: number of false positives of each ectopic beat type
+ - false_negative: number of false negatives of each ectopic beat type
"""
BaseCfg = CFG()
diff --git a/torch_ecg/databases/cpsc_databases/cpsc2021.py b/torch_ecg/databases/cpsc_databases/cpsc2021.py
index 76aa9d72..990c4f59 100644
--- a/torch_ecg/databases/cpsc_databases/cpsc2021.py
+++ b/torch_ecg/databases/cpsc_databases/cpsc2021.py
@@ -13,6 +13,7 @@
import pandas as pd
import scipy.io as sio
import wfdb
+from numpy.typing import NDArray
from ...cfg import CFG, DEFAULTS
from ...utils.misc import add_docstring, get_record_list_recursive3, ms2samples
@@ -40,50 +41,50 @@
7. classification of a record is stored in corresponding .hea file, which can be accessed via the attribute `comments` of a wfdb Record obtained using :func:`wfdb.rdheader`, :func:`wfdb.rdrecord`, and :func:`wfdb.rdsamp`; beat annotations and rhythm annotations can be accessed using the attributes `symbol`, `aux_note` of a ``wfdb`` Annotation obtained using :func:`wfdb.rdann`, corresponding indices in the signal can be accessed via the attribute `sample`
8. challenge task:
- - clasification of rhythm types: non-AF rhythm (N), persistent AF rhythm (AFf) and paroxysmal AF rhythm (AFp)
- - locating of the onset and offset for any AF episode prediction
+ - clasification of rhythm types: non-AF rhythm (N), persistent AF rhythm (AFf) and paroxysmal AF rhythm (AFp)
+ - locating of the onset and offset for any AF episode prediction
9. challenge metrics:
- - metrics (Ur, scoring matrix) for classification:
+ - metrics (Ur, scoring matrix) for classification:
- .. tikz:: The scoring matrix for the recording-level classification result.
- :align: center
- :libs: positioning
+ .. tikz:: The scoring matrix for the recording-level classification result.
+ :align: center
+ :libs: positioning
- \tikzstyle{rect} = [rectangle, text width = 50, text centered, inner sep = 3pt, minimum height = 50]
- \tikzstyle{txt} = [rectangle, text centered, inner sep = 3pt, minimum height = 1.5]
+ \tikzstyle{rect} = [rectangle, text width = 50, text centered, inner sep = 3pt, minimum height = 50]
+ \tikzstyle{txt} = [rectangle, text centered, inner sep = 3pt, minimum height = 1.5]
- \node[rect, fill = green!25] at (0,0) (31) {$-0.5$};
- \node[rect, fill = green!10, right = 0 of 31] (32) {$0$};
- \node[rect, fill = red!30, right = 0 of 32] (33) {$+1$};
- \node[rect, fill = green!40, above = 0 of 31] (21) {$-1$};
- \node[rect, fill = red!30, above = 0 of 32] (22) {$+1$};
- \node[rect, fill = green!10, above = 0 of 33] (23) {$0$};
- \node[rect, fill = red!30, above = 0 of 21] (11) {$+1$};
- \node[rect, fill = green!60, above = 0 of 22] (12) {$-2$};
- \node[rect, fill = green!40, above = 0 of 23] (13) {$-1$};
+ \node[rect, fill = green!25] at (0,0) (31) {$-0.5$};
+ \node[rect, fill = green!10, right = 0 of 31] (32) {$0$};
+ \node[rect, fill = red!30, right = 0 of 32] (33) {$+1$};
+ \node[rect, fill = green!40, above = 0 of 31] (21) {$-1$};
+ \node[rect, fill = red!30, above = 0 of 32] (22) {$+1$};
+ \node[rect, fill = green!10, above = 0 of 33] (23) {$0$};
+ \node[rect, fill = red!30, above = 0 of 21] (11) {$+1$};
+ \node[rect, fill = green!60, above = 0 of 22] (12) {$-2$};
+ \node[rect, fill = green!40, above = 0 of 23] (13) {$-1$};
- \node[txt, below = 0 of 31] {N};
- \node[txt, below = 0 of 32] (anchor_h) {AF$_{\text{f}}$};
- \node[txt, below = 0 of 33] {AF$_{\text{p}}$};
- \node[txt, left = 0 of 31] {AF$_{\text{p}}$};
- \node[txt, left = 0 of 21] (anchor_v) {AF$_{\text{f}}$};
- \node[txt, left = 0 of 11] {N};
+ \node[txt, below = 0 of 31] {N};
+ \node[txt, below = 0 of 32] (anchor_h) {AF$_{\text{f}}$};
+ \node[txt, below = 0 of 33] {AF$_{\text{p}}$};
+ \node[txt, left = 0 of 31] {AF$_{\text{p}}$};
+ \node[txt, left = 0 of 21] (anchor_v) {AF$_{\text{f}}$};
+ \node[txt, left = 0 of 11] {N};
- \node[txt, below = 0 of anchor_h] {\large\textbf{Annotation (Label)}};
- \node[txt, left = 0.6 of anchor_v, rotate = 90, anchor = north] {\large\textbf{Prediction}};
+ \node[txt, below = 0 of anchor_h] {\large\textbf{Annotation (Label)}};
+ \node[txt, left = 0.6 of anchor_v, rotate = 90, anchor = north] {\large\textbf{Prediction}};
- - metric (Ue) for detecting onsets and offsets for AF events (episodes): +1 if the detected onset (or offset) is within ±1 beat of the annotated position, and +0.5 if within ±2 beats.
- - final score (U):
+ - metric (Ue) for detecting onsets and offsets for AF events (episodes): +1 if the detected onset (or offset) is within ±1 beat of the annotated position, and +0.5 if within ±2 beats.
+ - final score (U):
- .. math::
+ .. math::
U = \dfrac{1}{N} \sum\limits_{i=1}^N \left( Ur_i + \dfrac{Ma_i}{\max\{Mr_i, Ma_i\}} \right)
- where :math:`N` is the number of records,
- :math:`Ma` is the number of annotated AF episodes,
- :math:`Mr` is the number of predicted AF episodes.
+ where :math:`N` is the number of records,
+ :math:`Ma` is the number of annotated AF episodes,
+ :math:`Mr` is the number of predicted AF episodes.
10. Challenge official website [1]_. Webpage of the database on PhysioNet [2]_.
""",
@@ -405,7 +406,7 @@ def load_ann(
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
**kwargs: Any,
- ) -> Union[dict, np.ndarray, List[List[int]], str]:
+ ) -> Union[dict, NDArray, List[List[int]], str]:
"""Load annotations of the record.
Parameters
@@ -424,12 +425,9 @@ def load_ann(
Key word arguments for functions
loading rpeaks, af_episodes, and label respectively,
including:
-
- - fs: int, optional,
- the resampling frequency
- - fmt: str,
- format of af_episodes, or format of label,
- for more details, ref. corresponding functions.
+ - fs: int, optional, the resampling frequency
+ - fmt: str, format of af_episodes, or format of label,
+ for more details, ref. corresponding functions.
Used only when `field` is specified (not None).
@@ -479,7 +477,7 @@ def load_rpeaks(
keep_original: bool = False,
valid_only: bool = True,
fs: Optional[Real] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load position (in terms of samples) of rpeaks.
Parameters
@@ -542,7 +540,7 @@ def load_rpeak_indices(
keep_original: bool = False,
valid_only: bool = True,
fs: Optional[Real] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""alias of `self.load_rpeaks`"""
return self.load_rpeaks(
rec=rec,
@@ -563,7 +561,7 @@ def load_af_episodes(
keep_original: bool = False,
fs: Optional[Real] = None,
fmt: Literal["intervals", "mask", "c_intervals"] = "intervals",
- ) -> Union[List[List[int]], np.ndarray]:
+ ) -> Union[List[List[int]], NDArray]:
"""Load the episodes of atrial fibrillation,
in terms of intervals or mask.
@@ -666,9 +664,9 @@ def load_label(
The three classes are:
- - "non atrial fibrillation",
- - "paroxysmal atrial fibrillation",
- - "persistent atrial fibrillation".
+ - "non atrial fibrillation",
+ - "paroxysmal atrial fibrillation",
+ - "persistent atrial fibrillation".
Parameters
----------
@@ -682,11 +680,10 @@ def load_label(
Not used, to keep in accordance with other methods.
fmt : str, default "a"
Format of the label, case in-sensitive, can be one of
-
- - "f", "fullname": the full name of the label
- - "a", "abbr", "abbrevation": abbreviation for the label
- - "n", "num", "number": class number of the label
- (in accordance with the settings of the offical class map)
+ - "f", "fullname": the full name of the label
+ - "a", "abbr", "abbrevation": abbreviation for the label
+ - "n", "num", "number": class number of the label
+ (in accordance with the settings of the offical class map)
Returns
-------
@@ -709,7 +706,7 @@ def gen_endpoint_score_mask(
rec: Union[str, int],
bias: dict = {1: 1, 2: 0.5},
verbose: Optional[int] = None,
- ) -> Tuple[np.ndarray, np.ndarray]:
+ ) -> Tuple[NDArray, NDArray]:
"""Generate the scoring mask for the onsets and offsets of af episodes.
Parameters
@@ -749,8 +746,8 @@ def gen_endpoint_score_mask(
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
ticks_granularity: int = 0,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
@@ -1304,7 +1301,7 @@ def gen_endpoint_score_mask(
af_intervals: Sequence[Sequence[int]],
bias: dict = {1: 1, 2: 0.5},
verbose: int = 0,
-) -> Tuple[np.ndarray, np.ndarray]:
+) -> Tuple[NDArray, NDArray]:
"""
generate the scoring mask for the onsets and offsets of af episodes,
diff --git a/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py b/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py
index cf137e02..c9061819 100644
--- a/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py
+++ b/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py
@@ -7,6 +7,7 @@
import numpy as np
import torch
+from numpy.typing import NDArray
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm
@@ -125,7 +126,7 @@ def _load_all_data(self) -> None:
self._signals = np.concatenate(self._signals, axis=0).astype(self.dtype)
self._labels = np.concatenate(self._labels, axis=0)
- def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]:
+ def _load_one_record(self, rec: str) -> Tuple[NDArray, NDArray]:
"""Load a record from the database using database reader.
NOTE
@@ -165,20 +166,20 @@ def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]:
return values, labels
@property
- def signals(self) -> np.ndarray:
+ def signals(self) -> NDArray:
"""Cached signals, only available when `lazy=False`
or preloading is performed manually.
"""
return self._signals
@property
- def labels(self) -> np.ndarray:
+ def labels(self) -> NDArray:
"""Cached labels, only available when `lazy=False`
or preloading is performed manually.
"""
return self._labels
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
return self.signals[index], self.labels[index]
def __len__(self) -> int:
@@ -396,7 +397,7 @@ def __init__(
def __len__(self) -> int:
return len(self.records)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
rec = self.records[index]
diff --git a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py
index 9f3a867b..2c9cab4e 100644
--- a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py
+++ b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
+from numpy.typing import NDArray
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm
@@ -141,7 +142,7 @@ def _load_all_data(self) -> None:
self._signals = np.concatenate(self._signals, axis=0).astype(self.dtype)
self._labels = np.concatenate(self._labels, axis=0)
- def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]:
+ def _load_one_record(self, rec: str) -> Tuple[NDArray, NDArray]:
"""Load one record from the database using data reader.
NOTE
@@ -277,20 +278,20 @@ def reload_from_extern(self, ext_ds: "CINC2021Dataset") -> None:
self._labels = ext_ds._labels.copy()
@property
- def signals(self) -> np.ndarray:
+ def signals(self) -> NDArray:
"""Cached signals, only available when `lazy=False`
or preloading is performed manually.
"""
return self._signals
@property
- def labels(self) -> np.ndarray:
+ def labels(self) -> NDArray:
"""Cached labels, only available when `lazy=False`
or preloading is performed manually.
"""
return self._labels
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
return self.signals[index], self.labels[index]
def __len__(self) -> int:
@@ -534,7 +535,7 @@ def __init__(
def __len__(self) -> int:
return len(self.records)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
rec = self.records[index]
diff --git a/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py b/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py
index 6b48e943..6ffb248d 100644
--- a/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py
+++ b/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py
@@ -6,6 +6,7 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm
@@ -83,7 +84,7 @@ def __init__(
if not self.lazy:
self._load_all_data()
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if self.lazy:
signal, label = self.fdr[index]
else:
@@ -110,14 +111,14 @@ def _load_all_data(self) -> None:
self._labels = np.array(self._labels)
@property
- def signals(self) -> np.ndarray:
+ def signals(self) -> NDArray:
"""Cached signals, only available when `lazy=False`
or preloading is performed manually.
"""
return self._signals
@property
- def labels(self) -> np.ndarray:
+ def labels(self) -> NDArray:
"""Cached labels, only available when `lazy=False`
or preloading is performed manually.
"""
@@ -205,7 +206,7 @@ def __init__(
def __len__(self) -> int:
return len(self.records)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
rec_name = self.records[index]
diff --git a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py
index 136c67fc..6051a9d2 100644
--- a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py
+++ b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py
@@ -46,6 +46,7 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
from scipy import signal as SS
from scipy.io import loadmat, savemat
from torch.utils.data.dataset import Dataset
@@ -78,19 +79,19 @@ class CPSC2021Dataset(ReprMixin, Dataset):
The returned values (tuple) of :meth:`__getitem__` depends on the task:
- 1. "qrs_detection": (`data`, `qrs_mask`, None)
- 2. "rr_lstm": (`rr_seq`, `rr_af_mask`, `rr_weight_mask`)
- 3. "main": (`data`, `af_mask`, `weight_mask`)
+ 1. "qrs_detection": (`data`, `qrs_mask`, None)
+ 2. "rr_lstm": (`rr_seq`, `rr_af_mask`, `rr_weight_mask`)
+ 3. "main": (`data`, `af_mask`, `weight_mask`)
where
- - `data` shape: ``(n_lead, n_sample)``
- - `qrs_mask` shape: ``(n_sample, 1)``
- - `af_mask` shape: ``(n_sample, 1)``
- - `weight_mask` shape: ``(n_sample, 1)``
- - `rr_seq` shape: ``(n_rr, 1)``
- - `rr_af_mask` shape: ``(n_rr, 1)``
- - `rr_weight_mask` shape: ``(n_rr, 1)``
+ - `data` shape: ``(n_lead, n_sample)``
+ - `qrs_mask` shape: ``(n_sample, 1)``
+ - `af_mask` shape: ``(n_sample, 1)``
+ - `weight_mask` shape: ``(n_sample, 1)``
+ - `rr_seq` shape: ``(n_rr, 1)``
+ - `rr_af_mask` shape: ``(n_rr, 1)``
+ - `rr_weight_mask` shape: ``(n_rr, 1)``
Typical values of ``n_sample`` and ``n_rr`` are 6000 and 30, respectively.
@@ -386,7 +387,7 @@ def all_rr_seq(self) -> CFG:
def __len__(self) -> int:
return len(self.fdr)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]:
if self.lazy:
if self.task in ["qrs_detection"]:
return self.fdr[index][:2]
@@ -426,7 +427,7 @@ def _get_seg_ann_path(self, seg: str) -> Path:
fp = self.segments_dirs.ann[subject] / f"{seg}.{self.segment_ext}"
return fp
- def _load_seg_data(self, seg: str) -> np.ndarray:
+ def _load_seg_data(self, seg: str) -> NDArray:
"""Load the data of the segment.
Parameters
@@ -456,19 +457,18 @@ def _load_seg_ann(self, seg: str) -> dict:
-------
seg_ann : dict
Annotations of the segment, containing:
-
- - rpeaks: indices of rpeaks of the segment
- - qrs_mask: mask of qrs complexes of the segment
- - af_mask: mask of af episodes of the segment
- - interval: interval ([start_idx, end_idx]) in
- the original ECG record of the segment
+ - rpeaks: indices of rpeaks of the segment
+ - qrs_mask: mask of qrs complexes of the segment
+ - af_mask: mask of af episodes of the segment
+ - interval: interval ([start_idx, end_idx]) in
+ the original ECG record of the segment
"""
seg_ann_fp = self._get_seg_ann_path(seg)
seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")}
return seg_ann
- def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[NDArray, Dict[str, NDArray]]:
"""Load the mask(s) of segment.
Parameters
@@ -510,7 +510,7 @@ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarr
seg_mask = seg_mask["af_mask"]
return seg_mask
- def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray:
+ def _load_seg_seq_lab(self, seg: str, reduction: int) -> NDArray:
"""Load sequence labeling annotations of the segment.
Parameters
@@ -561,7 +561,7 @@ def _get_rr_seq_path(self, seq_name: str) -> Path:
fp = self.rr_seq_dirs[subject] / f"{seq_name}.{self.rr_seq_ext}"
return fp
- def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]:
+ def _load_rr_seq(self, seq_name: str) -> Dict[str, NDArray]:
"""Load the metadata of the rr_seq.
Parameters
@@ -573,13 +573,12 @@ def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]:
-------
dict
metadata of sequence of rr intervals, including
-
- - rr: the sequence of rr intervals, with units in seconds,
- of shape ``(self.seglen, 1)``
- - label: label of the rr intervals, 0 for normal, 1 for af,
- of shape ``(self.seglen, self.n_classes)``
- - interval: interval of the current rr sequence in the whole
- rr sequence in the original record
+ - rr: the sequence of rr intervals, with units in seconds,
+ of shape ``(self.seglen, 1)``
+ - label: label of the rr intervals, 0 for normal, 1 for af,
+ of shape ``(self.seglen, self.n_classes)``
+ - interval: interval of the current rr sequence in the whole
+ rr sequence in the original record
"""
rr_seq_path = self._get_rr_seq_path(seq_name)
@@ -680,7 +679,7 @@ def _preprocess_one_record(self, rec: str, force_recompute: bool = False, verbos
pps, _ = self.ppm(self.reader.load_data(rec), self.config.fs)
savemat(save_fp, {"ecg": pps}, format="5")
- def load_preprocessed_data(self, rec: str) -> np.ndarray:
+ def load_preprocessed_data(self, rec: str) -> NDArray:
"""Load the preprocessed data of the record.
Parameters
@@ -860,7 +859,7 @@ def _slice_one_record(
def __generate_segment(
self,
rec: str,
- data: np.ndarray,
+ data: NDArray,
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
) -> CFG:
@@ -885,12 +884,12 @@ def __generate_segment(
dict
Segments (meta-)data, containing:
- - data: values of the segment, with units in mV
- - rpeaks: indices of rpeaks of the segment
- - qrs_mask: mask of qrs complexes of the segment
- - af_mask: mask of af episodes of the segment
- - interval: interval ([start_idx, end_idx]) in the
- original ECG record of the segment
+ - data: values of the segment, with units in mV
+ - rpeaks: indices of rpeaks of the segment
+ - qrs_mask: mask of qrs complexes of the segment
+ - af_mask: mask of af episodes of the segment
+ - interval: interval ([start_idx, end_idx]) in the
+ original ECG record of the segment
"""
assert not all([start_idx is None, end_idx is None]), "at least one of `start_idx` and `end_idx` should be set"
@@ -1439,7 +1438,7 @@ def __init__(
"main": "af_mask",
}
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
if self.task in [
diff --git a/torch_ecg/databases/datasets/ludb/ludb_dataset.py b/torch_ecg/databases/datasets/ludb/ludb_dataset.py
index cddd4043..10cbd1b1 100644
--- a/torch_ecg/databases/datasets/ludb/ludb_dataset.py
+++ b/torch_ecg/databases/datasets/ludb/ludb_dataset.py
@@ -7,6 +7,7 @@
from typing import Any, List, Optional, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm
@@ -83,7 +84,7 @@ def __len__(self) -> int:
return len(self.leads) * len(self.records)
return len(self.records)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
if self.config.use_single_lead:
@@ -124,14 +125,14 @@ def _load_all_data(self) -> None:
self._labels = np.array(self._labels)
@property
- def signals(self) -> np.ndarray:
+ def signals(self) -> NDArray:
"""Cached signals, only available when `lazy=False`
or preloading is performed manually.
"""
return self._signals
@property
- def labels(self) -> np.ndarray:
+ def labels(self) -> NDArray:
"""Cached labels, only available when `lazy=False`
or preloading is performed manually.
"""
@@ -230,7 +231,7 @@ def __init__(
def __len__(self) -> int:
return len(self.records)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
rec = self.records[index]
diff --git a/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py b/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py
index aa1c6c76..7d324fba 100644
--- a/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py
+++ b/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py
@@ -8,6 +8,7 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
from scipy import signal as SS
from scipy.io import loadmat, savemat
from torch.utils.data.dataset import Dataset
@@ -359,7 +360,7 @@ def __len__(self) -> int:
return len(self._all_data)
return len(self.fdr)
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]:
if self.task in ["beat_classification"]:
return self._all_data[index], self._all_labels[index]
if self.lazy:
@@ -401,7 +402,7 @@ def _get_seg_ann_path(self, seg: str) -> Path:
fp = self.segments_dirs.ann[rec] / f"{seg}.{self.segment_ext}"
return fp
- def _load_seg_data(self, seg: str) -> np.ndarray:
+ def _load_seg_data(self, seg: str) -> NDArray:
"""Load data of the segment.
Parameters
@@ -432,18 +433,18 @@ def _load_seg_ann(self, seg: str) -> dict:
dict
A dictionay of annotations of the segment, including
- - rpeaks: indices of rpeaks of the segment
- - qrs_mask: mask of qrs complexes of the segment
- - rhythm_mask: mask of rhythms of the segment
- - interval: interval ([start_idx, end_idx]) in the
- original ECG record of the segment
+ - rpeaks: indices of rpeaks of the segment
+ - qrs_mask: mask of qrs complexes of the segment
+ - rhythm_mask: mask of rhythms of the segment
+ - interval: interval ([start_idx, end_idx]) in the
+ original ECG record of the segment
"""
seg_ann_fp = self._get_seg_ann_path(seg)
seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")}
return seg_ann
- def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]:
+ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[NDArray, Dict[str, NDArray]]:
"""Load mask(s) of the segment.
Parameters
@@ -483,7 +484,7 @@ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarr
seg_mask = seg_mask["rhythm_mask"]
return seg_mask
- def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray:
+ def _load_seg_seq_lab(self, seg: str, reduction: int) -> NDArray:
"""Load sequence label of the segment.
Parameters
@@ -534,7 +535,7 @@ def _get_rr_seq_path(self, seq_name: str) -> Path:
fp = self.rr_seq_dirs[rec] / f"{seq_name}.{self.rr_seq_ext}"
return fp
- def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]:
+ def _load_rr_seq(self, seq_name: str) -> Dict[str, NDArray]:
"""Load metadata of sequence of rr intervals.
Parameters
@@ -547,12 +548,12 @@ def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]:
dict
Metadata of sequence of rr intervals, including
- - rr: the sequence of rr intervals, with units in seconds,
- of shape ``(self.seglen, 1)``
- - label: label of the rr intervals,
- of shape ``(self.seglen, self.n_classes)``
- - interval: interval of the current rr sequence
- in the whole rr sequence in the original record
+ - rr: the sequence of rr intervals, with units in seconds,
+ of shape ``(self.seglen, 1)``.
+ - label: label of the rr intervals,
+ of shape ``(self.seglen, self.n_classes)``.
+ - interval: interval of the current rr sequence
+ in the whole rr sequence in the original record.
"""
rr_seq_path = self._get_rr_seq_path(seq_name)
@@ -772,7 +773,7 @@ def _slice_one_record(
def __generate_segment(
self,
rec: str,
- data: np.ndarray,
+ data: NDArray,
start_idx: Optional[int] = None,
end_idx: Optional[int] = None,
) -> CFG:
@@ -1311,7 +1312,7 @@ def __init__(
"af_event": "rhythm_mask", # segmentation of AF events
}
- def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]:
+ def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]:
if isinstance(index, slice):
return collate_fn([self[i] for i in range(*index.indices(len(self)))])
if self.task in [
diff --git a/torch_ecg/databases/nsrr_databases/shhs.py b/torch_ecg/databases/nsrr_databases/shhs.py
index 2b911edf..c43d3762 100644
--- a/torch_ecg/databases/nsrr_databases/shhs.py
+++ b/torch_ecg/databases/nsrr_databases/shhs.py
@@ -5,7 +5,6 @@
import re
import warnings
from datetime import datetime
-from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
@@ -13,6 +12,7 @@
import pandas as pd
import scipy.signal as SS
import xmltodict as xtd
+from numpy.typing import NDArray
from tqdm.auto import tqdm
from ...cfg import DEFAULTS
@@ -78,12 +78,14 @@
2. Obstructive Apnea Index (OAI):
- - There is one OAI index in the data set. It reflects obstructive events associated with a 4% desaturation or arousal. Nearly 30% of the cohort has a zero value for this variable
+ - There is one OAI index in the data set. It reflects obstructive events associated with a 4% desaturation or arousal.
+ Nearly 30% of the cohort has a zero value for this variable
- Dichotomization is suggested (e.g. >=3 or >=4 events per hour indicates positive)
3. Central Apnea Index (CAI):
- - Several variables describe central breathing events, with different thresholds for desaturation and requirement/non-requirement of arousals. ~58% of the cohort have zero values
+ - Several variables describe central breathing events, with different thresholds for desaturation and
+ requirement/non-requirement of arousals. ~58% of the cohort have zero values
- Dichotomization is suggested (e.g. >=3 or >=4 events per hour indicates positive)
4. Sleep Stages:
@@ -387,7 +389,7 @@ def _ls_rec(self) -> None:
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
"""
- self.logger.info("Finding `edf` records....")
+ self.logger.info("Finding `edf` records....") # type: ignore
self._df_records = pd.DataFrame()
self._df_records["path"] = sorted(self.db_dir.rglob("*.edf"))
@@ -396,7 +398,7 @@ def _ls_rec(self) -> None:
len(self._df_records),
max(1, int(round(self._subsample * len(self._df_records)))),
)
- self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False)
+ self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) # type: ignore
# if self._df_records is non-empty, call `form_paths` again if necessary
# typically path for a record is like:
@@ -423,15 +425,15 @@ def _ls_rec(self) -> None:
self._all_records = self._df_records.index.tolist()
# update `current_version`
- if self.ann_path.is_dir():
- for file in self.ann_path.iterdir():
+ if self.ann_path.is_dir(): # type: ignore
+ for file in self.ann_path.iterdir(): # type: ignore
if file.is_file() and len(re.findall(self.version_pattern, file.name)) > 0:
self.current_version = re.findall(self.version_pattern, file.name)[0]
break
- self.logger.info("Loading tables....")
+ self.logger.info("Loading tables....") # type: ignore
# gather tables in self.ann_path and in self.hrv_ann_path
- for file in itertools.chain(self.ann_path.glob("*.csv"), self.hrv_ann_path.glob("*.csv")):
+ for file in itertools.chain(self.ann_path.glob("*.csv"), self.hrv_ann_path.glob("*.csv")): # type: ignore
if not file.suffix == ".csv":
continue
table_name = file.stem.replace(f"-{self.current_version}", "")
@@ -440,7 +442,7 @@ def _ls_rec(self) -> None:
except UnicodeDecodeError:
self._tables[table_name] = pd.read_csv(file, low_memory=False, encoding="latin-1")
- self.logger.info("Finding records with HRV annotations....")
+ self.logger.info("Finding records with HRV annotations....") # type: ignore
# find records with hrv annotations
self.rec_with_hrv_summary_ann = []
for table_name in ["shhs1-hrv-summary", "shhs2-hrv-summary"]:
@@ -457,17 +459,17 @@ def _ls_rec(self) -> None:
)
self.rec_with_hrv_detailed_ann = sorted(list(set(self.rec_with_hrv_detailed_ann)))
- self.logger.info("Finding records with rpeaks annotations....")
+ self.logger.info("Finding records with rpeaks annotations....") # type: ignore
# find available rpeak annotation files
self.rec_with_rpeaks_ann = sorted(
- [f.stem.replace("-rpoint", "") for f in self.wave_deli_path.rglob("shhs*-rpoint.csv")]
+ [f.stem.replace("-rpoint", "") for f in self.wave_deli_path.rglob("shhs*-rpoint.csv")] # type: ignore
)
- self.logger.info("Finding records with event annotations....")
+ self.logger.info("Finding records with event annotations....") # type: ignore
# find available event annotation files
- self.rec_with_event_ann = sorted([f.stem.replace("-nsrr", "") for f in self.event_ann_path.rglob("shhs*-nsrr.xml")])
+ self.rec_with_event_ann = sorted([f.stem.replace("-nsrr", "") for f in self.event_ann_path.rglob("shhs*-nsrr.xml")]) # type: ignore
self.rec_with_event_profusion_ann = sorted(
- [f.stem.replace("-profusion", "") for f in self.event_profusion_ann_path.rglob("shhs*-profusion.xml")]
+ [f.stem.replace("-profusion", "") for f in self.event_profusion_ann_path.rglob("shhs*-profusion.xml")] # type: ignore
)
self._df_records["available_signals"] = None
@@ -568,11 +570,11 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str],
mininterval=1.0,
disable=(self.verbose < 1),
):
- rec = row.name
- if self._df_records.loc[rec, "available_signals"] is not None:
+ rec = row.name # type: ignore
+ if self._df_records.loc[rec, "available_signals"] is not None: # type: ignore
continue
available_signals = self.get_available_signals(rec)
- self._df_records.at[rec, "available_signals"] = available_signals
+ self._df_records.at[rec, "available_signals"] = available_signals # type: ignore
return
if isinstance(rec, int):
@@ -580,8 +582,8 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str],
if rec in self._df_records.index:
available_signals = self._df_records.loc[rec, "available_signals"]
- if available_signals is not None and len(available_signals) > 0:
- return available_signals
+ if available_signals is not None and len(available_signals) > 0: # type: ignore
+ return available_signals # type: ignore
frp = self.get_absolute_path(rec)
try:
@@ -590,10 +592,10 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str],
self.safe_edf_file_operation("open", frp)
except OSError:
return None
- available_signals = [s.lower() for s in self.file_opened.getSignalLabels()]
+ available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore
self.safe_edf_file_operation("close")
self._df_records.at[rec, "available_signals"] = available_signals
- self.all_signals = self.all_signals.union(set(available_signals))
+ self.all_signals = self.all_signals.union(set(available_signals)) # type: ignore
else:
available_signals = []
return available_signals
@@ -639,7 +641,7 @@ def get_visitnumber(self, rec: Union[str, int]) -> int:
Visit number extracted from `rec`.
"""
- return self.split_rec_name(rec)["visitnumber"]
+ return self.split_rec_name(rec)["visitnumber"] # type: ignore
def get_tranche(self, rec: Union[str, int]) -> str:
"""Get ``tranche`` ("shhs1" or "shhs2") from `rec`.
@@ -656,7 +658,7 @@ def get_tranche(self, rec: Union[str, int]) -> str:
Tranche extracted from `rec`.
"""
- return self.split_rec_name(rec)["tranche"]
+ return self.split_rec_name(rec)["tranche"] # type: ignore
def get_nsrrid(self, rec: Union[str, int]) -> int:
"""Get ``nsrrid`` from `rec`.
@@ -673,14 +675,14 @@ def get_nsrrid(self, rec: Union[str, int]) -> int:
``nsrrid`` extracted from `rec`.
"""
- return self.split_rec_name(rec)["nsrrid"]
+ return self.split_rec_name(rec)["nsrrid"] # type: ignore
def get_fs(
self,
rec: Union[str, int],
sig: str = "ECG",
rec_path: Optional[Union[str, bytes, os.PathLike]] = None,
- ) -> Real:
+ ) -> Union[float, int]:
"""Get the sampling frequency of a signal of a record.
Parameters
@@ -698,7 +700,7 @@ def get_fs(
Returns
-------
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal `sig` of the record `rec`.
If corresponding signal (.edf) file is not available,
or the signal file does not contain the signal `sig`,
@@ -708,26 +710,26 @@ def get_fs(
if isinstance(rec, int):
rec = self[rec]
sig = self.match_channel(sig, raise_error=False)
- assert sig in self.all_signals.union({"rpeak"}), f"Invalid signal name: `{sig}`"
+ assert sig in self.all_signals.union({"rpeak"}), f"Invalid signal name: `{sig}`" # type: ignore
if sig.lower() == "rpeak":
df_rpeaks_with_type_info = self.load_wave_delineation_ann(rec)
if df_rpeaks_with_type_info.empty:
- self.logger.info(f"Rpeak annotation file corresponding to `{rec}` is not available.")
+ self.logger.info(f"Rpeak annotation file corresponding to `{rec}` is not available.") # type: ignore
return -1
return df_rpeaks_with_type_info.iloc[0]["samplingrate"]
frp = self.get_absolute_path(rec, rec_path)
if not frp.exists():
- self.logger.info(f"Signal (.edf) file corresponding to `{rec}` is not available.")
+ self.logger.info(f"Signal (.edf) file corresponding to `{rec}` is not available.") # type: ignore
return -1
self.safe_edf_file_operation("open", frp)
sig = self.match_channel(sig)
- available_signals = [s.lower() for s in self.file_opened.getSignalLabels()]
+ available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore
if sig not in available_signals:
- self.logger.info(f"Signal `{sig}` is not available in signal file corresponding to `{rec}`.")
+ self.logger.info(f"Signal `{sig}` is not available in signal file corresponding to `{rec}`.") # type: ignore
return -1
chn_num = available_signals.index(sig)
- fs = self.file_opened.getSampleFrequency(chn_num)
+ fs = self.file_opened.getSampleFrequency(chn_num) # type: ignore
self.safe_edf_file_operation("close")
return fs
@@ -761,15 +763,15 @@ def get_chn_num(
"""
sig = self.match_channel(sig)
available_signals = self.get_available_signals(rec)
- if sig not in available_signals:
+ if sig not in available_signals: # type: ignore
if isinstance(rec, int):
rec = self[rec]
- self.logger.info(
+ self.logger.info( # type: ignore
f"Signal (.edf) file corresponding to `{rec}` is not available, or"
f"signal `{sig}` is not available in signal file corresponding to `{rec}`."
)
return -1
- chn_num = available_signals.index(self.match_channel(sig))
+ chn_num = available_signals.index(self.match_channel(sig)) # type: ignore
return chn_num
def match_channel(self, channel: str, raise_error: bool = True) -> str:
@@ -797,7 +799,7 @@ def match_channel(self, channel: str, raise_error: bool = True) -> str:
raise ValueError(f"No channel named `{channel}`")
return channel
- def get_absolute_path(
+ def get_absolute_path( # type: ignore
self,
rec: Union[str, int],
rec_path: Optional[Union[str, bytes, os.PathLike]] = None,
@@ -823,7 +825,7 @@ def get_absolute_path(
"""
if rec_path is not None:
- rp = Path(rec_path)
+ rp = Path(rec_path) # type: ignore
return rp
assert rec_type in self.folder_or_file, (
@@ -835,7 +837,7 @@ def get_absolute_path(
tranche, nsrrid = [self.split_rec_name(rec)[k] for k in ["tranche", "nsrrid"]]
# rp = self._df_records.loc[rec, rec_type]
- rp = self.folder_or_file[rec_type] / tranche / f"{rec}{self.extension[rec_type]}"
+ rp = self.folder_or_file[rec_type] / tranche / f"{rec}{self.extension[rec_type]}" # type: ignore
return rp
def database_stats(self) -> None:
@@ -856,12 +858,12 @@ def show_rec_stats(self, rec: Union[str, int], rec_path: Optional[Union[str, byt
"""
frp = self.get_absolute_path(rec, rec_path, rec_type="psg")
self.safe_edf_file_operation("open", frp)
- for chn, lb in enumerate(self.file_opened.getSignalLabels()):
+ for chn, lb in enumerate(self.file_opened.getSignalLabels()): # type: ignore
print("SignalLabel:", lb)
- print("Prefilter:", self.file_opened.getPrefilter(chn))
- print("Transducer:", self.file_opened.getTransducer(chn))
- print("PhysicalDimension:", self.file_opened.getPhysicalDimension(chn))
- print("SampleFrequency:", self.file_opened.getSampleFrequency(chn))
+ print("Prefilter:", self.file_opened.getPrefilter(chn)) # type: ignore
+ print("Transducer:", self.file_opened.getTransducer(chn)) # type: ignore
+ print("PhysicalDimension:", self.file_opened.getPhysicalDimension(chn)) # type: ignore
+ print("SampleFrequency:", self.file_opened.getSampleFrequency(chn)) # type: ignore
print("*" * 40)
self.safe_edf_file_operation("close")
@@ -870,11 +872,11 @@ def load_psg_data(
rec: Union[str, int],
channel: str = "all",
rec_path: Optional[Union[str, bytes, os.PathLike]] = None,
- sampfrom: Optional[Real] = None,
- sampto: Optional[Real] = None,
- fs: Optional[int] = None,
+ sampfrom: Optional[Union[float, int]] = None,
+ sampto: Optional[Union[float, int]] = None,
+ fs: Optional[Union[float, int]] = None,
physical: bool = True,
- ) -> Union[Dict[str, Tuple[np.ndarray, Real]], Tuple[np.ndarray, Real]]:
+ ) -> Union[Dict[str, Tuple[NDArray, Union[float, int]]], Tuple[NDArray, Union[float, int]]]:
"""Load PSG data of the record.
Parameters
@@ -888,13 +890,13 @@ def load_psg_data(
rec_path : `path-like`, optional
Path of the file which contains the PSG data.
If is None, default path will be used.
- sampfrom : numbers.Real, optional
+ sampfrom : float or int, optional
Start time (units in seconds) of the data to be loaded,
valid only when `channel` is some specific channel.
- sampto : numbers.Real, optional
+ sampto : float or int, optional
End time (units in seconds) of the data to be loaded,
valid only when `channel` is some specific channel
- fs : numbers.Real, optional
+ fs : float or int, optional
Sampling frequency of the loaded data.
If not None, the loaded data will be resampled to this frequency,
otherwise, the original sampling frequency will be used.
@@ -908,11 +910,11 @@ def load_psg_data(
dict or tuple
If `channel` is "all", then a dictionary will be returned:
- - keys: PSG channel names;
- - values: PSG data and sampling frequency
+ - keys: PSG channel names;
+ - values: PSG data and sampling frequency
Otherwise, a 2-tuple will be returned:
- (:class:`numpy.ndarray`, :class:`numbers.Real`), which is the
+ (:class:`numpy.ndarray`, :class:`int` or :class:`float`), which is the
PSG data of the channel `channel` and its sampling frequency.
"""
@@ -923,17 +925,17 @@ def load_psg_data(
if chn == "all":
ret_data = {
k: (
- self.file_opened.readSignal(idx, digital=not physical),
- self.file_opened.getSampleFrequency(idx),
+ self.file_opened.readSignal(idx, digital=not physical), # type: ignore
+ self.file_opened.getSampleFrequency(idx), # type: ignore
)
- for idx, k in enumerate(self.file_opened.getSignalLabels())
+ for idx, k in enumerate(self.file_opened.getSignalLabels()) # type: ignore
}
else:
- all_signals = [s.lower() for s in self.file_opened.getSignalLabels()]
- assert chn in all_signals, f"`channel` should be one of `{self.file_opened.getSignalLabels()}`, but got `{chn}`"
+ all_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore
+ assert chn in all_signals, f"`channel` should be one of `{self.file_opened.getSignalLabels()}`, but got `{chn}`" # type: ignore
idx = all_signals.index(chn)
- data_fs = self.file_opened.getSampleFrequency(idx)
- data = self.file_opened.readSignal(idx, digital=not physical)
+ data_fs = self.file_opened.getSampleFrequency(idx) # type: ignore
+ data = self.file_opened.readSignal(idx, digital=not physical) # type: ignore
# the `readSignal` method of `EdfReader` does NOT treat
# the parameters `start` and `n` correctly
# so we have to do it manually
@@ -962,10 +964,10 @@ def load_ecg_data(
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
- units: Union[str, type(None)] = "mV",
- fs: Optional[int] = None,
+ units: Union[str, None] = "mV",
+ fs: Optional[Union[float, int]] = None,
return_fs: bool = True,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Union[float, int]]]:
"""Load ECG data of the record.
Parameters
@@ -988,7 +990,7 @@ def load_ecg_data(
units : str or None, default "mV"
Units of the output signal, can also be "μV" (aliases "uV", "muV").
None for digital data, without digital-to-physical conversion.
- fs : numbers.Real, optional
+ fs : float or int, optional
Sampling frequency of the loaded data.
If not None, the loaded data will be resampled to this frequency,
otherwise, the original sampling frequency will be used.
@@ -999,7 +1001,7 @@ def load_ecg_data(
-------
data : numpy.ndarray
The loaded ECG data.
- data_fs : numbers.Real
+ data_fs : float or int
Sampling frequency of the loaded ECG data.
Returned if `return_fs` is True.
@@ -1029,7 +1031,7 @@ def load_ecg_data(
fs=fs,
physical=units is not None,
)
- data = data.astype(DEFAULTS.DTYPE.NP)
+ data = data.astype(DEFAULTS.DTYPE.NP) # type: ignore
if units is not None and units.lower() in ["μv", "uv", "muv"]:
data *= 1e3
@@ -1039,25 +1041,25 @@ def load_ecg_data(
data = data[:, np.newaxis]
if return_fs:
- return data, data_fs
+ return data, data_fs # type: ignore
return data
@add_docstring(
" " * 8 + "NOTE: one should call `load_psg_data` to load other channels.",
mode="append",
)
- @add_docstring(load_ecg_data.__doc__)
- def load_data(
+ @add_docstring(load_ecg_data.__doc__) # type: ignore
+ def load_data( # type: ignore
self,
rec: Union[str, int],
rec_path: Optional[Union[str, bytes, os.PathLike]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
data_format: str = "channel_first",
- units: Union[str, type(None)] = "mV",
+ units: Union[str, None] = "mV",
fs: Optional[int] = None,
return_fs: bool = True,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Union[float, int]]]:
"""alias of `load_ecg_data`"""
return self.load_ecg_data(
rec=rec,
@@ -1070,13 +1072,13 @@ def load_data(
return_fs=return_fs,
)
- def load_ann(
+ def load_ann( # type: ignore
self,
rec: Union[str, int],
ann_type: str,
ann_path: Optional[Union[str, bytes, os.PathLike]] = None,
**kwargs: Any,
- ) -> Union[np.ndarray, pd.DataFrame, dict]:
+ ) -> Union[NDArray, pd.DataFrame, dict, None]:
"""Load annotations of specific type of the record.
Parameters
@@ -1160,7 +1162,7 @@ def load_event_ann(
df_events["EventType"] = df_events["EventType"].apply(lambda s: s.split("|")[1])
df_events["EventConcept"] = df_events["EventConcept"].apply(lambda s: s.split("|")[1])
for c in ["Start", "Duration", "SpO2Nadir", "SpO2Baseline"]:
- df_events[c] = df_events[c].apply(self.str_to_real_number)
+ df_events[c] = df_events[c].apply(self.str_to_real_number) # type: ignore
return df_events
@@ -1200,7 +1202,7 @@ def load_event_profusion_ann(
sleep_stage_list = [int(ss) for ss in doc["CMPStudyConfig"]["SleepStages"]["SleepStage"]]
df_events = pd.DataFrame(doc["CMPStudyConfig"]["ScoredEvents"]["ScoredEvent"])
for c in ["Start", "Duration", "LowestSpO2", "Desaturation"]:
- df_events[c] = df_events[c].apply(self.str_to_real_number)
+ df_events[c] = df_events[c].apply(self.str_to_real_number) # type: ignore
ret = {"sleep_stage_list": sleep_stage_list, "df_events": df_events}
return ret
@@ -1320,7 +1322,7 @@ def load_sleep_ann(
df_sleep_ann = df_hrv_ann[self.sleep_ann_keys_from_hrv].reset_index(drop=True)
else:
df_sleep_ann = pd.DataFrame(columns=self.sleep_ann_keys_from_hrv)
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"record `{rec}` has `{len(df_sleep_ann)}` sleep annotations from corresponding "
f"hrv-5min (detailed) annotation file, with `{len(self.sleep_ann_keys_from_hrv)}` column(s)"
)
@@ -1331,7 +1333,7 @@ def load_sleep_ann(
df_sleep_ann = df_event_ann[_cols]
else:
df_sleep_ann = pd.DataFrame(columns=_cols)
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"record `{rec}` has `{len(df_sleep_ann)}` sleep annotations from corresponding "
f"event-nsrr annotation file, with `{len(_cols)}` column(s)"
)
@@ -1340,7 +1342,7 @@ def load_sleep_ann(
# temporarily finished
# latter to make imporvements
df_sleep_ann = dict_event_ann
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"record `{rec}` has `{len(df_sleep_ann['df_events'])}` sleep event annotations "
"from corresponding event-profusion annotation file, "
f"with `{len(df_sleep_ann['df_events'].columns)}` column(s)"
@@ -1449,11 +1451,11 @@ def load_sleep_stage_ann(
)
if source.lower() != "event_profusion":
- self.logger.debug(
- f"record `{rec}` has `{len(df_tmp)}` raw (epoch_len = 5min) sleep stage annotations, "
+ self.logger.debug( # type: ignore
+ f"record `{rec}` has `{len(df_tmp)}` raw (epoch_len = 5min) sleep stage annotations, " # type: ignore
f"with `{len(self.sleep_stage_ann_keys_from_hrv)}` column(s)"
)
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"after being transformed (epoch_len = 30sec), record `{rec}` has {len(df_sleep_stage_ann)} "
f"sleep stage annotations, with `{len(self.sleep_stage_keys)}` column(s)"
)
@@ -1509,7 +1511,7 @@ def load_sleep_event_ann(
for _, row in df_sleep_ann.iterrows():
if row["hasrespevent"] == 0:
continue
- l_events = row[self.sleep_event_ann_keys_from_hrv[1:-1]].values.reshape(
+ l_events = row[self.sleep_event_ann_keys_from_hrv[1:-1]].values.reshape( # type: ignore
(len(self.sleep_event_ann_keys_from_hrv) // 2 - 1, 2)
)
l_events = l_events[~np.isnan(l_events[:, 0])]
@@ -1521,11 +1523,11 @@ def load_sleep_event_ann(
)
df_sleep_event_ann = df_sleep_event_ann[self.sleep_event_keys]
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"record `{rec}` has `{len(df_sleep_ann)}` raw (epoch_len = 5min) sleep event "
f"annotations from hrv, with `{len(self.sleep_event_ann_keys_from_hrv)}` column(s)"
)
- self.logger.debug(f"after being transformed, record `{rec}` has `{len(df_sleep_event_ann)}` sleep event(s)")
+ self.logger.debug(f"after being transformed, record `{rec}` has `{len(df_sleep_event_ann)}` sleep event(s)") # type: ignore
elif source.lower() == "event":
if event_types is None:
event_types = ["respiratory", "arousal"]
@@ -1570,7 +1572,7 @@ def load_sleep_event_ann(
_cols = _cols | set(self.long_event_names_from_event[3:4])
_cols = list(_cols)
- self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`")
+ self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") # type: ignore
df_sleep_event_ann = df_sleep_ann[df_sleep_ann["EventConcept"].isin(_cols)].reset_index(drop=True)
df_sleep_event_ann = df_sleep_event_ann.rename(
@@ -1631,7 +1633,7 @@ def load_sleep_event_ann(
_cols = _cols | set(self.event_names_from_event_profusion[3:4])
_cols = list(_cols)
- self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`")
+ self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") # type: ignore
df_sleep_event_ann = df_sleep_ann[df_sleep_ann["Name"].isin(_cols)].reset_index(drop=True)
df_sleep_event_ann = df_sleep_event_ann.rename(
@@ -1649,7 +1651,7 @@ def load_sleep_event_ann(
else:
raise ValueError(f"Source `{source}` not supported, " "only `hrv`, `event`, `event_profusion` are supported")
- return df_sleep_event_ann
+ return df_sleep_event_ann # type: ignore
def load_apnea_ann(
self,
@@ -1727,7 +1729,7 @@ def load_wave_delineation_ann(
file_path = self.get_absolute_path(rec, wave_deli_path, rec_type="wave_delineation")
if not file_path.is_file():
- self.logger.debug(
+ self.logger.debug( # type: ignore
f"The annotation file of wave delineation of record `{rec}` has not been downloaded yet. "
f"Or the path `{str(file_path)}` is not correct. "
f"Or `{rec}` does not have `rpeak.csv` annotation file. Please check!"
@@ -1746,7 +1748,7 @@ def load_rpeak_ann(
exclude_abnormal_beats: bool = True,
units: Optional[Literal["s", "ms"]] = None,
**kwargs: Any,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load annotations on R peaks of the record.
Parameters
@@ -1788,7 +1790,7 @@ def load_rpeak_ann(
rpeaks = df_rpeaks_with_type_info[~df_rpeaks_with_type_info["Type"].isin(exclude_beat_types)]["rpointadj"].values
if units is None:
- rpeaks = (np.round(rpeaks)).astype(int)
+ rpeaks = (np.round(rpeaks)).astype(int) # type: ignore
elif units.lower() == "s":
fs = df_rpeaks_with_type_info.iloc[0]["samplingrate"]
rpeaks = rpeaks / fs
@@ -1811,7 +1813,7 @@ def load_rr_ann(
rpeak_ann_path: Optional[Union[str, bytes, os.PathLike]] = None,
units: Literal["s", "ms", None] = "s",
**kwargs: Any,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load annotations on RR intervals of the record.
Parameters
@@ -1852,7 +1854,7 @@ def load_nn_ann(
rpeak_ann_path: Optional[Union[str, bytes, os.PathLike]] = None,
units: Union[str, None] = "s",
**kwargs: Any,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load annotations on NN intervals of the record.
Parameters
@@ -1912,7 +1914,7 @@ def locate_artifacts(
rec: Union[str, int],
wave_deli_path: Optional[Union[str, bytes, os.PathLike]] = None,
units: Optional[Literal["s", "ms"]] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Locate "artifacts" in the record.
Parameters
@@ -1941,7 +1943,7 @@ def locate_artifacts(
return np.array([], dtype=dtype)
# df_rpeaks_with_type_info = df_rpeaks_with_type_info[["Type", "rpointadj"]]
- artifacts = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 0]["rpointadj"].values)).astype(int)
+ artifacts = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 0]["rpointadj"].values)).astype(int) # type: ignore
if units is not None:
fs = df_rpeaks_with_type_info.iloc[0]["samplingrate"]
@@ -1965,7 +1967,7 @@ def locate_abnormal_beats(
wave_deli_path: Optional[Union[str, bytes, os.PathLike]] = None,
abnormal_type: Optional[Literal["VE", "SVE"]] = None,
units: Optional[Literal["s", "ms"]] = None,
- ) -> Union[Dict[str, np.ndarray], np.ndarray]:
+ ) -> Union[Dict[str, NDArray], NDArray, None]:
"""Locate "abnormal beats" in the record.
Parameters
@@ -2005,8 +2007,8 @@ def locate_abnormal_beats(
if not df_rpeaks_with_type_info.empty:
# df_rpeaks_with_type_info = df_rpeaks_with_type_info[["Type", "rpointadj"]]
# 2 = VE, 3 = SVE
- ve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 2]["rpointadj"].values)).astype(int)
- sve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 3]["rpointadj"].values)).astype(int)
+ ve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 2]["rpointadj"].values)).astype(int) # type: ignore
+ sve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 3]["rpointadj"].values)).astype(int) # type: ignore
abnormal_rpeaks = {"VE": ve, "SVE": sve}
else:
dtype = int if units is None or units.lower() != "s" else float
@@ -2043,7 +2045,7 @@ def load_eeg_band_ann(
rec: Union[str, int],
eeg_band_ann_path: Optional[Union[str, bytes, os.PathLike]] = None,
**kwargs: Any,
- ) -> pd.DataFrame:
+ ) -> pd.DataFrame: # type: ignore
"""Load annotations on EEG bands of the record.
Parameters
@@ -2062,7 +2064,7 @@ def load_eeg_band_ann(
"""
if self.current_version >= "0.15.0":
- self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}")
+ self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") # type: ignore
else:
raise NotImplementedError
@@ -2071,7 +2073,7 @@ def load_eeg_spectral_ann(
rec: Union[str, int],
eeg_spectral_ann_path: Optional[Union[str, bytes, os.PathLike]] = None,
**kwargs: Any,
- ) -> pd.DataFrame:
+ ) -> pd.DataFrame: # type: ignore
"""Load annotations on EEG spectral summary of the record.
Parameters
@@ -2090,7 +2092,7 @@ def load_eeg_spectral_ann(
"""
if self.current_version >= "0.15.0":
- self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}")
+ self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") # type: ignore
else:
raise NotImplementedError
@@ -2215,8 +2217,8 @@ def _plot_ann(
raise NotImplementedError("Plotting of some type of events in `df_sleep_event` has not been implemented yet!")
if plot_format.lower() == "hypnogram":
- stage_mask = df_sleep_stage["sleep_stage"].values
- stage_mask = len(self.sleep_stage_names) - 1 - stage_mask
+ stage_mask = df_sleep_stage["sleep_stage"].values # type: ignore
+ stage_mask = len(self.sleep_stage_names) - 1 - stage_mask # type: ignore
fig, ax = self.plot_hypnogram(stage_mask, granularity=30)
return
@@ -2238,42 +2240,42 @@ def _plot_ann(
ax_events.set_xlabel("Time", fontsize=16)
# ax_events.set_ylabel("Events", fontsize=16)
else:
- ax_stages, ax_events = axes
+ ax_stages, ax_events = axes # type: ignore
ax_stages.set_title("Sleep Stages and Events", fontsize=24)
ax_events.set_xlabel("Time", fontsize=16)
if ax_stages is not None:
- for k, v in sleep_stages.items():
+ for k, v in sleep_stages.items(): # type: ignore
for itv in v:
ax_stages.axvspan(
- datetime.fromtimestamp(itv[0]),
- datetime.fromtimestamp(itv[1]),
+ datetime.fromtimestamp(itv[0]), # type: ignore
+ datetime.fromtimestamp(itv[1]), # type: ignore
color=self.palette[k],
alpha=plot_alpha,
)
ax_stages.legend(
- handles=[patches[k] for k in self.all_sleep_stage_names if k in sleep_stages.keys()],
+ handles=[patches[k] for k in self.all_sleep_stage_names if k in sleep_stages.keys()], # type: ignore
loc="best",
) # keep ordering
plt.setp(ax_stages.get_yticklabels(), visible=False)
ax_stages.tick_params(axis="y", which="both", length=0)
if ax_events is not None:
- for _, row in df_sleep_event.iterrows():
+ for _, row in df_sleep_event.iterrows(): # type: ignore
ax_events.axvspan(
- datetime.fromtimestamp(row["event_start"]),
- datetime.fromtimestamp(row["event_end"]),
+ datetime.fromtimestamp(row["event_start"]), # type: ignore
+ datetime.fromtimestamp(row["event_end"]), # type: ignore
color=self.palette[row["event_name"]],
alpha=plot_alpha,
)
ax_events.legend(
- handles=[patches[k] for k in current_legal_events if k in set(df_sleep_event["event_name"])],
+ handles=[patches[k] for k in current_legal_events if k in set(df_sleep_event["event_name"])], # type: ignore
loc="best",
) # keep ordering
plt.setp(ax_events.get_yticklabels(), visible=False)
ax_events.tick_params(axis="y", which="both", length=0)
- def str_to_real_number(self, s: Union[str, Real]) -> Real:
+ def str_to_real_number(self, s: Union[str, float, int]) -> Union[float, int]:
"""Convert a string to a real number.
Some columns in the annotations might incorrectly
@@ -2560,7 +2562,7 @@ def folder_or_file(self) -> Dict[str, Path]:
"wave_delineation": self.wave_deli_path,
"event": self.event_ann_path,
"event_profusion": self.event_profusion_ann_path,
- }
+ } # type: ignore
@property
def url(self) -> str:
diff --git a/torch_ecg/databases/other_databases/cachet_cadb.py b/torch_ecg/databases/other_databases/cachet_cadb.py
index 22726d67..6855c475 100644
--- a/torch_ecg/databases/other_databases/cachet_cadb.py
+++ b/torch_ecg/databases/other_databases/cachet_cadb.py
@@ -13,6 +13,7 @@
import numpy as np
import pandas as pd
import scipy.signal as SS
+from numpy.typing import NDArray
from ...cfg import DEFAULTS
from ...utils.download import http_get
@@ -349,7 +350,7 @@ def load_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load physical (converted from digital) ECG data,
or load digital signal directly.
@@ -434,7 +435,7 @@ def load_context_data(
channels: Optional[Union[str, int, List[str], List[int]]] = None,
units: Optional[str] = None,
fs: Optional[Real] = None,
- ) -> Union[np.ndarray, pd.DataFrame]:
+ ) -> Union[NDArray, pd.DataFrame]:
"""Load context data (e.g. accelerometer, heart rate, etc.).
Parameters
@@ -544,7 +545,7 @@ def load_context_data(
def load_ann(
self, rec: Union[str, int], ann_format: str = "pd"
- ) -> Union[pd.DataFrame, np.ndarray, Dict[Union[int, str], np.ndarray]]:
+ ) -> Union[pd.DataFrame, NDArray, Dict[Union[int, str], NDArray]]:
"""Load annotation from the metadata file.
Parameters
diff --git a/torch_ecg/databases/other_databases/sph.py b/torch_ecg/databases/other_databases/sph.py
index 0e566a03..5aabc0d3 100644
--- a/torch_ecg/databases/other_databases/sph.py
+++ b/torch_ecg/databases/other_databases/sph.py
@@ -9,6 +9,7 @@
import h5py
import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from ...cfg import DEFAULTS
from ...utils import EAK
@@ -172,7 +173,7 @@ def load_data(
data_format: str = "channel_first",
units: str = "mV",
return_fs: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load ECG data from h5 file of the record.
Parameters
@@ -233,10 +234,9 @@ def load_ann(self, rec: Union[str, int], ann_format: str = "c", ignore_modifier:
Record name or index of the record in :attr:`all_records`.
ann_format : str, default "a"
Format of labels, one of the following (case insensitive):
-
- - "a": abbreviations
- - "f": full names
- - "c": AHACode
+ - "a": abbreviations
+ - "f": full names
+ - "c": AHACode
ignore_modifier : bool, default True
Whether to ignore the modifiers of the annotations or not.
For example, "60+310" will be converted to "60".
@@ -390,7 +390,7 @@ def download(self, files: Optional[Union[str, Sequence[str]]]) -> None:
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
ann: Optional[Sequence[str]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, List[Union[str, int]]]] = None,
diff --git a/torch_ecg/databases/physionet_databases/afdb.py b/torch_ecg/databases/physionet_databases/afdb.py
index 5adf2122..e67a7cda 100644
--- a/torch_ecg/databases/physionet_databases/afdb.py
+++ b/torch_ecg/databases/physionet_databases/afdb.py
@@ -8,6 +8,7 @@
import numpy as np
import wfdb
+from numpy.typing import NDArray
from ...cfg import CFG
from ...utils.misc import add_docstring, get_record_list_recursive
@@ -29,10 +30,10 @@
3. signals are sampled at 250 samples per second with 12-bit resolution over a range of ±10 millivolts, with a typical recording bandwidth of approximately 0.1 Hz to 40 Hz
4. 4 classes of rhythms are annotated:
- - AFIB: atrial fibrillation
- - AFL: atrial flutter
- - J: AV junctional rhythm
- - N: all other rhythms
+ - AFIB: atrial fibrillation
+ - AFL: atrial flutter
+ - J: AV junctional rhythm
+ - N: all other rhythms
5. rhythm annotations almost all start with "(N", except for 4 which start with '(AFIB', which are all within 1 second (250 samples)
6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_.
@@ -138,7 +139,7 @@ def load_ann(
sampto: Optional[int] = None,
ann_format: Literal["intervals", "mask"] = "intervals",
keep_original: bool = False,
- ) -> Union[Dict[str, list], np.ndarray]:
+ ) -> Union[Dict[str, list], NDArray]:
"""Load annotations (header) from the .hea files.
Parameters
@@ -203,7 +204,7 @@ def load_beat_ann(
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load beat annotations from corresponding annotation files.
Parameters
@@ -253,7 +254,7 @@ def load_rpeak_indices(
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
"""
alias of `self.load_beat_ann`
"""
@@ -262,9 +263,9 @@ def load_rpeak_indices(
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, List[str], List[int]]] = None,
sampfrom: Optional[int] = None,
diff --git a/torch_ecg/databases/physionet_databases/apnea_ecg.py b/torch_ecg/databases/physionet_databases/apnea_ecg.py
index 1fd07a37..28677d3e 100644
--- a/torch_ecg/databases/physionet_databases/apnea_ecg.py
+++ b/torch_ecg/databases/physionet_databases/apnea_ecg.py
@@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
import wfdb
+from numpy.typing import NDArray
from ...cfg import DEFAULTS
from ...utils import add_docstring
@@ -200,7 +201,7 @@ def load_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs)
@add_docstring(PhysioNetDataBase.load_data.__doc__)
@@ -213,7 +214,7 @@ def load_ecg_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
if isinstance(rec, int):
rec = self[rec]
if rec not in self.ecg_records:
@@ -244,7 +245,7 @@ def load_rsp_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
if rec not in self.rsp_records:
raise ValueError(f"`{rec}` is not a record of RSP signals")
data = self.load_data(
diff --git a/torch_ecg/databases/physionet_databases/cinc2017.py b/torch_ecg/databases/physionet_databases/cinc2017.py
index 799fc7e2..aec1ed8b 100644
--- a/torch_ecg/databases/physionet_databases/cinc2017.py
+++ b/torch_ecg/databases/physionet_databases/cinc2017.py
@@ -9,6 +9,7 @@
import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from ...cfg import DEFAULTS
from ...utils.misc import add_docstring, get_record_list_recursive3
@@ -208,9 +209,8 @@ def load_ann(self, rec: Union[str, int], version: Optional[int] = None, ann_form
Version of the annotation file, by default the latest version.
ann_format : {"a", "f"}, optional
Format of returned annotation, by default "a".
-
- - "a" - abbreviation
- - "f" - full name
+ - "a", abbreviation
+ - "f", full name
Returns
-------
@@ -236,10 +236,10 @@ def load_ann(self, rec: Union[str, int], version: Optional[int] = None, ann_form
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
ann: Optional[str] = None,
ticks_granularity: int = 0,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
) -> None:
"""Plot the ECG signal of the record.
diff --git a/torch_ecg/databases/physionet_databases/cinc2018.py b/torch_ecg/databases/physionet_databases/cinc2018.py
index b0470907..a7460b89 100644
--- a/torch_ecg/databases/physionet_databases/cinc2018.py
+++ b/torch_ecg/databases/physionet_databases/cinc2018.py
@@ -6,10 +6,10 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
-import numpy as np
import pandas as pd
import scipy.signal as SS
import wfdb
+from numpy.typing import NDArray
from tqdm.auto import tqdm
from ...cfg import DEFAULTS
@@ -96,9 +96,8 @@ class CINC2018(PhysioNetDataBase, PSGDataBaseMixin):
Level of logging verbosity.
kwargs : dict, optional
Auxilliary key word arguments, including:
-
- - `subset` : {"training", "test"}, default "training"
- The subset of the database to use.
+ - `subset` : {"training", "test"}, default "training".
+ The subset of the database to use.
"""
@@ -308,7 +307,7 @@ def load_psg_data(
physical: bool = True,
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load PSG data of the record.
Parameters
@@ -407,7 +406,7 @@ def load_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load ECG data of the record.
Parameters
@@ -479,7 +478,7 @@ def load_ecg_data(
units: Union[str, type(None)] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""alias of `load_data`"""
return self.load_data(
rec=rec,
diff --git a/torch_ecg/databases/physionet_databases/cinc2020.py b/torch_ecg/databases/physionet_databases/cinc2020.py
index c4794630..f296bdbd 100644
--- a/torch_ecg/databases/physionet_databases/cinc2020.py
+++ b/torch_ecg/databases/physionet_databases/cinc2020.py
@@ -16,6 +16,7 @@
import pandas as pd
import scipy.signal as SS
import wfdb
+from numpy.typing import NDArray
from scipy.io import loadmat
from ...cfg import CFG, DEFAULTS
@@ -501,7 +502,7 @@ def load_data(
units: Literal["mV", "μV", "uV", None] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load physical (converted from digital) ECG data,
which is more understandable for humans;
or load digital signal directly.
@@ -858,10 +859,9 @@ def get_labels(
in the CINC2020 official phase.
fmt : str, default "s"
Format of labels, one of the following (case insensitive):
-
- - "a", abbreviations
- - "f", full names
- - "s", SNOMED CT Code
+ - "a", abbreviations
+ - "f", full names
+ - "s", SNOMED CT Code
normalize : bool, default True
If True, the labels will be transformed into their equavalents,
@@ -945,8 +945,8 @@ def get_subject_info(self, rec: Union[str, int], items: Optional[List[str]] = No
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, List[str]]] = None,
same_range: bool = False,
@@ -1193,7 +1193,7 @@ def load_resampled_data(
rec: Union[str, int],
data_format: str = "channel_first",
siglen: Optional[int] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""
Resample the data of `rec` to 500Hz,
or load the resampled data in 500Hz, if the corr. data file already exists
@@ -1248,7 +1248,7 @@ def load_resampled_data(
data = np.moveaxis(data, -1, -2)
return data
- def load_raw_data(self, rec: Union[str, int], backend: Literal["scipy", "wfdb"] = "scipy") -> np.ndarray:
+ def load_raw_data(self, rec: Union[str, int], backend: Literal["scipy", "wfdb"] = "scipy") -> NDArray:
"""Load raw data from corresponding files with no further processing,
in order to facilitate feeding data into the `run_12ECG_classifier` function.
@@ -1406,7 +1406,7 @@ def compute_all_metrics(classes: List[str], truth: Sequence, binary_pred: Sequen
)
-def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float:
+def _compute_accuracy(labels: NDArray, outputs: NDArray) -> float:
"""Compute recording-wise accuracy.
Parameters
@@ -1434,7 +1434,7 @@ def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float:
return float(num_correct_recordings) / float(num_recordings)
-def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normalize: bool = False) -> np.ndarray:
+def _compute_confusion_matrices(labels: NDArray, outputs: NDArray, normalize: bool = False) -> NDArray:
"""Compute confusion matrices.
Compute a binary confusion matrix for each class k:
@@ -1498,7 +1498,7 @@ def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normali
return A
-def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float:
+def _compute_f_measure(labels: NDArray, outputs: NDArray) -> float:
"""Compute macro-averaged F1 score.
Parameters
@@ -1533,7 +1533,7 @@ def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float:
return macro_f_measure
-def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]:
+def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]:
"""Compute F-beta and G-beta measures.
Parameters
@@ -1578,7 +1578,7 @@ def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real)
return macro_f_beta_measure, macro_g_beta_measure
-def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float]:
+def _compute_auc(labels: NDArray, outputs: NDArray) -> Tuple[float, float]:
"""Compute macro-averaged AUROC and macro-averaged AUPRC.
Parameters
@@ -1674,7 +1674,7 @@ def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float]
return macro_auroc, macro_auprc
-def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) -> np.ndarray:
+def _compute_modified_confusion_matrix(labels: NDArray, outputs: NDArray) -> NDArray:
"""
Compute a binary multi-class, multi-label confusion matrix,
where the rows are the labels and the columns are the outputs.
@@ -1712,9 +1712,9 @@ def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray)
def compute_challenge_metric(
- weights: np.ndarray,
- labels: np.ndarray,
- outputs: np.ndarray,
+ weights: NDArray,
+ labels: NDArray,
+ outputs: NDArray,
classes: List[str],
normal_class: str,
) -> float:
diff --git a/torch_ecg/databases/physionet_databases/cinc2021.py b/torch_ecg/databases/physionet_databases/cinc2021.py
index 22088a9b..51a57344 100644
--- a/torch_ecg/databases/physionet_databases/cinc2021.py
+++ b/torch_ecg/databases/physionet_databases/cinc2021.py
@@ -17,6 +17,7 @@
import pandas as pd
import scipy.signal as SS
import wfdb
+from numpy.typing import NDArray
from scipy.io import loadmat
from tqdm.auto import tqdm
@@ -88,8 +89,8 @@
Each recording is 10 seconds long with a sampling frequency of 500 Hz
this tranche contains two subsets:
- - Chapman_Shaoxing: "JS00001" - "JS10646"
- - Ningbo: "JS10647" - "JS45551"
+ - Chapman_Shaoxing: "JS00001" - "JS10646"
+ - Ningbo: "JS10647" - "JS45551"
All files can be downloaded from [8]_ or [9]_.
@@ -115,6 +116,7 @@
... leads = ann["df_leads"]["lead_name"].values.tolist()
... if leads not in set_leads:
... set_leads.append(leads)
+
5. Challenge official website [1]_. Webpage of the database on PhysioNet [2]_.
""",
@@ -670,7 +672,7 @@ def load_data(
units: Literal["mV", "μV", "uV", "muV", None] = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
"""Load physical (converted from digital) ECG data.
Parameters
@@ -1049,10 +1051,9 @@ def get_labels(
in the CINC2021 official phase.
fmt : str, default "s"
Format of labels, one of the following (case insensitive):
-
- - "a", abbreviations
- - "f", full names
- - "s", SNOMED CT Code
+ - "a", abbreviations
+ - "f", full names
+ - "s", SNOMED CT Code
normalize : bool, default True
If True, the labels will be transformed into their equavalents,
@@ -1153,7 +1154,7 @@ def get_subject_info(self, rec: Union[str, int], items: Optional[List[str]] = No
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
ann: Optional[Dict[str, Sequence[str]]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, Sequence[str]]] = None,
@@ -1402,7 +1403,7 @@ def load_resampled_data(
leads: Optional[Union[str, List[str]]] = None,
data_format: str = "channel_first",
siglen: Optional[int] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""
Resample the data of `rec` to 500Hz,
or load the resampled data in 500Hz, if the corr. data file already exists
@@ -1472,7 +1473,7 @@ def load_resampled_data(
data = np.moveaxis(data, -1, -2)
return data
- def load_raw_data(self, rec: Union[str, int], backend: Literal["wfdb", "scipy"] = "scipy") -> np.ndarray:
+ def load_raw_data(self, rec: Union[str, int], backend: Literal["wfdb", "scipy"] = "scipy") -> NDArray:
"""Load raw data from corresponding files with no further processing.
This method facilitates feeding data into the `run_12ECG_classifier` function.
@@ -1700,7 +1701,7 @@ def database_info(self) -> DataBaseInfo:
def compute_all_metrics_detailed(
classes: List[str], truth: Sequence, binary_pred: Sequence, scalar_pred: Sequence
-) -> Tuple[Union[float, np.ndarray]]:
+) -> Tuple[Union[float, NDArray]]:
"""Compute detailed metrics for each class.
Parameters
@@ -1781,7 +1782,7 @@ def compute_all_metrics_detailed(
def compute_all_metrics(
classes: List[str], truth: Sequence, binary_pred: Sequence, scalar_pred: Sequence
-) -> Tuple[Union[float, np.ndarray]]:
+) -> Tuple[Union[float, NDArray]]:
"""Simplified version of :func:`compute_all_metrics_detailed`.
This function doesnot produce per-class scores.
@@ -1841,7 +1842,7 @@ def compute_all_metrics(
)
-def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float:
+def _compute_accuracy(labels: NDArray, outputs: NDArray) -> float:
"""Compute recording-wise accuracy.
Parameters
@@ -1869,7 +1870,7 @@ def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float:
return float(num_correct_recordings) / float(num_recordings)
-def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normalize: bool = False) -> np.ndarray:
+def _compute_confusion_matrices(labels: NDArray, outputs: NDArray, normalize: bool = False) -> NDArray:
"""Compute confusion matrices.
Compute a binary confusion matrix for each class k:
@@ -1933,7 +1934,7 @@ def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normali
return A
-def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, np.ndarray]:
+def _compute_f_measure(labels: NDArray, outputs: NDArray) -> Tuple[float, NDArray]:
"""Compute macro-averaged F1 score, and F1 score per class.
Parameters
@@ -1973,7 +1974,7 @@ def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float,
return macro_f_measure, f_measure
-def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]:
+def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]:
"""Compute F-beta and G-beta measures.
Parameters
@@ -2018,7 +2019,7 @@ def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real)
return macro_f_beta_measure, macro_g_beta_measure
-def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]:
+def _compute_auc(labels: NDArray, outputs: NDArray) -> Tuple[float, float, NDArray, NDArray]:
"""Compute macro-averaged AUROC and macro-averaged AUPRC.
Parameters
@@ -2125,7 +2126,7 @@ def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float,
# Compute modified confusion matrix for multi-class, multi-label tasks.
-def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) -> np.ndarray:
+def _compute_modified_confusion_matrix(labels: NDArray, outputs: NDArray) -> NDArray:
"""
Compute a binary multi-class, multi-label confusion matrix,
where the rows are the labels and the columns are the outputs.
@@ -2165,9 +2166,9 @@ def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray)
# Compute the evaluation metric for the Challenge.
def compute_challenge_metric(
- weights: np.ndarray,
- labels: np.ndarray,
- outputs: np.ndarray,
+ weights: NDArray,
+ labels: NDArray,
+ outputs: NDArray,
classes: List[str],
sinus_rhythm: str,
) -> float:
diff --git a/torch_ecg/databases/physionet_databases/ltafdb.py b/torch_ecg/databases/physionet_databases/ltafdb.py
index 44638d34..4d458ef7 100644
--- a/torch_ecg/databases/physionet_databases/ltafdb.py
+++ b/torch_ecg/databases/physionet_databases/ltafdb.py
@@ -9,6 +9,7 @@
import numpy as np
import wfdb
+from numpy.typing import NDArray
from ...cfg import CFG
from ...utils.misc import add_docstring
@@ -160,7 +161,7 @@ def load_data(
units: str = "mV",
fs: Optional[Real] = None,
return_fs: bool = False,
- ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]:
+ ) -> Union[NDArray, Tuple[NDArray, Real]]:
return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs)
def load_ann(
@@ -234,7 +235,7 @@ def load_rhythm_ann(
sampto: Optional[int] = None,
rhythm_format: Literal["intervals", "mask"] = "intervals",
keep_original: bool = False,
- ) -> Union[Dict[str, list], np.ndarray]:
+ ) -> Union[Dict[str, list], NDArray]:
"""Load rhythm annotations of the record.
Rhythm annotations are stored in the `aux_note` attribute
@@ -323,7 +324,7 @@ def load_beat_ann(
sampto: Optional[int] = None,
beat_format: Literal["beat", "dict"] = "beat",
keep_original: bool = False,
- ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]:
+ ) -> Union[Dict[str, NDArray], List[BeatAnn]]:
"""Load beat annotations of the record.
Beat annotations are stored in the `symbol` attribute
@@ -390,7 +391,7 @@ def load_rpeak_indices(
sampto: Optional[int] = None,
use_manual: bool = True,
keep_original: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load rpeak indices of the record.
Rpeak indices, or equivalently qrs complex locations,
@@ -437,10 +438,10 @@ def load_rpeak_indices(
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
- beat_ann: Optional[Dict[str, np.ndarray]] = None,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
+ beat_ann: Optional[Dict[str, NDArray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[int, List[int]]] = None,
sampfrom: Optional[int] = None,
diff --git a/torch_ecg/databases/physionet_databases/ludb.py b/torch_ecg/databases/physionet_databases/ludb.py
index 994e79d8..89fd86dc 100644
--- a/torch_ecg/databases/physionet_databases/ludb.py
+++ b/torch_ecg/databases/physionet_databases/ludb.py
@@ -2,13 +2,13 @@
import os
from copy import deepcopy
-from numbers import Real
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import wfdb
+from numpy.typing import NDArray
from ...cfg import CFG, DEFAULTS
from ...utils import EAK
@@ -293,8 +293,6 @@ def __init__(
verbose=verbose,
**kwargs,
)
- if self.version == "1.0.0":
- self.logger.info("Version of LUDB 1.0.0 has bugs, make sure that version 1.0.1 or higher is used")
self.fs = 500
self.spacing = 1000 / self.fs
self.data_ext = "dat"
@@ -320,7 +318,7 @@ def __init__(
self._df_subject_info = None
self._ls_rec()
- def _ls_rec(self) -> None:
+ def _ls_rec(self) -> None: # type: ignore
"""Find all records in the database directory
and store them (path, metadata, etc.) in some private attributes.
"""
@@ -389,7 +387,7 @@ def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = Non
extension = f".{extension}"
return self.db_dir / "data" / f"{rec}{extension or ''}"
- def load_ann(
+ def load_ann( # type: ignore
self,
rec: Union[str, int],
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
@@ -430,28 +428,28 @@ def load_ann(
df_lead_ann["onset"] = np.nan
df_lead_ann["offset"] = np.nan
for i, row in df_lead_ann.iterrows():
- peak_idx = peak_inds[i]
+ peak_idx = peak_inds[i] # type: ignore
if peak_idx == 0:
- df_lead_ann.loc[i, "onset"] = row["peak"]
+ df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore
if symbols[peak_idx + 1] == ")":
- df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
+ df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] # type: ignore
else:
- df_lead_ann.loc[i, "offset"] = row["peak"]
+ df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore
elif peak_idx == len(symbols) - 1:
- df_lead_ann.loc[i, "offset"] = row["peak"]
+ df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore
if symbols[peak_idx - 1] == "(":
- df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
+ df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] # type: ignore
else:
- df_lead_ann.loc[i, "onset"] = row["peak"]
+ df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore
else:
if symbols[peak_idx - 1] == "(":
- df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1]
+ df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] # type: ignore
else:
- df_lead_ann.loc[i, "onset"] = row["peak"]
+ df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore
if symbols[peak_idx + 1] == ")":
- df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1]
+ df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] # type: ignore
else:
- df_lead_ann.loc[i, "offset"] = row["peak"]
+ df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore
# df_lead_ann["onset"] = ann.sample[np.where(symbols=="(")[0]]
# df_lead_ann["offset"] = ann.sample[np.where(symbols==")")[0]]
@@ -501,7 +499,7 @@ def load_masks(
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
mask_format: str = "channel_first",
class_map: Optional[Dict[str, int]] = None,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load the wave delineation in the form of masks.
Parameters
@@ -543,11 +541,11 @@ def load_masks(
def from_masks(
self,
- masks: np.ndarray,
+ masks: NDArray,
mask_format: str = "channel_first",
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
class_map: Optional[Dict[str, int]] = None,
- fs: Optional[Real] = None,
+ fs: Optional[Union[float, int]] = None,
) -> Dict[str, List[ECGWaveForm]]:
"""Convert masks into lists of waveforms.
@@ -565,7 +563,7 @@ def from_masks(
class_map : dict, optional
Custom class map.
If not set, `self.class_map` will be used.
- fs : numbers.Real, optional
+ fs : float or int, optional
Sampling frequency of the signal corresponding to the `masks`,
If is None, `self.fs` will be used,
to compute `duration` of the ECG waveforms.
@@ -655,15 +653,15 @@ def _load_header(self, rec: Union[str, int]) -> dict:
header_dict["adc_gain"] = header_reader.adc_gain
header_dict["record_fmt"] = header_reader.fmt
try:
- header_dict["age"] = int([line for line in header_reader.comments if "" in line][0].split(": ")[-1])
+ header_dict["age"] = int([line for line in header_reader.comments if "" in line][0].split(": ")[-1]) # type: ignore
except Exception:
header_dict["age"] = np.nan
try:
- header_dict["sex"] = [line for line in header_reader.comments if "" in line][0].split(": ")[-1]
+ header_dict["sex"] = [line for line in header_reader.comments if "" in line][0].split(": ")[-1] # type: ignore
except Exception:
header_dict["sex"] = ""
- d_start = [idx for idx, line in enumerate(header_reader.comments) if "" in line][0] + 1
- header_dict["diagnoses"] = header_reader.comments[d_start:]
+ d_start = [idx for idx, line in enumerate(header_reader.comments) if "" in line][0] + 1 # type: ignore
+ header_dict["diagnoses"] = header_reader.comments[d_start:] # type: ignore
return header_dict
def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Sequence[str]]] = None) -> Union[dict, str]:
@@ -685,26 +683,26 @@ def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Se
"""
if isinstance(rec, int):
rec = self[rec]
- row = self._df_subject_info[self._df_subject_info.ID == rec]
+ row = self._df_subject_info[self._df_subject_info.ID == rec] # type: ignore
if row.empty:
return {}
row = row.iloc[0]
info = row.to_dict()
if fields is not None:
if isinstance(fields, str):
- assert fields in self._df_subject_info.columns, f"No field `{fields}`"
+ assert fields in self._df_subject_info.columns, f"No field `{fields}`" # type: ignore
info = info[fields]
else:
assert set(fields).issubset(
- set(self._df_subject_info.columns)
- ), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}"
+ set(self._df_subject_info.columns) # type: ignore
+ ), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}" # type: ignore
info = {k: v for k, v in info.items() if k in fields}
return info
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None,
same_range: bool = False,
@@ -749,29 +747,30 @@ def plot(
Contributors: Jeethan, and WEN Hao
"""
+ import matplotlib.pyplot as plt
+ from matplotlib.ticker import MultipleLocator
+
+ MultipleLocator.MAXTICKS = 3000
+
if isinstance(rec, int):
rec = self[rec]
- if "plt" not in dir():
- import matplotlib.pyplot as plt
-
- plt.MultipleLocator.MAXTICKS = 3000
if data is not None:
assert leads is not None, "`leads` must be specified when `data` is given"
data = np.atleast_2d(data)
_leads = self._normalize_leads(leads)
- _lead_indices = [self.all_leads.index(ld) for ld in _leads]
+ _lead_indices = [self.all_leads.index(ld) for ld in _leads] # type: ignore
assert len(_leads) == data.shape[0], "number of leads must match data"
units = self._auto_infer_units(data)
- self.logger.info(f"input data is auto detected to have units in {units}")
+ self.logger.info(f"input data is auto detected to have units in {units}") # type: ignore
if units.lower() == "mv":
_data = 1000 * data
else:
_data = data
else:
_leads = self._normalize_leads(leads)
- _lead_indices = [self.all_leads.index(ld) for ld in _leads]
- _data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices]
+ _lead_indices = [self.all_leads.index(ld) for ld in _leads] # type: ignore
+ _data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices] # type: ignore
if same_range:
y_ranges = np.ones((_data.shape[0],)) * np.max(np.abs(_data)) + 100
@@ -828,12 +827,12 @@ def plot(
axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
- axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2))
- axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500))
+ axes[idx].xaxis.set_major_locator(MultipleLocator(0.2))
+ axes[idx].yaxis.set_major_locator(MultipleLocator(500))
axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
- axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04))
- axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100))
+ axes[idx].xaxis.set_minor_locator(MultipleLocator(0.04))
+ axes[idx].yaxis.set_minor_locator(MultipleLocator(100))
axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black")
# add extra info. to legend
# https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib
@@ -868,10 +867,10 @@ def database_info(self) -> DataBaseInfo:
def compute_metrics(
- truth_masks: Sequence[np.ndarray],
- pred_masks: Sequence[np.ndarray],
+ truth_masks: Sequence[NDArray],
+ pred_masks: Sequence[NDArray],
class_map: Dict[str, int],
- fs: Real,
+ fs: Union[float, int],
mask_format: str = "channel_first",
) -> Dict[str, Dict[str, float]]:
"""Compute metrics for the wave delineation task.
@@ -892,7 +891,7 @@ def compute_metrics(
class_map : Dict[str, int]
Class map, mapping names to waves to numbers from 0 to n_classes-1,
the keys should contain "pwave", "qrs", "twave".
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal corresponding to the masks,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
@@ -932,7 +931,7 @@ def compute_metrics(
def compute_metrics_waveform(
truth_waveforms: Sequence[Sequence[ECGWaveForm]],
pred_waveforms: Sequence[Sequence[ECGWaveForm]],
- fs: Real,
+ fs: Union[float, int],
) -> Dict[str, Dict[str, float]]:
"""
Compute the sensitivity, precision, f1_score, mean error
@@ -948,7 +947,7 @@ def compute_metrics_waveform(
The predictions corresponding to `truth_waveforms`.
Each element is a sequence of :class:`ECGWaveForm`
from the same sample.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal corresponding to the waveforms,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
@@ -1047,7 +1046,7 @@ def compute_metrics_waveform(
def _compute_metrics_waveform(
- truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Real
+ truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Union[float, int]
) -> Dict[str, Dict[str, float]]:
"""
compute the sensitivity, precision, f1_score, mean error
@@ -1060,7 +1059,7 @@ def _compute_metrics_waveform(
The ground truth.
preds : Sequence[ECGWaveForm]
The predictions corresponding to `truths`,
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal corresponding to the waveforms,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
@@ -1132,17 +1131,17 @@ def _compute_metrics_waveform(
def _compute_metrics_base(
- truths: Sequence[Real], preds: Sequence[Real], fs: Real
+ truths: Sequence[Union[float, int]], preds: Sequence[Union[float, int]], fs: Union[float, int]
) -> Tuple[int, int, int, List[float], float, float, float, float, float]:
"""The base function for computing the metrics.
Parameters
----------
- truths : Sequence[Real]
+ truths : Sequence[Union[float, int]]
Ground truth of indices of corresponding critical points.
- preds : Sequence[Real]
+ preds : Sequence[Union[float, int]]
Predicted indices of corresponding critical points.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal corresponding to the critical points,
used to compute the duration of each waveform,
hence the error and standard deviations of errors.
@@ -1189,7 +1188,7 @@ def _compute_metrics_base(
mean_error = np.mean(errors) * 1000 / fs
standard_deviation = np.std(errors) * 1000 / fs
- return (
+ return ( # type: ignore
truth_positive,
false_negative,
false_positive,
diff --git a/torch_ecg/databases/physionet_databases/mitdb.py b/torch_ecg/databases/physionet_databases/mitdb.py
index e52774b2..e01b8559 100644
--- a/torch_ecg/databases/physionet_databases/mitdb.py
+++ b/torch_ecg/databases/physionet_databases/mitdb.py
@@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import wfdb
+from numpy.typing import NDArray
from tqdm.auto import tqdm
from ...cfg import CFG, DEFAULTS
@@ -325,7 +326,7 @@ def load_rhythm_ann(
rhythm_format: Literal["intervals", "mask"] = "intervals",
rhythm_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
- ) -> Union[Dict[str, list], np.ndarray]:
+ ) -> Union[Dict[str, list], NDArray]:
"""Load rhythm annotations of the record.
Rhythm annotations are stored in the `aux_note` attribute
@@ -374,7 +375,7 @@ def load_beat_ann(
beat_format: Literal["beat", "dict"] = "beat",
beat_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
- ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]:
+ ) -> Union[Dict[str, NDArray], List[BeatAnn]]:
"""Load beat annotations of the record.
Beat annotations are stored in the `symbol` attribute
@@ -420,7 +421,7 @@ def load_rpeak_indices(
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
keep_original: bool = False,
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load rpeak indices of the record.
Rpeak indices, or equivalently qrs complex locations,
@@ -560,10 +561,10 @@ def rhythm_types_records(self) -> Dict[str, List[str]]:
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
- ann: Optional[Dict[str, np.ndarray]] = None,
- beat_ann: Optional[Dict[str, np.ndarray]] = None,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ data: Optional[NDArray] = None,
+ ann: Optional[Dict[str, NDArray]] = None,
+ beat_ann: Optional[Dict[str, NDArray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
ticks_granularity: int = 0,
leads: Optional[Union[int, List[int]]] = None,
sampfrom: Optional[int] = None,
diff --git a/torch_ecg/databases/physionet_databases/ptb_xl.py b/torch_ecg/databases/physionet_databases/ptb_xl.py
index 0287ce22..541f4170 100644
--- a/torch_ecg/databases/physionet_databases/ptb_xl.py
+++ b/torch_ecg/databases/physionet_databases/ptb_xl.py
@@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
import wfdb
+from numpy.typing import NDArray
from tqdm.auto import tqdm
from ...cfg import DEFAULTS
@@ -396,7 +397,7 @@ def _ls_rec(self) -> None:
# Fix potential bugs in the database
self._fix_bugs()
- def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12sl") -> np.ndarray:
+ def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12sl") -> NDArray:
"""Load the data of a record.
Parameters
@@ -426,7 +427,7 @@ def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12s
return wfdb.rdrecord(path).p_signal
@add_docstring(load_data.__doc__)
- def load_median_beats(self, rec: Union[str, int], source: str = "12sl") -> np.ndarray:
+ def load_median_beats(self, rec: Union[str, int], source: str = "12sl") -> NDArray:
"""alias of `load_data`."""
return self.load_data(rec, source)
diff --git a/torch_ecg/databases/physionet_databases/qtdb.py b/torch_ecg/databases/physionet_databases/qtdb.py
index 4185446c..d3795b81 100644
--- a/torch_ecg/databases/physionet_databases/qtdb.py
+++ b/torch_ecg/databases/physionet_databases/qtdb.py
@@ -5,6 +5,7 @@
import numpy as np
import wfdb
+from numpy.typing import NDArray
from ...cfg import CFG
from ...utils.misc import add_docstring
@@ -280,7 +281,7 @@ def load_wave_ann(
keep_original: bool = False,
ignore_beat_types: bool = True,
extension: str = "q1c",
- ) -> np.ndarray:
+ ) -> NDArray:
"""alias of self.load_ann"""
return self.load_ann(
rec,
@@ -299,7 +300,7 @@ def load_wave_masks(
mask_format: str = "channel_first",
class_map: Optional[Dict[str, int]] = None,
extension: str = "q1c",
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load the wave delineation in the form of masks.
Parameters
@@ -341,7 +342,7 @@ def load_rhythm_ann(
rhythm_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
extension: Literal["atr", "man"] = "atr",
- ) -> Union[Dict[str, list], np.ndarray]:
+ ) -> Union[Dict[str, list], NDArray]:
"""Load rhythm annotations of a record.
Rhythm annotations are stored in the `aux_note` attribute
@@ -385,7 +386,7 @@ def load_beat_ann(
beat_types: Optional[Sequence[str]] = None,
keep_original: bool = False,
extension: Literal["atr", "man"] = "atr",
- ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]:
+ ) -> Union[Dict[str, NDArray], List[BeatAnn]]:
"""Load beat annotations of the record.
Beat annotations are stored in the `symbol` attribute
@@ -453,7 +454,7 @@ def load_rpeak_indices(
sampto: Optional[int] = None,
keep_original: bool = False,
extension: Literal["atr", "man"] = "atr",
- ) -> np.ndarray:
+ ) -> NDArray:
"""Load rpeak indices of the record.
Rpeak indices, or equivalently qrs complex locations,
@@ -509,15 +510,15 @@ def load_rpeak_indices(
def plot(
self,
rec: Union[str, int],
- data: Optional[np.ndarray] = None,
+ data: Optional[NDArray] = None,
ticks_granularity: int = 0,
leads: Optional[Union[str, int, List[str], List[int]]] = None,
sampfrom: Optional[int] = None,
sampto: Optional[int] = None,
same_range: bool = False,
waves: Optional[ECGWaveForm] = None,
- beat_ann: Optional[Dict[str, np.ndarray]] = None,
- rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None,
+ beat_ann: Optional[Dict[str, NDArray]] = None,
+ rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None,
**kwargs: Any,
) -> None:
"""
diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py
index 6e4ac52f..4a92825e 100644
--- a/torch_ecg/models/_nets.py
+++ b/torch_ecg/models/_nets.py
@@ -8,7 +8,7 @@
from itertools import repeat
from math import sqrt
from numbers import Real
-from typing import Any, List, Literal, Optional, Sequence, Tuple, Union
+from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -105,7 +105,7 @@
Activations.relu6 = nn.ReLU6
Activations.rrelu = nn.RReLU
Activations.leaky = nn.LeakyReLU
-Activations.leaky_relu = Activations.leaky
+Activations.leaky_relu = Activations.leaky # type: ignore
Activations.gelu = nn.GELU
Activations.silu = nn.SiLU
Activations.elu = nn.ELU
@@ -121,7 +121,7 @@
# Activations.linear = None
-def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict] = None) -> Optional[nn.Module]:
+def get_activation(act: Union[str, nn.Module, None], kw_act: Optional[dict] = None) -> Optional[nn.Module]:
"""Get the class or instance of the activation.
Parameters
@@ -159,7 +159,7 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict
else:
raise ValueError(f"activation `{act}` not supported")
if kw_act is None:
- return _act
+ return _act # type: ignore
return _act(**kw_act)
@@ -167,15 +167,15 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict
# normalizations
Normalizations = CFG()
Normalizations.batch_norm = nn.BatchNorm1d
-Normalizations.batch_normalization = Normalizations.batch_norm
+Normalizations.batch_normalization = Normalizations.batch_norm # type: ignore
Normalizations.group_norm = nn.GroupNorm
-Normalizations.group_normalization = Normalizations.group_norm
+Normalizations.group_normalization = Normalizations.group_norm # type: ignore
Normalizations.layer_norm = nn.LayerNorm
-Normalizations.layer_normalization = Normalizations.layer_norm
+Normalizations.layer_normalization = Normalizations.layer_norm # type: ignore
Normalizations.instance_norm = nn.InstanceNorm1d
-Normalizations.instance_normalization = Normalizations.instance_norm
+Normalizations.instance_normalization = Normalizations.instance_norm # type: ignore
Normalizations.local_response_norm = nn.LocalResponseNorm
-Normalizations.local_response_normalization = Normalizations.local_response_norm
+Normalizations.local_response_normalization = Normalizations.local_response_norm # type: ignore
# other normalizations:
# weight normalization
# batch re-normalization
@@ -189,7 +189,7 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict
# problem: parameters of different normalizations are different
-def get_normalization(norm: Union[str, nn.Module, type(None)], kw_norm: Optional[dict] = None) -> Optional[nn.Module]:
+def get_normalization(norm: Union[str, nn.Module, None], kw_norm: Optional[dict] = None) -> Optional[nn.Module]:
"""Get the class or instance of the normalization.
Parameters
@@ -224,12 +224,12 @@ def get_normalization(norm: Union[str, nn.Module, type(None)], kw_norm: Optional
else:
raise ValueError(f"normalization `{norm}` not supported")
if kw_norm is None:
- return _norm
+ return _norm # type: ignore
if "num_channels" in get_required_args(_norm) and "num_features" in kw_norm:
# for some normalizations, the argument name is `num_channels`
# instead of `num_features`, e.g., `torch.nn.GroupNorm`
kw_norm["num_channels"] = kw_norm.pop("num_features")
- return _norm(**kw_norm)
+ return _norm(**kw_norm) # type: ignore
# ---------------------------------------------
@@ -440,7 +440,7 @@ def __init__(
groups: int = 1,
batch_norm: Union[bool, str, nn.Module] = True,
activation: Optional[Union[str, nn.Module]] = None,
- kernel_initializer: Optional[Union[str, callable]] = None,
+ kernel_initializer: Optional[Union[str, Callable]] = None,
bias: bool = True,
ordering: Optional[str] = None,
**kwargs: Any,
@@ -602,28 +602,28 @@ def __init__(
if bn_layer:
self.add_module("batch_norm", bn_layer)
if act_layer:
- self.add_module(act_name, act_layer)
+ self.add_module(act_name, act_layer) # type: ignore
elif self.__ordering in ["cab"]:
self.add_module("conv1d", conv_layer)
if self.__stride == 1 and self.__kernel_size % 2 == 0:
self.__asymmetric_padding = (1, 0)
self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding))
if act_layer:
- self.add_module(act_name, act_layer)
+ self.add_module(act_name, act_layer) # type: ignore
if bn_layer:
self.add_module("batch_norm", bn_layer)
elif self.__ordering in ["bac", "bc"]:
if bn_layer:
self.add_module("batch_norm", bn_layer)
if act_layer:
- self.add_module(act_name, act_layer)
+ self.add_module(act_name, act_layer) # type: ignore
self.add_module("conv1d", conv_layer)
if self.__stride == 1 and self.__kernel_size % 2 == 0:
self.__asymmetric_padding = (1, 0)
self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding))
elif self.__ordering in ["acb", "ac"]:
if act_layer:
- self.add_module(act_name, act_layer)
+ self.add_module(act_name, act_layer) # type: ignore
self.add_module("conv1d", conv_layer)
if self.__stride == 1 and self.__kernel_size % 2 == 0:
self.__asymmetric_padding = (1, 0)
@@ -638,7 +638,7 @@ def __init__(
self.__asymmetric_padding = (1, 0)
self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding))
if act_layer:
- self.add_module(act_name, act_layer)
+ self.add_module(act_name, act_layer) # type: ignore
elif self.__ordering in ["c"]:
# only convolution
self.add_module("conv1d", conv_layer)
@@ -745,6 +745,8 @@ def compute_output_shape(
*output_shape[:-1],
output_shape[-1] + sum(self.__asymmetric_padding),
)
+ else:
+ output_shape = [None]
return output_shape
@add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block"))
@@ -754,7 +756,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[
strides=self.__stride,
dilations=self.__dilation,
input_len=input_len,
- fs=fs,
+ fs=fs, # type: ignore
)
@property
@@ -883,7 +885,7 @@ def __init__(
if isinstance(dropouts, (Real, dict)):
_dropouts = list(repeat(dropouts, self.__num_convs))
else:
- _dropouts = list(dropouts)
+ _dropouts = list(dropouts) # type: ignore
assert (
len(_dropouts) == self.__num_convs
), f"`dropouts` must be a real number or dict or sequence of real numbers of length {self.__num_convs}"
@@ -896,7 +898,7 @@ def __init__(
len(_dilations) == self.__num_convs
), f"`dilations` must be of type int or sequence of int of length {self.__num_convs}"
- __ordering = self.config.ordering.lower()
+ __ordering = self.config.ordering.lower() # type: ignore
if "a" in __ordering and __ordering.index("a") < __ordering.index("c"):
in_activation = out_activation
out_activation = True
@@ -905,7 +907,7 @@ def __init__(
conv_in_channels = self.__in_channels
for idx, (oc, ks, sd, dl, dp) in enumerate(zip(self.__out_channels, kernel_sizes, strides, _dilations, _dropouts)):
- activation = self.config.activation
+ activation = self.config.activation # type: ignore
if idx == 0 and not in_activation:
activation = None
if idx == self.__num_convs - 1 and not out_activation:
@@ -921,15 +923,15 @@ def __init__(
groups=groups,
norm=self.config.get("norm", self.config.get("batch_norm")),
activation=activation,
- kw_activation=self.config.kw_activation,
- kernel_initializer=self.config.kernel_initializer,
- kw_initializer=self.config.kw_initializer,
- ordering=self.config.ordering,
- conv_type=self.config.conv_type,
- width_multiplier=self.config.width_multiplier,
+ kw_activation=self.config.kw_activation, # type: ignore
+ kernel_initializer=self.config.kernel_initializer, # type: ignore
+ kw_initializer=self.config.kw_initializer, # type: ignore
+ ordering=self.config.ordering, # type: ignore
+ conv_type=self.config.conv_type, # type: ignore
+ width_multiplier=self.config.width_multiplier, # type: ignore
),
)
- conv_in_channels = int(oc * self.config.width_multiplier)
+ conv_in_channels = int(oc * self.config.width_multiplier) # type: ignore
if isinstance(dp, dict):
if dp["type"] == "1d" and dp["p"] > 0:
self.add_module(
@@ -956,7 +958,7 @@ def compute_output_shape(
if hasattr(module, "__name__") and module.__name__ == Conv_Bn_Activation.__name__:
output_shape = module.compute_output_shape(_seq_len, batch_size)
_, _, _seq_len = output_shape
- return output_shape
+ return output_shape # type: ignore
@add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block"))
def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]:
@@ -971,7 +973,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[
strides=strides,
dilations=dilations,
input_len=input_len,
- fs=fs,
+ fs=fs, # type: ignore
)
@property
@@ -1052,7 +1054,7 @@ def __init__(
if isinstance(dropouts, (Real, dict)):
_dropouts = list(repeat(dropouts, self.__num_branches))
else:
- _dropouts = list(dropouts)
+ _dropouts = list(dropouts) # type: ignore
assert (
len(_dropouts) == self.__num_branches
), f"`dropouts` must be a real number or dict or sequence of real numbers of length {self.__num_branches}"
@@ -1203,7 +1205,7 @@ def __init__(
padding: Optional[int] = None,
dilation: int = 1,
groups: int = 1,
- kernel_initializer: Optional[Union[str, callable]] = None,
+ kernel_initializer: Optional[Union[str, Callable]] = None,
bias: bool = True,
**kwargs: Any,
) -> None:
@@ -1320,7 +1322,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[
strides=[self.__stride, 1],
dilations=[self.__dilation, 1],
input_len=input_len,
- fs=fs,
+ fs=fs, # type: ignore
)
@property
@@ -1658,7 +1660,7 @@ def compute_output_shape(
stride=self.__down_scale,
padding=self.__padding,
)[-1]
- output_shape = (batch_size, self.__out_channels, out_seq_len)
+ output_shape = (batch_size, self.__out_channels, out_seq_len) # type: ignore
return output_shape
@property
@@ -1691,7 +1693,7 @@ def __init__(self, padding: Union[int, Sequence[int]]) -> None:
and all([i >= 0 for i in padding])
), "`padding` must be non-negative int or a 2-sequence of non-negative int"
padding = list(repeat(padding, 2)) if isinstance(padding, int) else padding
- super().__init__(padding, 0.0)
+ super().__init__(padding, 0.0) # type: ignore
def compute_output_shape(
self,
@@ -1854,7 +1856,7 @@ def _get_pad_layer(self) -> nn.Module:
PadLayer = ZeroPad1d
else:
raise NotImplementedError(f"Padding type of `{self.__pad_type}` is not implemented")
- return PadLayer(self.__pad_sizes)
+ return PadLayer(self.__pad_sizes) # type: ignore
@add_docstring(_COMPUTE_OUTPUT_SHAPE_DOC)
def compute_output_shape(
@@ -2196,10 +2198,11 @@ def forward(
module.flatten_parameters()
output, _hx = module(output, _hx)
if self.return_sequences:
- final_output = output # seq_len, batch_size, n_direction*hidden_size
+ final_output = output # seq_len, batch_size, n_direction * hidden_size
else:
- final_output = output[-1, ...] # batch_size, n_direction*hidden_size
- return final_output
+ # batch_size, n_direction * hidden_size
+ final_output = output[-1, ...] # type: ignore
+ return final_output # type: ignore
@add_docstring(_COMPUTE_OUTPUT_SHAPE_DOC)
def compute_output_shape(
@@ -2244,12 +2247,12 @@ def __init__(self, in_channels: int, bias: bool = True, initializer: str = "glor
self.init(self.W)
self.u = Parameter(torch.Tensor(in_channels))
- Initializers.constant(self.u, 1 / in_channels)
+ Initializers.constant(self.u, 1 / in_channels) # type: ignore
# self.init(self.u)
if self.bias:
self.b = Parameter(torch.Tensor(in_channels))
- Initializers.zeros(self.b)
+ Initializers.zeros(self.b) # type: ignore
# Initializers["zeros"](self.b)
else:
self.register_parameter("b", None)
@@ -2346,7 +2349,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tens
return output
-@add_docstring(nn.MultiheadAttention.__doc__, "append")
+@add_docstring(nn.MultiheadAttention.__doc__, "append") # type: ignore
class MultiHeadAttention(nn.MultiheadAttention, SizeMixin):
"""Multi-head attention.
@@ -2617,7 +2620,7 @@ def forward(self, input: Tensor) -> Tensor:
input = input.permute(0, 2, 1) # -> (batch_size, seq_len, n_channels
scores = self.dropout(input)
scores = self.mid_linear(scores) # -> (batch_size, seq_len, n_channels)
- scores = self.activation(scores) # -> (batch_size, seq_len, n_channels)
+ scores = self.activation(scores) # -> (batch_size, seq_len, n_channels) # type: ignore
scores = self.contraction(scores) # -> (batch_size, seq_len, 1)
scores = scores.squeeze(-1) # -> (batch_size, seq_len)
scores = self.softmax(scores) # -> (batch_size, seq_len)
@@ -2779,9 +2782,9 @@ def __init__(
self.__dropouts = [dropouts]
else:
self.__dropouts = dropouts
- assert len(self.__dropouts) == self.__num_layers, (
+ assert len(self.__dropouts) == self.__num_layers, ( # type: ignore
f"`out_channels` indicates `{self.__num_layers}` linear layers, "
- f"while `dropouts` indicates `{len(self.__dropouts)}`"
+ f"while `dropouts` indicates `{len(self.__dropouts)}`" # type: ignore
)
self.__skip_last_activation = kwargs.get("skip_last_activation", False)
@@ -2803,10 +2806,10 @@ def __init__(
f"act_{idx}",
act_layer(**kw_activation),
)
- if self.__dropouts[idx] > 0:
+ if self.__dropouts[idx] > 0: # type: ignore
self.add_module(
f"dropout_{idx}",
- nn.Dropout(self.__dropouts[idx]),
+ nn.Dropout(self.__dropouts[idx]), # type: ignore
)
lin_in_channels = self.__out_channels[idx]
@@ -3091,10 +3094,10 @@ def __init__(self, in_channels: int, reduction: int = 16, **config) -> None:
SeqLin(
in_channels=self.__in_channels,
out_channels=[self.__mid_channels, self.__out_channels],
- activation=self.config.activation,
- kw_activation=self.config.kw_activation,
- bias=self.config.bias,
- dropouts=self.config.dropouts,
+ activation=self.config.activation, # type: ignore
+ kw_activation=self.config.kw_activation, # type: ignore
+ bias=self.config.bias, # type: ignore
+ dropouts=self.config.dropouts, # type: ignore
skip_last_activation=True,
),
nn.Sigmoid(),
@@ -3287,7 +3290,7 @@ def spatial_pool(self, x: Tensor) -> Tensor:
context = context.squeeze(1) # --> (batch_size, n_channels, 1)
elif self.__pooling_type == "avg":
context = self.avg_pool(x) # --> (batch_size, n_channels, 1)
- return context
+ return context # type: ignore
def forward(self, input: Tensor) -> Tensor:
"""
@@ -3416,7 +3419,7 @@ def __init__(
self.__gate_channels // reduction,
self.__gate_channels,
],
- activation=activation,
+ activation=activation, # type: ignore
skip_last_activation=True,
)
# self.channel_gate_act = nn.Sigmoid()
@@ -3466,7 +3469,7 @@ def _fwd_channel_gate(self, input: Tensor) -> Tensor:
else:
channel_att_sum = channel_att_sum + channel_att_raw
# scale = torch.sigmoid(channel_att_sum)
- scale = self.channel_gate_act(channel_att_sum)
+ scale = self.channel_gate_act(channel_att_sum) # type: ignore
output = scale.unsqueeze(-1) * input
return output
@@ -3656,8 +3659,8 @@ def __repr__(self) -> str:
def neg_log_likelihood(
self,
emissions: Tensor,
- tags: torch.LongTensor,
- mask: Optional[torch.ByteTensor] = None,
+ tags: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
reduction: str = "sum",
) -> Tensor:
"""
@@ -3668,9 +3671,9 @@ def neg_log_likelihood(
----------
emissions: Tensor,
emission score tensor of shape (seq_len, batch_size, num_tags)
- tags: torch.LongTensor,
+ tags: torch.Tensor,
sequence of tags tensor of shape (seq_len, batch_size)
- mask: torch.ByteTensor,
+ mask: torch.Tensor,
mask tensor of shape (seq_len, batch_size)
reduction: str, default "sum",
specifies the reduction to apply to the output:
@@ -3720,7 +3723,7 @@ def neg_log_likelihood(
nll = nll.sum() / mask.float().sum()
return nll
- def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) -> Tensor:
+ def forward(self, emissions: Tensor, mask: Optional[torch.Tensor] = None) -> Tensor:
"""
Find the most likely tag sequence using Viterbi algorithm.
@@ -3730,7 +3733,7 @@ def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) ->
emission score tensor,
of shape (seq_len, batch_size, num_tags) if batch_first is False,
of shape (batch_size, seq_len, num_tags) if batch_first is True.
- mask: torch.ByteTensor
+ mask: torch.Tensor
mask tensor of shape (seq_len, batch_size) if batch_first is False,
of shape (batch_size, seq_len) if batch_first is True.
@@ -3756,8 +3759,8 @@ def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) ->
def _validate(
self,
emissions: Tensor,
- tags: Optional[torch.LongTensor] = None,
- mask: Optional[torch.ByteTensor] = None,
+ tags: Optional[torch.Tensor] = None,
+ mask: Optional[torch.Tensor] = None,
) -> None:
"""Check validity of input :class:`~torch.Tensor`."""
if emissions.dim() != 3:
@@ -3783,7 +3786,7 @@ def _validate(
if not no_empty_seq and not no_empty_seq_bf:
raise ValueError("mask of the first timestep must all be on")
- def _compute_score(self, emissions: Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> Tensor:
+ def _compute_score(self, emissions: Tensor, tags: torch.Tensor, mask: torch.Tensor) -> Tensor:
"""
# emissions: (seq_len, batch_size, num_tags)
# tags: (seq_len, batch_size)
@@ -3827,7 +3830,7 @@ def _compute_score(self, emissions: Tensor, tags: torch.LongTensor, mask: torch.
return score
- def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> Tensor:
+ def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.Tensor) -> Tensor:
"""
# emissions: (seq_len, batch_size, num_tags)
# mask: (seq_len, batch_size)
@@ -3884,7 +3887,7 @@ def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -
# shape: (batch_size,)
return torch.logsumexp(score, dim=1)
- def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]:
+ def _viterbi_decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]:
"""
# emissions: (seq_len, batch_size, num_tags)
# mask: (seq_len, batch_size)
@@ -4251,18 +4254,18 @@ def __init__(
self.wordvec_proj = nn.Identity()
self.decoder.duplicate_pooling = Parameter(Tensor(decoder_embedding, 1))
self.decoder.duplicate_pooling_bias = Parameter(Tensor(1))
- self.decoder.duplicate_factor = 1
+ self.decoder.duplicate_factor = 1 # type: ignore
else:
# group fully-connected
- self.decoder.out_channels = out_channels
- self.decoder.duplicate_factor = int(out_channels / embed_len_decoder + 0.999)
+ self.decoder.out_channels = out_channels # type: ignore
+ self.decoder.duplicate_factor = int(out_channels / embed_len_decoder + 0.999) # type: ignore
self.decoder.duplicate_pooling = Parameter(
Tensor(embed_len_decoder, decoder_embedding, self.decoder.duplicate_factor)
)
self.decoder.duplicate_pooling_bias = Parameter(Tensor(out_channels))
nn.init.xavier_normal_(self.decoder.duplicate_pooling)
nn.init.constant_(self.decoder.duplicate_pooling_bias, 0)
- self.decoder.group_fc = _GroupFC(embed_len_decoder)
+ self.decoder.group_fc = _GroupFC(embed_len_decoder) # type: ignore
self.train_wordvecs = None
self.test_wordvecs = None
@@ -4427,16 +4430,16 @@ def make_attention_layer(in_channels: int, **config: dict) -> nn.Module:
"""
key = "name" if "name" in config else "type"
assert key in config, "config must contain key 'name' or 'type'"
- name = config[key].lower()
+ name = config[key].lower() # type: ignore
config.pop(key)
if name in ["se"]:
- return SEBlock(in_channels, **config)
+ return SEBlock(in_channels, **config) # type: ignore
elif name in ["gc"]:
- return GlobalContextBlock(in_channels, **config)
+ return GlobalContextBlock(in_channels, **config) # type: ignore
elif name in ["nl", "non-local", "nonlocal", "non_local"]:
- return NonLocalBlock(in_channels, **config)
+ return NonLocalBlock(in_channels, **config) # type: ignore
elif name in ["cbam"]:
- return CBAMBlock(in_channels, **config)
+ return CBAMBlock(in_channels, **config) # type: ignore
elif name in ["ca"]:
# NOT IMPLEMENTED
return CoordAttention(in_channels, **config)
diff --git a/torch_ecg/models/cnn/densenet.py b/torch_ecg/models/cnn/densenet.py
index 333ff894..b774602c 100644
--- a/torch_ecg/models/cnn/densenet.py
+++ b/torch_ecg/models/cnn/densenet.py
@@ -585,41 +585,40 @@ class DenseNet(nn.Sequential, SizeMixin, CitationMixin):
config : dict
Other hyper-parameters of the Module, ref. corresponding config file.
Keyword arguments that must be set are as follows:
-
- - num_layers: sequence of int,
- number of building block layers of each dense (macro) block
- - init_num_filters: sequence of int,
- number of filters of the first convolutional layer
- - init_filter_length: sequence of int,
- filter length (kernel size) of the first convolutional layer
- - init_conv_stride: int,
- stride of the first convolutional layer
- - init_pool_size: int,
- pooling kernel size of the first pooling layer
- - init_pool_stride: int,
- pooling stride of the first pooling layer
- - growth_rates: int or sequence of int or sequence of sequences of int,
- growth rates of the building blocks,
- with granularity to the whole network, or to each dense (macro) block,
- or to each building block
- - filter_lengths: int or sequence of int or sequence of sequences of int,
- filter length(s) (kernel size(s)) of the convolutions,
- with granularity to the whole network, or to each macro block,
- or to each building block
- - subsample_lengths: int or sequence of int,
- subsampling length(s) (ratio(s)) of the transition blocks
- - compression: float,
- compression factor of the transition blocks
- - bn_size: int,
- bottleneck base width, used only when building block is :class:`DenseBottleNeck`
- - dropouts: float or dict,
- dropout ratio of each building block
- - groups: int,
- connection pattern (of channels) of the inputs and outputs
- - block: dict,
- other parameters that can be set for the building blocks
-
- For a full list of configurable parameters, ref. corr. config file
+ - num_layers: sequence of int,
+ number of building block layers of each dense (macro) block.
+ - init_num_filters: sequence of int,
+ number of filters of the first convolutional layer.
+ - init_filter_length: sequence of int,
+ filter length (kernel size) of the first convolutional layer.
+ - init_conv_stride: int,
+ stride of the first convolutional layer.
+ - init_pool_size: int,
+ pooling kernel size of the first pooling layer.
+ - init_pool_stride: int,
+ pooling stride of the first pooling layer.
+ - growth_rates: int or sequence of int or sequence of sequences of int,
+ growth rates of the building blocks,
+ with granularity to the whole network, or to each dense (macro) block,
+ or to each building block.
+ - filter_lengths: int or sequence of int or sequence of sequences of int,
+ filter length(s) (kernel size(s)) of the convolutions,
+ with granularity to the whole network, or to each macro block,
+ or to each building block.
+ - subsample_lengths: int or sequence of int,
+ subsampling length(s) (ratio(s)) of the transition blocks.
+ - compression: float,
+ compression factor of the transition blocks.
+ - bn_size: int,
+ bottleneck base width, used only when building block is :class:`DenseBottleNeck`.
+ - dropouts: float or dict,
+ dropout ratio of each building block.
+ - groups: int,
+ connection pattern (of channels) of the inputs and outputs.
+ - block: dict,
+ other parameters that can be set for the building blocks.
+
+ For a full list of configurable parameters, ref. corr. config file.
NOTE
----
diff --git a/torch_ecg/models/cnn/mobilenet.py b/torch_ecg/models/cnn/mobilenet.py
index b50a270e..86fee0a9 100644
--- a/torch_ecg/models/cnn/mobilenet.py
+++ b/torch_ecg/models/cnn/mobilenet.py
@@ -214,19 +214,18 @@ class MobileNetV1(nn.Sequential, SizeMixin, CitationMixin):
Other hyper-parameters of the Module, ref. corresponding config file.
key word arguments that have to be set in 3 sub-dict,
namely in "entry_flow", "middle_flow", and "exit_flow", including
-
- - out_channels: int,
- number of channels of the output.
- - kernel_size: int,
- kernel size of down sampling.
- If not specified, defaults to `down_scale`.
- - groups: int,
- connection pattern (of channels) of the inputs and outputs.
- - padding: int,
- zero-padding added to both sides of the input.
- - batch_norm: bool or Module,
- batch normalization, the Module itself
- or (if is bool) whether or not to use :class:`torch.nn.BatchNorm1d`.
+ - out_channels: int,
+ number of channels of the output.
+ - kernel_size: int,
+ kernel size of down sampling.
+ If not specified, defaults to `down_scale`.
+ - groups: int,
+ connection pattern (of channels) of the inputs and outputs.
+ - padding: int,
+ zero-padding added to both sides of the input.
+ - batch_norm: bool or Module,
+ batch normalization, the Module itself
+ or (if is bool) whether or not to use :class:`torch.nn.BatchNorm1d`.
References
----------
@@ -457,9 +456,9 @@ class InvertedResidual(nn.Module, SizeMixin):
If is None, no attention mechanism is used.
Keys:
- - "name": str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
- - "pos": int, position of the attention mechanism,
- other keys are specific to the attention mechanism.
+ - "name": str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
+ - "pos": int, position of the attention mechanism,
+ other keys are specific to the attention mechanism.
"""
@@ -669,36 +668,36 @@ class MobileNetV2(nn.Sequential, SizeMixin, CitationMixin):
- stem: CFG,
config of the stem block, with the following keys:
- - num_filters: int or Sequence[int],
- number of filters in the first convolutional layer(s).
- - filter_lengths: int or Sequence[int],
- filter lengths (kernel sizes) in the first convolutional layer(s).
- - subsample_lengths: int or Sequence[int],
- subsample lengths (strides) in the first convolutional layer(s).
+ - num_filters: int or Sequence[int],
+ number of filters in the first convolutional layer(s).
+ - filter_lengths: int or Sequence[int],
+ filter lengths (kernel sizes) in the first convolutional layer(s).
+ - subsample_lengths: int or Sequence[int],
+ subsample lengths (strides) in the first convolutional layer(s).
- inv_res: CFG,
Config of the inverted residual blocks, with the following keys:
- - expansions: Sequence[int],
- expansion ratios of the inverted residual blocks.
- - out_channels: Sequence[int],
- number of output channels in each block.
- - n_blocks: Sequence[int],
- number of inverted residual blocks.
- - strides: Sequence[int],
- strides of the inverted residual blocks.
- - filter_lengths: Sequence[int],
- filter lengths (kernel sizes) in each block.
+ - expansions: Sequence[int],
+ expansion ratios of the inverted residual blocks.
+ - out_channels: Sequence[int],
+ number of output channels in each block.
+ - n_blocks: Sequence[int],
+ number of inverted residual blocks.
+ - strides: Sequence[int],
+ strides of the inverted residual blocks.
+ - filter_lengths: Sequence[int],
+ filter lengths (kernel sizes) in each block.
- exit_flow: CFG,
Config of the exit flow blocks, with the following keys:
- - num_filters: int or Sequence[int],
- number of filters in the final convolutional layer(s).
- - filter_lengths: int or Sequence[int],
- filter lengths (kernel sizes) in the final convolutional layer(s).
- - subsample_lengths: int or Sequence[int],
- subsample lengths (strides) in the final convolutional layer(s).
+ - num_filters: int or Sequence[int],
+ number of filters in the final convolutional layer(s).
+ - filter_lengths: int or Sequence[int],
+ filter lengths (kernel sizes) in the final convolutional layer(s).
+ - subsample_lengths: int or Sequence[int],
+ subsample lengths (strides) in the final convolutional layer(s).
References
----------
@@ -1000,12 +999,12 @@ class MobileNetV3_STEM(nn.Sequential, SizeMixin):
config : CFG, optional
Config of the stem block, with the following items:
- - num_filters: int or Sequence[int],
- number of filters in the first convolutional layer(s).
- - filter_lengths: int or Sequence[int],
- filter lengths (kernel sizes) in the first convolutional layer(s).
- - subsample_lengths: int or Sequence[int],
- subsample lengths (strides) in the first convolutional layer(s).
+ - num_filters: int or Sequence[int],
+ number of filters in the first convolutional layer(s).
+ - filter_lengths: int or Sequence[int],
+ filter lengths (kernel sizes) in the first convolutional layer(s).
+ - subsample_lengths: int or Sequence[int],
+ subsample lengths (strides) in the first convolutional layer(s).
"""
@@ -1088,64 +1087,64 @@ class MobileNetV3(nn.Sequential, SizeMixin, CitationMixin):
Other hyper-parameters of the Module, ref. corresponding config file.
Keyword arguments that must be set:
- - groups: int,
- number of groups in the convolutional layer(s) other than depthwise convolutions.
- - norm: bool or str or Module,
- normalization layer.
- - bias: bool,
- whether to use bias in the convolutional layer(s).
- - width_multiplier: float,
- multiplier of the number of output channels of the pointwise convolution.
- - stem: CFG,
- config of the stem block, with the following keys:
-
- - num_filters: int or Sequence[int],
- number of filters in the first convolutional layer(s).
- - filter_lengths: int or Sequence[int],
- filter lengths (kernel sizes) in the first convolutional layer(s).
- - subsample_lengths: int or Sequence[int],
- subsample lengths (strides) in the first convolutional layer(s).
-
- - inv_res: CFG,
- config of the inverted residual blocks, with the following keys:
-
- - in_channels: Sequence[int],
- number of input channels.
- - n_blocks: Sequence[int],
- number of inverted residual blocks.
- - expansions: sequence of floats or sequence of sequence of floats,
- expansion ratios of the inverted residual blocks.
- - filter_lengths: sequence of ints or sequence of sequence of ints,
- filter length of the depthwise convolution in the inverted residual blocks.
- - stride: sequence of ints or sequence of sequence of ints, optional,
- stride of the depthwise convolution in the inverted residual blocks,
- defaults to ``[2] + [1] * (n_blocks - 1)``.
- - groups: int, default 1,
- number of groups in the expansion and pointwise convolution
- in the inverted residual blocks.
- - dilation: sequence of ints or sequence of sequence of ints, optional,
- dilation of the depthwise convolution in the inverted residual blocks.
- - batch_norm: bool or str or nn.Module, default True,
- normalization layer to use, defaults to batch normalization.
- - activation: str or nn.Module or sequence of str or torch.nn.Module,
- activation function to use.
- - width_multiplier: float or sequence of floats, default 1.0,
- width multiplier of the inverted residual blocks.
- - out_channels: sequence of ints or sequence of Sequence[int], optional,
- number of output channels of the inverted residual blocks,
- defaults to ``2 * in_channels``.
- - attn: sequence of CFG or sequence of sequence of CFG, optional,
- config of attention layer to use, defaults to None.
-
- - exit_flow: CFG,
- config of the exit flow blocks, with the following keys:
-
- - num_filters: int or Sequence[int],
- number of filters in the final convolutional layer(s).
- - filter_lengths: int or Sequence[int],
- filter lengths (kernel sizes) in the final convolutional layer(s).
- - subsample_lengths: int or Sequence[int],
- subsample lengths (strides) in the final convolutional layer(s).
+ - groups: int,
+ number of groups in the convolutional layer(s) other than depthwise convolutions.
+ - norm: bool or str or Module,
+ normalization layer.
+ - bias: bool,
+ whether to use bias in the convolutional layer(s).
+ - width_multiplier: float,
+ multiplier of the number of output channels of the pointwise convolution.
+ - stem: CFG,
+ config of the stem block, with the following keys:
+
+ - num_filters: int or Sequence[int],
+ number of filters in the first convolutional layer(s).
+ - filter_lengths: int or Sequence[int],
+ filter lengths (kernel sizes) in the first convolutional layer(s).
+ - subsample_lengths: int or Sequence[int],
+ subsample lengths (strides) in the first convolutional layer(s).
+
+ - inv_res: CFG,
+ config of the inverted residual blocks, with the following keys:
+
+ - in_channels: Sequence[int],
+ number of input channels.
+ - n_blocks: Sequence[int],
+ number of inverted residual blocks.
+ - expansions: sequence of floats or sequence of sequence of floats,
+ expansion ratios of the inverted residual blocks.
+ - filter_lengths: sequence of ints or sequence of sequence of ints,
+ filter length of the depthwise convolution in the inverted residual blocks.
+ - stride: sequence of ints or sequence of sequence of ints, optional,
+ stride of the depthwise convolution in the inverted residual blocks,
+ defaults to ``[2] + [1] * (n_blocks - 1)``.
+ - groups: int, default 1,
+ number of groups in the expansion and pointwise convolution
+ in the inverted residual blocks.
+ - dilation: sequence of ints or sequence of sequence of ints, optional,
+ dilation of the depthwise convolution in the inverted residual blocks.
+ - batch_norm: bool or str or nn.Module, default True,
+ normalization layer to use, defaults to batch normalization.
+ - activation: str or nn.Module or sequence of str or torch.nn.Module,
+ activation function to use.
+ - width_multiplier: float or sequence of floats, default 1.0,
+ width multiplier of the inverted residual blocks.
+ - out_channels: sequence of ints or sequence of Sequence[int], optional,
+ number of output channels of the inverted residual blocks,
+ defaults to ``2 * in_channels``.
+ - attn: sequence of CFG or sequence of sequence of CFG, optional,
+ config of attention layer to use, defaults to None.
+
+ - exit_flow: CFG,
+ config of the exit flow blocks, with the following keys:
+
+ - num_filters: int or Sequence[int],
+ number of filters in the final convolutional layer(s).
+ - filter_lengths: int or Sequence[int],
+ filter lengths (kernel sizes) in the final convolutional layer(s).
+ - subsample_lengths: int or Sequence[int],
+ subsample lengths (strides) in the final convolutional layer(s).
References
----------
diff --git a/torch_ecg/models/cnn/multi_scopic.py b/torch_ecg/models/cnn/multi_scopic.py
index af2c81b0..fb5b61bf 100644
--- a/torch_ecg/models/cnn/multi_scopic.py
+++ b/torch_ecg/models/cnn/multi_scopic.py
@@ -368,28 +368,28 @@ class MultiScopicCNN(nn.Module, SizeMixin, CitationMixin):
Other hyper-parameters of the Module, ref. corr. config file.
Key word arguments that must be set:
- - scopes: sequence of sequences of sequences of :obj:`int`,
- scopes (in terms of dilation) of each convolution.
- - num_filters: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`),
- number of filters of the convolutional layers,
- with granularity to each block of each branch,
- or to each convolution of each block of each branch.
- - filter_lengths: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`),
- filter length(s) (kernel size(s)) of the convolutions,
- with granularity to each block of each branch,
- or to each convolution of each block of each branch.
- - subsample_lengths: sequence of :obj:`int` or sequence of sequences of :obj:`int`,
- subsampling length(s) (ratio(s)) of all blocks,
- with granularity to each branch or to each block of each branch,
- each subsamples after the last convolution of each block.
- - dropouts: sequence of :obj:`int` or sequence of sequences of :obj:`int`,
- dropout rates of all blocks,
- with granularity to each branch or to each block of each branch,
- each dropouts at the last of each block.
- - groups: :obj:`int`,
- connection pattern (of channels) of the inputs and outputs.
- - block: :obj:`dict`,
- other parameters that can be set for the building blocks.
+ - scopes: sequence of sequences of sequences of :obj:`int`,
+ scopes (in terms of dilation) of each convolution.
+ - num_filters: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`),
+ number of filters of the convolutional layers,
+ with granularity to each block of each branch,
+ or to each convolution of each block of each branch.
+ - filter_lengths: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`),
+ filter length(s) (kernel size(s)) of the convolutions,
+ with granularity to each block of each branch,
+ or to each convolution of each block of each branch.
+ - subsample_lengths: sequence of :obj:`int` or sequence of sequences of :obj:`int`,
+ subsampling length(s) (ratio(s)) of all blocks,
+ with granularity to each branch or to each block of each branch,
+ each subsamples after the last convolution of each block.
+ - dropouts: sequence of :obj:`int` or sequence of sequences of :obj:`int`,
+ dropout rates of all blocks,
+ with granularity to each branch or to each block of each branch,
+ each dropouts at the last of each block.
+ - groups: :obj:`int`,
+ connection pattern (of channels) of the inputs and outputs.
+ - block: :obj:`dict`,
+ other parameters that can be set for the building blocks.
For a full list of configurable parameters, ref. corr. config file.
diff --git a/torch_ecg/models/cnn/regnet.py b/torch_ecg/models/cnn/regnet.py
index 9469c21a..b0e3795b 100644
--- a/torch_ecg/models/cnn/regnet.py
+++ b/torch_ecg/models/cnn/regnet.py
@@ -42,49 +42,50 @@ class AnyStage(nn.Sequential, SizeMixin):
Index of the stage in the whole :class:`RegNet`.
block_config: dict,
(optional) configs for the blocks, including
- - block: str or torch.nn.Module,
- the block class, can be one of
- "bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc.
- - expansion: int,
- the expansion factor for the bottleneck block.
- - increase_channels_method: str,
- the method to increase the number of channels,
- can be one of {"conv", "zero_padding"}.
- - subsample_mode: str,
- the mode of subsampling, can be one of
- {:class:`DownSample`.__MODES__},
- - activation: str or torch.nn.Module,
- the activation function, can be one of
- {:class:`Activations`}.
- - kw_activation: dict,
- keyword arguments for the activation function.
- - kernel_initializer: str,
- the kernel initializer, can be one of
- {:class:`Initializers`}.
- - kw_initializer: dict,
- keyword arguments for the kernel initializer.
- - bias: bool,
- whether to use bias in the convolution.
- - dilation: int,
- the dilation factor for the convolution.
- - base_width: int,
- number of filters per group for the neck conv layer
- usually number of filters of the initial conv layer
- of the whole :class:`RegNet`.
- - base_groups: int,
- pattern of connections between inputs and outputs of
- conv layers at the two ends, which should divide `groups`.
- - base_filter_length: int,
- lengths (sizes) of the filter kernels for conv layers at the two ends.
- - attn: dict,
- attention mechanism for the neck conv layer.
- If is None, no attention mechanism is used.
- If is not None, it should be a dict with the following items:
-
- - name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
- - pos: int, position of the attention mechanism.
-
- Other keys are specific to the attention mechanism.
+
+ - block: str or torch.nn.Module,
+ the block class, can be one of
+ "bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc.
+ - expansion: int,
+ the expansion factor for the bottleneck block.
+ - increase_channels_method: str,
+ the method to increase the number of channels,
+ can be one of {"conv", "zero_padding"}.
+ - subsample_mode: str,
+ the mode of subsampling, can be one of
+ {:class:`DownSample`.__MODES__},
+ - activation: str or torch.nn.Module,
+ the activation function, can be one of
+ {:class:`Activations`}.
+ - kw_activation: dict,
+ keyword arguments for the activation function.
+ - kernel_initializer: str,
+ the kernel initializer, can be one of
+ {:class:`Initializers`}.
+ - kw_initializer: dict,
+ keyword arguments for the kernel initializer.
+ - bias: bool,
+ whether to use bias in the convolution.
+ - dilation: int,
+ the dilation factor for the convolution.
+ - base_width: int,
+ number of filters per group for the neck conv layer
+ usually number of filters of the initial conv layer
+ of the whole :class:`RegNet`.
+ - base_groups: int,
+ pattern of connections between inputs and outputs of
+ conv layers at the two ends, which should divide `groups`.
+ - base_filter_length: int,
+ lengths (sizes) of the filter kernels for conv layers at the two ends.
+ - attn: dict,
+ attention mechanism for the neck conv layer.
+ If is None, no attention mechanism is used.
+ If is not None, it should be a dict with the following items:
+
+ - name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
+ - pos: int, position of the attention mechanism.
+
+ Other keys are specific to the attention mechanism.
"""
@@ -283,31 +284,31 @@ class RegNet(nn.Sequential, SizeMixin, CitationMixin):
Hyper-parameters of the Module, ref. corr. config file.
Keyword arguments that must be set:
- - filter_lengths: int or sequence of int,
- filter length(s) (kernel size(s)) of the convolutions,
- with granularity to the whole network, to each stage.
- - subsample_lengths: int or sequence of int,
- subsampling length(s) (ratio(s)) of all blocks,
- with granularity to the whole network, to each stage.
- - tot_blocks: int,
- the total number of building blocks.
- - w_a, w_0, w_m: float,
- the parameters for the widths generating function.
- - group_widths: int or sequence of int,
- the number of channels in each group,
- with granularity to the whole network, to each stage.
- - num_blocks: sequence of int, optional,
- the number of blocks in each stage,
- if not given, will be computed from tot_blocks
- and `w_a`, `w_0`, `w_m`.
- - num_filters: int or sequence of int, optional,
- the number of filters in each stage.
- If not given, will be computed from tot_blocks
- and `w_a`, `w_0`, `w_m`.
- - stem: dict,
- the config of the input stem.
- - block: dict,
- other parameters that can be set for the building blocks.
+ - filter_lengths: int or sequence of int,
+ filter length(s) (kernel size(s)) of the convolutions,
+ with granularity to the whole network, to each stage.
+ - subsample_lengths: int or sequence of int,
+ subsampling length(s) (ratio(s)) of all blocks,
+ with granularity to the whole network, to each stage.
+ - tot_blocks: int,
+ the total number of building blocks.
+ - w_a, w_0, w_m: float,
+ the parameters for the widths generating function.
+ - group_widths: int or sequence of int,
+ the number of channels in each group,
+ with granularity to the whole network, to each stage.
+ - num_blocks: sequence of int, optional,
+ the number of blocks in each stage,
+ if not given, will be computed from tot_blocks
+ and `w_a`, `w_0`, `w_m`.
+ - num_filters: int or sequence of int, optional,
+ the number of filters in each stage.
+ If not given, will be computed from tot_blocks
+ and `w_a`, `w_0`, `w_m`.
+ - stem: dict,
+ the config of the input stem.
+ - block: dict,
+ other parameters that can be set for the building blocks.
"""
diff --git a/torch_ecg/models/cnn/resnet.py b/torch_ecg/models/cnn/resnet.py
index 551216cc..39122731 100644
--- a/torch_ecg/models/cnn/resnet.py
+++ b/torch_ecg/models/cnn/resnet.py
@@ -63,10 +63,10 @@ class ResNetBasicBlock(nn.Module, SizeMixin):
If is None, no attention mechanism is used.
keys:
- - name: str,
- can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
- - pos: int,
- position of the attention mechanism.
+ - name: str,
+ can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
+ - pos: int,
+ position of the attention mechanism.
Other keys are specific to the attention mechanism.
config : dict
@@ -297,10 +297,10 @@ class ResNetBottleNeck(nn.Module, SizeMixin):
If is None, no attention mechanism is used.
Keys:
- - name: str,
- can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
- - pos: int,
- position of the attention mechanism.
+ - name: str,
+ can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc.
+ - pos: int,
+ position of the attention mechanism.
Other keys are specific to the attention mechanism.
config : dict
@@ -643,25 +643,25 @@ class ResNet(nn.Sequential, SizeMixin, CitationMixin):
Hyper-parameters of the Module, ref. corr. config file.
keyword arguments that must be set:
- - bias: bool,
- if True, each convolution will have a bias term.
- - num_blocks: sequence of int,
- number of building blocks in each macro block.
- - filter_lengths: int or sequence of int or sequence of sequences of int,
- filter length(s) (kernel size(s)) of the convolutions,
- with granularity to the whole network, to each macro block,
- or to each building block.
- - subsample_lengths: int or sequence of int or sequence of sequences of int,
- subsampling length(s) (ratio(s)) of all blocks,
- with granularity to the whole network, to each macro block,
- or to each building block,
- the former 2 subsample at the first building block.
- - groups: int,
- connection pattern (of channels) of the inputs and outputs.
- - stem: dict,
- other parameters that can be set for the input stem.
- - block: dict,
- other parameters that can be set for the building blocks.
+ - bias: bool,
+ if True, each convolution will have a bias term.
+ - num_blocks: sequence of int,
+ number of building blocks in each macro block.
+ - filter_lengths: int or sequence of int or sequence of sequences of int,
+ filter length(s) (kernel size(s)) of the convolutions,
+ with granularity to the whole network, to each macro block,
+ or to each building block.
+ - subsample_lengths: int or sequence of int or sequence of sequences of int,
+ subsampling length(s) (ratio(s)) of all blocks,
+ with granularity to the whole network, to each macro block,
+ or to each building block,
+ the former 2 subsample at the first building block.
+ - groups: int,
+ connection pattern (of channels) of the inputs and outputs.
+ - stem: dict,
+ other parameters that can be set for the input stem.
+ - block: dict,
+ other parameters that can be set for the building blocks.
For a full list of configurable parameters, ref. corr. config file.
diff --git a/torch_ecg/models/cnn/vgg.py b/torch_ecg/models/cnn/vgg.py
index e0fc8ea0..cad9e128 100644
--- a/torch_ecg/models/cnn/vgg.py
+++ b/torch_ecg/models/cnn/vgg.py
@@ -140,14 +140,14 @@ class VGG16(nn.Sequential, SizeMixin, CitationMixin):
and more for :class:`VGGBlock`.
Key word arguments that have to be set:
- - num_convs: sequence of int,
- number of convolutional layers for each :class:`VGGBlock`.
- - num_filters: sequence of int,
- number of filters for each :class:`VGGBlock`.
- - groups: int,
- connection pattern (of channels) of the inputs and outputs.
- - block: dict,
- other parameters that can be set for :class:`VGGBlock`.
+ - num_convs: sequence of int,
+ number of convolutional layers for each :class:`VGGBlock`.
+ - num_filters: sequence of int,
+ number of filters for each :class:`VGGBlock`.
+ - groups: int,
+ connection pattern (of channels) of the inputs and outputs.
+ - block: dict,
+ other parameters that can be set for :class:`VGGBlock`.
For a full list of configurable parameters, ref. corr. config file.
diff --git a/torch_ecg/models/ecg_crnn.py b/torch_ecg/models/ecg_crnn.py
index 9ef14340..85627d4c 100644
--- a/torch_ecg/models/ecg_crnn.py
+++ b/torch_ecg/models/ecg_crnn.py
@@ -2,14 +2,15 @@
C(R)NN structure models, for classifying ECG arrhythmias, and other tasks.
"""
+import os
import warnings
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Tuple, Union
-import numpy as np
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
+from numpy.typing import NDArray
from torch import Tensor, nn
from ..cfg import CFG
@@ -79,8 +80,8 @@ def __init__(
warnings.warn("No config is provided, using default config.", RuntimeWarning)
self.config.update(deepcopy(config) or {})
- cnn_choice = self.config.cnn.name.lower()
- cnn_config = self.config.cnn[self.config.cnn.name]
+ cnn_choice = self.config.cnn.name.lower() # type: ignore
+ cnn_config = self.config.cnn[self.config.cnn.name] # type: ignore
if "resnet" in cnn_choice or "resnext" in cnn_choice:
self.cnn = ResNet(self.n_leads, **cnn_config)
elif "regnet" in cnn_choice:
@@ -106,114 +107,118 @@ def __init__(
raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet")
rnn_input_size = self.cnn.compute_output_shape(None, None)[1]
- if self.config.rnn.name.lower() == "none":
+ if self.config.rnn.name.lower() == "none": # type: ignore
self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels")
self.rnn = nn.Identity()
self.__rnn_seqlen_dim = 0
self.rnn_out_rearrange = nn.Identity()
attn_input_size = rnn_input_size
- elif self.config.rnn.name.lower() == "lstm":
+ elif self.config.rnn.name.lower() == "lstm": # type: ignore
self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels")
self.rnn = StackedLSTM(
- input_size=rnn_input_size,
- hidden_sizes=self.config.rnn.lstm.hidden_sizes,
- bias=self.config.rnn.lstm.bias,
- dropouts=self.config.rnn.lstm.dropouts,
- bidirectional=self.config.rnn.lstm.bidirectional,
- return_sequences=self.config.rnn.lstm.retseq,
+ input_size=rnn_input_size, # type: ignore
+ hidden_sizes=self.config.rnn.lstm.hidden_sizes, # type: ignore
+ bias=self.config.rnn.lstm.bias, # type: ignore
+ dropouts=self.config.rnn.lstm.dropouts, # type: ignore
+ bidirectional=self.config.rnn.lstm.bidirectional, # type: ignore
+ return_sequences=self.config.rnn.lstm.retseq, # type: ignore
)
self.__rnn_seqlen_dim = 0
self.rnn_out_rearrange = nn.Identity()
attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
- elif self.config.rnn.name.lower() == "linear":
+ elif self.config.rnn.name.lower() == "linear": # type: ignore
# abuse of notation, to put before the global attention module
self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels")
self.rnn = MLP(
- in_channels=rnn_input_size,
- out_channels=self.config.rnn.linear.out_channels,
- activation=self.config.rnn.linear.activation,
- bias=self.config.rnn.linear.bias,
- dropouts=self.config.rnn.linear.dropouts,
+ in_channels=rnn_input_size, # type: ignore
+ out_channels=self.config.rnn.linear.out_channels, # type: ignore
+ activation=self.config.rnn.linear.activation, # type: ignore
+ bias=self.config.rnn.linear.bias, # type: ignore
+ dropouts=self.config.rnn.linear.dropouts, # type: ignore
)
self.__rnn_seqlen_dim = 1
self.rnn_out_rearrange = Rearrange("batch_size seq_len channels -> seq_len batch_size channels")
attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
else:
- raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet")
+ raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # type: ignore
# attention
- if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq:
+ if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore
self.attn_in_rearrange = nn.Identity()
self.attn = nn.Identity()
self.__attn_seqlen_dim = 0
self.attn_out_rearrange = nn.Identity()
clf_input_size = attn_input_size
- if self.config.attn.name.lower() != "none":
+ if self.config.attn.name.lower() != "none": # type: ignore
warnings.warn(
- f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored",
+ f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", # type: ignore
RuntimeWarning,
)
- elif self.config.attn.name.lower() == "none":
+ elif self.config.attn.name.lower() == "none": # type: ignore
self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len")
self.attn = nn.Identity()
self.__attn_seqlen_dim = -1
self.attn_out_rearrange = nn.Identity()
clf_input_size = attn_input_size
- elif self.config.attn.name.lower() == "nl": # non_local
+ elif self.config.attn.name.lower() == "nl": # type: ignore
+ # non_local
self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len")
self.attn = NonLocalBlock(
- in_channels=attn_input_size,
- filter_lengths=self.config.attn.nl.filter_lengths,
- subsample_length=self.config.attn.nl.subsample_length,
- batch_norm=self.config.attn.nl.batch_norm,
+ in_channels=attn_input_size, # type: ignore
+ filter_lengths=self.config.attn.nl.filter_lengths, # type: ignore
+ subsample_length=self.config.attn.nl.subsample_length, # type: ignore
+ batch_norm=self.config.attn.nl.batch_norm, # type: ignore
)
self.__attn_seqlen_dim = -1
self.attn_out_rearrange = nn.Identity()
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "se": # squeeze_exitation
+ elif self.config.attn.name.lower() == "se": # type: ignore
+ # squeeze_exitation
self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len")
self.attn = SEBlock(
- in_channels=attn_input_size,
- reduction=self.config.attn.se.reduction,
- activation=self.config.attn.se.activation,
- kw_activation=self.config.attn.se.kw_activation,
- bias=self.config.attn.se.bias,
+ in_channels=attn_input_size, # type: ignore
+ reduction=self.config.attn.se.reduction, # type: ignore
+ activation=self.config.attn.se.activation, # type: ignore
+ kw_activation=self.config.attn.se.kw_activation, # type: ignore
+ bias=self.config.attn.se.bias, # type: ignore
)
self.__attn_seqlen_dim = -1
self.attn_out_rearrange = nn.Identity()
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "gc": # global_context
+ elif self.config.attn.name.lower() == "gc": # type: ignore
+ # global_context
self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len")
self.attn = GlobalContextBlock(
- in_channels=attn_input_size,
- ratio=self.config.attn.gc.ratio,
- reduction=self.config.attn.gc.reduction,
- pooling_type=self.config.attn.gc.pooling_type,
- fusion_types=self.config.attn.gc.fusion_types,
+ in_channels=attn_input_size, # type: ignore
+ ratio=self.config.attn.gc.ratio, # type: ignore
+ reduction=self.config.attn.gc.reduction, # type: ignore
+ pooling_type=self.config.attn.gc.pooling_type, # type: ignore
+ fusion_types=self.config.attn.gc.fusion_types, # type: ignore
)
self.__attn_seqlen_dim = -1
self.attn_out_rearrange = nn.Identity()
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "sa": # self_attention
+ elif self.config.attn.name.lower() == "sa": # type: ignore
+ # self_attention
# NOTE: this branch NOT tested
self.attn_in_rearrange = nn.Identity()
- self.attn = SelfAttention(
+ self.attn = SelfAttention( # type: ignore
embed_dim=attn_input_size,
- num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")),
- dropout=self.config.attn.sa.dropout,
- bias=self.config.attn.sa.bias,
+ num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), # type: ignore
+ dropout=self.config.attn.sa.dropout, # type: ignore
+ bias=self.config.attn.sa.bias, # type: ignore
)
self.__attn_seqlen_dim = 0
self.attn_out_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len")
clf_input_size = self.attn.compute_output_shape(None, None)[-1]
- elif self.config.attn.name.lower() == "transformer":
+ elif self.config.attn.name.lower() == "transformer": # type: ignore
self.attn = Transformer(
- input_size=attn_input_size,
- hidden_size=self.config.attn.transformer.hidden_size,
- num_layers=self.config.attn.transformer.num_layers,
- num_heads=self.config.attn.transformer.num_heads,
- dropout=self.config.attn.transformer.dropout,
- activation=self.config.attn.transformer.activation,
+ input_size=attn_input_size, # type: ignore
+ hidden_size=self.config.attn.transformer.hidden_size, # type: ignore
+ num_layers=self.config.attn.transformer.num_layers, # type: ignore
+ num_heads=self.config.attn.transformer.num_heads, # type: ignore
+ dropout=self.config.attn.transformer.dropout, # type: ignore
+ activation=self.config.attn.transformer.activation, # type: ignore
)
if self.attn.batch_first:
self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size seq_len channels")
@@ -225,44 +230,44 @@ def __init__(
self.__attn_seqlen_dim = 0
clf_input_size = self.attn.compute_output_shape(None, None)[-1]
else:
- raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet")
+ raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") # type: ignore
# global pooling
- if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq:
+ if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore
self.pool = nn.Identity()
- if self.config.global_pool.lower() != "none":
+ if self.config.global_pool.lower() != "none": # type: ignore
warnings.warn(
- f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored",
+ f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", # type: ignore
RuntimeWarning,
)
self.pool_rearrange = nn.Identity()
self.__clf_input_seq = False
- elif self.config.global_pool.lower() == "max":
- self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False)
- clf_input_size *= self.config.global_pool_size
+ elif self.config.global_pool.lower() == "max": # type: ignore
+ self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) # type: ignore
+ clf_input_size *= self.config.global_pool_size # type: ignore
self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)")
self.__clf_input_seq = False
- elif self.config.global_pool.lower() == "avg":
- self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,))
- clf_input_size *= self.config.global_pool_size
+ elif self.config.global_pool.lower() == "avg": # type: ignore
+ self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) # type: ignore
+ clf_input_size *= self.config.global_pool_size # type: ignore
self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)")
self.__clf_input_seq = False
- elif self.config.global_pool.lower() == "attn":
+ elif self.config.global_pool.lower() == "attn": # type: ignore
raise NotImplementedError("Attentive pooling not implemented yet!")
- elif self.config.global_pool.lower() == "none":
+ elif self.config.global_pool.lower() == "none": # type: ignore
self.pool = nn.Identity()
self.pool_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels")
self.__clf_input_seq = True
else:
- raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!")
+ raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # type: ignore
# input of `self.clf` has shape: batch_size, channels
self.clf = MLP(
- in_channels=clf_input_size,
- out_channels=self.config.clf.out_channels + [self.n_classes],
- activation=self.config.clf.activation,
- bias=self.config.clf.bias,
- dropouts=self.config.clf.dropouts,
+ in_channels=clf_input_size, # type: ignore
+ out_channels=self.config.clf.out_channels + [self.n_classes], # type: ignore
+ activation=self.config.clf.activation, # type: ignore
+ bias=self.config.clf.bias, # type: ignore
+ dropouts=self.config.clf.dropouts, # type: ignore
skip_last_activation=True,
)
@@ -335,7 +340,7 @@ def forward(self, input: Tensor) -> Tensor:
@torch.no_grad()
def inference(
self,
- input: Union[np.ndarray, Tensor],
+ input: Union[NDArray, Tensor],
class_names: bool = False,
bin_pred_thr: float = 0.5,
) -> BaseOutput:
@@ -356,11 +361,10 @@ def inference(
-------
output : BaseOutput
The output of the inference method, including the following items:
-
- - prob: numpy.ndarray or torch.Tensor,
- scalar predictions, (and binary predictions if `class_names` is True).
- - pred: numpy.ndarray or torch.Tensor,
- the array (with values 0, 1 for each class) of binary prediction.
+ - prob: numpy.ndarray or torch.Tensor,
+ scalar predictions, (and binary predictions if `class_names` is True).
+ - pred: numpy.ndarray or torch.Tensor,
+ the array (with values 0, 1 for each class) of binary prediction.
"""
raise NotImplementedError("Implement a task-specific inference method.")
@@ -403,10 +407,10 @@ def doi(self) -> List[str]:
new_candidates = []
for candidate in candidates:
if hasattr(candidate, "doi"):
- if isinstance(candidate.doi, str):
- doi.append(candidate.doi)
+ if isinstance(candidate.doi, str): # type: ignore
+ doi.append(candidate.doi) # type: ignore
else:
- doi.extend(list(candidate.doi))
+ doi.extend(list(candidate.doi)) # type: ignore
for k, v in candidate.items():
if isinstance(v, CFG):
new_candidates.append(v)
@@ -416,13 +420,13 @@ def doi(self) -> List[str]:
@classmethod
def from_v1(
- cls, v1_ckpt: str, device: Optional[torch.device] = None, return_config: bool = False
+ cls, v1_ckpt: Union[str, bytes, os.PathLike], device: Optional[torch.device] = None, return_config: bool = False
) -> Union["ECG_CRNN", Tuple["ECG_CRNN", dict]]:
"""Restore an instance of the model from a v1 checkpoint.
Parameters
----------
- v1_ckpt : str
+ v1_ckpt : path_like
Path to the v1 checkpoint file.
device : torch.device, optional
The device to load the model to.
@@ -512,8 +516,8 @@ def __init__(
warnings.warn("No config is provided, using default config.", RuntimeWarning)
self.config.update(deepcopy(config) or {})
- cnn_choice = self.config.cnn.name.lower()
- cnn_config = self.config.cnn[self.config.cnn.name]
+ cnn_choice = self.config.cnn.name.lower() # type: ignore
+ cnn_config = self.config.cnn[self.config.cnn.name] # type: ignore
if "resnet" in cnn_choice or "resnext" in cnn_choice:
self.cnn = ResNet(self.n_leads, **cnn_config)
elif "regnet" in cnn_choice:
@@ -539,119 +543,123 @@ def __init__(
raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet")
rnn_input_size = self.cnn.compute_output_shape(None, None)[1]
- if self.config.rnn.name.lower() == "none":
+ if self.config.rnn.name.lower() == "none": # type: ignore
self.rnn = None
attn_input_size = rnn_input_size
- elif self.config.rnn.name.lower() == "lstm":
+ elif self.config.rnn.name.lower() == "lstm": # type: ignore
self.rnn = StackedLSTM(
- input_size=rnn_input_size,
- hidden_sizes=self.config.rnn.lstm.hidden_sizes,
- bias=self.config.rnn.lstm.bias,
- dropouts=self.config.rnn.lstm.dropouts,
- bidirectional=self.config.rnn.lstm.bidirectional,
- return_sequences=self.config.rnn.lstm.retseq,
+ input_size=rnn_input_size, # type: ignore
+ hidden_sizes=self.config.rnn.lstm.hidden_sizes, # type: ignore
+ bias=self.config.rnn.lstm.bias, # type: ignore
+ dropouts=self.config.rnn.lstm.dropouts, # type: ignore
+ bidirectional=self.config.rnn.lstm.bidirectional, # type: ignore
+ return_sequences=self.config.rnn.lstm.retseq, # type: ignore
)
attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
- elif self.config.rnn.name.lower() == "linear":
+ elif self.config.rnn.name.lower() == "linear": # type: ignore
# abuse of notation, to put before the global attention module
self.rnn = MLP(
- in_channels=rnn_input_size,
- out_channels=self.config.rnn.linear.out_channels,
- activation=self.config.rnn.linear.activation,
- bias=self.config.rnn.linear.bias,
- dropouts=self.config.rnn.linear.dropouts,
+ in_channels=rnn_input_size, # type: ignore
+ out_channels=self.config.rnn.linear.out_channels, # type: ignore
+ activation=self.config.rnn.linear.activation, # type: ignore
+ bias=self.config.rnn.linear.bias, # type: ignore
+ dropouts=self.config.rnn.linear.dropouts, # type: ignore
)
attn_input_size = self.rnn.compute_output_shape(None, None)[-1]
else:
- raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet")
+ raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # type: ignore
# attention
- if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq:
+ if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore
self.attn = None
clf_input_size = attn_input_size
- if self.config.attn.name.lower() != "none":
+ if self.config.attn.name.lower() != "none": # type: ignore
warnings.warn(
- f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored",
+ f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", # type: ignore
RuntimeWarning,
)
- elif self.config.attn.name.lower() == "none":
+ elif self.config.attn.name.lower() == "none": # type: ignore
self.attn = None
clf_input_size = attn_input_size
- elif self.config.attn.name.lower() == "nl": # non_local
+ elif self.config.attn.name.lower() == "nl": # type: ignore
+ # non_local
self.attn = NonLocalBlock(
- in_channels=attn_input_size,
- filter_lengths=self.config.attn.nl.filter_lengths,
- subsample_length=self.config.attn.nl.subsample_length,
- batch_norm=self.config.attn.nl.batch_norm,
+ in_channels=attn_input_size, # type: ignore
+ filter_lengths=self.config.attn.nl.filter_lengths, # type: ignore
+ subsample_length=self.config.attn.nl.subsample_length, # type: ignore
+ batch_norm=self.config.attn.nl.batch_norm, # type: ignore
)
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "se": # squeeze_exitation
+ elif self.config.attn.name.lower() == "se": # type: ignore
+ # squeeze_exitation
self.attn = SEBlock(
- in_channels=attn_input_size,
- reduction=self.config.attn.se.reduction,
- activation=self.config.attn.se.activation,
- kw_activation=self.config.attn.se.kw_activation,
- bias=self.config.attn.se.bias,
+ in_channels=attn_input_size, # type: ignore
+ reduction=self.config.attn.se.reduction, # type: ignore
+ activation=self.config.attn.se.activation, # type: ignore
+ kw_activation=self.config.attn.se.kw_activation, # type: ignore
+ bias=self.config.attn.se.bias, # type: ignore
)
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "gc": # global_context
+ elif self.config.attn.name.lower() == "gc": # type: ignore
+ # global_context
self.attn = GlobalContextBlock(
- in_channels=attn_input_size,
- ratio=self.config.attn.gc.ratio,
- reduction=self.config.attn.gc.reduction,
- pooling_type=self.config.attn.gc.pooling_type,
- fusion_types=self.config.attn.gc.fusion_types,
+ in_channels=attn_input_size, # type: ignore
+ ratio=self.config.attn.gc.ratio, # type: ignore
+ reduction=self.config.attn.gc.reduction, # type: ignore
+ pooling_type=self.config.attn.gc.pooling_type, # type: ignore
+ fusion_types=self.config.attn.gc.fusion_types, # type: ignore
)
clf_input_size = self.attn.compute_output_shape(None, None)[1]
- elif self.config.attn.name.lower() == "sa": # self_attention
+ elif self.config.attn.name.lower() == "sa": # type: ignore
+ # self_attention
# NOTE: this branch NOT tested
- self.attn = SelfAttention(
+ self.attn = SelfAttention( # type: ignore
embed_dim=attn_input_size,
- num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")),
- dropout=self.config.attn.sa.dropout,
- bias=self.config.attn.sa.bias,
+ num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), # type: ignore
+ dropout=self.config.attn.sa.dropout, # type: ignore
+ bias=self.config.attn.sa.bias, # type: ignore
)
clf_input_size = self.attn.compute_output_shape(None, None)[-1]
- elif self.config.attn.name.lower() == "transformer":
+ elif self.config.attn.name.lower() == "transformer": # type: ignore
self.attn = Transformer(
- input_size=attn_input_size,
- hidden_size=self.config.attn.transformer.hidden_size,
- num_layers=self.config.attn.transformer.num_layers,
- num_heads=self.config.attn.transformer.num_heads,
- dropout=self.config.attn.transformer.dropout,
- activation=self.config.attn.transformer.activation,
+ input_size=attn_input_size, # type: ignore
+ hidden_size=self.config.attn.transformer.hidden_size, # type: ignore
+ num_layers=self.config.attn.transformer.num_layers, # type: ignore
+ num_heads=self.config.attn.transformer.num_heads, # type: ignore
+ dropout=self.config.attn.transformer.dropout, # type: ignore
+ activation=self.config.attn.transformer.activation, # type: ignore
)
clf_input_size = self.attn.compute_output_shape(None, None)[-1]
else:
- raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet")
+ raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") # type: ignore
- if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq:
+ if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore
self.pool = None
- if self.config.global_pool.lower() != "none":
+ if self.config.global_pool.lower() != "none": # type: ignore
warnings.warn(
- f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored",
+ f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", # type: ignore
RuntimeWarning,
)
- elif self.config.global_pool.lower() == "max":
- self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False)
- clf_input_size *= self.config.global_pool_size
- elif self.config.global_pool.lower() == "avg":
- self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,))
- clf_input_size *= self.config.global_pool_size
- elif self.config.global_pool.lower() == "attn":
+ elif self.config.global_pool.lower() == "max": # type: ignore
+ self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) # type: ignore
+ clf_input_size *= self.config.global_pool_size # type: ignore
+ elif self.config.global_pool.lower() == "avg": # type: ignore
+ self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) # type: ignore
+ clf_input_size *= self.config.global_pool_size # type: ignore
+ elif self.config.global_pool.lower() == "attn": # type: ignore
raise NotImplementedError("Attentive pooling not implemented yet!")
- elif self.config.global_pool.lower() == "none":
+ elif self.config.global_pool.lower() == "none": # type: ignore
self.pool = None
else:
- raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!")
+ raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # type: ignore
# input of `self.clf` has shape: batch_size, channels
self.clf = MLP(
- in_channels=clf_input_size,
- out_channels=self.config.clf.out_channels + [self.n_classes],
- activation=self.config.clf.activation,
- bias=self.config.clf.bias,
- dropouts=self.config.clf.dropouts,
+ in_channels=clf_input_size, # type: ignore
+ out_channels=self.config.clf.out_channels + [self.n_classes], # type: ignore
+ activation=self.config.clf.activation, # type: ignore
+ bias=self.config.clf.bias, # type: ignore
+ dropouts=self.config.clf.dropouts, # type: ignore
skip_last_activation=True,
)
@@ -683,14 +691,16 @@ def extract_features(self, input: Tensor) -> Tensor:
features = self.cnn(input) # batch_size, channels, seq_len
# RNN (optional)
- if self.config.rnn.name.lower() in ["lstm"]:
+ if self.config.rnn.name.lower() in ["lstm"]: # type: ignore
# (batch_size, channels, seq_len) --> (seq_len, batch_size, channels)
features = features.permute(2, 0, 1)
- features = self.rnn(features) # (seq_len, batch_size, channels) or (batch_size, channels)
- elif self.config.rnn.name.lower() in ["linear"]:
+ features = self.rnn(features) # type: ignore
+ # (seq_len, batch_size, channels) or (batch_size, channels)
+ elif self.config.rnn.name.lower() in ["linear"]: # type: ignore
# (batch_size, channels, seq_len) --> (batch_size, seq_len, channels)
features = features.permute(0, 2, 1)
- features = self.rnn(features) # (batch_size, seq_len, channels)
+ # (batch_size, seq_len, channels)
+ features = self.rnn(features) # type: ignore
# (batch_size, seq_len, channels) --> (seq_len, batch_size, channels)
features = features.permute(1, 0, 2)
else:
@@ -701,16 +711,18 @@ def extract_features(self, input: Tensor) -> Tensor:
if self.attn is None and features.ndim == 3:
# (seq_len, batch_size, channels) --> (batch_size, channels, seq_len)
features = features.permute(1, 2, 0)
- elif self.config.attn.name.lower() in ["nl", "se", "gc"]:
+ elif self.config.attn.name.lower() in ["nl", "se", "gc"]: # type: ignore
# (seq_len, batch_size, channels) --> (batch_size, channels, seq_len)
features = features.permute(1, 2, 0)
- features = self.attn(features) # (batch_size, channels, seq_len)
- elif self.config.attn.name.lower() in ["sa"]:
- features = self.attn(features) # (seq_len, batch_size, channels)
+ # (batch_size, channels, seq_len)
+ features = self.attn(features) # type: ignore
+ elif self.config.attn.name.lower() in ["sa"]: # type: ignore
+ # (seq_len, batch_size, channels)
+ features = self.attn(features) # type: ignore
# (seq_len, batch_size, channels) -> (batch_size, channels, seq_len)
features = features.permute(1, 2, 0)
- elif self.config.attn.name.lower() in ["transformer"]:
- features = self.attn(features)
+ elif self.config.attn.name.lower() in ["transformer"]: # type: ignore
+ features = self.attn(features) # type: ignore
# (seq_len, batch_size, channels) -> (batch_size, channels, seq_len)
features = features.permute(1, 2, 0)
return features
@@ -753,7 +765,7 @@ def forward(self, input: Tensor) -> Tensor:
@torch.no_grad()
def inference(
self,
- input: Union[np.ndarray, Tensor],
+ input: Union[NDArray, Tensor],
class_names: bool = False,
bin_pred_thr: float = 0.5,
) -> BaseOutput:
@@ -774,11 +786,10 @@ def inference(
-------
output : BaseOutput
The output of the inference method, including the following items:
-
- - prob: numpy.ndarray or torch.Tensor,
- scalar predictions, (and binary predictions if `class_names` is True).
- - pred: numpy.ndarray or torch.Tensor,
- the array (with values 0, 1 for each class) of binary prediction.
+ - prob: numpy.ndarray or torch.Tensor,
+ scalar predictions, (and binary predictions if `class_names` is True).
+ - pred: numpy.ndarray or torch.Tensor,
+ the array (with values 0, 1 for each class) of binary prediction.
"""
raise NotImplementedError("Implement a task-specific inference method.")
@@ -824,10 +835,10 @@ def doi(self) -> List[str]:
new_candidates = []
for candidate in candidates:
if hasattr(candidate, "doi"):
- if isinstance(candidate.doi, str):
- doi.append(candidate.doi)
+ if isinstance(candidate.doi, str): # type: ignore
+ doi.append(candidate.doi) # type: ignore
else:
- doi.extend(list(candidate.doi))
+ doi.extend(list(candidate.doi)) # type: ignore
for k, v in candidate.items():
if isinstance(v, CFG):
new_candidates.append(v)
diff --git a/torch_ecg/models/grad_cam.py b/torch_ecg/models/grad_cam.py
index 973e89b9..d5cb5bca 100644
--- a/torch_ecg/models/grad_cam.py
+++ b/torch_ecg/models/grad_cam.py
@@ -134,7 +134,7 @@ def __call__(self, input: Tensor, index: Optional[int] = None):
n_classes = output.shape[-1]
if index is None:
- index = np.argmax(output.cpu().detach().numpy()[0])
+ index = np.argmax(output.detach().cpu().numpy()[0])
one_hot = np.zeros((1, n_classes), dtype=np.float32)
one_hot[0][index] = 1
@@ -145,12 +145,12 @@ def __call__(self, input: Tensor, index: Optional[int] = None):
self.model.zero_grad()
one_hot.backward(retain_graph=True)
- grads_val = self.extractor.get_gradients()[-1].cpu().detach().numpy()
+ grads_val = self.extractor.get_gradients()[-1].detach().cpu().numpy()
# of shape (batch_size (=1), channels, seq_len) or (batch_size (=1), seq_len, channels)
target = features[-1]
# of shape (channels, seq_len) or (seq_len, channels)
- target = target.cpu().detach().numpy()[0, :]
+ target = target.detach().cpu().numpy()[0, :]
if self.target_channel_last:
weights = np.mean(grads_val, axis=-2)[0, :]
diff --git a/torch_ecg/models/loss.py b/torch_ecg/models/loss.py
index d872e150..66ecc341 100644
--- a/torch_ecg/models/loss.py
+++ b/torch_ecg/models/loss.py
@@ -302,7 +302,7 @@ class FocalLoss(nn.modules.loss._WeightedLoss):
Where:
- - :math:`p_t` is the model's estimated probability for each class.
+ - :math:`p_t` is the model's estimated probability for each class.
Parameters
----------
diff --git a/torch_ecg/models/unets/ecg_subtract_unet.py b/torch_ecg/models/unets/ecg_subtract_unet.py
index 12839784..f52dad11 100644
--- a/torch_ecg/models/unets/ecg_subtract_unet.py
+++ b/torch_ecg/models/unets/ecg_subtract_unet.py
@@ -15,8 +15,8 @@
from itertools import repeat
from typing import List, Optional, Sequence, Union
-import numpy as np
import torch
+from numpy.typing import NDArray
from torch import Tensor, nn
from ...cfg import CFG
@@ -642,7 +642,7 @@ def forward(self, input: Tensor) -> Tensor:
return output
@torch.no_grad()
- def inference(self, input: Union[np.ndarray, Tensor], bin_pred_thr: float = 0.5) -> Tensor:
+ def inference(self, input: Union[NDArray, Tensor], bin_pred_thr: float = 0.5) -> Tensor:
"""Method for making inference on a single input."""
raise NotImplementedError("Implement a task-specific inference method.")
diff --git a/torch_ecg/utils/_ecg_plot.py b/torch_ecg/utils/_ecg_plot.py
index 94deae8c..e610ec85 100644
--- a/torch_ecg/utils/_ecg_plot.py
+++ b/torch_ecg/utils/_ecg_plot.py
@@ -42,6 +42,7 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import AutoMinorLocator
+from numpy.typing import NDArray
__all__ = ["ecg_plot"]
@@ -99,7 +100,7 @@ def inches_to_dots(value: float, resolution: int) -> float:
return value * resolution
-def create_signal_dictionary(signal: np.ndarray, full_leads: List[str]) -> Dict[str, np.ndarray]:
+def create_signal_dictionary(signal: NDArray, full_leads: List[str]) -> Dict[str, NDArray]:
record_dict = {}
for k in range(len(full_leads)):
record_dict[full_leads[k]] = signal[k]
@@ -107,7 +108,7 @@ def create_signal_dictionary(signal: np.ndarray, full_leads: List[str]) -> Dict[
def ecg_plot(
- ecg: Dict[str, np.ndarray],
+ ecg: Dict[str, NDArray],
sample_rate: int,
columns: int,
rec_file_name: Union[str, bytes, os.PathLike],
@@ -140,7 +141,7 @@ def ecg_plot(
Parameters
----------
- ecg : Dict[str, np.ndarray]
+ ecg : Dict[str, NDArray]
Dictionary of ECG signals with lead names as keys,
values as 1D numpy arrays.
sample_rate : int
diff --git a/torch_ecg/utils/_edr.py b/torch_ecg/utils/_edr.py
index 02e23223..a084ccec 100644
--- a/torch_ecg/utils/_edr.py
+++ b/torch_ecg/utils/_edr.py
@@ -8,6 +8,7 @@
from typing import Sequence
import numpy as np
+from numpy.typing import NDArray
__all__ = [
"phs_edr",
@@ -23,7 +24,7 @@ def phs_edr(
return_with_time: bool = True,
mode: str = "complex",
verbose: int = 0,
-) -> np.ndarray:
+) -> NDArray:
"""
computes the respiratory rate from single-lead ECG signals.
@@ -53,7 +54,7 @@ def phs_edr(
Returns
-------
- np.ndarray,
+ NDArray,
1d, if `return_with_time` is set False,
2d in the form of [ts, val] (ts in milliseconds), if `return_with_time` is set True
diff --git a/torch_ecg/utils/_preproc.py b/torch_ecg/utils/_preproc.py
index f4f9ee91..6899fcad 100644
--- a/torch_ecg/utils/_preproc.py
+++ b/torch_ecg/utils/_preproc.py
@@ -28,6 +28,7 @@
# from scipy.signal import medfilt
# https://github.com/scipy/scipy/issues/9680
from biosppy.signals.tools import filter_signal
+from numpy.typing import NDArray
from scipy.ndimage.filters import median_filter
from ..cfg import CFG
@@ -62,14 +63,14 @@
def preprocess_multi_lead_signal(
- raw_sig: np.ndarray,
+ raw_sig: NDArray,
fs: Real,
sig_fmt: str = "channel_first",
bl_win: Optional[List[Real]] = None,
band_fs: Optional[List[Real]] = None,
rpeak_fn: Optional[str] = None,
verbose: int = 0,
-) -> Dict[str, np.ndarray]:
+) -> Dict[str, NDArray]:
"""
perform preprocessing for multi-lead ECG signal (with units in mV),
preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
@@ -146,13 +147,13 @@ def preprocess_multi_lead_signal(
def preprocess_single_lead_signal(
- raw_sig: np.ndarray,
+ raw_sig: NDArray,
fs: Real,
bl_win: Optional[List[Real]] = None,
band_fs: Optional[List[Real]] = None,
rpeak_fn: Optional[str] = None,
verbose: int = 0,
-) -> Dict[str, np.ndarray]:
+) -> Dict[str, NDArray]:
"""
perform preprocessing for single lead ECG signal (with units in mV),
preprocessing may include median filter, bandpass filter, and rpeaks detection, etc.
@@ -225,12 +226,12 @@ def preprocess_single_lead_signal(
def rpeaks_detect_multi_leads(
- sig: np.ndarray,
+ sig: NDArray,
fs: Real,
sig_fmt: str = "channel_first",
rpeak_fn: str = "xqrs",
verbose: int = 0,
-) -> np.ndarray:
+) -> NDArray:
"""
detect rpeaks from the filtered multi-lead ECG signal (with units in mV)
@@ -252,7 +253,7 @@ def rpeaks_detect_multi_leads(
Returns
-------
- rpeaks: np.ndarray,
+ rpeaks: NDArray,
array of indices of the detected rpeaks of the multi-lead ECG signal
"""
@@ -273,7 +274,7 @@ def rpeaks_detect_multi_leads(
return rpeaks
-def merge_rpeaks(rpeaks_candidates: List[np.ndarray], sig: np.ndarray, fs: Real, verbose: int = 0) -> np.ndarray:
+def merge_rpeaks(rpeaks_candidates: List[NDArray], sig: NDArray, fs: Real, verbose: int = 0) -> NDArray:
"""
merge rpeaks that are detected from each lead of multi-lead signals (with units in mV),
using certain criterion merging qrs masks from each lead
@@ -291,7 +292,7 @@ def merge_rpeaks(rpeaks_candidates: List[np.ndarray], sig: np.ndarray, fs: Real,
Returns
-------
- final_rpeaks: np.ndarray
+ final_rpeaks: NDArray
the final rpeaks obtained by merging the rpeaks from all the leads
"""
diff --git a/torch_ecg/utils/download.py b/torch_ecg/utils/download.py
index f726af59..069c8137 100644
--- a/torch_ecg/utils/download.py
+++ b/torch_ecg/utils/download.py
@@ -19,7 +19,7 @@
import warnings
import zipfile
from pathlib import Path
-from typing import Any, Iterable, Literal, Optional, Union
+from typing import Any, Iterable, Literal, Optional, Tuple, Union
import boto3
import requests
@@ -27,6 +27,8 @@
from botocore.client import Config
from tqdm.auto import tqdm
+from .misc import str2bool
+
__all__ = [
"http_get",
]
@@ -36,10 +38,10 @@
def _requests_retry_session(
- retries=5,
- backoff_factor=0.5,
- status_forcelist=(500, 502, 503, 504),
- session=None,
+ retries: int = 5,
+ backoff_factor: float = 0.5,
+ status_forcelist: Tuple[int, ...] = (429, 500, 502, 503, 504, 403),
+ session: Optional[requests.Session] = None,
) -> requests.Session:
"""Get a requests session with retry strategy.
@@ -50,7 +52,7 @@ def _requests_retry_session(
backoff_factor : float, default 0.5
A backoff factor to apply between attempts.
A backoff factor of 0.5 will sleep for [0.5s, 1s, 2s, ...] between retries.
- status_forcelist : tuple, default (500, 502, 503, 504)
+ status_forcelist : tuple, default (429, 500, 502, 503, 504, 403)
A set of HTTP status codes that we should force a retry on.
session : requests.Session, optional
An existing requests session.
@@ -63,17 +65,23 @@ def _requests_retry_session(
"""
session = session or requests.Session()
- retry = requests.packages.urllib3.util.retry.Retry(
+ from requests.adapters import HTTPAdapter
+ from requests.packages.urllib3.util.retry import Retry # type: ignore
+
+ retry = Retry(
total=retries,
read=retries,
connect=retries,
+ status=retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist,
- allowed_methods=["HEAD", "GET", "OPTIONS"],
+ allowed_methods=frozenset(["HEAD", "GET", "OPTIONS"]),
+ raise_on_redirect=True,
+ raise_on_status=False,
)
- adapter = requests.adapters.HTTPAdapter(max_retries=retry)
- session.mount("http://", adapter)
+ adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)
+ session.mount("http://", adapter)
return session
@@ -83,6 +91,9 @@ def http_get(
proxies: Optional[dict] = None,
extract: Literal[True, False, "auto"] = "auto",
filename: Optional[str] = None,
+ *,
+ timeout: Union[Tuple[float, float], float] = (5.0, 60.0),
+ verify_length: bool = True,
) -> Path:
"""Download contents of a URL and save to a file.
@@ -108,6 +119,16 @@ def http_get(
which is set to `dst_dir`, and `filename` is only the downloaded file name.
.. versionadded:: 0.0.20
+ timeout : float or tuple, default (5.0, 60.0)
+ How many seconds to wait for the server to send data
+ before giving up, as a float, or a (connect timeout, read timeout) tuple.
+
+ .. versionadded:: 0.0.32
+ verify_length : bool, default True
+ Whether to verify the length of the downloaded file
+ with the `Content-Length` header in the HTTP response.
+
+ .. versionadded:: 0.0.32
Returns
-------
@@ -119,8 +140,8 @@ def http_get(
.. [1] https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py
"""
- Path(dst_dir).mkdir(parents=True, exist_ok=True)
- if filename is not None and (Path(dst_dir) / filename).exists():
+ Path(dst_dir).mkdir(parents=True, exist_ok=True) # type: ignore
+ if filename is not None and (Path(dst_dir) / filename).exists(): # type: ignore
raise FileExistsError("file already exists")
url_parsed = urllib.parse.urlparse(url)
if url_parsed.scheme == "":
@@ -134,7 +155,7 @@ def http_get(
if url_parsed.scheme == "s3":
_download_from_aws_s3_using_awscli(url, dst_dir)
- return Path(dst_dir)
+ return Path(dst_dir) # type: ignore
if url_parsed.netloc == "www.dropbox.com" and url_parsed.query == "dl=0":
url_parsed = url_parsed._replace(query="dl=1")
@@ -163,7 +184,7 @@ def http_get(
delete=False,
)
_download_from_google_drive(url, downloaded_file.name)
- df_suffix = _suffix(filename)
+ df_suffix = _suffix(filename) # type: ignore
downloaded_file.close()
else:
print(f"Downloading {url}.")
@@ -186,7 +207,7 @@ def http_get(
RuntimeWarning,
)
extract = False
- parent_dir = Path(dst_dir).parent
+ parent_dir = Path(dst_dir).parent # type: ignore
df_suffix = _suffix(pure_url) if filename is None else _suffix(filename)
downloaded_file = tempfile.NamedTemporaryFile(
dir=parent_dir,
@@ -194,19 +215,39 @@ def http_get(
delete=False,
)
# req = requests.get(url, stream=True, proxies=proxies)
- req = _requests_retry_session().get(url, stream=True, proxies=proxies)
+ req = _requests_retry_session().get(url, stream=True, proxies=proxies, timeout=timeout)
+ try:
+ req.raise_for_status()
+ except Exception:
+ snippet = ""
+ try:
+ snippet = req.text[:300]
+ except Exception:
+ pass
+ raise RuntimeError(f"Failed to download {url}, status={req.status_code}, body[:300]={snippet!r}")
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
if req.status_code in [403, 404]:
raise Exception(f"Could not reach {url}.")
- progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=1.0)
+ if str2bool(os.environ.get("CI")):
+ mininterval = 10.0
+ disable = True
+ else:
+ mininterval = 1.0
+ disable = False
+ progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=mininterval, disable=disable)
+ downloaded_size = 0
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
downloaded_file.write(chunk)
+ downloaded_size += len(chunk)
progress.close()
downloaded_file.close()
+ if verify_length and total is not None and downloaded_size != total:
+ raise IOError(f"Size mismatch for {url}. Expected {total} bytes, got {downloaded_size} bytes.")
+
# add a delay to avoid the error "process cannot access the file because it is being used by another process"
time.sleep(0.1)
@@ -225,20 +266,20 @@ def http_get(
_folder = Path(url).name.replace(_suffix(url), "")
else:
_folder = _stem(Path(filename))
- if _folder in os.listdir(dst_dir):
+ if (Path(dst_dir) / _folder).exists(): # type: ignore
tmp_folder = str(dst_dir).rstrip(os.sep) + "_tmp"
# move (rename) the dst_dir to a temporary folder
os.rename(dst_dir, tmp_folder)
# move (rename) the extracted folder to the destination folder
os.rename(Path(tmp_folder) / _folder, dst_dir)
shutil.rmtree(tmp_folder)
- final_dst = Path(dst_dir)
+ final_dst = Path(dst_dir) # type: ignore
else:
- Path(dst_dir).mkdir(parents=True, exist_ok=True)
+ Path(dst_dir).mkdir(parents=True, exist_ok=True) # type: ignore
if filename is None:
- final_dst = Path(dst_dir) / Path(pure_url).name
+ final_dst = Path(dst_dir) / Path(pure_url).name # type: ignore
else:
- final_dst = Path(dst_dir) / filename
+ final_dst = Path(dst_dir) / filename # type: ignore
shutil.copyfile(downloaded_file.name, final_dst)
os.remove(downloaded_file.name)
return final_dst
@@ -258,7 +299,9 @@ def _stem(path: Union[str, bytes, os.PathLike]) -> str:
Filename without extension.
"""
- ret = Path(path).stem
+ if isinstance(path, bytes):
+ path = path.decode()
+ ret = Path(path).stem # type: ignore
if Path(ret).suffix in [".tar", ".gz", ".tz", ".lz", ".bz2", ".xz", ".zip", ".7z"]:
return _stem(ret)
return ret
@@ -342,8 +385,8 @@ def _untar_file(path_to_tar_file: Union[str, bytes, os.PathLike], dst_dir: Union
"""
print(f"Extracting file {path_to_tar_file} to {dst_dir}.")
- mode = Path(path_to_tar_file).suffix.replace(".", "r:").replace("tar", "").strip(":")
- with tarfile.open(str(path_to_tar_file), mode) as tar_ref:
+ mode = Path(path_to_tar_file).suffix.replace(".", "r:").replace("tar", "").strip(":") # type: ignore
+ with tarfile.open(str(path_to_tar_file), mode) as tar_ref: # type: ignore
# tar_ref.extractall(str(dst_dir))
# CVE-2007-4559 (related to CVE-2001-1267):
# directory traversal vulnerability in `extract` and `extractall` in `tarfile` module
@@ -370,7 +413,7 @@ def _is_within_directory(directory: Union[str, bytes, os.PathLike], target: Unio
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
- prefix = os.path.commonprefix([abs_directory, abs_target])
+ prefix = os.path.commonprefix([abs_directory, abs_target]) # type: ignore
return prefix == abs_directory
@@ -409,7 +452,7 @@ def _safe_tar_extract(
"""
for member in members or tar.getmembers():
- member_path = os.path.join(dst_dir, member.name)
+ member_path = os.path.join(dst_dir, member.name) # type: ignore
if not _is_within_directory(dst_dir, member_path):
raise Exception("Attempted Path Traversal in Tar File")
@@ -510,11 +553,12 @@ def count_aws_s3_bucket(bucket_name: str, prefix: str = "") -> int:
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
object_count = 0
- for page in page_iterator:
- if "Contents" in page:
- object_count += len(page["Contents"])
-
- s3.close()
+ try:
+ for page in page_iterator:
+ if "Contents" in page:
+ object_count += len(page.get("Contents", []))
+ finally:
+ s3.close()
return object_count
@@ -562,7 +606,7 @@ def _download_from_aws_s3_using_boto3(url: str, dst_dir: Union[str, bytes, os.Pa
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
- dst_file = Path(dst_dir) / key
+ dst_file = Path(dst_dir) / key # type: ignore
if not dst_file.parent.exists():
dst_file.parent.mkdir(parents=True, exist_ok=True)
s3.download_file(bucket_name, key, str(dst_file))
@@ -571,7 +615,13 @@ def _download_from_aws_s3_using_boto3(url: str, dst_dir: Union[str, bytes, os.Pa
s3.close()
-def _download_from_aws_s3_using_awscli(url: str, dst_dir: Union[str, bytes, os.PathLike]) -> None:
+def _download_from_aws_s3_using_awscli(
+ url: str,
+ dst_dir: Union[str, bytes, os.PathLike],
+ *,
+ show_progress: bool = True,
+ env: Optional[dict] = None,
+) -> None:
"""Download a file from AWS S3 using awscli.
Parameters
@@ -581,51 +631,78 @@ def _download_from_aws_s3_using_awscli(url: str, dst_dir: Union[str, bytes, os.P
For example, "s3://bucket-name/files/pre-fix".
dst_dir : `path-like`
The output directory.
+ show_progress : bool, default True
+ Whether to display tqdm progress bar.
+ env : dict, optional
+ Custom environment.
Returns
-------
None
"""
- assert shutil.which("aws") is not None, "AWS cli is required to download from S3."
+ if shutil.which("aws") is None:
+ raise RuntimeError("AWS cli is required to download from S3 (please install AWS CLI v2).")
pattern = "^s3://(?P[^/]+)/(?P.+)$"
match = re.match(pattern, url)
if match is None:
raise ValueError(f"Invalid S3 URL: {url}")
+
bucket_name = match.group("bucket_name")
prefix = match.group("prefix")
if prefix.startswith("files/"):
- prefix = prefix.replace("files/", "")
- object_count = count_aws_s3_bucket(bucket_name, prefix)
- print(f"Downloading from S3 bucket: {bucket_name}, prefix: {prefix}, total files: {object_count}")
+ prefix = prefix[len("files/") :]
+ total_files = count_aws_s3_bucket(bucket_name, prefix)
- pbar = tqdm(total=object_count, dynamic_ncols=True, mininterval=1.0)
+ if total_files == 0:
+ print(f"[S3 sync] No objects found at {bucket_name}/{prefix}, skipping.")
+ return
+
+ print(f"Downloading from S3 bucket: {bucket_name}, prefix: {prefix}, total files: {total_files}")
+
+ dst_dir = Path(dst_dir).expanduser().resolve() # type: ignore
+ dst_dir.mkdir(parents=True, exist_ok=True)
+
+ if str2bool(os.environ.get("CI")):
+ mininterval = 10.0
+ disable = True
+ else:
+ mininterval = 1.0
+ disable = not show_progress
+ pbar = tqdm(total=total_files, dynamic_ncols=True, mininterval=mininterval, disable=disable)
download_count = 0
- command = f"aws s3 sync --no-sign-request {url} {dst_dir}"
- process = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- debug_stdout = collections.deque(maxlen=10)
- while 1:
- line = process.stdout.readline().decode("utf-8", errors="replace")
- if line.rstrip():
+ command = f"aws s3 sync --no-sign-request --only-show-errors {url} {shlex.quote(str(dst_dir))}"
+ process = subprocess.Popen(
+ shlex.split(command),
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ bufsize=1,
+ text=True,
+ env=env if env is not None else os.environ.copy(),
+ )
+ debug_stdout = collections.deque(maxlen=50)
+ try:
+ assert process.stdout is not None # for type checker
+ for line in process.stdout:
+ line = line.rstrip("\n")
debug_stdout.append(line)
- if "download: s3:" in line:
+ if "download: s3://" in line:
download_count += 1
pbar.update(1)
- exitcode = process.poll()
- if exitcode is not None:
- for line in process.stdout:
- debug_stdout.append(line.decode("utf-8", errors="replace"))
- if exitcode is not None and exitcode != 0:
- error_msg = "\n".join(debug_stdout)
- process.communicate()
- process.stdout.close()
- raise subprocess.CalledProcessError(exitcode, error_msg)
- else:
- break
- process.communicate()
- process.stdout.close()
- # object_count - download_count files skipped for they already exist
- pbar.update(object_count - download_count)
- pbar.close()
- print(f"Downloaded {download_count} files from S3.")
+
+ retcode = process.wait()
+ if retcode != 0:
+ raise subprocess.CalledProcessError(
+ retcode,
+ command,
+ output="\n".join(debug_stdout),
+ )
+ finally:
+ if download_count < total_files:
+ pbar.update(total_files - download_count)
+ pbar.close()
+ if process.stdout and not process.stdout.closed:
+ process.stdout.close()
+
+ print(f"[S3 sync] Completed. Reported downloads: {download_count}/{total_files} (existing files are skipped silently).")
diff --git a/torch_ecg/utils/misc.py b/torch_ecg/utils/misc.py
index 8579811d..e453a95a 100644
--- a/torch_ecg/utils/misc.py
+++ b/torch_ecg/utils/misc.py
@@ -14,14 +14,15 @@
from copy import deepcopy
from functools import reduce, wraps
from glob import glob
-from numbers import Number, Real
-from pathlib import Path
+from numbers import Number
+from pathlib import Path, PurePath
from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
from bib_lookup import CitationMixin as _CitationMixin
from deprecated import deprecated
+from numpy.typing import NDArray
from ..cfg import _DATA_CACHE, DEFAULTS
@@ -85,10 +86,10 @@ def get_record_list_recursive(db_dir: Union[str, bytes, os.PathLike], rec_ext: s
"""
if not rec_ext.startswith("."):
- res = Path(db_dir).rglob(f"*.{rec_ext}")
+ res = Path(db_dir).rglob(f"*.{rec_ext}") # type: ignore
else:
- res = Path(db_dir).rglob(f"*{rec_ext}")
- res = [str((item.relative_to(db_dir) if relative else item).with_suffix("")) for item in res if str(item).endswith(rec_ext)]
+ res = Path(db_dir).rglob(f"*{rec_ext}") # type: ignore
+ res = [str((item.relative_to(db_dir) if relative else item).with_suffix("")) for item in res if str(item).endswith(rec_ext)] # type: ignore
res = sorted(res)
return res
@@ -173,32 +174,32 @@ def get_record_list_recursive3(
res = []
elif isinstance(rec_patterns, dict):
res = {k: [] for k in rec_patterns.keys()}
- _db_dir = Path(db_dir).resolve() # make absolute
+ _db_dir = Path(db_dir).resolve() # type: ignore
roots = [_db_dir]
while len(roots) > 0:
new_roots = []
for r in roots:
tmp = os.listdir(r)
if isinstance(rec_patterns, str):
- res += [r / item for item in filter(re.compile(rec_patterns).search, tmp)]
+ res += [r / item for item in filter(re.compile(rec_patterns).search, tmp)] # type: ignore
elif isinstance(rec_patterns, dict):
for k in rec_patterns.keys():
- res[k] += [r / item for item in filter(re.compile(rec_patterns[k]).search, tmp)]
+ res[k] += [r / item for item in filter(re.compile(rec_patterns[k]).search, tmp)] # type: ignore
new_roots += [r / item for item in tmp if (r / item).is_dir()]
roots = deepcopy(new_roots)
if isinstance(rec_patterns, str):
if with_suffix:
- res = [str((item.relative_to(_db_dir) if relative else item)) for item in res]
+ res = [str((item.relative_to(_db_dir) if relative else item)) for item in res] # type: ignore
else:
- res = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res]
+ res = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res] # type: ignore
res = sorted(res)
elif isinstance(rec_patterns, dict):
for k in rec_patterns.keys():
if with_suffix:
- res[k] = [str((item.relative_to(_db_dir) if relative else item)) for item in res[k]]
+ res[k] = [str((item.relative_to(_db_dir) if relative else item)) for item in res[k]] # type: ignore
else:
- res[k] = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res[k]]
- res[k] = sorted(res[k])
+ res[k] = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res[k]] # type: ignore
+ res[k] = sorted(res[k]) # type: ignore
return res
@@ -282,7 +283,11 @@ def dict_to_str(d: Union[dict, list, tuple], current_depth: int = 1, indent_spac
return s
-def str2bool(v: Union[str, bool]) -> bool:
+_TRUE_SET = {"yes", "true", "t", "y", "1"}
+_FALSE_SET = {"no", "false", "f", "n", "0"}
+
+
+def str2bool(v: Union[str, bool, None], *, default: bool = False, strict: bool = True) -> bool:
"""Converts a "boolean" value possibly
in the format of :class:`str` to :class:`bool`.
@@ -290,8 +295,17 @@ def str2bool(v: Union[str, bool]) -> bool:
Parameters
----------
- v : str or bool
+ v : str or bool or None
The "boolean" value.
+ default : bool, default False
+ The default value to return if `v` is ``None``,
+ or if `strict` is ``False`` and `v` could not be converted.
+
+ .. versionadded:: 0.0.32
+ strict : bool, default True
+ Whether to raise error if `v` could not be converted.
+
+ .. versionadded:: 0.0.32
Returns
-------
@@ -304,18 +318,25 @@ def str2bool(v: Union[str, bool]) -> bool:
"""
if isinstance(v, bool):
- b = v
- elif v.lower() in ("yes", "true", "t", "y", "1"):
- b = True
- elif v.lower() in ("no", "false", "f", "n", "0"):
- b = False
- else:
- raise ValueError("Boolean value expected.")
- return b
+ return v
+ if v is None:
+ return default
+ if not isinstance(v, str):
+ if strict:
+ raise TypeError(f"Expected str|bool|None, got {type(v)}")
+ return default
+ v_norm = v.strip().lower()
+ if v_norm in _TRUE_SET:
+ return True
+ if v_norm in _FALSE_SET:
+ return False
+ if strict:
+ raise ValueError(f"Boolean value expected, got {v!r}")
+ return default
@deprecated("Use `np.diff` instead.")
-def diff_with_step(a: np.ndarray, step: int = 1) -> np.ndarray:
+def diff_with_step(a: NDArray, step: int = 1) -> NDArray:
"""Compute ``a[n+step] - a[n]`` for all valid `n`.
Parameters
@@ -337,14 +358,14 @@ def diff_with_step(a: np.ndarray, step: int = 1) -> np.ndarray:
return d
-def ms2samples(t: Real, fs: Real) -> int:
+def ms2samples(t: Union[float, int], fs: Union[float, int]) -> int:
"""Convert time duration in ms to number of samples.
Parameters
----------
- t : numbers.Real
+ t : float or int
Time duration in ms.
- fs : numbers.Real
+ fs : float or int
Sampling frequency.
Returns
@@ -354,23 +375,23 @@ def ms2samples(t: Real, fs: Real) -> int:
with sampling frequency `fs`.
"""
- n_samples = t * fs // 1000
+ n_samples = int(t * fs / 1000)
return n_samples
-def samples2ms(n_samples: int, fs: Real) -> Real:
+def samples2ms(n_samples: int, fs: int) -> float:
"""Convert number of samples to time duration in ms.
Parameters
----------
n_samples : int
Number of sample points.
- fs : numbers.Real
+ fs : int
Sampling frequency.
Returns
-------
- t : numbers.Real
+ t : float
Time duration in ms converted from `n_samples`,
with sampling frequency `fs`.
@@ -380,8 +401,8 @@ def samples2ms(n_samples: int, fs: Real) -> Real:
def plot_single_lead(
- t: np.ndarray,
- sig: np.ndarray,
+ t: NDArray,
+ sig: NDArray,
ax: Optional[Any] = None,
ticks_granularity: int = 0,
**kwargs,
@@ -405,8 +426,9 @@ def plot_single_lead(
None
"""
- if "plt" not in dir():
- import matplotlib.pyplot as plt
+ import matplotlib.pyplot as plt
+ from matplotlib.ticker import MultipleLocator
+
palette = {
"p_waves": "cyan",
"qrs": "green",
@@ -426,12 +448,12 @@ def plot_single_lead(
ax.axhline(y=0, linestyle="-", linewidth="1.0", color="red")
# NOTE that `Locator` has default `MAXTICKS` equal to 1000
if ticks_granularity >= 1:
- ax.xaxis.set_major_locator(plt.MultipleLocator(0.2))
- ax.yaxis.set_major_locator(plt.MultipleLocator(500))
+ ax.xaxis.set_major_locator(MultipleLocator(0.2))
+ ax.yaxis.set_major_locator(MultipleLocator(500))
ax.grid(which="major", linestyle="-", linewidth="0.5", color="red")
if ticks_granularity >= 2:
- ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04))
- ax.yaxis.set_minor_locator(plt.MultipleLocator(100))
+ ax.xaxis.set_minor_locator(MultipleLocator(0.04))
+ ax.yaxis.set_minor_locator(MultipleLocator(100))
ax.grid(which="minor", linestyle=":", linewidth="0.5", color="black")
waves = kwargs.get("waves", {"p_waves": [], "qrs": [], "t_waves": []})
@@ -485,13 +507,13 @@ def init_logger(
log_file = None
else:
if log_file is None:
- log_file = f"{DEFAULTS.prefix}-log-{get_date_str()}.txt"
- log_dir = Path(log_dir).expanduser().resolve() if log_dir is not None else DEFAULTS.log_dir
+ log_file = f"{DEFAULTS.prefix}-log-{get_date_str()}.txt" # type: ignore
+ log_dir = Path(log_dir).expanduser().resolve() if log_dir is not None else DEFAULTS.log_dir # type: ignore
log_dir.mkdir(parents=True, exist_ok=True)
- log_file = log_dir / log_file
+ log_file = log_dir / log_file # type: ignore
print(f"log file path: {str(log_file)}")
- log_name = (log_name or DEFAULTS.prefix) + (f"-{suffix}" if suffix else "")
+ log_name = (log_name or DEFAULTS.prefix) + (f"-{suffix}" if suffix else "") # type: ignore
# if a logger with the same name already exists, remove it
if log_name in logging.root.manager.loggerDict:
logging.getLogger(log_name).handlers = []
@@ -506,21 +528,21 @@ def init_logger(
c_handler.setLevel(logging.DEBUG)
if log_file is not None:
# print("level of `f_handler` is set DEBUG")
- f_handler.setLevel(logging.DEBUG)
+ f_handler.setLevel(logging.DEBUG) # type: ignore
logger.setLevel(logging.DEBUG)
elif verbose >= 1:
# print("level of `c_handler` is set INFO")
c_handler.setLevel(logging.INFO)
if log_file is not None:
# print("level of `f_handler` is set DEBUG")
- f_handler.setLevel(logging.DEBUG)
+ f_handler.setLevel(logging.DEBUG) # type: ignore
logger.setLevel(logging.DEBUG)
else:
# print("level of `c_handler` is set WARNING")
c_handler.setLevel(logging.WARNING)
if log_file is not None:
# print("level of `f_handler` is set INFO")
- f_handler.setLevel(logging.INFO)
+ f_handler.setLevel(logging.INFO) # type: ignore
logger.setLevel(logging.INFO)
c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
@@ -529,8 +551,8 @@ def init_logger(
if log_file is not None:
f_format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
- f_handler.setFormatter(f_format)
- logger.addHandler(f_handler)
+ f_handler.setFormatter(f_format) # type: ignore
+ logger.addHandler(f_handler) # type: ignore
return logger
@@ -598,7 +620,7 @@ def read_log_txt(
Scalars summary, in the format of a :class:`~pandas.DataFrame`.
"""
- content = Path(fp).read_text().splitlines()
+ content = Path(fp).read_text().splitlines() # type: ignore
if isinstance(scalar_startswith, str):
field_pattern = f"({scalar_startswith})"
else:
@@ -615,7 +637,8 @@ def read_log_txt(
field, val = line.split(":")[-2:]
field = field.strip()
val = float(val.strip())
- new_line[field] = val
+ if new_line:
+ new_line[field] = val
summary.append(new_line)
summary = pd.DataFrame(summary)
return summary
@@ -641,7 +664,7 @@ def read_event_scalars(
"""
try:
- from tensorflow.python.summary.event_accumulator import EventAccumulator
+ from tensorflow.python.summary.event_accumulator import EventAccumulator # type: ignore
except Exception:
try:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
@@ -662,7 +685,7 @@ def read_event_scalars(
df.columns = ["wall_time", "step", "value"]
summary[k] = df
if isinstance(keys, str):
- summary = summary[k]
+ summary = summary[k] # type: ignore
return summary
@@ -809,15 +832,15 @@ def default_class_repr(c: object, align: str = "center", depth: int = 1) -> str:
closing_indent = 4 * (depth - 1) * " "
if not hasattr(c, "extra_repr_keys"):
return repr(c)
- elif len(c.extra_repr_keys()) > 0:
- max_len = max([len(k) for k in c.extra_repr_keys()])
+ elif len(c.extra_repr_keys()) > 0: # type: ignore
+ max_len = max([len(k) for k in c.extra_repr_keys()]) # type: ignore
extra_str = (
"(\n"
+ ",\n".join(
[
f"""{indent}{k.ljust(max_len, " ") if align.lower() in ["center", "c"] else k} = {default_class_repr(eval(f"c.{k}"),align,depth+1)}"""
for k in c.__dir__()
- if k in c.extra_repr_keys()
+ if k in c.extra_repr_keys() # type: ignore
]
)
+ f"{closing_indent}\n)"
@@ -874,7 +897,7 @@ def get_citation(
style: Optional[str] = None,
timeout: Optional[float] = None,
print_result: bool = True,
- ) -> Union[str, type(None)]:
+ ) -> Union[str, None]:
"""Get bib citation from DOIs.
Overrides the default method to make the `print_result` argument
@@ -937,7 +960,7 @@ def __init__(self, data: Optional[Sequence] = None, **kwargs: Any) -> None:
self.data = np.array(data)
self.verbose = kwargs.get("verbose", 0)
- def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwargs: Any) -> np.ndarray:
+ def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwargs: Any) -> NDArray:
"""Compute moving average.
Parameters
@@ -946,13 +969,11 @@ def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwarg
The series data to compute its moving average.
method : str
method for computing moving average, can be one of
-
- - "sma", "simple", "simple moving average";
- - "ema", "ewma", "exponential", "exponential weighted",
- "exponential moving average", "exponential weighted moving average";
- - "cma", "cumulative", "cumulative moving average";
- - "wma", "weighted", "weighted moving average".
-
+ - "sma", "simple", "simple moving average";
+ - "ema", "ewma", "exponential", "exponential weighted",
+ "exponential moving average", "exponential weighted moving average";
+ - "cma", "cumulative", "cumulative moving average";
+ - "wma", "weighted", "weighted moving average".
kwargs : dict, optional
Keyword arguments for the specific moving average method.
@@ -984,7 +1005,7 @@ def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwarg
self.data = np.array(data)
return func(**kwargs)
- def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> np.ndarray:
+ def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> NDArray:
"""Simple moving average.
Parameters
@@ -1020,13 +1041,13 @@ def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> np.ndarr
smoothed.append(s)
smoothed = np.array(smoothed)
if center:
- smoothed[hw:-hw] = smoothed[window - 1 :]
- for n in range(hw):
- smoothed[n] = np.mean(self.data[: n + hw + 1])
- smoothed[-n - 1] = np.mean(self.data[-n - hw - 1 :])
+ smoothed[hw:-hw] = smoothed[window - 1 :] # type: ignore
+ for n in range(hw): # type: ignore
+ smoothed[n] = np.mean(self.data[: n + hw + 1]) # type: ignore
+ smoothed[-n - 1] = np.mean(self.data[-n - hw - 1 :]) # type: ignore
return smoothed
- def _ema(self, weight: float = 0.6, **kwargs: Any) -> np.ndarray:
+ def _ema(self, weight: float = 0.6, **kwargs: Any) -> NDArray:
"""Exponential moving average
This is also the function used in Tensorboard Scalar panel,
@@ -1057,7 +1078,7 @@ def _ema(self, weight: float = 0.6, **kwargs: Any) -> np.ndarray:
smoothed = np.array(smoothed)
return smoothed
- def _cma(self, **kwargs) -> np.ndarray:
+ def _cma(self, **kwargs) -> NDArray:
"""Cumulative moving average.
Parameters
@@ -1084,7 +1105,7 @@ def _cma(self, **kwargs) -> np.ndarray:
smoothed = np.array(smoothed)
return smoothed
- def _wma(self, window: int = 5, **kwargs: Any) -> np.ndarray:
+ def _wma(self, window: int = 5, **kwargs: Any) -> NDArray:
"""Weighted moving average.
Parameters
@@ -1199,7 +1220,7 @@ def remove_parameters_returns_from_docstring(
if start_idx is not None and len(line.strip()) == 0:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = None
- if parameters_starts and len(line.lstrip()) == len(line) - len(parameters_indent):
+ if parameters_starts and len(line.lstrip()) == len(line) - len(parameters_indent): # type: ignore
if any([line.lstrip().startswith(p) for p in parameters]):
if start_idx is not None:
indices2remove.extend(list(range(start_idx, idx)))
@@ -1210,7 +1231,7 @@ def remove_parameters_returns_from_docstring(
else:
indices2remove.extend(list(range(start_idx, idx)))
start_idx = None
- if returns_starts and len(line.lstrip()) == len(line) - len(returns_indent):
+ if returns_starts and len(line.lstrip()) == len(line) - len(returns_indent): # type: ignore
if any([line.lstrip().startswith(p) for p in returns]):
if start_idx is not None:
indices2remove.extend(list(range(start_idx, idx)))
@@ -1226,7 +1247,7 @@ def remove_parameters_returns_from_docstring(
@contextmanager
-def timeout(duration: float):
+def timeout(duration: Union[float, int]):
"""A context manager that raises a
:class:`TimeoutError` after a specified time (in seconds).
@@ -1247,16 +1268,25 @@ def timeout(duration: float):
duration = 0
elif duration < 0:
raise ValueError("`duration` must be non-negative")
- elif duration > 0: # granularity is 1 second, so round up
+ elif duration > 0:
duration = max(1, int(duration))
+ if duration == 0:
+ yield
+ return
+
+ old_handler = signal.getsignal(signal.SIGALRM)
+
def timeout_handler(signum, frame):
raise TimeoutError(f"block timedout after `{duration}` seconds")
signal.signal(signal.SIGALRM, timeout_handler)
- signal.alarm(duration)
- yield
- signal.alarm(0)
+ signal.alarm(duration) # type: ignore
+ try:
+ yield
+ finally:
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, old_handler)
class Timer(ReprMixin):
@@ -1365,7 +1395,7 @@ def extra_repr_keys(self) -> List[str]:
return ["name", "verbose"]
-def get_kwargs(func_or_cls: callable, kwonly: bool = False) -> Dict[str, Any]:
+def get_kwargs(func_or_cls: Callable, kwonly: bool = False) -> Dict[str, Any]:
"""Get the kwargs of a function or class.
Parameters
@@ -1403,7 +1433,7 @@ def get_kwargs(func_or_cls: callable, kwonly: bool = False) -> Dict[str, Any]:
return kwargs
-def get_required_args(func_or_cls: callable) -> List[str]:
+def get_required_args(func_or_cls: Callable) -> List[str]:
"""Get the required positional arguments of a function or class.
Parameters
@@ -1427,7 +1457,7 @@ def get_required_args(func_or_cls: callable) -> List[str]:
return required_args
-def add_kwargs(func: callable, **kwargs: Any) -> callable:
+def add_kwargs(func: Callable, **kwargs: Any) -> Callable:
"""Add keyword arguments to a function.
This function is used to add keyword arguments to a function
@@ -1466,18 +1496,18 @@ def add_kwargs(func: callable, **kwargs: Any) -> callable:
# move the VAR_POSITIONAL and VAR_KEYWORD in `func_parameters` to the end
for k, v in func_parameters.items():
if v.kind == inspect.Parameter.VAR_POSITIONAL:
- func_parameters.move_to_end(k)
+ func_parameters.move_to_end(k) # type: ignore
break
for k, v in func_parameters.items():
if v.kind == inspect.Parameter.VAR_KEYWORD:
- func_parameters.move_to_end(k)
+ func_parameters.move_to_end(k) # type: ignore
break
if isinstance(func, types.MethodType):
# can not assign `__signature__` to a bound method directly
- func.__func__.__signature__ = func_signature.replace(parameters=func_parameters.values())
+ func.__func__.__signature__ = func_signature.replace(parameters=func_parameters.values()) # type: ignore
else:
- func.__signature__ = func_signature.replace(parameters=func_parameters.values())
+ func.__signature__ = func_signature.replace(parameters=func_parameters.values()) # type: ignore
# docstring is automatically copied by `functools.wraps`
@@ -1492,62 +1522,114 @@ def wrapper(*args: Any, **kwargs_: Any) -> Any:
return wrapper
-def make_serializable(x: Union[np.ndarray, np.generic, dict, list, tuple]) -> Union[list, dict, Number]:
- """Make an object serializable.
+def _is_pathlike_string(s: str) -> bool:
+ """Heuristically check if a string looks like a filesystem path."""
+ if not isinstance(s, str):
+ return False
- This function is used to convert all numpy arrays to list in an object,
- and also convert numpy data types to python data types in the object,
- so that it can be serialized by :mod:`json`.
+ p = PurePath(s)
+ if os.sep in s or (os.altsep and os.altsep in s):
+ return True
+ if s.startswith((".", "~")) or p.is_absolute():
+ return True
+ if p.suffix != "":
+ return True
+ if len(s) > 2 and s[1] == ":" and s[0].isalpha() and s[2] in ("/", "\\"):
+ return True
+ return False
+
+
+def make_serializable(
+ x: Any, drop_unserializable: bool = True, drop_paths: bool = False
+) -> Optional[Union[list, dict, str, int, float, bool]]:
+ """Recursively convert object into JSON-serializable form.
+
+ Rules
+ -----
+ - NDArray → list
+ - np.generic → Python scalar
+ - dict → new dict with only serializable values
+ - list/tuple → list with only serializable values
+ - str/int/float/bool/None → kept
+ - if drop_unserializable:
+ anything else (like Path, custom classes) → dropped (return None)
+ else:
+ fallback to str(x)
Parameters
----------
- x : Union[numpy.ndarray, numpy.generic, dict, list, tuple]
- Input data, which can be numpy array (or numpy data type),
- or dict, list, tuple containing numpy arrays (or numpy data type).
+ x : Any
+ Input object to be converted.
+ drop_unserializable : bool, default=True
+ Whether to drop unserializable objects (return None),
+ or convert them to string with str(x).
+
+ .. versionadded:: 0.0.32
+ drop_paths : bool, default=False
+ If True, drop all filesystem paths (Path objects and strings
+ that look like paths).
+
+ .. versionadded:: 0.0.32
Returns
-------
- Union[list, dict, numbers.Number]
- Converted data.
+ Optional[Union[list, dict, str, int, float, bool]]
+ A JSON-serializable object, or None if dropped.
Examples
--------
>>> import numpy as np
- >>> from fl_sim.utils.misc import make_serializable
- >>> x = np.array([1, 2, 3])
- >>> make_serializable(x)
+ >>> make_serializable(np.array([1, 2, 3]))
[1, 2, 3]
- >>> x = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])}
- >>> make_serializable(x)
- {'a': [1, 2, 3], 'b': [4, 5, 6]}
- >>> x = [np.array([1, 2, 3]), np.array([4, 5, 6])]
- >>> make_serializable(x)
- [[1, 2, 3], [4, 5, 6]]
- >>> x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean())
- >>> obj = make_serializable(x)
- >>> obj
- [[1, 2, 3], 5.0]
- >>> type(obj[1]), type(x[1])
- (float, numpy.float64)
+ >>> make_serializable({"a": np.float64(3.14), "b": Path("file.txt")})
+ {'a': 3.14, 'b': 'file.txt'}
+ >>> make_serializable({"a": np.float64(3.14), "b": Path("file.txt")}, drop_paths=True)
+ {'a': 3.14}
"""
+
if isinstance(x, np.ndarray):
- return x.tolist()
- elif isinstance(x, (list, tuple)):
- # to avoid cases where the list contains numpy data types
- return [make_serializable(v) for v in x]
- elif isinstance(x, dict):
- for k, v in x.items():
- x[k] = make_serializable(v)
+ return make_serializable(x.tolist(), drop_unserializable=drop_unserializable, drop_paths=drop_paths)
+
elif isinstance(x, np.generic):
return x.item()
- # the other types will be returned directly
- return x
+
+ elif isinstance(x, dict):
+ result = {}
+ for k, v in x.items():
+ v_serial = make_serializable(v, drop_unserializable=drop_unserializable, drop_paths=drop_paths)
+ if v_serial is not None:
+ result[k] = v_serial
+ return result if result else None
+
+ elif isinstance(x, (list, tuple)):
+ result = []
+ for v in x:
+ v_serial = make_serializable(v, drop_unserializable=drop_unserializable, drop_paths=drop_paths)
+ if v_serial is not None:
+ result.append(v_serial)
+ return result if result else None
+
+ elif isinstance(x, (str, int, float, bool, type(None))):
+ if isinstance(x, str) and drop_paths and _is_pathlike_string(x):
+ return None
+ return x
+
+ elif isinstance(x, Path):
+ if drop_paths:
+ return None
+ return str(x)
+
+ else:
+ if drop_unserializable:
+ return None
+ else:
+ return str(x)
def select_k(
- arr: np.ndarray, k: Union[int, List[int], np.ndarray], dim: int = -1, largest: bool = True, sorted: bool = True
-) -> Tuple[np.ndarray, np.ndarray]:
+ arr: NDArray, k: Union[int, List[int], NDArray], dim: int = -1, largest: bool = True, sorted: bool = True
+) -> Tuple[NDArray, NDArray]:
"""Select elements from an array along a specified axis of specific rankings.
Parameters
@@ -1602,7 +1684,7 @@ def select_k(
return values, indices
-def np_topk(arr: np.ndarray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+def np_topk(arr: NDArray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[NDArray, NDArray]:
"""Find the k largest elements of an array along a specified axis.
Parameters
diff --git a/torch_ecg/utils/rpeaks.py b/torch_ecg/utils/rpeaks.py
index 635fd724..b4163346 100644
--- a/torch_ecg/utils/rpeaks.py
+++ b/torch_ecg/utils/rpeaks.py
@@ -16,10 +16,10 @@
"""
-from numbers import Real
+from typing import Union
import biosppy.signals.ecg as BSE
-import numpy as np
+from numpy.typing import NDArray
from wfdb.processing.qrs import gqrs_detect as _gqrs_detect
from wfdb.processing.qrs import xqrs_detect as _xqrs_detect
@@ -38,7 +38,7 @@
# algorithms from wfdb
-def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def xqrs_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""XQRS algorithm.
default kwargs:
@@ -60,7 +60,7 @@ def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
return rpeaks
-def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def gqrs_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""GQRS algorithm.
default kwargs:
@@ -102,7 +102,7 @@ def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
# ---------------------------------------------------------------------
# algorithms from biosppy
-def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def hamilton_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""The default detector used by `biosppy`.
This algorithm is based on [#ham]_.
@@ -126,7 +126,7 @@ def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
return rpeaks
-def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def ssf_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""Slope Sum Function (SSF)
This algorithm is originally proposed for blood pressure (BP)
@@ -157,7 +157,7 @@ def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
return rpeaks
-def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def christov_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""Christov detector.
Detector proposed in [#chr]_.
@@ -181,7 +181,7 @@ def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
return rpeaks
-def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def engzee_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""Detector proposed by Engelse and Zeelenberg.
This algorithm is originally proposed in [#ez1]_,
@@ -213,7 +213,7 @@ def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
return rpeaks
-def gamboa_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray:
+def gamboa_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray:
"""Detector proposed by Gamboa.
This algorithm is proposed in a PhD thesis [#gam]_.
diff --git a/torch_ecg/utils/utils_data.py b/torch_ecg/utils/utils_data.py
index c3d7530e..08565a8c 100644
--- a/torch_ecg/utils/utils_data.py
+++ b/torch_ecg/utils/utils_data.py
@@ -7,12 +7,12 @@
from collections import Counter
from copy import deepcopy
from dataclasses import dataclass
-from numbers import Real
from pathlib import Path
from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
+from numpy.typing import NDArray
from sklearn.utils import compute_class_weight
from torch import Tensor, from_numpy
from torch.nn.functional import interpolate
@@ -41,11 +41,11 @@
def get_mask(
shape: Union[int, Sequence[int]],
- critical_points: np.ndarray,
+ critical_points: NDArray,
left_bias: int,
right_bias: int,
return_fmt: Literal["mask", "intervals"] = "mask",
-) -> Union[np.ndarray, list]:
+) -> Union[NDArray, list]:
"""Get the mask around the given critical points.
Parameters
@@ -93,19 +93,21 @@ def get_mask(
mask[..., itv[0] : itv[1]] = 1
elif return_fmt.lower() == "intervals":
mask = l_itv
+ else:
+ raise ValueError(f"Unknown return_fmt. Expected 'mask' or 'intervals', but got {return_fmt}")
return mask
def class_weight_to_sample_weight(
- y: np.ndarray, class_weight: Union[str, dict, List[float], np.ndarray] = "balanced"
-) -> np.ndarray:
+ y: Union[NDArray, Sequence], class_weight: Union[str, dict, List[float], NDArray, None] = "balanced"
+) -> NDArray:
"""Transform class weight to sample weight.
Parameters
----------
- y : numpy.ndarray
+ y : numpy.ndarray or Sequence
The label (class) of each sample.
- class_weight : str or dict or List[float] or numpy.ndarray, default "balanced"
+ class_weight : str or dict or List[float] or numpy.ndarray or None, default "balanced"
The weight for each sample class.
If is "balanced", the class weight will automatically be given by
the inverse of the class frequency.
@@ -138,6 +140,7 @@ def class_weight_to_sample_weight(
sample_weight = np.ones_like(y, dtype=DEFAULTS.np_dtype)
return sample_weight
+ y = np.asarray(y)
try:
sample_weight = np.array(y.copy()).astype(int)
except ValueError:
@@ -249,7 +252,7 @@ def rdheader(header_data: Union[Path, str, Sequence[str]]) -> Union[Record, Mult
return record
-def ensure_lead_fmt(values: np.ndarray, n_leads: int = 12, fmt: str = "lead_first") -> np.ndarray:
+def ensure_lead_fmt(values: NDArray, n_leads: int = 12, fmt: str = "lead_first") -> NDArray:
"""Ensure the multi-lead (ECG) signal to be of specified format.
Parameters
@@ -299,11 +302,11 @@ def ensure_lead_fmt(values: np.ndarray, n_leads: int = 12, fmt: str = "lead_firs
def ensure_siglen(
- values: np.ndarray,
+ values: NDArray,
siglen: int,
fmt: str = "lead_first",
tolerance: Optional[float] = None,
-) -> np.ndarray:
+) -> NDArray:
"""Ensure the (ECG) signal to be of specified length.
Strategy:
@@ -398,16 +401,16 @@ class ECGWaveForm:
----------
name : str
Name of the wave, e.g. "N", "p", "t", etc.
- onset : numbers.Real
+ onset : float or int
Onset index of the wave,
:class:`~numpy.nan` for unknown/unannotated onset.
- offset : numbers.Real
+ offset : float or int
Offset index of the wave,
:class:`~numpy.nan` for unknown/unannotated offset.
- peak : numbers.Real
+ peak : float or int
Peak index of the wave,
:class:`~numpy.nan` for unknown/unannotated peak.
- duration : numbers.Real
+ duration : float or int
Suration of the wave, with units in milliseconds,
:class:`~numpy.nan` for unknown/unannotated duration.
@@ -418,13 +421,13 @@ class ECGWaveForm:
"""
name: str
- onset: Real
- offset: Real
- peak: Real
- duration: Real
+ onset: Union[float, int]
+ offset: Union[float, int]
+ peak: Union[float, int]
+ duration: Union[float, int]
@property
- def duration_(self) -> Real:
+ def duration_(self) -> Union[float, int]:
"""Duration of the wave, with units in number of samples."""
try:
return self.offset - self.onset
@@ -445,9 +448,9 @@ def duration_(self) -> Real:
def masks_to_waveforms(
- masks: np.ndarray,
+ masks: NDArray,
class_map: Dict[str, int],
- fs: Real,
+ fs: Union[float, int],
mask_format: str = "channel_first",
leads: Optional[Sequence[str]] = None,
) -> Dict[str, List[ECGWaveForm]]:
@@ -461,7 +464,7 @@ def masks_to_waveforms(
class_map : dict
Class map, mapping names to waves to numbers from 0 to n_classes-1,
the keys should contain "pwave", "qrs", "twave".
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal corresponding to the `masks`,
used to compute the duration of each waveform.
mask_format : str, default "channel_first"
@@ -542,7 +545,7 @@ def masks_to_waveforms(
def mask_to_intervals(
- mask: np.ndarray,
+ mask: NDArray,
vals: Optional[Union[int, Sequence[int]]] = None,
right_inclusive: bool = False,
) -> Union[list, dict]:
@@ -614,14 +617,14 @@ def mask_to_intervals(
return intervals
-def uniform(low: Real, high: Real, num: int) -> List[float]:
+def uniform(low: Union[float, int], high: Union[float, int], num: int) -> List[float]:
"""Generate a list of numbers uniformly distributed.
Parameters
----------
- low : numbers.Real
+ low : float or int
Lower bound of the interval of the uniform distribution.
- high : numbers.Real
+ high : float or int
Upper bound of the interval of the uniform distribution.
num : int
Number of random numbers to generate.
@@ -694,7 +697,7 @@ def stratified_train_test_split(
try:
df_inspection = df[stratified_cols].copy().map(str)
except AttributeError:
- df_inspection = df[stratified_cols].copy().applymap(str)
+ df_inspection = df[stratified_cols].copy().applymap(str) # type: ignore
for item in stratified_cols:
all_entities = df_inspection[item].unique().tolist()
entities_dict = {e: str(i) for i, e in enumerate(all_entities)}
@@ -704,7 +707,7 @@ def stratified_train_test_split(
df_inspection[inspection_col_name] = ""
for idx, row in df_inspection.iterrows():
cn = "-".join([row[sc] for sc in stratified_cols])
- df_inspection.loc[idx, inspection_col_name] = cn
+ df_inspection.loc[idx, inspection_col_name] = cn # type: ignore
item_names = df_inspection[inspection_col_name].unique().tolist()
item_indices = {n: df_inspection.index[df_inspection[inspection_col_name] == n].tolist() for n in item_names}
for n in item_names:
@@ -723,8 +726,8 @@ def stratified_train_test_split(
def one_hot_encode(
- cls_array: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32
-) -> np.ndarray:
+ cls_array: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32
+) -> NDArray:
"""Convert a categorical array to a one-hot array.
Convert a categorical (class indices) array of shape ``(num_samples,)``
@@ -768,42 +771,42 @@ def one_hot_encode(
else: # sequence of sequences of class indices
num_classes = max([max(c) for c in cls_array]) + 1
if isinstance(cls_array, np.ndarray) and cls_array.ndim == 1:
- assert num_classes > 0 and num_classes >= cls_array.max() + 1, (
+ assert num_classes > 0 and num_classes >= cls_array.max() + 1, ( # type: ignore
"num_classes must be greater than 0 and greater than or equal to "
"the max value of `cls_array` if `cls_array` is 1D and `num_classes` is specified"
)
elif isinstance(cls_array, Sequence):
assert all(
- [max(c) < num_classes for c in cls_array]
+ [max(c) < num_classes for c in cls_array] # type: ignore
), "all values in the multi-class `cls_array` should be less than `num_classes`"
if isinstance(cls_array, np.ndarray) and cls_array.ndim == 2 and cls_array.shape[1] == num_classes:
bin_array = cls_array
else:
shape = (len(cls_array), num_classes)
- bin_array = np.zeros(shape)
+ bin_array = np.zeros(shape) # type: ignore
for i in range(shape[0]):
bin_array[i, cls_array[i]] = 1
return bin_array.astype(dtype)
-@add_docstring(one_hot_encode.__doc__.replace("one_hot_encode", "cls_to_bin"))
+@add_docstring(one_hot_encode.__doc__.replace("one_hot_encode", "cls_to_bin")) # type: ignore
def cls_to_bin(
- cls_array: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32
-) -> np.ndarray:
+ cls_array: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32
+) -> NDArray:
"""Alias of `one_hot_encode`."""
warnings.warn("`cls_to_bin` is deprecated, use `one_hot_encode` instead", DeprecationWarning)
return one_hot_encode(cls_array, num_classes, dtype)
def generate_weight_mask(
- target_mask: np.ndarray,
- fg_weight: Real,
- fs: Real,
- reduction: Real,
- radius: Real,
- boundary_weight: Real,
+ target_mask: NDArray,
+ fg_weight: Union[float, int],
+ fs: Union[float, int],
+ reduction: Union[float, int],
+ radius: Union[float, int],
+ boundary_weight: Union[float, int],
plot: bool = False,
-) -> np.ndarray:
+) -> NDArray:
"""Generate weight mask for a binary target mask,
accounting the foreground weight and boundary weight.
@@ -811,15 +814,15 @@ def generate_weight_mask(
----------
target_mask : numpy.ndarray
The target mask, assumed to be 1D and binary.
- fg_weight: numbers.Real
+ fg_weight: float or int
Foreground (value 1) weight, usually > 1.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal.
- reduction : numbers.Real
+ reduction : float or int
Reduction ratio of the mask w.r.t. the signal.
- radius : numbers.Real
+ radius : float or int
Radius of the boundary, with units in seconds.
- boundary_weight : numbers.Real
+ boundary_weight : float or int
Weight for the boundaries (positions where values change)
of the target map.
plot : bool, default False
@@ -855,7 +858,7 @@ def generate_weight_mask(
"""
assert target_mask.ndim == 1, "`target_mask` should be 1D"
assert set(np.unique(target_mask)).issubset({0, 1}), "`target_mask` should be binary"
- assert isinstance(reduction, Real) and reduction >= 1, "`reduction` should be a real number greater than 1"
+ assert isinstance(reduction, (int, float)) and reduction >= 1, "`reduction` should be a real number greater than 1"
if reduction > 1:
# downsample the target mask
target_mask = (
diff --git a/torch_ecg/utils/utils_interval.py b/torch_ecg/utils/utils_interval.py
index b9b3c91e..5eb48aa8 100644
--- a/torch_ecg/utils/utils_interval.py
+++ b/torch_ecg/utils/utils_interval.py
@@ -20,6 +20,7 @@
from typing import Any, List, Literal, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
__all__ = [
"overlaps",
@@ -51,9 +52,9 @@ def overlaps(interval: Interval, another: Interval) -> int:
The amount of overlap, in bp between interval and anohter, is returned.
- - If > 0, the number of bp of overlap
- - If 0, they are book-ended
- - If < 0, the distance in bp between them
+ - If > 0, the number of bp of overlap
+ - If 0, they are book-ended
+ - If < 0, the distance in bp between them
Parameters
----------
@@ -104,8 +105,8 @@ def validate_interval(
tuple
2-tuple consisting of
- - bool: indicating whether `interval` is a valid interval
- - an interval (can be empty)
+ - bool: indicating whether `interval` is a valid interval
+ - an interval (can be empty)
Examples
--------
@@ -770,9 +771,9 @@ def find_max_cont_len(sublist: Interval, tot_rng: Real) -> dict:
dict
A dictionary containing the following keys:
- - "max_cont_len"
- - "max_cont_sublist_start"
- - "max_cont_sublist"
+ - "max_cont_len"
+ - "max_cont_sublist_start"
+ - "max_cont_sublist"
Examples
--------
@@ -857,7 +858,7 @@ def generalized_interval_len(generalized_interval: GeneralizedInterval) -> Real:
return gi_len
-def find_extrema(signal: Union[np.ndarray, Sequence], mode: Literal["max", "min", "both"] = "both") -> np.ndarray:
+def find_extrema(signal: Union[NDArray, Sequence], mode: Literal["max", "min", "both"] = "both") -> NDArray:
"""Locate local extrema points in a 1D signal.
This function is based on Fermat's Theorem.
diff --git a/torch_ecg/utils/utils_metrics.py b/torch_ecg/utils/utils_metrics.py
index 7be1ff78..e7bed79c 100644
--- a/torch_ecg/utils/utils_metrics.py
+++ b/torch_ecg/utils/utils_metrics.py
@@ -12,6 +12,7 @@
import einops
import numpy as np
import torch
+from numpy.typing import NDArray
from torch import Tensor
from ..cfg import DEFAULTS
@@ -29,8 +30,8 @@
def top_n_accuracy(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
n: Union[int, Sequence[int]] = 1,
) -> Union[float, Dict[str, float]]:
"""Compute top n accuracy.
@@ -85,10 +86,10 @@ def top_n_accuracy(
def confusion_matrix(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
-) -> np.ndarray:
+) -> NDArray:
"""Compute a binary confusion matrix
The columns are ground truth labels and rows are predicted labels.
@@ -129,10 +130,10 @@ def confusion_matrix(
def one_vs_rest_confusion_matrix(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
-) -> np.ndarray:
+) -> NDArray:
"""Compute binary one-vs-rest confusion matrices.
Columns are ground truth labels and rows are predicted labels.
@@ -216,13 +217,13 @@ def one_vs_rest_confusion_matrix(
"prepend",
)
def metrics_from_confusion_matrix(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Dict[str, Union[float, np.ndarray]]:
+) -> Dict[str, Union[float, NDArray]]:
"""
Returns
-------
@@ -434,13 +435,13 @@ def metrics_from_confusion_matrix(
"prepend",
)
def f_measure(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Tuple[float, np.ndarray]:
+) -> Tuple[float, NDArray]:
"""
Returns
-------
@@ -460,13 +461,13 @@ def f_measure(
"prepend",
)
def sensitivity(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Tuple[float, np.ndarray]:
+) -> Tuple[float, NDArray]:
"""
Returns
-------
@@ -492,13 +493,13 @@ def sensitivity(
"prepend",
)
def precision(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Tuple[float, np.ndarray]:
+) -> Tuple[float, NDArray]:
"""
Returns
-------
@@ -522,13 +523,13 @@ def precision(
"prepend",
)
def specificity(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Tuple[float, np.ndarray]:
+) -> Tuple[float, NDArray]:
"""
Returns
-------
@@ -553,13 +554,13 @@ def specificity(
"prepend",
)
def auc(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
-) -> Tuple[float, float, np.ndarray, np.ndarray]:
+) -> Tuple[float, float, NDArray, NDArray]:
"""
Returns
-------
@@ -585,10 +586,10 @@ def auc(
"prepend",
)
def accuracy(
- labels: Union[np.ndarray, Tensor],
- outputs: Union[np.ndarray, Tensor],
+ labels: Union[NDArray, Tensor],
+ outputs: Union[NDArray, Tensor],
num_classes: Optional[int] = None,
- weights: Optional[Union[np.ndarray, Tensor]] = None,
+ weights: Optional[Union[NDArray, Tensor]] = None,
thr: float = 0.5,
fillna: Union[bool, float] = 0.0,
) -> float:
@@ -607,8 +608,8 @@ def accuracy(
def QRS_score(
- rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]],
- rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]],
+ rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]],
+ rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]],
fs: Real,
thr: float = 0.075,
) -> float:
@@ -679,10 +680,10 @@ def QRS_score(
def one_hot_pair(
- labels: Union[np.ndarray, Tensor, Sequence[Sequence[int]]],
- outputs: Union[np.ndarray, Tensor, Sequence[Sequence[int]]],
+ labels: Union[NDArray, Tensor, Sequence[Sequence[int]]],
+ outputs: Union[NDArray, Tensor, Sequence[Sequence[int]]],
num_classes: Optional[int] = None,
-) -> Tuple[np.ndarray, np.ndarray]:
+) -> Tuple[NDArray, NDArray]:
"""Convert categorical (of shape ``(n_samples,)``) labels and outputs
to binary (of shape ``(n_samples, n_classes)``) labels and outputs if applicable.
@@ -729,7 +730,7 @@ def one_hot_pair(
return labels, outputs
-def _one_hot_pair(cls_array: Union[np.ndarray, Sequence[Sequence[int]]], shape: Tuple[int]) -> np.ndarray:
+def _one_hot_pair(cls_array: Union[NDArray, Sequence[Sequence[int]]], shape: Tuple[int]) -> NDArray:
"""Convert categorical array to binary array.
Parameters
@@ -757,18 +758,18 @@ def _one_hot_pair(cls_array: Union[np.ndarray, Sequence[Sequence[int]]], shape:
@add_docstring(one_hot_pair.__doc__)
def cls_to_bin(
- labels: Union[np.ndarray, Tensor, Sequence[Sequence[int]]],
- outputs: Union[np.ndarray, Tensor, Sequence[Sequence[int]]],
+ labels: Union[NDArray, Tensor, Sequence[Sequence[int]]],
+ outputs: Union[NDArray, Tensor, Sequence[Sequence[int]]],
num_classes: Optional[int] = None,
-) -> Tuple[np.ndarray, np.ndarray]:
+) -> Tuple[NDArray, NDArray]:
"""Alias of `one_hot_pair`."""
warnings.warn("`cls_to_bin` is deprecated, use `one_hot_pair` instead", DeprecationWarning)
return one_hot_pair(labels, outputs, num_classes)
def compute_wave_delineation_metrics(
- truth_masks: Sequence[np.ndarray],
- pred_masks: Sequence[np.ndarray],
+ truth_masks: Sequence[NDArray],
+ pred_masks: Sequence[NDArray],
class_map: Dict[str, int],
fs: Real,
mask_format: str = "channel_first",
diff --git a/torch_ecg/utils/utils_nn.py b/torch_ecg/utils/utils_nn.py
index 22987ded..8b9c4804 100644
--- a/torch_ecg/utils/utils_nn.py
+++ b/torch_ecg/utils/utils_nn.py
@@ -3,6 +3,7 @@
"""
+import json
import os
import pickle
import re
@@ -17,6 +18,9 @@
import numpy as np
import torch
from easydict import EasyDict
+from numpy.typing import NDArray
+from safetensors import safe_open
+from safetensors.torch import load_file, save_file
from torch import Tensor, nn
from ..cfg import CFG, DEFAULTS, DTYPE
@@ -52,17 +56,23 @@ def _get_np_dtypes():
_safe_globals = [
CFG, DTYPE, EasyDict,
Path, PosixPath, WindowsPath,
- np.core.multiarray._reconstruct,
+ # np.core.multiarray._reconstruct,
np.ndarray, np.dtype,
np.float32, np.float64, np.int32, np.int64, np.uint8, np.int8,
] + _get_np_dtypes()
+ if hasattr(np, "core"):
+ _safe_globals.append(np.core.multiarray._reconstruct) # type: ignore
+ else:
+ _safe_globals.append(np._core.multiarray._reconstruct) # type: ignore
# fmt: on
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals(_safe_globals)
-def extend_predictions(preds: Sequence, classes: List[str], extended_classes: List[str]) -> np.ndarray:
+def extend_predictions(
+ preds: Union[Sequence, NDArray, torch.Tensor], classes: List[str], extended_classes: List[str]
+) -> NDArray:
"""Extend the prediction arrays to prediction arrays in larger range of classes
Parameters
@@ -140,8 +150,8 @@ def compute_output_shape(
output_padding: Union[Sequence[int], int] = 0,
dilation: Union[Sequence[int], int] = 1,
channel_last: bool = False,
- asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None,
-) -> Tuple[Union[int, None]]:
+ asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None,
+) -> Tuple[Union[int, None], ...]:
"""Compute the output shape of a (transpose) convolution/maxpool/avgpool layer.
This function is based on the discussion [#disc]_.
@@ -287,17 +297,17 @@ def check_output_validity(shape):
none_dim_msg = "spatial dimensions should be all `None`, or all not `None`"
if channel_last:
if all([n is None for n in input_shape[1:-1]]):
- if out_channels is None:
+ if out_channels is None: # type: ignore
raise ValueError("out channel dimension and spatial dimensions are all `None`")
- output_shape = tuple(list(input_shape[:-1]) + [out_channels])
+ output_shape = tuple(list(input_shape[:-1]) + [out_channels]) # type: ignore
return check_output_validity(output_shape)
elif any([n is None for n in input_shape[1:-1]]):
raise ValueError(none_dim_msg)
else:
if all([n is None for n in input_shape[2:]]):
- if out_channels is None:
+ if out_channels is None: # type: ignore
raise ValueError("out channel dimension and spatial dimensions are all `None`")
- output_shape = tuple([input_shape[0], out_channels] + list(input_shape[2:]))
+ output_shape = tuple([input_shape[0], out_channels] + list(input_shape[2:])) # type: ignore
return check_output_validity(output_shape)
elif any([n is None for n in input_shape[2:]]):
raise ValueError(none_dim_msg)
@@ -349,12 +359,12 @@ def check_output_validity(shape):
_asymmetric_padding = list(repeat(asymmetric_padding, dim))
else:
assert len(asymmetric_padding) == dim and all(
- len(ap) == 2 and all(isinstance(p, int) for p in ap) for ap in asymmetric_padding
+ len(ap) == 2 and all(isinstance(p, int) for p in ap) for ap in asymmetric_padding # type: ignore
), "Invalid `asymmetric_padding`"
_asymmetric_padding = asymmetric_padding
for idx in range(dim):
- _padding[idx][0] += _asymmetric_padding[idx][0]
- _padding[idx][1] += _asymmetric_padding[idx][1]
+ _padding[idx][0] += _asymmetric_padding[idx][0] # type: ignore
+ _padding[idx][1] += _asymmetric_padding[idx][1] # type: ignore
if isinstance(output_padding, int):
_output_padding = list(repeat(output_padding, dim))
@@ -400,13 +410,13 @@ def check_output_validity(shape):
]
else:
output_shape = [
- floor(((i + sum(p) - minus_term(d, k)) / s) + 1)
+ floor(((i + sum(p) - minus_term(d, k)) / s) + 1) # type: ignore
for i, p, d, k, s in zip(_input_shape, _padding, _dilation, _kernel_size, _stride)
]
if channel_last:
- output_shape = tuple([input_shape[0]] + output_shape + [out_channels])
+ output_shape = tuple([input_shape[0]] + output_shape + [out_channels]) # type: ignore
else:
- output_shape = tuple([input_shape[0], out_channels] + output_shape)
+ output_shape = tuple([input_shape[0], out_channels] + output_shape) # type: ignore
return check_output_validity(output_shape)
@@ -419,8 +429,8 @@ def compute_conv_output_shape(
padding: Union[Sequence[int], int] = 0,
dilation: Union[Sequence[int], int] = 1,
channel_last: bool = False,
- asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None,
-) -> Tuple[Union[int, None]]:
+ asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None,
+) -> Tuple[Union[int, None], ...]:
"""Compute the output shape of a convolution layer.
Parameters
@@ -477,7 +487,7 @@ def compute_maxpool_output_shape(
padding: Union[Sequence[int], int] = 0,
dilation: Union[Sequence[int], int] = 1,
channel_last: bool = False,
-) -> Tuple[Union[int, None]]:
+) -> Tuple[Union[int, None], ...]:
"""Compute the output shape of a maxpool layer.
Parameters
@@ -527,7 +537,7 @@ def compute_avgpool_output_shape(
stride: Union[Sequence[int], int] = 1,
padding: Union[Sequence[int], int] = 0,
channel_last: bool = False,
-) -> Tuple[Union[int, None]]:
+) -> Tuple[Union[int, None], ...]:
"""Compute the output shape of a avgpool layer.
Parameters
@@ -578,8 +588,8 @@ def compute_deconv_output_shape(
output_padding: Union[Sequence[int], int] = 0,
dilation: Union[Sequence[int], int] = 1,
channel_last: bool = False,
- asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None,
-) -> Tuple[Union[int, None]]:
+ asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None,
+) -> Tuple[Union[int, None], ...]:
"""Compute the output shape of a transpose convolution layer
Parameters
@@ -657,10 +667,12 @@ def compute_sequential_output_shape(
"""Compute the output shape of a sequential model."""
assert issubclass(type(model), nn.Sequential), f"model should be nn.Sequential, but got {type(model)}"
_seq_len = seq_len
+ if len(model) == 0:
+ raise AssertionError("model has no modules")
for module in model:
output_shape = module.compute_output_shape(_seq_len, batch_size)
_, _, _seq_len = output_shape
- return output_shape
+ return output_shape # type: ignore
def compute_module_size(
@@ -760,7 +772,7 @@ def compute_receptive_field(
strides: Union[Sequence[int], int] = 1,
dilations: Union[Sequence[int], int] = 1,
input_len: Optional[int] = None,
- fs: Optional[Real] = None,
+ fs: Optional[Union[int, float]] = None,
) -> Union[int, float]:
"""Compute the receptive field of several types of
:class:`~torch.nn.Module`.
@@ -855,11 +867,11 @@ def compute_receptive_field(
receptive_field = min(receptive_field, input_len)
if fs is not None:
receptive_field /= fs
- return make_serializable(receptive_field)
+ return make_serializable(receptive_field) # type: ignore
def default_collate_fn(
- batch: Sequence[Union[Tuple[np.ndarray, ...], Dict[str, np.ndarray]]],
+ batch: Sequence[Union[Tuple[NDArray, ...], Dict[str, NDArray]]],
) -> Union[Tuple[Tensor, ...], Dict[str, Tensor]]:
"""Default collate functions for model training.
@@ -883,13 +895,13 @@ def default_collate_fn(
"""
if isinstance(batch[0], dict):
keys = batch[0].keys()
- collated = _default_collate_fn([tuple(b[k] for k in keys) for b in batch])
+ collated = _default_collate_fn([tuple(b[k] for k in keys) for b in batch]) # type: ignore
return {k: collated[i] for i, k in enumerate(keys)}
else:
- return _default_collate_fn(batch)
+ return _default_collate_fn(batch) # type: ignore
-def _default_collate_fn(batch: Sequence[Tuple[np.ndarray, ...]]) -> Tuple[Tensor, ...]:
+def _default_collate_fn(batch: Sequence[Tuple[NDArray, ...]]) -> Tuple[Tensor, ...]:
"""Collate functions for tuples of tensors.
The data generator (:class:`~torch.utils.data.Dataset`) should
@@ -970,7 +982,7 @@ def _adjust_cnn_filter_lengths(
]
elif isinstance(v, Real):
# DO NOT use `int`, which might not work for numpy array elements
- if v > 1:
+ if v > 1: # type: ignore
config[k] = int(round(v * fs / config["fs"]))
if ensure_odd:
config[k] = config[k] - config[k] % 2 + 1
@@ -1018,27 +1030,27 @@ class SizeMixin(object):
@property
def module_size(self) -> int:
"""Size of trainable parameters in the model in terms of number of parameters."""
- return compute_module_size(self)
+ return compute_module_size(self) # type: ignore
@property
def module_size_(self) -> str:
"""Size of trainable parameters in the model in terms of memory capacity."""
- return compute_module_size(self, human=True)
+ return compute_module_size(self, human=True) # type: ignore
@property
def sizeof(self) -> int:
"""Size of the model in terms of number of parameters, including non-trainable parameters and buffers."""
- return compute_module_size(self, requires_grad=False, include_buffers=True, human=False)
+ return compute_module_size(self, requires_grad=False, include_buffers=True, human=False) # type: ignore
@property
def sizeof_(self) -> str:
"""Size of the model in terms of memory capacity, including non-trainable parameters and buffers."""
- return compute_module_size(self, requires_grad=False, include_buffers=True, human=True)
+ return compute_module_size(self, requires_grad=False, include_buffers=True, human=True) # type: ignore
@property
def dtype(self) -> torch.dtype:
try:
- return next(self.parameters()).dtype
+ return next(self.parameters()).dtype # type: ignore
except StopIteration:
return torch.float32
except Exception as err:
@@ -1047,7 +1059,7 @@ def dtype(self) -> torch.dtype:
@property
def device(self) -> torch.device:
try:
- return next(self.parameters()).device
+ return next(self.parameters()).device # type: ignore
except StopIteration:
return torch.device("cpu")
except Exception as err:
@@ -1103,7 +1115,16 @@ def make_safe_globals(obj: CFG, remove_paths: bool = True) -> CFG:
sg = None
elif isinstance(sg, (str, bytes)) and os.path.exists(sg):
sg = None
- return sg
+ return sg # type: ignore
+
+
+_SFT_META_MODEL_CFG = "model_config"
+_SFT_META_TRAIN_CFG = "train_config"
+_SFT_META_FORMAT = "__format__"
+_SFT_META_VERSION = "torch_ecg.ckpt.v1"
+_SFT_META_EXTRA_JSON_PREFIX = "__extra_json__/" # JSON-serialized extras
+_SFT_EXTRA_TENSOR_PREFIX = "__extra__/" # Tensor extras grouped under this prefix
+_SFT_META_EXTRA_TENSOR_GROUPS = "__extra_tensor_groups__" # JSON list of groups
class CkptMixin(object):
@@ -1122,8 +1143,9 @@ def from_checkpoint(
----------
path : `path-like`
Path to the checkpoint.
- If it is a directory, then this directory should contain only one checkpoint file
- (with the extension `.pth` or `.pt`).
+ If it is a directory, then this directory should be one of the following cases:
+ - contain a `model.safetensors` file, a `model_config.json` file, and a `train_config.json` file.
+ - contain only one checkpoint file (with the extension `.pth` or `.pt`).
device : torch.device, optional
Map location of the model parameters,
defaults to "cuda" if available, otherwise "cpu".
@@ -1140,16 +1162,77 @@ def from_checkpoint(
Auxiliary configs that are needed for data preprocessing, etc.
"""
- if Path(path).is_dir():
- candidates = list(Path(path).glob("*.pth")) + list(Path(path).glob("*.pt"))
- assert len(candidates) == 1, "The directory should contain only one checkpoint file"
- path = candidates[0]
- _device = device or DEFAULTS.device
+ _device = device or DEFAULTS.device # type: ignore
if weights_only == "auto":
if hasattr(torch.serialization, "add_safe_globals"):
weights_only = True
else:
weights_only = False
+
+ if isinstance(path, bytes):
+ path = path.decode()
+ path = Path(path).expanduser().resolve() # type: ignore
+
+ if path.is_dir():
+ candidates = list(path.glob("*.pth")) + list(path.glob("*.pt"))
+ assert len(candidates) in [0, 1], "The directory should contain only one checkpoint file"
+ if len(candidates) == 0:
+ # the directory should contain a `model.safetensors` file, a `model_config.json` file,
+ # and a `train_config.json` file
+ model_path = path / "model.safetensors"
+ train_config_path = path / "train_config.json"
+ model_config_path = path / "model_config.json"
+ assert model_path.exists(), "model.safetensors file not found"
+ assert train_config_path.exists(), "train_config.json file not found"
+ assert model_config_path.exists(), "model_config.json file not found"
+ train_config = json.loads(train_config_path.read_text())
+ model_config = json.loads(model_config_path.read_text())
+ aux_config = train_config
+ kwargs = dict(config=model_config)
+ if "classes" in aux_config:
+ kwargs["classes"] = aux_config["classes"]
+ if "n_leads" in aux_config:
+ kwargs["n_leads"] = aux_config["n_leads"]
+ model = cls(**kwargs)
+ model.load_state_dict(load_file(model_path, device="cpu")) # type: ignore
+ model.to(_device) # type: ignore
+ return model, aux_config # type: ignore
+
+ # we have only one checkpoint file in the directory
+ # and we will use `torch.load` to load the model
+ path = candidates[0]
+
+ if path.suffix == ".safetensors": # load safetensors format file
+ try:
+ with safe_open(str(path), framework="pt", device="cpu") as f: # type: ignore
+ meta = f.metadata() or {}
+ if _SFT_META_MODEL_CFG in meta and _SFT_META_TRAIN_CFG in meta:
+ model_config = json.loads(meta[_SFT_META_MODEL_CFG])
+ train_config = json.loads(meta[_SFT_META_TRAIN_CFG])
+ aux_config = train_config
+ kwargs = dict(config=model_config)
+ if "classes" in aux_config:
+ kwargs["classes"] = aux_config["classes"]
+ if "n_leads" in aux_config:
+ kwargs["n_leads"] = aux_config["n_leads"]
+ model = cls(**kwargs)
+
+ # load only model weights, and ignores extra tensors
+ state_dict = {}
+ for key in f.keys():
+ if not key.startswith(_SFT_EXTRA_TENSOR_PREFIX):
+ state_dict[key] = f.get_tensor(key)
+ model.load_state_dict(state_dict, strict=False) # type: ignore
+ model.to(_device) # type: ignore
+ return model, aux_config # type: ignore
+ except Exception as e:
+ # failed to load the safetensors file
+ raise RuntimeError(
+ "Failed to load the safetensors file. "
+ "The file may be corrupted, or the version of safetensors is incompatible. "
+ "Try updating safetensors to the latest version, or check the file integrity."
+ ) from e
+
try:
ckpt = torch.load(path, map_location=_device, weights_only=weights_only)
except pickle.UnpicklingError as pue:
@@ -1160,16 +1243,14 @@ def from_checkpoint(
) from pue
aux_config = ckpt.get("train_config", None) or ckpt.get("config", None)
assert aux_config is not None, "input checkpoint has no sufficient data to recover a model"
- kwargs = dict(
- config=ckpt["model_config"],
- )
+ kwargs = dict(config=ckpt["model_config"])
if "classes" in aux_config:
kwargs["classes"] = aux_config["classes"]
if "n_leads" in aux_config:
kwargs["n_leads"] = aux_config["n_leads"]
model = cls(**kwargs)
- model.load_state_dict(ckpt["model_state_dict"])
- return model, aux_config
+ model.load_state_dict(ckpt["model_state_dict"]) # type: ignore
+ return model, aux_config # type: ignore
@classmethod
def from_remote(
@@ -1209,9 +1290,22 @@ def from_remote(
model_path_or_dir = http_get(url, model_dir, extract="auto", filename=filename)
return cls.from_checkpoint(model_path_or_dir, device=device, weights_only=weights_only)
- def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None:
+ def save(
+ self,
+ path: Union[str, bytes, os.PathLike],
+ train_config: CFG,
+ extra_items: Optional[dict] = None,
+ use_safetensors: bool = True,
+ safetensors_single_file: bool = True,
+ ) -> None:
"""Save the model to disk.
+ .. note::
+
+ `safetensors` is used by default to save the model.
+ If one wants to save the models in `.pth` or `.pt` format,
+ he/she must explicitly set ``use_safetensors=False``.
+
Parameters
----------
path : `path-like`
@@ -1219,22 +1313,98 @@ def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None:
train_config : CFG
Config for training the model,
used when one restores the model.
+ extra_items : dict, optional
+ Extra items to save along with the model.
+ The values should be serializable: can be saved as a json file,
+ or is a dict of torch tensors.
+
+ .. versionadded:: 0.0.32
+ use_safetensors : bool, default True
+ Whether to use `safetensors` to save the model.
+ This will be overridden by the suffix of `path`:
+ if it is `.safetensors`, then `use_safetensors` is set to True;
+ if it is `.pth` or `.pt`, then if `use_safetensors` is True,
+ the suffix is changed to `.safetensors`, otherwise it is unchanged.
+
+ .. versionadded:: 0.0.32
+ safetensors_single_file : bool, default True
+ Whether to save the metadata along with the state dict into one file.
+
+ .. versionadded:: 0.0.32
Returns
-------
None
"""
- path = Path(path)
+ if isinstance(path, bytes):
+ path = path.decode()
+ path = Path(path).expanduser().resolve() # type: ignore
if not path.parent.exists():
path.parent.mkdir(parents=True)
- _model_config = make_safe_globals(self.config)
+ extra_items = extra_items or {}
+
+ _model_config = make_safe_globals(self.config) # type: ignore
_train_config = make_safe_globals(train_config)
+
+ if path.suffix in [".pth", ".pt"]:
+ if use_safetensors:
+ path = path.with_suffix(".safetensors")
+ warnings.warn(
+ f"`safetensors` is used by default. The saved file name is changed to {path.name}", RuntimeWarning
+ )
+ elif path.suffix == ".safetensors":
+ use_safetensors = True
+
+ if use_safetensors and safetensors_single_file:
+ tensors = dict(self.state_dict()) # type: ignore
+
+ tensor_groups = []
+ for key, val in extra_items.items():
+ if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()):
+ tensor_groups.append(key)
+ for tname, ten in val.items():
+ tensors[f"{_SFT_EXTRA_TENSOR_PREFIX}{key}/{tname}"] = ten
+
+ meta = {
+ _SFT_META_FORMAT: _SFT_META_VERSION,
+ _SFT_META_MODEL_CFG: json.dumps(make_serializable(_model_config), ensure_ascii=False),
+ _SFT_META_TRAIN_CFG: json.dumps(make_serializable(_train_config), ensure_ascii=False),
+ _SFT_META_EXTRA_TENSOR_GROUPS: json.dumps(sorted(tensor_groups)),
+ }
+ for key, val in extra_items.items():
+ if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()):
+ continue
+ meta[f"{_SFT_META_EXTRA_JSON_PREFIX}{key}"] = json.dumps(make_serializable(val), ensure_ascii=False)
+
+ save_file(tensors, path.with_suffix(".safetensors"), metadata=meta)
+ return
+
+ if use_safetensors: # not single file
+ # save the model with safetensors into a zip file with the same name as `path`
+ # `model_config` and `train_config` are saved as json files
+ path = path.with_suffix("")
+ path.mkdir(exist_ok=True)
+ _model_config = make_serializable(_model_config)
+ _train_config = make_serializable(_train_config)
+ (path / "model_config.json").write_text(json.dumps(_model_config, ensure_ascii=False))
+ (path / "train_config.json").write_text(json.dumps(_train_config, ensure_ascii=False))
+ save_file(self.state_dict(), path / "model.safetensors") # type: ignore
+ # save extra items
+ for key, val in extra_items.items():
+ # if val is a dict of torch tensors, save them as safetensors
+ if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()):
+ save_file(val, path / f"{key}.safetensors")
+ else:
+ (path / f"{key}.json").write_text(json.dumps(make_serializable(val), ensure_ascii=False))
+ return
+
torch.save(
{
- "model_state_dict": self.state_dict(),
+ "model_state_dict": self.state_dict(), # type: ignore
"model_config": _model_config,
"train_config": _train_config,
+ **extra_items,
},
path,
)
diff --git a/torch_ecg/utils/utils_signal.py b/torch_ecg/utils/utils_signal.py
index e3c41d07..305e5588 100644
--- a/torch_ecg/utils/utils_signal.py
+++ b/torch_ecg/utils/utils_signal.py
@@ -5,10 +5,10 @@
import warnings
from copy import deepcopy
-from numbers import Real
from typing import Iterable, Literal, Optional, Sequence, Tuple, Union
import numpy as np
+from numpy.typing import NDArray
from scipy import interpolate
from scipy.signal import butter, filtfilt, peak_prominences
@@ -26,12 +26,12 @@
def smooth(
- x: np.ndarray,
+ x: NDArray,
window_len: int = 11,
window: Literal["flat", "hanning", "hamming", "bartlett", "blackman"] = "hanning",
mode: str = "valid",
keep_dtype: bool = True,
-) -> np.ndarray:
+) -> NDArray:
"""Smooth the 1d data using a window with requested size.
This method is originally from [#smooth]_,
@@ -119,7 +119,7 @@ def smooth(
else:
w = eval("np." + window + "(radius)")
- y = np.convolve(w / w.sum(), s, mode=mode)
+ y = np.convolve(w / w.sum(), s, mode=mode) # type: ignore
y = y[(radius // 2 - 1) : -(radius // 2) - 1]
assert len(x) == len(y)
@@ -130,14 +130,14 @@ def smooth(
def resample_irregular_timeseries(
- sig: np.ndarray,
- output_fs: Optional[Real] = None,
+ sig: NDArray,
+ output_fs: Optional[Union[float, int]] = None,
method: Literal["spline", "interp1d"] = "interp1d",
return_with_time: bool = False,
- tnew: Optional[np.ndarray] = None,
+ tnew: Optional[NDArray] = None,
interp_kw: dict = {},
verbose: int = 0,
-) -> np.ndarray:
+) -> NDArray:
"""
Resample the 2d irregular timeseries `sig` into a 1d or 2d
regular time series with frequency `output_fs`,
@@ -149,7 +149,7 @@ def resample_irregular_timeseries(
sig : numpy.ndarray
The 2d irregular timeseries.
Each row is ``[time, value]``.
- output_fs : numbers.Real, optional
+ output_fs : float or int, optional
the frequency of the output 1d regular timeseries,
one and only one of `output_fs` and `tnew` should be specified
method : {"spline", "interp1d"}, default "interp1d"
@@ -203,7 +203,7 @@ def resample_irregular_timeseries(
dtype = sig.dtype
time_series = np.atleast_2d(sig).astype(dtype)
if tnew is None:
- step_ts = 1000 / output_fs
+ step_ts = 1000 / output_fs # type: ignore
tot_len = int((time_series[-1][0] - time_series[0][0]) / step_ts) + 1
xnew = time_series[0][0] + np.arange(0, tot_len * step_ts, step_ts)
else:
@@ -234,19 +234,19 @@ def resample_irregular_timeseries(
regular_timeseries = f(xnew)
if return_with_time:
- return np.column_stack((xnew, regular_timeseries)).astype(dtype)
+ return np.column_stack((xnew, regular_timeseries)).astype(dtype) # type: ignore
else:
- return regular_timeseries.astype(dtype)
+ return regular_timeseries.astype(dtype) # type: ignore
def detect_peaks(
x: Sequence,
- mph: Optional[Real] = None,
+ mph: Optional[Union[float, int]] = None,
mpd: int = 1,
- threshold: Real = 0,
- left_threshold: Real = 0,
- right_threshold: Real = 0,
- prominence: Optional[Real] = None,
+ threshold: Union[float, int] = 0,
+ left_threshold: Union[float, int] = 0,
+ right_threshold: Union[float, int] = 0,
+ prominence: Optional[Union[float, int]] = None,
prominence_wlen: Optional[int] = None,
edge: Union[str, None] = "rising",
kpsh: bool = False,
@@ -254,7 +254,7 @@ def detect_peaks(
show: bool = False,
ax=None,
verbose: int = 0,
-) -> np.ndarray:
+) -> NDArray:
"""Detect peaks in data based on their amplitude and other features.
Parameters
@@ -384,7 +384,7 @@ def detect_peaks(
# handle NaN's
if ind.size and indnan.size:
# NaN's and values close to NaN's cannot be peaks
- ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)]
+ ind = ind[np.isin(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)]
if verbose >= 1:
print(f"after handling nan values, ind = {ind.tolist()}")
@@ -455,7 +455,7 @@ def detect_peaks(
return ind
-def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = True) -> np.ndarray:
+def remove_spikes_naive(sig: NDArray, threshold: Union[float, int] = 20, inplace: bool = True) -> NDArray:
"""Remove signal spikes using a naive method.
This is a method proposed in entry 0416 of CPSC2019.
@@ -470,7 +470,7 @@ def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = T
1D, 2D or 3D signal with potential spikes.
The last dimension is the time dimension.
The signal can be single-lead, multi-lead, or batched signals.
- threshold : numbers.Real, optional
+ threshold : float or int, optional
Values of `sig` that are larger than `threshold` will be removed.
inplace : bool, optional
Whether to modify `sig` in place or not.
@@ -512,16 +512,18 @@ def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = T
return sig.astype(dtype)
-def butter_bandpass(lowcut: Real, highcut: Real, fs: Real, order: int, verbose: int = 0) -> Tuple[np.ndarray, np.ndarray]:
+def butter_bandpass(
+ lowcut: Union[float, int], highcut: Union[float, int], fs: Union[float, int], order: int, verbose: int = 0
+) -> Tuple[NDArray, NDArray]:
"""Butterworth Bandpass Filter Design.
Parameters
----------
- lowcut : numbers.Real
+ lowcut : float or int
Low cutoff frequency.
- highcut : numbers.Real
+ highcut : float or int
High cutoff frequency.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of `data`.
order : int,
Order of the filter.
@@ -569,19 +571,19 @@ def butter_bandpass(lowcut: Real, highcut: Real, fs: Real, order: int, verbose:
if verbose >= 1:
print(f"by the setup of lowcut and highcut, the filter type falls to {btype}, with Wn = {Wn}")
- b, a = butter(order, Wn, btype=btype)
+ b, a = butter(order, Wn, btype=btype) # type: ignore
return b, a
def butter_bandpass_filter(
- data: np.ndarray,
- lowcut: Real,
- highcut: Real,
- fs: Real,
+ data: NDArray,
+ lowcut: Union[float, int],
+ highcut: Union[float, int],
+ fs: Union[float, int],
order: int,
btype: Optional[Literal["lohi", "hilo"]] = None,
verbose: int = 0,
-) -> np.ndarray:
+) -> NDArray:
"""Butterworth bandpass filtering the signals.
Apply a Butterworth bandpass filter to the signal.
@@ -639,19 +641,19 @@ def butter_bandpass_filter(
def get_ampl(
- sig: np.ndarray,
- fs: Real,
+ sig: NDArray,
+ fs: Union[float, int],
fmt: str = "lead_first",
- window: Real = 0.2,
+ window: Union[float, int] = 0.2,
critical_points: Optional[Sequence] = None,
-) -> Union[float, np.ndarray]:
+) -> Union[float, NDArray]:
"""Get amplitude of a signal (near critical points if given).
Parameters
----------
sig : numpy.ndarray
(ECG) signal.
- fs : numbers.Real
+ fs : float or int
Sampling frequency of the signal
fmt : str, default "lead_first"
Format of the signal, can be
@@ -713,13 +715,13 @@ def get_ampl(
def normalize(
- sig: np.ndarray,
+ sig: NDArray,
method: Literal["naive", "min-max", "z-score"],
- mean: Union[Real, Iterable[Real]] = 0.0,
- std: Union[Real, Iterable[Real]] = 1.0,
+ mean: Union[Union[float, int], Iterable[Union[float, int]]] = 0.0,
+ std: Union[Union[float, int], Iterable[Union[float, int]]] = 1.0,
sig_fmt: str = "channel_first",
per_channel: bool = False,
-) -> np.ndarray:
+) -> NDArray:
"""Normalize a signal.
Perform z-score normalization on `sig`,
@@ -742,11 +744,11 @@ def normalize(
The signal to be normalized.
method : {"naive", "min-max", "z-score"}
Normalization method, case insensitive.
- mean : numbers.Real or array_like, default 0.0
+ mean : float or int or array_like, default 0.0
Mean value of the normalized signal,
or mean values for each lead of the normalized signal.
Useless if `method` is "min-max".
- std : numbers.Real or array_like, default 1.0
+ std : float or int or array_like, default 1.0
Standard deviation of the normalized signal,
or standard deviations for each lead of the normalized signal.
Useless if `method` is "min-max".
@@ -786,17 +788,17 @@ def normalize(
], f"unknown normalization method `{method}`"
if not per_channel:
if sig.ndim == 2:
- assert isinstance(mean, Real) and isinstance(
- std, Real
+ assert isinstance(mean, (float, int)) and isinstance(
+ std, (float, int)
), "`mean` and `std` should be real numbers in the non per-channel setting for 2d signal"
else: # sig.ndim == 3
- assert (isinstance(mean, Real) or np.shape(mean) == (sig.shape[0],)) and (
- isinstance(std, Real) or np.shape(std) == (sig.shape[0],)
+ assert (isinstance(mean, (float, int)) or np.shape(mean) == (sig.shape[0],)) and ( # type: ignore
+ isinstance(std, (float, int)) or np.shape(std) == (sig.shape[0],) # type: ignore
), (
f"`mean` and `std` should be real numbers or have shape ({sig.shape[0]},) "
"in the non per-channel setting for 3d signal"
)
- if isinstance(std, Real):
+ if isinstance(std, (float, int)):
assert std > 0, "standard deviation should be positive"
else:
assert (np.array(std) > 0).all(), "standard deviations should all be positive"
@@ -810,10 +812,10 @@ def normalize(
if isinstance(mean, Iterable):
assert sig.ndim in [2, 3], "`mean` should be a real number for 1d signal"
if sig.ndim == 2:
- assert np.shape(mean) in [
+ assert np.shape(mean) in [ # type: ignore
(sig.shape[0],),
(sig.shape[-1],),
- ], f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}"
+ ], f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}" # type: ignore
if sig_fmt.lower() in [
"channel_first",
"lead_first",
@@ -826,40 +828,40 @@ def normalize(
"channel_first",
"lead_first",
]:
- if np.shape(mean) == (sig.shape[0],):
+ if np.shape(mean) == (sig.shape[0],): # type: ignore
_mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis]
- elif np.shape(mean) == (sig.shape[1],):
+ elif np.shape(mean) == (sig.shape[1],): # type: ignore
_mean = np.repeat(
np.array(mean, dtype=dtype)[np.newaxis, ..., np.newaxis],
sig.shape[0],
axis=0,
)
- elif np.shape(mean) == sig.shape[:2]:
+ elif np.shape(mean) == sig.shape[:2]: # type: ignore
_mean = np.array(mean, dtype=dtype)[..., np.newaxis]
else:
- raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}")
+ raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore
else: # "channel_last" or "lead_last"
- if np.shape(mean) == (sig.shape[0],):
+ if np.shape(mean) == (sig.shape[0],): # type: ignore
_mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis]
- elif np.shape(mean) == (sig.shape[-1],):
+ elif np.shape(mean) == (sig.shape[-1],): # type: ignore
_mean = np.repeat(
np.array(mean, dtype=dtype)[np.newaxis, np.newaxis, ...],
sig.shape[0],
axis=0,
)
- elif np.shape(mean) == (sig.shape[0], sig.shape[-1]):
+ elif np.shape(mean) == (sig.shape[0], sig.shape[-1]): # type: ignore
_mean = np.expand_dims(np.array(mean, dtype=dtype), axis=1)
else:
- raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}")
+ raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore
else:
_mean = mean
if isinstance(std, Iterable):
assert sig.ndim in [2, 3], "`std` should be a real number for 1d signal"
if sig.ndim == 2:
- assert np.shape(std) in [
+ assert np.shape(std) in [ # type: ignore
(sig.shape[0],),
(sig.shape[-1],),
- ], f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}"
+ ], f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}" # type: ignore
if sig_fmt.lower() in [
"channel_first",
"lead_first",
@@ -872,31 +874,31 @@ def normalize(
"channel_first",
"lead_first",
]:
- if np.shape(std) == (sig.shape[0],):
+ if np.shape(std) == (sig.shape[0],): # type: ignore
_std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis]
- elif np.shape(std) == (sig.shape[1],):
+ elif np.shape(std) == (sig.shape[1],): # type: ignore
_std = np.repeat(
np.array(std, dtype=dtype)[np.newaxis, ..., np.newaxis],
sig.shape[0],
axis=0,
)
- elif np.shape(std) == sig.shape[:2]:
+ elif np.shape(std) == sig.shape[:2]: # type: ignore
_std = np.array(std, dtype=dtype)[..., np.newaxis]
else:
- raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}")
+ raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore
else: # "channel_last" or "lead_last"
- if np.shape(std) == (sig.shape[0],):
+ if np.shape(std) == (sig.shape[0],): # type: ignore
_std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis]
- elif np.shape(std) == (sig.shape[-1],):
+ elif np.shape(std) == (sig.shape[-1],): # type: ignore
_std = np.repeat(
np.array(std, dtype=dtype)[np.newaxis, np.newaxis, ...],
sig.shape[0],
axis=0,
)
- elif np.shape(std) == (sig.shape[0], sig.shape[-1]):
+ elif np.shape(std) == (sig.shape[0], sig.shape[-1]): # type: ignore
_std = np.expand_dims(np.array(std, dtype=dtype), axis=1)
else:
- raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}")
+ raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore
else:
_std = std
@@ -927,7 +929,7 @@ def normalize(
options = dict(axis=0, keepdims=True)
if _method == "z-score":
- nm_sig = ((sig - np.mean(sig, dtype=dtype, **options)) / (np.std(sig, dtype=dtype, **options) + eps)) * _std + _mean
+ nm_sig = ((sig - np.mean(sig, dtype=dtype, **options)) / (np.std(sig, dtype=dtype, **options) + eps)) * _std + _mean # type: ignore
elif _method == "min-max":
- nm_sig = (sig - np.amin(sig, **options)) / (np.amax(sig, **options) - np.amin(sig, **options) + eps)
- return nm_sig.astype(dtype)
+ nm_sig = (sig - np.amin(sig, **options)) / (np.amax(sig, **options) - np.amin(sig, **options) + eps) # type: ignore
+ return nm_sig.astype(dtype) # type: ignore