diff --git a/.github/workflows/run-pytest.yml b/.github/workflows/run-pytest.yml index 5820ddbc..9f3fdd88 100644 --- a/.github/workflows/run-pytest.yml +++ b/.github/workflows/run-pytest.yml @@ -14,6 +14,7 @@ on: env: PYTHON_PRIMARY_VERSION: '3.10' + SHHS_DATA_AVAILABLE: 'false' jobs: build: @@ -66,36 +67,46 @@ jobs: - name: List installed Python packages run: | python -m pip list - - name: Install nsrr and download a samll part of SHHS to do test + - name: Install nsrr and download a small part of SHHS to do test # ref. https://github.com/DeepPSP/nsrr-automate - uses: gacts/run-and-post-run@v1 - with: - # if ~/tmp/nsrr-data/shhs is empty (no files downloaded), - # fail and terminate the workflow - run: | - gem install nsrr --no-document - nsrr download shhs/polysomnography/edfs/shhs1/ --file="^shhs1\-20010.*\.edf" --token=${{ secrets.NSRR_TOKEN }} - nsrr download shhs/polysomnography/annotations-events-nsrr/shhs1/ --file="^shhs1\-20010.*\-nsrr\.xml" --token=${{ secrets.NSRR_TOKEN }} - nsrr download shhs/polysomnography/annotations-events-profusion/shhs1/ --file="^shhs1\-20010.*\-profusion\.xml" --token=${{ secrets.NSRR_TOKEN }} - nsrr download shhs/polysomnography/annotations-rpoints/shhs1/ --file="^shhs1\-20010.*\-rpoint\.csv" --token=${{ secrets.NSRR_TOKEN }} - nsrr download shhs/datasets/ --shallow --token=${{ secrets.NSRR_TOKEN }} - nsrr download shhs/datasets/hrv-analysis/ --token=${{ secrets.NSRR_TOKEN }} - mkdir -p ~/tmp/nsrr-data/ - mv shhs/ ~/tmp/nsrr-data/ - du -sh ~/tmp/nsrr-data/* - if [ "$(find ~/tmp/nsrr-data/shhs -type f | wc -l)" -eq 0 ]; \ - then (echo "No files downloaded. Exiting..." && exit 1); \ - else echo "Found $(find ~/tmp/nsrr-data/shhs -type f | wc -l) files"; fi - post: | - cd ~/tmp/ && du -sh $(ls -A) - rm -rf ~/tmp/nsrr-data/ - cd ~/tmp/ && du -sh $(ls -A) + # uses: gacts/run-and-post-run@v1 + continue-on-error: true + run: | + set -u -o pipefail + gem install nsrr --no-document + nsrr download shhs/polysomnography/edfs/shhs1/ --file="^shhs1\-20010.*\.edf" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + nsrr download shhs/polysomnography/annotations-events-nsrr/shhs1/ --file="^shhs1\-20010.*\-nsrr\.xml" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + nsrr download shhs/polysomnography/annotations-events-profusion/shhs1/ --file="^shhs1\-20010.*\-profusion\.xml" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + nsrr download shhs/polysomnography/annotations-rpoints/shhs1/ --file="^shhs1\-20010.*\-rpoint\.csv" --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + nsrr download shhs/datasets/ --shallow --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + nsrr download shhs/datasets/hrv-analysis/ --token=${{ secrets.NSRR_TOKEN }} 2>&1 | sed -E '/^[[:space:]]*[Ss]kipped([[:space:]]|$)/d' || true + + mkdir -p ~/tmp/nsrr-data/ + mv shhs/ ~/tmp/nsrr-data/ + du -sh ~/tmp/nsrr-data/* || true + + EDF_COUNT=$(find ~/tmp/nsrr-data/shhs -type f -name "*.edf" 2>/dev/null | wc -l | tr -d ' ') + echo "Detected SHHS EDF file count: $EDF_COUNT" + + if [ "$EDF_COUNT" -eq 0 ]; then + echo "::error title=No SHHS EDF files downloaded::No .edf files were downloaded (token may be invalid or pattern mismatch). SHHS tests will be skipped." + echo "No SHHS EDF files downloaded; setting SHHS_DATA_AVAILABLE=false" + echo "SHHS_DATA_AVAILABLE=false" >> $GITHUB_ENV + exit 1 + else + echo "Found $EDF_COUNT SHHS EDF files; setting SHHS_DATA_AVAILABLE=true" + echo "SHHS_DATA_AVAILABLE=true" >> $GITHUB_ENV + fi + # post: | + # cd ~/tmp/ && du -sh $(ls -A) + # rm -rf ~/tmp/nsrr-data/ + # cd ~/tmp/ && du -sh $(ls -A) - name: Run test with pytest and collect coverage run: | - pytest -vv -s \ - --cov=torch_ecg \ - --ignore=test/test_pipelines \ - test + echo "SHHS_DATA_AVAILABLE at test step: $SHHS_DATA_AVAILABLE" + pytest --cov --junitxml=junit.xml -o junit_family=legacy + env: + SHHS_DATA_AVAILABLE: ${{ env.SHHS_DATA_AVAILABLE }} - name: Upload coverage to Codecov if: matrix.python-version == ${{ env.PYTHON_PRIMARY_VERSION }} uses: codecov/codecov-action@v4 @@ -103,3 +114,14 @@ jobs: fail_ci_if_error: true # optional (default = false) verbose: true # optional (default = false) token: ${{ secrets.CODECOV_TOKEN }} # required + - name: Upload test results to Codecov + if: ${{ !cancelled() }} + uses: codecov/test-results-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + - name: Cleanup SHHS temp data + if: always() && true + run: | + cd ~/tmp/ && du -sh $(ls -A) + rm -rf ~/tmp/nsrr-data/ + cd ~/tmp/ && du -sh $(ls -A) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 0046e588..a4d9a63a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,8 +19,13 @@ Changed - Make the function `remove_spikes_naive` in `torch_ecg.utils.utils_signal` support 2D and 3D input signals. +- Use `save_file` and `load_file` from the `safetensors` package for saving + and loading files in place of `torch.save` and `torch.load` in the `CkptMixin` + class in `torch_ecg.utils.utils_nn`. - Add retry mechanism to the `http_get` function in `torch_ecg.utils.download` module. +- Add length verification in the `http_get` function in + `torch_ecg.utils.download` module. Deprecated ~~~~~~~~~~ @@ -42,6 +47,14 @@ Fixed - Fix potential errors when deepcopying a `torch_ecg.cfg.CFG` object: previously, deepcopying such an object like `CFG({"a": {1: 0.1, 2: 0.2}})` would result in an error. +- Fix potential bugs in contextmanager `torch_ecg.utils.timeout`: restore the previously + installed SIGALRM handler after use, cancel any pending alarm reliably in a finally block, + avoid installing a handler when duration <= 0 (preventing unintended global side-effects), + and thereby eliminate spurious `TimeoutError` exceptions that could be triggered later by + unrelated signal.alarm calls due to the old implementation not reinstating the original handler. +- Fix bugs in utility function `torch_ecg.utils.make_serializable`: the previous implementation + does not drop some types of unserializable items correctly. Two additional parameters + `drop_unserializable` and `drop_paths` are added. Security ~~~~~~~~ diff --git a/benchmarks/train_crnn_cinc2020/model.py b/benchmarks/train_crnn_cinc2020/model.py index 7df6941e..a94e8d03 100644 --- a/benchmarks/train_crnn_cinc2020/model.py +++ b/benchmarks/train_crnn_cinc2020/model.py @@ -96,8 +96,8 @@ def inference( _input = _input.unsqueeze(0) # add a batch dimension prob = self.sigmoid(self.forward(_input)) pred = (prob >= bin_pred_thr).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() for row_idx, row in enumerate(pred): row_max_prob = prob[row_idx, ...].max() if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None: diff --git a/benchmarks/train_crnn_cinc2021/model.py b/benchmarks/train_crnn_cinc2021/model.py index c828805c..f616548c 100644 --- a/benchmarks/train_crnn_cinc2021/model.py +++ b/benchmarks/train_crnn_cinc2021/model.py @@ -99,8 +99,8 @@ def inference( # batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) pred = (prob >= bin_pred_thr).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() for row_idx, row in enumerate(pred): row_max_prob = prob[row_idx, ...].max() if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None: diff --git a/benchmarks/train_hybrid_cpsc2020/model.py b/benchmarks/train_hybrid_cpsc2020/model.py index f474d2db..5e22b94a 100644 --- a/benchmarks/train_hybrid_cpsc2020/model.py +++ b/benchmarks/train_hybrid_cpsc2020/model.py @@ -92,8 +92,8 @@ def inference( _input = _input.unsqueeze(0) # add a batch dimension prob = self.sigmoid(self.forward(_input)) pred = (prob >= bin_pred_thr).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() for row_idx, row in enumerate(pred): row_max_prob = prob[row_idx, ...].max() if row.sum() == 0: @@ -190,14 +190,14 @@ def inference( if self.n_classes == 2: prob = self.sigmoid(prob) # (batch_size, seq_len, 2) pred = (prob >= bin_pred_thr).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() # aux used to filter out potential simultaneous predictions of SPB and PVC aux = (prob == np.max(prob, axis=2, keepdims=True)).astype(int) pred = aux * pred elif self.n_classes == 3: prob = self.softmax(prob) # (batch_size, seq_len, 3) - prob = prob.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() pred = np.argmax(prob, axis=2) if rpeak_inds is not None: diff --git a/benchmarks/train_hybrid_cpsc2021/entry_2021.py b/benchmarks/train_hybrid_cpsc2021/entry_2021.py index 2854852c..2b70feab 100644 --- a/benchmarks/train_hybrid_cpsc2021/entry_2021.py +++ b/benchmarks/train_hybrid_cpsc2021/entry_2021.py @@ -379,12 +379,12 @@ def _detect_rpeaks(model, sig, siglen, overlap_len, config): for idx in range(batch_size // _BATCH_SIZE): pred = model.forward(sig[_BATCH_SIZE * idx : _BATCH_SIZE * (idx + 1), ...]) pred = model.sigmoid(pred) - pred = pred.cpu().detach().numpy().squeeze(-1) + pred = pred.detach().cpu().numpy().squeeze(-1) l_pred.append(pred) if batch_size % _BATCH_SIZE != 0: pred = model.forward(sig[batch_size // _BATCH_SIZE * _BATCH_SIZE :, ...]) pred = model.sigmoid(pred) - pred = pred.cpu().detach().numpy().squeeze(-1) + pred = pred.detach().cpu().numpy().squeeze(-1) l_pred.append(pred) pred = np.concatenate(l_pred) @@ -473,12 +473,12 @@ def _main_task(model, sig, siglen, overlap_len, rpeaks, config): for idx in range(batch_size // _BATCH_SIZE): pred = model.forward(sig[_BATCH_SIZE * idx : _BATCH_SIZE * (idx + 1), ...]) pred = model.sigmoid(pred) - pred = pred.cpu().detach().numpy().squeeze(-1) + pred = pred.detach().cpu().numpy().squeeze(-1) l_pred.append(pred) if batch_size % _BATCH_SIZE != 0: pred = model.forward(sig[batch_size // _BATCH_SIZE * _BATCH_SIZE :, ...]) pred = model.sigmoid(pred) - pred = pred.cpu().detach().numpy().squeeze(-1) + pred = pred.detach().cpu().numpy().squeeze(-1) l_pred.append(pred) pred = np.concatenate(l_pred) diff --git a/benchmarks/train_hybrid_cpsc2021/model.py b/benchmarks/train_hybrid_cpsc2021/model.py index 2f50730d..9d02c102 100644 --- a/benchmarks/train_hybrid_cpsc2021/model.py +++ b/benchmarks/train_hybrid_cpsc2021/model.py @@ -169,7 +169,7 @@ def _inference_qrs_detection( _input = _input.unsqueeze(0) # add a batch dimension # batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _qrs_detection_post_process( @@ -226,7 +226,7 @@ def _inference_main_task( _input = _input.unsqueeze(0) # add a batch dimension batch_size, n_leads, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) af_episodes, af_mask = _main_task_post_process( prob=prob, @@ -368,7 +368,7 @@ def _inference_qrs_detection( _input = _input.unsqueeze(0) # add a batch dimension # batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _qrs_detection_post_process( @@ -425,7 +425,7 @@ def _inference_main_task( _input = _input.unsqueeze(0) # add a batch dimension batch_size, n_leads, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) af_episodes, af_mask = _main_task_post_process( prob=prob, @@ -567,7 +567,7 @@ def _inference_qrs_detection( _input = _input.unsqueeze(0) # add a batch dimension # batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _qrs_detection_post_process( @@ -624,7 +624,7 @@ def _inference_main_task( _input = _input.unsqueeze(0) # add a batch dimension batch_size, n_leads, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) af_episodes, af_mask = _main_task_post_process( prob=prob, @@ -721,7 +721,7 @@ def inference( prob = self.forward(_input) if self.config.clf.name != "crf": prob = self.sigmoid(prob) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) af_episodes, af_mask = _main_task_post_process( prob=prob, diff --git a/benchmarks/train_mtl_cinc2022/models/crnn.py b/benchmarks/train_mtl_cinc2022/models/crnn.py index 557e9380..b08d2e60 100644 --- a/benchmarks/train_mtl_cinc2022/models/crnn.py +++ b/benchmarks/train_mtl_cinc2022/models/crnn.py @@ -200,31 +200,31 @@ def inference( prob = self.softmax(forward_output["murmur"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() murmur_output = ClassificationOutput( classes=self.classes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["murmur"].cpu().detach().numpy(), + forward_output=forward_output["murmur"].detach().cpu().numpy(), ) if forward_output.get("outcome", None) is not None: prob = self.softmax(forward_output["outcome"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() outcome_output = ClassificationOutput( classes=self.outcomes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["outcome"].cpu().detach().numpy(), + forward_output=forward_output["outcome"].detach().cpu().numpy(), ) else: outcome_output = None @@ -238,13 +238,13 @@ def inference( else: prob = self.sigmoid(forward_output["segmentation"]) pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.states, prob=prob, pred=pred, - forward_output=forward_output["segmentation"].cpu().detach().numpy(), + forward_output=forward_output["segmentation"].detach().cpu().numpy(), ) else: segmentation_output = None diff --git a/benchmarks/train_mtl_cinc2022/models/model_ml.py b/benchmarks/train_mtl_cinc2022/models/model_ml.py index b55305cb..cad80801 100644 --- a/benchmarks/train_mtl_cinc2022/models/model_ml.py +++ b/benchmarks/train_mtl_cinc2022/models/model_ml.py @@ -170,9 +170,10 @@ def get_model(self, model_name: str, params: Optional[dict] = None) -> BaseEstim """ model_cls = self.model_map[model_name] + params = params or {} if model_cls in [GradientBoostingClassifier, SVC]: params.pop("n_jobs", None) - return model_cls(**(params or {})) + return model_cls(**params) def save_model( self, @@ -198,9 +199,13 @@ def save_model( path to save the model. """ + if isinstance(model_path, bytes): + model_path = model_path.decode() + model_path = Path(model_path).expanduser().resolve() + model_path.parent.mkdir(parents=True, exist_ok=True) _config = deepcopy(config) _config.pop("db_dir", None) - Path(model_path).write_bytes( + model_path.write_bytes( pickle.dumps( { "config": _config, diff --git a/benchmarks/train_mtl_cinc2022/models/seg.py b/benchmarks/train_mtl_cinc2022/models/seg.py index d67eb684..d93a181b 100644 --- a/benchmarks/train_mtl_cinc2022/models/seg.py +++ b/benchmarks/train_mtl_cinc2022/models/seg.py @@ -138,8 +138,8 @@ def inference( else: prob = self.sigmoid(self.forward(_input)["segmentation"]) pred = (prob > bin_pred_threshold).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.classes, @@ -264,8 +264,8 @@ def inference( else: prob = self.sigmoid(self.forward(_input)["segmentation"]) pred = (prob > bin_pred_threshold).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.classes, diff --git a/benchmarks/train_mtl_cinc2022/models/wav2vec2.py b/benchmarks/train_mtl_cinc2022/models/wav2vec2.py index 28e5c346..3886064b 100644 --- a/benchmarks/train_mtl_cinc2022/models/wav2vec2.py +++ b/benchmarks/train_mtl_cinc2022/models/wav2vec2.py @@ -279,31 +279,31 @@ def inference( prob = self.softmax(forward_output["murmur"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() murmur_output = ClassificationOutput( classes=self.classes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["murmur"].cpu().detach().numpy(), + forward_output=forward_output["murmur"].detach().cpu().numpy(), ) if forward_output.get("outcome", None) is not None: prob = self.softmax(forward_output["outcome"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() outcome_output = ClassificationOutput( classes=self.outcomes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["outcome"].cpu().detach().numpy(), + forward_output=forward_output["outcome"].detach().cpu().numpy(), ) else: outcome_output = None @@ -317,13 +317,13 @@ def inference( else: prob = self.sigmoid(forward_output["segmentation"]) pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.states, prob=prob, pred=pred, - forward_output=forward_output["segmentation"].cpu().detach().numpy(), + forward_output=forward_output["segmentation"].detach().cpu().numpy(), ) else: segmentation_output = None @@ -599,31 +599,31 @@ def inference( prob = self.softmax(forward_output["murmur"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() murmur_output = ClassificationOutput( classes=self.classes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["murmur"].cpu().detach().numpy(), + forward_output=forward_output["murmur"].detach().cpu().numpy(), ) if forward_output.get("outcome", None) is not None: prob = self.softmax(forward_output["outcome"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() outcome_output = ClassificationOutput( classes=self.outcomes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["outcome"].cpu().detach().numpy(), + forward_output=forward_output["outcome"].detach().cpu().numpy(), ) else: outcome_output = None @@ -637,13 +637,13 @@ def inference( else: prob = self.sigmoid(forward_output["segmentation"]) pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.states, prob=prob, pred=pred, - forward_output=forward_output["segmentation"].cpu().detach().numpy(), + forward_output=forward_output["segmentation"].detach().cpu().numpy(), ) else: segmentation_output = None diff --git a/benchmarks/train_mtl_cinc2022/team_code.py b/benchmarks/train_mtl_cinc2022/team_code.py index dd40d991..bb47c482 100644 --- a/benchmarks/train_mtl_cinc2022/team_code.py +++ b/benchmarks/train_mtl_cinc2022/team_code.py @@ -424,8 +424,8 @@ def run_challenge_model( # forward_output = main_model.clf(features) # shape (1, n_classes) # probabilities = main_model.softmax(forward_output) # labels = (probabilities == probabilities.max(dim=-1, keepdim=True).values).to(int) - # probabilities = probabilities.squeeze(dim=0).cpu().detach().numpy() - # labels = labels.squeeze(dim=0).cpu().detach().numpy() + # probabilities = probabilities.squeeze(dim=0).detach().cpu().numpy() + # labels = labels.squeeze(dim=0).detach().cpu().numpy() # get final prediction for murmurs: # strategy: diff --git a/benchmarks/train_multi_cpsc2019/model.py b/benchmarks/train_multi_cpsc2019/model.py index 93586dd7..10396142 100644 --- a/benchmarks/train_multi_cpsc2019/model.py +++ b/benchmarks/train_multi_cpsc2019/model.py @@ -112,7 +112,7 @@ def inference( mode="linear", align_corners=True, ).permute(0, 2, 1) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _inference_post_process( @@ -132,7 +132,7 @@ def inference( sampling_rate=self.config.fs, tol=0.05, )[0] - for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks) + for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks) ] return RPeaksDetectionOutput( @@ -224,7 +224,7 @@ def inference( _input = _input.unsqueeze(0) # add a batch dimension batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _inference_post_process( @@ -244,7 +244,7 @@ def inference( sampling_rate=self.config.fs, tol=0.05, )[0] - for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks) + for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks) ] return RPeaksDetectionOutput( @@ -336,7 +336,7 @@ def inference( _input = _input.unsqueeze(0) # add a batch dimension batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _inference_post_process( @@ -356,7 +356,7 @@ def inference( sampling_rate=self.config.fs, tol=0.05, )[0] - for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks) + for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks) ] return RPeaksDetectionOutput( diff --git a/benchmarks/train_unet_ludb/model.py b/benchmarks/train_unet_ludb/model.py index 0efb37ac..5e46cfcc 100644 --- a/benchmarks/train_unet_ludb/model.py +++ b/benchmarks/train_unet_ludb/model.py @@ -89,7 +89,7 @@ def inference( prob = self.softmax(prob) else: prob = torch.sigmoid(prob) - prob = prob.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() if "i" in self.classes: mask = np.argmax(prob, axis=-1) diff --git a/docs/source/_static/css/codeblock.css b/docs/source/_static/css/codeblock.css new file mode 100644 index 00000000..dd8fa012 --- /dev/null +++ b/docs/source/_static/css/codeblock.css @@ -0,0 +1,96 @@ +div.highlight { + position: relative; + display: flex; + border: 1px solid #eee; + border-radius: 8px; + overflow: hidden; + background: #fff; +} + +div.highlight .custom-linenos { + flex-shrink: 0; + background: #fafafa; + border-right: 1px solid #eee; + color: #999; + text-align: right; + padding: 12px 8px; + font-family: monospace; + user-select: none; + line-height: 1.5; + white-space: pre; + min-width: 36px; +} + +html[data-theme="dark"] div.highlight .custom-linenos { + background: #2a2a2a; + border-right-color: #444; + color: #aaa; +} + +div.highlight pre { + flex: 1; + margin: 0; + padding: 12px 16px; + overflow-x: auto; + white-space: pre; + background: transparent; + line-height: 1.5; +} + +div.highlight pre.code-wrapped { + white-space: pre-wrap; + word-break: break-word; + overflow-x: hidden; +} + +.code-toolbar { + position: absolute; + top: 8px; + right: 8px; + display: flex; + gap: 6px; + z-index: 50; +} + +.code-icon-btn { + display: inline-flex; + align-items: center; + justify-content: center; + background-color: rgba(255, 255, 255, 0.9); + border: 1px solid #ddd; + border-radius: 4px; + padding: 4px; + cursor: pointer; + color: #666; + transition: all 0.1s ease; + width: 26px; + height: 26px; +} +.code-icon-btn:hover { + background-color: #f5f5f5; + border-color: #ccc; + color: #333; +} +.code-icon-btn.active { + background-color: #e8f4e8; + border-color: #8cc08c; + color: #2d7d2d; +} + +.copy-tooltip { + position: fixed; + transform: translateX(-50%); + background: rgba(0, 0, 0, 0.9); + color: #fff; + font-size: 12px; + padding: 4px 8px; + border-radius: 4px; + opacity: 0; + transition: opacity 0.15s ease, transform 0.15s ease; + pointer-events: none; + z-index: 9999; +} +.copy-tooltip.visible { + opacity: 1; + transform: translateX(-50%) translateY(-3px); +} diff --git a/docs/source/_static/js/codeblock.js b/docs/source/_static/js/codeblock.js new file mode 100644 index 00000000..1f6649ff --- /dev/null +++ b/docs/source/_static/js/codeblock.js @@ -0,0 +1,125 @@ +document.addEventListener("DOMContentLoaded", () => { + document.querySelectorAll("div.highlight").forEach((highlightBlock) => { + const codePre = highlightBlock.querySelector("pre"); + if (!codePre) return; + + const hasSphinxLinenos = codePre.querySelector(".linenos") !== null; + + if (hasSphinxLinenos) { + codePre.querySelectorAll("span.linenos").forEach((el) => el.remove()); + codePre.dataset.linenos = "true"; + } + + const wrapBtn = document.createElement("button"); + wrapBtn.className = "code-icon-btn code-wrap-btn"; + wrapBtn.innerHTML = ` + + `; + wrapBtn.title = "Toggle line wrap"; + + wrapBtn.addEventListener("click", () => { + const isWrapped = codePre.classList.toggle("code-wrapped"); + wrapBtn.classList.toggle("active", isWrapped); + + const lineDiv = highlightBlock.querySelector(".custom-linenos"); + if (lineDiv) { + lineDiv.style.display = isWrapped ? "none" : "block"; + } + }); + + const copyBtn = document.createElement("button"); + copyBtn.className = "code-icon-btn code-copy-btn"; + copyBtn.innerHTML = ` + + `; + copyBtn.title = "Copy to clipboard"; + + copyBtn.addEventListener("click", async () => { + let codeText = codePre.textContent.trim(); + try { + if (navigator.clipboard?.writeText) { + await navigator.clipboard.writeText(codeText); + } else { + fallbackCopyText(codeText); + } + showCopyTooltip(copyBtn, "Copied!"); + } catch (err) { + console.error("Copy failed:", err); + showCopyTooltip(copyBtn, "Failed"); + } + }); + + const toolbar = document.createElement("div"); + toolbar.className = "code-toolbar"; + toolbar.appendChild(wrapBtn); + toolbar.appendChild(copyBtn); + highlightBlock.style.position = "relative"; + highlightBlock.appendChild(toolbar); + + if (codePre.dataset.linenos === "true") { + const contentForCount = codePre.textContent.trimEnd(); + const totalLines = contentForCount.split("\n").length; + + const lineNumbers = Array.from({ length: totalLines }, (_, i) => i + 1).join( + "\n" + ); + const lineDiv = document.createElement("div"); + lineDiv.className = "custom-linenos"; + lineDiv.textContent = lineNumbers; + highlightBlock.insertBefore(lineDiv, codePre); + } + }); +}); + +function fallbackCopyText(text) { + const ta = document.createElement("textarea"); + ta.value = text; + ta.style.position = "fixed"; + ta.style.top = "-10000px"; + document.body.appendChild(ta); + ta.focus(); + ta.select(); + document.execCommand("copy"); + document.body.removeChild(ta); +} + +function showCopyTooltip(btn, text) { + const tooltip = document.createElement("div"); + tooltip.className = "copy-tooltip"; + tooltip.textContent = text; + document.body.appendChild(tooltip); + const rect = btn.getBoundingClientRect(); + let top = rect.top - 28; + if (top < 6) top = rect.bottom + 8; + tooltip.style.left = `${rect.left + rect.width / 2}px`; + tooltip.style.top = `${top}px`; + requestAnimationFrame(() => tooltip.classList.add("visible")); + setTimeout(() => { + tooltip.classList.remove("visible"); + setTimeout(() => tooltip.remove(), 200); + }, 1200); +} diff --git a/docs/source/conf.py b/docs/source/conf.py index 89afcbb3..41bb8064 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -203,6 +203,8 @@ def setup(app): # ) # app.add_transform(AutoStructify) app.add_css_file("css/custom.css") + app.add_css_file("css/codeblock.css") + app.add_js_file("js/codeblock.js") latex_documents = [ diff --git a/pyproject.toml b/pyproject.toml index 172f9add..d4f4ad32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "pyEDFlib", "PyWavelets", "requests", + "safetensors", "scikit-learn", "scipy", "soundfile", @@ -54,7 +55,7 @@ dependencies = [ [project.optional-dependencies] dev = [ - "black==24.10.0", + "black", "flake8", "gdown", "librosa", @@ -104,7 +105,7 @@ docs = [ "sphinxcontrib-tikz", ] test = [ - "black==24.10.0", + "black", "flake8", "gdown", "librosa", @@ -130,3 +131,39 @@ path = "torch_ecg/version.py" include = [ "/torch_ecg", ] + +# configuration for pytest and coverage +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-vv -s --cov=torch_ecg --ignore=test/test_pipelines" +testpaths = ["test"] + +[tool.coverage.run] +source = ["torch_ecg"] +omit = [ + "torch_ecg/databases/nsrr_databases/shhs.py", +] + +[tool.coverage.report] +omit = [ + # "torch_ecg/databases/nsrr_databases/shhs.py", + # NSRR databases are ignored temporarily, since data access is denied by the US government + "torch_ecg/databases/nsrr_databases/*", + # temporarily ignore torch_ecg/components/nas.py since it's not implemented completely + "torch_ecg/components/nas.py", + # temporarily ignore torch_ecg/models/grad_cam.py since it's not implemented completely + "torch_ecg/models/grad_cam.py", + # temporarily ignore torch_ecg/ssl since it's not implemented + "torch_ecg/ssl/*", +] +exclude_also = [ + "raise NotImplementedError", + # Don't complain if non-runnable code isn't run: + "if __name__ == .__main__.:", + # Don't complain about abstract methods, they aren't run: + "@(abc\\.)?abstractmethod", + # base class of the NSRR databases are also ignored temporarily + "^class NSRRDataBase\\(.*\\):", +] +# show_missing = true +# skip_covered = true diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..5a8db90e --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,5 @@ +{ + "diagnosticSeverityOverrides": { + "reportAttributeAccessIssue": "none" + } +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3e04b0ae..cad67fcd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ tensorboardX tqdm deprecated deprecate-kwargs +safetensors pyEDFlib PyWavelets torch-optimizer diff --git a/test/requirements.txt b/test/requirements.txt index 566f53ba..fbf9f542 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,4 +1,4 @@ -black==24.10.0 +black flake8 pytest pytest-xdist diff --git a/test/test_components/test_trainer.py b/test/test_components/test_trainer.py index 9d7b2188..d690cf11 100644 --- a/test/test_components/test_trainer.py +++ b/test/test_components/test_trainer.py @@ -95,7 +95,7 @@ def __init__(self, n_leads: int, config: Optional[CFG] = None, **kwargs: Any) -> super().__init__(ModelCfg.mask_classes, n_leads, model_config) @torch.no_grad() - def inference( + def inference( # type: ignore self, input: Union[Sequence[float], np.ndarray, Tensor], bin_pred_thr: float = 0.5, @@ -130,7 +130,7 @@ def inference( prob = self.softmax(prob) else: prob = torch.sigmoid(prob) - prob = prob.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() if "i" in self.classes: mask = np.argmax(prob, axis=-1) @@ -165,37 +165,37 @@ def __init__( """ Parameters ---------- - model: Module, + model : Module the model to be trained - model_config: dict, + model_config : dict the configuration of the model, used to keep a record in the checkpoints - train_config: dict, + train_config : dict the configuration of the training, including configurations for the data loader, for the optimization, etc. will also be recorded in the checkpoints. `train_config` should at least contain the following keys: - "monitor": str, - "loss": str, - "n_epochs": int, - "batch_size": int, - "learning_rate": float, - "lr_scheduler": str, - "lr_step_size": int, optional, depending on the scheduler - "lr_gamma": float, optional, depending on the scheduler - "max_lr": float, optional, depending on the scheduler - "optimizer": str, - "decay": float, optional, depending on the optimizer - "momentum": float, optional, depending on the optimizer - device: torch.device, optional, + - "monitor": str, + - "loss": str, + - "n_epochs": int, + - "batch_size": int, + - "learning_rate": float, + - "lr_scheduler": str, + - "lr_step_size": int, optional, depending on the scheduler + - "lr_gamma": float, optional, depending on the scheduler + - "max_lr": float, optional, depending on the scheduler + - "optimizer": str, + - "decay": float, optional, depending on the optimizer + - "momentum": float, optional, depending on the optimizer + device : torch.device, optional the device to be used for training - lazy: bool, default True, + lazy : bool, default True whether to initialize the data loader lazily """ super().__init__( model=model, - dataset_cls=LUDBDataset, + dataset_cls=LUDBDataset, # type: ignore model_config=model_config, train_config=train_config, device=device, @@ -212,21 +212,21 @@ def _setup_dataloaders( Parameters ---------- - train_dataset: Dataset, optional, + train_dataset : Dataset, optional the training dataset - val_dataset: Dataset, optional, + val_dataset : Dataset, optional the validation dataset """ if train_dataset is None: - train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False) + train_dataset = self.dataset_cls(config=self.train_config, training=True, lazy=False) # type: ignore if self.train_config.debug: val_train_dataset = train_dataset else: val_train_dataset = None if val_dataset is None: - val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False) + val_dataset = self.dataset_cls(config=self.train_config, training=False, lazy=False) # type: ignore # https://discuss.pytorch.org/t/guidelines-for-assigning-num-workers-to-dataloader/813/4 if torch.cuda.is_available(): @@ -235,7 +235,7 @@ def _setup_dataloaders( num_workers = 0 self.train_loader = DataLoader( - dataset=train_dataset, + dataset=train_dataset, # type: ignore batch_size=self.batch_size, shuffle=True, num_workers=num_workers, @@ -246,7 +246,7 @@ def _setup_dataloaders( if self.train_config.debug: self.val_train_loader = DataLoader( - dataset=val_train_dataset, + dataset=val_train_dataset, # type: ignore batch_size=self.batch_size, shuffle=True, num_workers=num_workers, @@ -257,7 +257,7 @@ def _setup_dataloaders( else: self.val_train_loader = None self.val_loader = DataLoader( - dataset=val_dataset, + dataset=val_dataset, # type: ignore batch_size=self.batch_size, shuffle=True, num_workers=num_workers, @@ -266,7 +266,7 @@ def _setup_dataloaders( collate_fn=collate_fn, ) - def run_one_step(self, *data: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def run_one_step(self, *data: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: # type: ignore """ Parameters ---------- @@ -331,9 +331,9 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: # each scoring is a dict consisting of the following metrics: # sensitivity, precision, f1_score, mean_error, standard_deviation eval_res_split = compute_ludb_metrics( - np.repeat(all_labels[:, np.newaxis, :], self.model_config.n_leads, axis=1), - np.repeat(all_mask_preds[:, np.newaxis, :], self.model_config.n_leads, axis=1), - self._cm, + np.repeat(all_labels[:, np.newaxis, :], self.model_config.n_leads, axis=1), # type: ignore + np.repeat(all_mask_preds[:, np.newaxis, :], self.model_config.n_leads, axis=1), # type: ignore + self._cm, # type: ignore self.train_config.fs, ) @@ -356,7 +356,7 @@ def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: self.model.train() - return eval_res + return eval_res # type: ignore @property def _cm(self) -> Dict[str, str]: @@ -401,9 +401,10 @@ def test_unet_trainer() -> None: # ds_train_fl = LUDB(train_cfg_fl, training=True, lazy=False) # ds_val_fl = LUDB(train_cfg_fl, training=False, lazy=False) - train_cfg_fl.keep_checkpoint_max = 0 - train_cfg_fl.monitor = None - train_cfg_fl.n_epochs = 2 + train_cfg_fl.keep_checkpoint_max = 2 + train_cfg_fl.monitor = "f1_score" + train_cfg_fl.n_epochs = 3 + train_cfg_fl.flooding_level = 0.1 model_config = deepcopy(ModelCfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -423,6 +424,8 @@ def test_unet_trainer() -> None: lazy=False, ) + print(repr(trainer)) + bmd = trainer.train() del model, trainer, bmd @@ -430,7 +433,7 @@ def test_unet_trainer() -> None: def test_base_trainer(): with pytest.raises(TypeError, match="Can't instantiate abstract class"): - BaseTrainer( + BaseTrainer( # type: ignore model=ECG_UNET_LUDB(12, ModelCfg), dataset_cls=LUDBDataset, train_config=LUDBTrainCfg, diff --git a/test/test_databases/test_afdb.py b/test/test_databases/test_afdb.py index 63267682..d0ccabae 100644 --- a/test/test_databases/test_afdb.py +++ b/test/test_databases/test_afdb.py @@ -29,7 +29,7 @@ with pytest.warns(RuntimeWarning): - reader = AFDB(_CWD, verbose=3) + reader = AFDB(_CWD, verbose=1) if len(reader) == 0: reader.download() diff --git a/test/test_databases/test_base.py b/test/test_databases/test_base.py index aa789905..50b4ce1d 100644 --- a/test/test_databases/test_base.py +++ b/test/test_databases/test_base.py @@ -2,6 +2,7 @@ from pathlib import Path +import numpy as np import pytest from torch_ecg.databases import AFDB, list_databases @@ -24,25 +25,25 @@ def test_base_database(): TypeError, match=f"Can't instantiate abstract class {_DataBase.__name__}", ): - db = _DataBase() + db = _DataBase() # type: ignore[abstract] with pytest.raises( TypeError, match=f"Can't instantiate abstract class {PhysioNetDataBase.__name__}", ): - db = PhysioNetDataBase() + db = PhysioNetDataBase() # type: ignore[abstract] with pytest.raises( TypeError, match=f"Can't instantiate abstract class {NSRRDataBase.__name__}", ): - db = NSRRDataBase() + db = NSRRDataBase() # type: ignore[abstract] with pytest.raises( TypeError, match=f"Can't instantiate abstract class {CPSCDataBase.__name__}", ): - db = CPSCDataBase() + db = CPSCDataBase() # type: ignore[abstract] def test_beat_ann(): @@ -88,6 +89,9 @@ def test_database_meta(): for k in WFDB_Rhythm_Annotations: assert reader.helper(k) is None # printed: `{k}` stands for `{WFDB_Rhythm_Annotations[k]}` + with pytest.raises(NotImplementedError, match="not implemented for"): + reader._auto_infer_units(np.ones((10, 2)), sig_type="EEG") + def test_database_info(): with pytest.warns(RuntimeWarning, match="`db_dir` is not specified"): diff --git a/test/test_databases/test_ludb.py b/test/test_databases/test_ludb.py index 8d894d01..910b1f9f 100644 --- a/test/test_databases/test_ludb.py +++ b/test/test_databases/test_ludb.py @@ -60,7 +60,7 @@ def test_load_data(self): data_1 = reader.load_data(0, leads=[1, 7]) assert data.shape[0] == 12 assert data_1.shape[0] == 2 - assert np.allclose(data[[1, 7], :], data_1) + assert np.allclose(data[[1, 7], :], data_1) # type: ignore def test_load_ann(self): ann = reader.load_ann(0) @@ -114,7 +114,11 @@ def test_meta_data(self): def test_plot(self): reader.plot(0, leads=["I", 5], ticks_granularity=2) data = reader.load_data(0, leads="III", data_format="flat") - reader.plot(0, data=data, leads="III") + reader.plot(0, data=data, leads="III") # type: ignore + + def test_get_absolute_path(self): + path = reader.get_absolute_path(0, extension="avf") + assert path.is_file() and path.suffix == ".avf" config = deepcopy(LUDBTrainCfg) diff --git a/test/test_databases/test_shhs.py b/test/test_databases/test_shhs.py index e4f8b8f1..9f66e376 100644 --- a/test/test_databases/test_shhs.py +++ b/test/test_databases/test_shhs.py @@ -4,6 +4,7 @@ subsampling: accomplished """ +import os import time from numbers import Real from pathlib import Path @@ -14,6 +15,11 @@ from torch_ecg.databases import SHHS, DataBaseInfo +pytestmark = pytest.mark.skipif( + os.getenv("SHHS_DATA_AVAILABLE") != "true", reason="SHHS dataset not available (token invalid or download skipped)" +) + + ############################################################################### # set paths # 9 files are downloaded in the following directory using `nsrr` @@ -80,28 +86,28 @@ def test_load_psg_data(self): assert isinstance(value, tuple) assert len(value) == 2 assert isinstance(value[0], np.ndarray) - assert isinstance(value[1], Real) and value[1] > 0 + assert isinstance(value[1], Real) and value[1] > 0 # type: ignore available_signals = reader.get_available_signals(0) - for signal in available_signals: + for signal in available_signals: # type: ignore psg_data = reader.load_psg_data(0, channel=signal, physical=True) assert isinstance(psg_data, tuple) assert len(psg_data) == 2 assert isinstance(psg_data[0], np.ndarray) - assert isinstance(psg_data[1], Real) and psg_data[1] > 0 + assert isinstance(psg_data[1], Real) and psg_data[1] > 0 # type: ignore def test_load_data(self): data, fs = reader.load_data(0) assert isinstance(data, np.ndarray) assert data.ndim == 2 - assert isinstance(fs, Real) and fs > 0 + assert isinstance(fs, Real) and fs > 0 # type: ignore data_1, fs_1 = reader.load_data(0, fs=500, data_format="flat") assert isinstance(data_1, np.ndarray) assert data_1.ndim == 1 assert fs_1 == 500 data_1, fs_1 = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat") assert fs_1 == fs - assert data_1.shape[0] == int(10 * fs) - assert np.allclose(data_1, data[0, int(10 * fs) : int(20 * fs)]) + assert data_1.shape[0] == int(10 * fs) # type: ignore + assert np.allclose(data_1, data[0, int(10 * fs) : int(20 * fs)]) # type: ignore data_1 = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat", return_fs=False) assert isinstance(data_1, np.ndarray) data_2, _ = reader.load_data(0, sampfrom=10, sampto=20, data_format="flat", units="uv") @@ -244,7 +250,7 @@ def test_load_rpeak_ann(self): ValueError, match="`units` should be one of 's', 'ms', case insensitive, or None", ): - reader.load_rpeak_ann(rec, units="invalid") + reader.load_rpeak_ann(rec, units="invalid") # type: ignore def test_load_rr_ann(self): rec = reader.rec_with_rpeaks_ann[0] @@ -271,7 +277,7 @@ def test_load_rr_ann(self): ValueError, match="`units` should be one of 's', 'ms', case insensitive, or None", ): - reader.load_rr_ann(rec, units="invalid") + reader.load_rr_ann(rec, units="invalid") # type: ignore def test_load_nn_ann(self): rec = reader.rec_with_rpeaks_ann[0] @@ -331,7 +337,7 @@ def test_load_sleep_ann(self): assert isinstance(ann["df_events"], pd.DataFrame) and len(ann["df_events"]) == 0 with pytest.raises(ValueError, match="Source `.+` not supported, "): - reader.load_sleep_ann(rec, source="invalid") + reader.load_sleep_ann(rec, source="invalid") # type: ignore def test_load_apnea_ann(self): rec = reader.rec_with_event_ann[0] @@ -352,7 +358,7 @@ def test_load_apnea_ann(self): assert isinstance(ann, pd.DataFrame) and ann.empty with pytest.raises(ValueError, match="Source `hrv` contains no apnea annotations"): - reader.load_apnea_ann(rec, source="hrv") + reader.load_apnea_ann(rec, source="hrv") # type: ignore def test_load_sleep_event_ann(self): rec = reader.rec_with_event_ann[0] @@ -379,7 +385,7 @@ def test_load_sleep_event_ann(self): assert isinstance(ann, pd.DataFrame) and len(ann) == 0 with pytest.raises(ValueError, match="Source `.+` not supported, "): - reader.load_sleep_event_ann(rec, source="invalid") + reader.load_sleep_event_ann(rec, source="invalid") # type: ignore def test_load_sleep_stage_ann(self): rec = reader.rec_with_event_ann[0] @@ -402,7 +408,7 @@ def test_load_sleep_stage_ann(self): assert isinstance(ann, pd.DataFrame) and len(ann) == 0 with pytest.raises(ValueError, match="Source `.+` not supported, "): - reader.load_sleep_stage_ann(rec, source="invalid") + reader.load_sleep_stage_ann(rec, source="invalid") # type: ignore def test_locate_abnormal_beats(self): rec = reader.rec_with_rpeaks_ann[0] @@ -439,12 +445,12 @@ def test_locate_abnormal_beats(self): rec = reader.rec_with_rpeaks_ann[0] with pytest.raises(ValueError, match="No abnormal type of `.+`"): - reader.locate_abnormal_beats(rec, abnormal_type="AF") + reader.locate_abnormal_beats(rec, abnormal_type="AF") # type: ignore with pytest.raises( ValueError, match="`units` should be one of 's', 'ms', case insensitive, or None", ): - reader.locate_abnormal_beats(rec, units="invalid") + reader.locate_abnormal_beats(rec, units="invalid") # type: ignore def test_locate_artifacts(self): rec = reader.rec_with_rpeaks_ann[0] @@ -473,7 +479,7 @@ def test_locate_artifacts(self): ValueError, match="`units` should be one of 's', 'ms', case insensitive, or None", ): - reader.locate_artifacts(rec, units="invalid") + reader.locate_artifacts(rec, units="invalid") # type: ignore def test_get_available_signals(self): assert reader.get_available_signals(None) is None # no return @@ -486,14 +492,14 @@ def test_get_available_signals(self): def test_get_chn_num(self): available_signals = reader.get_available_signals(0) - for sig in available_signals: + for sig in available_signals: # type: ignore chn_num = reader.get_chn_num(0, sig) assert isinstance(chn_num, int) - assert 0 <= chn_num < len(available_signals) + assert 0 <= chn_num < len(available_signals) # type: ignore def test_match_channel(self): available_signals = reader.get_available_signals(0) - for sig in available_signals: + for sig in available_signals: # type: ignore assert sig == reader.match_channel(sig.lower()) assert sig in reader.all_signals @@ -501,13 +507,13 @@ def test_match_channel(self): def test_get_fs(self): available_signals = reader.get_available_signals(0) - for sig in available_signals: + for sig in available_signals: # type: ignore fs = reader.get_fs(0, sig) - assert isinstance(fs, Real) and fs > 0 + assert isinstance(fs, Real) and fs > 0 # type: ignore rec = reader.rec_with_rpeaks_ann[0] fs = reader.get_fs(rec, "rpeak") - assert isinstance(fs, Real) and fs > 0 + assert isinstance(fs, Real) and fs > 0 # type: ignore rec = "shhs2-200001" # a record (both signal and ann. files) that does not exist fs = reader.get_fs(rec) @@ -629,7 +635,7 @@ def test_plot(self): match="Unknown plot format `xxx`! `plot_format` can only be one of `span`, `hypnogram`", ): rec = reader.rec_with_event_ann[0] - reader.plot_ann(rec, event_source="event", plot_format="xxx") + reader.plot_ann(rec, event_source="event", plot_format="xxx") # type: ignore with pytest.raises(ValueError, match="No input data"): rec = reader.rec_with_event_ann[0] diff --git a/test/test_databases/test_sph.py b/test/test_databases/test_sph.py index 3b908053..0a04120d 100644 --- a/test/test_databases/test_sph.py +++ b/test/test_databases/test_sph.py @@ -66,6 +66,7 @@ def test_load_data(self): data_1, data_1_fs = reader.load_data(rec, return_fs=True) assert data_1_fs == reader.fs + rec = 0 with pytest.raises(AssertionError, match="Invalid data_format: `flat`"): reader.load_data(rec, data_format="flat") with pytest.raises(AssertionError, match="Invalid units: `kV`"): @@ -82,6 +83,7 @@ def test_load_ann(self): ann_1 = reader.load_ann(rec, ann_format="f", ignore_modifier=False) assert len(ann) == len(ann_1) + rec = 0 with pytest.raises(ValueError, match="Unknown annotation format: `flat`"): reader.load_ann(rec, ann_format="flat") with pytest.raises(NotImplementedError, match="Abbreviations are not supported yet"): @@ -92,23 +94,34 @@ def test_get_subject_info(self): info = reader.get_subject_info(rec) assert isinstance(info, dict) assert info.keys() == {"age", "sex"} + info = reader.get_subject_info(0, items=["age"]) + assert isinstance(info, dict) + assert info.keys() == {"age"} def test_get_subject_id(self): for rec in reader: sid = reader.get_subject_id(rec) assert isinstance(sid, str) + sid = reader.get_subject_id(0) + assert isinstance(sid, str) def test_get_age(self): for rec in reader: age = reader.get_age(rec) assert isinstance(age, int) assert age > 0 + age = reader.get_age(0) + assert isinstance(age, int) + assert age > 0 def test_get_sex(self): for rec in reader: sex = reader.get_sex(rec) assert isinstance(sex, str) assert sex in ["M", "F"] + sex = reader.get_sex(0) + assert isinstance(sex, str) + assert sex in ["M", "F"] def test_get_siglen(self): for rec in reader: @@ -116,6 +129,10 @@ def test_get_siglen(self): data = reader.load_data(rec) assert isinstance(siglen, int) assert siglen == data.shape[1] + siglen = reader.get_siglen(0) + data = reader.load_data(0) + assert isinstance(siglen, int) + assert siglen == data.shape[1] def test_meta_data(self): assert isinstance(reader.url, dict) @@ -131,18 +148,18 @@ def test_plot(self): "t_onsets": [150, 1150], "t_offsets": [190, 1190], } - reader.plot(0, leads="II", ticks_granularity=2, waves=waves) + reader.plot(0, leads="II", ticks_granularity=2, waves=waves) # type: ignore waves = { "p_peaks": [105, 1105], "q_peaks": [120, 1120], "s_peaks": [125, 1125], "t_peaks": [170, 1170], } - reader.plot(0, leads=["II", 7], ticks_granularity=1, waves=waves) + reader.plot(0, leads=["II", 7], ticks_granularity=1, waves=waves) # type: ignore waves = { "p_peaks": [105, 1105], "r_peaks": [122, 1122], "t_peaks": [170, 1170], } data = reader.load_data(0) - reader.plot(0, data=data, ticks_granularity=0, waves=waves) + reader.plot(0, data=data, ticks_granularity=0, waves=waves) # type: ignore diff --git a/test/test_models/test_ecg_crnn.py b/test/test_models/test_ecg_crnn.py index 75db7109..5805cc1f 100644 --- a/test/test_models/test_ecg_crnn.py +++ b/test/test_models/test_ecg_crnn.py @@ -24,18 +24,18 @@ def test_ecg_crnn(): inp = torch.randn(2, n_leads, 2000).to(DEVICE) grid = itertools.product( - [cnn_name for cnn_name in ECG_CRNN_CONFIG.cnn.keys() if cnn_name != "name"], - [rnn_name for rnn_name in ECG_CRNN_CONFIG.rnn.keys() if rnn_name != "name"] + ["none"], - [attn_name for attn_name in ECG_CRNN_CONFIG.attn.keys() if attn_name != "name"] + ["none"], + [cnn_name for cnn_name in ECG_CRNN_CONFIG.cnn.keys() if cnn_name != "name"], # type: ignore + [rnn_name for rnn_name in ECG_CRNN_CONFIG.rnn.keys() if rnn_name != "name"] + ["none"], # type: ignore + [attn_name for attn_name in ECG_CRNN_CONFIG.attn.keys() if attn_name != "name"] + ["none"], # type: ignore ["none", "max", "avg"], # global pool ) - total = (len(ECG_CRNN_CONFIG.cnn.keys()) - 1) * len(ECG_CRNN_CONFIG.rnn.keys()) * len(ECG_CRNN_CONFIG.attn.keys()) * 3 + total = (len(ECG_CRNN_CONFIG.cnn.keys()) - 1) * len(ECG_CRNN_CONFIG.rnn.keys()) * len(ECG_CRNN_CONFIG.attn.keys()) * 3 # type: ignore for cnn_name, rnn_name, attn_name, global_pool in tqdm(grid, total=total, mininterval=1): config = deepcopy(ECG_CRNN_CONFIG) - config.cnn.name = cnn_name - config.rnn.name = rnn_name - config.attn.name = attn_name + config.cnn.name = cnn_name # type: ignore + config.rnn.name = rnn_name # type: ignore + config.attn.name = attn_name # type: ignore config.global_pool = global_pool model = ECG_CRNN(classes=classes, n_leads=n_leads, config=config).to(DEVICE) @@ -55,15 +55,15 @@ def test_ecg_crnn(): # load weights from v1 model.cnn.load_state_dict(model_v1.cnn.state_dict()) if model.rnn.__class__.__name__ != "Identity": - model.rnn.load_state_dict(model_v1.rnn.state_dict()) + model.rnn.load_state_dict(model_v1.rnn.state_dict()) # type: ignore if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) - doi = model.doi + doi = model.doi # type: ignore assert isinstance(doi, list) assert all([isinstance(d, str) for d in doi]), doi - doi = model_v1.doi + doi = model_v1.doi # type: ignore assert isinstance(doi, list) assert all([isinstance(d, str) for d in doi]), doi @@ -85,24 +85,24 @@ def test_warns_errors(): model_v1.inference(inp) config = deepcopy(ECG_CRNN_CONFIG) - config.cnn.name = "not_implemented" - config.cnn.not_implemented = {} + config.cnn.name = "not_implemented" # type: ignore + config.cnn.not_implemented = {} # type: ignore with pytest.raises(NotImplementedError, match="CNN \042.+\042 not implemented yet"): ECG_CRNN(classes=classes, n_leads=n_leads, config=config) with pytest.raises(NotImplementedError, match="CNN \042.+\042 not implemented yet"): ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config) config = deepcopy(ECG_CRNN_CONFIG) - config.rnn.name = "not_implemented" - config.rnn.not_implemented = {} + config.rnn.name = "not_implemented" # type: ignore + config.rnn.not_implemented = {} # type: ignore with pytest.raises(NotImplementedError, match="RNN \042.+\042 not implemented yet"): ECG_CRNN(classes=classes, n_leads=n_leads, config=config) with pytest.raises(NotImplementedError, match="RNN \042.+\042 not implemented yet"): ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config) config = deepcopy(ECG_CRNN_CONFIG) - config.attn.name = "not_implemented" - config.attn.not_implemented = {} + config.attn.name = "not_implemented" # type: ignore + config.attn.not_implemented = {} # type: ignore with pytest.raises(NotImplementedError, match="Attention \042.+\042 not implemented yet"): ECG_CRNN(classes=classes, n_leads=n_leads, config=config) with pytest.raises(NotImplementedError, match="Attention \042.+\042 not implemented yet"): @@ -128,10 +128,10 @@ def test_from_v1(): n_leads = 12 classes = ["NSR", "AF", "PVC", "LBBB", "RBBB", "PAB", "VFL"] model_v1 = ECG_CRNN_v1(classes=classes, n_leads=n_leads, config=config) - model_v1.save(_TMP_DIR / "ecg_crnn_v1.pth", {"classes": classes, "n_leads": n_leads}) + model_v1.save(_TMP_DIR / "ecg_crnn_v1.pth", {"classes": classes, "n_leads": n_leads}, use_safetensors=False) # type: ignore model = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth") del model - model, _ = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth", return_config=True) + model, _ = ECG_CRNN.from_v1(_TMP_DIR / "ecg_crnn_v1.pth", return_config=True) # type: ignore (_TMP_DIR / "ecg_crnn_v1.pth").unlink() del model_v1, model diff --git a/test/test_models/test_rr_lstm.py b/test/test_models/test_rr_lstm.py index 6b75c49d..8e33b851 100644 --- a/test/test_models/test_rr_lstm.py +++ b/test/test_models/test_rr_lstm.py @@ -22,9 +22,9 @@ def test_rr_lstm(): inp_bf = torch.randn(2, in_channels, 100).to(DEVICE) config = deepcopy(RR_LSTM_CONFIG) - config.clf.name = "crf" + config.clf.name = "crf" # type: ignore for attn_name in ["none"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp) @@ -35,22 +35,22 @@ def test_rr_lstm(): model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1]) model.lstm.load_state_dict(model_v1.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) config = deepcopy(RR_LSTM_CONFIG) - config.clf.name = "crf" + config.clf.name = "crf" # type: ignore config.batch_first = True for attn_name in ["none"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp_bf) assert out.shape == model.compute_output_shape(seq_len=inp_bf.shape[-1], batch_size=inp_bf.shape[0]) config = deepcopy(RR_LSTM_CONFIG) - config.clf.name = "linear" + config.clf.name = "linear" # type: ignore for attn_name in ["none", "gc", "nl", "se"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp) @@ -61,13 +61,13 @@ def test_rr_lstm(): model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1]) model.lstm.load_state_dict(model_v1.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) config = deepcopy(RR_LSTM_CONFIG) - config.clf.name = "linear" + config.clf.name = "linear" # type: ignore config.batch_first = True for attn_name in ["none", "gc", "nl", "se"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp_bf) @@ -75,7 +75,7 @@ def test_rr_lstm(): config = deepcopy(RR_AF_VANILLA_CONFIG) for attn_name in ["none", "gc", "nl", "se"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp) @@ -86,12 +86,12 @@ def test_rr_lstm(): model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1]) model.lstm.load_state_dict(model_v1.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) config = deepcopy(RR_AF_VANILLA_CONFIG) config.batch_first = True for attn_name in ["none", "gc", "nl", "se"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp_bf) @@ -99,7 +99,7 @@ def test_rr_lstm(): config = deepcopy(RR_AF_CRF_CONFIG) for attn_name in ["none"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp) @@ -110,21 +110,21 @@ def test_rr_lstm(): model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1]) model.lstm.load_state_dict(model_v1.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) config = deepcopy(RR_AF_CRF_CONFIG) config.batch_first = True for attn_name in ["none"]: - config.attn.name = attn_name + config.attn.name = attn_name # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp_bf) assert out.shape == model.compute_output_shape(seq_len=inp_bf.shape[-1], batch_size=inp_bf.shape[0]) config = deepcopy(RR_LSTM_CONFIG) - config.lstm.retseq = False - config.clf.name = "linear" - config.attn.name = "none" + config.lstm.retseq = False # type: ignore + config.clf.name = "linear" # type: ignore + config.attn.name = "none" # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp) @@ -135,13 +135,13 @@ def test_rr_lstm(): model_v1.compute_output_shape(seq_len=inp.shape[0], batch_size=inp.shape[1]) model.lstm.load_state_dict(model_v1.lstm.state_dict()) if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) config = deepcopy(RR_LSTM_CONFIG) - config.lstm.retseq = False - config.clf.name = "linear" + config.lstm.retseq = False # type: ignore + config.clf.name = "linear" # type: ignore config.batch_first = True - config.attn.name = "none" + config.attn.name = "none" # type: ignore model = RR_LSTM(classes=classes, config=config).to(DEVICE) model = model.eval() out = model(inp_bf) @@ -167,9 +167,9 @@ def test_warns_errors(): model_v1 = RR_LSTM_v1(classes=classes).to(DEVICE) config = deepcopy(RR_LSTM_CONFIG) - config.lstm.retseq = False - config.attn.name = "gc" - config.clf.name = "linear" + config.lstm.retseq = False # type: ignore + config.attn.name = "gc" # type: ignore + config.clf.name = "linear" # type: ignore with pytest.warns( RuntimeWarning, match="Attention is not supported when lstm is not returning sequences", @@ -182,9 +182,9 @@ def test_warns_errors(): model_v1 = RR_LSTM_v1(classes=classes, config=config).to(DEVICE) config = deepcopy(RR_LSTM_CONFIG) - config.lstm.retseq = False - config.attn.name = "none" - config.clf.name = "crf" + config.lstm.retseq = False # type: ignore + config.attn.name = "none" # type: ignore + config.clf.name = "crf" # type: ignore with pytest.warns( RuntimeWarning, match="CRF layer is not supported in non-sequence mode, using linear instead", @@ -215,15 +215,15 @@ def test_warns_errors(): model_v1.inference(inp) config = deepcopy(RR_LSTM_CONFIG) - config.attn.name = "not_implemented" - config.attn.not_implemented = {} + config.attn.name = "not_implemented" # type: ignore + config.attn.not_implemented = {} # type: ignore with pytest.raises(NotImplementedError, match="Attn module \042.+\042 not implemented yet"): model = RR_LSTM(classes=classes, config=config).to(DEVICE) with pytest.raises(NotImplementedError, match="Attn module \042.+\042 not implemented yet"): model_v1 = RR_LSTM_v1(classes=classes, config=config).to(DEVICE) config = deepcopy(RR_LSTM_CONFIG) - config.clf.name = "linear" + config.clf.name = "linear" # type: ignore config.global_pool = "not_supported" with pytest.raises(NotImplementedError, match="Pooling type \042.+\042 not supported"): model = RR_LSTM(classes=classes, config=config).to(DEVICE) @@ -235,9 +235,9 @@ def test_from_v1(): config = deepcopy(RR_LSTM_CONFIG) classes = ["NSR", "AF", "PVC", "LBBB", "RBBB", "PAB", "VFL"] model_v1 = RR_LSTM_v1(classes=classes, config=config) - model_v1.save(_TMP_DIR / "rr_lstm_v1.pth", {"classes": classes}) - model = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth") + model_v1.save(_TMP_DIR / "rr_lstm_v1.pth", {"classes": classes}, use_safetensors=False) # type: ignore + model = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth") # type: ignore del model - model, _ = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth", return_config=True) + model, _ = RR_LSTM.from_v1(_TMP_DIR / "rr_lstm_v1.pth", return_config=True) # type: ignore (_TMP_DIR / "rr_lstm_v1.pth").unlink() del model_v1, model diff --git a/test/test_models/test_seq_lab_net.py b/test/test_models/test_seq_lab_net.py index 7f8d8cf9..f0204eed 100644 --- a/test/test_models/test_seq_lab_net.py +++ b/test/test_models/test_seq_lab_net.py @@ -33,32 +33,32 @@ def test_ecg_seq_lab_net(): for cnn, rnn, attn, recover_length in tqdm(grid, total=total, mininterval=1): config = adjust_cnn_filter_lengths(ECG_SEQ_LAB_NET_CONFIG, fs) - config.cnn.name = cnn - config.rnn.name = rnn - config.attn.name = attn - config.recover_length = recover_length + config.cnn.name = cnn # type: ignore + config.rnn.name = rnn # type: ignore + config.attn.name = attn # type: ignore + config.recover_length = recover_length # type: ignore - model = ECG_SEQ_LAB_NET(classes=classes, n_leads=12, config=config).to(DEVICE) + model = ECG_SEQ_LAB_NET(classes=classes, n_leads=12, config=config).to(DEVICE) # type: ignore model = model.eval() out = model(inp) assert out.shape == model.compute_output_shape(seq_len=inp.shape[-1], batch_size=inp.shape[0]) if recover_length: assert out.shape[1] == inp.shape[-1] - model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=12, config=config).to(DEVICE) + model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=12, config=config).to(DEVICE) # type: ignore model_v1 = model_v1.eval() out_v1 = model_v1(inp) model.cnn.load_state_dict(model_v1.cnn.state_dict()) if model.rnn.__class__.__name__ != "Identity": - model.rnn.load_state_dict(model_v1.rnn.state_dict()) + model.rnn.load_state_dict(model_v1.rnn.state_dict()) # type: ignore if model.attn.__class__.__name__ != "Identity": - model.attn.load_state_dict(model_v1.attn.state_dict()) + model.attn.load_state_dict(model_v1.attn.state_dict()) # type: ignore model.clf.load_state_dict(model_v1.clf.state_dict()) - doi = model.doi + doi = model.doi # type: ignore assert isinstance(doi, list) assert all([isinstance(d, str) for d in doi]), doi - doi = model_v1.doi + doi = model_v1.doi # type: ignore assert isinstance(doi, list) assert all([isinstance(d, str) for d in doi]), doi @@ -92,9 +92,9 @@ def test_from_v1(): n_leads = 12 classes = ["N"] model_v1 = ECG_SEQ_LAB_NET_v1(classes=classes, n_leads=n_leads, config=config) - model_v1.save(_TMP_DIR / "ecg_seq_lab_net_v1.pth", {"classes": classes, "n_leads": n_leads}) - model = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth") + model_v1.save(_TMP_DIR / "ecg_seq_lab_net_v1.pth", {"classes": classes, "n_leads": n_leads}, use_safetensors=False) # type: ignore + model = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth") # type: ignore del model - model, _ = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth", return_config=True) + model, _ = ECG_SEQ_LAB_NET.from_v1(_TMP_DIR / "ecg_seq_lab_net_v1.pth", return_config=True) # type: ignore (_TMP_DIR / "ecg_seq_lab_net_v1.pth").unlink() del model_v1, model diff --git a/test/test_pipelines/test_crnn_cinc2021_pipeline.py b/test/test_pipelines/test_crnn_cinc2021_pipeline.py index 07171053..349dfb85 100644 --- a/test/test_pipelines/test_crnn_cinc2021_pipeline.py +++ b/test/test_pipelines/test_crnn_cinc2021_pipeline.py @@ -405,8 +405,8 @@ def inference( # batch_size, channels, seq_len = _input.shape prob = self.sigmoid(self.forward(_input)) pred = (prob >= bin_pred_thr).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() for row_idx, row in enumerate(pred): row_max_prob = prob[row_idx, ...].max() if row_max_prob < ModelCfg.bin_pred_nsr_thr and nsr_cid is not None: diff --git a/test/test_pipelines/test_mtl_cinc2022_pipeline.py b/test/test_pipelines/test_mtl_cinc2022_pipeline.py index 21d46270..420ae746 100644 --- a/test/test_pipelines/test_mtl_cinc2022_pipeline.py +++ b/test/test_pipelines/test_mtl_cinc2022_pipeline.py @@ -2212,31 +2212,31 @@ def inference( prob = self.softmax(forward_output["murmur"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() murmur_output = ClassificationOutput( classes=self.classes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["murmur"].cpu().detach().numpy(), + forward_output=forward_output["murmur"].detach().cpu().numpy(), ) if forward_output.get("outcome", None) is not None: prob = self.softmax(forward_output["outcome"]) pred = torch.argmax(prob, dim=-1) bin_pred = (prob == prob.max(dim=-1, keepdim=True).values).to(int) - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() - bin_pred = bin_pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() + bin_pred = bin_pred.detach().cpu().numpy() outcome_output = ClassificationOutput( classes=self.outcomes, prob=prob, pred=pred, bin_pred=bin_pred, - forward_output=forward_output["outcome"].cpu().detach().numpy(), + forward_output=forward_output["outcome"].detach().cpu().numpy(), ) else: outcome_output = None @@ -2250,13 +2250,13 @@ def inference( else: prob = self.sigmoid(forward_output["segmentation"]) pred = (prob > seg_thr).int() * (prob == prob.max(dim=-1, keepdim=True).values).int() - prob = prob.cpu().detach().numpy() - pred = pred.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() + pred = pred.detach().cpu().numpy() segmentation_output = SequenceLabellingOutput( classes=self.states, prob=prob, pred=pred, - forward_output=forward_output["segmentation"].cpu().detach().numpy(), + forward_output=forward_output["segmentation"].detach().cpu().numpy(), ) else: segmentation_output = None diff --git a/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py b/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py index 0f0145ba..1a4e8287 100644 --- a/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py +++ b/test/test_pipelines/test_seq_lab_cpsc2019_pipeline.py @@ -149,7 +149,7 @@ def inference( mode="linear", align_corners=True, ).permute(0, 2, 1) - prob = prob.cpu().detach().numpy().squeeze(-1) + prob = prob.detach().cpu().numpy().squeeze(-1) # prob --> qrs mask --> qrs intervals --> rpeaks rpeaks = _inference_post_process( @@ -169,7 +169,7 @@ def inference( sampling_rate=self.config.fs, tol=0.05, )[0] - for b_input, b_rpeaks in zip(_input.detach().numpy().squeeze(1), rpeaks) + for b_input, b_rpeaks in zip(_input.detach().cpu().numpy().squeeze(1), rpeaks) ] return RPeaksDetectionOutput( diff --git a/test/test_pipelines/test_unet_ludb_pipeline.py b/test/test_pipelines/test_unet_ludb_pipeline.py index add76d60..7773893c 100644 --- a/test/test_pipelines/test_unet_ludb_pipeline.py +++ b/test/test_pipelines/test_unet_ludb_pipeline.py @@ -128,7 +128,7 @@ def inference( prob = self.softmax(prob) else: prob = torch.sigmoid(prob) - prob = prob.cpu().detach().numpy() + prob = prob.detach().cpu().numpy() if "i" in self.classes: mask = np.argmax(prob, axis=-1) diff --git a/test/test_utils/test_download.py b/test/test_utils/test_download.py index 56e6ceed..bfda06f6 100644 --- a/test/test_utils/test_download.py +++ b/test/test_utils/test_download.py @@ -1,18 +1,14 @@ """ """ import shutil +import subprocess +import types import urllib.parse from pathlib import Path import pytest -from torch_ecg.utils.download import ( - _download_from_aws_s3_using_boto3, - _download_from_google_drive, - http_get, - is_compressed_file, - url_is_reachable, -) +import torch_ecg.utils.download as dl _TMP_DIR = Path(__file__).resolve().parents[2] / "tmp" / "test_download" _TMP_DIR.mkdir(parents=True, exist_ok=True) @@ -22,11 +18,11 @@ def test_http_get(): # normally, direct downloading from dropbox with `dl=0` will not download the file # http_get internally replaces `dl=0` with `dl=1` to force download url = "https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0" - http_get(url, _TMP_DIR / "action-test-zip-extract", extract=True, filename="test.zip") + dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract=True, filename="test.zip") shutil.rmtree(_TMP_DIR / "action-test-zip-extract") - http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto", filename="test.zip") + dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto", filename="test.zip") shutil.rmtree(_TMP_DIR / "action-test-zip-extract") - http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto") + dl.http_get(url, _TMP_DIR / "action-test-zip-extract", extract="auto") shutil.rmtree(_TMP_DIR / "action-test-zip-extract") url = ( @@ -41,13 +37,13 @@ def test_http_get(): "Automatic decompression is turned off\\." ), ): - http_get(url, _TMP_DIR, extract=True, filename="test.txt") + dl.http_get(url, _TMP_DIR, extract=True, filename="test.txt") with pytest.raises(FileExistsError, match="file already exists"): - http_get(url, _TMP_DIR, extract=True, filename="test.txt") + dl.http_get(url, _TMP_DIR, extract=True, filename="test.txt") (_TMP_DIR / "test.txt").unlink() - http_get(url, _TMP_DIR, extract="auto", filename="test.txt") + dl.http_get(url, _TMP_DIR, extract="auto", filename="test.txt") (_TMP_DIR / "test.txt").unlink() - http_get(url, _TMP_DIR, extract="auto") + dl.http_get(url, _TMP_DIR, extract="auto") with pytest.warns( RuntimeWarning, @@ -57,7 +53,7 @@ def test_http_get(): "The user is responsible for decompressing the file manually\\." ), ): - http_get(url, _TMP_DIR, extract=True) + dl.http_get(url, _TMP_DIR, extract=True) Path(_TMP_DIR / Path(url).name).unlink() # test downloading from Google Drive @@ -66,45 +62,246 @@ def test_http_get(): url_no_scheme = f"drive.google.com/file/d/{file_id}/view?usp=sharing" url_xxx_schme = f"xxx://drive.google.com/file/d/{file_id}/view?usp=sharing" with pytest.raises(AssertionError, match="filename can not be inferred from Google Drive URL"): - http_get(url_no_scheme, _TMP_DIR) + dl.http_get(url_no_scheme, _TMP_DIR) with pytest.raises(ValueError, match="Unsupported URL scheme"): - http_get(url_xxx_schme, _TMP_DIR, extract=False, filename="torch-ecg-paper.bib") - http_get(url, _TMP_DIR, filename="torch-ecg-paper.bib", extract=False) + dl.http_get(url_xxx_schme, _TMP_DIR, extract=False, filename="torch-ecg-paper.bib") + dl.http_get(url, _TMP_DIR, filename="torch-ecg-paper.bib", extract=False) (_TMP_DIR / "torch-ecg-paper.bib").unlink() - _download_from_google_drive(file_id, _TMP_DIR / "torch-ecg-paper.bib") + dl._download_from_google_drive(file_id, _TMP_DIR / "torch-ecg-paper.bib") (_TMP_DIR / "torch-ecg-paper.bib").unlink() - _download_from_google_drive(url_no_scheme, _TMP_DIR / "torch-ecg-paper.bib") + dl._download_from_google_drive(url_no_scheme, _TMP_DIR / "torch-ecg-paper.bib") (_TMP_DIR / "torch-ecg-paper.bib").unlink() # test downloading from AWS S3 (by default using AWS CLI) (_TMP_DIR / "ludb").mkdir(exist_ok=True) - http_get("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb") + dl.http_get("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb") # test downloading from AWS S3 (by using boto3) shutil.rmtree(_TMP_DIR / "ludb") (_TMP_DIR / "ludb").mkdir(exist_ok=True) - _download_from_aws_s3_using_boto3("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb") + dl._download_from_aws_s3_using_boto3("s3://physionet-open/ludb/1.0.1/", _TMP_DIR / "ludb") + + with pytest.raises(ValueError, match="Invalid S3 URL"): + dl.http_get("s3://xxx", _TMP_DIR / "ludb") + + assert dl._stem(b"https://example.com/path/to/file.tar.gz") == "file" def test_url_is_reachable(): - assert url_is_reachable("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1") - assert not url_is_reachable("https://www.some-unknown-domain.com/unknown-path/unknown-file.zip") + assert dl.url_is_reachable("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1") + assert not dl.url_is_reachable("https://www.some-unknown-domain.com/unknown-path/unknown-file.zip") def test_is_compressed_file(): # check local files - assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txt") - assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract") - assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test") - assert not is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.pth.tar") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.gz") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tgz") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.bz2") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tbz2") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.xz") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txz") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.zip") - assert is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.7z") + assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txt") + assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract") + assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test") + assert not dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.pth.tar") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.gz") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tgz") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.bz2") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tbz2") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.tar.xz") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.txz") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.zip") + assert dl.is_compressed_file(_TMP_DIR / "action-test-zip-extract" / "test.7z") # check remote files (by URL) - assert is_compressed_file(urllib.parse.urlparse("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0").path) + assert dl.is_compressed_file(urllib.parse.urlparse("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0").path) + + +class FakeResponseOK: + """A fake streaming HTTP response that will raise during iter_content.""" + + def __init__(self, raise_in_iter=False, raise_text=False): + self.status_code = 200 + self.headers = {"Content-Length": "10"} + self._raise_in_iter = raise_in_iter + self._raise_text = raise_text + self._text_value = "FAKE_CONTENT_ABCDEFG" * 5 + + def iter_content(self, chunk_size=1024): + if self._raise_in_iter: + # simulate mid-download exception + raise RuntimeError("iter boom") + yield b"abc" + yield b"def" + + @property + def text(self): + if self._raise_text: + raise ValueError("text boom") + return self._text_value + + +class FakeSession: + def __init__(self, response: FakeResponseOK): + self._resp = response + + def get(self, *a, **kw): + return self._resp + + +@pytest.mark.parametrize("raise_text", [False, True], ids=["text_ok", "text_raises"]) +def test_http_get_iter_exception_triggers_runtime(monkeypatch, tmp_path, raise_text): + """Cover the outer except and the inner try/except that reads req.text.""" + resp = FakeResponseOK(raise_in_iter=True, raise_text=raise_text) + + def fake_retry_session(): + return FakeSession(resp) + + monkeypatch.setattr(dl, "_requests_retry_session", lambda: types.SimpleNamespace(get=fake_retry_session().get)) + + target_dir = tmp_path / "download" + with pytest.raises(RuntimeError) as exc: + dl.http_get("https://example.com/file.dat", target_dir, extract=False, filename="f.bin") + + msg = str(exc.value) + if raise_text: + # inner text extraction failed => snippet empty + assert "body[:300]=''" in msg or "body[:300]=''"[:10] + else: + assert "FAKE_CONTENT" in msg + + +def test_http_get_status_403(monkeypatch, tmp_path): + """Force status 403 path (your code raises generic Exception then caught by outer except).""" + + class Resp403: + status_code = 403 + headers = {} + text = "Forbidden" + + def iter_content(self, chunk_size=1024): + yield b"" + + def fake_session(): + return types.SimpleNamespace(get=lambda *a, **k: Resp403()) + + monkeypatch.setattr(dl, "_requests_retry_session", lambda: fake_session()) + + with pytest.raises(RuntimeError) as exc: + dl.http_get("https://example.com/forbidden.bin", tmp_path, extract=False, filename="f.bin") + assert "Failed to download" in str(exc.value) + assert "status=403" in str(exc.value) + assert "Forbidden" in str(exc.value) + + +def test_download_from_aws_s3_using_boto3_empty_bucket(monkeypatch, tmp_path): + """Cover object_count == 0 -> ValueError.""" + + class FakePaginator: + def paginate(self, **kwargs): + # return an iterator with no 'Contents' + return iter([{"Other": 1}, {"Meta": 2}]) + + class FakeBoto3Client: + def get_paginator(self, name): + assert name == "list_objects_v2" + return FakePaginator() + + def close(self): + pass + + monkeypatch.setattr(dl, "boto3", types.SimpleNamespace(client=lambda *a, **k: FakeBoto3Client())) + + with pytest.raises(ValueError, match="No objects found"): + dl._download_from_aws_s3_using_boto3("s3://fake-bucket/prefix/", tmp_path / "out") + + +def test_download_from_aws_s3_using_awscli_subprocess_fail(monkeypatch, tmp_path): + """Force awscli sync subprocess to exit with non-zero code -> CalledProcessError.""" + # Pretend aws exists + monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws") + # Control object count + monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 3) + + class FakeStdout: + def __init__(self, lines): + self._lines = lines # list[str] + self._idx = 0 + self.closed = False + + def __iter__(self): + return self + + def __next__(self): + if self._idx < len(self._lines): + line = self._lines[self._idx] + self._idx += 1 + return line + raise StopIteration + + def close(self): + self.closed = True + + class FakePopen: + def __init__(self, *a, **k): + self.stdout = FakeStdout( + [ + "download: s3://bucket/file1 to file1\n", + "download: s3://bucket/file2 to file2\n", + "some other line\n", + ] + ) + self._returncode = None + self._waited = False + self.args = a + self.kwargs = k + + def poll(self): + if self.stdout._idx < len(self.stdout._lines): + return None + if self._returncode is None: + self._returncode = 2 + return self._returncode + + def wait(self): + self.stdout._idx = len(self.stdout._lines) + if self._returncode is None: + self._returncode = 2 + self._waited = True + return self._returncode + + def communicate(self): + combined = "".join(self.stdout._lines) + return (combined, "boom\n") + + monkeypatch.setattr(dl.subprocess, "Popen", lambda *a, **k: FakePopen(*a, **k)) + + with pytest.raises(subprocess.CalledProcessError) as exc: + dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path / "out", show_progress=False) + + assert "download: s3://bucket/file1" in exc.value.output + assert exc.value.returncode != 0 + + +def test_url_is_reachable_exception(monkeypatch): + def boom(*a, **k): + raise RuntimeError("network down") + + monkeypatch.setattr(dl.requests, "head", boom) + assert dl.url_is_reachable("https://whatever") is False + + +def test_download_from_aws_awscli_missing(monkeypatch, tmp_path): + monkeypatch.setattr(dl.shutil, "which", lambda name: None) + + with pytest.raises(RuntimeError, match="AWS cli is required to download from S3"): + dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path) + + +def test_download_from_aws_awscli_present_fast_path(monkeypatch, tmp_path): + monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws") + monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 0) + dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path, show_progress=False) + + +def test_awscli_non_ci_branch(monkeypatch, tmp_path): + monkeypatch.delenv("CI", raising=False) + + monkeypatch.setattr(dl.shutil, "which", lambda name: "/usr/bin/aws") + monkeypatch.setattr(dl, "count_aws_s3_bucket", lambda bucket, prefix: 0) + + dl._download_from_aws_s3_using_awscli("s3://bucket/prefix/", tmp_path, show_progress=False) diff --git a/test/test_utils/test_misc.py b/test/test_utils/test_misc.py index f1b9267a..6485407a 100644 --- a/test/test_utils/test_misc.py +++ b/test/test_utils/test_misc.py @@ -1,6 +1,7 @@ """ """ import datetime +import json import textwrap import time from itertools import product @@ -18,6 +19,7 @@ MovingAverage, ReprMixin, Timer, + _is_pathlike_string, add_docstring, add_kwargs, dict_to_str, @@ -128,7 +130,7 @@ def test_get_record_list_recursive3(): record_list = get_record_list_recursive3(path, rec_patterns_with_ext, with_suffix=True) for tranche in list("EFG"): # assert the records come with file extension - assert all([p.endswith(".mat") for p in record_list[tranche]]), record_list[tranche] + assert all([p.endswith(".mat") for p in record_list[tranche]]), record_list[tranche] # type: ignore def test_dict_to_str(): @@ -140,8 +142,8 @@ def test_dict_to_str(): def test_str2bool(): assert str2bool(True) is True assert str2bool(False) is False - assert str2bool("True") is True - assert str2bool("False") is False + assert str2bool("True ") is True + assert str2bool(" False") is False assert str2bool("true") is True assert str2bool("false") is False assert str2bool("1") is True @@ -150,10 +152,16 @@ def test_str2bool(): assert str2bool("no") is False assert str2bool("y") is True assert str2bool("n") is False + assert str2bool(None) is False + assert str2bool(None, default=True) is True + assert str2bool("", strict=False) is False with pytest.raises(ValueError, match="Boolean value expected"): str2bool("abc") with pytest.raises(ValueError, match="Boolean value expected"): str2bool("2") + with pytest.raises(TypeError, match="Expected str|bool|None"): + str2bool(1) # type: ignore + assert str2bool(1, strict=False) is False # type: ignore def test_diff_with_step(): @@ -192,7 +200,7 @@ def test_plot_single_lead(): n_samples = 5000 plot_single_lead( t=np.arange(n_samples) / fs, - sig=500 * DEFAULTS.RNG.normal(size=(n_samples,)), + sig=500 * DEFAULTS.RNG.normal(size=(n_samples,)), # type: ignore ticks_granularity=2, ) @@ -225,6 +233,7 @@ def test_read_log_txt(): dst_dir=str(_TMP_DIR), extract=True, filename="log.txt", + verify_length=False, ) log_txt_file = str(_TMP_DIR / "log.txt") log_txt = read_log_txt(log_txt_file) @@ -260,15 +269,15 @@ def test_dicts_equal(): d2 = {"a": pd.DataFrame([{"hehe": 2, "haha": 2}])[["hehe", "haha"]]} assert dicts_equal(d1, d2) is False assert dicts_equal(d2, d1) is False - d1["a"] = d1["a"]["hehe"] - d2["a"] = d2["a"]["haha"] + d1["a"] = d1["a"]["hehe"] # type: ignore + d2["a"] = d2["a"]["haha"] # type: ignore assert dicts_equal(d1, d2) is False assert dicts_equal(d2, d1) is False d1 = {"a": pd.DataFrame([{"hehe": 1, "haha": 2}])[["haha", "hehe"]]} d2 = {"a": pd.DataFrame([{"hehe": 2, "haha": 2}])[["hehe", "haha"]]} - d1["a"] = d1["a"]["hehe"] - d2["a"] = d2["a"]["hehe"] + d1["a"] = d1["a"]["hehe"] # type: ignore + d2["a"] = d2["a"]["hehe"] # type: ignore assert dicts_equal(d1, d2) is False assert dicts_equal(d2, d1) is False @@ -361,7 +370,7 @@ def test_CitationMixin(): def test_MovingAverage(): ma = MovingAverage(verbose=2) - data = DEFAULTS.RNG.normal(size=(100,)) + data = DEFAULTS.RNG.normal(size=(100,)) # type: ignore new_data = ma(data, method="sma", window=7, center=True) assert new_data.shape == data.shape new_data = ma(data, method="ema", weight=0.7) @@ -435,7 +444,7 @@ def func(a, b): with pytest.raises(ValueError, match="mode `.+` is not supported"): - @add_docstring("This is a new docstring.", mode="xxx") + @add_docstring("This is a new docstring.", mode="xxx") # type: ignore def func(a, b): """This is a docstring.""" return a + b @@ -443,7 +452,7 @@ def func(a, b): def test_remove_parameters_returns_from_docstring(): new_docstring = remove_parameters_returns_from_docstring( - remove_parameters_returns_from_docstring.__doc__, + remove_parameters_returns_from_docstring.__doc__, # type: ignore parameters=["returns_indicator", "parameters_indicator"], returns="str", ) @@ -602,7 +611,17 @@ def test_make_serializable(): x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean()) obj = make_serializable(x) assert obj == [[1, 2, 3], 5.0] - assert isinstance(obj[1], float) and isinstance(x[1], np.float64) + assert isinstance(obj[1], float) and isinstance(x[1], np.float64) # type: ignore + + obj = make_serializable(DEFAULTS, drop_unserializable=False, drop_paths=False) + assert isinstance(obj, dict) and set(["RNG", "DTYPE", "log_dir"]).issubset(set(obj)) + json.dumps(obj) # should raise no error + obj = make_serializable(DEFAULTS, drop_unserializable=False, drop_paths=True) + assert isinstance(obj, dict) and set(["RNG", "DTYPE"]).issubset(set(obj)) and "log_dir" not in obj + json.dumps(obj) # should raise no error + obj = make_serializable(DEFAULTS, drop_unserializable=True, drop_paths=True) + assert isinstance(obj, dict) and set(["RNG", "DTYPE", "log_dir"]).intersection(set(obj)) == set() + json.dumps(obj) # should raise no error def test_select_k(): @@ -674,3 +693,20 @@ def test_np_topk(): with pytest.raises(AssertionError, match="dim out of bounds"): np_topk(arr1d, k=1, dim=1) + + +def test_is_pathlike_string(): + assert _is_pathlike_string("abc") is False + assert _is_pathlike_string("abc.txt") is True + assert _is_pathlike_string("/home/abc") is True + assert _is_pathlike_string("C:\\abc") is True + assert _is_pathlike_string("C:/abc") is True + assert _is_pathlike_string("./abc") is True + assert _is_pathlike_string("") is False + assert _is_pathlike_string("~/abc") is True + assert _is_pathlike_string("A:project") is False + assert _is_pathlike_string("README") is False + assert _is_pathlike_string("my.folder") is True + assert _is_pathlike_string(".") is True + assert _is_pathlike_string(["abc", "def"]) is False # type: ignore + assert _is_pathlike_string(123) is False # type: ignore diff --git a/test/test_utils/test_utils_data.py b/test/test_utils/test_utils_data.py index 205800c5..1349b69a 100644 --- a/test/test_utils/test_utils_data.py +++ b/test/test_utils/test_utils_data.py @@ -51,6 +51,9 @@ def test_get_mask(): assert intervals == mask_to_intervals(mask[idx], 1) assert (get_mask(5000, np.arange(250, 5000 - 250, 400), 50, 50) == mask[0]).all() + with pytest.raises(ValueError, match="Unknown return_fmt. Expected 'mask' or 'intervals', but got"): + get_mask((12, 5000), np.arange(250, 5000 - 250, 400), 50, 50, return_fmt="xxx") # type: ignore + def test_mask_to_intervals(): mask = np.zeros(100, dtype=int) @@ -126,7 +129,7 @@ def test_rdheader(): with pytest.raises(FileNotFoundError, match="file `not_exist_file\\.hea` not found"): rdheader("not_exist_file") with pytest.raises(TypeError, match="header_data must be str or sequence of str, but got"): - rdheader(1) + rdheader(1) # type: ignore def test_ensure_lead_fmt(): diff --git a/test/test_utils/test_utils_nn.py b/test/test_utils/test_utils_nn.py index ef19e42d..d753f07a 100644 --- a/test/test_utils/test_utils_nn.py +++ b/test/test_utils/test_utils_nn.py @@ -107,9 +107,9 @@ def test_compute_output_shape(): num_filters=num_filters, output_padding=0, channel_last=channel_last, - **conv_kw, + **conv_kw, # type: ignore ) - conv_output_tensor = torch.nn.Conv1d(in_channels, num_filters, **conv_kw)(tensor_first) + conv_output_tensor = torch.nn.Conv1d(in_channels, num_filters, **conv_kw)(tensor_first) # type: ignore if channel_last: conv_output_tensor = conv_output_tensor.permute(0, 2, 1) @@ -130,15 +130,15 @@ def test_compute_output_shape(): num_filters=num_filters, output_padding=0, channel_last=channel_last, - **deconv_kw, + **deconv_kw, # type: ignore ) - deconv_output_tensor = torch.nn.ConvTranspose1d(in_channels, num_filters, **deconv_kw)(tensor_first) + deconv_output_tensor = torch.nn.ConvTranspose1d(in_channels, num_filters, **deconv_kw)(tensor_first) # type: ignore if channel_last: deconv_output_tensor = deconv_output_tensor.permute(0, 2, 1) assert deconv_output_shape == deconv_output_tensor.shape compute_deconv_output_shape( - input_shape=tensor.shape, num_filters=num_filters, output_padding=0, channel_last=channel_last, **deconv_kw + input_shape=tensor.shape, num_filters=num_filters, output_padding=0, channel_last=channel_last, **deconv_kw # type: ignore ) # maxpool @@ -155,9 +155,9 @@ def test_compute_output_shape(): ) for tensor, channel_last in zip([tensor_first, tensor_last], [False, True]): maxpool_output_shape = compute_output_shape( - "maxpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **maxpool_kw + "maxpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **maxpool_kw # type: ignore ) - maxpool_output_tensor = torch.nn.MaxPool1d(**maxpool_kw)(tensor_first) + maxpool_output_tensor = torch.nn.MaxPool1d(**maxpool_kw)(tensor_first) # type: ignore if channel_last: maxpool_output_tensor = maxpool_output_tensor.permute(0, 2, 1) @@ -176,9 +176,9 @@ def test_compute_output_shape(): ) for tensor, channel_last in zip([tensor_first, tensor_last], [False, True]): avgpool_output_shape = compute_output_shape( - "avgpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **avgpool_kw + "avgpool", input_shape=tensor.shape, num_filters=1, output_padding=0, channel_last=channel_last, **avgpool_kw # type: ignore ) - avgpool_output_tensor = torch.nn.AvgPool1d(**avgpool_kw)(tensor_first) + avgpool_output_tensor = torch.nn.AvgPool1d(**avgpool_kw)(tensor_first) # type: ignore if channel_last: avgpool_output_tensor = avgpool_output_tensor.permute(0, 2, 1) @@ -186,7 +186,7 @@ def test_compute_output_shape(): shape_1 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8]) shape_2 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8], asymmetric_padding=[1, 3]) - assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 1 + 3) + assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 1 + 3) # type: ignore shape_1 = compute_output_shape("conv", [None, None, 224, 224], padding=[4, 8]) shape_2 = compute_output_shape( "conv", @@ -194,7 +194,7 @@ def test_compute_output_shape(): padding=[4, 8], asymmetric_padding=[[1, 3], [0, 2]], ) - assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 0 + 2) + assert shape_2[2:] == (shape_1[2] + 1 + 3, shape_1[3] + 0 + 2) # type: ignore shape_1 = compute_output_shape( "conv", @@ -217,7 +217,7 @@ def test_compute_output_shape(): with pytest.raises(AssertionError, match="`num_filters` should be `None` or positive integer"): compute_output_shape("conv", tensor_first.shape, num_filters=-12) with pytest.raises(AssertionError, match="`kernel_size` should contain only positive integers"): - compute_output_shape("conv", tensor_first.shape, kernel_size=2.5) + compute_output_shape("conv", tensor_first.shape, kernel_size=2.5) # type: ignore with pytest.raises(AssertionError, match="`kernel_size` should contain only positive integers"): compute_output_shape("conv", tensor_first.shape, kernel_size=0) with pytest.raises(AssertionError, match="`stride` should contain only positive integers"): @@ -249,7 +249,7 @@ def test_compute_output_shape(): with pytest.raises(ValueError, match="input has 1 dimensions, while `padding` has 2 dimensions,"): compute_output_shape("conv", tensor_first.shape, padding=[1, 1]) with pytest.raises(AssertionError, match="Invalid `asymmetric_padding`"): - compute_output_shape("conv", tensor_first.shape, asymmetric_padding=2) + compute_output_shape("conv", tensor_first.shape, asymmetric_padding=2) # type: ignore with pytest.raises(AssertionError, match="Invalid `asymmetric_padding`"): compute_output_shape( "conv", @@ -314,32 +314,32 @@ def test_default_collate_fn(): batch_data = [ ( - DEFAULTS.RNG.uniform(size=shape_1), - DEFAULTS.RNG.uniform(size=shape_2), - DEFAULTS.RNG.uniform(size=shape_3), + DEFAULTS.RNG.uniform(size=shape_1), # type: ignore + DEFAULTS.RNG.uniform(size=shape_2), # type: ignore + DEFAULTS.RNG.uniform(size=shape_3), # type: ignore ) for _ in range(batch_size) ] tensor_1, tensor_2, tensor_3 = default_collate_fn(batch_data) - assert tensor_1.shape == (batch_size, *shape_1) - assert tensor_2.shape == (batch_size, *shape_2) - assert tensor_3.shape == (batch_size, *shape_3) + assert tensor_1.shape == (batch_size, *shape_1) # type: ignore + assert tensor_2.shape == (batch_size, *shape_2) # type: ignore + assert tensor_3.shape == (batch_size, *shape_3) # type: ignore batch_data = [ dict( - tensor_1=DEFAULTS.RNG.uniform(size=shape_1), - tensor_2=DEFAULTS.RNG.uniform(size=shape_2), - tensor_3=DEFAULTS.RNG.uniform(size=shape_3), + tensor_1=DEFAULTS.RNG.uniform(size=shape_1), # type: ignore + tensor_2=DEFAULTS.RNG.uniform(size=shape_2), # type: ignore + tensor_3=DEFAULTS.RNG.uniform(size=shape_3), # type: ignore ) for _ in range(batch_size) ] tensors = default_collate_fn(batch_data) - assert tensors["tensor_1"].shape == (batch_size, *shape_1) - assert tensors["tensor_2"].shape == (batch_size, *shape_2) - assert tensors["tensor_3"].shape == (batch_size, *shape_3) + assert tensors["tensor_1"].shape == (batch_size, *shape_1) # type: ignore + assert tensors["tensor_2"].shape == (batch_size, *shape_2) # type: ignore + assert tensors["tensor_3"].shape == (batch_size, *shape_3) # type: ignore with pytest.raises(ValueError, match="Invalid batch"): - default_collate_fn([1]) + default_collate_fn([1]) # type: ignore with pytest.raises(ValueError, match="No data"): default_collate_fn([tuple()]) @@ -516,17 +516,40 @@ def test_mixin_classes(): assert isinstance(model_1d.dtype_, str) assert isinstance(model_1d.device_, str) + # test pth/pt file save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.pth" + # convert save_path to bytes to cover bytes path handling code + save_path = str(save_path).encode() + model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}, use_safetensors=False) + assert save_path.is_file() + loaded_model, _ = Model1D.from_checkpoint(save_path) + assert repr(model_1d) == repr(loaded_model) + save_path = Path(save_path.decode()) + save_path.unlink() - model_1d.save(save_path, dict(n_leads=12)) + with pytest.warns(RuntimeWarning, match="`safetensors` is used by default."): + model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}) + assert save_path.with_suffix(".safetensors").is_file() + save_path.with_suffix(".safetensors").unlink() + # test single safetensors file + save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.safetensors" + model_1d.save(save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}) assert save_path.is_file() - loaded_model, _ = Model1D.from_checkpoint(save_path) assert repr(model_1d) == repr(loaded_model) - save_path.unlink() + # test directory with safetensors + save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_mixin.safetensors" + model_1d.save( + save_path, CFG(dict(n_leads=12)), extra_items={"xxx": {"ones": torch.ones((2, 2))}}, safetensors_single_file=False + ) + assert save_path.with_suffix("").is_dir() + loaded_model, _ = Model1D.from_checkpoint(save_path.with_suffix("")) + assert repr(model_1d) == repr(loaded_model) + shutil.rmtree(save_path.with_suffix("")) + # test remote un-compressed model save_path = Path(__file__).resolve().parents[1] / "tmp" / "test_remote_model" save_path.mkdir(exist_ok=True, parents=True) diff --git a/torch_ecg/_preprocessors/bandpass.py b/torch_ecg/_preprocessors/bandpass.py index a371386a..5e81a796 100644 --- a/torch_ecg/_preprocessors/bandpass.py +++ b/torch_ecg/_preprocessors/bandpass.py @@ -1,9 +1,8 @@ """BandPass filter preprocessor.""" -from numbers import Real -from typing import Any, List, Literal, Optional, Tuple +from typing import Any, List, Literal, Optional, Tuple, Union -import numpy as np +from numpy.typing import NDArray from .base import PreProcessor, preprocess_multi_lead_signal @@ -17,9 +16,9 @@ class BandPass(PreProcessor): Parameters ---------- - lowcut : numbers.Real, optional + lowcut : int or float, optional Low cutoff frequency - highcut : numbers.Real, optional + highcut : int or float, optional High cutoff frequency. filter_type : {"butter", "fir"}, , default "butter" Type of the bandpass filter. @@ -43,8 +42,8 @@ class BandPass(PreProcessor): def __init__( self, - lowcut: Optional[Real] = 0.5, - highcut: Optional[Real] = 45, + lowcut: Optional[Union[int, float]] = 0.5, + highcut: Optional[Union[int, float]] = 45, filter_type: Literal["butter", "fir"] = "butter", filter_order: Optional[int] = None, **kwargs: Any, @@ -59,17 +58,17 @@ def __init__( self.filter_type = filter_type self.filter_order = filter_order - def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]: + def apply(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. - fs : int + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + fs : int or float Sampling frequency of the ECG signal. Returns @@ -84,8 +83,8 @@ def apply(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]: filtered_sig = preprocess_multi_lead_signal( raw_sig=sig, fs=fs, - band_fs=[self.lowcut, self.highcut], - filter_type=self.filter_type, + band_fs=[self.lowcut, self.highcut], # type: ignore + filter_type=self.filter_type, # type: ignore filter_order=self.filter_order, ) return filtered_sig, fs diff --git a/torch_ecg/_preprocessors/base.py b/torch_ecg/_preprocessors/base.py index ca74241d..ae2f0999 100644 --- a/torch_ecg/_preprocessors/base.py +++ b/torch_ecg/_preprocessors/base.py @@ -2,11 +2,11 @@ from abc import ABC, abstractmethod from itertools import repeat -from numbers import Real -from typing import List, Literal, Optional, Tuple +from typing import List, Literal, Optional, Tuple, Union import numpy as np from biosppy.signals.tools import filter_signal +from numpy.typing import NDArray from scipy.ndimage import median_filter from ..cfg import DEFAULTS @@ -30,37 +30,37 @@ class PreProcessor(ReprMixin, ABC): __name__ = "PreProcessor" @abstractmethod - def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + def apply(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. - fs : numbers.Real + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + fs : int or float Sampling frequency of the ECG signal. """ raise NotImplementedError - @add_docstring(apply) - def __call__(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + @add_docstring(apply.__doc__) # type: ignore + def __call__(self, sig: NDArray, fs: Union[int, float]) -> Tuple[NDArray, Union[int, float]]: """alias of :meth:`self.apply`.""" return self.apply(sig, fs) - def _check_sig(self, sig: np.ndarray) -> None: + def _check_sig(self, sig: NDArray) -> None: """Check validity of the signal. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. """ if sig.ndim not in [1, 2, 3]: @@ -73,14 +73,14 @@ def _check_sig(self, sig: np.ndarray) -> None: def preprocess_multi_lead_signal( - raw_sig: np.ndarray, - fs: Real, + raw_sig: NDArray, + fs: Union[int, float], sig_fmt: Literal["channel_first", "lead_first", "channel_last", "lead_last"] = "channel_first", - bl_win: Optional[List[Real]] = None, - band_fs: Optional[List[Real]] = None, + bl_win: Optional[List[Union[int, float]]] = None, + band_fs: Optional[List[Union[int, float]]] = None, filter_type: Literal["butter", "fir"] = "butter", filter_order: Optional[int] = None, -) -> np.ndarray: +) -> NDArray: """Perform preprocessing for multi-lead ECG signal (with units in mV). preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. @@ -90,19 +90,19 @@ def preprocess_multi_lead_signal( ---------- raw_sig : numpy.ndarray The raw ECG signal, with units in mV. - fs : numbers.Real + fs : int or float Sampling frequency of `raw_sig`. sig_fmt : str, default "channel_first" Format of the multi-lead ECG signal, "channel_last" (alias "lead_last"), or "channel_first" (alias "lead_first"). - bl_win : List[numbers.Real], optional + bl_win : List[Union[int, float]], optional Window (units in second) of baseline removal using :meth:`~scipy.ndimage.median_filter`, the first is the shorter one, the second the longer one, a typical pair is ``[0.2, 0.6]``. If is None or empty, baseline removal will not be performed. - band_fs : List[numbers.Real], optional + band_fs : List[Union[int, float]], optional Frequency band of the bandpass filter, a typical pair is ``[0.5, 45]``. Be careful when detecting paced rhythm. @@ -130,9 +130,9 @@ def preprocess_multi_lead_signal( ], f"multi-lead signal format `{sig_fmt}` not supported" if sig_fmt.lower() in ["channel_last", "lead_last"]: # might have a batch dimension at the first axis - filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype) + filtered_ecg = np.moveaxis(raw_sig, -2, -1).astype(DEFAULTS.np_dtype) # type: ignore else: - filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) + filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore # remove baseline if bl_win: @@ -186,13 +186,13 @@ def preprocess_multi_lead_signal( def preprocess_single_lead_signal( - raw_sig: np.ndarray, - fs: Real, - bl_win: Optional[List[Real]] = None, - band_fs: Optional[List[Real]] = None, + raw_sig: NDArray, + fs: Union[int, float], + bl_win: Optional[List[Union[int, float]]] = None, + band_fs: Optional[List[Union[int, float]]] = None, filter_type: Literal["butter", "fir"] = "butter", filter_order: Optional[int] = None, -) -> np.ndarray: +) -> NDArray: """Perform preprocessing for single lead ECG signal (with units in mV). Preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. @@ -201,15 +201,15 @@ def preprocess_single_lead_signal( ---------- raw_sig : numpy.ndarray Raw ECG signal, with units in mV. - fs : numbers.Real + fs : int or float Sampling frequency of `raw_sig`. - bl_win : list (of 2 numbers.Real), optional + bl_win : list (of 2 int or float), optional Window (units in second) of baseline removal using :meth:`~scipy.ndimage.median_filter`, the first is the shorter one, the second the longer one, a typical pair is ``[0.2, 0.6]``. If is None or empty, baseline removal will not be performed. - band_fs : list of numbers.Real, optional + band_fs : list of int or float, optional Frequency band of the bandpass filter, a typical pair is ``[0.5, 45]``. Be careful when detecting paced rhythm. @@ -230,7 +230,7 @@ def preprocess_single_lead_signal( e.g. :meth:`~torch_ecg.utils.butter_bandpass_filter`. """ - filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) + filtered_ecg = np.asarray(raw_sig, dtype=DEFAULTS.np_dtype) # type: ignore assert filtered_ecg.ndim == 1, "single-lead signal should be 1d array" # remove baseline diff --git a/torch_ecg/_preprocessors/baseline_remove.py b/torch_ecg/_preprocessors/baseline_remove.py index 1a46ed47..129d2fbb 100644 --- a/torch_ecg/_preprocessors/baseline_remove.py +++ b/torch_ecg/_preprocessors/baseline_remove.py @@ -4,10 +4,9 @@ """ import warnings -from numbers import Real -from typing import Any, List, Tuple +from typing import Any, List, Tuple, Union -import numpy as np +from numpy.typing import NDArray from .base import PreProcessor, preprocess_multi_lead_signal @@ -46,24 +45,24 @@ def __init__(self, window1: float = 0.2, window2: float = 0.6, **kwargs: Any) -> self.window1, self.window2 = self.window2, self.window1 warnings.warn("values of `window1` and `window2` are switched", RuntimeWarning) - def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. - fs : numbers.Real + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + fs : float or int Sampling frequency of the ECG signal. Returns ------- filtered_sig : :class:`numpy.ndarray` The median filtered (hence baseline removed) ECG signal. - fs : :class:`int` + fs : float or int Sampling frequency of the filtered ECG signal. """ diff --git a/torch_ecg/_preprocessors/normalize.py b/torch_ecg/_preprocessors/normalize.py index dfc8e6d7..bcbca694 100644 --- a/torch_ecg/_preprocessors/normalize.py +++ b/torch_ecg/_preprocessors/normalize.py @@ -3,7 +3,7 @@ from numbers import Real from typing import Any, List, Literal, Tuple, Union -import numpy as np +from numpy.typing import NDArray from ..cfg import DEFAULTS from ..utils.utils_signal import normalize @@ -39,11 +39,11 @@ class Normalize(PreProcessor): ---------- method : {"naive", "min-max", "z-score"}, default "z-score" Normalization method, case insensitive. - mean : numbers.Real or numpy.ndarray, default 0.0 + mean : float or int or numpy.ndarray, default 0.0 Mean value of the normalized signal, or mean values for each lead of the normalized signal. Useless if `method` is "min-max". - std : numbers.Real or numpy.ndarray, default 1.0 + std : float or int or numpy.ndarray, default 1.0 Standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal. Useless if `method` is "min-max". @@ -66,8 +66,8 @@ class Normalize(PreProcessor): def __init__( self, method: Literal["naive", "min-max", "z-score"] = "z-score", - mean: Union[Real, np.ndarray] = 0.0, - std: Union[Real, np.ndarray] = 1.0, + mean: Union[float, int, NDArray] = 0.0, + std: Union[float, int, NDArray] = 1.0, per_channel: bool = False, **kwargs: Any, ) -> None: @@ -89,17 +89,17 @@ def __init__( std, Real ), "mean and std should be real numbers in the non per-channel setting" - def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. - fs : numbers.Real + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + fs : float or int Sampling frequency of the ECG signal. **NOT** used currently. @@ -114,7 +114,7 @@ def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: self._check_sig(sig) normalized_sig = normalize( sig=sig.astype(DEFAULTS.np_dtype), - method=self.method, + method=self.method, # type: ignore mean=self.mean, std=self.std, sig_fmt="channel_first", @@ -181,9 +181,9 @@ class NaiveNormalize(Normalize): Parameters ---------- - mean : numbers.Real or numpy.ndarray, default 0.0 + mean : float or int or numpy.ndarray, default 0.0 Value(s) to be subtracted. - std : numbers.Real or numpy.ndarray, default 1.0 + std : float or int or numpy.ndarray, default 1.0 Value(s) to be divided. per_channel : bool, default False If True, normalization will be done per channel. @@ -203,8 +203,8 @@ class NaiveNormalize(Normalize): def __init__( self, - mean: Union[Real, np.ndarray] = 0.0, - std: Union[Real, np.ndarray] = 1.0, + mean: Union[float, int, NDArray] = 0.0, + std: Union[float, int, NDArray] = 1.0, per_channel: bool = False, **kwargs: Any, ) -> None: @@ -234,10 +234,10 @@ class ZScoreNormalize(Normalize): Parameters ---------- - mean : numbers.Real or numpy.ndarray, default 0.0 + mean : float or int or numpy.ndarray, default 0.0 Mean value of the normalized signal, or mean values for each lead of the normalized signal. - std : numbers.Real or numpy.ndarray, default 1.0 + std : float or int or numpy.ndarray, default 1.0 Standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal. per_channel : bool, default False @@ -258,8 +258,8 @@ class ZScoreNormalize(Normalize): def __init__( self, - mean: Union[Real, np.ndarray] = 0.0, - std: Union[Real, np.ndarray] = 1.0, + mean: Union[float, int, NDArray] = 0.0, + std: Union[float, int, NDArray] = 1.0, per_channel: bool = False, **kwargs: Any, ) -> None: diff --git a/torch_ecg/_preprocessors/preproc_manager.py b/torch_ecg/_preprocessors/preproc_manager.py index 95946643..faf9f22e 100644 --- a/torch_ecg/_preprocessors/preproc_manager.py +++ b/torch_ecg/_preprocessors/preproc_manager.py @@ -4,7 +4,7 @@ from random import sample from typing import List, Optional, Tuple -import numpy as np +from numpy.typing import NDArray from ..utils.misc import ReprMixin from .bandpass import BandPass @@ -101,7 +101,7 @@ def _add_resample(self, **config: dict) -> None: """ self._preprocessors.append(Resample(**config)) - def __call__(self, sig: np.ndarray, fs: int) -> Tuple[np.ndarray, int]: + def __call__(self, sig: NDArray, fs: int) -> Tuple[NDArray, int]: """The main function of the manager, which applies the preprocessors Parameters diff --git a/torch_ecg/_preprocessors/resample.py b/torch_ecg/_preprocessors/resample.py index e58d43e7..d4f67187 100644 --- a/torch_ecg/_preprocessors/resample.py +++ b/torch_ecg/_preprocessors/resample.py @@ -1,10 +1,9 @@ """Resample the signal into fixed sampling frequency or length.""" -from numbers import Real -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union -import numpy as np import scipy.signal as SS +from numpy.typing import NDArray from ..cfg import DEFAULTS from .base import PreProcessor @@ -46,18 +45,17 @@ def __init__(self, fs: Optional[int] = None, siglen: Optional[int] = None, **kwa self.siglen = siglen assert sum([bool(self.fs), bool(self.siglen)]) == 1, "one and only one of `fs` and `siglen` should be set" - def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: + def apply(self, sig: NDArray, fs: Union[float, int]) -> Tuple[NDArray, Union[float, int]]: """Apply the preprocessor to `sig`. Parameters ---------- sig : numpy.ndarray The ECG signal, can be - - - 1d array, which is a single-lead ECG; - - 2d array, which is a multi-lead ECG of "lead_first" format; - - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. - fs : numbers.Real + - 1d array, which is a single-lead ECG; + - 2d array, which is a multi-lead ECG of "lead_first" format; + - 3d array, which is a tensor of several ECGs, of shape ``(batch, lead, siglen)``. + fs : float or int Sampling frequency of the ECG signal. Returns @@ -75,7 +73,7 @@ def apply(self, sig: np.ndarray, fs: Real) -> Tuple[np.ndarray, int]: else: # self.siglen is not None rsmp_sig = SS.resample(sig.astype(DEFAULTS.np_dtype), num=self.siglen, axis=-1) new_fs = int(round(self.siglen / sig.shape[-1] * fs)) - return rsmp_sig, new_fs + return rsmp_sig, new_fs # type: ignore def extra_repr_keys(self) -> List[str]: return [ diff --git a/torch_ecg/augmenters/baseline_wander.py b/torch_ecg/augmenters/baseline_wander.py index 0a4bd186..57e48dcb 100644 --- a/torch_ecg/augmenters/baseline_wander.py +++ b/torch_ecg/augmenters/baseline_wander.py @@ -2,12 +2,12 @@ import multiprocessing as mp from itertools import repeat -from numbers import Real from random import randint from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np import torch +from numpy.typing import NDArray from torch import Tensor from ..cfg import DEFAULTS @@ -88,9 +88,9 @@ class BaselineWanderAugmenter(Augmenter): def __init__( self, fs: int, - bw_fs: Optional[np.ndarray] = None, - ampl_ratio: Optional[np.ndarray] = None, - gaussian: Optional[np.ndarray] = None, + bw_fs: Optional[NDArray] = None, + ampl_ratio: Optional[NDArray] = None, + gaussian: Optional[NDArray] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any, @@ -186,8 +186,8 @@ def forward( if not self.inplace: sig = sig.clone() if self.prob > 0: - sig.add_(gen_baseline_wander(sig, self.fs, self.bw_fs, self.ampl_ratio, self.gaussian)) - return (sig, label, *extra_tensors) + sig.add_(gen_baseline_wander(sig, self.fs, self.bw_fs, self.ampl_ratio, self.gaussian)) # type: ignore + return (sig, label, *extra_tensors) # type: ignore def extra_repr_keys(self) -> List[str]: return [ @@ -223,7 +223,7 @@ def _get_ampl(sig: Tensor, fs: int) -> Tensor: return ampl -def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarray: +def _gen_gaussian_noise(siglen: int, mean: Union[float, int] = 0, std: Union[float, int] = 0) -> NDArray: """Generate 1d Gaussian noise of given length, mean, and standard deviation. @@ -231,9 +231,9 @@ def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarra ---------- siglen : int Length of the noise signal. - mean : numbers.Real, default 0 + mean : float or int, default 0 Mean value of the noise. - std : numbers.Real, default 0 + std : float or int, default 0 Standard deviation of the noise. Returns @@ -248,12 +248,12 @@ def _gen_gaussian_noise(siglen: int, mean: Real = 0, std: Real = 0) -> np.ndarra def _gen_sinusoidal_noise( siglen: int, - start_phase: Real, - end_phase: Real, - amplitude: Real, - amplitude_mean: Real = 0, - amplitude_std: Real = 0, -) -> np.ndarray: + start_phase: Union[float, int], + end_phase: Union[float, int], + amplitude: Union[float, int], + amplitude_mean: Union[float, int] = 0, + amplitude_std: Union[float, int] = 0, +) -> NDArray: """Generate 1d sinusoidal noise of given length, amplitude, start phase, and end phase. @@ -261,15 +261,15 @@ def _gen_sinusoidal_noise( ---------- siglen : int Length of the (noise) signal. - start_phase : numbers.Real + start_phase : float or int Start phase, with units in degrees. - end_phase : numbers.Real + end_phase : float or int End phase, with units in degrees. - amplitude : numbers.Real + amplitude : float or int Amplitude of the sinusoidal curve. - amplitude_mean : numbers.Real + amplitude_mean : float or int Mean amplitude of an extra Gaussian noise. - amplitude_std : numbers.Real, default 0 + amplitude_std : float or int, default 0 Standard deviation of an extra Gaussian noise Returns @@ -286,11 +286,11 @@ def _gen_sinusoidal_noise( def _gen_baseline_wander( siglen: int, - fs: Real, - bw_fs: Union[Real, Sequence[Real]], - amplitude: Union[Real, Sequence[Real]], - amplitude_gaussian: Sequence[Real] = [0, 0], -) -> np.ndarray: + fs: Union[float, int], + bw_fs: Union[float, int, Sequence[Union[float, int]]], + amplitude: Union[float, int, Sequence[Union[float, int]]], + amplitude_gaussian: Sequence[Union[float, int]] = [0, 0], +) -> NDArray: """Generate 1d baseline wander of given length, amplitude, and frequency. @@ -298,14 +298,14 @@ def _gen_baseline_wander( ---------- siglen : int Length of the (noise) signal. - fs : numbers.Real + fs : float or int Sampling frequency of the original signal. - bw_fs : numbers.Real, or list of numbers.Real + bw_fs : float or int, or list of float or int Frequency (Frequencies) of the baseline wander. - amplitude : numbers.Real, or list of numbers.Real + amplitude : float or int, or list of float or int Amplitude of the baseline wander (corr. to each frequency band). - amplitude_gaussian : Tuple[numbers.Real], default [0,0] - 2-tuple of :class:`~numbers.Real`. + amplitude_gaussian : Tuple[float or int], default [0,0] + 2-tuple of :class:`~float or int`. Mean and std of amplitude of an extra Gaussian noise. Returns @@ -319,11 +319,11 @@ def _gen_baseline_wander( """ bw = _gen_gaussian_noise(siglen, amplitude_gaussian[0], amplitude_gaussian[1]) - if isinstance(bw_fs, Real): + if isinstance(bw_fs, (int, float)): _bw_fs = [bw_fs] else: _bw_fs = bw_fs - if isinstance(amplitude, Real): + if isinstance(amplitude, (int, float)): _amplitude = list(repeat(amplitude, len(_bw_fs))) else: _amplitude = amplitude @@ -338,11 +338,11 @@ def _gen_baseline_wander( def gen_baseline_wander( sig: Tensor, - fs: Real, - bw_fs: Union[Real, Sequence[Real]], - ampl_ratio: np.ndarray, - gaussian: np.ndarray, -) -> np.ndarray: + fs: Union[float, int], + bw_fs: Union[float, int, Sequence[Union[float, int]]], + ampl_ratio: NDArray, + gaussian: NDArray, +) -> Tensor: """Generate 1d baseline wander of given length, amplitude, and frequency. @@ -350,9 +350,9 @@ def gen_baseline_wander( ---------- sig : torch.Tensor Batched ECGs to be augmented, of shape (batch, lead, siglen). - fs : numbers.Real + fs : float or int Sampling frequency of the original signal. - bw_fs : numbers.Real, or list of numbers.Real, + bw_fs : float or int, or list of float or int, Frequency (Frequencies) of the baseline wander. ampl_ratio : numpy.ndarray, optional Candidate ratios of noise amplitdes compared to the original ECGs for each `fs`, @@ -363,7 +363,7 @@ def gen_baseline_wander( Returns ------- - bw : numpy.ndarray + bw : torch.Tensor Baseline wander of given length, amplitude, frequency, of shape ``(batch, lead, siglen)``. diff --git a/torch_ecg/augmenters/cutmix.py b/torch_ecg/augmenters/cutmix.py index c0c7edf3..60bd98f2 100644 --- a/torch_ecg/augmenters/cutmix.py +++ b/torch_ecg/augmenters/cutmix.py @@ -1,12 +1,12 @@ """ """ from copy import deepcopy -from numbers import Real from random import shuffle from typing import Any, List, Optional, Sequence, Tuple import numpy as np import torch +from numpy.typing import NDArray from torch import Tensor from ..cfg import DEFAULTS @@ -67,8 +67,8 @@ def __init__( self, fs: Optional[int] = None, num_mix: int = 1, - alpha: Real = 0.5, - beta: Optional[Real] = None, + alpha: float = 0.5, + beta: Optional[float] = None, prob: float = 0.5, inplace: bool = True, **kwargs: Any, @@ -174,7 +174,7 @@ def extra_repr_keys(self) -> List[str]: ] + super().extra_repr_keys() -def _make_intervals(lam: Tensor, siglen: int) -> np.ndarray: +def _make_intervals(lam: Tensor, siglen: int) -> NDArray: """Make intervals for cutmix. Parameters diff --git a/torch_ecg/augmenters/stretch_compress.py b/torch_ecg/augmenters/stretch_compress.py index d6bb8af7..1337dc0e 100644 --- a/torch_ecg/augmenters/stretch_compress.py +++ b/torch_ecg/augmenters/stretch_compress.py @@ -8,6 +8,7 @@ import scipy.signal as SS import torch import torch.nn.functional as F +from numpy.typing import NDArray from torch import Tensor from ..cfg import DEFAULTS @@ -383,10 +384,10 @@ def __init__( def generate( self, seglen: int, - sig: np.ndarray, - *labels: Sequence[np.ndarray], + sig: NDArray, + *labels: Sequence[NDArray], critical_points: Optional[Sequence[int]] = None, - ) -> List[Tuple[Union[np.ndarray, int], ...]]: + ) -> List[Tuple[Union[NDArray, int], ...]]: """Generate stretched or compressed segments from the ECGs. Parameters @@ -465,11 +466,11 @@ def generate( def __generate_segment( self, seglen: int, - sig: np.ndarray, - *labels: Sequence[np.ndarray], + sig: NDArray, + *labels: Sequence[NDArray], start_idx: Optional[int] = None, end_idx: Optional[int] = None, - ) -> Tuple[Union[np.ndarray, int], ...]: + ) -> Tuple[Union[NDArray, int], ...]: """Internal function to generate a stretched or compressed segment. Parameters @@ -564,10 +565,10 @@ def _sample_ratio(self) -> float: def __call__( self, seglen: int, - sig: np.ndarray, - *labels: Sequence[np.ndarray], + sig: NDArray, + *labels: Sequence[NDArray], critical_points: Optional[Sequence[int]] = None, - ) -> List[Tuple[np.ndarray, ...]]: + ) -> List[Tuple[NDArray, ...]]: return self.generate(seglen, sig, *labels, critical_points=critical_points) def extra_repr_keys(self) -> List[str]: diff --git a/torch_ecg/cfg.py b/torch_ecg/cfg.py index 0ffaed0a..70037ff0 100644 --- a/torch_ecg/cfg.py +++ b/torch_ecg/cfg.py @@ -84,7 +84,7 @@ def __setattr__(self, name: str, value: Any) -> None: __setitem__ = __setattr__ - def update(self, new_cfg: Optional[MutableMapping] = None, **kwargs: Any) -> None: + def update(self, new_cfg: Optional[MutableMapping] = None, **kwargs: Any) -> None: # type: ignore """The new hierarchical update method. Parameters @@ -157,9 +157,9 @@ class DTYPE: """ STR: str - NP: np.dtype = None - TORCH: torch.dtype = None - INT: int = None # int representation of the dtype, mainly used for `wfdb.rdrecord` + NP: np.dtype = None # type: ignore + TORCH: torch.dtype = None # type: ignore + INT: int = None # int representation of the dtype, mainly used for `wfdb.rdrecord` # type: ignore def __post_init__(self) -> None: """check consistency""" @@ -168,12 +168,12 @@ def __post_init__(self) -> None: if self.TORCH is None: self.TORCH = eval(f"torch.{self.STR}") if self.INT is None: - self.INT = int(re.search("\\d+", self.STR).group(0)) + self.INT = int(re.search("\\d+", self.STR).group(0)) # type: ignore assert all( [ self.NP == getattr(np, self.STR), self.TORCH == getattr(torch, self.STR), - self.INT == int(re.search("\\d+", self.STR).group(0)), + self.INT == int(re.search("\\d+", self.STR).group(0)), # type: ignore ] ), "inconsistent dtype" diff --git a/torch_ecg/components/inputs.py b/torch_ecg/components/inputs.py index fc348c03..12f07478 100644 --- a/torch_ecg/components/inputs.py +++ b/torch_ecg/components/inputs.py @@ -6,9 +6,9 @@ from copy import deepcopy from typing import List, Literal, Sequence, Tuple, Union -import numpy as np import torch from einops.layers.torch import Rearrange +from numpy.typing import NDArray from torch.nn import functional as F from ..cfg import CFG, DEFAULTS @@ -109,7 +109,7 @@ def __init__(self, config: InputConfig) -> None: self._device = self._config.get("device", DEFAULTS.device) self._post_init() - def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Method to transform the waveform to the input tensor. Parameters @@ -126,11 +126,11 @@ def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: return self.from_waveform(waveform) @abstractmethod - def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Internal method to convert the waveform to the input tensor.""" raise NotImplementedError - def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Transform the waveform to the input tensor. Parameters @@ -289,12 +289,12 @@ def _post_init(self) -> None: """Make sure the input type is `waveform`.""" assert self.input_type == "waveform", "`input_type` must be `waveform`" - def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Internal method to convert the waveform to the input tensor.""" self._values = torch.as_tensor(waveform).to(self.device, self.dtype) return self._values - def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Converts the input :class:`~numpy.ndarray` or :class:`~torch.Tensor` waveform to a :class:`~torch.Tensor`. @@ -320,7 +320,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens return super().from_waveform(waveform) @add_docstring(from_waveform.__doc__) - def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """ """ return self.from_waveform(waveform) @@ -329,16 +329,18 @@ class FFTInput(BaseInput): """Inputs from the FFT, via concatenating the amplitudes and the phases. One can set the following optional parameters for initialization: - - nfft: int - the number of FFT bins. - If nfft is None, the number of FFT bins is computed from the input shape. - - drop_dc: bool, default True - Whether to drop the zero frequency bin (the DC component). - - norm: str, optional - The normalization of the FFT, can be - - "forward" - - "backward" - - "ortho" + + - nfft: int + the number of FFT bins. + If nfft is None, the number of FFT bins is computed from the input shape. + - drop_dc: bool, default True + Whether to drop the zero frequency bin (the DC component). + - norm: str, optional + The normalization of the FFT, can be + + - "forward" + - "backward" + - "ortho" Examples -------- @@ -400,7 +402,7 @@ def _post_init(self) -> None: "ortho", ], f"`norm` must be one of [`forward`, `backward`, `ortho`], got {self.norm}" - def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Internal method to convert the waveform to the input tensor.""" self._values = torch.fft.rfft( torch.as_tensor(waveform).to(self.device, self.dtype), @@ -413,7 +415,7 @@ def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Ten self._values = torch.cat([torch.abs(self._values), torch.angle(self._values)], dim=-2) return self._values - def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Converts the input :class:`~numpy.ndarray` or :class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of FFTs. @@ -441,7 +443,7 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens return super().from_waveform(waveform) @add_docstring(from_waveform.__doc__) - def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: return self.from_waveform(waveform) def extra_repr_keys(self) -> List[str]: @@ -452,25 +454,26 @@ class _SpectralInput(BaseInput): """Inputs from the spectro-temporal domain. One has to set the following parameters for initialization: - - n_bins : int - The number of frequency bins. - - fs (or sample_rate) : int - The sample rate of the waveform. + + - n_bins : int + The number of frequency bins. + - fs (or sample_rate) : int + The sample rate of the waveform. with the following optional parameters with default values: - - window_size : float, default: 1 / 20 - The size of the window in seconds. - - overlap_size : float, default: 1 / 40 - The overlap of the windows in seconds. - - feature_fs : None or float, - The sample rate of the features. - If specified, the features will be resampled - against `fs` to this sample rate. - - to1d : bool, default False - Whether to convert the features to 1D. - NOTE that if `to1d` is True, - then if the convolutions with ``groups=1`` applied to the `input` - acts on all the bins, which is "global" - w.r.t. the `bins` dimension of the corresponding 2d input. + - window_size : float, default: 1 / 20 + The size of the window in seconds. + - overlap_size : float, default: 1 / 40 + The overlap of the windows in seconds. + - feature_fs : None or float, + The sample rate of the features. + If specified, the features will be resampled + against `fs` to this sample rate. + - to1d : bool, default False + Whether to convert the features to 1D. + NOTE that if `to1d` is True, + then if the convolutions with ``groups=1`` applied to the `input` + acts on all the bins, which is "global" + w.r.t. the `bins` dimension of the corresponding 2d input. """ @@ -589,7 +592,7 @@ def _post_init(self) -> None: Rearrange("... channel n_bins time -> ... (channel n_bins) time"), ) - def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def _from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: """Internal method to convert the waveform to the input tensor.""" self._values = self._transform(torch.as_tensor(waveform).to(self.device, self.dtype)) if self.feature_fs is not None: @@ -605,7 +608,7 @@ def _from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Ten self._values = F.interpolate(self._values, scale_factor=scale_factor, recompute_scale_factor=True) return self._values - def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def from_waveform(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: r"""Converts the input :class:`~numpy.ndarray` or :class:`~torch.Tensor` waveform to a :class:`~torch.Tensor` of spectrograms. @@ -635,5 +638,5 @@ def from_waveform(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tens return super().from_waveform(waveform) @add_docstring(from_waveform.__doc__) - def __call__(self, waveform: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + def __call__(self, waveform: Union[NDArray, torch.Tensor]) -> torch.Tensor: return self.from_waveform(waveform) diff --git a/torch_ecg/components/metrics.py b/torch_ecg/components/metrics.py index 2c7fcae1..fa3094e5 100644 --- a/torch_ecg/components/metrics.py +++ b/torch_ecg/components/metrics.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union import numpy as np +from numpy.typing import NDArray from torch import Tensor from ..utils.misc import ReprMixin, add_docstring @@ -55,10 +56,10 @@ class ClassificationMetrics(Metrics): .. code-block:: python def extra_metrics( - labels : np.ndarray - outputs : np.ndarray + labels : NDArray + outputs : NDArray num_classes : Optional[int]=None - weights : Optional[np.ndarray]=None + weights : Optional[NDArray]=None ) -> dict """ @@ -140,10 +141,10 @@ def set_macro(self, macro: bool) -> None: ) def compute( self, - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[np.ndarray] = None, + weights: Optional[NDArray] = None, thr: float = 0.5, ) -> "ClassificationMetrics": labels, outputs = one_hot_pair(labels, outputs, num_classes) @@ -164,154 +165,154 @@ def compute( @add_docstring(compute.__doc__) def __call__( self, - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[np.ndarray] = None, + weights: Optional[NDArray] = None, thr: float = 0.5, ) -> "ClassificationMetrics": return self.compute(labels, outputs, num_classes, weights) @property - def sensitivity(self) -> Union[float, np.ndarray]: + def sensitivity(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}sens"] @property - def recall(self) -> Union[float, np.ndarray]: + def recall(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}sens"] @property - def hit_rate(self) -> Union[float, np.ndarray]: + def hit_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}sens"] @property - def true_positive_rate(self) -> Union[float, np.ndarray]: + def true_positive_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}sens"] @property - def specificity(self) -> Union[float, np.ndarray]: + def specificity(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}spec"] @property - def selectivity(self) -> Union[float, np.ndarray]: + def selectivity(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}spec"] @property - def true_negative_rate(self) -> Union[float, np.ndarray]: + def true_negative_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}spec"] @property - def precision(self) -> Union[float, np.ndarray]: + def precision(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}prec"] @property - def positive_predictive_value(self) -> Union[float, np.ndarray]: + def positive_predictive_value(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}prec"] @property - def negative_predictive_value(self) -> Union[float, np.ndarray]: + def negative_predictive_value(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}npv"] @property - def jaccard_index(self) -> Union[float, np.ndarray]: + def jaccard_index(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}jac"] @property - def threat_score(self) -> Union[float, np.ndarray]: + def threat_score(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}jac"] @property - def critical_success_index(self) -> Union[float, np.ndarray]: + def critical_success_index(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}jac"] @property - def accuracy(self) -> Union[float, np.ndarray]: + def accuracy(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}acc"] @property - def phi_coefficient(self) -> Union[float, np.ndarray]: + def phi_coefficient(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}phi"] @property - def matthews_correlation_coefficient(self) -> Union[float, np.ndarray]: + def matthews_correlation_coefficient(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}phi"] @property - def false_negative_rate(self) -> Union[float, np.ndarray]: + def false_negative_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fnr"] @property - def miss_rate(self) -> Union[float, np.ndarray]: + def miss_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fnr"] @property - def false_positive_rate(self) -> Union[float, np.ndarray]: + def false_positive_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fpr"] @property - def fall_out(self) -> Union[float, np.ndarray]: + def fall_out(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fpr"] @property - def false_discovery_rate(self) -> Union[float, np.ndarray]: + def false_discovery_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fdr"] @property - def false_omission_rate(self) -> Union[float, np.ndarray]: + def false_omission_rate(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}for"] @property - def positive_likelihood_ratio(self) -> Union[float, np.ndarray]: + def positive_likelihood_ratio(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}plr"] @property - def negative_likelihood_ratio(self) -> Union[float, np.ndarray]: + def negative_likelihood_ratio(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}nlr"] @property - def prevalence_threshold(self) -> Union[float, np.ndarray]: + def prevalence_threshold(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}pt"] @property - def balanced_accuracy(self) -> Union[float, np.ndarray]: + def balanced_accuracy(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}ba"] @property - def f1_measure(self) -> Union[float, np.ndarray]: + def f1_measure(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}f1"] @property - def fowlkes_mallows_index(self) -> Union[float, np.ndarray]: + def fowlkes_mallows_index(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}fm"] @property - def bookmaker_informedness(self) -> Union[float, np.ndarray]: + def bookmaker_informedness(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}bm"] @property - def markedness(self) -> Union[float, np.ndarray]: + def markedness(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}mk"] @property - def diagnostic_odds_ratio(self) -> Union[float, np.ndarray]: + def diagnostic_odds_ratio(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}dor"] @property def area_under_the_receiver_operater_characteristic_curve( self, - ) -> Union[float, np.ndarray]: + ) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}auroc"] @property - def auroc(self) -> Union[float, np.ndarray]: + def auroc(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}auroc"] @property - def area_under_the_precision_recall_curve(self) -> Union[float, np.ndarray]: + def area_under_the_precision_recall_curve(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}auprc"] @property - def auprc(self) -> Union[float, np.ndarray]: + def auprc(self) -> Union[float, NDArray]: return self._metrics[f"{self.__prefix}auprc"] @property @@ -348,8 +349,8 @@ class RPeaksDetectionMetrics(Metrics): .. code-block:: python def extra_metrics( - labels : Sequence[Union[Sequence[int], np.ndarray]], - outputs : Sequence[Union[Sequence[int], np.ndarray]], + labels : Sequence[Union[Sequence[int], NDArray]], + outputs : Sequence[Union[Sequence[int], NDArray]], fs : int ) -> dict @@ -395,8 +396,8 @@ def __init__( ) def compute( self, - labels: Sequence[Union[Sequence[int], np.ndarray]], - outputs: Sequence[Union[Sequence[int], np.ndarray]], + labels: Sequence[Union[Sequence[int], NDArray]], + outputs: Sequence[Union[Sequence[int], NDArray]], fs: int, thr: Optional[float] = None, ) -> "RPeaksDetectionMetrics": @@ -410,8 +411,8 @@ def compute( @add_docstring(compute.__doc__) def __call__( self, - labels: Sequence[Union[Sequence[int], np.ndarray]], - outputs: Sequence[Union[Sequence[int], np.ndarray]], + labels: Sequence[Union[Sequence[int], NDArray]], + outputs: Sequence[Union[Sequence[int], NDArray]], fs: int, thr: Optional[float] = None, ) -> "RPeaksDetectionMetrics": @@ -446,8 +447,8 @@ class WaveDelineationMetrics(Metrics): .. code-block:: python def extra_metrics( - labels: Sequence[Union[Sequence[int], np.ndarray]], - outputs: Sequence[Union[Sequence[int], np.ndarray]], + labels: Sequence[Union[Sequence[int], NDArray]], + outputs: Sequence[Union[Sequence[int], NDArray]], fs: int ) -> dict @@ -531,8 +532,8 @@ def set_macro(self, macro: bool) -> None: ) def compute( self, - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], class_map: Dict[str, int], fs: int, mask_format: str = "channel_first", @@ -592,8 +593,8 @@ def compute( @add_docstring(compute.__doc__) def __call__( self, - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], class_map: Dict[str, int], fs: int, mask_format: str = "channel_first", diff --git a/torch_ecg/components/trainer.py b/torch_ecg/components/trainer.py index 127d23e6..c3e20b62 100644 --- a/torch_ecg/components/trainer.py +++ b/torch_ecg/components/trainer.py @@ -10,7 +10,7 @@ from collections import OrderedDict, deque from copy import deepcopy from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -53,18 +53,18 @@ class BaseTrainer(ReprMixin, ABC): Will also be recorded in the checkpoints. `train_config` should at least contain the following keys: - - "monitor": str - - "loss": str - - "n_epochs": int - - "batch_size": int - - "learning_rate": float - - "lr_scheduler": str - - "lr_step_size": int, optional, depending on the scheduler - - "lr_gamma": float, optional, depending on the scheduler - - "max_lr": float, optional, depending on the scheduler - - "optimizer": str - - "decay": float, optional, depending on the optimizer - - "momentum": float, optional, depending on the optimizer + - "monitor": str + - "loss": str + - "n_epochs": int + - "batch_size": int + - "learning_rate": float + - "lr_scheduler": str + - "lr_step_size": int, optional, depending on the scheduler + - "lr_gamma": float, optional, depending on the scheduler + - "max_lr": float, optional, depending on the scheduler + - "optimizer": str + - "decay": float, optional, depending on the optimizer + - "momentum": float, optional, depending on the optimizer collate_fn : callable, optional The collate function for the data loader, defaults to :meth:`default_collate_fn`. @@ -93,7 +93,7 @@ def __init__( dataset_cls: Dataset, model_config: dict, train_config: dict, - collate_fn: Optional[callable] = None, + collate_fn: Optional[Callable] = None, device: Optional[torch.device] = None, lazy: bool = False, ) -> None: @@ -108,7 +108,7 @@ def __init__( self.dataset_cls = dataset_cls self.model_config = CFG(deepcopy(model_config)) self._train_config = CFG(deepcopy(train_config)) - self._train_config.checkpoints = Path(self._train_config.checkpoints) + self._train_config.checkpoints = Path(self._train_config.checkpoints) # type: ignore self.device = device or next(self._model.parameters()).device self.dtype = next(self._model.parameters()).dtype self.model.to(self.device) @@ -150,12 +150,12 @@ def train(self) -> OrderedDict: self._setup_criterion() - if self.train_config.monitor is not None: + if self.train_config.monitor is not None: # type: ignore # if monitor is set but val_loader is None, use train_loader for validation # and choose the best model based on the metrics on the train set if self.val_loader is None and self.val_train_loader is None: self.val_train_loader = self.train_loader - self.log_manager.log_message( + self.log_manager.log_message( # type: ignore ( "No separate validation set is provided, while monitor is set. " "The training set will be used for validation, " @@ -179,7 +179,7 @@ def train(self) -> OrderedDict: ----------------------------------------- """ ) - self.log_manager.log_message(msg) + self.log_manager.log_message(msg) # type: ignore start_epoch = self.epoch for _ in range(start_epoch, self.n_epochs): @@ -193,15 +193,15 @@ def train(self) -> OrderedDict: dynamic_ncols=True, mininterval=1.0, ) as pbar: - self.log_manager.epoch_start(self.epoch) + self.log_manager.epoch_start(self.epoch) # type: ignore # train one epoch self.train_one_epoch(pbar) # evaluate on train set, if debug is True if self.val_train_loader is not None: eval_train_res = self.evaluate(self.val_train_loader) - self.log_manager.log_metrics( - metrics=eval_train_res, + self.log_manager.log_metrics( # type: ignore + metrics=eval_train_res, # type: ignore step=self.global_step, epoch=self.epoch, part="train", @@ -211,8 +211,8 @@ def train(self) -> OrderedDict: # evaluate on val set if self.val_loader is not None: eval_res = self.evaluate(self.val_loader) - self.log_manager.log_metrics( - metrics=eval_res, + self.log_manager.log_metrics( # type: ignore + metrics=eval_res, # type: ignore step=self.global_step, epoch=self.epoch, part="val", @@ -224,19 +224,19 @@ def train(self) -> OrderedDict: eval_res = {} # update best model and best metric if monitor is set - if self.train_config.monitor is not None: - if eval_res[self.train_config.monitor] > self.best_metric: - self.best_metric = eval_res[self.train_config.monitor] + if self.train_config.monitor is not None: # type: ignore + if eval_res[self.train_config.monitor] > self.best_metric: # type: ignore + self.best_metric = eval_res[self.train_config.monitor] # type: ignore self.best_state_dict = self._model.state_dict() self.best_eval_res = deepcopy(eval_res) self.best_epoch = self.epoch self.pseudo_best_epoch = self.epoch - elif self.train_config.early_stopping: - if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta: + elif self.train_config.early_stopping: # type: ignore + if eval_res[self.train_config.monitor] >= self.best_metric - self.train_config.early_stopping.min_delta: # type: ignore self.pseudo_best_epoch = self.epoch - elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience: + elif self.epoch - self.pseudo_best_epoch >= self.train_config.early_stopping.patience: # type: ignore msg = f"early stopping is triggered at epoch {self.epoch}" - self.log_manager.log_message(msg) + self.log_manager.log_message(msg) # type: ignore break msg = textwrap.dedent( @@ -245,62 +245,65 @@ def train(self) -> OrderedDict: obtained at epoch {self.best_epoch} """ ) - self.log_manager.log_message(msg) + self.log_manager.log_message(msg) # type: ignore # save checkpoint - save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}" + save_suffix = f"epochloss_{self.epoch_loss:.5f}_metric_{eval_res[self.train_config.monitor]:.2f}" # type: ignore else: save_suffix = f"epochloss_{self.epoch_loss:.5f}" - save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" - save_path = self.train_config.checkpoints / save_filename - if self.train_config.keep_checkpoint_max != 0: + # save_filename = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}.pth.tar" + save_folder = f"{self.save_prefix}_epoch{self.epoch}_{get_date_str()}_{save_suffix}" + save_path = self.train_config.checkpoints / save_folder # type: ignore + if self.train_config.keep_checkpoint_max != 0: # type: ignore self.save_checkpoint(str(save_path)) self.saved_models.append(save_path) # remove outdated models - if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0: + if len(self.saved_models) > self.train_config.keep_checkpoint_max > 0: # type: ignore model_to_remove = self.saved_models.popleft() try: os.remove(model_to_remove) except Exception: - self.log_manager.log_message(f"failed to remove {str(model_to_remove)}") + self.log_manager.log_message(f"failed to remove {str(model_to_remove)}") # type: ignore # update learning rate using lr_scheduler - if self.train_config.lr_scheduler.lower() == "plateau": + if self.train_config.lr_scheduler.lower() == "plateau": # type: ignore self._update_lr(eval_res) - self.log_manager.epoch_end(self.epoch) + self.log_manager.epoch_end(self.epoch) # type: ignore self.epoch += 1 # save the best model if self.best_metric > -np.inf: - if self.train_config.final_model_name: - save_filename = self.train_config.final_model_name + if self.train_config.final_model_name: # type: ignore + save_folder = self.train_config.final_model_name # type: ignore else: - save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}" - save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" - save_path = self.train_config.model_dir / save_filename + save_suffix = f"metric_{self.best_eval_res[self.train_config.monitor]:.2f}" # type: ignore + # save_filename = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}.pth.tar" + save_folder = f"BestModel_{self.save_prefix}{self.best_epoch}_{get_date_str()}_{save_suffix}" + save_path = self.train_config.model_dir / save_folder # type: ignore # self.save_checkpoint(path=str(save_path)) self._model.save(path=str(save_path), train_config=self.train_config) - self.log_manager.log_message(f"best model is saved at {save_path}") - elif self.train_config.monitor is None: - self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") + self.log_manager.log_message(f"best model is saved at {save_path}") # type: ignore + elif self.train_config.monitor is None: # type: ignore + self.log_manager.log_message("no monitor is set, the last model is selected and saved as the best model") # type: ignore self.best_state_dict = self._model.state_dict() - save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" - save_path = self.train_config.model_dir / save_filename + # save_filename = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}.pth.tar" + save_folder = f"BestModel_{self.save_prefix}{self.epoch}_{get_date_str()}" + save_path = self.train_config.model_dir / save_folder # type: ignore # self.save_checkpoint(path=str(save_path)) self._model.save(path=str(save_path), train_config=self.train_config) else: raise ValueError("No best model found!") - self.log_manager.close() + self.log_manager.close() # type: ignore if not self.best_state_dict: # in case no best model is found, # e.g. monitor is not set, or keep_checkpoint_max is 0 self.best_state_dict = self._model.state_dict() - return self.best_state_dict + return self.best_state_dict # type: ignore def train_one_epoch(self, pbar: tqdm) -> None: """Train one epoch, and update the progress bar @@ -311,16 +314,16 @@ def train_one_epoch(self, pbar: tqdm) -> None: The progress bar for training. """ - for epoch_step, data in enumerate(self.train_loader): + for epoch_step, data in enumerate(self.train_loader): # type: ignore self.global_step += 1 # data is assumed to be a tuple of tensors, of the following order: # signals, labels, *extra_tensors - data = self.augmenter_manager(*data) + data = self.augmenter_manager(*data) # type: ignore out_tensors = self.run_one_step(*data) loss = self.criterion(*out_tensors).to(self.dtype) - if self.train_config.flooding_level > 0: - flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level + if self.train_config.flooding_level > 0: # type: ignore + flood = (loss - self.train_config.flooding_level).abs() + self.train_config.flooding_level # type: ignore self.epoch_loss += loss.item() self.optimizer.zero_grad() flood.backward() @@ -331,7 +334,7 @@ def train_one_epoch(self, pbar: tqdm) -> None: self.optimizer.step() self._update_lr() - if self.global_step % self.train_config.log_step == 0: + if self.global_step % self.train_config.log_step == 0: # type: ignore train_step_metrics = {"loss": loss.item()} if self.scheduler: train_step_metrics.update({"lr": self.scheduler.get_last_lr()[0]}) @@ -347,9 +350,9 @@ def train_one_epoch(self, pbar: tqdm) -> None: "loss (batch)": loss.item(), } ) - if self.train_config.flooding_level > 0: - train_step_metrics.update({"flood": flood.item()}) - self.log_manager.log_metrics( + if self.train_config.flooding_level > 0: # type: ignore + train_step_metrics.update({"flood": flood.item()}) # type: ignore + self.log_manager.log_metrics( # type: ignore metrics=train_step_metrics, step=self.global_step, epoch=self.epoch, @@ -458,22 +461,22 @@ def _update_lr(self, eval_res: Optional[dict] = None) -> None: The evaluation results (metrics). """ - if self.train_config.lr_scheduler.lower() == "none": + if self.train_config.lr_scheduler.lower() == "none": # type: ignore pass - elif self.train_config.lr_scheduler.lower() == "plateau": + elif self.train_config.lr_scheduler.lower() == "plateau": # type: ignore if eval_res is None: return - metrics = eval_res[self.train_config.monitor] + metrics = eval_res[self.train_config.monitor] # type: ignore if isinstance(metrics, torch.Tensor): metrics = metrics.item() - self.scheduler.step(metrics) - elif self.train_config.lr_scheduler.lower() == "step": - self.scheduler.step() - elif self.train_config.lr_scheduler.lower() in [ + self.scheduler.step(metrics) # type: ignore + elif self.train_config.lr_scheduler.lower() == "step": # type: ignore + self.scheduler.step() # type: ignore + elif self.train_config.lr_scheduler.lower() in [ # type: ignore "one_cycle", "onecycle", ]: - self.scheduler.step() + self.scheduler.step() # type: ignore def _setup_from_config(self, train_config: dict) -> None: """Setup the trainer from the training configuration. @@ -492,14 +495,14 @@ def _setup_from_config(self, train_config: dict) -> None: self._validate_train_config() # set aliases - self.n_epochs = self.train_config.n_epochs - self.batch_size = self.train_config.batch_size - self.lr = self.train_config.learning_rate + self.n_epochs = self.train_config.n_epochs # type: ignore + self.batch_size = self.train_config.batch_size # type: ignore + self.lr = self.train_config.learning_rate # type: ignore # setup log manager first self._setup_log_manager() msg = f"training configurations are as follows:\n{dict_to_str(self.train_config)}" - self.log_manager.log_message(msg) + self.log_manager.log_message(msg) # type: ignore # setup directories self._setup_directories() @@ -517,7 +520,7 @@ def _setup_from_config(self, train_config: dict) -> None: def extra_log_suffix(self) -> str: """Extra suffix for the log file name.""" model_name = self._model.__name__ if hasattr(self._model, "__name__") else self._model.__class__.__name__ - return f"{model_name}_{self.train_config.optimizer}_LR_{self.lr}_BS_{self.batch_size}" + return f"{model_name}_{self.train_config.optimizer}_LR_{self.lr}_BS_{self.batch_size}" # type: ignore def _setup_log_manager(self) -> None: """Setup the log manager.""" @@ -528,27 +531,27 @@ def _setup_log_manager(self) -> None: def _setup_directories(self) -> None: """Setup the directories for saving checkpoints and logs.""" if not self.train_config.get("model_dir", None): - self._train_config.model_dir = self.train_config.checkpoints - self._train_config.model_dir = Path(self._train_config.model_dir) - self.train_config.checkpoints.mkdir(parents=True, exist_ok=True) - self.train_config.model_dir.mkdir(parents=True, exist_ok=True) + self._train_config.model_dir = self.train_config.checkpoints # type: ignore + self._train_config.model_dir = Path(self._train_config.model_dir) # type: ignore + self.train_config.checkpoints.mkdir(parents=True, exist_ok=True) # type: ignore + self.train_config.model_dir.mkdir(parents=True, exist_ok=True) # type: ignore def _setup_callbacks(self) -> None: """Setup the callbacks.""" self._train_config.monitor = self.train_config.get("monitor", None) - if self.train_config.monitor is None: + if self.train_config.monitor is None: # type: ignore assert ( - self.train_config.lr_scheduler.lower() != "plateau" + self.train_config.lr_scheduler.lower() != "plateau" # type: ignore ), "monitor is not specified, lr_scheduler should not be ReduceLROnPlateau" self._train_config.keep_checkpoint_max = self.train_config.get("keep_checkpoint_max", 1) - if self._train_config.keep_checkpoint_max < 0: + if self._train_config.keep_checkpoint_max < 0: # type: ignore self._train_config.keep_checkpoint_max = -1 - self.log_manager.log_message( + self.log_manager.log_message( # type: ignore msg="keep_checkpoint_max is set to -1, all checkpoints will be kept", level=logging.WARNING, ) - elif self._train_config.keep_checkpoint_max == 0: - self.log_manager.log_message( + elif self._train_config.keep_checkpoint_max == 0: # type: ignore + self.log_manager.log_message( # type: ignore msg="keep_checkpoint_max is set to 0, no checkpoint will be kept", level=logging.WARNING, ) @@ -617,7 +620,7 @@ def n_val(self) -> int: def _setup_optimizer(self) -> None: """Setup the optimizer.""" - if self.train_config.optimizer.lower() == "adam": + if self.train_config.optimizer.lower() == "adam": # type: ignore optimizer_kwargs = get_kwargs(optim.Adam) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update(dict(lr=self.lr)) @@ -625,20 +628,20 @@ def _setup_optimizer(self) -> None: params=self.model.parameters(), **optimizer_kwargs, ) - elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]: + elif self.train_config.optimizer.lower() in ["adamw", "adamw_amsgrad"]: # type: ignore optimizer_kwargs = get_kwargs(optim.AdamW) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update( dict( lr=self.lr, - amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"), + amsgrad=self.train_config.optimizer.lower().endswith("amsgrad"), # type: ignore ) ) self.optimizer = optim.AdamW( params=self.model.parameters(), **optimizer_kwargs, ) - elif self.train_config.optimizer.lower() == "sgd": + elif self.train_config.optimizer.lower() == "sgd": # type: ignore optimizer_kwargs = get_kwargs(optim.SGD) optimizer_kwargs.update({k: self.train_config.get(k, v) for k, v in optimizer_kwargs.items()}) optimizer_kwargs.update(dict(lr=self.lr)) @@ -648,44 +651,42 @@ def _setup_optimizer(self) -> None: ) else: raise NotImplementedError( - f"optimizer `{self.train_config.optimizer}` not implemented! " + f"optimizer `{self.train_config.optimizer}` not implemented! " # type: ignore "Please use one of the following: `adam`, `adamw`, `adamw_amsgrad`, `sgd`, " "or override this method to setup your own optimizer." ) def _setup_scheduler(self) -> None: """Setup the learning rate scheduler.""" - if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none": + if self.train_config.lr_scheduler is None or self.train_config.lr_scheduler.lower() == "none": # type: ignore self.train_config.lr_scheduler = "none" self.scheduler = None - elif self.train_config.lr_scheduler.lower() == "plateau": + elif self.train_config.lr_scheduler.lower() == "plateau": # type: ignore self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, "max", patience=2, - verbose=False, ) - elif self.train_config.lr_scheduler.lower() == "step": + elif self.train_config.lr_scheduler.lower() == "step": # type: ignore self.scheduler = optim.lr_scheduler.StepLR( self.optimizer, - self.train_config.lr_step_size, - self.train_config.lr_gamma, - # verbose=False, + self.train_config.lr_step_size, # type: ignore + self.train_config.lr_gamma, # type: ignore ) - elif self.train_config.lr_scheduler.lower() in [ + elif self.train_config.lr_scheduler.lower() in [ # type: ignore "one_cycle", "onecycle", ]: self.scheduler = optim.lr_scheduler.OneCycleLR( optimizer=self.optimizer, - max_lr=self.train_config.max_lr, + max_lr=self.train_config.max_lr, # type: ignore epochs=self.n_epochs, - steps_per_epoch=len(self.train_loader), + steps_per_epoch=len(self.train_loader), # type: ignore # verbose=False, ) else: # TODO: add linear and linear with warmup schedulers raise NotImplementedError( - f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! " + f"lr scheduler `{self.train_config.lr_scheduler.lower()}` not implemented for training! " # type: ignore "Please use one of the following: `none`, `plateau`, `step`, `one_cycle`, " "or override this method to setup your own lr scheduler." ) @@ -696,7 +697,7 @@ def _setup_criterion(self) -> None: for k, v in loss_kw.items(): if isinstance(v, torch.Tensor): loss_kw[k] = v.to(device=self.device, dtype=self.dtype) - self.criterion = setup_criterion(self.train_config.loss, **loss_kw) + self.criterion = setup_criterion(self.train_config.loss, **loss_kw) # type: ignore self.criterion.to(self.device) def _check_model_config_compatability(self, model_config: dict) -> bool: @@ -765,16 +766,30 @@ def save_checkpoint(self, path: str) -> None: Path to save the checkpoint """ - torch.save( - { - "model_state_dict": self._model.state_dict(), - "optimizer_state_dict": self.optimizer.state_dict(), - "model_config": make_safe_globals(self.model_config), - "train_config": make_safe_globals(self.train_config), - "epoch": self.epoch, - }, - path, - ) + # if self._model has method `save`, then use it + if hasattr(self._model, "save"): + self._model.save( + path=path, + train_config=self.train_config, + extra_items={ + "optimizer_state_dict": self.optimizer.state_dict(), + "epoch": self.epoch, + }, + use_safetensors=True, + ) + else: + if not str(path).endswith(".pth.tar"): + path = Path(path).with_suffix(".pth.tar") # type: ignore + torch.save( + { + "model_state_dict": self._model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "model_config": make_safe_globals(self.model_config), + "train_config": make_safe_globals(self.train_config), + "epoch": self.epoch, + }, + path, + ) def extra_repr_keys(self) -> List[str]: return [ diff --git a/torch_ecg/databases/aux_data/aha.py b/torch_ecg/databases/aux_data/aha.py index 72fef94d..070d8a2b 100644 --- a/torch_ecg/databases/aux_data/aha.py +++ b/torch_ecg/databases/aux_data/aha.py @@ -11,10 +11,7 @@ import pandas as pd -try: - pd.set_option("future.no_silent_downcasting", True) -except Exception: # pandas._config.config.OptionError: "No such keys(s): 'future.no_silent_downcasting'" - pass +pd.set_option("future.no_silent_downcasting", True) __all__ = [ "df_primary_statements", diff --git a/torch_ecg/databases/aux_data/cinc2020_aux_data.py b/torch_ecg/databases/aux_data/cinc2020_aux_data.py index 2786873a..2bfae5e0 100644 --- a/torch_ecg/databases/aux_data/cinc2020_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2020_aux_data.py @@ -8,8 +8,8 @@ from numbers import Real from typing import Dict, Literal, Optional, Sequence, Union -import numpy as np import pandas as pd +from numpy.typing import NDArray from ...cfg import CFG @@ -248,7 +248,7 @@ def load_weights( classes: Sequence[Union[int, str]] = None, return_fmt: Literal["np", "pd"] = "np" -) -> Union[np.ndarray, pd.DataFrame]: +) -> Union[NDArray, pd.DataFrame]: """Load the weight matrix of the `classes`. Parameters @@ -342,7 +342,7 @@ def get_class_count( exclude_classes: Optional[Sequence[str]] = None, scored_only: bool = False, normalize: bool = True, - threshold: Optional[Real] = 0, + threshold: Union[float, int] = 0, fmt: str = "a", ) -> Dict[str, int]: """Get the number of classes in the `tranches`. @@ -359,14 +359,15 @@ def get_class_count( normalize : bool, default True whether collapse equivalent classes into one or not, used only when `scored_only` is True. - threshold : numbers.Real, default 0 + threshold : float or int, default 0 Minimum ratio (0-1) or absolute number (>1) of a class to be counted. fmt : str, default "a" Format of the names of the classes in the returned dict, can be one of the following (case insensitive): - - "a", abbreviations - - "f", full names - - "s", SNOMED CT Code + + - "a", abbreviations + - "f", full names + - "s", SNOMED CT Code Returns ------- @@ -464,8 +465,8 @@ def get_class_weight( Returns: -------- class_weight : dict - - key: class in the format of `fmt` - - value: weight of a class in `tranches` + key: class in the format of `fmt`, + value: weight of a class in `tranches`. """ class_count = get_class_count( diff --git a/torch_ecg/databases/aux_data/cinc2021_aux_data.py b/torch_ecg/databases/aux_data/cinc2021_aux_data.py index 859aa61d..f291ca9e 100644 --- a/torch_ecg/databases/aux_data/cinc2021_aux_data.py +++ b/torch_ecg/databases/aux_data/cinc2021_aux_data.py @@ -8,8 +8,8 @@ from numbers import Real from typing import Dict, List, Literal, Optional, Sequence, Union -import numpy as np import pandas as pd +from numpy.typing import NDArray from ...cfg import _DATA_CACHE, CFG @@ -336,7 +336,7 @@ def load_weights( classes: Sequence[Union[int, str]] = None, equivalent_classes: Optional[Union[Dict[str, str], List[List[str]]]] = None, return_fmt: Literal["np", "pd"] = "np", -) -> Union[np.ndarray, pd.DataFrame]: +) -> Union[NDArray, pd.DataFrame]: """Load the weight matrix of the `classes`. Parameters @@ -430,7 +430,7 @@ def get_class_count( exclude_classes: Optional[Sequence[str]] = None, scored_only: bool = False, normalize: bool = True, - threshold: Optional[Real] = 0, + threshold: Union[float, int] = 0, fmt: str = "a", ) -> Dict[str, int]: """Get the number of classes in the `tranches`. @@ -447,20 +447,21 @@ def get_class_count( normalize : bool, default True Whether collapse equivalent classes into one or not, used only when `scored_only` is True. - threshold : numbers.Real + threshold : int or float, default 0 Minimum ratio (0-1) or absolute number (>1) of a class to be counted. fmt : str, default "a" Format of the names of the classes in the returned dict, can be one of the following (case insensitive): - - "a", abbreviations - - "f", full names - - "s", SNOMEDCTCode + + - "a", abbreviations + - "f", full names + - "s", SNOMEDCTCode Returns ------- class_count : dict - - key: class in the format of `fmt` - - value: count of a class in `tranches` + key: class in the format of `fmt`, + value: count of a class in `tranches`. """ assert threshold >= 0 @@ -543,9 +544,9 @@ def get_class_weight( fmt : str, default "a" Format of the names of the classes in the returned dict, can be one of the following (case insensitive): - - "a", abbreviations - - "f", full names - - "s", SNOMED CT Code + - "a", abbreviations + - "f", full names + - "s", SNOMED CT Code min_weight : numbers.Real, default 0.5 Minimum value of the weight of all classes, or equivalently the weight of the largest class. @@ -553,8 +554,8 @@ def get_class_weight( Returns: -------- class_weight : dict - - key: class in the format of `fmt` - - value: weight of a class in `tranches` + key: class in the format of `fmt`, + value: weight of a class in `tranches`. """ class_count = get_class_count( diff --git a/torch_ecg/databases/base.py b/torch_ecg/databases/base.py index 7b0cf887..34a51afa 100644 --- a/torch_ecg/databases/base.py +++ b/torch_ecg/databases/base.py @@ -2,10 +2,10 @@ """ Base classes for datasets from different sources: - - PhysioNet - - NSRR - - CPSC - - Other databases +- PhysioNet +- NSRR +- CPSC +- Other databases Remarks ------- @@ -36,6 +36,7 @@ import requests import scipy.signal as SS import wfdb +from numpy.typing import NDArray from pyedflib import EdfReader from ..cfg import _DATA_CACHE, CFG, DEFAULTS @@ -135,11 +136,13 @@ class _DataBase(ReprMixin, ABC): """Universal abstract base class for all databases. Abstract methods that should be implemented by the subclass: + - `_ls_rec`: Find all records in the database. - `load_data`: Load data from a record. - `load_ann`: Load annotations of a record. Abstract properties that should be implemented by the subclass: + - `database_info`: The :class:`DataBaseInfo` object of the database. - `url`: URL(s) for downloading the database. @@ -172,7 +175,7 @@ def __init__( f"`db_dir` is not specified, " f"using default `{db_dir}` as the storage path", RuntimeWarning, ) - self.db_dir = Path(db_dir).expanduser().resolve() + self.db_dir = Path(db_dir).expanduser().resolve() # type: ignore if not self.db_dir.exists(): self.db_dir.mkdir(parents=True, exist_ok=True) warnings.warn( @@ -182,7 +185,7 @@ def __init__( "please use the `download()` method.", RuntimeWarning, ) - self.working_dir = Path(working_dir or DEFAULTS.working_dir).expanduser().resolve().absolute() / self.db_name + self.working_dir = Path(working_dir or DEFAULTS.working_dir).expanduser().resolve().absolute() / self.db_name # type: ignore self.working_dir.mkdir(parents=True, exist_ok=True) self.logger = kwargs.get("logger", None) @@ -256,7 +259,7 @@ def get_citation(self, format: Optional[str] = None, style: Optional[str] = None """ self.database_info.get_citation(lookup=True, format=format, style=style, timeout=10.0, print_result=True) - def _auto_infer_units(self, sig: np.ndarray, sig_type: str = "ECG") -> str: + def _auto_infer_units(self, sig: NDArray, sig_type: str = "ECG") -> str: """Automatically infer the units of the signal. It is assumed that `sig` is not raw signal, but with baseline removed. @@ -286,7 +289,7 @@ def _auto_infer_units(self, sig: np.ndarray, sig_type: str = "ECG") -> str: return units @property - def all_records(self) -> List[str]: + def all_records(self) -> Union[List[str], None]: if self._all_records is None: self._ls_rec() return self._all_records @@ -312,6 +315,7 @@ def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = Non path = self._df_records.loc[rec].path if extension is not None: path = path.with_suffix(extension if extension.startswith(".") else f".{extension}") + path = Path(path).expanduser().resolve() # type: ignore return path def _normalize_leads( @@ -347,24 +351,24 @@ def _normalize_leads( all_leads = self.all_leads err_msg = ( f"`leads` should be a subset of {all_leads} or non-negative integers " - f"less than {len(all_leads)}, but got {leads}" + f"less than {len(all_leads)}, but got {leads}" # type: ignore ) if leads is None or (isinstance(leads, str) and leads.lower() == "all"): _leads = all_leads elif isinstance(leads, str): _leads = [leads] elif isinstance(leads, int): - assert len(all_leads) > leads >= 0, err_msg - _leads = [all_leads[leads]] + assert len(all_leads) > leads >= 0, err_msg # type: ignore + _leads = [all_leads[leads]] # type: ignore else: try: - _leads = [ld if isinstance(ld, str) else all_leads[ld] for ld in leads] + _leads = [ld if isinstance(ld, str) else all_leads[ld] for ld in leads] # type: ignore except Exception: raise AssertionError(err_msg) - assert set(_leads).issubset(all_leads), err_msg + assert set(_leads).issubset(all_leads), err_msg # type: ignore if numeric: - _leads = [all_leads.index(ld) for ld in _leads] - return _leads + _leads = [all_leads.index(ld) for ld in _leads] # type: ignore + return _leads # type: ignore @classmethod def get_arrhythmia_knowledge(cls, arrhythmias: Union[str, List[str]]) -> None: @@ -416,7 +420,9 @@ def __len__(self) -> int: Number of records in the database. """ - return len(self.all_records) + if self._all_records is None: + return 0 + return len(self.all_records) # type: ignore def __getitem__(self, index: int) -> str: """Get the record name by index. @@ -434,7 +440,7 @@ def __getitem__(self, index: int) -> str: Record name. """ - return self.all_records[index] + return self.all_records[index] # type: ignore class PhysioNetDataBase(_DataBase): @@ -490,13 +496,13 @@ def __init__( self.df_all_db_info = get_physionet_dbs(local=(self.verbose <= 2)) if self.db_name not in self.df_all_db_info["db_name"].values: if self.verbose <= 2: - self.logger.warning( + self.logger.warning( # type: ignore f"Database `{self.db_name}` is not found in the local database list. " "Please check if the database name is correct, " "or call method `_update_db_list()` to update the database list." ) else: - self.logger.warning( + self.logger.warning( # type: ignore f"Database `{self.db_name}` is not found in the database list on PhysioNet server. " "Please check if the database name is correct." ) @@ -552,7 +558,7 @@ def _ls_rec(self, db_name: Optional[str] = None, local: bool = True) -> None: len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) - self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") + self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._all_records = self._df_records.index.tolist() except Exception: @@ -572,7 +578,7 @@ def _ls_rec_local(self) -> None: len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) - self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") + self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) self._df_records["path"] = self._df_records["record"].apply(lambda x: (self.db_dir / x).resolve()) self._df_records = self._df_records[self._df_records["path"].apply(lambda x: x.is_file())] @@ -581,16 +587,16 @@ def _ls_rec_local(self) -> None: if len(self._df_records) == 0: print("Please wait patiently to let the reader find " "all records of the database from local storage...") start = time.time() - self._df_records["path"] = get_record_list_recursive(self.db_dir, self.data_ext, relative=False) + self._df_records["path"] = get_record_list_recursive(self.db_dir, self.data_ext, relative=False) # type: ignore if self._subsample is not None: size = min( len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) - self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") + self.logger.debug(f"subsample `{size}` records from `{len(self._df_records)}`") # type: ignore self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) - self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) - self.logger.info(f"Done in {time.time() - start:.3f} seconds!") + self._df_records["path"] = self._df_records["path"].apply(lambda x: Path(x)) # type: ignore + self.logger.info(f"Done in {time.time() - start:.3f} seconds!") # type: ignore self._df_records["record"] = self._df_records["path"].apply(lambda x: x.name) self._df_records.set_index("record", inplace=True) self._all_records = self._df_records.index.values.tolist() @@ -611,17 +617,17 @@ def get_subject_id(self, rec: Union[str, int]) -> int: """ raise NotImplementedError - def load_data( + def load_data( # type: ignore self, rec: Union[str, int], leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", + units: Union[str, None] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load physical (converted from digital) ECG data, which is more understandable for humans; or load digital signal directly. @@ -705,9 +711,9 @@ def load_data( sampto=sampto, physical=units is not None, return_res=DEFAULTS.DTYPE.INT, - channels=[all_leads.index(ld) for ld in _leads], + channels=[all_leads.index(ld) for ld in _leads], # type: ignore ) # use `channels` instead of `channel_names` since there're exceptional cases where `channel_names` has duplicates - wfdb_rec = wfdb.rdrecord(fp, **rdrecord_kwargs) + wfdb_rec = wfdb.rdrecord(fp, **rdrecord_kwargs) # type: ignore # p_signal or d_signal is in the format of "lead_last", and with units in "mV" if units is None: @@ -715,7 +721,7 @@ def load_data( elif units.lower() == "mv": data = wfdb_rec.p_signal elif units.lower() in ["μv", "uv", "muv"]: - data = 1000 * wfdb_rec.p_signal + data = 1000 * wfdb_rec.p_signal # type: ignore if fs is not None: data_fs = fs @@ -724,18 +730,18 @@ def load_data( else: data_fs = wfdb_rec.fs if data_fs != wfdb_rec.fs: - data = SS.resample_poly(data, data_fs, wfdb_rec.fs, axis=0).astype(data.dtype) + data = SS.resample_poly(data, data_fs, wfdb_rec.fs, axis=0).astype(data.dtype) # type: ignore if data_format.lower() in ["channel_first", "lead_first"]: - data = data.T + data = data.T # type: ignore elif data_format.lower() in ["flat", "plain"]: - data = data.flatten() + data = data.flatten() # type: ignore if return_fs: - return data, data_fs - return data + return data, data_fs # type: ignore + return data # type: ignore - def helper(self, items: Union[List[str], str, type(None)] = None) -> None: + def helper(self, items: Union[List[str], str, None] = None) -> None: """Print corr. meanings of symbols belonging to `items`. More details can be found @@ -837,11 +843,11 @@ def get_file_download_url(self, file_name: Union[str, bytes, os.PathLike]) -> st URL of the file to be downloaded. """ - url = posixpath.join( + url = posixpath.join( # type: ignore wfdb.io.download.PN_INDEX_URL, self.db_name, self.version, - file_name, + file_name, # type: ignore ) return url @@ -879,7 +885,7 @@ def s3_url(self) -> str: return f"s3://physionet-open/{self.db_name}/{self.version}/" @property - def url_(self) -> Union[str, type(None)]: + def url_(self) -> Union[str, None]: """URL of the compressed database file for downloading.""" if self._url_compressed is not None: return self._url_compressed @@ -888,7 +894,7 @@ def url_(self) -> Union[str, type(None)]: try: db_desc = self.df_all_db_info[self.df_all_db_info["db_name"] == self.db_name].iloc[0]["db_description"] except IndexError: - self.logger.info(f"\042{self.db_name}\042 is not in the database list hosted at PhysioNet!") + self.logger.info(f"\042{self.db_name}\042 is not in the database list hosted at PhysioNet!") # type: ignore return None db_desc = re.sub(f"[{punct}]+", "", db_desc).lower() db_desc = re.sub("[\\s:]+", "-", db_desc) @@ -924,7 +930,7 @@ def download(self, compressed: bool = True, use_s3: bool = True) -> None: """ if shutil.which("aws") is None: use_s3 = False - self.logger.warning("AWS CLI is not available! Downloading the database from PhysioNet...") + self.logger.warning("AWS CLI is not available! Downloading the database from PhysioNet...") # type: ignore if use_s3: http_get(self.s3_url, self.db_dir) elif compressed: @@ -933,7 +939,7 @@ def download(self, compressed: bool = True, use_s3: bool = True) -> None: self._ls_rec() return else: - self.logger.info("No compressed database available! Downloading the uncompressed version...") + self.logger.info("No compressed database available! Downloading the uncompressed version...") # type: ignore else: wfdb.dl_database( self.db_name, @@ -1085,7 +1091,7 @@ def show_rec_stats(self, rec: Union[str, int]) -> None: """ raise NotImplementedError - def helper(self, items: Union[List[str], str, type(None)] = None) -> None: + def helper(self, items: Union[List[str], str, None] = None) -> None: """Print corr. meanings of symbols belonging to `items`. Parameters @@ -1181,7 +1187,7 @@ def get_subject_id(self, rec: Union[str, int]) -> int: """ raise NotImplementedError - def helper(self, items: Union[List[str], str, type(None)] = None) -> None: + def helper(self, items: Union[List[str], str, None] = None) -> None: """Print corr. meanings of symbols belonging to `items`. Parameters @@ -1331,9 +1337,9 @@ def format_database_docstring(self, indent: Optional[str] = None) -> str: docstring = f"{self.status}\n\n{docstring}" lookup = os.getenv("DB_BIB_LOOKUP", False) - citation = self.get_citation(lookup=lookup, print_result=False) - if citation.startswith("@"): - citation = textwrap.indent(citation, indent) + citation = self.get_citation(lookup=lookup, print_result=False) # type: ignore + if citation.startswith("@"): # type: ignore + citation = textwrap.indent(citation, indent) # type: ignore citation = textwrap.indent(f"""Citation\n--------\n.. code-block:: bibtex\n\n{citation}""", indent) docstring = f"{docstring}\n\n{citation}\n" elif not lookup: @@ -1359,7 +1365,7 @@ def sleep_stage_intervals_to_mask( fs: Optional[int] = None, granularity: int = 30, class_map: Optional[Dict[str, int]] = None, - ) -> np.ndarray: + ) -> NDArray: """Convert sleep stage intervals to sleep stage mask. Parameters @@ -1391,7 +1397,7 @@ def sleep_stage_intervals_to_mask( assert class_map is not None, "`class_map` must be provided" else: class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)} - intervals = { + intervals = { # type: ignore class_map[k]: [[int(round(s / fs / granularity)), int(round(e / fs / granularity))] for s, e in v] for k, v in intervals.items() } @@ -1406,7 +1412,7 @@ def sleep_stage_intervals_to_mask( def plot_hypnogram( self, - mask: np.ndarray, + mask: NDArray, granularity: int = 30, class_map: Optional[Dict[str, int]] = None, **kwargs, @@ -1435,14 +1441,13 @@ def plot_hypnogram( Axes object. """ + import matplotlib.pyplot as plt + if not hasattr(self, "sleep_stage_names"): pass else: class_map = class_map or {k: len(self.sleep_stage_names) - i - 1 for i, k in enumerate(self.sleep_stage_names)} - if "plt" not in globals(): - import matplotlib.pyplot as plt - fig_width = len(mask) * granularity / 3600 / 6 * 20 # stardard width is 20 for 6 hours fig, ax = plt.subplots(figsize=(fig_width, 4)) @@ -1457,8 +1462,8 @@ def plot_hypnogram( ax.set_xlabel("Time", fontsize=18) ax.set_xlim(0, len(mask)) # yticks to the format of sleep stages - yticks = sorted(class_map.values()) - yticklabels = [k for k, v in sorted(class_map.items(), key=lambda x: x[1])] + yticks = sorted(class_map.values()) # type: ignore + yticklabels = [k for k, v in sorted(class_map.items(), key=lambda x: x[1])] # type: ignore ax.set_yticks(yticks) ax.set_yticklabels(yticklabels, fontsize=14) ax.set_ylabel("Sleep Stage", fontsize=18) diff --git a/torch_ecg/databases/cpsc_databases/cpsc2018.py b/torch_ecg/databases/cpsc_databases/cpsc2018.py index a1d92e96..4ce15724 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2018.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2018.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from numpy.typing import NDArray from scipy.io import loadmat from ...cfg import DEFAULTS @@ -264,7 +265,7 @@ def load_data( data_format="channel_first", units: str = "mV", return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load the ECG data of a record. Parameters @@ -332,10 +333,9 @@ def load_ann(self, rec: Union[str, int], ann_format: str = "n") -> List[str]: Record name or index of the record in :attr:`all_records`. ann_format : str, default "n" Format of labels, one of the following (case insensitive): - - - "a", abbreviations - - "f", full names - - "n", numeric codes + - "a", abbreviations + - "f", full names + - "n", numeric codes Returns ------- diff --git a/torch_ecg/databases/cpsc_databases/cpsc2019.py b/torch_ecg/databases/cpsc_databases/cpsc2019.py index a49a5b63..fc229b76 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2019.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2019.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import scipy.signal as SS +from numpy.typing import NDArray from scipy.io import loadmat from ...cfg import DEFAULTS @@ -222,7 +223,7 @@ def load_data( units: str = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load the ECG data of the record `rec`. Parameters @@ -273,7 +274,7 @@ def load_data( return data, data_fs return data - def load_ann(self, rec: Union[int, str]) -> np.ndarray: + def load_ann(self, rec: Union[int, str]) -> NDArray: """Load the annotations (indices of R peaks) of the record `rec`. Parameters @@ -292,14 +293,14 @@ def load_ann(self, rec: Union[int, str]) -> np.ndarray: return ann @add_docstring(load_ann.__doc__) - def load_rpeaks(self, rec: Union[int, str]) -> np.ndarray: + def load_rpeaks(self, rec: Union[int, str]) -> NDArray: """ alias of `self.load_ann` """ return self.load_ann(rec=rec) @add_docstring(load_rpeaks.__doc__) - def load_rpeak_indices(self, rec: Union[int, str]) -> np.ndarray: + def load_rpeak_indices(self, rec: Union[int, str]) -> NDArray: """ alias of `self.load_rpeaks` """ @@ -308,8 +309,8 @@ def load_rpeak_indices(self, rec: Union[int, str]) -> np.ndarray: def plot( self, rec: Union[int, str], - data: Optional[np.ndarray] = None, - ann: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, + ann: Optional[NDArray] = None, ticks_granularity: int = 0, **kwargs: Any, ) -> None: @@ -397,8 +398,8 @@ def webpage(self) -> str: def compute_metrics( - rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]], - rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]], + rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]], + rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]], fs: Real, thr: float = 0.075, verbose: int = 0, diff --git a/torch_ecg/databases/cpsc_databases/cpsc2020.py b/torch_ecg/databases/cpsc_databases/cpsc2020.py index 2139d699..f36f029f 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2020.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2020.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd import scipy.signal as SS +from numpy.typing import NDArray from scipy.io import loadmat from ...cfg import CFG, DEFAULTS @@ -380,7 +381,7 @@ def load_data( units: str = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load the ECG data of the record `rec`. Parameters @@ -444,7 +445,7 @@ def load_ann( rec: Union[int, str], sampfrom: Optional[int] = None, sampto: Optional[int] = None, - ) -> Dict[str, np.ndarray]: + ) -> Dict[str, NDArray]: """Load the annotations of the record `rec`. Parameters @@ -587,12 +588,12 @@ def locate_premature_beats( def plot( self, rec: Union[int, str], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, ticks_granularity: int = 0, sampfrom: Optional[int] = None, sampto: Optional[int] = None, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, ) -> None: """Plot the ECG signal of a record. @@ -750,10 +751,10 @@ def webpage(self) -> str: def compute_metrics( - sbp_true: List[np.ndarray], - pvc_true: List[np.ndarray], - sbp_pred: List[np.ndarray], - pvc_pred: List[np.ndarray], + sbp_true: List[NDArray], + pvc_true: List[NDArray], + sbp_pred: List[NDArray], + pvc_pred: List[NDArray], verbose: int = 0, ) -> Union[Tuple[int], dict]: """Score Function for all (test) records. @@ -771,10 +772,10 @@ def compute_metrics( Tuple of (negative) scores for each ectopic beat type (SBP, PVC), or dict of more scoring details, including - - total_loss: sum of loss of each ectopic beat type (PVC and SPB) - - true_positive: number of true positives of each ectopic beat type - - false_positive: number of false positives of each ectopic beat type - - false_negative: number of false negatives of each ectopic beat type + - total_loss: sum of loss of each ectopic beat type (PVC and SPB) + - true_positive: number of true positives of each ectopic beat type + - false_positive: number of false positives of each ectopic beat type + - false_negative: number of false negatives of each ectopic beat type """ BaseCfg = CFG() diff --git a/torch_ecg/databases/cpsc_databases/cpsc2021.py b/torch_ecg/databases/cpsc_databases/cpsc2021.py index 76aa9d72..990c4f59 100644 --- a/torch_ecg/databases/cpsc_databases/cpsc2021.py +++ b/torch_ecg/databases/cpsc_databases/cpsc2021.py @@ -13,6 +13,7 @@ import pandas as pd import scipy.io as sio import wfdb +from numpy.typing import NDArray from ...cfg import CFG, DEFAULTS from ...utils.misc import add_docstring, get_record_list_recursive3, ms2samples @@ -40,50 +41,50 @@ 7. classification of a record is stored in corresponding .hea file, which can be accessed via the attribute `comments` of a wfdb Record obtained using :func:`wfdb.rdheader`, :func:`wfdb.rdrecord`, and :func:`wfdb.rdsamp`; beat annotations and rhythm annotations can be accessed using the attributes `symbol`, `aux_note` of a ``wfdb`` Annotation obtained using :func:`wfdb.rdann`, corresponding indices in the signal can be accessed via the attribute `sample` 8. challenge task: - - clasification of rhythm types: non-AF rhythm (N), persistent AF rhythm (AFf) and paroxysmal AF rhythm (AFp) - - locating of the onset and offset for any AF episode prediction + - clasification of rhythm types: non-AF rhythm (N), persistent AF rhythm (AFf) and paroxysmal AF rhythm (AFp) + - locating of the onset and offset for any AF episode prediction 9. challenge metrics: - - metrics (Ur, scoring matrix) for classification: + - metrics (Ur, scoring matrix) for classification: - .. tikz:: The scoring matrix for the recording-level classification result. - :align: center - :libs: positioning + .. tikz:: The scoring matrix for the recording-level classification result. + :align: center + :libs: positioning - \tikzstyle{rect} = [rectangle, text width = 50, text centered, inner sep = 3pt, minimum height = 50] - \tikzstyle{txt} = [rectangle, text centered, inner sep = 3pt, minimum height = 1.5] + \tikzstyle{rect} = [rectangle, text width = 50, text centered, inner sep = 3pt, minimum height = 50] + \tikzstyle{txt} = [rectangle, text centered, inner sep = 3pt, minimum height = 1.5] - \node[rect, fill = green!25] at (0,0) (31) {$-0.5$}; - \node[rect, fill = green!10, right = 0 of 31] (32) {$0$}; - \node[rect, fill = red!30, right = 0 of 32] (33) {$+1$}; - \node[rect, fill = green!40, above = 0 of 31] (21) {$-1$}; - \node[rect, fill = red!30, above = 0 of 32] (22) {$+1$}; - \node[rect, fill = green!10, above = 0 of 33] (23) {$0$}; - \node[rect, fill = red!30, above = 0 of 21] (11) {$+1$}; - \node[rect, fill = green!60, above = 0 of 22] (12) {$-2$}; - \node[rect, fill = green!40, above = 0 of 23] (13) {$-1$}; + \node[rect, fill = green!25] at (0,0) (31) {$-0.5$}; + \node[rect, fill = green!10, right = 0 of 31] (32) {$0$}; + \node[rect, fill = red!30, right = 0 of 32] (33) {$+1$}; + \node[rect, fill = green!40, above = 0 of 31] (21) {$-1$}; + \node[rect, fill = red!30, above = 0 of 32] (22) {$+1$}; + \node[rect, fill = green!10, above = 0 of 33] (23) {$0$}; + \node[rect, fill = red!30, above = 0 of 21] (11) {$+1$}; + \node[rect, fill = green!60, above = 0 of 22] (12) {$-2$}; + \node[rect, fill = green!40, above = 0 of 23] (13) {$-1$}; - \node[txt, below = 0 of 31] {N}; - \node[txt, below = 0 of 32] (anchor_h) {AF$_{\text{f}}$}; - \node[txt, below = 0 of 33] {AF$_{\text{p}}$}; - \node[txt, left = 0 of 31] {AF$_{\text{p}}$}; - \node[txt, left = 0 of 21] (anchor_v) {AF$_{\text{f}}$}; - \node[txt, left = 0 of 11] {N}; + \node[txt, below = 0 of 31] {N}; + \node[txt, below = 0 of 32] (anchor_h) {AF$_{\text{f}}$}; + \node[txt, below = 0 of 33] {AF$_{\text{p}}$}; + \node[txt, left = 0 of 31] {AF$_{\text{p}}$}; + \node[txt, left = 0 of 21] (anchor_v) {AF$_{\text{f}}$}; + \node[txt, left = 0 of 11] {N}; - \node[txt, below = 0 of anchor_h] {\large\textbf{Annotation (Label)}}; - \node[txt, left = 0.6 of anchor_v, rotate = 90, anchor = north] {\large\textbf{Prediction}}; + \node[txt, below = 0 of anchor_h] {\large\textbf{Annotation (Label)}}; + \node[txt, left = 0.6 of anchor_v, rotate = 90, anchor = north] {\large\textbf{Prediction}}; - - metric (Ue) for detecting onsets and offsets for AF events (episodes): +1 if the detected onset (or offset) is within ±1 beat of the annotated position, and +0.5 if within ±2 beats. - - final score (U): + - metric (Ue) for detecting onsets and offsets for AF events (episodes): +1 if the detected onset (or offset) is within ±1 beat of the annotated position, and +0.5 if within ±2 beats. + - final score (U): - .. math:: + .. math:: U = \dfrac{1}{N} \sum\limits_{i=1}^N \left( Ur_i + \dfrac{Ma_i}{\max\{Mr_i, Ma_i\}} \right) - where :math:`N` is the number of records, - :math:`Ma` is the number of annotated AF episodes, - :math:`Mr` is the number of predicted AF episodes. + where :math:`N` is the number of records, + :math:`Ma` is the number of annotated AF episodes, + :math:`Mr` is the number of predicted AF episodes. 10. Challenge official website [1]_. Webpage of the database on PhysioNet [2]_. """, @@ -405,7 +406,7 @@ def load_ann( sampfrom: Optional[int] = None, sampto: Optional[int] = None, **kwargs: Any, - ) -> Union[dict, np.ndarray, List[List[int]], str]: + ) -> Union[dict, NDArray, List[List[int]], str]: """Load annotations of the record. Parameters @@ -424,12 +425,9 @@ def load_ann( Key word arguments for functions loading rpeaks, af_episodes, and label respectively, including: - - - fs: int, optional, - the resampling frequency - - fmt: str, - format of af_episodes, or format of label, - for more details, ref. corresponding functions. + - fs: int, optional, the resampling frequency + - fmt: str, format of af_episodes, or format of label, + for more details, ref. corresponding functions. Used only when `field` is specified (not None). @@ -479,7 +477,7 @@ def load_rpeaks( keep_original: bool = False, valid_only: bool = True, fs: Optional[Real] = None, - ) -> np.ndarray: + ) -> NDArray: """Load position (in terms of samples) of rpeaks. Parameters @@ -542,7 +540,7 @@ def load_rpeak_indices( keep_original: bool = False, valid_only: bool = True, fs: Optional[Real] = None, - ) -> np.ndarray: + ) -> NDArray: """alias of `self.load_rpeaks`""" return self.load_rpeaks( rec=rec, @@ -563,7 +561,7 @@ def load_af_episodes( keep_original: bool = False, fs: Optional[Real] = None, fmt: Literal["intervals", "mask", "c_intervals"] = "intervals", - ) -> Union[List[List[int]], np.ndarray]: + ) -> Union[List[List[int]], NDArray]: """Load the episodes of atrial fibrillation, in terms of intervals or mask. @@ -666,9 +664,9 @@ def load_label( The three classes are: - - "non atrial fibrillation", - - "paroxysmal atrial fibrillation", - - "persistent atrial fibrillation". + - "non atrial fibrillation", + - "paroxysmal atrial fibrillation", + - "persistent atrial fibrillation". Parameters ---------- @@ -682,11 +680,10 @@ def load_label( Not used, to keep in accordance with other methods. fmt : str, default "a" Format of the label, case in-sensitive, can be one of - - - "f", "fullname": the full name of the label - - "a", "abbr", "abbrevation": abbreviation for the label - - "n", "num", "number": class number of the label - (in accordance with the settings of the offical class map) + - "f", "fullname": the full name of the label + - "a", "abbr", "abbrevation": abbreviation for the label + - "n", "num", "number": class number of the label + (in accordance with the settings of the offical class map) Returns ------- @@ -709,7 +706,7 @@ def gen_endpoint_score_mask( rec: Union[str, int], bias: dict = {1: 1, 2: 0.5}, verbose: Optional[int] = None, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> Tuple[NDArray, NDArray]: """Generate the scoring mask for the onsets and offsets of af episodes. Parameters @@ -749,8 +746,8 @@ def gen_endpoint_score_mask( def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, ticks_granularity: int = 0, sampfrom: Optional[int] = None, sampto: Optional[int] = None, @@ -1304,7 +1301,7 @@ def gen_endpoint_score_mask( af_intervals: Sequence[Sequence[int]], bias: dict = {1: 1, 2: 0.5}, verbose: int = 0, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[NDArray, NDArray]: """ generate the scoring mask for the onsets and offsets of af episodes, diff --git a/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py b/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py index cf137e02..c9061819 100644 --- a/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py +++ b/torch_ecg/databases/datasets/cinc2020/cinc2020_dataset.py @@ -7,6 +7,7 @@ import numpy as np import torch +from numpy.typing import NDArray from torch.utils.data.dataset import Dataset from tqdm.auto import tqdm @@ -125,7 +126,7 @@ def _load_all_data(self) -> None: self._signals = np.concatenate(self._signals, axis=0).astype(self.dtype) self._labels = np.concatenate(self._labels, axis=0) - def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]: + def _load_one_record(self, rec: str) -> Tuple[NDArray, NDArray]: """Load a record from the database using database reader. NOTE @@ -165,20 +166,20 @@ def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]: return values, labels @property - def signals(self) -> np.ndarray: + def signals(self) -> NDArray: """Cached signals, only available when `lazy=False` or preloading is performed manually. """ return self._signals @property - def labels(self) -> np.ndarray: + def labels(self) -> NDArray: """Cached labels, only available when `lazy=False` or preloading is performed manually. """ return self._labels - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: return self.signals[index], self.labels[index] def __len__(self) -> int: @@ -396,7 +397,7 @@ def __init__( def __len__(self) -> int: return len(self.records) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) rec = self.records[index] diff --git a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py index 9f3a867b..2c9cab4e 100644 --- a/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py +++ b/torch_ecg/databases/datasets/cinc2021/cinc2021_dataset.py @@ -8,6 +8,7 @@ import numpy as np import torch +from numpy.typing import NDArray from torch.utils.data.dataset import Dataset from tqdm.auto import tqdm @@ -141,7 +142,7 @@ def _load_all_data(self) -> None: self._signals = np.concatenate(self._signals, axis=0).astype(self.dtype) self._labels = np.concatenate(self._labels, axis=0) - def _load_one_record(self, rec: str) -> Tuple[np.ndarray, np.ndarray]: + def _load_one_record(self, rec: str) -> Tuple[NDArray, NDArray]: """Load one record from the database using data reader. NOTE @@ -277,20 +278,20 @@ def reload_from_extern(self, ext_ds: "CINC2021Dataset") -> None: self._labels = ext_ds._labels.copy() @property - def signals(self) -> np.ndarray: + def signals(self) -> NDArray: """Cached signals, only available when `lazy=False` or preloading is performed manually. """ return self._signals @property - def labels(self) -> np.ndarray: + def labels(self) -> NDArray: """Cached labels, only available when `lazy=False` or preloading is performed manually. """ return self._labels - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: return self.signals[index], self.labels[index] def __len__(self) -> int: @@ -534,7 +535,7 @@ def __init__( def __len__(self) -> int: return len(self.records) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) rec = self.records[index] diff --git a/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py b/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py index 6b48e943..6ffb248d 100644 --- a/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py +++ b/torch_ecg/databases/datasets/cpsc2019/cpsc2019_dataset.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray from torch.utils.data.dataset import Dataset from tqdm.auto import tqdm @@ -83,7 +84,7 @@ def __init__( if not self.lazy: self._load_all_data() - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if self.lazy: signal, label = self.fdr[index] else: @@ -110,14 +111,14 @@ def _load_all_data(self) -> None: self._labels = np.array(self._labels) @property - def signals(self) -> np.ndarray: + def signals(self) -> NDArray: """Cached signals, only available when `lazy=False` or preloading is performed manually. """ return self._signals @property - def labels(self) -> np.ndarray: + def labels(self) -> NDArray: """Cached labels, only available when `lazy=False` or preloading is performed manually. """ @@ -205,7 +206,7 @@ def __init__( def __len__(self) -> int: return len(self.records) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) rec_name = self.records[index] diff --git a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py index 136c67fc..6051a9d2 100644 --- a/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py +++ b/torch_ecg/databases/datasets/cpsc2021/cpsc2021_dataset.py @@ -46,6 +46,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray from scipy import signal as SS from scipy.io import loadmat, savemat from torch.utils.data.dataset import Dataset @@ -78,19 +79,19 @@ class CPSC2021Dataset(ReprMixin, Dataset): The returned values (tuple) of :meth:`__getitem__` depends on the task: - 1. "qrs_detection": (`data`, `qrs_mask`, None) - 2. "rr_lstm": (`rr_seq`, `rr_af_mask`, `rr_weight_mask`) - 3. "main": (`data`, `af_mask`, `weight_mask`) + 1. "qrs_detection": (`data`, `qrs_mask`, None) + 2. "rr_lstm": (`rr_seq`, `rr_af_mask`, `rr_weight_mask`) + 3. "main": (`data`, `af_mask`, `weight_mask`) where - - `data` shape: ``(n_lead, n_sample)`` - - `qrs_mask` shape: ``(n_sample, 1)`` - - `af_mask` shape: ``(n_sample, 1)`` - - `weight_mask` shape: ``(n_sample, 1)`` - - `rr_seq` shape: ``(n_rr, 1)`` - - `rr_af_mask` shape: ``(n_rr, 1)`` - - `rr_weight_mask` shape: ``(n_rr, 1)`` + - `data` shape: ``(n_lead, n_sample)`` + - `qrs_mask` shape: ``(n_sample, 1)`` + - `af_mask` shape: ``(n_sample, 1)`` + - `weight_mask` shape: ``(n_sample, 1)`` + - `rr_seq` shape: ``(n_rr, 1)`` + - `rr_af_mask` shape: ``(n_rr, 1)`` + - `rr_weight_mask` shape: ``(n_rr, 1)`` Typical values of ``n_sample`` and ``n_rr`` are 6000 and 30, respectively. @@ -386,7 +387,7 @@ def all_rr_seq(self) -> CFG: def __len__(self) -> int: return len(self.fdr) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]: if self.lazy: if self.task in ["qrs_detection"]: return self.fdr[index][:2] @@ -426,7 +427,7 @@ def _get_seg_ann_path(self, seg: str) -> Path: fp = self.segments_dirs.ann[subject] / f"{seg}.{self.segment_ext}" return fp - def _load_seg_data(self, seg: str) -> np.ndarray: + def _load_seg_data(self, seg: str) -> NDArray: """Load the data of the segment. Parameters @@ -456,19 +457,18 @@ def _load_seg_ann(self, seg: str) -> dict: ------- seg_ann : dict Annotations of the segment, containing: - - - rpeaks: indices of rpeaks of the segment - - qrs_mask: mask of qrs complexes of the segment - - af_mask: mask of af episodes of the segment - - interval: interval ([start_idx, end_idx]) in - the original ECG record of the segment + - rpeaks: indices of rpeaks of the segment + - qrs_mask: mask of qrs complexes of the segment + - af_mask: mask of af episodes of the segment + - interval: interval ([start_idx, end_idx]) in + the original ECG record of the segment """ seg_ann_fp = self._get_seg_ann_path(seg) seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")} return seg_ann - def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[NDArray, Dict[str, NDArray]]: """Load the mask(s) of segment. Parameters @@ -510,7 +510,7 @@ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarr seg_mask = seg_mask["af_mask"] return seg_mask - def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray: + def _load_seg_seq_lab(self, seg: str, reduction: int) -> NDArray: """Load sequence labeling annotations of the segment. Parameters @@ -561,7 +561,7 @@ def _get_rr_seq_path(self, seq_name: str) -> Path: fp = self.rr_seq_dirs[subject] / f"{seq_name}.{self.rr_seq_ext}" return fp - def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: + def _load_rr_seq(self, seq_name: str) -> Dict[str, NDArray]: """Load the metadata of the rr_seq. Parameters @@ -573,13 +573,12 @@ def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: ------- dict metadata of sequence of rr intervals, including - - - rr: the sequence of rr intervals, with units in seconds, - of shape ``(self.seglen, 1)`` - - label: label of the rr intervals, 0 for normal, 1 for af, - of shape ``(self.seglen, self.n_classes)`` - - interval: interval of the current rr sequence in the whole - rr sequence in the original record + - rr: the sequence of rr intervals, with units in seconds, + of shape ``(self.seglen, 1)`` + - label: label of the rr intervals, 0 for normal, 1 for af, + of shape ``(self.seglen, self.n_classes)`` + - interval: interval of the current rr sequence in the whole + rr sequence in the original record """ rr_seq_path = self._get_rr_seq_path(seq_name) @@ -680,7 +679,7 @@ def _preprocess_one_record(self, rec: str, force_recompute: bool = False, verbos pps, _ = self.ppm(self.reader.load_data(rec), self.config.fs) savemat(save_fp, {"ecg": pps}, format="5") - def load_preprocessed_data(self, rec: str) -> np.ndarray: + def load_preprocessed_data(self, rec: str) -> NDArray: """Load the preprocessed data of the record. Parameters @@ -860,7 +859,7 @@ def _slice_one_record( def __generate_segment( self, rec: str, - data: np.ndarray, + data: NDArray, start_idx: Optional[int] = None, end_idx: Optional[int] = None, ) -> CFG: @@ -885,12 +884,12 @@ def __generate_segment( dict Segments (meta-)data, containing: - - data: values of the segment, with units in mV - - rpeaks: indices of rpeaks of the segment - - qrs_mask: mask of qrs complexes of the segment - - af_mask: mask of af episodes of the segment - - interval: interval ([start_idx, end_idx]) in the - original ECG record of the segment + - data: values of the segment, with units in mV + - rpeaks: indices of rpeaks of the segment + - qrs_mask: mask of qrs complexes of the segment + - af_mask: mask of af episodes of the segment + - interval: interval ([start_idx, end_idx]) in the + original ECG record of the segment """ assert not all([start_idx is None, end_idx is None]), "at least one of `start_idx` and `end_idx` should be set" @@ -1439,7 +1438,7 @@ def __init__( "main": "af_mask", } - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) if self.task in [ diff --git a/torch_ecg/databases/datasets/ludb/ludb_dataset.py b/torch_ecg/databases/datasets/ludb/ludb_dataset.py index cddd4043..10cbd1b1 100644 --- a/torch_ecg/databases/datasets/ludb/ludb_dataset.py +++ b/torch_ecg/databases/datasets/ludb/ludb_dataset.py @@ -7,6 +7,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray from torch.utils.data.dataset import Dataset from tqdm.auto import tqdm @@ -83,7 +84,7 @@ def __len__(self) -> int: return len(self.leads) * len(self.records) return len(self.records) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) if self.config.use_single_lead: @@ -124,14 +125,14 @@ def _load_all_data(self) -> None: self._labels = np.array(self._labels) @property - def signals(self) -> np.ndarray: + def signals(self) -> NDArray: """Cached signals, only available when `lazy=False` or preloading is performed manually. """ return self._signals @property - def labels(self) -> np.ndarray: + def labels(self) -> NDArray: """Cached labels, only available when `lazy=False` or preloading is performed manually. """ @@ -230,7 +231,7 @@ def __init__( def __len__(self) -> int: return len(self.records) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, np.ndarray]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, NDArray]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) rec = self.records[index] diff --git a/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py b/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py index aa1c6c76..7d324fba 100644 --- a/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py +++ b/torch_ecg/databases/datasets/mitdb/mitdb_dataset.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray from scipy import signal as SS from scipy.io import loadmat, savemat from torch.utils.data.dataset import Dataset @@ -359,7 +360,7 @@ def __len__(self) -> int: return len(self._all_data) return len(self.fdr) - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]: if self.task in ["beat_classification"]: return self._all_data[index], self._all_labels[index] if self.lazy: @@ -401,7 +402,7 @@ def _get_seg_ann_path(self, seg: str) -> Path: fp = self.segments_dirs.ann[rec] / f"{seg}.{self.segment_ext}" return fp - def _load_seg_data(self, seg: str) -> np.ndarray: + def _load_seg_data(self, seg: str) -> NDArray: """Load data of the segment. Parameters @@ -432,18 +433,18 @@ def _load_seg_ann(self, seg: str) -> dict: dict A dictionay of annotations of the segment, including - - rpeaks: indices of rpeaks of the segment - - qrs_mask: mask of qrs complexes of the segment - - rhythm_mask: mask of rhythms of the segment - - interval: interval ([start_idx, end_idx]) in the - original ECG record of the segment + - rpeaks: indices of rpeaks of the segment + - qrs_mask: mask of qrs complexes of the segment + - rhythm_mask: mask of rhythms of the segment + - interval: interval ([start_idx, end_idx]) in the + original ECG record of the segment """ seg_ann_fp = self._get_seg_ann_path(seg) seg_ann = {k: v.flatten() for k, v in loadmat(str(seg_ann_fp)).items() if not k.startswith("__")} return seg_ann - def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarray, Dict[str, np.ndarray]]: + def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[NDArray, Dict[str, NDArray]]: """Load mask(s) of the segment. Parameters @@ -483,7 +484,7 @@ def _load_seg_mask(self, seg: str, task: Optional[str] = None) -> Union[np.ndarr seg_mask = seg_mask["rhythm_mask"] return seg_mask - def _load_seg_seq_lab(self, seg: str, reduction: int) -> np.ndarray: + def _load_seg_seq_lab(self, seg: str, reduction: int) -> NDArray: """Load sequence label of the segment. Parameters @@ -534,7 +535,7 @@ def _get_rr_seq_path(self, seq_name: str) -> Path: fp = self.rr_seq_dirs[rec] / f"{seq_name}.{self.rr_seq_ext}" return fp - def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: + def _load_rr_seq(self, seq_name: str) -> Dict[str, NDArray]: """Load metadata of sequence of rr intervals. Parameters @@ -547,12 +548,12 @@ def _load_rr_seq(self, seq_name: str) -> Dict[str, np.ndarray]: dict Metadata of sequence of rr intervals, including - - rr: the sequence of rr intervals, with units in seconds, - of shape ``(self.seglen, 1)`` - - label: label of the rr intervals, - of shape ``(self.seglen, self.n_classes)`` - - interval: interval of the current rr sequence - in the whole rr sequence in the original record + - rr: the sequence of rr intervals, with units in seconds, + of shape ``(self.seglen, 1)``. + - label: label of the rr intervals, + of shape ``(self.seglen, self.n_classes)``. + - interval: interval of the current rr sequence + in the whole rr sequence in the original record. """ rr_seq_path = self._get_rr_seq_path(seq_name) @@ -772,7 +773,7 @@ def _slice_one_record( def __generate_segment( self, rec: str, - data: np.ndarray, + data: NDArray, start_idx: Optional[int] = None, end_idx: Optional[int] = None, ) -> CFG: @@ -1311,7 +1312,7 @@ def __init__( "af_event": "rhythm_mask", # segmentation of AF events } - def __getitem__(self, index: Union[int, slice]) -> Tuple[np.ndarray, ...]: + def __getitem__(self, index: Union[int, slice]) -> Tuple[NDArray, ...]: if isinstance(index, slice): return collate_fn([self[i] for i in range(*index.indices(len(self)))]) if self.task in [ diff --git a/torch_ecg/databases/nsrr_databases/shhs.py b/torch_ecg/databases/nsrr_databases/shhs.py index 2b911edf..c43d3762 100644 --- a/torch_ecg/databases/nsrr_databases/shhs.py +++ b/torch_ecg/databases/nsrr_databases/shhs.py @@ -5,7 +5,6 @@ import re import warnings from datetime import datetime -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union @@ -13,6 +12,7 @@ import pandas as pd import scipy.signal as SS import xmltodict as xtd +from numpy.typing import NDArray from tqdm.auto import tqdm from ...cfg import DEFAULTS @@ -78,12 +78,14 @@ 2. Obstructive Apnea Index (OAI): - - There is one OAI index in the data set. It reflects obstructive events associated with a 4% desaturation or arousal. Nearly 30% of the cohort has a zero value for this variable + - There is one OAI index in the data set. It reflects obstructive events associated with a 4% desaturation or arousal. + Nearly 30% of the cohort has a zero value for this variable - Dichotomization is suggested (e.g. >=3 or >=4 events per hour indicates positive) 3. Central Apnea Index (CAI): - - Several variables describe central breathing events, with different thresholds for desaturation and requirement/non-requirement of arousals. ~58% of the cohort have zero values + - Several variables describe central breathing events, with different thresholds for desaturation and + requirement/non-requirement of arousals. ~58% of the cohort have zero values - Dichotomization is suggested (e.g. >=3 or >=4 events per hour indicates positive) 4. Sleep Stages: @@ -387,7 +389,7 @@ def _ls_rec(self) -> None: """Find all records in the database directory and store them (path, metadata, etc.) in some private attributes. """ - self.logger.info("Finding `edf` records....") + self.logger.info("Finding `edf` records....") # type: ignore self._df_records = pd.DataFrame() self._df_records["path"] = sorted(self.db_dir.rglob("*.edf")) @@ -396,7 +398,7 @@ def _ls_rec(self) -> None: len(self._df_records), max(1, int(round(self._subsample * len(self._df_records)))), ) - self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) + self._df_records = self._df_records.sample(n=size, random_state=DEFAULTS.SEED, replace=False) # type: ignore # if self._df_records is non-empty, call `form_paths` again if necessary # typically path for a record is like: @@ -423,15 +425,15 @@ def _ls_rec(self) -> None: self._all_records = self._df_records.index.tolist() # update `current_version` - if self.ann_path.is_dir(): - for file in self.ann_path.iterdir(): + if self.ann_path.is_dir(): # type: ignore + for file in self.ann_path.iterdir(): # type: ignore if file.is_file() and len(re.findall(self.version_pattern, file.name)) > 0: self.current_version = re.findall(self.version_pattern, file.name)[0] break - self.logger.info("Loading tables....") + self.logger.info("Loading tables....") # type: ignore # gather tables in self.ann_path and in self.hrv_ann_path - for file in itertools.chain(self.ann_path.glob("*.csv"), self.hrv_ann_path.glob("*.csv")): + for file in itertools.chain(self.ann_path.glob("*.csv"), self.hrv_ann_path.glob("*.csv")): # type: ignore if not file.suffix == ".csv": continue table_name = file.stem.replace(f"-{self.current_version}", "") @@ -440,7 +442,7 @@ def _ls_rec(self) -> None: except UnicodeDecodeError: self._tables[table_name] = pd.read_csv(file, low_memory=False, encoding="latin-1") - self.logger.info("Finding records with HRV annotations....") + self.logger.info("Finding records with HRV annotations....") # type: ignore # find records with hrv annotations self.rec_with_hrv_summary_ann = [] for table_name in ["shhs1-hrv-summary", "shhs2-hrv-summary"]: @@ -457,17 +459,17 @@ def _ls_rec(self) -> None: ) self.rec_with_hrv_detailed_ann = sorted(list(set(self.rec_with_hrv_detailed_ann))) - self.logger.info("Finding records with rpeaks annotations....") + self.logger.info("Finding records with rpeaks annotations....") # type: ignore # find available rpeak annotation files self.rec_with_rpeaks_ann = sorted( - [f.stem.replace("-rpoint", "") for f in self.wave_deli_path.rglob("shhs*-rpoint.csv")] + [f.stem.replace("-rpoint", "") for f in self.wave_deli_path.rglob("shhs*-rpoint.csv")] # type: ignore ) - self.logger.info("Finding records with event annotations....") + self.logger.info("Finding records with event annotations....") # type: ignore # find available event annotation files - self.rec_with_event_ann = sorted([f.stem.replace("-nsrr", "") for f in self.event_ann_path.rglob("shhs*-nsrr.xml")]) + self.rec_with_event_ann = sorted([f.stem.replace("-nsrr", "") for f in self.event_ann_path.rglob("shhs*-nsrr.xml")]) # type: ignore self.rec_with_event_profusion_ann = sorted( - [f.stem.replace("-profusion", "") for f in self.event_profusion_ann_path.rglob("shhs*-profusion.xml")] + [f.stem.replace("-profusion", "") for f in self.event_profusion_ann_path.rglob("shhs*-profusion.xml")] # type: ignore ) self._df_records["available_signals"] = None @@ -568,11 +570,11 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str], mininterval=1.0, disable=(self.verbose < 1), ): - rec = row.name - if self._df_records.loc[rec, "available_signals"] is not None: + rec = row.name # type: ignore + if self._df_records.loc[rec, "available_signals"] is not None: # type: ignore continue available_signals = self.get_available_signals(rec) - self._df_records.at[rec, "available_signals"] = available_signals + self._df_records.at[rec, "available_signals"] = available_signals # type: ignore return if isinstance(rec, int): @@ -580,8 +582,8 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str], if rec in self._df_records.index: available_signals = self._df_records.loc[rec, "available_signals"] - if available_signals is not None and len(available_signals) > 0: - return available_signals + if available_signals is not None and len(available_signals) > 0: # type: ignore + return available_signals # type: ignore frp = self.get_absolute_path(rec) try: @@ -590,10 +592,10 @@ def get_available_signals(self, rec: Union[str, int, None]) -> Union[List[str], self.safe_edf_file_operation("open", frp) except OSError: return None - available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] + available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore self.safe_edf_file_operation("close") self._df_records.at[rec, "available_signals"] = available_signals - self.all_signals = self.all_signals.union(set(available_signals)) + self.all_signals = self.all_signals.union(set(available_signals)) # type: ignore else: available_signals = [] return available_signals @@ -639,7 +641,7 @@ def get_visitnumber(self, rec: Union[str, int]) -> int: Visit number extracted from `rec`. """ - return self.split_rec_name(rec)["visitnumber"] + return self.split_rec_name(rec)["visitnumber"] # type: ignore def get_tranche(self, rec: Union[str, int]) -> str: """Get ``tranche`` ("shhs1" or "shhs2") from `rec`. @@ -656,7 +658,7 @@ def get_tranche(self, rec: Union[str, int]) -> str: Tranche extracted from `rec`. """ - return self.split_rec_name(rec)["tranche"] + return self.split_rec_name(rec)["tranche"] # type: ignore def get_nsrrid(self, rec: Union[str, int]) -> int: """Get ``nsrrid`` from `rec`. @@ -673,14 +675,14 @@ def get_nsrrid(self, rec: Union[str, int]) -> int: ``nsrrid`` extracted from `rec`. """ - return self.split_rec_name(rec)["nsrrid"] + return self.split_rec_name(rec)["nsrrid"] # type: ignore def get_fs( self, rec: Union[str, int], sig: str = "ECG", rec_path: Optional[Union[str, bytes, os.PathLike]] = None, - ) -> Real: + ) -> Union[float, int]: """Get the sampling frequency of a signal of a record. Parameters @@ -698,7 +700,7 @@ def get_fs( Returns ------- - fs : numbers.Real + fs : float or int Sampling frequency of the signal `sig` of the record `rec`. If corresponding signal (.edf) file is not available, or the signal file does not contain the signal `sig`, @@ -708,26 +710,26 @@ def get_fs( if isinstance(rec, int): rec = self[rec] sig = self.match_channel(sig, raise_error=False) - assert sig in self.all_signals.union({"rpeak"}), f"Invalid signal name: `{sig}`" + assert sig in self.all_signals.union({"rpeak"}), f"Invalid signal name: `{sig}`" # type: ignore if sig.lower() == "rpeak": df_rpeaks_with_type_info = self.load_wave_delineation_ann(rec) if df_rpeaks_with_type_info.empty: - self.logger.info(f"Rpeak annotation file corresponding to `{rec}` is not available.") + self.logger.info(f"Rpeak annotation file corresponding to `{rec}` is not available.") # type: ignore return -1 return df_rpeaks_with_type_info.iloc[0]["samplingrate"] frp = self.get_absolute_path(rec, rec_path) if not frp.exists(): - self.logger.info(f"Signal (.edf) file corresponding to `{rec}` is not available.") + self.logger.info(f"Signal (.edf) file corresponding to `{rec}` is not available.") # type: ignore return -1 self.safe_edf_file_operation("open", frp) sig = self.match_channel(sig) - available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] + available_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore if sig not in available_signals: - self.logger.info(f"Signal `{sig}` is not available in signal file corresponding to `{rec}`.") + self.logger.info(f"Signal `{sig}` is not available in signal file corresponding to `{rec}`.") # type: ignore return -1 chn_num = available_signals.index(sig) - fs = self.file_opened.getSampleFrequency(chn_num) + fs = self.file_opened.getSampleFrequency(chn_num) # type: ignore self.safe_edf_file_operation("close") return fs @@ -761,15 +763,15 @@ def get_chn_num( """ sig = self.match_channel(sig) available_signals = self.get_available_signals(rec) - if sig not in available_signals: + if sig not in available_signals: # type: ignore if isinstance(rec, int): rec = self[rec] - self.logger.info( + self.logger.info( # type: ignore f"Signal (.edf) file corresponding to `{rec}` is not available, or" f"signal `{sig}` is not available in signal file corresponding to `{rec}`." ) return -1 - chn_num = available_signals.index(self.match_channel(sig)) + chn_num = available_signals.index(self.match_channel(sig)) # type: ignore return chn_num def match_channel(self, channel: str, raise_error: bool = True) -> str: @@ -797,7 +799,7 @@ def match_channel(self, channel: str, raise_error: bool = True) -> str: raise ValueError(f"No channel named `{channel}`") return channel - def get_absolute_path( + def get_absolute_path( # type: ignore self, rec: Union[str, int], rec_path: Optional[Union[str, bytes, os.PathLike]] = None, @@ -823,7 +825,7 @@ def get_absolute_path( """ if rec_path is not None: - rp = Path(rec_path) + rp = Path(rec_path) # type: ignore return rp assert rec_type in self.folder_or_file, ( @@ -835,7 +837,7 @@ def get_absolute_path( tranche, nsrrid = [self.split_rec_name(rec)[k] for k in ["tranche", "nsrrid"]] # rp = self._df_records.loc[rec, rec_type] - rp = self.folder_or_file[rec_type] / tranche / f"{rec}{self.extension[rec_type]}" + rp = self.folder_or_file[rec_type] / tranche / f"{rec}{self.extension[rec_type]}" # type: ignore return rp def database_stats(self) -> None: @@ -856,12 +858,12 @@ def show_rec_stats(self, rec: Union[str, int], rec_path: Optional[Union[str, byt """ frp = self.get_absolute_path(rec, rec_path, rec_type="psg") self.safe_edf_file_operation("open", frp) - for chn, lb in enumerate(self.file_opened.getSignalLabels()): + for chn, lb in enumerate(self.file_opened.getSignalLabels()): # type: ignore print("SignalLabel:", lb) - print("Prefilter:", self.file_opened.getPrefilter(chn)) - print("Transducer:", self.file_opened.getTransducer(chn)) - print("PhysicalDimension:", self.file_opened.getPhysicalDimension(chn)) - print("SampleFrequency:", self.file_opened.getSampleFrequency(chn)) + print("Prefilter:", self.file_opened.getPrefilter(chn)) # type: ignore + print("Transducer:", self.file_opened.getTransducer(chn)) # type: ignore + print("PhysicalDimension:", self.file_opened.getPhysicalDimension(chn)) # type: ignore + print("SampleFrequency:", self.file_opened.getSampleFrequency(chn)) # type: ignore print("*" * 40) self.safe_edf_file_operation("close") @@ -870,11 +872,11 @@ def load_psg_data( rec: Union[str, int], channel: str = "all", rec_path: Optional[Union[str, bytes, os.PathLike]] = None, - sampfrom: Optional[Real] = None, - sampto: Optional[Real] = None, - fs: Optional[int] = None, + sampfrom: Optional[Union[float, int]] = None, + sampto: Optional[Union[float, int]] = None, + fs: Optional[Union[float, int]] = None, physical: bool = True, - ) -> Union[Dict[str, Tuple[np.ndarray, Real]], Tuple[np.ndarray, Real]]: + ) -> Union[Dict[str, Tuple[NDArray, Union[float, int]]], Tuple[NDArray, Union[float, int]]]: """Load PSG data of the record. Parameters @@ -888,13 +890,13 @@ def load_psg_data( rec_path : `path-like`, optional Path of the file which contains the PSG data. If is None, default path will be used. - sampfrom : numbers.Real, optional + sampfrom : float or int, optional Start time (units in seconds) of the data to be loaded, valid only when `channel` is some specific channel. - sampto : numbers.Real, optional + sampto : float or int, optional End time (units in seconds) of the data to be loaded, valid only when `channel` is some specific channel - fs : numbers.Real, optional + fs : float or int, optional Sampling frequency of the loaded data. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -908,11 +910,11 @@ def load_psg_data( dict or tuple If `channel` is "all", then a dictionary will be returned: - - keys: PSG channel names; - - values: PSG data and sampling frequency + - keys: PSG channel names; + - values: PSG data and sampling frequency Otherwise, a 2-tuple will be returned: - (:class:`numpy.ndarray`, :class:`numbers.Real`), which is the + (:class:`numpy.ndarray`, :class:`int` or :class:`float`), which is the PSG data of the channel `channel` and its sampling frequency. """ @@ -923,17 +925,17 @@ def load_psg_data( if chn == "all": ret_data = { k: ( - self.file_opened.readSignal(idx, digital=not physical), - self.file_opened.getSampleFrequency(idx), + self.file_opened.readSignal(idx, digital=not physical), # type: ignore + self.file_opened.getSampleFrequency(idx), # type: ignore ) - for idx, k in enumerate(self.file_opened.getSignalLabels()) + for idx, k in enumerate(self.file_opened.getSignalLabels()) # type: ignore } else: - all_signals = [s.lower() for s in self.file_opened.getSignalLabels()] - assert chn in all_signals, f"`channel` should be one of `{self.file_opened.getSignalLabels()}`, but got `{chn}`" + all_signals = [s.lower() for s in self.file_opened.getSignalLabels()] # type: ignore + assert chn in all_signals, f"`channel` should be one of `{self.file_opened.getSignalLabels()}`, but got `{chn}`" # type: ignore idx = all_signals.index(chn) - data_fs = self.file_opened.getSampleFrequency(idx) - data = self.file_opened.readSignal(idx, digital=not physical) + data_fs = self.file_opened.getSampleFrequency(idx) # type: ignore + data = self.file_opened.readSignal(idx, digital=not physical) # type: ignore # the `readSignal` method of `EdfReader` does NOT treat # the parameters `start` and `n` correctly # so we have to do it manually @@ -962,10 +964,10 @@ def load_ecg_data( sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", - fs: Optional[int] = None, + units: Union[str, None] = "mV", + fs: Optional[Union[float, int]] = None, return_fs: bool = True, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Union[float, int]]]: """Load ECG data of the record. Parameters @@ -988,7 +990,7 @@ def load_ecg_data( units : str or None, default "mV" Units of the output signal, can also be "μV" (aliases "uV", "muV"). None for digital data, without digital-to-physical conversion. - fs : numbers.Real, optional + fs : float or int, optional Sampling frequency of the loaded data. If not None, the loaded data will be resampled to this frequency, otherwise, the original sampling frequency will be used. @@ -999,7 +1001,7 @@ def load_ecg_data( ------- data : numpy.ndarray The loaded ECG data. - data_fs : numbers.Real + data_fs : float or int Sampling frequency of the loaded ECG data. Returned if `return_fs` is True. @@ -1029,7 +1031,7 @@ def load_ecg_data( fs=fs, physical=units is not None, ) - data = data.astype(DEFAULTS.DTYPE.NP) + data = data.astype(DEFAULTS.DTYPE.NP) # type: ignore if units is not None and units.lower() in ["μv", "uv", "muv"]: data *= 1e3 @@ -1039,25 +1041,25 @@ def load_ecg_data( data = data[:, np.newaxis] if return_fs: - return data, data_fs + return data, data_fs # type: ignore return data @add_docstring( " " * 8 + "NOTE: one should call `load_psg_data` to load other channels.", mode="append", ) - @add_docstring(load_ecg_data.__doc__) - def load_data( + @add_docstring(load_ecg_data.__doc__) # type: ignore + def load_data( # type: ignore self, rec: Union[str, int], rec_path: Optional[Union[str, bytes, os.PathLike]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, data_format: str = "channel_first", - units: Union[str, type(None)] = "mV", + units: Union[str, None] = "mV", fs: Optional[int] = None, return_fs: bool = True, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Union[float, int]]]: """alias of `load_ecg_data`""" return self.load_ecg_data( rec=rec, @@ -1070,13 +1072,13 @@ def load_data( return_fs=return_fs, ) - def load_ann( + def load_ann( # type: ignore self, rec: Union[str, int], ann_type: str, ann_path: Optional[Union[str, bytes, os.PathLike]] = None, **kwargs: Any, - ) -> Union[np.ndarray, pd.DataFrame, dict]: + ) -> Union[NDArray, pd.DataFrame, dict, None]: """Load annotations of specific type of the record. Parameters @@ -1160,7 +1162,7 @@ def load_event_ann( df_events["EventType"] = df_events["EventType"].apply(lambda s: s.split("|")[1]) df_events["EventConcept"] = df_events["EventConcept"].apply(lambda s: s.split("|")[1]) for c in ["Start", "Duration", "SpO2Nadir", "SpO2Baseline"]: - df_events[c] = df_events[c].apply(self.str_to_real_number) + df_events[c] = df_events[c].apply(self.str_to_real_number) # type: ignore return df_events @@ -1200,7 +1202,7 @@ def load_event_profusion_ann( sleep_stage_list = [int(ss) for ss in doc["CMPStudyConfig"]["SleepStages"]["SleepStage"]] df_events = pd.DataFrame(doc["CMPStudyConfig"]["ScoredEvents"]["ScoredEvent"]) for c in ["Start", "Duration", "LowestSpO2", "Desaturation"]: - df_events[c] = df_events[c].apply(self.str_to_real_number) + df_events[c] = df_events[c].apply(self.str_to_real_number) # type: ignore ret = {"sleep_stage_list": sleep_stage_list, "df_events": df_events} return ret @@ -1320,7 +1322,7 @@ def load_sleep_ann( df_sleep_ann = df_hrv_ann[self.sleep_ann_keys_from_hrv].reset_index(drop=True) else: df_sleep_ann = pd.DataFrame(columns=self.sleep_ann_keys_from_hrv) - self.logger.debug( + self.logger.debug( # type: ignore f"record `{rec}` has `{len(df_sleep_ann)}` sleep annotations from corresponding " f"hrv-5min (detailed) annotation file, with `{len(self.sleep_ann_keys_from_hrv)}` column(s)" ) @@ -1331,7 +1333,7 @@ def load_sleep_ann( df_sleep_ann = df_event_ann[_cols] else: df_sleep_ann = pd.DataFrame(columns=_cols) - self.logger.debug( + self.logger.debug( # type: ignore f"record `{rec}` has `{len(df_sleep_ann)}` sleep annotations from corresponding " f"event-nsrr annotation file, with `{len(_cols)}` column(s)" ) @@ -1340,7 +1342,7 @@ def load_sleep_ann( # temporarily finished # latter to make imporvements df_sleep_ann = dict_event_ann - self.logger.debug( + self.logger.debug( # type: ignore f"record `{rec}` has `{len(df_sleep_ann['df_events'])}` sleep event annotations " "from corresponding event-profusion annotation file, " f"with `{len(df_sleep_ann['df_events'].columns)}` column(s)" @@ -1449,11 +1451,11 @@ def load_sleep_stage_ann( ) if source.lower() != "event_profusion": - self.logger.debug( - f"record `{rec}` has `{len(df_tmp)}` raw (epoch_len = 5min) sleep stage annotations, " + self.logger.debug( # type: ignore + f"record `{rec}` has `{len(df_tmp)}` raw (epoch_len = 5min) sleep stage annotations, " # type: ignore f"with `{len(self.sleep_stage_ann_keys_from_hrv)}` column(s)" ) - self.logger.debug( + self.logger.debug( # type: ignore f"after being transformed (epoch_len = 30sec), record `{rec}` has {len(df_sleep_stage_ann)} " f"sleep stage annotations, with `{len(self.sleep_stage_keys)}` column(s)" ) @@ -1509,7 +1511,7 @@ def load_sleep_event_ann( for _, row in df_sleep_ann.iterrows(): if row["hasrespevent"] == 0: continue - l_events = row[self.sleep_event_ann_keys_from_hrv[1:-1]].values.reshape( + l_events = row[self.sleep_event_ann_keys_from_hrv[1:-1]].values.reshape( # type: ignore (len(self.sleep_event_ann_keys_from_hrv) // 2 - 1, 2) ) l_events = l_events[~np.isnan(l_events[:, 0])] @@ -1521,11 +1523,11 @@ def load_sleep_event_ann( ) df_sleep_event_ann = df_sleep_event_ann[self.sleep_event_keys] - self.logger.debug( + self.logger.debug( # type: ignore f"record `{rec}` has `{len(df_sleep_ann)}` raw (epoch_len = 5min) sleep event " f"annotations from hrv, with `{len(self.sleep_event_ann_keys_from_hrv)}` column(s)" ) - self.logger.debug(f"after being transformed, record `{rec}` has `{len(df_sleep_event_ann)}` sleep event(s)") + self.logger.debug(f"after being transformed, record `{rec}` has `{len(df_sleep_event_ann)}` sleep event(s)") # type: ignore elif source.lower() == "event": if event_types is None: event_types = ["respiratory", "arousal"] @@ -1570,7 +1572,7 @@ def load_sleep_event_ann( _cols = _cols | set(self.long_event_names_from_event[3:4]) _cols = list(_cols) - self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") + self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") # type: ignore df_sleep_event_ann = df_sleep_ann[df_sleep_ann["EventConcept"].isin(_cols)].reset_index(drop=True) df_sleep_event_ann = df_sleep_event_ann.rename( @@ -1631,7 +1633,7 @@ def load_sleep_event_ann( _cols = _cols | set(self.event_names_from_event_profusion[3:4]) _cols = list(_cols) - self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") + self.logger.debug(f"for record `{rec}`, _cols = `{_cols}`") # type: ignore df_sleep_event_ann = df_sleep_ann[df_sleep_ann["Name"].isin(_cols)].reset_index(drop=True) df_sleep_event_ann = df_sleep_event_ann.rename( @@ -1649,7 +1651,7 @@ def load_sleep_event_ann( else: raise ValueError(f"Source `{source}` not supported, " "only `hrv`, `event`, `event_profusion` are supported") - return df_sleep_event_ann + return df_sleep_event_ann # type: ignore def load_apnea_ann( self, @@ -1727,7 +1729,7 @@ def load_wave_delineation_ann( file_path = self.get_absolute_path(rec, wave_deli_path, rec_type="wave_delineation") if not file_path.is_file(): - self.logger.debug( + self.logger.debug( # type: ignore f"The annotation file of wave delineation of record `{rec}` has not been downloaded yet. " f"Or the path `{str(file_path)}` is not correct. " f"Or `{rec}` does not have `rpeak.csv` annotation file. Please check!" @@ -1746,7 +1748,7 @@ def load_rpeak_ann( exclude_abnormal_beats: bool = True, units: Optional[Literal["s", "ms"]] = None, **kwargs: Any, - ) -> np.ndarray: + ) -> NDArray: """Load annotations on R peaks of the record. Parameters @@ -1788,7 +1790,7 @@ def load_rpeak_ann( rpeaks = df_rpeaks_with_type_info[~df_rpeaks_with_type_info["Type"].isin(exclude_beat_types)]["rpointadj"].values if units is None: - rpeaks = (np.round(rpeaks)).astype(int) + rpeaks = (np.round(rpeaks)).astype(int) # type: ignore elif units.lower() == "s": fs = df_rpeaks_with_type_info.iloc[0]["samplingrate"] rpeaks = rpeaks / fs @@ -1811,7 +1813,7 @@ def load_rr_ann( rpeak_ann_path: Optional[Union[str, bytes, os.PathLike]] = None, units: Literal["s", "ms", None] = "s", **kwargs: Any, - ) -> np.ndarray: + ) -> NDArray: """Load annotations on RR intervals of the record. Parameters @@ -1852,7 +1854,7 @@ def load_nn_ann( rpeak_ann_path: Optional[Union[str, bytes, os.PathLike]] = None, units: Union[str, None] = "s", **kwargs: Any, - ) -> np.ndarray: + ) -> NDArray: """Load annotations on NN intervals of the record. Parameters @@ -1912,7 +1914,7 @@ def locate_artifacts( rec: Union[str, int], wave_deli_path: Optional[Union[str, bytes, os.PathLike]] = None, units: Optional[Literal["s", "ms"]] = None, - ) -> np.ndarray: + ) -> NDArray: """Locate "artifacts" in the record. Parameters @@ -1941,7 +1943,7 @@ def locate_artifacts( return np.array([], dtype=dtype) # df_rpeaks_with_type_info = df_rpeaks_with_type_info[["Type", "rpointadj"]] - artifacts = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 0]["rpointadj"].values)).astype(int) + artifacts = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 0]["rpointadj"].values)).astype(int) # type: ignore if units is not None: fs = df_rpeaks_with_type_info.iloc[0]["samplingrate"] @@ -1965,7 +1967,7 @@ def locate_abnormal_beats( wave_deli_path: Optional[Union[str, bytes, os.PathLike]] = None, abnormal_type: Optional[Literal["VE", "SVE"]] = None, units: Optional[Literal["s", "ms"]] = None, - ) -> Union[Dict[str, np.ndarray], np.ndarray]: + ) -> Union[Dict[str, NDArray], NDArray, None]: """Locate "abnormal beats" in the record. Parameters @@ -2005,8 +2007,8 @@ def locate_abnormal_beats( if not df_rpeaks_with_type_info.empty: # df_rpeaks_with_type_info = df_rpeaks_with_type_info[["Type", "rpointadj"]] # 2 = VE, 3 = SVE - ve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 2]["rpointadj"].values)).astype(int) - sve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 3]["rpointadj"].values)).astype(int) + ve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 2]["rpointadj"].values)).astype(int) # type: ignore + sve = (np.round(df_rpeaks_with_type_info[df_rpeaks_with_type_info["Type"] == 3]["rpointadj"].values)).astype(int) # type: ignore abnormal_rpeaks = {"VE": ve, "SVE": sve} else: dtype = int if units is None or units.lower() != "s" else float @@ -2043,7 +2045,7 @@ def load_eeg_band_ann( rec: Union[str, int], eeg_band_ann_path: Optional[Union[str, bytes, os.PathLike]] = None, **kwargs: Any, - ) -> pd.DataFrame: + ) -> pd.DataFrame: # type: ignore """Load annotations on EEG bands of the record. Parameters @@ -2062,7 +2064,7 @@ def load_eeg_band_ann( """ if self.current_version >= "0.15.0": - self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") + self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") # type: ignore else: raise NotImplementedError @@ -2071,7 +2073,7 @@ def load_eeg_spectral_ann( rec: Union[str, int], eeg_spectral_ann_path: Optional[Union[str, bytes, os.PathLike]] = None, **kwargs: Any, - ) -> pd.DataFrame: + ) -> pd.DataFrame: # type: ignore """Load annotations on EEG spectral summary of the record. Parameters @@ -2090,7 +2092,7 @@ def load_eeg_spectral_ann( """ if self.current_version >= "0.15.0": - self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") + self.logger.info(f"EEG spectral summary variables are removed in version {self.current_version}") # type: ignore else: raise NotImplementedError @@ -2215,8 +2217,8 @@ def _plot_ann( raise NotImplementedError("Plotting of some type of events in `df_sleep_event` has not been implemented yet!") if plot_format.lower() == "hypnogram": - stage_mask = df_sleep_stage["sleep_stage"].values - stage_mask = len(self.sleep_stage_names) - 1 - stage_mask + stage_mask = df_sleep_stage["sleep_stage"].values # type: ignore + stage_mask = len(self.sleep_stage_names) - 1 - stage_mask # type: ignore fig, ax = self.plot_hypnogram(stage_mask, granularity=30) return @@ -2238,42 +2240,42 @@ def _plot_ann( ax_events.set_xlabel("Time", fontsize=16) # ax_events.set_ylabel("Events", fontsize=16) else: - ax_stages, ax_events = axes + ax_stages, ax_events = axes # type: ignore ax_stages.set_title("Sleep Stages and Events", fontsize=24) ax_events.set_xlabel("Time", fontsize=16) if ax_stages is not None: - for k, v in sleep_stages.items(): + for k, v in sleep_stages.items(): # type: ignore for itv in v: ax_stages.axvspan( - datetime.fromtimestamp(itv[0]), - datetime.fromtimestamp(itv[1]), + datetime.fromtimestamp(itv[0]), # type: ignore + datetime.fromtimestamp(itv[1]), # type: ignore color=self.palette[k], alpha=plot_alpha, ) ax_stages.legend( - handles=[patches[k] for k in self.all_sleep_stage_names if k in sleep_stages.keys()], + handles=[patches[k] for k in self.all_sleep_stage_names if k in sleep_stages.keys()], # type: ignore loc="best", ) # keep ordering plt.setp(ax_stages.get_yticklabels(), visible=False) ax_stages.tick_params(axis="y", which="both", length=0) if ax_events is not None: - for _, row in df_sleep_event.iterrows(): + for _, row in df_sleep_event.iterrows(): # type: ignore ax_events.axvspan( - datetime.fromtimestamp(row["event_start"]), - datetime.fromtimestamp(row["event_end"]), + datetime.fromtimestamp(row["event_start"]), # type: ignore + datetime.fromtimestamp(row["event_end"]), # type: ignore color=self.palette[row["event_name"]], alpha=plot_alpha, ) ax_events.legend( - handles=[patches[k] for k in current_legal_events if k in set(df_sleep_event["event_name"])], + handles=[patches[k] for k in current_legal_events if k in set(df_sleep_event["event_name"])], # type: ignore loc="best", ) # keep ordering plt.setp(ax_events.get_yticklabels(), visible=False) ax_events.tick_params(axis="y", which="both", length=0) - def str_to_real_number(self, s: Union[str, Real]) -> Real: + def str_to_real_number(self, s: Union[str, float, int]) -> Union[float, int]: """Convert a string to a real number. Some columns in the annotations might incorrectly @@ -2560,7 +2562,7 @@ def folder_or_file(self) -> Dict[str, Path]: "wave_delineation": self.wave_deli_path, "event": self.event_ann_path, "event_profusion": self.event_profusion_ann_path, - } + } # type: ignore @property def url(self) -> str: diff --git a/torch_ecg/databases/other_databases/cachet_cadb.py b/torch_ecg/databases/other_databases/cachet_cadb.py index 22726d67..6855c475 100644 --- a/torch_ecg/databases/other_databases/cachet_cadb.py +++ b/torch_ecg/databases/other_databases/cachet_cadb.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd import scipy.signal as SS +from numpy.typing import NDArray from ...cfg import DEFAULTS from ...utils.download import http_get @@ -349,7 +350,7 @@ def load_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load physical (converted from digital) ECG data, or load digital signal directly. @@ -434,7 +435,7 @@ def load_context_data( channels: Optional[Union[str, int, List[str], List[int]]] = None, units: Optional[str] = None, fs: Optional[Real] = None, - ) -> Union[np.ndarray, pd.DataFrame]: + ) -> Union[NDArray, pd.DataFrame]: """Load context data (e.g. accelerometer, heart rate, etc.). Parameters @@ -544,7 +545,7 @@ def load_context_data( def load_ann( self, rec: Union[str, int], ann_format: str = "pd" - ) -> Union[pd.DataFrame, np.ndarray, Dict[Union[int, str], np.ndarray]]: + ) -> Union[pd.DataFrame, NDArray, Dict[Union[int, str], NDArray]]: """Load annotation from the metadata file. Parameters diff --git a/torch_ecg/databases/other_databases/sph.py b/torch_ecg/databases/other_databases/sph.py index 0e566a03..5aabc0d3 100644 --- a/torch_ecg/databases/other_databases/sph.py +++ b/torch_ecg/databases/other_databases/sph.py @@ -9,6 +9,7 @@ import h5py import numpy as np import pandas as pd +from numpy.typing import NDArray from ...cfg import DEFAULTS from ...utils import EAK @@ -172,7 +173,7 @@ def load_data( data_format: str = "channel_first", units: str = "mV", return_fs: bool = False, - ) -> np.ndarray: + ) -> NDArray: """Load ECG data from h5 file of the record. Parameters @@ -233,10 +234,9 @@ def load_ann(self, rec: Union[str, int], ann_format: str = "c", ignore_modifier: Record name or index of the record in :attr:`all_records`. ann_format : str, default "a" Format of labels, one of the following (case insensitive): - - - "a": abbreviations - - "f": full names - - "c": AHACode + - "a": abbreviations + - "f": full names + - "c": AHACode ignore_modifier : bool, default True Whether to ignore the modifiers of the annotations or not. For example, "60+310" will be converted to "60". @@ -390,7 +390,7 @@ def download(self, files: Optional[Union[str, Sequence[str]]]) -> None: def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, ann: Optional[Sequence[str]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, List[Union[str, int]]]] = None, diff --git a/torch_ecg/databases/physionet_databases/afdb.py b/torch_ecg/databases/physionet_databases/afdb.py index 5adf2122..e67a7cda 100644 --- a/torch_ecg/databases/physionet_databases/afdb.py +++ b/torch_ecg/databases/physionet_databases/afdb.py @@ -8,6 +8,7 @@ import numpy as np import wfdb +from numpy.typing import NDArray from ...cfg import CFG from ...utils.misc import add_docstring, get_record_list_recursive @@ -29,10 +30,10 @@ 3. signals are sampled at 250 samples per second with 12-bit resolution over a range of ±10 millivolts, with a typical recording bandwidth of approximately 0.1 Hz to 40 Hz 4. 4 classes of rhythms are annotated: - - AFIB: atrial fibrillation - - AFL: atrial flutter - - J: AV junctional rhythm - - N: all other rhythms + - AFIB: atrial fibrillation + - AFL: atrial flutter + - J: AV junctional rhythm + - N: all other rhythms 5. rhythm annotations almost all start with "(N", except for 4 which start with '(AFIB', which are all within 1 second (250 samples) 6. Webpage of the database on PhysioNet [1]_. Paper describing the database [2]_. @@ -138,7 +139,7 @@ def load_ann( sampto: Optional[int] = None, ann_format: Literal["intervals", "mask"] = "intervals", keep_original: bool = False, - ) -> Union[Dict[str, list], np.ndarray]: + ) -> Union[Dict[str, list], NDArray]: """Load annotations (header) from the .hea files. Parameters @@ -203,7 +204,7 @@ def load_beat_ann( sampto: Optional[int] = None, use_manual: bool = True, keep_original: bool = False, - ) -> np.ndarray: + ) -> NDArray: """Load beat annotations from corresponding annotation files. Parameters @@ -253,7 +254,7 @@ def load_rpeak_indices( sampto: Optional[int] = None, use_manual: bool = True, keep_original: bool = False, - ) -> np.ndarray: + ) -> NDArray: """ alias of `self.load_beat_ann` """ @@ -262,9 +263,9 @@ def load_rpeak_indices( def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, List[str], List[int]]] = None, sampfrom: Optional[int] = None, diff --git a/torch_ecg/databases/physionet_databases/apnea_ecg.py b/torch_ecg/databases/physionet_databases/apnea_ecg.py index 1fd07a37..28677d3e 100644 --- a/torch_ecg/databases/physionet_databases/apnea_ecg.py +++ b/torch_ecg/databases/physionet_databases/apnea_ecg.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import wfdb +from numpy.typing import NDArray from ...cfg import DEFAULTS from ...utils import add_docstring @@ -200,7 +201,7 @@ def load_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs) @add_docstring(PhysioNetDataBase.load_data.__doc__) @@ -213,7 +214,7 @@ def load_ecg_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: if isinstance(rec, int): rec = self[rec] if rec not in self.ecg_records: @@ -244,7 +245,7 @@ def load_rsp_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> np.ndarray: + ) -> NDArray: if rec not in self.rsp_records: raise ValueError(f"`{rec}` is not a record of RSP signals") data = self.load_data( diff --git a/torch_ecg/databases/physionet_databases/cinc2017.py b/torch_ecg/databases/physionet_databases/cinc2017.py index 799fc7e2..aec1ed8b 100644 --- a/torch_ecg/databases/physionet_databases/cinc2017.py +++ b/torch_ecg/databases/physionet_databases/cinc2017.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from numpy.typing import NDArray from ...cfg import DEFAULTS from ...utils.misc import add_docstring, get_record_list_recursive3 @@ -208,9 +209,8 @@ def load_ann(self, rec: Union[str, int], version: Optional[int] = None, ann_form Version of the annotation file, by default the latest version. ann_format : {"a", "f"}, optional Format of returned annotation, by default "a". - - - "a" - abbreviation - - "f" - full name + - "a", abbreviation + - "f", full name Returns ------- @@ -236,10 +236,10 @@ def load_ann(self, rec: Union[str, int], version: Optional[int] = None, ann_form def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, ann: Optional[str] = None, ticks_granularity: int = 0, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, ) -> None: """Plot the ECG signal of the record. diff --git a/torch_ecg/databases/physionet_databases/cinc2018.py b/torch_ecg/databases/physionet_databases/cinc2018.py index b0470907..a7460b89 100644 --- a/torch_ecg/databases/physionet_databases/cinc2018.py +++ b/torch_ecg/databases/physionet_databases/cinc2018.py @@ -6,10 +6,10 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -import numpy as np import pandas as pd import scipy.signal as SS import wfdb +from numpy.typing import NDArray from tqdm.auto import tqdm from ...cfg import DEFAULTS @@ -96,9 +96,8 @@ class CINC2018(PhysioNetDataBase, PSGDataBaseMixin): Level of logging verbosity. kwargs : dict, optional Auxilliary key word arguments, including: - - - `subset` : {"training", "test"}, default "training" - The subset of the database to use. + - `subset` : {"training", "test"}, default "training". + The subset of the database to use. """ @@ -308,7 +307,7 @@ def load_psg_data( physical: bool = True, fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load PSG data of the record. Parameters @@ -407,7 +406,7 @@ def load_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load ECG data of the record. Parameters @@ -479,7 +478,7 @@ def load_ecg_data( units: Union[str, type(None)] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """alias of `load_data`""" return self.load_data( rec=rec, diff --git a/torch_ecg/databases/physionet_databases/cinc2020.py b/torch_ecg/databases/physionet_databases/cinc2020.py index c4794630..f296bdbd 100644 --- a/torch_ecg/databases/physionet_databases/cinc2020.py +++ b/torch_ecg/databases/physionet_databases/cinc2020.py @@ -16,6 +16,7 @@ import pandas as pd import scipy.signal as SS import wfdb +from numpy.typing import NDArray from scipy.io import loadmat from ...cfg import CFG, DEFAULTS @@ -501,7 +502,7 @@ def load_data( units: Literal["mV", "μV", "uV", None] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load physical (converted from digital) ECG data, which is more understandable for humans; or load digital signal directly. @@ -858,10 +859,9 @@ def get_labels( in the CINC2020 official phase. fmt : str, default "s" Format of labels, one of the following (case insensitive): - - - "a", abbreviations - - "f", full names - - "s", SNOMED CT Code + - "a", abbreviations + - "f", full names + - "s", SNOMED CT Code normalize : bool, default True If True, the labels will be transformed into their equavalents, @@ -945,8 +945,8 @@ def get_subject_info(self, rec: Union[str, int], items: Optional[List[str]] = No def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, List[str]]] = None, same_range: bool = False, @@ -1193,7 +1193,7 @@ def load_resampled_data( rec: Union[str, int], data_format: str = "channel_first", siglen: Optional[int] = None, - ) -> np.ndarray: + ) -> NDArray: """ Resample the data of `rec` to 500Hz, or load the resampled data in 500Hz, if the corr. data file already exists @@ -1248,7 +1248,7 @@ def load_resampled_data( data = np.moveaxis(data, -1, -2) return data - def load_raw_data(self, rec: Union[str, int], backend: Literal["scipy", "wfdb"] = "scipy") -> np.ndarray: + def load_raw_data(self, rec: Union[str, int], backend: Literal["scipy", "wfdb"] = "scipy") -> NDArray: """Load raw data from corresponding files with no further processing, in order to facilitate feeding data into the `run_12ECG_classifier` function. @@ -1406,7 +1406,7 @@ def compute_all_metrics(classes: List[str], truth: Sequence, binary_pred: Sequen ) -def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float: +def _compute_accuracy(labels: NDArray, outputs: NDArray) -> float: """Compute recording-wise accuracy. Parameters @@ -1434,7 +1434,7 @@ def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float: return float(num_correct_recordings) / float(num_recordings) -def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normalize: bool = False) -> np.ndarray: +def _compute_confusion_matrices(labels: NDArray, outputs: NDArray, normalize: bool = False) -> NDArray: """Compute confusion matrices. Compute a binary confusion matrix for each class k: @@ -1498,7 +1498,7 @@ def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normali return A -def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float: +def _compute_f_measure(labels: NDArray, outputs: NDArray) -> float: """Compute macro-averaged F1 score. Parameters @@ -1533,7 +1533,7 @@ def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> float: return macro_f_measure -def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]: +def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]: """Compute F-beta and G-beta measures. Parameters @@ -1578,7 +1578,7 @@ def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) return macro_f_beta_measure, macro_g_beta_measure -def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float]: +def _compute_auc(labels: NDArray, outputs: NDArray) -> Tuple[float, float]: """Compute macro-averaged AUROC and macro-averaged AUPRC. Parameters @@ -1674,7 +1674,7 @@ def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float] return macro_auroc, macro_auprc -def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) -> np.ndarray: +def _compute_modified_confusion_matrix(labels: NDArray, outputs: NDArray) -> NDArray: """ Compute a binary multi-class, multi-label confusion matrix, where the rows are the labels and the columns are the outputs. @@ -1712,9 +1712,9 @@ def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) def compute_challenge_metric( - weights: np.ndarray, - labels: np.ndarray, - outputs: np.ndarray, + weights: NDArray, + labels: NDArray, + outputs: NDArray, classes: List[str], normal_class: str, ) -> float: diff --git a/torch_ecg/databases/physionet_databases/cinc2021.py b/torch_ecg/databases/physionet_databases/cinc2021.py index 22088a9b..51a57344 100644 --- a/torch_ecg/databases/physionet_databases/cinc2021.py +++ b/torch_ecg/databases/physionet_databases/cinc2021.py @@ -17,6 +17,7 @@ import pandas as pd import scipy.signal as SS import wfdb +from numpy.typing import NDArray from scipy.io import loadmat from tqdm.auto import tqdm @@ -88,8 +89,8 @@ Each recording is 10 seconds long with a sampling frequency of 500 Hz this tranche contains two subsets: - - Chapman_Shaoxing: "JS00001" - "JS10646" - - Ningbo: "JS10647" - "JS45551" + - Chapman_Shaoxing: "JS00001" - "JS10646" + - Ningbo: "JS10647" - "JS45551" All files can be downloaded from [8]_ or [9]_. @@ -115,6 +116,7 @@ ... leads = ann["df_leads"]["lead_name"].values.tolist() ... if leads not in set_leads: ... set_leads.append(leads) + 5. Challenge official website [1]_. Webpage of the database on PhysioNet [2]_. """, @@ -670,7 +672,7 @@ def load_data( units: Literal["mV", "μV", "uV", "muV", None] = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: """Load physical (converted from digital) ECG data. Parameters @@ -1049,10 +1051,9 @@ def get_labels( in the CINC2021 official phase. fmt : str, default "s" Format of labels, one of the following (case insensitive): - - - "a", abbreviations - - "f", full names - - "s", SNOMED CT Code + - "a", abbreviations + - "f", full names + - "s", SNOMED CT Code normalize : bool, default True If True, the labels will be transformed into their equavalents, @@ -1153,7 +1154,7 @@ def get_subject_info(self, rec: Union[str, int], items: Optional[List[str]] = No def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, ann: Optional[Dict[str, Sequence[str]]] = None, ticks_granularity: int = 0, leads: Optional[Union[str, Sequence[str]]] = None, @@ -1402,7 +1403,7 @@ def load_resampled_data( leads: Optional[Union[str, List[str]]] = None, data_format: str = "channel_first", siglen: Optional[int] = None, - ) -> np.ndarray: + ) -> NDArray: """ Resample the data of `rec` to 500Hz, or load the resampled data in 500Hz, if the corr. data file already exists @@ -1472,7 +1473,7 @@ def load_resampled_data( data = np.moveaxis(data, -1, -2) return data - def load_raw_data(self, rec: Union[str, int], backend: Literal["wfdb", "scipy"] = "scipy") -> np.ndarray: + def load_raw_data(self, rec: Union[str, int], backend: Literal["wfdb", "scipy"] = "scipy") -> NDArray: """Load raw data from corresponding files with no further processing. This method facilitates feeding data into the `run_12ECG_classifier` function. @@ -1700,7 +1701,7 @@ def database_info(self) -> DataBaseInfo: def compute_all_metrics_detailed( classes: List[str], truth: Sequence, binary_pred: Sequence, scalar_pred: Sequence -) -> Tuple[Union[float, np.ndarray]]: +) -> Tuple[Union[float, NDArray]]: """Compute detailed metrics for each class. Parameters @@ -1781,7 +1782,7 @@ def compute_all_metrics_detailed( def compute_all_metrics( classes: List[str], truth: Sequence, binary_pred: Sequence, scalar_pred: Sequence -) -> Tuple[Union[float, np.ndarray]]: +) -> Tuple[Union[float, NDArray]]: """Simplified version of :func:`compute_all_metrics_detailed`. This function doesnot produce per-class scores. @@ -1841,7 +1842,7 @@ def compute_all_metrics( ) -def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float: +def _compute_accuracy(labels: NDArray, outputs: NDArray) -> float: """Compute recording-wise accuracy. Parameters @@ -1869,7 +1870,7 @@ def _compute_accuracy(labels: np.ndarray, outputs: np.ndarray) -> float: return float(num_correct_recordings) / float(num_recordings) -def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normalize: bool = False) -> np.ndarray: +def _compute_confusion_matrices(labels: NDArray, outputs: NDArray, normalize: bool = False) -> NDArray: """Compute confusion matrices. Compute a binary confusion matrix for each class k: @@ -1933,7 +1934,7 @@ def _compute_confusion_matrices(labels: np.ndarray, outputs: np.ndarray, normali return A -def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, np.ndarray]: +def _compute_f_measure(labels: NDArray, outputs: NDArray) -> Tuple[float, NDArray]: """Compute macro-averaged F1 score, and F1 score per class. Parameters @@ -1973,7 +1974,7 @@ def _compute_f_measure(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, return macro_f_measure, f_measure -def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) -> Tuple[float, float]: +def _compute_beta_measures(labels: NDArray, outputs: NDArray, beta: Real) -> Tuple[float, float]: """Compute F-beta and G-beta measures. Parameters @@ -2018,7 +2019,7 @@ def _compute_beta_measures(labels: np.ndarray, outputs: np.ndarray, beta: Real) return macro_f_beta_measure, macro_g_beta_measure -def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float, np.ndarray, np.ndarray]: +def _compute_auc(labels: NDArray, outputs: NDArray) -> Tuple[float, float, NDArray, NDArray]: """Compute macro-averaged AUROC and macro-averaged AUPRC. Parameters @@ -2125,7 +2126,7 @@ def _compute_auc(labels: np.ndarray, outputs: np.ndarray) -> Tuple[float, float, # Compute modified confusion matrix for multi-class, multi-label tasks. -def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) -> np.ndarray: +def _compute_modified_confusion_matrix(labels: NDArray, outputs: NDArray) -> NDArray: """ Compute a binary multi-class, multi-label confusion matrix, where the rows are the labels and the columns are the outputs. @@ -2165,9 +2166,9 @@ def _compute_modified_confusion_matrix(labels: np.ndarray, outputs: np.ndarray) # Compute the evaluation metric for the Challenge. def compute_challenge_metric( - weights: np.ndarray, - labels: np.ndarray, - outputs: np.ndarray, + weights: NDArray, + labels: NDArray, + outputs: NDArray, classes: List[str], sinus_rhythm: str, ) -> float: diff --git a/torch_ecg/databases/physionet_databases/ltafdb.py b/torch_ecg/databases/physionet_databases/ltafdb.py index 44638d34..4d458ef7 100644 --- a/torch_ecg/databases/physionet_databases/ltafdb.py +++ b/torch_ecg/databases/physionet_databases/ltafdb.py @@ -9,6 +9,7 @@ import numpy as np import wfdb +from numpy.typing import NDArray from ...cfg import CFG from ...utils.misc import add_docstring @@ -160,7 +161,7 @@ def load_data( units: str = "mV", fs: Optional[Real] = None, return_fs: bool = False, - ) -> Union[np.ndarray, Tuple[np.ndarray, Real]]: + ) -> Union[NDArray, Tuple[NDArray, Real]]: return super().load_data(rec, leads, sampfrom, sampto, data_format, units, fs, return_fs) def load_ann( @@ -234,7 +235,7 @@ def load_rhythm_ann( sampto: Optional[int] = None, rhythm_format: Literal["intervals", "mask"] = "intervals", keep_original: bool = False, - ) -> Union[Dict[str, list], np.ndarray]: + ) -> Union[Dict[str, list], NDArray]: """Load rhythm annotations of the record. Rhythm annotations are stored in the `aux_note` attribute @@ -323,7 +324,7 @@ def load_beat_ann( sampto: Optional[int] = None, beat_format: Literal["beat", "dict"] = "beat", keep_original: bool = False, - ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]: + ) -> Union[Dict[str, NDArray], List[BeatAnn]]: """Load beat annotations of the record. Beat annotations are stored in the `symbol` attribute @@ -390,7 +391,7 @@ def load_rpeak_indices( sampto: Optional[int] = None, use_manual: bool = True, keep_original: bool = False, - ) -> np.ndarray: + ) -> NDArray: """Load rpeak indices of the record. Rpeak indices, or equivalently qrs complex locations, @@ -437,10 +438,10 @@ def load_rpeak_indices( def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, - beat_ann: Optional[Dict[str, np.ndarray]] = None, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, + beat_ann: Optional[Dict[str, NDArray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, ticks_granularity: int = 0, leads: Optional[Union[int, List[int]]] = None, sampfrom: Optional[int] = None, diff --git a/torch_ecg/databases/physionet_databases/ludb.py b/torch_ecg/databases/physionet_databases/ludb.py index 994e79d8..89fd86dc 100644 --- a/torch_ecg/databases/physionet_databases/ludb.py +++ b/torch_ecg/databases/physionet_databases/ludb.py @@ -2,13 +2,13 @@ import os from copy import deepcopy -from numbers import Real from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd import wfdb +from numpy.typing import NDArray from ...cfg import CFG, DEFAULTS from ...utils import EAK @@ -293,8 +293,6 @@ def __init__( verbose=verbose, **kwargs, ) - if self.version == "1.0.0": - self.logger.info("Version of LUDB 1.0.0 has bugs, make sure that version 1.0.1 or higher is used") self.fs = 500 self.spacing = 1000 / self.fs self.data_ext = "dat" @@ -320,7 +318,7 @@ def __init__( self._df_subject_info = None self._ls_rec() - def _ls_rec(self) -> None: + def _ls_rec(self) -> None: # type: ignore """Find all records in the database directory and store them (path, metadata, etc.) in some private attributes. """ @@ -389,7 +387,7 @@ def get_absolute_path(self, rec: Union[str, int], extension: Optional[str] = Non extension = f".{extension}" return self.db_dir / "data" / f"{rec}{extension or ''}" - def load_ann( + def load_ann( # type: ignore self, rec: Union[str, int], leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, @@ -430,28 +428,28 @@ def load_ann( df_lead_ann["onset"] = np.nan df_lead_ann["offset"] = np.nan for i, row in df_lead_ann.iterrows(): - peak_idx = peak_inds[i] + peak_idx = peak_inds[i] # type: ignore if peak_idx == 0: - df_lead_ann.loc[i, "onset"] = row["peak"] + df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore if symbols[peak_idx + 1] == ")": - df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] + df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] # type: ignore else: - df_lead_ann.loc[i, "offset"] = row["peak"] + df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore elif peak_idx == len(symbols) - 1: - df_lead_ann.loc[i, "offset"] = row["peak"] + df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore if symbols[peak_idx - 1] == "(": - df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] + df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] # type: ignore else: - df_lead_ann.loc[i, "onset"] = row["peak"] + df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore else: if symbols[peak_idx - 1] == "(": - df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] + df_lead_ann.loc[i, "onset"] = ann.sample[peak_idx - 1] # type: ignore else: - df_lead_ann.loc[i, "onset"] = row["peak"] + df_lead_ann.loc[i, "onset"] = row["peak"] # type: ignore if symbols[peak_idx + 1] == ")": - df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] + df_lead_ann.loc[i, "offset"] = ann.sample[peak_idx + 1] # type: ignore else: - df_lead_ann.loc[i, "offset"] = row["peak"] + df_lead_ann.loc[i, "offset"] = row["peak"] # type: ignore # df_lead_ann["onset"] = ann.sample[np.where(symbols=="(")[0]] # df_lead_ann["offset"] = ann.sample[np.where(symbols==")")[0]] @@ -501,7 +499,7 @@ def load_masks( leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, mask_format: str = "channel_first", class_map: Optional[Dict[str, int]] = None, - ) -> np.ndarray: + ) -> NDArray: """Load the wave delineation in the form of masks. Parameters @@ -543,11 +541,11 @@ def load_masks( def from_masks( self, - masks: np.ndarray, + masks: NDArray, mask_format: str = "channel_first", leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, class_map: Optional[Dict[str, int]] = None, - fs: Optional[Real] = None, + fs: Optional[Union[float, int]] = None, ) -> Dict[str, List[ECGWaveForm]]: """Convert masks into lists of waveforms. @@ -565,7 +563,7 @@ def from_masks( class_map : dict, optional Custom class map. If not set, `self.class_map` will be used. - fs : numbers.Real, optional + fs : float or int, optional Sampling frequency of the signal corresponding to the `masks`, If is None, `self.fs` will be used, to compute `duration` of the ECG waveforms. @@ -655,15 +653,15 @@ def _load_header(self, rec: Union[str, int]) -> dict: header_dict["adc_gain"] = header_reader.adc_gain header_dict["record_fmt"] = header_reader.fmt try: - header_dict["age"] = int([line for line in header_reader.comments if "" in line][0].split(": ")[-1]) + header_dict["age"] = int([line for line in header_reader.comments if "" in line][0].split(": ")[-1]) # type: ignore except Exception: header_dict["age"] = np.nan try: - header_dict["sex"] = [line for line in header_reader.comments if "" in line][0].split(": ")[-1] + header_dict["sex"] = [line for line in header_reader.comments if "" in line][0].split(": ")[-1] # type: ignore except Exception: header_dict["sex"] = "" - d_start = [idx for idx, line in enumerate(header_reader.comments) if "" in line][0] + 1 - header_dict["diagnoses"] = header_reader.comments[d_start:] + d_start = [idx for idx, line in enumerate(header_reader.comments) if "" in line][0] + 1 # type: ignore + header_dict["diagnoses"] = header_reader.comments[d_start:] # type: ignore return header_dict def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Sequence[str]]] = None) -> Union[dict, str]: @@ -685,26 +683,26 @@ def load_subject_info(self, rec: Union[str, int], fields: Optional[Union[str, Se """ if isinstance(rec, int): rec = self[rec] - row = self._df_subject_info[self._df_subject_info.ID == rec] + row = self._df_subject_info[self._df_subject_info.ID == rec] # type: ignore if row.empty: return {} row = row.iloc[0] info = row.to_dict() if fields is not None: if isinstance(fields, str): - assert fields in self._df_subject_info.columns, f"No field `{fields}`" + assert fields in self._df_subject_info.columns, f"No field `{fields}`" # type: ignore info = info[fields] else: assert set(fields).issubset( - set(self._df_subject_info.columns) - ), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}" + set(self._df_subject_info.columns) # type: ignore + ), f"No field(s) {set(fields).difference(set(self._df_subject_info.columns))}" # type: ignore info = {k: v for k, v in info.items() if k in fields} return info def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, Sequence[Union[str, int]]]] = None, same_range: bool = False, @@ -749,29 +747,30 @@ def plot( Contributors: Jeethan, and WEN Hao """ + import matplotlib.pyplot as plt + from matplotlib.ticker import MultipleLocator + + MultipleLocator.MAXTICKS = 3000 + if isinstance(rec, int): rec = self[rec] - if "plt" not in dir(): - import matplotlib.pyplot as plt - - plt.MultipleLocator.MAXTICKS = 3000 if data is not None: assert leads is not None, "`leads` must be specified when `data` is given" data = np.atleast_2d(data) _leads = self._normalize_leads(leads) - _lead_indices = [self.all_leads.index(ld) for ld in _leads] + _lead_indices = [self.all_leads.index(ld) for ld in _leads] # type: ignore assert len(_leads) == data.shape[0], "number of leads must match data" units = self._auto_infer_units(data) - self.logger.info(f"input data is auto detected to have units in {units}") + self.logger.info(f"input data is auto detected to have units in {units}") # type: ignore if units.lower() == "mv": _data = 1000 * data else: _data = data else: _leads = self._normalize_leads(leads) - _lead_indices = [self.all_leads.index(ld) for ld in _leads] - _data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices] + _lead_indices = [self.all_leads.index(ld) for ld in _leads] # type: ignore + _data = self.load_data(rec, data_format="channel_first", units="μV")[_lead_indices] # type: ignore if same_range: y_ranges = np.ones((_data.shape[0],)) * np.max(np.abs(_data)) + 100 @@ -828,12 +827,12 @@ def plot( axes[idx].axhline(y=0, linestyle="-", linewidth="1.0", color="red") # NOTE that `Locator` has default `MAXTICKS` equal to 1000 if ticks_granularity >= 1: - axes[idx].xaxis.set_major_locator(plt.MultipleLocator(0.2)) - axes[idx].yaxis.set_major_locator(plt.MultipleLocator(500)) + axes[idx].xaxis.set_major_locator(MultipleLocator(0.2)) + axes[idx].yaxis.set_major_locator(MultipleLocator(500)) axes[idx].grid(which="major", linestyle="-", linewidth="0.5", color="red") if ticks_granularity >= 2: - axes[idx].xaxis.set_minor_locator(plt.MultipleLocator(0.04)) - axes[idx].yaxis.set_minor_locator(plt.MultipleLocator(100)) + axes[idx].xaxis.set_minor_locator(MultipleLocator(0.04)) + axes[idx].yaxis.set_minor_locator(MultipleLocator(100)) axes[idx].grid(which="minor", linestyle=":", linewidth="0.5", color="black") # add extra info. to legend # https://stackoverflow.com/questions/16826711/is-it-possible-to-add-a-string-as-a-legend-item-in-matplotlib @@ -868,10 +867,10 @@ def database_info(self) -> DataBaseInfo: def compute_metrics( - truth_masks: Sequence[np.ndarray], - pred_masks: Sequence[np.ndarray], + truth_masks: Sequence[NDArray], + pred_masks: Sequence[NDArray], class_map: Dict[str, int], - fs: Real, + fs: Union[float, int], mask_format: str = "channel_first", ) -> Dict[str, Dict[str, float]]: """Compute metrics for the wave delineation task. @@ -892,7 +891,7 @@ def compute_metrics( class_map : Dict[str, int] Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain "pwave", "qrs", "twave". - fs : numbers.Real + fs : float or int Sampling frequency of the signal corresponding to the masks, used to compute the duration of each waveform, hence the error and standard deviations of errors. @@ -932,7 +931,7 @@ def compute_metrics( def compute_metrics_waveform( truth_waveforms: Sequence[Sequence[ECGWaveForm]], pred_waveforms: Sequence[Sequence[ECGWaveForm]], - fs: Real, + fs: Union[float, int], ) -> Dict[str, Dict[str, float]]: """ Compute the sensitivity, precision, f1_score, mean error @@ -948,7 +947,7 @@ def compute_metrics_waveform( The predictions corresponding to `truth_waveforms`. Each element is a sequence of :class:`ECGWaveForm` from the same sample. - fs : numbers.Real + fs : float or int Sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, hence the error and standard deviations of errors. @@ -1047,7 +1046,7 @@ def compute_metrics_waveform( def _compute_metrics_waveform( - truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Real + truths: Sequence[ECGWaveForm], preds: Sequence[ECGWaveForm], fs: Union[float, int] ) -> Dict[str, Dict[str, float]]: """ compute the sensitivity, precision, f1_score, mean error @@ -1060,7 +1059,7 @@ def _compute_metrics_waveform( The ground truth. preds : Sequence[ECGWaveForm] The predictions corresponding to `truths`, - fs : numbers.Real + fs : float or int Sampling frequency of the signal corresponding to the waveforms, used to compute the duration of each waveform, hence the error and standard deviations of errors. @@ -1132,17 +1131,17 @@ def _compute_metrics_waveform( def _compute_metrics_base( - truths: Sequence[Real], preds: Sequence[Real], fs: Real + truths: Sequence[Union[float, int]], preds: Sequence[Union[float, int]], fs: Union[float, int] ) -> Tuple[int, int, int, List[float], float, float, float, float, float]: """The base function for computing the metrics. Parameters ---------- - truths : Sequence[Real] + truths : Sequence[Union[float, int]] Ground truth of indices of corresponding critical points. - preds : Sequence[Real] + preds : Sequence[Union[float, int]] Predicted indices of corresponding critical points. - fs : numbers.Real + fs : float or int Sampling frequency of the signal corresponding to the critical points, used to compute the duration of each waveform, hence the error and standard deviations of errors. @@ -1189,7 +1188,7 @@ def _compute_metrics_base( mean_error = np.mean(errors) * 1000 / fs standard_deviation = np.std(errors) * 1000 / fs - return ( + return ( # type: ignore truth_positive, false_negative, false_positive, diff --git a/torch_ecg/databases/physionet_databases/mitdb.py b/torch_ecg/databases/physionet_databases/mitdb.py index e52774b2..e01b8559 100644 --- a/torch_ecg/databases/physionet_databases/mitdb.py +++ b/torch_ecg/databases/physionet_databases/mitdb.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import wfdb +from numpy.typing import NDArray from tqdm.auto import tqdm from ...cfg import CFG, DEFAULTS @@ -325,7 +326,7 @@ def load_rhythm_ann( rhythm_format: Literal["intervals", "mask"] = "intervals", rhythm_types: Optional[Sequence[str]] = None, keep_original: bool = False, - ) -> Union[Dict[str, list], np.ndarray]: + ) -> Union[Dict[str, list], NDArray]: """Load rhythm annotations of the record. Rhythm annotations are stored in the `aux_note` attribute @@ -374,7 +375,7 @@ def load_beat_ann( beat_format: Literal["beat", "dict"] = "beat", beat_types: Optional[Sequence[str]] = None, keep_original: bool = False, - ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]: + ) -> Union[Dict[str, NDArray], List[BeatAnn]]: """Load beat annotations of the record. Beat annotations are stored in the `symbol` attribute @@ -420,7 +421,7 @@ def load_rpeak_indices( sampfrom: Optional[int] = None, sampto: Optional[int] = None, keep_original: bool = False, - ) -> np.ndarray: + ) -> NDArray: """Load rpeak indices of the record. Rpeak indices, or equivalently qrs complex locations, @@ -560,10 +561,10 @@ def rhythm_types_records(self) -> Dict[str, List[str]]: def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, - ann: Optional[Dict[str, np.ndarray]] = None, - beat_ann: Optional[Dict[str, np.ndarray]] = None, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + data: Optional[NDArray] = None, + ann: Optional[Dict[str, NDArray]] = None, + beat_ann: Optional[Dict[str, NDArray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, ticks_granularity: int = 0, leads: Optional[Union[int, List[int]]] = None, sampfrom: Optional[int] = None, diff --git a/torch_ecg/databases/physionet_databases/ptb_xl.py b/torch_ecg/databases/physionet_databases/ptb_xl.py index 0287ce22..541f4170 100644 --- a/torch_ecg/databases/physionet_databases/ptb_xl.py +++ b/torch_ecg/databases/physionet_databases/ptb_xl.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import wfdb +from numpy.typing import NDArray from tqdm.auto import tqdm from ...cfg import DEFAULTS @@ -396,7 +397,7 @@ def _ls_rec(self) -> None: # Fix potential bugs in the database self._fix_bugs() - def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12sl") -> np.ndarray: + def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12sl") -> NDArray: """Load the data of a record. Parameters @@ -426,7 +427,7 @@ def load_data(self, rec: Union[str, int], source: Literal["12sl", "unig"] = "12s return wfdb.rdrecord(path).p_signal @add_docstring(load_data.__doc__) - def load_median_beats(self, rec: Union[str, int], source: str = "12sl") -> np.ndarray: + def load_median_beats(self, rec: Union[str, int], source: str = "12sl") -> NDArray: """alias of `load_data`.""" return self.load_data(rec, source) diff --git a/torch_ecg/databases/physionet_databases/qtdb.py b/torch_ecg/databases/physionet_databases/qtdb.py index 4185446c..d3795b81 100644 --- a/torch_ecg/databases/physionet_databases/qtdb.py +++ b/torch_ecg/databases/physionet_databases/qtdb.py @@ -5,6 +5,7 @@ import numpy as np import wfdb +from numpy.typing import NDArray from ...cfg import CFG from ...utils.misc import add_docstring @@ -280,7 +281,7 @@ def load_wave_ann( keep_original: bool = False, ignore_beat_types: bool = True, extension: str = "q1c", - ) -> np.ndarray: + ) -> NDArray: """alias of self.load_ann""" return self.load_ann( rec, @@ -299,7 +300,7 @@ def load_wave_masks( mask_format: str = "channel_first", class_map: Optional[Dict[str, int]] = None, extension: str = "q1c", - ) -> np.ndarray: + ) -> NDArray: """Load the wave delineation in the form of masks. Parameters @@ -341,7 +342,7 @@ def load_rhythm_ann( rhythm_types: Optional[Sequence[str]] = None, keep_original: bool = False, extension: Literal["atr", "man"] = "atr", - ) -> Union[Dict[str, list], np.ndarray]: + ) -> Union[Dict[str, list], NDArray]: """Load rhythm annotations of a record. Rhythm annotations are stored in the `aux_note` attribute @@ -385,7 +386,7 @@ def load_beat_ann( beat_types: Optional[Sequence[str]] = None, keep_original: bool = False, extension: Literal["atr", "man"] = "atr", - ) -> Union[Dict[str, np.ndarray], List[BeatAnn]]: + ) -> Union[Dict[str, NDArray], List[BeatAnn]]: """Load beat annotations of the record. Beat annotations are stored in the `symbol` attribute @@ -453,7 +454,7 @@ def load_rpeak_indices( sampto: Optional[int] = None, keep_original: bool = False, extension: Literal["atr", "man"] = "atr", - ) -> np.ndarray: + ) -> NDArray: """Load rpeak indices of the record. Rpeak indices, or equivalently qrs complex locations, @@ -509,15 +510,15 @@ def load_rpeak_indices( def plot( self, rec: Union[str, int], - data: Optional[np.ndarray] = None, + data: Optional[NDArray] = None, ticks_granularity: int = 0, leads: Optional[Union[str, int, List[str], List[int]]] = None, sampfrom: Optional[int] = None, sampto: Optional[int] = None, same_range: bool = False, waves: Optional[ECGWaveForm] = None, - beat_ann: Optional[Dict[str, np.ndarray]] = None, - rpeak_inds: Optional[Union[Sequence[int], np.ndarray]] = None, + beat_ann: Optional[Dict[str, NDArray]] = None, + rpeak_inds: Optional[Union[Sequence[int], NDArray]] = None, **kwargs: Any, ) -> None: """ diff --git a/torch_ecg/models/_nets.py b/torch_ecg/models/_nets.py index 6e4ac52f..4a92825e 100644 --- a/torch_ecg/models/_nets.py +++ b/torch_ecg/models/_nets.py @@ -8,7 +8,7 @@ from itertools import repeat from math import sqrt from numbers import Real -from typing import Any, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Literal, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -105,7 +105,7 @@ Activations.relu6 = nn.ReLU6 Activations.rrelu = nn.RReLU Activations.leaky = nn.LeakyReLU -Activations.leaky_relu = Activations.leaky +Activations.leaky_relu = Activations.leaky # type: ignore Activations.gelu = nn.GELU Activations.silu = nn.SiLU Activations.elu = nn.ELU @@ -121,7 +121,7 @@ # Activations.linear = None -def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict] = None) -> Optional[nn.Module]: +def get_activation(act: Union[str, nn.Module, None], kw_act: Optional[dict] = None) -> Optional[nn.Module]: """Get the class or instance of the activation. Parameters @@ -159,7 +159,7 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict else: raise ValueError(f"activation `{act}` not supported") if kw_act is None: - return _act + return _act # type: ignore return _act(**kw_act) @@ -167,15 +167,15 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict # normalizations Normalizations = CFG() Normalizations.batch_norm = nn.BatchNorm1d -Normalizations.batch_normalization = Normalizations.batch_norm +Normalizations.batch_normalization = Normalizations.batch_norm # type: ignore Normalizations.group_norm = nn.GroupNorm -Normalizations.group_normalization = Normalizations.group_norm +Normalizations.group_normalization = Normalizations.group_norm # type: ignore Normalizations.layer_norm = nn.LayerNorm -Normalizations.layer_normalization = Normalizations.layer_norm +Normalizations.layer_normalization = Normalizations.layer_norm # type: ignore Normalizations.instance_norm = nn.InstanceNorm1d -Normalizations.instance_normalization = Normalizations.instance_norm +Normalizations.instance_normalization = Normalizations.instance_norm # type: ignore Normalizations.local_response_norm = nn.LocalResponseNorm -Normalizations.local_response_normalization = Normalizations.local_response_norm +Normalizations.local_response_normalization = Normalizations.local_response_norm # type: ignore # other normalizations: # weight normalization # batch re-normalization @@ -189,7 +189,7 @@ def get_activation(act: Union[str, nn.Module, type(None)], kw_act: Optional[dict # problem: parameters of different normalizations are different -def get_normalization(norm: Union[str, nn.Module, type(None)], kw_norm: Optional[dict] = None) -> Optional[nn.Module]: +def get_normalization(norm: Union[str, nn.Module, None], kw_norm: Optional[dict] = None) -> Optional[nn.Module]: """Get the class or instance of the normalization. Parameters @@ -224,12 +224,12 @@ def get_normalization(norm: Union[str, nn.Module, type(None)], kw_norm: Optional else: raise ValueError(f"normalization `{norm}` not supported") if kw_norm is None: - return _norm + return _norm # type: ignore if "num_channels" in get_required_args(_norm) and "num_features" in kw_norm: # for some normalizations, the argument name is `num_channels` # instead of `num_features`, e.g., `torch.nn.GroupNorm` kw_norm["num_channels"] = kw_norm.pop("num_features") - return _norm(**kw_norm) + return _norm(**kw_norm) # type: ignore # --------------------------------------------- @@ -440,7 +440,7 @@ def __init__( groups: int = 1, batch_norm: Union[bool, str, nn.Module] = True, activation: Optional[Union[str, nn.Module]] = None, - kernel_initializer: Optional[Union[str, callable]] = None, + kernel_initializer: Optional[Union[str, Callable]] = None, bias: bool = True, ordering: Optional[str] = None, **kwargs: Any, @@ -602,28 +602,28 @@ def __init__( if bn_layer: self.add_module("batch_norm", bn_layer) if act_layer: - self.add_module(act_name, act_layer) + self.add_module(act_name, act_layer) # type: ignore elif self.__ordering in ["cab"]: self.add_module("conv1d", conv_layer) if self.__stride == 1 and self.__kernel_size % 2 == 0: self.__asymmetric_padding = (1, 0) self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding)) if act_layer: - self.add_module(act_name, act_layer) + self.add_module(act_name, act_layer) # type: ignore if bn_layer: self.add_module("batch_norm", bn_layer) elif self.__ordering in ["bac", "bc"]: if bn_layer: self.add_module("batch_norm", bn_layer) if act_layer: - self.add_module(act_name, act_layer) + self.add_module(act_name, act_layer) # type: ignore self.add_module("conv1d", conv_layer) if self.__stride == 1 and self.__kernel_size % 2 == 0: self.__asymmetric_padding = (1, 0) self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding)) elif self.__ordering in ["acb", "ac"]: if act_layer: - self.add_module(act_name, act_layer) + self.add_module(act_name, act_layer) # type: ignore self.add_module("conv1d", conv_layer) if self.__stride == 1 and self.__kernel_size % 2 == 0: self.__asymmetric_padding = (1, 0) @@ -638,7 +638,7 @@ def __init__( self.__asymmetric_padding = (1, 0) self.add_module("zero_pad", ZeroPad1d(self.__asymmetric_padding)) if act_layer: - self.add_module(act_name, act_layer) + self.add_module(act_name, act_layer) # type: ignore elif self.__ordering in ["c"]: # only convolution self.add_module("conv1d", conv_layer) @@ -745,6 +745,8 @@ def compute_output_shape( *output_shape[:-1], output_shape[-1] + sum(self.__asymmetric_padding), ) + else: + output_shape = [None] return output_shape @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) @@ -754,7 +756,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[ strides=self.__stride, dilations=self.__dilation, input_len=input_len, - fs=fs, + fs=fs, # type: ignore ) @property @@ -883,7 +885,7 @@ def __init__( if isinstance(dropouts, (Real, dict)): _dropouts = list(repeat(dropouts, self.__num_convs)) else: - _dropouts = list(dropouts) + _dropouts = list(dropouts) # type: ignore assert ( len(_dropouts) == self.__num_convs ), f"`dropouts` must be a real number or dict or sequence of real numbers of length {self.__num_convs}" @@ -896,7 +898,7 @@ def __init__( len(_dilations) == self.__num_convs ), f"`dilations` must be of type int or sequence of int of length {self.__num_convs}" - __ordering = self.config.ordering.lower() + __ordering = self.config.ordering.lower() # type: ignore if "a" in __ordering and __ordering.index("a") < __ordering.index("c"): in_activation = out_activation out_activation = True @@ -905,7 +907,7 @@ def __init__( conv_in_channels = self.__in_channels for idx, (oc, ks, sd, dl, dp) in enumerate(zip(self.__out_channels, kernel_sizes, strides, _dilations, _dropouts)): - activation = self.config.activation + activation = self.config.activation # type: ignore if idx == 0 and not in_activation: activation = None if idx == self.__num_convs - 1 and not out_activation: @@ -921,15 +923,15 @@ def __init__( groups=groups, norm=self.config.get("norm", self.config.get("batch_norm")), activation=activation, - kw_activation=self.config.kw_activation, - kernel_initializer=self.config.kernel_initializer, - kw_initializer=self.config.kw_initializer, - ordering=self.config.ordering, - conv_type=self.config.conv_type, - width_multiplier=self.config.width_multiplier, + kw_activation=self.config.kw_activation, # type: ignore + kernel_initializer=self.config.kernel_initializer, # type: ignore + kw_initializer=self.config.kw_initializer, # type: ignore + ordering=self.config.ordering, # type: ignore + conv_type=self.config.conv_type, # type: ignore + width_multiplier=self.config.width_multiplier, # type: ignore ), ) - conv_in_channels = int(oc * self.config.width_multiplier) + conv_in_channels = int(oc * self.config.width_multiplier) # type: ignore if isinstance(dp, dict): if dp["type"] == "1d" and dp["p"] > 0: self.add_module( @@ -956,7 +958,7 @@ def compute_output_shape( if hasattr(module, "__name__") and module.__name__ == Conv_Bn_Activation.__name__: output_shape = module.compute_output_shape(_seq_len, batch_size) _, _, _seq_len = output_shape - return output_shape + return output_shape # type: ignore @add_docstring(_COMPUTE_RECEPTIVE_FIELD_DOC.replace("layer", "block")) def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[Real] = None) -> Union[int, float]: @@ -971,7 +973,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[ strides=strides, dilations=dilations, input_len=input_len, - fs=fs, + fs=fs, # type: ignore ) @property @@ -1052,7 +1054,7 @@ def __init__( if isinstance(dropouts, (Real, dict)): _dropouts = list(repeat(dropouts, self.__num_branches)) else: - _dropouts = list(dropouts) + _dropouts = list(dropouts) # type: ignore assert ( len(_dropouts) == self.__num_branches ), f"`dropouts` must be a real number or dict or sequence of real numbers of length {self.__num_branches}" @@ -1203,7 +1205,7 @@ def __init__( padding: Optional[int] = None, dilation: int = 1, groups: int = 1, - kernel_initializer: Optional[Union[str, callable]] = None, + kernel_initializer: Optional[Union[str, Callable]] = None, bias: bool = True, **kwargs: Any, ) -> None: @@ -1320,7 +1322,7 @@ def compute_receptive_field(self, input_len: Optional[int] = None, fs: Optional[ strides=[self.__stride, 1], dilations=[self.__dilation, 1], input_len=input_len, - fs=fs, + fs=fs, # type: ignore ) @property @@ -1658,7 +1660,7 @@ def compute_output_shape( stride=self.__down_scale, padding=self.__padding, )[-1] - output_shape = (batch_size, self.__out_channels, out_seq_len) + output_shape = (batch_size, self.__out_channels, out_seq_len) # type: ignore return output_shape @property @@ -1691,7 +1693,7 @@ def __init__(self, padding: Union[int, Sequence[int]]) -> None: and all([i >= 0 for i in padding]) ), "`padding` must be non-negative int or a 2-sequence of non-negative int" padding = list(repeat(padding, 2)) if isinstance(padding, int) else padding - super().__init__(padding, 0.0) + super().__init__(padding, 0.0) # type: ignore def compute_output_shape( self, @@ -1854,7 +1856,7 @@ def _get_pad_layer(self) -> nn.Module: PadLayer = ZeroPad1d else: raise NotImplementedError(f"Padding type of `{self.__pad_type}` is not implemented") - return PadLayer(self.__pad_sizes) + return PadLayer(self.__pad_sizes) # type: ignore @add_docstring(_COMPUTE_OUTPUT_SHAPE_DOC) def compute_output_shape( @@ -2196,10 +2198,11 @@ def forward( module.flatten_parameters() output, _hx = module(output, _hx) if self.return_sequences: - final_output = output # seq_len, batch_size, n_direction*hidden_size + final_output = output # seq_len, batch_size, n_direction * hidden_size else: - final_output = output[-1, ...] # batch_size, n_direction*hidden_size - return final_output + # batch_size, n_direction * hidden_size + final_output = output[-1, ...] # type: ignore + return final_output # type: ignore @add_docstring(_COMPUTE_OUTPUT_SHAPE_DOC) def compute_output_shape( @@ -2244,12 +2247,12 @@ def __init__(self, in_channels: int, bias: bool = True, initializer: str = "glor self.init(self.W) self.u = Parameter(torch.Tensor(in_channels)) - Initializers.constant(self.u, 1 / in_channels) + Initializers.constant(self.u, 1 / in_channels) # type: ignore # self.init(self.u) if self.bias: self.b = Parameter(torch.Tensor(in_channels)) - Initializers.zeros(self.b) + Initializers.zeros(self.b) # type: ignore # Initializers["zeros"](self.b) else: self.register_parameter("b", None) @@ -2346,7 +2349,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tens return output -@add_docstring(nn.MultiheadAttention.__doc__, "append") +@add_docstring(nn.MultiheadAttention.__doc__, "append") # type: ignore class MultiHeadAttention(nn.MultiheadAttention, SizeMixin): """Multi-head attention. @@ -2617,7 +2620,7 @@ def forward(self, input: Tensor) -> Tensor: input = input.permute(0, 2, 1) # -> (batch_size, seq_len, n_channels scores = self.dropout(input) scores = self.mid_linear(scores) # -> (batch_size, seq_len, n_channels) - scores = self.activation(scores) # -> (batch_size, seq_len, n_channels) + scores = self.activation(scores) # -> (batch_size, seq_len, n_channels) # type: ignore scores = self.contraction(scores) # -> (batch_size, seq_len, 1) scores = scores.squeeze(-1) # -> (batch_size, seq_len) scores = self.softmax(scores) # -> (batch_size, seq_len) @@ -2779,9 +2782,9 @@ def __init__( self.__dropouts = [dropouts] else: self.__dropouts = dropouts - assert len(self.__dropouts) == self.__num_layers, ( + assert len(self.__dropouts) == self.__num_layers, ( # type: ignore f"`out_channels` indicates `{self.__num_layers}` linear layers, " - f"while `dropouts` indicates `{len(self.__dropouts)}`" + f"while `dropouts` indicates `{len(self.__dropouts)}`" # type: ignore ) self.__skip_last_activation = kwargs.get("skip_last_activation", False) @@ -2803,10 +2806,10 @@ def __init__( f"act_{idx}", act_layer(**kw_activation), ) - if self.__dropouts[idx] > 0: + if self.__dropouts[idx] > 0: # type: ignore self.add_module( f"dropout_{idx}", - nn.Dropout(self.__dropouts[idx]), + nn.Dropout(self.__dropouts[idx]), # type: ignore ) lin_in_channels = self.__out_channels[idx] @@ -3091,10 +3094,10 @@ def __init__(self, in_channels: int, reduction: int = 16, **config) -> None: SeqLin( in_channels=self.__in_channels, out_channels=[self.__mid_channels, self.__out_channels], - activation=self.config.activation, - kw_activation=self.config.kw_activation, - bias=self.config.bias, - dropouts=self.config.dropouts, + activation=self.config.activation, # type: ignore + kw_activation=self.config.kw_activation, # type: ignore + bias=self.config.bias, # type: ignore + dropouts=self.config.dropouts, # type: ignore skip_last_activation=True, ), nn.Sigmoid(), @@ -3287,7 +3290,7 @@ def spatial_pool(self, x: Tensor) -> Tensor: context = context.squeeze(1) # --> (batch_size, n_channels, 1) elif self.__pooling_type == "avg": context = self.avg_pool(x) # --> (batch_size, n_channels, 1) - return context + return context # type: ignore def forward(self, input: Tensor) -> Tensor: """ @@ -3416,7 +3419,7 @@ def __init__( self.__gate_channels // reduction, self.__gate_channels, ], - activation=activation, + activation=activation, # type: ignore skip_last_activation=True, ) # self.channel_gate_act = nn.Sigmoid() @@ -3466,7 +3469,7 @@ def _fwd_channel_gate(self, input: Tensor) -> Tensor: else: channel_att_sum = channel_att_sum + channel_att_raw # scale = torch.sigmoid(channel_att_sum) - scale = self.channel_gate_act(channel_att_sum) + scale = self.channel_gate_act(channel_att_sum) # type: ignore output = scale.unsqueeze(-1) * input return output @@ -3656,8 +3659,8 @@ def __repr__(self) -> str: def neg_log_likelihood( self, emissions: Tensor, - tags: torch.LongTensor, - mask: Optional[torch.ByteTensor] = None, + tags: torch.Tensor, + mask: Optional[torch.Tensor] = None, reduction: str = "sum", ) -> Tensor: """ @@ -3668,9 +3671,9 @@ def neg_log_likelihood( ---------- emissions: Tensor, emission score tensor of shape (seq_len, batch_size, num_tags) - tags: torch.LongTensor, + tags: torch.Tensor, sequence of tags tensor of shape (seq_len, batch_size) - mask: torch.ByteTensor, + mask: torch.Tensor, mask tensor of shape (seq_len, batch_size) reduction: str, default "sum", specifies the reduction to apply to the output: @@ -3720,7 +3723,7 @@ def neg_log_likelihood( nll = nll.sum() / mask.float().sum() return nll - def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) -> Tensor: + def forward(self, emissions: Tensor, mask: Optional[torch.Tensor] = None) -> Tensor: """ Find the most likely tag sequence using Viterbi algorithm. @@ -3730,7 +3733,7 @@ def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) -> emission score tensor, of shape (seq_len, batch_size, num_tags) if batch_first is False, of shape (batch_size, seq_len, num_tags) if batch_first is True. - mask: torch.ByteTensor + mask: torch.Tensor mask tensor of shape (seq_len, batch_size) if batch_first is False, of shape (batch_size, seq_len) if batch_first is True. @@ -3756,8 +3759,8 @@ def forward(self, emissions: Tensor, mask: Optional[torch.ByteTensor] = None) -> def _validate( self, emissions: Tensor, - tags: Optional[torch.LongTensor] = None, - mask: Optional[torch.ByteTensor] = None, + tags: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, ) -> None: """Check validity of input :class:`~torch.Tensor`.""" if emissions.dim() != 3: @@ -3783,7 +3786,7 @@ def _validate( if not no_empty_seq and not no_empty_seq_bf: raise ValueError("mask of the first timestep must all be on") - def _compute_score(self, emissions: Tensor, tags: torch.LongTensor, mask: torch.ByteTensor) -> Tensor: + def _compute_score(self, emissions: Tensor, tags: torch.Tensor, mask: torch.Tensor) -> Tensor: """ # emissions: (seq_len, batch_size, num_tags) # tags: (seq_len, batch_size) @@ -3827,7 +3830,7 @@ def _compute_score(self, emissions: Tensor, tags: torch.LongTensor, mask: torch. return score - def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) -> Tensor: + def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.Tensor) -> Tensor: """ # emissions: (seq_len, batch_size, num_tags) # mask: (seq_len, batch_size) @@ -3884,7 +3887,7 @@ def _compute_normalizer(self, emissions: torch.Tensor, mask: torch.ByteTensor) - # shape: (batch_size,) return torch.logsumexp(score, dim=1) - def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) -> List[List[int]]: + def _viterbi_decode(self, emissions: torch.Tensor, mask: torch.Tensor) -> List[List[int]]: """ # emissions: (seq_len, batch_size, num_tags) # mask: (seq_len, batch_size) @@ -4251,18 +4254,18 @@ def __init__( self.wordvec_proj = nn.Identity() self.decoder.duplicate_pooling = Parameter(Tensor(decoder_embedding, 1)) self.decoder.duplicate_pooling_bias = Parameter(Tensor(1)) - self.decoder.duplicate_factor = 1 + self.decoder.duplicate_factor = 1 # type: ignore else: # group fully-connected - self.decoder.out_channels = out_channels - self.decoder.duplicate_factor = int(out_channels / embed_len_decoder + 0.999) + self.decoder.out_channels = out_channels # type: ignore + self.decoder.duplicate_factor = int(out_channels / embed_len_decoder + 0.999) # type: ignore self.decoder.duplicate_pooling = Parameter( Tensor(embed_len_decoder, decoder_embedding, self.decoder.duplicate_factor) ) self.decoder.duplicate_pooling_bias = Parameter(Tensor(out_channels)) nn.init.xavier_normal_(self.decoder.duplicate_pooling) nn.init.constant_(self.decoder.duplicate_pooling_bias, 0) - self.decoder.group_fc = _GroupFC(embed_len_decoder) + self.decoder.group_fc = _GroupFC(embed_len_decoder) # type: ignore self.train_wordvecs = None self.test_wordvecs = None @@ -4427,16 +4430,16 @@ def make_attention_layer(in_channels: int, **config: dict) -> nn.Module: """ key = "name" if "name" in config else "type" assert key in config, "config must contain key 'name' or 'type'" - name = config[key].lower() + name = config[key].lower() # type: ignore config.pop(key) if name in ["se"]: - return SEBlock(in_channels, **config) + return SEBlock(in_channels, **config) # type: ignore elif name in ["gc"]: - return GlobalContextBlock(in_channels, **config) + return GlobalContextBlock(in_channels, **config) # type: ignore elif name in ["nl", "non-local", "nonlocal", "non_local"]: - return NonLocalBlock(in_channels, **config) + return NonLocalBlock(in_channels, **config) # type: ignore elif name in ["cbam"]: - return CBAMBlock(in_channels, **config) + return CBAMBlock(in_channels, **config) # type: ignore elif name in ["ca"]: # NOT IMPLEMENTED return CoordAttention(in_channels, **config) diff --git a/torch_ecg/models/cnn/densenet.py b/torch_ecg/models/cnn/densenet.py index 333ff894..b774602c 100644 --- a/torch_ecg/models/cnn/densenet.py +++ b/torch_ecg/models/cnn/densenet.py @@ -585,41 +585,40 @@ class DenseNet(nn.Sequential, SizeMixin, CitationMixin): config : dict Other hyper-parameters of the Module, ref. corresponding config file. Keyword arguments that must be set are as follows: - - - num_layers: sequence of int, - number of building block layers of each dense (macro) block - - init_num_filters: sequence of int, - number of filters of the first convolutional layer - - init_filter_length: sequence of int, - filter length (kernel size) of the first convolutional layer - - init_conv_stride: int, - stride of the first convolutional layer - - init_pool_size: int, - pooling kernel size of the first pooling layer - - init_pool_stride: int, - pooling stride of the first pooling layer - - growth_rates: int or sequence of int or sequence of sequences of int, - growth rates of the building blocks, - with granularity to the whole network, or to each dense (macro) block, - or to each building block - - filter_lengths: int or sequence of int or sequence of sequences of int, - filter length(s) (kernel size(s)) of the convolutions, - with granularity to the whole network, or to each macro block, - or to each building block - - subsample_lengths: int or sequence of int, - subsampling length(s) (ratio(s)) of the transition blocks - - compression: float, - compression factor of the transition blocks - - bn_size: int, - bottleneck base width, used only when building block is :class:`DenseBottleNeck` - - dropouts: float or dict, - dropout ratio of each building block - - groups: int, - connection pattern (of channels) of the inputs and outputs - - block: dict, - other parameters that can be set for the building blocks - - For a full list of configurable parameters, ref. corr. config file + - num_layers: sequence of int, + number of building block layers of each dense (macro) block. + - init_num_filters: sequence of int, + number of filters of the first convolutional layer. + - init_filter_length: sequence of int, + filter length (kernel size) of the first convolutional layer. + - init_conv_stride: int, + stride of the first convolutional layer. + - init_pool_size: int, + pooling kernel size of the first pooling layer. + - init_pool_stride: int, + pooling stride of the first pooling layer. + - growth_rates: int or sequence of int or sequence of sequences of int, + growth rates of the building blocks, + with granularity to the whole network, or to each dense (macro) block, + or to each building block. + - filter_lengths: int or sequence of int or sequence of sequences of int, + filter length(s) (kernel size(s)) of the convolutions, + with granularity to the whole network, or to each macro block, + or to each building block. + - subsample_lengths: int or sequence of int, + subsampling length(s) (ratio(s)) of the transition blocks. + - compression: float, + compression factor of the transition blocks. + - bn_size: int, + bottleneck base width, used only when building block is :class:`DenseBottleNeck`. + - dropouts: float or dict, + dropout ratio of each building block. + - groups: int, + connection pattern (of channels) of the inputs and outputs. + - block: dict, + other parameters that can be set for the building blocks. + + For a full list of configurable parameters, ref. corr. config file. NOTE ---- diff --git a/torch_ecg/models/cnn/mobilenet.py b/torch_ecg/models/cnn/mobilenet.py index b50a270e..86fee0a9 100644 --- a/torch_ecg/models/cnn/mobilenet.py +++ b/torch_ecg/models/cnn/mobilenet.py @@ -214,19 +214,18 @@ class MobileNetV1(nn.Sequential, SizeMixin, CitationMixin): Other hyper-parameters of the Module, ref. corresponding config file. key word arguments that have to be set in 3 sub-dict, namely in "entry_flow", "middle_flow", and "exit_flow", including - - - out_channels: int, - number of channels of the output. - - kernel_size: int, - kernel size of down sampling. - If not specified, defaults to `down_scale`. - - groups: int, - connection pattern (of channels) of the inputs and outputs. - - padding: int, - zero-padding added to both sides of the input. - - batch_norm: bool or Module, - batch normalization, the Module itself - or (if is bool) whether or not to use :class:`torch.nn.BatchNorm1d`. + - out_channels: int, + number of channels of the output. + - kernel_size: int, + kernel size of down sampling. + If not specified, defaults to `down_scale`. + - groups: int, + connection pattern (of channels) of the inputs and outputs. + - padding: int, + zero-padding added to both sides of the input. + - batch_norm: bool or Module, + batch normalization, the Module itself + or (if is bool) whether or not to use :class:`torch.nn.BatchNorm1d`. References ---------- @@ -457,9 +456,9 @@ class InvertedResidual(nn.Module, SizeMixin): If is None, no attention mechanism is used. Keys: - - "name": str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. - - "pos": int, position of the attention mechanism, - other keys are specific to the attention mechanism. + - "name": str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. + - "pos": int, position of the attention mechanism, + other keys are specific to the attention mechanism. """ @@ -669,36 +668,36 @@ class MobileNetV2(nn.Sequential, SizeMixin, CitationMixin): - stem: CFG, config of the stem block, with the following keys: - - num_filters: int or Sequence[int], - number of filters in the first convolutional layer(s). - - filter_lengths: int or Sequence[int], - filter lengths (kernel sizes) in the first convolutional layer(s). - - subsample_lengths: int or Sequence[int], - subsample lengths (strides) in the first convolutional layer(s). + - num_filters: int or Sequence[int], + number of filters in the first convolutional layer(s). + - filter_lengths: int or Sequence[int], + filter lengths (kernel sizes) in the first convolutional layer(s). + - subsample_lengths: int or Sequence[int], + subsample lengths (strides) in the first convolutional layer(s). - inv_res: CFG, Config of the inverted residual blocks, with the following keys: - - expansions: Sequence[int], - expansion ratios of the inverted residual blocks. - - out_channels: Sequence[int], - number of output channels in each block. - - n_blocks: Sequence[int], - number of inverted residual blocks. - - strides: Sequence[int], - strides of the inverted residual blocks. - - filter_lengths: Sequence[int], - filter lengths (kernel sizes) in each block. + - expansions: Sequence[int], + expansion ratios of the inverted residual blocks. + - out_channels: Sequence[int], + number of output channels in each block. + - n_blocks: Sequence[int], + number of inverted residual blocks. + - strides: Sequence[int], + strides of the inverted residual blocks. + - filter_lengths: Sequence[int], + filter lengths (kernel sizes) in each block. - exit_flow: CFG, Config of the exit flow blocks, with the following keys: - - num_filters: int or Sequence[int], - number of filters in the final convolutional layer(s). - - filter_lengths: int or Sequence[int], - filter lengths (kernel sizes) in the final convolutional layer(s). - - subsample_lengths: int or Sequence[int], - subsample lengths (strides) in the final convolutional layer(s). + - num_filters: int or Sequence[int], + number of filters in the final convolutional layer(s). + - filter_lengths: int or Sequence[int], + filter lengths (kernel sizes) in the final convolutional layer(s). + - subsample_lengths: int or Sequence[int], + subsample lengths (strides) in the final convolutional layer(s). References ---------- @@ -1000,12 +999,12 @@ class MobileNetV3_STEM(nn.Sequential, SizeMixin): config : CFG, optional Config of the stem block, with the following items: - - num_filters: int or Sequence[int], - number of filters in the first convolutional layer(s). - - filter_lengths: int or Sequence[int], - filter lengths (kernel sizes) in the first convolutional layer(s). - - subsample_lengths: int or Sequence[int], - subsample lengths (strides) in the first convolutional layer(s). + - num_filters: int or Sequence[int], + number of filters in the first convolutional layer(s). + - filter_lengths: int or Sequence[int], + filter lengths (kernel sizes) in the first convolutional layer(s). + - subsample_lengths: int or Sequence[int], + subsample lengths (strides) in the first convolutional layer(s). """ @@ -1088,64 +1087,64 @@ class MobileNetV3(nn.Sequential, SizeMixin, CitationMixin): Other hyper-parameters of the Module, ref. corresponding config file. Keyword arguments that must be set: - - groups: int, - number of groups in the convolutional layer(s) other than depthwise convolutions. - - norm: bool or str or Module, - normalization layer. - - bias: bool, - whether to use bias in the convolutional layer(s). - - width_multiplier: float, - multiplier of the number of output channels of the pointwise convolution. - - stem: CFG, - config of the stem block, with the following keys: - - - num_filters: int or Sequence[int], - number of filters in the first convolutional layer(s). - - filter_lengths: int or Sequence[int], - filter lengths (kernel sizes) in the first convolutional layer(s). - - subsample_lengths: int or Sequence[int], - subsample lengths (strides) in the first convolutional layer(s). - - - inv_res: CFG, - config of the inverted residual blocks, with the following keys: - - - in_channels: Sequence[int], - number of input channels. - - n_blocks: Sequence[int], - number of inverted residual blocks. - - expansions: sequence of floats or sequence of sequence of floats, - expansion ratios of the inverted residual blocks. - - filter_lengths: sequence of ints or sequence of sequence of ints, - filter length of the depthwise convolution in the inverted residual blocks. - - stride: sequence of ints or sequence of sequence of ints, optional, - stride of the depthwise convolution in the inverted residual blocks, - defaults to ``[2] + [1] * (n_blocks - 1)``. - - groups: int, default 1, - number of groups in the expansion and pointwise convolution - in the inverted residual blocks. - - dilation: sequence of ints or sequence of sequence of ints, optional, - dilation of the depthwise convolution in the inverted residual blocks. - - batch_norm: bool or str or nn.Module, default True, - normalization layer to use, defaults to batch normalization. - - activation: str or nn.Module or sequence of str or torch.nn.Module, - activation function to use. - - width_multiplier: float or sequence of floats, default 1.0, - width multiplier of the inverted residual blocks. - - out_channels: sequence of ints or sequence of Sequence[int], optional, - number of output channels of the inverted residual blocks, - defaults to ``2 * in_channels``. - - attn: sequence of CFG or sequence of sequence of CFG, optional, - config of attention layer to use, defaults to None. - - - exit_flow: CFG, - config of the exit flow blocks, with the following keys: - - - num_filters: int or Sequence[int], - number of filters in the final convolutional layer(s). - - filter_lengths: int or Sequence[int], - filter lengths (kernel sizes) in the final convolutional layer(s). - - subsample_lengths: int or Sequence[int], - subsample lengths (strides) in the final convolutional layer(s). + - groups: int, + number of groups in the convolutional layer(s) other than depthwise convolutions. + - norm: bool or str or Module, + normalization layer. + - bias: bool, + whether to use bias in the convolutional layer(s). + - width_multiplier: float, + multiplier of the number of output channels of the pointwise convolution. + - stem: CFG, + config of the stem block, with the following keys: + + - num_filters: int or Sequence[int], + number of filters in the first convolutional layer(s). + - filter_lengths: int or Sequence[int], + filter lengths (kernel sizes) in the first convolutional layer(s). + - subsample_lengths: int or Sequence[int], + subsample lengths (strides) in the first convolutional layer(s). + + - inv_res: CFG, + config of the inverted residual blocks, with the following keys: + + - in_channels: Sequence[int], + number of input channels. + - n_blocks: Sequence[int], + number of inverted residual blocks. + - expansions: sequence of floats or sequence of sequence of floats, + expansion ratios of the inverted residual blocks. + - filter_lengths: sequence of ints or sequence of sequence of ints, + filter length of the depthwise convolution in the inverted residual blocks. + - stride: sequence of ints or sequence of sequence of ints, optional, + stride of the depthwise convolution in the inverted residual blocks, + defaults to ``[2] + [1] * (n_blocks - 1)``. + - groups: int, default 1, + number of groups in the expansion and pointwise convolution + in the inverted residual blocks. + - dilation: sequence of ints or sequence of sequence of ints, optional, + dilation of the depthwise convolution in the inverted residual blocks. + - batch_norm: bool or str or nn.Module, default True, + normalization layer to use, defaults to batch normalization. + - activation: str or nn.Module or sequence of str or torch.nn.Module, + activation function to use. + - width_multiplier: float or sequence of floats, default 1.0, + width multiplier of the inverted residual blocks. + - out_channels: sequence of ints or sequence of Sequence[int], optional, + number of output channels of the inverted residual blocks, + defaults to ``2 * in_channels``. + - attn: sequence of CFG or sequence of sequence of CFG, optional, + config of attention layer to use, defaults to None. + + - exit_flow: CFG, + config of the exit flow blocks, with the following keys: + + - num_filters: int or Sequence[int], + number of filters in the final convolutional layer(s). + - filter_lengths: int or Sequence[int], + filter lengths (kernel sizes) in the final convolutional layer(s). + - subsample_lengths: int or Sequence[int], + subsample lengths (strides) in the final convolutional layer(s). References ---------- diff --git a/torch_ecg/models/cnn/multi_scopic.py b/torch_ecg/models/cnn/multi_scopic.py index af2c81b0..fb5b61bf 100644 --- a/torch_ecg/models/cnn/multi_scopic.py +++ b/torch_ecg/models/cnn/multi_scopic.py @@ -368,28 +368,28 @@ class MultiScopicCNN(nn.Module, SizeMixin, CitationMixin): Other hyper-parameters of the Module, ref. corr. config file. Key word arguments that must be set: - - scopes: sequence of sequences of sequences of :obj:`int`, - scopes (in terms of dilation) of each convolution. - - num_filters: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`), - number of filters of the convolutional layers, - with granularity to each block of each branch, - or to each convolution of each block of each branch. - - filter_lengths: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`), - filter length(s) (kernel size(s)) of the convolutions, - with granularity to each block of each branch, - or to each convolution of each block of each branch. - - subsample_lengths: sequence of :obj:`int` or sequence of sequences of :obj:`int`, - subsampling length(s) (ratio(s)) of all blocks, - with granularity to each branch or to each block of each branch, - each subsamples after the last convolution of each block. - - dropouts: sequence of :obj:`int` or sequence of sequences of :obj:`int`, - dropout rates of all blocks, - with granularity to each branch or to each block of each branch, - each dropouts at the last of each block. - - groups: :obj:`int`, - connection pattern (of channels) of the inputs and outputs. - - block: :obj:`dict`, - other parameters that can be set for the building blocks. + - scopes: sequence of sequences of sequences of :obj:`int`, + scopes (in terms of dilation) of each convolution. + - num_filters: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`), + number of filters of the convolutional layers, + with granularity to each block of each branch, + or to each convolution of each block of each branch. + - filter_lengths: sequence of sequences (of :obj:`int` or of sequences of :obj:`int`), + filter length(s) (kernel size(s)) of the convolutions, + with granularity to each block of each branch, + or to each convolution of each block of each branch. + - subsample_lengths: sequence of :obj:`int` or sequence of sequences of :obj:`int`, + subsampling length(s) (ratio(s)) of all blocks, + with granularity to each branch or to each block of each branch, + each subsamples after the last convolution of each block. + - dropouts: sequence of :obj:`int` or sequence of sequences of :obj:`int`, + dropout rates of all blocks, + with granularity to each branch or to each block of each branch, + each dropouts at the last of each block. + - groups: :obj:`int`, + connection pattern (of channels) of the inputs and outputs. + - block: :obj:`dict`, + other parameters that can be set for the building blocks. For a full list of configurable parameters, ref. corr. config file. diff --git a/torch_ecg/models/cnn/regnet.py b/torch_ecg/models/cnn/regnet.py index 9469c21a..b0e3795b 100644 --- a/torch_ecg/models/cnn/regnet.py +++ b/torch_ecg/models/cnn/regnet.py @@ -42,49 +42,50 @@ class AnyStage(nn.Sequential, SizeMixin): Index of the stage in the whole :class:`RegNet`. block_config: dict, (optional) configs for the blocks, including - - block: str or torch.nn.Module, - the block class, can be one of - "bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc. - - expansion: int, - the expansion factor for the bottleneck block. - - increase_channels_method: str, - the method to increase the number of channels, - can be one of {"conv", "zero_padding"}. - - subsample_mode: str, - the mode of subsampling, can be one of - {:class:`DownSample`.__MODES__}, - - activation: str or torch.nn.Module, - the activation function, can be one of - {:class:`Activations`}. - - kw_activation: dict, - keyword arguments for the activation function. - - kernel_initializer: str, - the kernel initializer, can be one of - {:class:`Initializers`}. - - kw_initializer: dict, - keyword arguments for the kernel initializer. - - bias: bool, - whether to use bias in the convolution. - - dilation: int, - the dilation factor for the convolution. - - base_width: int, - number of filters per group for the neck conv layer - usually number of filters of the initial conv layer - of the whole :class:`RegNet`. - - base_groups: int, - pattern of connections between inputs and outputs of - conv layers at the two ends, which should divide `groups`. - - base_filter_length: int, - lengths (sizes) of the filter kernels for conv layers at the two ends. - - attn: dict, - attention mechanism for the neck conv layer. - If is None, no attention mechanism is used. - If is not None, it should be a dict with the following items: - - - name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. - - pos: int, position of the attention mechanism. - - Other keys are specific to the attention mechanism. + + - block: str or torch.nn.Module, + the block class, can be one of + "bottleneck", "bottle_neck", :class:`ResNetBottleNeck`, etc. + - expansion: int, + the expansion factor for the bottleneck block. + - increase_channels_method: str, + the method to increase the number of channels, + can be one of {"conv", "zero_padding"}. + - subsample_mode: str, + the mode of subsampling, can be one of + {:class:`DownSample`.__MODES__}, + - activation: str or torch.nn.Module, + the activation function, can be one of + {:class:`Activations`}. + - kw_activation: dict, + keyword arguments for the activation function. + - kernel_initializer: str, + the kernel initializer, can be one of + {:class:`Initializers`}. + - kw_initializer: dict, + keyword arguments for the kernel initializer. + - bias: bool, + whether to use bias in the convolution. + - dilation: int, + the dilation factor for the convolution. + - base_width: int, + number of filters per group for the neck conv layer + usually number of filters of the initial conv layer + of the whole :class:`RegNet`. + - base_groups: int, + pattern of connections between inputs and outputs of + conv layers at the two ends, which should divide `groups`. + - base_filter_length: int, + lengths (sizes) of the filter kernels for conv layers at the two ends. + - attn: dict, + attention mechanism for the neck conv layer. + If is None, no attention mechanism is used. + If is not None, it should be a dict with the following items: + + - name: str, can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. + - pos: int, position of the attention mechanism. + + Other keys are specific to the attention mechanism. """ @@ -283,31 +284,31 @@ class RegNet(nn.Sequential, SizeMixin, CitationMixin): Hyper-parameters of the Module, ref. corr. config file. Keyword arguments that must be set: - - filter_lengths: int or sequence of int, - filter length(s) (kernel size(s)) of the convolutions, - with granularity to the whole network, to each stage. - - subsample_lengths: int or sequence of int, - subsampling length(s) (ratio(s)) of all blocks, - with granularity to the whole network, to each stage. - - tot_blocks: int, - the total number of building blocks. - - w_a, w_0, w_m: float, - the parameters for the widths generating function. - - group_widths: int or sequence of int, - the number of channels in each group, - with granularity to the whole network, to each stage. - - num_blocks: sequence of int, optional, - the number of blocks in each stage, - if not given, will be computed from tot_blocks - and `w_a`, `w_0`, `w_m`. - - num_filters: int or sequence of int, optional, - the number of filters in each stage. - If not given, will be computed from tot_blocks - and `w_a`, `w_0`, `w_m`. - - stem: dict, - the config of the input stem. - - block: dict, - other parameters that can be set for the building blocks. + - filter_lengths: int or sequence of int, + filter length(s) (kernel size(s)) of the convolutions, + with granularity to the whole network, to each stage. + - subsample_lengths: int or sequence of int, + subsampling length(s) (ratio(s)) of all blocks, + with granularity to the whole network, to each stage. + - tot_blocks: int, + the total number of building blocks. + - w_a, w_0, w_m: float, + the parameters for the widths generating function. + - group_widths: int or sequence of int, + the number of channels in each group, + with granularity to the whole network, to each stage. + - num_blocks: sequence of int, optional, + the number of blocks in each stage, + if not given, will be computed from tot_blocks + and `w_a`, `w_0`, `w_m`. + - num_filters: int or sequence of int, optional, + the number of filters in each stage. + If not given, will be computed from tot_blocks + and `w_a`, `w_0`, `w_m`. + - stem: dict, + the config of the input stem. + - block: dict, + other parameters that can be set for the building blocks. """ diff --git a/torch_ecg/models/cnn/resnet.py b/torch_ecg/models/cnn/resnet.py index 551216cc..39122731 100644 --- a/torch_ecg/models/cnn/resnet.py +++ b/torch_ecg/models/cnn/resnet.py @@ -63,10 +63,10 @@ class ResNetBasicBlock(nn.Module, SizeMixin): If is None, no attention mechanism is used. keys: - - name: str, - can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. - - pos: int, - position of the attention mechanism. + - name: str, + can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. + - pos: int, + position of the attention mechanism. Other keys are specific to the attention mechanism. config : dict @@ -297,10 +297,10 @@ class ResNetBottleNeck(nn.Module, SizeMixin): If is None, no attention mechanism is used. Keys: - - name: str, - can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. - - pos: int, - position of the attention mechanism. + - name: str, + can be "se", "gc", "nl" (alias "nonlocal", "non-local"), etc. + - pos: int, + position of the attention mechanism. Other keys are specific to the attention mechanism. config : dict @@ -643,25 +643,25 @@ class ResNet(nn.Sequential, SizeMixin, CitationMixin): Hyper-parameters of the Module, ref. corr. config file. keyword arguments that must be set: - - bias: bool, - if True, each convolution will have a bias term. - - num_blocks: sequence of int, - number of building blocks in each macro block. - - filter_lengths: int or sequence of int or sequence of sequences of int, - filter length(s) (kernel size(s)) of the convolutions, - with granularity to the whole network, to each macro block, - or to each building block. - - subsample_lengths: int or sequence of int or sequence of sequences of int, - subsampling length(s) (ratio(s)) of all blocks, - with granularity to the whole network, to each macro block, - or to each building block, - the former 2 subsample at the first building block. - - groups: int, - connection pattern (of channels) of the inputs and outputs. - - stem: dict, - other parameters that can be set for the input stem. - - block: dict, - other parameters that can be set for the building blocks. + - bias: bool, + if True, each convolution will have a bias term. + - num_blocks: sequence of int, + number of building blocks in each macro block. + - filter_lengths: int or sequence of int or sequence of sequences of int, + filter length(s) (kernel size(s)) of the convolutions, + with granularity to the whole network, to each macro block, + or to each building block. + - subsample_lengths: int or sequence of int or sequence of sequences of int, + subsampling length(s) (ratio(s)) of all blocks, + with granularity to the whole network, to each macro block, + or to each building block, + the former 2 subsample at the first building block. + - groups: int, + connection pattern (of channels) of the inputs and outputs. + - stem: dict, + other parameters that can be set for the input stem. + - block: dict, + other parameters that can be set for the building blocks. For a full list of configurable parameters, ref. corr. config file. diff --git a/torch_ecg/models/cnn/vgg.py b/torch_ecg/models/cnn/vgg.py index e0fc8ea0..cad9e128 100644 --- a/torch_ecg/models/cnn/vgg.py +++ b/torch_ecg/models/cnn/vgg.py @@ -140,14 +140,14 @@ class VGG16(nn.Sequential, SizeMixin, CitationMixin): and more for :class:`VGGBlock`. Key word arguments that have to be set: - - num_convs: sequence of int, - number of convolutional layers for each :class:`VGGBlock`. - - num_filters: sequence of int, - number of filters for each :class:`VGGBlock`. - - groups: int, - connection pattern (of channels) of the inputs and outputs. - - block: dict, - other parameters that can be set for :class:`VGGBlock`. + - num_convs: sequence of int, + number of convolutional layers for each :class:`VGGBlock`. + - num_filters: sequence of int, + number of filters for each :class:`VGGBlock`. + - groups: int, + connection pattern (of channels) of the inputs and outputs. + - block: dict, + other parameters that can be set for :class:`VGGBlock`. For a full list of configurable parameters, ref. corr. config file. diff --git a/torch_ecg/models/ecg_crnn.py b/torch_ecg/models/ecg_crnn.py index 9ef14340..85627d4c 100644 --- a/torch_ecg/models/ecg_crnn.py +++ b/torch_ecg/models/ecg_crnn.py @@ -2,14 +2,15 @@ C(R)NN structure models, for classifying ECG arrhythmias, and other tasks. """ +import os import warnings from copy import deepcopy from typing import Any, List, Optional, Sequence, Tuple, Union -import numpy as np import torch from einops import rearrange from einops.layers.torch import Rearrange +from numpy.typing import NDArray from torch import Tensor, nn from ..cfg import CFG @@ -79,8 +80,8 @@ def __init__( warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) - cnn_choice = self.config.cnn.name.lower() - cnn_config = self.config.cnn[self.config.cnn.name] + cnn_choice = self.config.cnn.name.lower() # type: ignore + cnn_config = self.config.cnn[self.config.cnn.name] # type: ignore if "resnet" in cnn_choice or "resnext" in cnn_choice: self.cnn = ResNet(self.n_leads, **cnn_config) elif "regnet" in cnn_choice: @@ -106,114 +107,118 @@ def __init__( raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] - if self.config.rnn.name.lower() == "none": + if self.config.rnn.name.lower() == "none": # type: ignore self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels") self.rnn = nn.Identity() self.__rnn_seqlen_dim = 0 self.rnn_out_rearrange = nn.Identity() attn_input_size = rnn_input_size - elif self.config.rnn.name.lower() == "lstm": + elif self.config.rnn.name.lower() == "lstm": # type: ignore self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> seq_len batch_size channels") self.rnn = StackedLSTM( - input_size=rnn_input_size, - hidden_sizes=self.config.rnn.lstm.hidden_sizes, - bias=self.config.rnn.lstm.bias, - dropouts=self.config.rnn.lstm.dropouts, - bidirectional=self.config.rnn.lstm.bidirectional, - return_sequences=self.config.rnn.lstm.retseq, + input_size=rnn_input_size, # type: ignore + hidden_sizes=self.config.rnn.lstm.hidden_sizes, # type: ignore + bias=self.config.rnn.lstm.bias, # type: ignore + dropouts=self.config.rnn.lstm.dropouts, # type: ignore + bidirectional=self.config.rnn.lstm.bidirectional, # type: ignore + return_sequences=self.config.rnn.lstm.retseq, # type: ignore ) self.__rnn_seqlen_dim = 0 self.rnn_out_rearrange = nn.Identity() attn_input_size = self.rnn.compute_output_shape(None, None)[-1] - elif self.config.rnn.name.lower() == "linear": + elif self.config.rnn.name.lower() == "linear": # type: ignore # abuse of notation, to put before the global attention module self.rnn_in_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels") self.rnn = MLP( - in_channels=rnn_input_size, - out_channels=self.config.rnn.linear.out_channels, - activation=self.config.rnn.linear.activation, - bias=self.config.rnn.linear.bias, - dropouts=self.config.rnn.linear.dropouts, + in_channels=rnn_input_size, # type: ignore + out_channels=self.config.rnn.linear.out_channels, # type: ignore + activation=self.config.rnn.linear.activation, # type: ignore + bias=self.config.rnn.linear.bias, # type: ignore + dropouts=self.config.rnn.linear.dropouts, # type: ignore ) self.__rnn_seqlen_dim = 1 self.rnn_out_rearrange = Rearrange("batch_size seq_len channels -> seq_len batch_size channels") attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: - raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") + raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # type: ignore # attention - if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: + if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore self.attn_in_rearrange = nn.Identity() self.attn = nn.Identity() self.__attn_seqlen_dim = 0 self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size - if self.config.attn.name.lower() != "none": + if self.config.attn.name.lower() != "none": # type: ignore warnings.warn( - f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", + f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", # type: ignore RuntimeWarning, ) - elif self.config.attn.name.lower() == "none": + elif self.config.attn.name.lower() == "none": # type: ignore self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = nn.Identity() self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = attn_input_size - elif self.config.attn.name.lower() == "nl": # non_local + elif self.config.attn.name.lower() == "nl": # type: ignore + # non_local self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = NonLocalBlock( - in_channels=attn_input_size, - filter_lengths=self.config.attn.nl.filter_lengths, - subsample_length=self.config.attn.nl.subsample_length, - batch_norm=self.config.attn.nl.batch_norm, + in_channels=attn_input_size, # type: ignore + filter_lengths=self.config.attn.nl.filter_lengths, # type: ignore + subsample_length=self.config.attn.nl.subsample_length, # type: ignore + batch_norm=self.config.attn.nl.batch_norm, # type: ignore ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "se": # squeeze_exitation + elif self.config.attn.name.lower() == "se": # type: ignore + # squeeze_exitation self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = SEBlock( - in_channels=attn_input_size, - reduction=self.config.attn.se.reduction, - activation=self.config.attn.se.activation, - kw_activation=self.config.attn.se.kw_activation, - bias=self.config.attn.se.bias, + in_channels=attn_input_size, # type: ignore + reduction=self.config.attn.se.reduction, # type: ignore + activation=self.config.attn.se.activation, # type: ignore + kw_activation=self.config.attn.se.kw_activation, # type: ignore + bias=self.config.attn.se.bias, # type: ignore ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "gc": # global_context + elif self.config.attn.name.lower() == "gc": # type: ignore + # global_context self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") self.attn = GlobalContextBlock( - in_channels=attn_input_size, - ratio=self.config.attn.gc.ratio, - reduction=self.config.attn.gc.reduction, - pooling_type=self.config.attn.gc.pooling_type, - fusion_types=self.config.attn.gc.fusion_types, + in_channels=attn_input_size, # type: ignore + ratio=self.config.attn.gc.ratio, # type: ignore + reduction=self.config.attn.gc.reduction, # type: ignore + pooling_type=self.config.attn.gc.pooling_type, # type: ignore + fusion_types=self.config.attn.gc.fusion_types, # type: ignore ) self.__attn_seqlen_dim = -1 self.attn_out_rearrange = nn.Identity() clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "sa": # self_attention + elif self.config.attn.name.lower() == "sa": # type: ignore + # self_attention # NOTE: this branch NOT tested self.attn_in_rearrange = nn.Identity() - self.attn = SelfAttention( + self.attn = SelfAttention( # type: ignore embed_dim=attn_input_size, - num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), - dropout=self.config.attn.sa.dropout, - bias=self.config.attn.sa.bias, + num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), # type: ignore + dropout=self.config.attn.sa.dropout, # type: ignore + bias=self.config.attn.sa.bias, # type: ignore ) self.__attn_seqlen_dim = 0 self.attn_out_rearrange = Rearrange("seq_len batch_size channels -> batch_size channels seq_len") clf_input_size = self.attn.compute_output_shape(None, None)[-1] - elif self.config.attn.name.lower() == "transformer": + elif self.config.attn.name.lower() == "transformer": # type: ignore self.attn = Transformer( - input_size=attn_input_size, - hidden_size=self.config.attn.transformer.hidden_size, - num_layers=self.config.attn.transformer.num_layers, - num_heads=self.config.attn.transformer.num_heads, - dropout=self.config.attn.transformer.dropout, - activation=self.config.attn.transformer.activation, + input_size=attn_input_size, # type: ignore + hidden_size=self.config.attn.transformer.hidden_size, # type: ignore + num_layers=self.config.attn.transformer.num_layers, # type: ignore + num_heads=self.config.attn.transformer.num_heads, # type: ignore + dropout=self.config.attn.transformer.dropout, # type: ignore + activation=self.config.attn.transformer.activation, # type: ignore ) if self.attn.batch_first: self.attn_in_rearrange = Rearrange("seq_len batch_size channels -> batch_size seq_len channels") @@ -225,44 +230,44 @@ def __init__( self.__attn_seqlen_dim = 0 clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: - raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") + raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") # type: ignore # global pooling - if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: + if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore self.pool = nn.Identity() - if self.config.global_pool.lower() != "none": + if self.config.global_pool.lower() != "none": # type: ignore warnings.warn( - f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", + f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", # type: ignore RuntimeWarning, ) self.pool_rearrange = nn.Identity() self.__clf_input_seq = False - elif self.config.global_pool.lower() == "max": - self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) - clf_input_size *= self.config.global_pool_size + elif self.config.global_pool.lower() == "max": # type: ignore + self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) # type: ignore + clf_input_size *= self.config.global_pool_size # type: ignore self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") self.__clf_input_seq = False - elif self.config.global_pool.lower() == "avg": - self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) - clf_input_size *= self.config.global_pool_size + elif self.config.global_pool.lower() == "avg": # type: ignore + self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) # type: ignore + clf_input_size *= self.config.global_pool_size # type: ignore self.pool_rearrange = Rearrange("batch_size channels pool_size -> batch_size (channels pool_size)") self.__clf_input_seq = False - elif self.config.global_pool.lower() == "attn": + elif self.config.global_pool.lower() == "attn": # type: ignore raise NotImplementedError("Attentive pooling not implemented yet!") - elif self.config.global_pool.lower() == "none": + elif self.config.global_pool.lower() == "none": # type: ignore self.pool = nn.Identity() self.pool_rearrange = Rearrange("batch_size channels seq_len -> batch_size seq_len channels") self.__clf_input_seq = True else: - raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") + raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # type: ignore # input of `self.clf` has shape: batch_size, channels self.clf = MLP( - in_channels=clf_input_size, - out_channels=self.config.clf.out_channels + [self.n_classes], - activation=self.config.clf.activation, - bias=self.config.clf.bias, - dropouts=self.config.clf.dropouts, + in_channels=clf_input_size, # type: ignore + out_channels=self.config.clf.out_channels + [self.n_classes], # type: ignore + activation=self.config.clf.activation, # type: ignore + bias=self.config.clf.bias, # type: ignore + dropouts=self.config.clf.dropouts, # type: ignore skip_last_activation=True, ) @@ -335,7 +340,7 @@ def forward(self, input: Tensor) -> Tensor: @torch.no_grad() def inference( self, - input: Union[np.ndarray, Tensor], + input: Union[NDArray, Tensor], class_names: bool = False, bin_pred_thr: float = 0.5, ) -> BaseOutput: @@ -356,11 +361,10 @@ def inference( ------- output : BaseOutput The output of the inference method, including the following items: - - - prob: numpy.ndarray or torch.Tensor, - scalar predictions, (and binary predictions if `class_names` is True). - - pred: numpy.ndarray or torch.Tensor, - the array (with values 0, 1 for each class) of binary prediction. + - prob: numpy.ndarray or torch.Tensor, + scalar predictions, (and binary predictions if `class_names` is True). + - pred: numpy.ndarray or torch.Tensor, + the array (with values 0, 1 for each class) of binary prediction. """ raise NotImplementedError("Implement a task-specific inference method.") @@ -403,10 +407,10 @@ def doi(self) -> List[str]: new_candidates = [] for candidate in candidates: if hasattr(candidate, "doi"): - if isinstance(candidate.doi, str): - doi.append(candidate.doi) + if isinstance(candidate.doi, str): # type: ignore + doi.append(candidate.doi) # type: ignore else: - doi.extend(list(candidate.doi)) + doi.extend(list(candidate.doi)) # type: ignore for k, v in candidate.items(): if isinstance(v, CFG): new_candidates.append(v) @@ -416,13 +420,13 @@ def doi(self) -> List[str]: @classmethod def from_v1( - cls, v1_ckpt: str, device: Optional[torch.device] = None, return_config: bool = False + cls, v1_ckpt: Union[str, bytes, os.PathLike], device: Optional[torch.device] = None, return_config: bool = False ) -> Union["ECG_CRNN", Tuple["ECG_CRNN", dict]]: """Restore an instance of the model from a v1 checkpoint. Parameters ---------- - v1_ckpt : str + v1_ckpt : path_like Path to the v1 checkpoint file. device : torch.device, optional The device to load the model to. @@ -512,8 +516,8 @@ def __init__( warnings.warn("No config is provided, using default config.", RuntimeWarning) self.config.update(deepcopy(config) or {}) - cnn_choice = self.config.cnn.name.lower() - cnn_config = self.config.cnn[self.config.cnn.name] + cnn_choice = self.config.cnn.name.lower() # type: ignore + cnn_config = self.config.cnn[self.config.cnn.name] # type: ignore if "resnet" in cnn_choice or "resnext" in cnn_choice: self.cnn = ResNet(self.n_leads, **cnn_config) elif "regnet" in cnn_choice: @@ -539,119 +543,123 @@ def __init__( raise NotImplementedError(f"CNN \042{cnn_choice}\042 not implemented yet") rnn_input_size = self.cnn.compute_output_shape(None, None)[1] - if self.config.rnn.name.lower() == "none": + if self.config.rnn.name.lower() == "none": # type: ignore self.rnn = None attn_input_size = rnn_input_size - elif self.config.rnn.name.lower() == "lstm": + elif self.config.rnn.name.lower() == "lstm": # type: ignore self.rnn = StackedLSTM( - input_size=rnn_input_size, - hidden_sizes=self.config.rnn.lstm.hidden_sizes, - bias=self.config.rnn.lstm.bias, - dropouts=self.config.rnn.lstm.dropouts, - bidirectional=self.config.rnn.lstm.bidirectional, - return_sequences=self.config.rnn.lstm.retseq, + input_size=rnn_input_size, # type: ignore + hidden_sizes=self.config.rnn.lstm.hidden_sizes, # type: ignore + bias=self.config.rnn.lstm.bias, # type: ignore + dropouts=self.config.rnn.lstm.dropouts, # type: ignore + bidirectional=self.config.rnn.lstm.bidirectional, # type: ignore + return_sequences=self.config.rnn.lstm.retseq, # type: ignore ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] - elif self.config.rnn.name.lower() == "linear": + elif self.config.rnn.name.lower() == "linear": # type: ignore # abuse of notation, to put before the global attention module self.rnn = MLP( - in_channels=rnn_input_size, - out_channels=self.config.rnn.linear.out_channels, - activation=self.config.rnn.linear.activation, - bias=self.config.rnn.linear.bias, - dropouts=self.config.rnn.linear.dropouts, + in_channels=rnn_input_size, # type: ignore + out_channels=self.config.rnn.linear.out_channels, # type: ignore + activation=self.config.rnn.linear.activation, # type: ignore + bias=self.config.rnn.linear.bias, # type: ignore + dropouts=self.config.rnn.linear.dropouts, # type: ignore ) attn_input_size = self.rnn.compute_output_shape(None, None)[-1] else: - raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") + raise NotImplementedError(f"RNN \042{self.config.rnn.name}\042 not implemented yet") # type: ignore # attention - if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: + if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore self.attn = None clf_input_size = attn_input_size - if self.config.attn.name.lower() != "none": + if self.config.attn.name.lower() != "none": # type: ignore warnings.warn( - f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", + f"since `retseq` of rnn is False, hence attention `{self.config.attn.name}` is ignored", # type: ignore RuntimeWarning, ) - elif self.config.attn.name.lower() == "none": + elif self.config.attn.name.lower() == "none": # type: ignore self.attn = None clf_input_size = attn_input_size - elif self.config.attn.name.lower() == "nl": # non_local + elif self.config.attn.name.lower() == "nl": # type: ignore + # non_local self.attn = NonLocalBlock( - in_channels=attn_input_size, - filter_lengths=self.config.attn.nl.filter_lengths, - subsample_length=self.config.attn.nl.subsample_length, - batch_norm=self.config.attn.nl.batch_norm, + in_channels=attn_input_size, # type: ignore + filter_lengths=self.config.attn.nl.filter_lengths, # type: ignore + subsample_length=self.config.attn.nl.subsample_length, # type: ignore + batch_norm=self.config.attn.nl.batch_norm, # type: ignore ) clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "se": # squeeze_exitation + elif self.config.attn.name.lower() == "se": # type: ignore + # squeeze_exitation self.attn = SEBlock( - in_channels=attn_input_size, - reduction=self.config.attn.se.reduction, - activation=self.config.attn.se.activation, - kw_activation=self.config.attn.se.kw_activation, - bias=self.config.attn.se.bias, + in_channels=attn_input_size, # type: ignore + reduction=self.config.attn.se.reduction, # type: ignore + activation=self.config.attn.se.activation, # type: ignore + kw_activation=self.config.attn.se.kw_activation, # type: ignore + bias=self.config.attn.se.bias, # type: ignore ) clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "gc": # global_context + elif self.config.attn.name.lower() == "gc": # type: ignore + # global_context self.attn = GlobalContextBlock( - in_channels=attn_input_size, - ratio=self.config.attn.gc.ratio, - reduction=self.config.attn.gc.reduction, - pooling_type=self.config.attn.gc.pooling_type, - fusion_types=self.config.attn.gc.fusion_types, + in_channels=attn_input_size, # type: ignore + ratio=self.config.attn.gc.ratio, # type: ignore + reduction=self.config.attn.gc.reduction, # type: ignore + pooling_type=self.config.attn.gc.pooling_type, # type: ignore + fusion_types=self.config.attn.gc.fusion_types, # type: ignore ) clf_input_size = self.attn.compute_output_shape(None, None)[1] - elif self.config.attn.name.lower() == "sa": # self_attention + elif self.config.attn.name.lower() == "sa": # type: ignore + # self_attention # NOTE: this branch NOT tested - self.attn = SelfAttention( + self.attn = SelfAttention( # type: ignore embed_dim=attn_input_size, - num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), - dropout=self.config.attn.sa.dropout, - bias=self.config.attn.sa.bias, + num_heads=self.config.attn.sa.get("num_heads", self.config.attn.sa.get("head_num")), # type: ignore + dropout=self.config.attn.sa.dropout, # type: ignore + bias=self.config.attn.sa.bias, # type: ignore ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] - elif self.config.attn.name.lower() == "transformer": + elif self.config.attn.name.lower() == "transformer": # type: ignore self.attn = Transformer( - input_size=attn_input_size, - hidden_size=self.config.attn.transformer.hidden_size, - num_layers=self.config.attn.transformer.num_layers, - num_heads=self.config.attn.transformer.num_heads, - dropout=self.config.attn.transformer.dropout, - activation=self.config.attn.transformer.activation, + input_size=attn_input_size, # type: ignore + hidden_size=self.config.attn.transformer.hidden_size, # type: ignore + num_layers=self.config.attn.transformer.num_layers, # type: ignore + num_heads=self.config.attn.transformer.num_heads, # type: ignore + dropout=self.config.attn.transformer.dropout, # type: ignore + activation=self.config.attn.transformer.activation, # type: ignore ) clf_input_size = self.attn.compute_output_shape(None, None)[-1] else: - raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") + raise NotImplementedError(f"Attention \042{self.config.attn.name}\042 not implemented yet") # type: ignore - if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: + if self.config.rnn.name.lower() == "lstm" and not self.config.rnn.lstm.retseq: # type: ignore self.pool = None - if self.config.global_pool.lower() != "none": + if self.config.global_pool.lower() != "none": # type: ignore warnings.warn( - f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", + f"since `retseq` of rnn is False, hence global pooling `{self.config.global_pool}` is ignored", # type: ignore RuntimeWarning, ) - elif self.config.global_pool.lower() == "max": - self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) - clf_input_size *= self.config.global_pool_size - elif self.config.global_pool.lower() == "avg": - self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) - clf_input_size *= self.config.global_pool_size - elif self.config.global_pool.lower() == "attn": + elif self.config.global_pool.lower() == "max": # type: ignore + self.pool = nn.AdaptiveMaxPool1d((self.config.global_pool_size,), return_indices=False) # type: ignore + clf_input_size *= self.config.global_pool_size # type: ignore + elif self.config.global_pool.lower() == "avg": # type: ignore + self.pool = nn.AdaptiveAvgPool1d((self.config.global_pool_size,)) # type: ignore + clf_input_size *= self.config.global_pool_size # type: ignore + elif self.config.global_pool.lower() == "attn": # type: ignore raise NotImplementedError("Attentive pooling not implemented yet!") - elif self.config.global_pool.lower() == "none": + elif self.config.global_pool.lower() == "none": # type: ignore self.pool = None else: - raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") + raise NotImplementedError(f"Global Pooling \042{self.config.global_pool}\042 not implemented yet!") # type: ignore # input of `self.clf` has shape: batch_size, channels self.clf = MLP( - in_channels=clf_input_size, - out_channels=self.config.clf.out_channels + [self.n_classes], - activation=self.config.clf.activation, - bias=self.config.clf.bias, - dropouts=self.config.clf.dropouts, + in_channels=clf_input_size, # type: ignore + out_channels=self.config.clf.out_channels + [self.n_classes], # type: ignore + activation=self.config.clf.activation, # type: ignore + bias=self.config.clf.bias, # type: ignore + dropouts=self.config.clf.dropouts, # type: ignore skip_last_activation=True, ) @@ -683,14 +691,16 @@ def extract_features(self, input: Tensor) -> Tensor: features = self.cnn(input) # batch_size, channels, seq_len # RNN (optional) - if self.config.rnn.name.lower() in ["lstm"]: + if self.config.rnn.name.lower() in ["lstm"]: # type: ignore # (batch_size, channels, seq_len) --> (seq_len, batch_size, channels) features = features.permute(2, 0, 1) - features = self.rnn(features) # (seq_len, batch_size, channels) or (batch_size, channels) - elif self.config.rnn.name.lower() in ["linear"]: + features = self.rnn(features) # type: ignore + # (seq_len, batch_size, channels) or (batch_size, channels) + elif self.config.rnn.name.lower() in ["linear"]: # type: ignore # (batch_size, channels, seq_len) --> (batch_size, seq_len, channels) features = features.permute(0, 2, 1) - features = self.rnn(features) # (batch_size, seq_len, channels) + # (batch_size, seq_len, channels) + features = self.rnn(features) # type: ignore # (batch_size, seq_len, channels) --> (seq_len, batch_size, channels) features = features.permute(1, 0, 2) else: @@ -701,16 +711,18 @@ def extract_features(self, input: Tensor) -> Tensor: if self.attn is None and features.ndim == 3: # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) - elif self.config.attn.name.lower() in ["nl", "se", "gc"]: + elif self.config.attn.name.lower() in ["nl", "se", "gc"]: # type: ignore # (seq_len, batch_size, channels) --> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) - features = self.attn(features) # (batch_size, channels, seq_len) - elif self.config.attn.name.lower() in ["sa"]: - features = self.attn(features) # (seq_len, batch_size, channels) + # (batch_size, channels, seq_len) + features = self.attn(features) # type: ignore + elif self.config.attn.name.lower() in ["sa"]: # type: ignore + # (seq_len, batch_size, channels) + features = self.attn(features) # type: ignore # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) - elif self.config.attn.name.lower() in ["transformer"]: - features = self.attn(features) + elif self.config.attn.name.lower() in ["transformer"]: # type: ignore + features = self.attn(features) # type: ignore # (seq_len, batch_size, channels) -> (batch_size, channels, seq_len) features = features.permute(1, 2, 0) return features @@ -753,7 +765,7 @@ def forward(self, input: Tensor) -> Tensor: @torch.no_grad() def inference( self, - input: Union[np.ndarray, Tensor], + input: Union[NDArray, Tensor], class_names: bool = False, bin_pred_thr: float = 0.5, ) -> BaseOutput: @@ -774,11 +786,10 @@ def inference( ------- output : BaseOutput The output of the inference method, including the following items: - - - prob: numpy.ndarray or torch.Tensor, - scalar predictions, (and binary predictions if `class_names` is True). - - pred: numpy.ndarray or torch.Tensor, - the array (with values 0, 1 for each class) of binary prediction. + - prob: numpy.ndarray or torch.Tensor, + scalar predictions, (and binary predictions if `class_names` is True). + - pred: numpy.ndarray or torch.Tensor, + the array (with values 0, 1 for each class) of binary prediction. """ raise NotImplementedError("Implement a task-specific inference method.") @@ -824,10 +835,10 @@ def doi(self) -> List[str]: new_candidates = [] for candidate in candidates: if hasattr(candidate, "doi"): - if isinstance(candidate.doi, str): - doi.append(candidate.doi) + if isinstance(candidate.doi, str): # type: ignore + doi.append(candidate.doi) # type: ignore else: - doi.extend(list(candidate.doi)) + doi.extend(list(candidate.doi)) # type: ignore for k, v in candidate.items(): if isinstance(v, CFG): new_candidates.append(v) diff --git a/torch_ecg/models/grad_cam.py b/torch_ecg/models/grad_cam.py index 973e89b9..d5cb5bca 100644 --- a/torch_ecg/models/grad_cam.py +++ b/torch_ecg/models/grad_cam.py @@ -134,7 +134,7 @@ def __call__(self, input: Tensor, index: Optional[int] = None): n_classes = output.shape[-1] if index is None: - index = np.argmax(output.cpu().detach().numpy()[0]) + index = np.argmax(output.detach().cpu().numpy()[0]) one_hot = np.zeros((1, n_classes), dtype=np.float32) one_hot[0][index] = 1 @@ -145,12 +145,12 @@ def __call__(self, input: Tensor, index: Optional[int] = None): self.model.zero_grad() one_hot.backward(retain_graph=True) - grads_val = self.extractor.get_gradients()[-1].cpu().detach().numpy() + grads_val = self.extractor.get_gradients()[-1].detach().cpu().numpy() # of shape (batch_size (=1), channels, seq_len) or (batch_size (=1), seq_len, channels) target = features[-1] # of shape (channels, seq_len) or (seq_len, channels) - target = target.cpu().detach().numpy()[0, :] + target = target.detach().cpu().numpy()[0, :] if self.target_channel_last: weights = np.mean(grads_val, axis=-2)[0, :] diff --git a/torch_ecg/models/loss.py b/torch_ecg/models/loss.py index d872e150..66ecc341 100644 --- a/torch_ecg/models/loss.py +++ b/torch_ecg/models/loss.py @@ -302,7 +302,7 @@ class FocalLoss(nn.modules.loss._WeightedLoss): Where: - - :math:`p_t` is the model's estimated probability for each class. + - :math:`p_t` is the model's estimated probability for each class. Parameters ---------- diff --git a/torch_ecg/models/unets/ecg_subtract_unet.py b/torch_ecg/models/unets/ecg_subtract_unet.py index 12839784..f52dad11 100644 --- a/torch_ecg/models/unets/ecg_subtract_unet.py +++ b/torch_ecg/models/unets/ecg_subtract_unet.py @@ -15,8 +15,8 @@ from itertools import repeat from typing import List, Optional, Sequence, Union -import numpy as np import torch +from numpy.typing import NDArray from torch import Tensor, nn from ...cfg import CFG @@ -642,7 +642,7 @@ def forward(self, input: Tensor) -> Tensor: return output @torch.no_grad() - def inference(self, input: Union[np.ndarray, Tensor], bin_pred_thr: float = 0.5) -> Tensor: + def inference(self, input: Union[NDArray, Tensor], bin_pred_thr: float = 0.5) -> Tensor: """Method for making inference on a single input.""" raise NotImplementedError("Implement a task-specific inference method.") diff --git a/torch_ecg/utils/_ecg_plot.py b/torch_ecg/utils/_ecg_plot.py index 94deae8c..e610ec85 100644 --- a/torch_ecg/utils/_ecg_plot.py +++ b/torch_ecg/utils/_ecg_plot.py @@ -42,6 +42,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.ticker import AutoMinorLocator +from numpy.typing import NDArray __all__ = ["ecg_plot"] @@ -99,7 +100,7 @@ def inches_to_dots(value: float, resolution: int) -> float: return value * resolution -def create_signal_dictionary(signal: np.ndarray, full_leads: List[str]) -> Dict[str, np.ndarray]: +def create_signal_dictionary(signal: NDArray, full_leads: List[str]) -> Dict[str, NDArray]: record_dict = {} for k in range(len(full_leads)): record_dict[full_leads[k]] = signal[k] @@ -107,7 +108,7 @@ def create_signal_dictionary(signal: np.ndarray, full_leads: List[str]) -> Dict[ def ecg_plot( - ecg: Dict[str, np.ndarray], + ecg: Dict[str, NDArray], sample_rate: int, columns: int, rec_file_name: Union[str, bytes, os.PathLike], @@ -140,7 +141,7 @@ def ecg_plot( Parameters ---------- - ecg : Dict[str, np.ndarray] + ecg : Dict[str, NDArray] Dictionary of ECG signals with lead names as keys, values as 1D numpy arrays. sample_rate : int diff --git a/torch_ecg/utils/_edr.py b/torch_ecg/utils/_edr.py index 02e23223..a084ccec 100644 --- a/torch_ecg/utils/_edr.py +++ b/torch_ecg/utils/_edr.py @@ -8,6 +8,7 @@ from typing import Sequence import numpy as np +from numpy.typing import NDArray __all__ = [ "phs_edr", @@ -23,7 +24,7 @@ def phs_edr( return_with_time: bool = True, mode: str = "complex", verbose: int = 0, -) -> np.ndarray: +) -> NDArray: """ computes the respiratory rate from single-lead ECG signals. @@ -53,7 +54,7 @@ def phs_edr( Returns ------- - np.ndarray, + NDArray, 1d, if `return_with_time` is set False, 2d in the form of [ts, val] (ts in milliseconds), if `return_with_time` is set True diff --git a/torch_ecg/utils/_preproc.py b/torch_ecg/utils/_preproc.py index f4f9ee91..6899fcad 100644 --- a/torch_ecg/utils/_preproc.py +++ b/torch_ecg/utils/_preproc.py @@ -28,6 +28,7 @@ # from scipy.signal import medfilt # https://github.com/scipy/scipy/issues/9680 from biosppy.signals.tools import filter_signal +from numpy.typing import NDArray from scipy.ndimage.filters import median_filter from ..cfg import CFG @@ -62,14 +63,14 @@ def preprocess_multi_lead_signal( - raw_sig: np.ndarray, + raw_sig: NDArray, fs: Real, sig_fmt: str = "channel_first", bl_win: Optional[List[Real]] = None, band_fs: Optional[List[Real]] = None, rpeak_fn: Optional[str] = None, verbose: int = 0, -) -> Dict[str, np.ndarray]: +) -> Dict[str, NDArray]: """ perform preprocessing for multi-lead ECG signal (with units in mV), preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. @@ -146,13 +147,13 @@ def preprocess_multi_lead_signal( def preprocess_single_lead_signal( - raw_sig: np.ndarray, + raw_sig: NDArray, fs: Real, bl_win: Optional[List[Real]] = None, band_fs: Optional[List[Real]] = None, rpeak_fn: Optional[str] = None, verbose: int = 0, -) -> Dict[str, np.ndarray]: +) -> Dict[str, NDArray]: """ perform preprocessing for single lead ECG signal (with units in mV), preprocessing may include median filter, bandpass filter, and rpeaks detection, etc. @@ -225,12 +226,12 @@ def preprocess_single_lead_signal( def rpeaks_detect_multi_leads( - sig: np.ndarray, + sig: NDArray, fs: Real, sig_fmt: str = "channel_first", rpeak_fn: str = "xqrs", verbose: int = 0, -) -> np.ndarray: +) -> NDArray: """ detect rpeaks from the filtered multi-lead ECG signal (with units in mV) @@ -252,7 +253,7 @@ def rpeaks_detect_multi_leads( Returns ------- - rpeaks: np.ndarray, + rpeaks: NDArray, array of indices of the detected rpeaks of the multi-lead ECG signal """ @@ -273,7 +274,7 @@ def rpeaks_detect_multi_leads( return rpeaks -def merge_rpeaks(rpeaks_candidates: List[np.ndarray], sig: np.ndarray, fs: Real, verbose: int = 0) -> np.ndarray: +def merge_rpeaks(rpeaks_candidates: List[NDArray], sig: NDArray, fs: Real, verbose: int = 0) -> NDArray: """ merge rpeaks that are detected from each lead of multi-lead signals (with units in mV), using certain criterion merging qrs masks from each lead @@ -291,7 +292,7 @@ def merge_rpeaks(rpeaks_candidates: List[np.ndarray], sig: np.ndarray, fs: Real, Returns ------- - final_rpeaks: np.ndarray + final_rpeaks: NDArray the final rpeaks obtained by merging the rpeaks from all the leads """ diff --git a/torch_ecg/utils/download.py b/torch_ecg/utils/download.py index f726af59..069c8137 100644 --- a/torch_ecg/utils/download.py +++ b/torch_ecg/utils/download.py @@ -19,7 +19,7 @@ import warnings import zipfile from pathlib import Path -from typing import Any, Iterable, Literal, Optional, Union +from typing import Any, Iterable, Literal, Optional, Tuple, Union import boto3 import requests @@ -27,6 +27,8 @@ from botocore.client import Config from tqdm.auto import tqdm +from .misc import str2bool + __all__ = [ "http_get", ] @@ -36,10 +38,10 @@ def _requests_retry_session( - retries=5, - backoff_factor=0.5, - status_forcelist=(500, 502, 503, 504), - session=None, + retries: int = 5, + backoff_factor: float = 0.5, + status_forcelist: Tuple[int, ...] = (429, 500, 502, 503, 504, 403), + session: Optional[requests.Session] = None, ) -> requests.Session: """Get a requests session with retry strategy. @@ -50,7 +52,7 @@ def _requests_retry_session( backoff_factor : float, default 0.5 A backoff factor to apply between attempts. A backoff factor of 0.5 will sleep for [0.5s, 1s, 2s, ...] between retries. - status_forcelist : tuple, default (500, 502, 503, 504) + status_forcelist : tuple, default (429, 500, 502, 503, 504, 403) A set of HTTP status codes that we should force a retry on. session : requests.Session, optional An existing requests session. @@ -63,17 +65,23 @@ def _requests_retry_session( """ session = session or requests.Session() - retry = requests.packages.urllib3.util.retry.Retry( + from requests.adapters import HTTPAdapter + from requests.packages.urllib3.util.retry import Retry # type: ignore + + retry = Retry( total=retries, read=retries, connect=retries, + status=retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist, - allowed_methods=["HEAD", "GET", "OPTIONS"], + allowed_methods=frozenset(["HEAD", "GET", "OPTIONS"]), + raise_on_redirect=True, + raise_on_status=False, ) - adapter = requests.adapters.HTTPAdapter(max_retries=retry) - session.mount("http://", adapter) + adapter = HTTPAdapter(max_retries=retry) session.mount("https://", adapter) + session.mount("http://", adapter) return session @@ -83,6 +91,9 @@ def http_get( proxies: Optional[dict] = None, extract: Literal[True, False, "auto"] = "auto", filename: Optional[str] = None, + *, + timeout: Union[Tuple[float, float], float] = (5.0, 60.0), + verify_length: bool = True, ) -> Path: """Download contents of a URL and save to a file. @@ -108,6 +119,16 @@ def http_get( which is set to `dst_dir`, and `filename` is only the downloaded file name. .. versionadded:: 0.0.20 + timeout : float or tuple, default (5.0, 60.0) + How many seconds to wait for the server to send data + before giving up, as a float, or a (connect timeout, read timeout) tuple. + + .. versionadded:: 0.0.32 + verify_length : bool, default True + Whether to verify the length of the downloaded file + with the `Content-Length` header in the HTTP response. + + .. versionadded:: 0.0.32 Returns ------- @@ -119,8 +140,8 @@ def http_get( .. [1] https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py """ - Path(dst_dir).mkdir(parents=True, exist_ok=True) - if filename is not None and (Path(dst_dir) / filename).exists(): + Path(dst_dir).mkdir(parents=True, exist_ok=True) # type: ignore + if filename is not None and (Path(dst_dir) / filename).exists(): # type: ignore raise FileExistsError("file already exists") url_parsed = urllib.parse.urlparse(url) if url_parsed.scheme == "": @@ -134,7 +155,7 @@ def http_get( if url_parsed.scheme == "s3": _download_from_aws_s3_using_awscli(url, dst_dir) - return Path(dst_dir) + return Path(dst_dir) # type: ignore if url_parsed.netloc == "www.dropbox.com" and url_parsed.query == "dl=0": url_parsed = url_parsed._replace(query="dl=1") @@ -163,7 +184,7 @@ def http_get( delete=False, ) _download_from_google_drive(url, downloaded_file.name) - df_suffix = _suffix(filename) + df_suffix = _suffix(filename) # type: ignore downloaded_file.close() else: print(f"Downloading {url}.") @@ -186,7 +207,7 @@ def http_get( RuntimeWarning, ) extract = False - parent_dir = Path(dst_dir).parent + parent_dir = Path(dst_dir).parent # type: ignore df_suffix = _suffix(pure_url) if filename is None else _suffix(filename) downloaded_file = tempfile.NamedTemporaryFile( dir=parent_dir, @@ -194,19 +215,39 @@ def http_get( delete=False, ) # req = requests.get(url, stream=True, proxies=proxies) - req = _requests_retry_session().get(url, stream=True, proxies=proxies) + req = _requests_retry_session().get(url, stream=True, proxies=proxies, timeout=timeout) + try: + req.raise_for_status() + except Exception: + snippet = "" + try: + snippet = req.text[:300] + except Exception: + pass + raise RuntimeError(f"Failed to download {url}, status={req.status_code}, body[:300]={snippet!r}") content_length = req.headers.get("Content-Length") total = int(content_length) if content_length is not None else None if req.status_code in [403, 404]: raise Exception(f"Could not reach {url}.") - progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=1.0) + if str2bool(os.environ.get("CI")): + mininterval = 10.0 + disable = True + else: + mininterval = 1.0 + disable = False + progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=mininterval, disable=disable) + downloaded_size = 0 for chunk in req.iter_content(chunk_size=1024): if chunk: # filter out keep-alive new chunks progress.update(len(chunk)) downloaded_file.write(chunk) + downloaded_size += len(chunk) progress.close() downloaded_file.close() + if verify_length and total is not None and downloaded_size != total: + raise IOError(f"Size mismatch for {url}. Expected {total} bytes, got {downloaded_size} bytes.") + # add a delay to avoid the error "process cannot access the file because it is being used by another process" time.sleep(0.1) @@ -225,20 +266,20 @@ def http_get( _folder = Path(url).name.replace(_suffix(url), "") else: _folder = _stem(Path(filename)) - if _folder in os.listdir(dst_dir): + if (Path(dst_dir) / _folder).exists(): # type: ignore tmp_folder = str(dst_dir).rstrip(os.sep) + "_tmp" # move (rename) the dst_dir to a temporary folder os.rename(dst_dir, tmp_folder) # move (rename) the extracted folder to the destination folder os.rename(Path(tmp_folder) / _folder, dst_dir) shutil.rmtree(tmp_folder) - final_dst = Path(dst_dir) + final_dst = Path(dst_dir) # type: ignore else: - Path(dst_dir).mkdir(parents=True, exist_ok=True) + Path(dst_dir).mkdir(parents=True, exist_ok=True) # type: ignore if filename is None: - final_dst = Path(dst_dir) / Path(pure_url).name + final_dst = Path(dst_dir) / Path(pure_url).name # type: ignore else: - final_dst = Path(dst_dir) / filename + final_dst = Path(dst_dir) / filename # type: ignore shutil.copyfile(downloaded_file.name, final_dst) os.remove(downloaded_file.name) return final_dst @@ -258,7 +299,9 @@ def _stem(path: Union[str, bytes, os.PathLike]) -> str: Filename without extension. """ - ret = Path(path).stem + if isinstance(path, bytes): + path = path.decode() + ret = Path(path).stem # type: ignore if Path(ret).suffix in [".tar", ".gz", ".tz", ".lz", ".bz2", ".xz", ".zip", ".7z"]: return _stem(ret) return ret @@ -342,8 +385,8 @@ def _untar_file(path_to_tar_file: Union[str, bytes, os.PathLike], dst_dir: Union """ print(f"Extracting file {path_to_tar_file} to {dst_dir}.") - mode = Path(path_to_tar_file).suffix.replace(".", "r:").replace("tar", "").strip(":") - with tarfile.open(str(path_to_tar_file), mode) as tar_ref: + mode = Path(path_to_tar_file).suffix.replace(".", "r:").replace("tar", "").strip(":") # type: ignore + with tarfile.open(str(path_to_tar_file), mode) as tar_ref: # type: ignore # tar_ref.extractall(str(dst_dir)) # CVE-2007-4559 (related to CVE-2001-1267): # directory traversal vulnerability in `extract` and `extractall` in `tarfile` module @@ -370,7 +413,7 @@ def _is_within_directory(directory: Union[str, bytes, os.PathLike], target: Unio abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) - prefix = os.path.commonprefix([abs_directory, abs_target]) + prefix = os.path.commonprefix([abs_directory, abs_target]) # type: ignore return prefix == abs_directory @@ -409,7 +452,7 @@ def _safe_tar_extract( """ for member in members or tar.getmembers(): - member_path = os.path.join(dst_dir, member.name) + member_path = os.path.join(dst_dir, member.name) # type: ignore if not _is_within_directory(dst_dir, member_path): raise Exception("Attempted Path Traversal in Tar File") @@ -510,11 +553,12 @@ def count_aws_s3_bucket(bucket_name: str, prefix: str = "") -> int: page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) object_count = 0 - for page in page_iterator: - if "Contents" in page: - object_count += len(page["Contents"]) - - s3.close() + try: + for page in page_iterator: + if "Contents" in page: + object_count += len(page.get("Contents", [])) + finally: + s3.close() return object_count @@ -562,7 +606,7 @@ def _download_from_aws_s3_using_boto3(url: str, dst_dir: Union[str, bytes, os.Pa if "Contents" in page: for obj in page["Contents"]: key = obj["Key"] - dst_file = Path(dst_dir) / key + dst_file = Path(dst_dir) / key # type: ignore if not dst_file.parent.exists(): dst_file.parent.mkdir(parents=True, exist_ok=True) s3.download_file(bucket_name, key, str(dst_file)) @@ -571,7 +615,13 @@ def _download_from_aws_s3_using_boto3(url: str, dst_dir: Union[str, bytes, os.Pa s3.close() -def _download_from_aws_s3_using_awscli(url: str, dst_dir: Union[str, bytes, os.PathLike]) -> None: +def _download_from_aws_s3_using_awscli( + url: str, + dst_dir: Union[str, bytes, os.PathLike], + *, + show_progress: bool = True, + env: Optional[dict] = None, +) -> None: """Download a file from AWS S3 using awscli. Parameters @@ -581,51 +631,78 @@ def _download_from_aws_s3_using_awscli(url: str, dst_dir: Union[str, bytes, os.P For example, "s3://bucket-name/files/pre-fix". dst_dir : `path-like` The output directory. + show_progress : bool, default True + Whether to display tqdm progress bar. + env : dict, optional + Custom environment. Returns ------- None """ - assert shutil.which("aws") is not None, "AWS cli is required to download from S3." + if shutil.which("aws") is None: + raise RuntimeError("AWS cli is required to download from S3 (please install AWS CLI v2).") pattern = "^s3://(?P[^/]+)/(?P.+)$" match = re.match(pattern, url) if match is None: raise ValueError(f"Invalid S3 URL: {url}") + bucket_name = match.group("bucket_name") prefix = match.group("prefix") if prefix.startswith("files/"): - prefix = prefix.replace("files/", "") - object_count = count_aws_s3_bucket(bucket_name, prefix) - print(f"Downloading from S3 bucket: {bucket_name}, prefix: {prefix}, total files: {object_count}") + prefix = prefix[len("files/") :] + total_files = count_aws_s3_bucket(bucket_name, prefix) - pbar = tqdm(total=object_count, dynamic_ncols=True, mininterval=1.0) + if total_files == 0: + print(f"[S3 sync] No objects found at {bucket_name}/{prefix}, skipping.") + return + + print(f"Downloading from S3 bucket: {bucket_name}, prefix: {prefix}, total files: {total_files}") + + dst_dir = Path(dst_dir).expanduser().resolve() # type: ignore + dst_dir.mkdir(parents=True, exist_ok=True) + + if str2bool(os.environ.get("CI")): + mininterval = 10.0 + disable = True + else: + mininterval = 1.0 + disable = not show_progress + pbar = tqdm(total=total_files, dynamic_ncols=True, mininterval=mininterval, disable=disable) download_count = 0 - command = f"aws s3 sync --no-sign-request {url} {dst_dir}" - process = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - debug_stdout = collections.deque(maxlen=10) - while 1: - line = process.stdout.readline().decode("utf-8", errors="replace") - if line.rstrip(): + command = f"aws s3 sync --no-sign-request --only-show-errors {url} {shlex.quote(str(dst_dir))}" + process = subprocess.Popen( + shlex.split(command), + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + bufsize=1, + text=True, + env=env if env is not None else os.environ.copy(), + ) + debug_stdout = collections.deque(maxlen=50) + try: + assert process.stdout is not None # for type checker + for line in process.stdout: + line = line.rstrip("\n") debug_stdout.append(line) - if "download: s3:" in line: + if "download: s3://" in line: download_count += 1 pbar.update(1) - exitcode = process.poll() - if exitcode is not None: - for line in process.stdout: - debug_stdout.append(line.decode("utf-8", errors="replace")) - if exitcode is not None and exitcode != 0: - error_msg = "\n".join(debug_stdout) - process.communicate() - process.stdout.close() - raise subprocess.CalledProcessError(exitcode, error_msg) - else: - break - process.communicate() - process.stdout.close() - # object_count - download_count files skipped for they already exist - pbar.update(object_count - download_count) - pbar.close() - print(f"Downloaded {download_count} files from S3.") + + retcode = process.wait() + if retcode != 0: + raise subprocess.CalledProcessError( + retcode, + command, + output="\n".join(debug_stdout), + ) + finally: + if download_count < total_files: + pbar.update(total_files - download_count) + pbar.close() + if process.stdout and not process.stdout.closed: + process.stdout.close() + + print(f"[S3 sync] Completed. Reported downloads: {download_count}/{total_files} (existing files are skipped silently).") diff --git a/torch_ecg/utils/misc.py b/torch_ecg/utils/misc.py index 8579811d..e453a95a 100644 --- a/torch_ecg/utils/misc.py +++ b/torch_ecg/utils/misc.py @@ -14,14 +14,15 @@ from copy import deepcopy from functools import reduce, wraps from glob import glob -from numbers import Number, Real -from pathlib import Path +from numbers import Number +from pathlib import Path, PurePath from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd from bib_lookup import CitationMixin as _CitationMixin from deprecated import deprecated +from numpy.typing import NDArray from ..cfg import _DATA_CACHE, DEFAULTS @@ -85,10 +86,10 @@ def get_record_list_recursive(db_dir: Union[str, bytes, os.PathLike], rec_ext: s """ if not rec_ext.startswith("."): - res = Path(db_dir).rglob(f"*.{rec_ext}") + res = Path(db_dir).rglob(f"*.{rec_ext}") # type: ignore else: - res = Path(db_dir).rglob(f"*{rec_ext}") - res = [str((item.relative_to(db_dir) if relative else item).with_suffix("")) for item in res if str(item).endswith(rec_ext)] + res = Path(db_dir).rglob(f"*{rec_ext}") # type: ignore + res = [str((item.relative_to(db_dir) if relative else item).with_suffix("")) for item in res if str(item).endswith(rec_ext)] # type: ignore res = sorted(res) return res @@ -173,32 +174,32 @@ def get_record_list_recursive3( res = [] elif isinstance(rec_patterns, dict): res = {k: [] for k in rec_patterns.keys()} - _db_dir = Path(db_dir).resolve() # make absolute + _db_dir = Path(db_dir).resolve() # type: ignore roots = [_db_dir] while len(roots) > 0: new_roots = [] for r in roots: tmp = os.listdir(r) if isinstance(rec_patterns, str): - res += [r / item for item in filter(re.compile(rec_patterns).search, tmp)] + res += [r / item for item in filter(re.compile(rec_patterns).search, tmp)] # type: ignore elif isinstance(rec_patterns, dict): for k in rec_patterns.keys(): - res[k] += [r / item for item in filter(re.compile(rec_patterns[k]).search, tmp)] + res[k] += [r / item for item in filter(re.compile(rec_patterns[k]).search, tmp)] # type: ignore new_roots += [r / item for item in tmp if (r / item).is_dir()] roots = deepcopy(new_roots) if isinstance(rec_patterns, str): if with_suffix: - res = [str((item.relative_to(_db_dir) if relative else item)) for item in res] + res = [str((item.relative_to(_db_dir) if relative else item)) for item in res] # type: ignore else: - res = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res] + res = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res] # type: ignore res = sorted(res) elif isinstance(rec_patterns, dict): for k in rec_patterns.keys(): if with_suffix: - res[k] = [str((item.relative_to(_db_dir) if relative else item)) for item in res[k]] + res[k] = [str((item.relative_to(_db_dir) if relative else item)) for item in res[k]] # type: ignore else: - res[k] = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res[k]] - res[k] = sorted(res[k]) + res[k] = [str((item.relative_to(_db_dir) if relative else item).with_suffix("")) for item in res[k]] # type: ignore + res[k] = sorted(res[k]) # type: ignore return res @@ -282,7 +283,11 @@ def dict_to_str(d: Union[dict, list, tuple], current_depth: int = 1, indent_spac return s -def str2bool(v: Union[str, bool]) -> bool: +_TRUE_SET = {"yes", "true", "t", "y", "1"} +_FALSE_SET = {"no", "false", "f", "n", "0"} + + +def str2bool(v: Union[str, bool, None], *, default: bool = False, strict: bool = True) -> bool: """Converts a "boolean" value possibly in the format of :class:`str` to :class:`bool`. @@ -290,8 +295,17 @@ def str2bool(v: Union[str, bool]) -> bool: Parameters ---------- - v : str or bool + v : str or bool or None The "boolean" value. + default : bool, default False + The default value to return if `v` is ``None``, + or if `strict` is ``False`` and `v` could not be converted. + + .. versionadded:: 0.0.32 + strict : bool, default True + Whether to raise error if `v` could not be converted. + + .. versionadded:: 0.0.32 Returns ------- @@ -304,18 +318,25 @@ def str2bool(v: Union[str, bool]) -> bool: """ if isinstance(v, bool): - b = v - elif v.lower() in ("yes", "true", "t", "y", "1"): - b = True - elif v.lower() in ("no", "false", "f", "n", "0"): - b = False - else: - raise ValueError("Boolean value expected.") - return b + return v + if v is None: + return default + if not isinstance(v, str): + if strict: + raise TypeError(f"Expected str|bool|None, got {type(v)}") + return default + v_norm = v.strip().lower() + if v_norm in _TRUE_SET: + return True + if v_norm in _FALSE_SET: + return False + if strict: + raise ValueError(f"Boolean value expected, got {v!r}") + return default @deprecated("Use `np.diff` instead.") -def diff_with_step(a: np.ndarray, step: int = 1) -> np.ndarray: +def diff_with_step(a: NDArray, step: int = 1) -> NDArray: """Compute ``a[n+step] - a[n]`` for all valid `n`. Parameters @@ -337,14 +358,14 @@ def diff_with_step(a: np.ndarray, step: int = 1) -> np.ndarray: return d -def ms2samples(t: Real, fs: Real) -> int: +def ms2samples(t: Union[float, int], fs: Union[float, int]) -> int: """Convert time duration in ms to number of samples. Parameters ---------- - t : numbers.Real + t : float or int Time duration in ms. - fs : numbers.Real + fs : float or int Sampling frequency. Returns @@ -354,23 +375,23 @@ def ms2samples(t: Real, fs: Real) -> int: with sampling frequency `fs`. """ - n_samples = t * fs // 1000 + n_samples = int(t * fs / 1000) return n_samples -def samples2ms(n_samples: int, fs: Real) -> Real: +def samples2ms(n_samples: int, fs: int) -> float: """Convert number of samples to time duration in ms. Parameters ---------- n_samples : int Number of sample points. - fs : numbers.Real + fs : int Sampling frequency. Returns ------- - t : numbers.Real + t : float Time duration in ms converted from `n_samples`, with sampling frequency `fs`. @@ -380,8 +401,8 @@ def samples2ms(n_samples: int, fs: Real) -> Real: def plot_single_lead( - t: np.ndarray, - sig: np.ndarray, + t: NDArray, + sig: NDArray, ax: Optional[Any] = None, ticks_granularity: int = 0, **kwargs, @@ -405,8 +426,9 @@ def plot_single_lead( None """ - if "plt" not in dir(): - import matplotlib.pyplot as plt + import matplotlib.pyplot as plt + from matplotlib.ticker import MultipleLocator + palette = { "p_waves": "cyan", "qrs": "green", @@ -426,12 +448,12 @@ def plot_single_lead( ax.axhline(y=0, linestyle="-", linewidth="1.0", color="red") # NOTE that `Locator` has default `MAXTICKS` equal to 1000 if ticks_granularity >= 1: - ax.xaxis.set_major_locator(plt.MultipleLocator(0.2)) - ax.yaxis.set_major_locator(plt.MultipleLocator(500)) + ax.xaxis.set_major_locator(MultipleLocator(0.2)) + ax.yaxis.set_major_locator(MultipleLocator(500)) ax.grid(which="major", linestyle="-", linewidth="0.5", color="red") if ticks_granularity >= 2: - ax.xaxis.set_minor_locator(plt.MultipleLocator(0.04)) - ax.yaxis.set_minor_locator(plt.MultipleLocator(100)) + ax.xaxis.set_minor_locator(MultipleLocator(0.04)) + ax.yaxis.set_minor_locator(MultipleLocator(100)) ax.grid(which="minor", linestyle=":", linewidth="0.5", color="black") waves = kwargs.get("waves", {"p_waves": [], "qrs": [], "t_waves": []}) @@ -485,13 +507,13 @@ def init_logger( log_file = None else: if log_file is None: - log_file = f"{DEFAULTS.prefix}-log-{get_date_str()}.txt" - log_dir = Path(log_dir).expanduser().resolve() if log_dir is not None else DEFAULTS.log_dir + log_file = f"{DEFAULTS.prefix}-log-{get_date_str()}.txt" # type: ignore + log_dir = Path(log_dir).expanduser().resolve() if log_dir is not None else DEFAULTS.log_dir # type: ignore log_dir.mkdir(parents=True, exist_ok=True) - log_file = log_dir / log_file + log_file = log_dir / log_file # type: ignore print(f"log file path: {str(log_file)}") - log_name = (log_name or DEFAULTS.prefix) + (f"-{suffix}" if suffix else "") + log_name = (log_name or DEFAULTS.prefix) + (f"-{suffix}" if suffix else "") # type: ignore # if a logger with the same name already exists, remove it if log_name in logging.root.manager.loggerDict: logging.getLogger(log_name).handlers = [] @@ -506,21 +528,21 @@ def init_logger( c_handler.setLevel(logging.DEBUG) if log_file is not None: # print("level of `f_handler` is set DEBUG") - f_handler.setLevel(logging.DEBUG) + f_handler.setLevel(logging.DEBUG) # type: ignore logger.setLevel(logging.DEBUG) elif verbose >= 1: # print("level of `c_handler` is set INFO") c_handler.setLevel(logging.INFO) if log_file is not None: # print("level of `f_handler` is set DEBUG") - f_handler.setLevel(logging.DEBUG) + f_handler.setLevel(logging.DEBUG) # type: ignore logger.setLevel(logging.DEBUG) else: # print("level of `c_handler` is set WARNING") c_handler.setLevel(logging.WARNING) if log_file is not None: # print("level of `f_handler` is set INFO") - f_handler.setLevel(logging.INFO) + f_handler.setLevel(logging.INFO) # type: ignore logger.setLevel(logging.INFO) c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") @@ -529,8 +551,8 @@ def init_logger( if log_file is not None: f_format = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") - f_handler.setFormatter(f_format) - logger.addHandler(f_handler) + f_handler.setFormatter(f_format) # type: ignore + logger.addHandler(f_handler) # type: ignore return logger @@ -598,7 +620,7 @@ def read_log_txt( Scalars summary, in the format of a :class:`~pandas.DataFrame`. """ - content = Path(fp).read_text().splitlines() + content = Path(fp).read_text().splitlines() # type: ignore if isinstance(scalar_startswith, str): field_pattern = f"({scalar_startswith})" else: @@ -615,7 +637,8 @@ def read_log_txt( field, val = line.split(":")[-2:] field = field.strip() val = float(val.strip()) - new_line[field] = val + if new_line: + new_line[field] = val summary.append(new_line) summary = pd.DataFrame(summary) return summary @@ -641,7 +664,7 @@ def read_event_scalars( """ try: - from tensorflow.python.summary.event_accumulator import EventAccumulator + from tensorflow.python.summary.event_accumulator import EventAccumulator # type: ignore except Exception: try: from tensorboard.backend.event_processing.event_accumulator import EventAccumulator @@ -662,7 +685,7 @@ def read_event_scalars( df.columns = ["wall_time", "step", "value"] summary[k] = df if isinstance(keys, str): - summary = summary[k] + summary = summary[k] # type: ignore return summary @@ -809,15 +832,15 @@ def default_class_repr(c: object, align: str = "center", depth: int = 1) -> str: closing_indent = 4 * (depth - 1) * " " if not hasattr(c, "extra_repr_keys"): return repr(c) - elif len(c.extra_repr_keys()) > 0: - max_len = max([len(k) for k in c.extra_repr_keys()]) + elif len(c.extra_repr_keys()) > 0: # type: ignore + max_len = max([len(k) for k in c.extra_repr_keys()]) # type: ignore extra_str = ( "(\n" + ",\n".join( [ f"""{indent}{k.ljust(max_len, " ") if align.lower() in ["center", "c"] else k} = {default_class_repr(eval(f"c.{k}"),align,depth+1)}""" for k in c.__dir__() - if k in c.extra_repr_keys() + if k in c.extra_repr_keys() # type: ignore ] ) + f"{closing_indent}\n)" @@ -874,7 +897,7 @@ def get_citation( style: Optional[str] = None, timeout: Optional[float] = None, print_result: bool = True, - ) -> Union[str, type(None)]: + ) -> Union[str, None]: """Get bib citation from DOIs. Overrides the default method to make the `print_result` argument @@ -937,7 +960,7 @@ def __init__(self, data: Optional[Sequence] = None, **kwargs: Any) -> None: self.data = np.array(data) self.verbose = kwargs.get("verbose", 0) - def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwargs: Any) -> np.ndarray: + def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwargs: Any) -> NDArray: """Compute moving average. Parameters @@ -946,13 +969,11 @@ def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwarg The series data to compute its moving average. method : str method for computing moving average, can be one of - - - "sma", "simple", "simple moving average"; - - "ema", "ewma", "exponential", "exponential weighted", - "exponential moving average", "exponential weighted moving average"; - - "cma", "cumulative", "cumulative moving average"; - - "wma", "weighted", "weighted moving average". - + - "sma", "simple", "simple moving average"; + - "ema", "ewma", "exponential", "exponential weighted", + "exponential moving average", "exponential weighted moving average"; + - "cma", "cumulative", "cumulative moving average"; + - "wma", "weighted", "weighted moving average". kwargs : dict, optional Keyword arguments for the specific moving average method. @@ -984,7 +1005,7 @@ def __call__(self, data: Optional[Sequence] = None, method: str = "ema", **kwarg self.data = np.array(data) return func(**kwargs) - def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> np.ndarray: + def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> NDArray: """Simple moving average. Parameters @@ -1020,13 +1041,13 @@ def _sma(self, window: int = 5, center: bool = False, **kwargs: Any) -> np.ndarr smoothed.append(s) smoothed = np.array(smoothed) if center: - smoothed[hw:-hw] = smoothed[window - 1 :] - for n in range(hw): - smoothed[n] = np.mean(self.data[: n + hw + 1]) - smoothed[-n - 1] = np.mean(self.data[-n - hw - 1 :]) + smoothed[hw:-hw] = smoothed[window - 1 :] # type: ignore + for n in range(hw): # type: ignore + smoothed[n] = np.mean(self.data[: n + hw + 1]) # type: ignore + smoothed[-n - 1] = np.mean(self.data[-n - hw - 1 :]) # type: ignore return smoothed - def _ema(self, weight: float = 0.6, **kwargs: Any) -> np.ndarray: + def _ema(self, weight: float = 0.6, **kwargs: Any) -> NDArray: """Exponential moving average This is also the function used in Tensorboard Scalar panel, @@ -1057,7 +1078,7 @@ def _ema(self, weight: float = 0.6, **kwargs: Any) -> np.ndarray: smoothed = np.array(smoothed) return smoothed - def _cma(self, **kwargs) -> np.ndarray: + def _cma(self, **kwargs) -> NDArray: """Cumulative moving average. Parameters @@ -1084,7 +1105,7 @@ def _cma(self, **kwargs) -> np.ndarray: smoothed = np.array(smoothed) return smoothed - def _wma(self, window: int = 5, **kwargs: Any) -> np.ndarray: + def _wma(self, window: int = 5, **kwargs: Any) -> NDArray: """Weighted moving average. Parameters @@ -1199,7 +1220,7 @@ def remove_parameters_returns_from_docstring( if start_idx is not None and len(line.strip()) == 0: indices2remove.extend(list(range(start_idx, idx))) start_idx = None - if parameters_starts and len(line.lstrip()) == len(line) - len(parameters_indent): + if parameters_starts and len(line.lstrip()) == len(line) - len(parameters_indent): # type: ignore if any([line.lstrip().startswith(p) for p in parameters]): if start_idx is not None: indices2remove.extend(list(range(start_idx, idx))) @@ -1210,7 +1231,7 @@ def remove_parameters_returns_from_docstring( else: indices2remove.extend(list(range(start_idx, idx))) start_idx = None - if returns_starts and len(line.lstrip()) == len(line) - len(returns_indent): + if returns_starts and len(line.lstrip()) == len(line) - len(returns_indent): # type: ignore if any([line.lstrip().startswith(p) for p in returns]): if start_idx is not None: indices2remove.extend(list(range(start_idx, idx))) @@ -1226,7 +1247,7 @@ def remove_parameters_returns_from_docstring( @contextmanager -def timeout(duration: float): +def timeout(duration: Union[float, int]): """A context manager that raises a :class:`TimeoutError` after a specified time (in seconds). @@ -1247,16 +1268,25 @@ def timeout(duration: float): duration = 0 elif duration < 0: raise ValueError("`duration` must be non-negative") - elif duration > 0: # granularity is 1 second, so round up + elif duration > 0: duration = max(1, int(duration)) + if duration == 0: + yield + return + + old_handler = signal.getsignal(signal.SIGALRM) + def timeout_handler(signum, frame): raise TimeoutError(f"block timedout after `{duration}` seconds") signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(duration) - yield - signal.alarm(0) + signal.alarm(duration) # type: ignore + try: + yield + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) class Timer(ReprMixin): @@ -1365,7 +1395,7 @@ def extra_repr_keys(self) -> List[str]: return ["name", "verbose"] -def get_kwargs(func_or_cls: callable, kwonly: bool = False) -> Dict[str, Any]: +def get_kwargs(func_or_cls: Callable, kwonly: bool = False) -> Dict[str, Any]: """Get the kwargs of a function or class. Parameters @@ -1403,7 +1433,7 @@ def get_kwargs(func_or_cls: callable, kwonly: bool = False) -> Dict[str, Any]: return kwargs -def get_required_args(func_or_cls: callable) -> List[str]: +def get_required_args(func_or_cls: Callable) -> List[str]: """Get the required positional arguments of a function or class. Parameters @@ -1427,7 +1457,7 @@ def get_required_args(func_or_cls: callable) -> List[str]: return required_args -def add_kwargs(func: callable, **kwargs: Any) -> callable: +def add_kwargs(func: Callable, **kwargs: Any) -> Callable: """Add keyword arguments to a function. This function is used to add keyword arguments to a function @@ -1466,18 +1496,18 @@ def add_kwargs(func: callable, **kwargs: Any) -> callable: # move the VAR_POSITIONAL and VAR_KEYWORD in `func_parameters` to the end for k, v in func_parameters.items(): if v.kind == inspect.Parameter.VAR_POSITIONAL: - func_parameters.move_to_end(k) + func_parameters.move_to_end(k) # type: ignore break for k, v in func_parameters.items(): if v.kind == inspect.Parameter.VAR_KEYWORD: - func_parameters.move_to_end(k) + func_parameters.move_to_end(k) # type: ignore break if isinstance(func, types.MethodType): # can not assign `__signature__` to a bound method directly - func.__func__.__signature__ = func_signature.replace(parameters=func_parameters.values()) + func.__func__.__signature__ = func_signature.replace(parameters=func_parameters.values()) # type: ignore else: - func.__signature__ = func_signature.replace(parameters=func_parameters.values()) + func.__signature__ = func_signature.replace(parameters=func_parameters.values()) # type: ignore # docstring is automatically copied by `functools.wraps` @@ -1492,62 +1522,114 @@ def wrapper(*args: Any, **kwargs_: Any) -> Any: return wrapper -def make_serializable(x: Union[np.ndarray, np.generic, dict, list, tuple]) -> Union[list, dict, Number]: - """Make an object serializable. +def _is_pathlike_string(s: str) -> bool: + """Heuristically check if a string looks like a filesystem path.""" + if not isinstance(s, str): + return False - This function is used to convert all numpy arrays to list in an object, - and also convert numpy data types to python data types in the object, - so that it can be serialized by :mod:`json`. + p = PurePath(s) + if os.sep in s or (os.altsep and os.altsep in s): + return True + if s.startswith((".", "~")) or p.is_absolute(): + return True + if p.suffix != "": + return True + if len(s) > 2 and s[1] == ":" and s[0].isalpha() and s[2] in ("/", "\\"): + return True + return False + + +def make_serializable( + x: Any, drop_unserializable: bool = True, drop_paths: bool = False +) -> Optional[Union[list, dict, str, int, float, bool]]: + """Recursively convert object into JSON-serializable form. + + Rules + ----- + - NDArray → list + - np.generic → Python scalar + - dict → new dict with only serializable values + - list/tuple → list with only serializable values + - str/int/float/bool/None → kept + - if drop_unserializable: + anything else (like Path, custom classes) → dropped (return None) + else: + fallback to str(x) Parameters ---------- - x : Union[numpy.ndarray, numpy.generic, dict, list, tuple] - Input data, which can be numpy array (or numpy data type), - or dict, list, tuple containing numpy arrays (or numpy data type). + x : Any + Input object to be converted. + drop_unserializable : bool, default=True + Whether to drop unserializable objects (return None), + or convert them to string with str(x). + + .. versionadded:: 0.0.32 + drop_paths : bool, default=False + If True, drop all filesystem paths (Path objects and strings + that look like paths). + + .. versionadded:: 0.0.32 Returns ------- - Union[list, dict, numbers.Number] - Converted data. + Optional[Union[list, dict, str, int, float, bool]] + A JSON-serializable object, or None if dropped. Examples -------- >>> import numpy as np - >>> from fl_sim.utils.misc import make_serializable - >>> x = np.array([1, 2, 3]) - >>> make_serializable(x) + >>> make_serializable(np.array([1, 2, 3])) [1, 2, 3] - >>> x = {"a": np.array([1, 2, 3]), "b": np.array([4, 5, 6])} - >>> make_serializable(x) - {'a': [1, 2, 3], 'b': [4, 5, 6]} - >>> x = [np.array([1, 2, 3]), np.array([4, 5, 6])] - >>> make_serializable(x) - [[1, 2, 3], [4, 5, 6]] - >>> x = (np.array([1, 2, 3]), np.array([4, 5, 6]).mean()) - >>> obj = make_serializable(x) - >>> obj - [[1, 2, 3], 5.0] - >>> type(obj[1]), type(x[1]) - (float, numpy.float64) + >>> make_serializable({"a": np.float64(3.14), "b": Path("file.txt")}) + {'a': 3.14, 'b': 'file.txt'} + >>> make_serializable({"a": np.float64(3.14), "b": Path("file.txt")}, drop_paths=True) + {'a': 3.14} """ + if isinstance(x, np.ndarray): - return x.tolist() - elif isinstance(x, (list, tuple)): - # to avoid cases where the list contains numpy data types - return [make_serializable(v) for v in x] - elif isinstance(x, dict): - for k, v in x.items(): - x[k] = make_serializable(v) + return make_serializable(x.tolist(), drop_unserializable=drop_unserializable, drop_paths=drop_paths) + elif isinstance(x, np.generic): return x.item() - # the other types will be returned directly - return x + + elif isinstance(x, dict): + result = {} + for k, v in x.items(): + v_serial = make_serializable(v, drop_unserializable=drop_unserializable, drop_paths=drop_paths) + if v_serial is not None: + result[k] = v_serial + return result if result else None + + elif isinstance(x, (list, tuple)): + result = [] + for v in x: + v_serial = make_serializable(v, drop_unserializable=drop_unserializable, drop_paths=drop_paths) + if v_serial is not None: + result.append(v_serial) + return result if result else None + + elif isinstance(x, (str, int, float, bool, type(None))): + if isinstance(x, str) and drop_paths and _is_pathlike_string(x): + return None + return x + + elif isinstance(x, Path): + if drop_paths: + return None + return str(x) + + else: + if drop_unserializable: + return None + else: + return str(x) def select_k( - arr: np.ndarray, k: Union[int, List[int], np.ndarray], dim: int = -1, largest: bool = True, sorted: bool = True -) -> Tuple[np.ndarray, np.ndarray]: + arr: NDArray, k: Union[int, List[int], NDArray], dim: int = -1, largest: bool = True, sorted: bool = True +) -> Tuple[NDArray, NDArray]: """Select elements from an array along a specified axis of specific rankings. Parameters @@ -1602,7 +1684,7 @@ def select_k( return values, indices -def np_topk(arr: np.ndarray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[np.ndarray, np.ndarray]: +def np_topk(arr: NDArray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[NDArray, NDArray]: """Find the k largest elements of an array along a specified axis. Parameters diff --git a/torch_ecg/utils/rpeaks.py b/torch_ecg/utils/rpeaks.py index 635fd724..b4163346 100644 --- a/torch_ecg/utils/rpeaks.py +++ b/torch_ecg/utils/rpeaks.py @@ -16,10 +16,10 @@ """ -from numbers import Real +from typing import Union import biosppy.signals.ecg as BSE -import numpy as np +from numpy.typing import NDArray from wfdb.processing.qrs import gqrs_detect as _gqrs_detect from wfdb.processing.qrs import xqrs_detect as _xqrs_detect @@ -38,7 +38,7 @@ # algorithms from wfdb -def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def xqrs_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """XQRS algorithm. default kwargs: @@ -60,7 +60,7 @@ def xqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def gqrs_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """GQRS algorithm. default kwargs: @@ -102,7 +102,7 @@ def gqrs_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: # --------------------------------------------------------------------- # algorithms from biosppy -def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def hamilton_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """The default detector used by `biosppy`. This algorithm is based on [#ham]_. @@ -126,7 +126,7 @@ def hamilton_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def ssf_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """Slope Sum Function (SSF) This algorithm is originally proposed for blood pressure (BP) @@ -157,7 +157,7 @@ def ssf_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def christov_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """Christov detector. Detector proposed in [#chr]_. @@ -181,7 +181,7 @@ def christov_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def engzee_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """Detector proposed by Engelse and Zeelenberg. This algorithm is originally proposed in [#ez1]_, @@ -213,7 +213,7 @@ def engzee_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: return rpeaks -def gamboa_detect(sig: np.ndarray, fs: Real, **kwargs) -> np.ndarray: +def gamboa_detect(sig: NDArray, fs: Union[float, int], **kwargs) -> NDArray: """Detector proposed by Gamboa. This algorithm is proposed in a PhD thesis [#gam]_. diff --git a/torch_ecg/utils/utils_data.py b/torch_ecg/utils/utils_data.py index c3d7530e..08565a8c 100644 --- a/torch_ecg/utils/utils_data.py +++ b/torch_ecg/utils/utils_data.py @@ -7,12 +7,12 @@ from collections import Counter from copy import deepcopy from dataclasses import dataclass -from numbers import Real from pathlib import Path from typing import Dict, List, Literal, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd +from numpy.typing import NDArray from sklearn.utils import compute_class_weight from torch import Tensor, from_numpy from torch.nn.functional import interpolate @@ -41,11 +41,11 @@ def get_mask( shape: Union[int, Sequence[int]], - critical_points: np.ndarray, + critical_points: NDArray, left_bias: int, right_bias: int, return_fmt: Literal["mask", "intervals"] = "mask", -) -> Union[np.ndarray, list]: +) -> Union[NDArray, list]: """Get the mask around the given critical points. Parameters @@ -93,19 +93,21 @@ def get_mask( mask[..., itv[0] : itv[1]] = 1 elif return_fmt.lower() == "intervals": mask = l_itv + else: + raise ValueError(f"Unknown return_fmt. Expected 'mask' or 'intervals', but got {return_fmt}") return mask def class_weight_to_sample_weight( - y: np.ndarray, class_weight: Union[str, dict, List[float], np.ndarray] = "balanced" -) -> np.ndarray: + y: Union[NDArray, Sequence], class_weight: Union[str, dict, List[float], NDArray, None] = "balanced" +) -> NDArray: """Transform class weight to sample weight. Parameters ---------- - y : numpy.ndarray + y : numpy.ndarray or Sequence The label (class) of each sample. - class_weight : str or dict or List[float] or numpy.ndarray, default "balanced" + class_weight : str or dict or List[float] or numpy.ndarray or None, default "balanced" The weight for each sample class. If is "balanced", the class weight will automatically be given by the inverse of the class frequency. @@ -138,6 +140,7 @@ def class_weight_to_sample_weight( sample_weight = np.ones_like(y, dtype=DEFAULTS.np_dtype) return sample_weight + y = np.asarray(y) try: sample_weight = np.array(y.copy()).astype(int) except ValueError: @@ -249,7 +252,7 @@ def rdheader(header_data: Union[Path, str, Sequence[str]]) -> Union[Record, Mult return record -def ensure_lead_fmt(values: np.ndarray, n_leads: int = 12, fmt: str = "lead_first") -> np.ndarray: +def ensure_lead_fmt(values: NDArray, n_leads: int = 12, fmt: str = "lead_first") -> NDArray: """Ensure the multi-lead (ECG) signal to be of specified format. Parameters @@ -299,11 +302,11 @@ def ensure_lead_fmt(values: np.ndarray, n_leads: int = 12, fmt: str = "lead_firs def ensure_siglen( - values: np.ndarray, + values: NDArray, siglen: int, fmt: str = "lead_first", tolerance: Optional[float] = None, -) -> np.ndarray: +) -> NDArray: """Ensure the (ECG) signal to be of specified length. Strategy: @@ -398,16 +401,16 @@ class ECGWaveForm: ---------- name : str Name of the wave, e.g. "N", "p", "t", etc. - onset : numbers.Real + onset : float or int Onset index of the wave, :class:`~numpy.nan` for unknown/unannotated onset. - offset : numbers.Real + offset : float or int Offset index of the wave, :class:`~numpy.nan` for unknown/unannotated offset. - peak : numbers.Real + peak : float or int Peak index of the wave, :class:`~numpy.nan` for unknown/unannotated peak. - duration : numbers.Real + duration : float or int Suration of the wave, with units in milliseconds, :class:`~numpy.nan` for unknown/unannotated duration. @@ -418,13 +421,13 @@ class ECGWaveForm: """ name: str - onset: Real - offset: Real - peak: Real - duration: Real + onset: Union[float, int] + offset: Union[float, int] + peak: Union[float, int] + duration: Union[float, int] @property - def duration_(self) -> Real: + def duration_(self) -> Union[float, int]: """Duration of the wave, with units in number of samples.""" try: return self.offset - self.onset @@ -445,9 +448,9 @@ def duration_(self) -> Real: def masks_to_waveforms( - masks: np.ndarray, + masks: NDArray, class_map: Dict[str, int], - fs: Real, + fs: Union[float, int], mask_format: str = "channel_first", leads: Optional[Sequence[str]] = None, ) -> Dict[str, List[ECGWaveForm]]: @@ -461,7 +464,7 @@ def masks_to_waveforms( class_map : dict Class map, mapping names to waves to numbers from 0 to n_classes-1, the keys should contain "pwave", "qrs", "twave". - fs : numbers.Real + fs : float or int Sampling frequency of the signal corresponding to the `masks`, used to compute the duration of each waveform. mask_format : str, default "channel_first" @@ -542,7 +545,7 @@ def masks_to_waveforms( def mask_to_intervals( - mask: np.ndarray, + mask: NDArray, vals: Optional[Union[int, Sequence[int]]] = None, right_inclusive: bool = False, ) -> Union[list, dict]: @@ -614,14 +617,14 @@ def mask_to_intervals( return intervals -def uniform(low: Real, high: Real, num: int) -> List[float]: +def uniform(low: Union[float, int], high: Union[float, int], num: int) -> List[float]: """Generate a list of numbers uniformly distributed. Parameters ---------- - low : numbers.Real + low : float or int Lower bound of the interval of the uniform distribution. - high : numbers.Real + high : float or int Upper bound of the interval of the uniform distribution. num : int Number of random numbers to generate. @@ -694,7 +697,7 @@ def stratified_train_test_split( try: df_inspection = df[stratified_cols].copy().map(str) except AttributeError: - df_inspection = df[stratified_cols].copy().applymap(str) + df_inspection = df[stratified_cols].copy().applymap(str) # type: ignore for item in stratified_cols: all_entities = df_inspection[item].unique().tolist() entities_dict = {e: str(i) for i, e in enumerate(all_entities)} @@ -704,7 +707,7 @@ def stratified_train_test_split( df_inspection[inspection_col_name] = "" for idx, row in df_inspection.iterrows(): cn = "-".join([row[sc] for sc in stratified_cols]) - df_inspection.loc[idx, inspection_col_name] = cn + df_inspection.loc[idx, inspection_col_name] = cn # type: ignore item_names = df_inspection[inspection_col_name].unique().tolist() item_indices = {n: df_inspection.index[df_inspection[inspection_col_name] == n].tolist() for n in item_names} for n in item_names: @@ -723,8 +726,8 @@ def stratified_train_test_split( def one_hot_encode( - cls_array: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32 -) -> np.ndarray: + cls_array: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32 +) -> NDArray: """Convert a categorical array to a one-hot array. Convert a categorical (class indices) array of shape ``(num_samples,)`` @@ -768,42 +771,42 @@ def one_hot_encode( else: # sequence of sequences of class indices num_classes = max([max(c) for c in cls_array]) + 1 if isinstance(cls_array, np.ndarray) and cls_array.ndim == 1: - assert num_classes > 0 and num_classes >= cls_array.max() + 1, ( + assert num_classes > 0 and num_classes >= cls_array.max() + 1, ( # type: ignore "num_classes must be greater than 0 and greater than or equal to " "the max value of `cls_array` if `cls_array` is 1D and `num_classes` is specified" ) elif isinstance(cls_array, Sequence): assert all( - [max(c) < num_classes for c in cls_array] + [max(c) < num_classes for c in cls_array] # type: ignore ), "all values in the multi-class `cls_array` should be less than `num_classes`" if isinstance(cls_array, np.ndarray) and cls_array.ndim == 2 and cls_array.shape[1] == num_classes: bin_array = cls_array else: shape = (len(cls_array), num_classes) - bin_array = np.zeros(shape) + bin_array = np.zeros(shape) # type: ignore for i in range(shape[0]): bin_array[i, cls_array[i]] = 1 return bin_array.astype(dtype) -@add_docstring(one_hot_encode.__doc__.replace("one_hot_encode", "cls_to_bin")) +@add_docstring(one_hot_encode.__doc__.replace("one_hot_encode", "cls_to_bin")) # type: ignore def cls_to_bin( - cls_array: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32 -) -> np.ndarray: + cls_array: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, dtype: type = np.float32 +) -> NDArray: """Alias of `one_hot_encode`.""" warnings.warn("`cls_to_bin` is deprecated, use `one_hot_encode` instead", DeprecationWarning) return one_hot_encode(cls_array, num_classes, dtype) def generate_weight_mask( - target_mask: np.ndarray, - fg_weight: Real, - fs: Real, - reduction: Real, - radius: Real, - boundary_weight: Real, + target_mask: NDArray, + fg_weight: Union[float, int], + fs: Union[float, int], + reduction: Union[float, int], + radius: Union[float, int], + boundary_weight: Union[float, int], plot: bool = False, -) -> np.ndarray: +) -> NDArray: """Generate weight mask for a binary target mask, accounting the foreground weight and boundary weight. @@ -811,15 +814,15 @@ def generate_weight_mask( ---------- target_mask : numpy.ndarray The target mask, assumed to be 1D and binary. - fg_weight: numbers.Real + fg_weight: float or int Foreground (value 1) weight, usually > 1. - fs : numbers.Real + fs : float or int Sampling frequency of the signal. - reduction : numbers.Real + reduction : float or int Reduction ratio of the mask w.r.t. the signal. - radius : numbers.Real + radius : float or int Radius of the boundary, with units in seconds. - boundary_weight : numbers.Real + boundary_weight : float or int Weight for the boundaries (positions where values change) of the target map. plot : bool, default False @@ -855,7 +858,7 @@ def generate_weight_mask( """ assert target_mask.ndim == 1, "`target_mask` should be 1D" assert set(np.unique(target_mask)).issubset({0, 1}), "`target_mask` should be binary" - assert isinstance(reduction, Real) and reduction >= 1, "`reduction` should be a real number greater than 1" + assert isinstance(reduction, (int, float)) and reduction >= 1, "`reduction` should be a real number greater than 1" if reduction > 1: # downsample the target mask target_mask = ( diff --git a/torch_ecg/utils/utils_interval.py b/torch_ecg/utils/utils_interval.py index b9b3c91e..5eb48aa8 100644 --- a/torch_ecg/utils/utils_interval.py +++ b/torch_ecg/utils/utils_interval.py @@ -20,6 +20,7 @@ from typing import Any, List, Literal, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray __all__ = [ "overlaps", @@ -51,9 +52,9 @@ def overlaps(interval: Interval, another: Interval) -> int: The amount of overlap, in bp between interval and anohter, is returned. - - If > 0, the number of bp of overlap - - If 0, they are book-ended - - If < 0, the distance in bp between them + - If > 0, the number of bp of overlap + - If 0, they are book-ended + - If < 0, the distance in bp between them Parameters ---------- @@ -104,8 +105,8 @@ def validate_interval( tuple 2-tuple consisting of - - bool: indicating whether `interval` is a valid interval - - an interval (can be empty) + - bool: indicating whether `interval` is a valid interval + - an interval (can be empty) Examples -------- @@ -770,9 +771,9 @@ def find_max_cont_len(sublist: Interval, tot_rng: Real) -> dict: dict A dictionary containing the following keys: - - "max_cont_len" - - "max_cont_sublist_start" - - "max_cont_sublist" + - "max_cont_len" + - "max_cont_sublist_start" + - "max_cont_sublist" Examples -------- @@ -857,7 +858,7 @@ def generalized_interval_len(generalized_interval: GeneralizedInterval) -> Real: return gi_len -def find_extrema(signal: Union[np.ndarray, Sequence], mode: Literal["max", "min", "both"] = "both") -> np.ndarray: +def find_extrema(signal: Union[NDArray, Sequence], mode: Literal["max", "min", "both"] = "both") -> NDArray: """Locate local extrema points in a 1D signal. This function is based on Fermat's Theorem. diff --git a/torch_ecg/utils/utils_metrics.py b/torch_ecg/utils/utils_metrics.py index 7be1ff78..e7bed79c 100644 --- a/torch_ecg/utils/utils_metrics.py +++ b/torch_ecg/utils/utils_metrics.py @@ -12,6 +12,7 @@ import einops import numpy as np import torch +from numpy.typing import NDArray from torch import Tensor from ..cfg import DEFAULTS @@ -29,8 +30,8 @@ def top_n_accuracy( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], n: Union[int, Sequence[int]] = 1, ) -> Union[float, Dict[str, float]]: """Compute top n accuracy. @@ -85,10 +86,10 @@ def top_n_accuracy( def confusion_matrix( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, -) -> np.ndarray: +) -> NDArray: """Compute a binary confusion matrix The columns are ground truth labels and rows are predicted labels. @@ -129,10 +130,10 @@ def confusion_matrix( def one_vs_rest_confusion_matrix( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, -) -> np.ndarray: +) -> NDArray: """Compute binary one-vs-rest confusion matrices. Columns are ground truth labels and rows are predicted labels. @@ -216,13 +217,13 @@ def one_vs_rest_confusion_matrix( "prepend", ) def metrics_from_confusion_matrix( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Dict[str, Union[float, np.ndarray]]: +) -> Dict[str, Union[float, NDArray]]: """ Returns ------- @@ -434,13 +435,13 @@ def metrics_from_confusion_matrix( "prepend", ) def f_measure( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Tuple[float, np.ndarray]: +) -> Tuple[float, NDArray]: """ Returns ------- @@ -460,13 +461,13 @@ def f_measure( "prepend", ) def sensitivity( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Tuple[float, np.ndarray]: +) -> Tuple[float, NDArray]: """ Returns ------- @@ -492,13 +493,13 @@ def sensitivity( "prepend", ) def precision( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Tuple[float, np.ndarray]: +) -> Tuple[float, NDArray]: """ Returns ------- @@ -522,13 +523,13 @@ def precision( "prepend", ) def specificity( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Tuple[float, np.ndarray]: +) -> Tuple[float, NDArray]: """ Returns ------- @@ -553,13 +554,13 @@ def specificity( "prepend", ) def auc( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, -) -> Tuple[float, float, np.ndarray, np.ndarray]: +) -> Tuple[float, float, NDArray, NDArray]: """ Returns ------- @@ -585,10 +586,10 @@ def auc( "prepend", ) def accuracy( - labels: Union[np.ndarray, Tensor], - outputs: Union[np.ndarray, Tensor], + labels: Union[NDArray, Tensor], + outputs: Union[NDArray, Tensor], num_classes: Optional[int] = None, - weights: Optional[Union[np.ndarray, Tensor]] = None, + weights: Optional[Union[NDArray, Tensor]] = None, thr: float = 0.5, fillna: Union[bool, float] = 0.0, ) -> float: @@ -607,8 +608,8 @@ def accuracy( def QRS_score( - rpeaks_truths: Sequence[Union[np.ndarray, Sequence[int]]], - rpeaks_preds: Sequence[Union[np.ndarray, Sequence[int]]], + rpeaks_truths: Sequence[Union[NDArray, Sequence[int]]], + rpeaks_preds: Sequence[Union[NDArray, Sequence[int]]], fs: Real, thr: float = 0.075, ) -> float: @@ -679,10 +680,10 @@ def QRS_score( def one_hot_pair( - labels: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], - outputs: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], + labels: Union[NDArray, Tensor, Sequence[Sequence[int]]], + outputs: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[NDArray, NDArray]: """Convert categorical (of shape ``(n_samples,)``) labels and outputs to binary (of shape ``(n_samples, n_classes)``) labels and outputs if applicable. @@ -729,7 +730,7 @@ def one_hot_pair( return labels, outputs -def _one_hot_pair(cls_array: Union[np.ndarray, Sequence[Sequence[int]]], shape: Tuple[int]) -> np.ndarray: +def _one_hot_pair(cls_array: Union[NDArray, Sequence[Sequence[int]]], shape: Tuple[int]) -> NDArray: """Convert categorical array to binary array. Parameters @@ -757,18 +758,18 @@ def _one_hot_pair(cls_array: Union[np.ndarray, Sequence[Sequence[int]]], shape: @add_docstring(one_hot_pair.__doc__) def cls_to_bin( - labels: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], - outputs: Union[np.ndarray, Tensor, Sequence[Sequence[int]]], + labels: Union[NDArray, Tensor, Sequence[Sequence[int]]], + outputs: Union[NDArray, Tensor, Sequence[Sequence[int]]], num_classes: Optional[int] = None, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[NDArray, NDArray]: """Alias of `one_hot_pair`.""" warnings.warn("`cls_to_bin` is deprecated, use `one_hot_pair` instead", DeprecationWarning) return one_hot_pair(labels, outputs, num_classes) def compute_wave_delineation_metrics( - truth_masks: Sequence[np.ndarray], - pred_masks: Sequence[np.ndarray], + truth_masks: Sequence[NDArray], + pred_masks: Sequence[NDArray], class_map: Dict[str, int], fs: Real, mask_format: str = "channel_first", diff --git a/torch_ecg/utils/utils_nn.py b/torch_ecg/utils/utils_nn.py index 22987ded..8b9c4804 100644 --- a/torch_ecg/utils/utils_nn.py +++ b/torch_ecg/utils/utils_nn.py @@ -3,6 +3,7 @@ """ +import json import os import pickle import re @@ -17,6 +18,9 @@ import numpy as np import torch from easydict import EasyDict +from numpy.typing import NDArray +from safetensors import safe_open +from safetensors.torch import load_file, save_file from torch import Tensor, nn from ..cfg import CFG, DEFAULTS, DTYPE @@ -52,17 +56,23 @@ def _get_np_dtypes(): _safe_globals = [ CFG, DTYPE, EasyDict, Path, PosixPath, WindowsPath, - np.core.multiarray._reconstruct, + # np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.float32, np.float64, np.int32, np.int64, np.uint8, np.int8, ] + _get_np_dtypes() + if hasattr(np, "core"): + _safe_globals.append(np.core.multiarray._reconstruct) # type: ignore + else: + _safe_globals.append(np._core.multiarray._reconstruct) # type: ignore # fmt: on if hasattr(torch.serialization, "add_safe_globals"): torch.serialization.add_safe_globals(_safe_globals) -def extend_predictions(preds: Sequence, classes: List[str], extended_classes: List[str]) -> np.ndarray: +def extend_predictions( + preds: Union[Sequence, NDArray, torch.Tensor], classes: List[str], extended_classes: List[str] +) -> NDArray: """Extend the prediction arrays to prediction arrays in larger range of classes Parameters @@ -140,8 +150,8 @@ def compute_output_shape( output_padding: Union[Sequence[int], int] = 0, dilation: Union[Sequence[int], int] = 1, channel_last: bool = False, - asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None, -) -> Tuple[Union[int, None]]: + asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None, +) -> Tuple[Union[int, None], ...]: """Compute the output shape of a (transpose) convolution/maxpool/avgpool layer. This function is based on the discussion [#disc]_. @@ -287,17 +297,17 @@ def check_output_validity(shape): none_dim_msg = "spatial dimensions should be all `None`, or all not `None`" if channel_last: if all([n is None for n in input_shape[1:-1]]): - if out_channels is None: + if out_channels is None: # type: ignore raise ValueError("out channel dimension and spatial dimensions are all `None`") - output_shape = tuple(list(input_shape[:-1]) + [out_channels]) + output_shape = tuple(list(input_shape[:-1]) + [out_channels]) # type: ignore return check_output_validity(output_shape) elif any([n is None for n in input_shape[1:-1]]): raise ValueError(none_dim_msg) else: if all([n is None for n in input_shape[2:]]): - if out_channels is None: + if out_channels is None: # type: ignore raise ValueError("out channel dimension and spatial dimensions are all `None`") - output_shape = tuple([input_shape[0], out_channels] + list(input_shape[2:])) + output_shape = tuple([input_shape[0], out_channels] + list(input_shape[2:])) # type: ignore return check_output_validity(output_shape) elif any([n is None for n in input_shape[2:]]): raise ValueError(none_dim_msg) @@ -349,12 +359,12 @@ def check_output_validity(shape): _asymmetric_padding = list(repeat(asymmetric_padding, dim)) else: assert len(asymmetric_padding) == dim and all( - len(ap) == 2 and all(isinstance(p, int) for p in ap) for ap in asymmetric_padding + len(ap) == 2 and all(isinstance(p, int) for p in ap) for ap in asymmetric_padding # type: ignore ), "Invalid `asymmetric_padding`" _asymmetric_padding = asymmetric_padding for idx in range(dim): - _padding[idx][0] += _asymmetric_padding[idx][0] - _padding[idx][1] += _asymmetric_padding[idx][1] + _padding[idx][0] += _asymmetric_padding[idx][0] # type: ignore + _padding[idx][1] += _asymmetric_padding[idx][1] # type: ignore if isinstance(output_padding, int): _output_padding = list(repeat(output_padding, dim)) @@ -400,13 +410,13 @@ def check_output_validity(shape): ] else: output_shape = [ - floor(((i + sum(p) - minus_term(d, k)) / s) + 1) + floor(((i + sum(p) - minus_term(d, k)) / s) + 1) # type: ignore for i, p, d, k, s in zip(_input_shape, _padding, _dilation, _kernel_size, _stride) ] if channel_last: - output_shape = tuple([input_shape[0]] + output_shape + [out_channels]) + output_shape = tuple([input_shape[0]] + output_shape + [out_channels]) # type: ignore else: - output_shape = tuple([input_shape[0], out_channels] + output_shape) + output_shape = tuple([input_shape[0], out_channels] + output_shape) # type: ignore return check_output_validity(output_shape) @@ -419,8 +429,8 @@ def compute_conv_output_shape( padding: Union[Sequence[int], int] = 0, dilation: Union[Sequence[int], int] = 1, channel_last: bool = False, - asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None, -) -> Tuple[Union[int, None]]: + asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None, +) -> Tuple[Union[int, None], ...]: """Compute the output shape of a convolution layer. Parameters @@ -477,7 +487,7 @@ def compute_maxpool_output_shape( padding: Union[Sequence[int], int] = 0, dilation: Union[Sequence[int], int] = 1, channel_last: bool = False, -) -> Tuple[Union[int, None]]: +) -> Tuple[Union[int, None], ...]: """Compute the output shape of a maxpool layer. Parameters @@ -527,7 +537,7 @@ def compute_avgpool_output_shape( stride: Union[Sequence[int], int] = 1, padding: Union[Sequence[int], int] = 0, channel_last: bool = False, -) -> Tuple[Union[int, None]]: +) -> Tuple[Union[int, None], ...]: """Compute the output shape of a avgpool layer. Parameters @@ -578,8 +588,8 @@ def compute_deconv_output_shape( output_padding: Union[Sequence[int], int] = 0, dilation: Union[Sequence[int], int] = 1, channel_last: bool = False, - asymmetric_padding: Union[Sequence[int], Sequence[Sequence[int]]] = None, -) -> Tuple[Union[int, None]]: + asymmetric_padding: Optional[Union[Sequence[int], Sequence[Sequence[int]]]] = None, +) -> Tuple[Union[int, None], ...]: """Compute the output shape of a transpose convolution layer Parameters @@ -657,10 +667,12 @@ def compute_sequential_output_shape( """Compute the output shape of a sequential model.""" assert issubclass(type(model), nn.Sequential), f"model should be nn.Sequential, but got {type(model)}" _seq_len = seq_len + if len(model) == 0: + raise AssertionError("model has no modules") for module in model: output_shape = module.compute_output_shape(_seq_len, batch_size) _, _, _seq_len = output_shape - return output_shape + return output_shape # type: ignore def compute_module_size( @@ -760,7 +772,7 @@ def compute_receptive_field( strides: Union[Sequence[int], int] = 1, dilations: Union[Sequence[int], int] = 1, input_len: Optional[int] = None, - fs: Optional[Real] = None, + fs: Optional[Union[int, float]] = None, ) -> Union[int, float]: """Compute the receptive field of several types of :class:`~torch.nn.Module`. @@ -855,11 +867,11 @@ def compute_receptive_field( receptive_field = min(receptive_field, input_len) if fs is not None: receptive_field /= fs - return make_serializable(receptive_field) + return make_serializable(receptive_field) # type: ignore def default_collate_fn( - batch: Sequence[Union[Tuple[np.ndarray, ...], Dict[str, np.ndarray]]], + batch: Sequence[Union[Tuple[NDArray, ...], Dict[str, NDArray]]], ) -> Union[Tuple[Tensor, ...], Dict[str, Tensor]]: """Default collate functions for model training. @@ -883,13 +895,13 @@ def default_collate_fn( """ if isinstance(batch[0], dict): keys = batch[0].keys() - collated = _default_collate_fn([tuple(b[k] for k in keys) for b in batch]) + collated = _default_collate_fn([tuple(b[k] for k in keys) for b in batch]) # type: ignore return {k: collated[i] for i, k in enumerate(keys)} else: - return _default_collate_fn(batch) + return _default_collate_fn(batch) # type: ignore -def _default_collate_fn(batch: Sequence[Tuple[np.ndarray, ...]]) -> Tuple[Tensor, ...]: +def _default_collate_fn(batch: Sequence[Tuple[NDArray, ...]]) -> Tuple[Tensor, ...]: """Collate functions for tuples of tensors. The data generator (:class:`~torch.utils.data.Dataset`) should @@ -970,7 +982,7 @@ def _adjust_cnn_filter_lengths( ] elif isinstance(v, Real): # DO NOT use `int`, which might not work for numpy array elements - if v > 1: + if v > 1: # type: ignore config[k] = int(round(v * fs / config["fs"])) if ensure_odd: config[k] = config[k] - config[k] % 2 + 1 @@ -1018,27 +1030,27 @@ class SizeMixin(object): @property def module_size(self) -> int: """Size of trainable parameters in the model in terms of number of parameters.""" - return compute_module_size(self) + return compute_module_size(self) # type: ignore @property def module_size_(self) -> str: """Size of trainable parameters in the model in terms of memory capacity.""" - return compute_module_size(self, human=True) + return compute_module_size(self, human=True) # type: ignore @property def sizeof(self) -> int: """Size of the model in terms of number of parameters, including non-trainable parameters and buffers.""" - return compute_module_size(self, requires_grad=False, include_buffers=True, human=False) + return compute_module_size(self, requires_grad=False, include_buffers=True, human=False) # type: ignore @property def sizeof_(self) -> str: """Size of the model in terms of memory capacity, including non-trainable parameters and buffers.""" - return compute_module_size(self, requires_grad=False, include_buffers=True, human=True) + return compute_module_size(self, requires_grad=False, include_buffers=True, human=True) # type: ignore @property def dtype(self) -> torch.dtype: try: - return next(self.parameters()).dtype + return next(self.parameters()).dtype # type: ignore except StopIteration: return torch.float32 except Exception as err: @@ -1047,7 +1059,7 @@ def dtype(self) -> torch.dtype: @property def device(self) -> torch.device: try: - return next(self.parameters()).device + return next(self.parameters()).device # type: ignore except StopIteration: return torch.device("cpu") except Exception as err: @@ -1103,7 +1115,16 @@ def make_safe_globals(obj: CFG, remove_paths: bool = True) -> CFG: sg = None elif isinstance(sg, (str, bytes)) and os.path.exists(sg): sg = None - return sg + return sg # type: ignore + + +_SFT_META_MODEL_CFG = "model_config" +_SFT_META_TRAIN_CFG = "train_config" +_SFT_META_FORMAT = "__format__" +_SFT_META_VERSION = "torch_ecg.ckpt.v1" +_SFT_META_EXTRA_JSON_PREFIX = "__extra_json__/" # JSON-serialized extras +_SFT_EXTRA_TENSOR_PREFIX = "__extra__/" # Tensor extras grouped under this prefix +_SFT_META_EXTRA_TENSOR_GROUPS = "__extra_tensor_groups__" # JSON list of groups class CkptMixin(object): @@ -1122,8 +1143,9 @@ def from_checkpoint( ---------- path : `path-like` Path to the checkpoint. - If it is a directory, then this directory should contain only one checkpoint file - (with the extension `.pth` or `.pt`). + If it is a directory, then this directory should be one of the following cases: + - contain a `model.safetensors` file, a `model_config.json` file, and a `train_config.json` file. + - contain only one checkpoint file (with the extension `.pth` or `.pt`). device : torch.device, optional Map location of the model parameters, defaults to "cuda" if available, otherwise "cpu". @@ -1140,16 +1162,77 @@ def from_checkpoint( Auxiliary configs that are needed for data preprocessing, etc. """ - if Path(path).is_dir(): - candidates = list(Path(path).glob("*.pth")) + list(Path(path).glob("*.pt")) - assert len(candidates) == 1, "The directory should contain only one checkpoint file" - path = candidates[0] - _device = device or DEFAULTS.device + _device = device or DEFAULTS.device # type: ignore if weights_only == "auto": if hasattr(torch.serialization, "add_safe_globals"): weights_only = True else: weights_only = False + + if isinstance(path, bytes): + path = path.decode() + path = Path(path).expanduser().resolve() # type: ignore + + if path.is_dir(): + candidates = list(path.glob("*.pth")) + list(path.glob("*.pt")) + assert len(candidates) in [0, 1], "The directory should contain only one checkpoint file" + if len(candidates) == 0: + # the directory should contain a `model.safetensors` file, a `model_config.json` file, + # and a `train_config.json` file + model_path = path / "model.safetensors" + train_config_path = path / "train_config.json" + model_config_path = path / "model_config.json" + assert model_path.exists(), "model.safetensors file not found" + assert train_config_path.exists(), "train_config.json file not found" + assert model_config_path.exists(), "model_config.json file not found" + train_config = json.loads(train_config_path.read_text()) + model_config = json.loads(model_config_path.read_text()) + aux_config = train_config + kwargs = dict(config=model_config) + if "classes" in aux_config: + kwargs["classes"] = aux_config["classes"] + if "n_leads" in aux_config: + kwargs["n_leads"] = aux_config["n_leads"] + model = cls(**kwargs) + model.load_state_dict(load_file(model_path, device="cpu")) # type: ignore + model.to(_device) # type: ignore + return model, aux_config # type: ignore + + # we have only one checkpoint file in the directory + # and we will use `torch.load` to load the model + path = candidates[0] + + if path.suffix == ".safetensors": # load safetensors format file + try: + with safe_open(str(path), framework="pt", device="cpu") as f: # type: ignore + meta = f.metadata() or {} + if _SFT_META_MODEL_CFG in meta and _SFT_META_TRAIN_CFG in meta: + model_config = json.loads(meta[_SFT_META_MODEL_CFG]) + train_config = json.loads(meta[_SFT_META_TRAIN_CFG]) + aux_config = train_config + kwargs = dict(config=model_config) + if "classes" in aux_config: + kwargs["classes"] = aux_config["classes"] + if "n_leads" in aux_config: + kwargs["n_leads"] = aux_config["n_leads"] + model = cls(**kwargs) + + # load only model weights, and ignores extra tensors + state_dict = {} + for key in f.keys(): + if not key.startswith(_SFT_EXTRA_TENSOR_PREFIX): + state_dict[key] = f.get_tensor(key) + model.load_state_dict(state_dict, strict=False) # type: ignore + model.to(_device) # type: ignore + return model, aux_config # type: ignore + except Exception as e: + # failed to load the safetensors file + raise RuntimeError( + "Failed to load the safetensors file. " + "The file may be corrupted, or the version of safetensors is incompatible. " + "Try updating safetensors to the latest version, or check the file integrity." + ) from e + try: ckpt = torch.load(path, map_location=_device, weights_only=weights_only) except pickle.UnpicklingError as pue: @@ -1160,16 +1243,14 @@ def from_checkpoint( ) from pue aux_config = ckpt.get("train_config", None) or ckpt.get("config", None) assert aux_config is not None, "input checkpoint has no sufficient data to recover a model" - kwargs = dict( - config=ckpt["model_config"], - ) + kwargs = dict(config=ckpt["model_config"]) if "classes" in aux_config: kwargs["classes"] = aux_config["classes"] if "n_leads" in aux_config: kwargs["n_leads"] = aux_config["n_leads"] model = cls(**kwargs) - model.load_state_dict(ckpt["model_state_dict"]) - return model, aux_config + model.load_state_dict(ckpt["model_state_dict"]) # type: ignore + return model, aux_config # type: ignore @classmethod def from_remote( @@ -1209,9 +1290,22 @@ def from_remote( model_path_or_dir = http_get(url, model_dir, extract="auto", filename=filename) return cls.from_checkpoint(model_path_or_dir, device=device, weights_only=weights_only) - def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None: + def save( + self, + path: Union[str, bytes, os.PathLike], + train_config: CFG, + extra_items: Optional[dict] = None, + use_safetensors: bool = True, + safetensors_single_file: bool = True, + ) -> None: """Save the model to disk. + .. note:: + + `safetensors` is used by default to save the model. + If one wants to save the models in `.pth` or `.pt` format, + he/she must explicitly set ``use_safetensors=False``. + Parameters ---------- path : `path-like` @@ -1219,22 +1313,98 @@ def save(self, path: Union[str, bytes, os.PathLike], train_config: CFG) -> None: train_config : CFG Config for training the model, used when one restores the model. + extra_items : dict, optional + Extra items to save along with the model. + The values should be serializable: can be saved as a json file, + or is a dict of torch tensors. + + .. versionadded:: 0.0.32 + use_safetensors : bool, default True + Whether to use `safetensors` to save the model. + This will be overridden by the suffix of `path`: + if it is `.safetensors`, then `use_safetensors` is set to True; + if it is `.pth` or `.pt`, then if `use_safetensors` is True, + the suffix is changed to `.safetensors`, otherwise it is unchanged. + + .. versionadded:: 0.0.32 + safetensors_single_file : bool, default True + Whether to save the metadata along with the state dict into one file. + + .. versionadded:: 0.0.32 Returns ------- None """ - path = Path(path) + if isinstance(path, bytes): + path = path.decode() + path = Path(path).expanduser().resolve() # type: ignore if not path.parent.exists(): path.parent.mkdir(parents=True) - _model_config = make_safe_globals(self.config) + extra_items = extra_items or {} + + _model_config = make_safe_globals(self.config) # type: ignore _train_config = make_safe_globals(train_config) + + if path.suffix in [".pth", ".pt"]: + if use_safetensors: + path = path.with_suffix(".safetensors") + warnings.warn( + f"`safetensors` is used by default. The saved file name is changed to {path.name}", RuntimeWarning + ) + elif path.suffix == ".safetensors": + use_safetensors = True + + if use_safetensors and safetensors_single_file: + tensors = dict(self.state_dict()) # type: ignore + + tensor_groups = [] + for key, val in extra_items.items(): + if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()): + tensor_groups.append(key) + for tname, ten in val.items(): + tensors[f"{_SFT_EXTRA_TENSOR_PREFIX}{key}/{tname}"] = ten + + meta = { + _SFT_META_FORMAT: _SFT_META_VERSION, + _SFT_META_MODEL_CFG: json.dumps(make_serializable(_model_config), ensure_ascii=False), + _SFT_META_TRAIN_CFG: json.dumps(make_serializable(_train_config), ensure_ascii=False), + _SFT_META_EXTRA_TENSOR_GROUPS: json.dumps(sorted(tensor_groups)), + } + for key, val in extra_items.items(): + if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()): + continue + meta[f"{_SFT_META_EXTRA_JSON_PREFIX}{key}"] = json.dumps(make_serializable(val), ensure_ascii=False) + + save_file(tensors, path.with_suffix(".safetensors"), metadata=meta) + return + + if use_safetensors: # not single file + # save the model with safetensors into a zip file with the same name as `path` + # `model_config` and `train_config` are saved as json files + path = path.with_suffix("") + path.mkdir(exist_ok=True) + _model_config = make_serializable(_model_config) + _train_config = make_serializable(_train_config) + (path / "model_config.json").write_text(json.dumps(_model_config, ensure_ascii=False)) + (path / "train_config.json").write_text(json.dumps(_train_config, ensure_ascii=False)) + save_file(self.state_dict(), path / "model.safetensors") # type: ignore + # save extra items + for key, val in extra_items.items(): + # if val is a dict of torch tensors, save them as safetensors + if isinstance(val, dict) and all(isinstance(v, torch.Tensor) for v in val.values()): + save_file(val, path / f"{key}.safetensors") + else: + (path / f"{key}.json").write_text(json.dumps(make_serializable(val), ensure_ascii=False)) + return + torch.save( { - "model_state_dict": self.state_dict(), + "model_state_dict": self.state_dict(), # type: ignore "model_config": _model_config, "train_config": _train_config, + **extra_items, }, path, ) diff --git a/torch_ecg/utils/utils_signal.py b/torch_ecg/utils/utils_signal.py index e3c41d07..305e5588 100644 --- a/torch_ecg/utils/utils_signal.py +++ b/torch_ecg/utils/utils_signal.py @@ -5,10 +5,10 @@ import warnings from copy import deepcopy -from numbers import Real from typing import Iterable, Literal, Optional, Sequence, Tuple, Union import numpy as np +from numpy.typing import NDArray from scipy import interpolate from scipy.signal import butter, filtfilt, peak_prominences @@ -26,12 +26,12 @@ def smooth( - x: np.ndarray, + x: NDArray, window_len: int = 11, window: Literal["flat", "hanning", "hamming", "bartlett", "blackman"] = "hanning", mode: str = "valid", keep_dtype: bool = True, -) -> np.ndarray: +) -> NDArray: """Smooth the 1d data using a window with requested size. This method is originally from [#smooth]_, @@ -119,7 +119,7 @@ def smooth( else: w = eval("np." + window + "(radius)") - y = np.convolve(w / w.sum(), s, mode=mode) + y = np.convolve(w / w.sum(), s, mode=mode) # type: ignore y = y[(radius // 2 - 1) : -(radius // 2) - 1] assert len(x) == len(y) @@ -130,14 +130,14 @@ def smooth( def resample_irregular_timeseries( - sig: np.ndarray, - output_fs: Optional[Real] = None, + sig: NDArray, + output_fs: Optional[Union[float, int]] = None, method: Literal["spline", "interp1d"] = "interp1d", return_with_time: bool = False, - tnew: Optional[np.ndarray] = None, + tnew: Optional[NDArray] = None, interp_kw: dict = {}, verbose: int = 0, -) -> np.ndarray: +) -> NDArray: """ Resample the 2d irregular timeseries `sig` into a 1d or 2d regular time series with frequency `output_fs`, @@ -149,7 +149,7 @@ def resample_irregular_timeseries( sig : numpy.ndarray The 2d irregular timeseries. Each row is ``[time, value]``. - output_fs : numbers.Real, optional + output_fs : float or int, optional the frequency of the output 1d regular timeseries, one and only one of `output_fs` and `tnew` should be specified method : {"spline", "interp1d"}, default "interp1d" @@ -203,7 +203,7 @@ def resample_irregular_timeseries( dtype = sig.dtype time_series = np.atleast_2d(sig).astype(dtype) if tnew is None: - step_ts = 1000 / output_fs + step_ts = 1000 / output_fs # type: ignore tot_len = int((time_series[-1][0] - time_series[0][0]) / step_ts) + 1 xnew = time_series[0][0] + np.arange(0, tot_len * step_ts, step_ts) else: @@ -234,19 +234,19 @@ def resample_irregular_timeseries( regular_timeseries = f(xnew) if return_with_time: - return np.column_stack((xnew, regular_timeseries)).astype(dtype) + return np.column_stack((xnew, regular_timeseries)).astype(dtype) # type: ignore else: - return regular_timeseries.astype(dtype) + return regular_timeseries.astype(dtype) # type: ignore def detect_peaks( x: Sequence, - mph: Optional[Real] = None, + mph: Optional[Union[float, int]] = None, mpd: int = 1, - threshold: Real = 0, - left_threshold: Real = 0, - right_threshold: Real = 0, - prominence: Optional[Real] = None, + threshold: Union[float, int] = 0, + left_threshold: Union[float, int] = 0, + right_threshold: Union[float, int] = 0, + prominence: Optional[Union[float, int]] = None, prominence_wlen: Optional[int] = None, edge: Union[str, None] = "rising", kpsh: bool = False, @@ -254,7 +254,7 @@ def detect_peaks( show: bool = False, ax=None, verbose: int = 0, -) -> np.ndarray: +) -> NDArray: """Detect peaks in data based on their amplitude and other features. Parameters @@ -384,7 +384,7 @@ def detect_peaks( # handle NaN's if ind.size and indnan.size: # NaN's and values close to NaN's cannot be peaks - ind = ind[np.in1d(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)] + ind = ind[np.isin(ind, np.unique(np.hstack((indnan, indnan - 1, indnan + 1))), invert=True)] if verbose >= 1: print(f"after handling nan values, ind = {ind.tolist()}") @@ -455,7 +455,7 @@ def detect_peaks( return ind -def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = True) -> np.ndarray: +def remove_spikes_naive(sig: NDArray, threshold: Union[float, int] = 20, inplace: bool = True) -> NDArray: """Remove signal spikes using a naive method. This is a method proposed in entry 0416 of CPSC2019. @@ -470,7 +470,7 @@ def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = T 1D, 2D or 3D signal with potential spikes. The last dimension is the time dimension. The signal can be single-lead, multi-lead, or batched signals. - threshold : numbers.Real, optional + threshold : float or int, optional Values of `sig` that are larger than `threshold` will be removed. inplace : bool, optional Whether to modify `sig` in place or not. @@ -512,16 +512,18 @@ def remove_spikes_naive(sig: np.ndarray, threshold: Real = 20, inplace: bool = T return sig.astype(dtype) -def butter_bandpass(lowcut: Real, highcut: Real, fs: Real, order: int, verbose: int = 0) -> Tuple[np.ndarray, np.ndarray]: +def butter_bandpass( + lowcut: Union[float, int], highcut: Union[float, int], fs: Union[float, int], order: int, verbose: int = 0 +) -> Tuple[NDArray, NDArray]: """Butterworth Bandpass Filter Design. Parameters ---------- - lowcut : numbers.Real + lowcut : float or int Low cutoff frequency. - highcut : numbers.Real + highcut : float or int High cutoff frequency. - fs : numbers.Real + fs : float or int Sampling frequency of `data`. order : int, Order of the filter. @@ -569,19 +571,19 @@ def butter_bandpass(lowcut: Real, highcut: Real, fs: Real, order: int, verbose: if verbose >= 1: print(f"by the setup of lowcut and highcut, the filter type falls to {btype}, with Wn = {Wn}") - b, a = butter(order, Wn, btype=btype) + b, a = butter(order, Wn, btype=btype) # type: ignore return b, a def butter_bandpass_filter( - data: np.ndarray, - lowcut: Real, - highcut: Real, - fs: Real, + data: NDArray, + lowcut: Union[float, int], + highcut: Union[float, int], + fs: Union[float, int], order: int, btype: Optional[Literal["lohi", "hilo"]] = None, verbose: int = 0, -) -> np.ndarray: +) -> NDArray: """Butterworth bandpass filtering the signals. Apply a Butterworth bandpass filter to the signal. @@ -639,19 +641,19 @@ def butter_bandpass_filter( def get_ampl( - sig: np.ndarray, - fs: Real, + sig: NDArray, + fs: Union[float, int], fmt: str = "lead_first", - window: Real = 0.2, + window: Union[float, int] = 0.2, critical_points: Optional[Sequence] = None, -) -> Union[float, np.ndarray]: +) -> Union[float, NDArray]: """Get amplitude of a signal (near critical points if given). Parameters ---------- sig : numpy.ndarray (ECG) signal. - fs : numbers.Real + fs : float or int Sampling frequency of the signal fmt : str, default "lead_first" Format of the signal, can be @@ -713,13 +715,13 @@ def get_ampl( def normalize( - sig: np.ndarray, + sig: NDArray, method: Literal["naive", "min-max", "z-score"], - mean: Union[Real, Iterable[Real]] = 0.0, - std: Union[Real, Iterable[Real]] = 1.0, + mean: Union[Union[float, int], Iterable[Union[float, int]]] = 0.0, + std: Union[Union[float, int], Iterable[Union[float, int]]] = 1.0, sig_fmt: str = "channel_first", per_channel: bool = False, -) -> np.ndarray: +) -> NDArray: """Normalize a signal. Perform z-score normalization on `sig`, @@ -742,11 +744,11 @@ def normalize( The signal to be normalized. method : {"naive", "min-max", "z-score"} Normalization method, case insensitive. - mean : numbers.Real or array_like, default 0.0 + mean : float or int or array_like, default 0.0 Mean value of the normalized signal, or mean values for each lead of the normalized signal. Useless if `method` is "min-max". - std : numbers.Real or array_like, default 1.0 + std : float or int or array_like, default 1.0 Standard deviation of the normalized signal, or standard deviations for each lead of the normalized signal. Useless if `method` is "min-max". @@ -786,17 +788,17 @@ def normalize( ], f"unknown normalization method `{method}`" if not per_channel: if sig.ndim == 2: - assert isinstance(mean, Real) and isinstance( - std, Real + assert isinstance(mean, (float, int)) and isinstance( + std, (float, int) ), "`mean` and `std` should be real numbers in the non per-channel setting for 2d signal" else: # sig.ndim == 3 - assert (isinstance(mean, Real) or np.shape(mean) == (sig.shape[0],)) and ( - isinstance(std, Real) or np.shape(std) == (sig.shape[0],) + assert (isinstance(mean, (float, int)) or np.shape(mean) == (sig.shape[0],)) and ( # type: ignore + isinstance(std, (float, int)) or np.shape(std) == (sig.shape[0],) # type: ignore ), ( f"`mean` and `std` should be real numbers or have shape ({sig.shape[0]},) " "in the non per-channel setting for 3d signal" ) - if isinstance(std, Real): + if isinstance(std, (float, int)): assert std > 0, "standard deviation should be positive" else: assert (np.array(std) > 0).all(), "standard deviations should all be positive" @@ -810,10 +812,10 @@ def normalize( if isinstance(mean, Iterable): assert sig.ndim in [2, 3], "`mean` should be a real number for 1d signal" if sig.ndim == 2: - assert np.shape(mean) in [ + assert np.shape(mean) in [ # type: ignore (sig.shape[0],), (sig.shape[-1],), - ], f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}" + ], f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}" # type: ignore if sig_fmt.lower() in [ "channel_first", "lead_first", @@ -826,40 +828,40 @@ def normalize( "channel_first", "lead_first", ]: - if np.shape(mean) == (sig.shape[0],): + if np.shape(mean) == (sig.shape[0],): # type: ignore _mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis] - elif np.shape(mean) == (sig.shape[1],): + elif np.shape(mean) == (sig.shape[1],): # type: ignore _mean = np.repeat( np.array(mean, dtype=dtype)[np.newaxis, ..., np.newaxis], sig.shape[0], axis=0, ) - elif np.shape(mean) == sig.shape[:2]: + elif np.shape(mean) == sig.shape[:2]: # type: ignore _mean = np.array(mean, dtype=dtype)[..., np.newaxis] else: - raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") + raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore else: # "channel_last" or "lead_last" - if np.shape(mean) == (sig.shape[0],): + if np.shape(mean) == (sig.shape[0],): # type: ignore _mean = np.array(mean, dtype=dtype)[..., np.newaxis, np.newaxis] - elif np.shape(mean) == (sig.shape[-1],): + elif np.shape(mean) == (sig.shape[-1],): # type: ignore _mean = np.repeat( np.array(mean, dtype=dtype)[np.newaxis, np.newaxis, ...], sig.shape[0], axis=0, ) - elif np.shape(mean) == (sig.shape[0], sig.shape[-1]): + elif np.shape(mean) == (sig.shape[0], sig.shape[-1]): # type: ignore _mean = np.expand_dims(np.array(mean, dtype=dtype), axis=1) else: - raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") + raise AssertionError(f"shape of `mean` = {np.shape(mean)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore else: _mean = mean if isinstance(std, Iterable): assert sig.ndim in [2, 3], "`std` should be a real number for 1d signal" if sig.ndim == 2: - assert np.shape(std) in [ + assert np.shape(std) in [ # type: ignore (sig.shape[0],), (sig.shape[-1],), - ], f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}" + ], f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}" # type: ignore if sig_fmt.lower() in [ "channel_first", "lead_first", @@ -872,31 +874,31 @@ def normalize( "channel_first", "lead_first", ]: - if np.shape(std) == (sig.shape[0],): + if np.shape(std) == (sig.shape[0],): # type: ignore _std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis] - elif np.shape(std) == (sig.shape[1],): + elif np.shape(std) == (sig.shape[1],): # type: ignore _std = np.repeat( np.array(std, dtype=dtype)[np.newaxis, ..., np.newaxis], sig.shape[0], axis=0, ) - elif np.shape(std) == sig.shape[:2]: + elif np.shape(std) == sig.shape[:2]: # type: ignore _std = np.array(std, dtype=dtype)[..., np.newaxis] else: - raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") + raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore else: # "channel_last" or "lead_last" - if np.shape(std) == (sig.shape[0],): + if np.shape(std) == (sig.shape[0],): # type: ignore _std = np.array(std, dtype=dtype)[..., np.newaxis, np.newaxis] - elif np.shape(std) == (sig.shape[-1],): + elif np.shape(std) == (sig.shape[-1],): # type: ignore _std = np.repeat( np.array(std, dtype=dtype)[np.newaxis, np.newaxis, ...], sig.shape[0], axis=0, ) - elif np.shape(std) == (sig.shape[0], sig.shape[-1]): + elif np.shape(std) == (sig.shape[0], sig.shape[-1]): # type: ignore _std = np.expand_dims(np.array(std, dtype=dtype), axis=1) else: - raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") + raise AssertionError(f"shape of `std` = {np.shape(std)} not compatible with the `sig` = {np.shape(sig)}") # type: ignore else: _std = std @@ -927,7 +929,7 @@ def normalize( options = dict(axis=0, keepdims=True) if _method == "z-score": - nm_sig = ((sig - np.mean(sig, dtype=dtype, **options)) / (np.std(sig, dtype=dtype, **options) + eps)) * _std + _mean + nm_sig = ((sig - np.mean(sig, dtype=dtype, **options)) / (np.std(sig, dtype=dtype, **options) + eps)) * _std + _mean # type: ignore elif _method == "min-max": - nm_sig = (sig - np.amin(sig, **options)) / (np.amax(sig, **options) - np.amin(sig, **options) + eps) - return nm_sig.astype(dtype) + nm_sig = (sig - np.amin(sig, **options)) / (np.amax(sig, **options) - np.amin(sig, **options) + eps) # type: ignore + return nm_sig.astype(dtype) # type: ignore