Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyhealth/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,4 @@
MutationPathogenicityPrediction,
VariantClassificationClinVar,
)
from .sepsis_task import sepsis_ehr_task
58 changes: 58 additions & 0 deletions pyhealth/tasks/sepsis_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
EHR-only sepsis classification task for PyHealth. This task computes mean vital
sign values per visit and assigns a binary sepsis labelf ro classification.

Attributes:
patient (Patient):
A PyHealth Patient object whose `data_source` attribute is a
Polars DataFrame containing the patient's full event history.
Returns:
List[Dict[str, Any]]:
A list of samples. Each sample has the structure:

{
"patient_id": <str>,
"visit_id": <str>,
"ehr_features_mean": numpy.ndarray of shape (num_features,),
"y": <int binary label>
}
"""
import numpy as np
from typing import Dict, Any, List
from pyhealth.data import Patient


def sepsis_ehr_task(patient: Patient) -> List[Dict[str, Any]]:
# Retrieve patient's event dataframe
df = patient.data_source

# Filter rows where event_type == "ehr"
ehr_df = df.filter(df["event_type"] == "ehr")
if ehr_df.is_empty():
return []

samples = []

# Unique visits for this patient
visit_ids = ehr_df["visit_id"].unique().to_list()

# Explicit whitelist of vital sign columns
vital_cols = ["heart_rate", "spo2", "glucose"]

for vid in visit_ids:
visit_df = ehr_df.filter(ehr_df["visit_id"] == vid)

# Select only the known vital sign columns
available = [c for c in vital_cols if c in visit_df.columns]
x = visit_df.select(available).mean().to_numpy().astype(float).flatten()

y = int(visit_df["label"][0])

samples.append({
"patient_id": patient.patient_id,
"visit_id": vid,
"ehr_features_mean": x,
"y": y,
})

return samples
37 changes: 37 additions & 0 deletions tests/core/test_sepsis_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import numpy as np
import polars as pl
from pyhealth.data import Patient
from pyhealth.tasks import sepsis_ehr_task


def test_sepsis_ehr_task():
# Build test flat event-level DataFrame
df = pl.DataFrame({
"timestamp": [1, 2, 3],
"patient_id": ["P1", "P1", "P1"],
"visit_id": ["V1", "V1", "V1"],
"event_type": ["ehr", "ehr", "ehr"],
"table": ["ehr", "ehr", "ehr"],
"heart_rate": [80, 82, 85],
"spo2": [96, 95, 97],
"glucose": [120, 118, 110],
"label": [1, 1, 1],
})

# Create patient from raw event DataFrame
patient = Patient(
patient_id="P1",
data_source=df
)

# Run task
samples = sepsis_ehr_task(patient)

assert len(samples) == 1
sample = samples[0]

assert "ehr_features_mean" in sample
assert "y" in sample
assert isinstance(sample["ehr_features_mean"], np.ndarray)
assert sample["y"] == 1
assert sample["ehr_features_mean"].shape[0] == 3 # 3 vital signs