diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index fb3c6966a..d3d4afa7f 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -62,3 +62,4 @@ MutationPathogenicityPrediction, VariantClassificationClinVar, ) +from .sepsis_task import sepsis_ehr_task diff --git a/pyhealth/tasks/sepsis_task.py b/pyhealth/tasks/sepsis_task.py new file mode 100644 index 000000000..5eda88c13 --- /dev/null +++ b/pyhealth/tasks/sepsis_task.py @@ -0,0 +1,58 @@ +""" +EHR-only sepsis classification task for PyHealth. This task computes mean vital +sign values per visit and assigns a binary sepsis labelf ro classification. + +Attributes: + patient (Patient): + A PyHealth Patient object whose `data_source` attribute is a + Polars DataFrame containing the patient's full event history. +Returns: + List[Dict[str, Any]]: + A list of samples. Each sample has the structure: + + { + "patient_id": , + "visit_id": , + "ehr_features_mean": numpy.ndarray of shape (num_features,), + "y": + } +""" +import numpy as np +from typing import Dict, Any, List +from pyhealth.data import Patient + + +def sepsis_ehr_task(patient: Patient) -> List[Dict[str, Any]]: + # Retrieve patient's event dataframe + df = patient.data_source + + # Filter rows where event_type == "ehr" + ehr_df = df.filter(df["event_type"] == "ehr") + if ehr_df.is_empty(): + return [] + + samples = [] + + # Unique visits for this patient + visit_ids = ehr_df["visit_id"].unique().to_list() + + # Explicit whitelist of vital sign columns + vital_cols = ["heart_rate", "spo2", "glucose"] + + for vid in visit_ids: + visit_df = ehr_df.filter(ehr_df["visit_id"] == vid) + + # Select only the known vital sign columns + available = [c for c in vital_cols if c in visit_df.columns] + x = visit_df.select(available).mean().to_numpy().astype(float).flatten() + + y = int(visit_df["label"][0]) + + samples.append({ + "patient_id": patient.patient_id, + "visit_id": vid, + "ehr_features_mean": x, + "y": y, + }) + + return samples diff --git a/tests/core/test_sepsis_task.py b/tests/core/test_sepsis_task.py new file mode 100644 index 000000000..8209789e4 --- /dev/null +++ b/tests/core/test_sepsis_task.py @@ -0,0 +1,37 @@ +import numpy as np +import polars as pl +from pyhealth.data import Patient +from pyhealth.tasks import sepsis_ehr_task + + +def test_sepsis_ehr_task(): + # Build test flat event-level DataFrame + df = pl.DataFrame({ + "timestamp": [1, 2, 3], + "patient_id": ["P1", "P1", "P1"], + "visit_id": ["V1", "V1", "V1"], + "event_type": ["ehr", "ehr", "ehr"], + "table": ["ehr", "ehr", "ehr"], + "heart_rate": [80, 82, 85], + "spo2": [96, 95, 97], + "glucose": [120, 118, 110], + "label": [1, 1, 1], + }) + + # Create patient from raw event DataFrame + patient = Patient( + patient_id="P1", + data_source=df + ) + + # Run task + samples = sepsis_ehr_task(patient) + + assert len(samples) == 1 + sample = samples[0] + + assert "ehr_features_mean" in sample + assert "y" in sample + assert isinstance(sample["ehr_features_mean"], np.ndarray) + assert sample["y"] == 1 + assert sample["ehr_features_mean"].shape[0] == 3 # 3 vital signs