Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions examples/sdoh_icd9_llm_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import json
import os
from datetime import datetime
from typing import List, Set

import numpy as np
from pyhealth.datasets import MIMIC3NoteDataset
from pyhealth.metrics import multilabel_metrics_fn
from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM
from pyhealth.tasks.sdoh_icd9_detection import TARGET_CODES
from pyhealth.tasks.sdoh_utils import codes_to_multihot, load_sdoh_icd9_labels


def parse_args():
"""Parse CLI arguments for SDOH ICD-9 note evaluation."""
parser = argparse.ArgumentParser(
description="Admission-level SDOH ICD-9 evaluation with per-note LLM calls."
)
parser.add_argument("--mimic-root", required=True, help="Root folder for MIMIC-III CSVs")
parser.add_argument("--label-csv-path", required=True, help="Path to sdoh_icd9_dataset.csv")
parser.add_argument(
"--label-source",
default="manual",
choices=["manual", "true"],
help="Which labels to use as primary ground truth.",
)
parser.add_argument(
"--max-notes",
default="all",
help="Limit notes per admission (e.g., 1, 2, 5, or 'all').",
)
parser.add_argument(
"--max-admissions",
default="all",
help="Limit admissions to process (e.g., 5 or 'all').",
)
parser.add_argument(
"--note-categories",
help="Comma-separated NOTE_CATEGORY values to include (optional).",
)
parser.add_argument("--output-dir", default=".", help="Directory to save outputs.")
parser.add_argument("--dry-run", action="store_true")
return parser.parse_args()


def main():
"""Run admission-level evaluation with per-note LLM calls."""
args = parse_args()
target_codes = list(TARGET_CODES)
label_map = load_sdoh_icd9_labels(args.label_csv_path, target_codes)

include_categories = (
[cat.strip() for cat in args.note_categories.split(",")]
if args.note_categories
else None
)
if str(args.max_notes).lower() == "all":
max_notes = None
else:
try:
max_notes = int(args.max_notes)
except ValueError as exc:
raise ValueError("--max-notes must be an integer or 'all'") from exc
if max_notes <= 0:
raise ValueError("--max-notes must be a positive integer or 'all'")
if str(args.max_admissions).lower() == "all":
max_admissions = None
else:
try:
max_admissions = int(args.max_admissions)
except ValueError as exc:
raise ValueError("--max-admissions must be an integer or 'all'") from exc
if max_admissions <= 0:
raise ValueError("--max-admissions must be a positive integer or 'all'")

hadm_ids = list(label_map.keys())
if max_admissions is not None:
hadm_ids = hadm_ids[:max_admissions]
label_map = {hadm_id: label_map[hadm_id] for hadm_id in hadm_ids}

note_dataset = MIMIC3NoteDataset(
root=args.mimic_root,
target_codes=target_codes,
hadm_ids=hadm_ids,
include_categories=include_categories,
)
sample_dataset = note_dataset.set_task(
label_source=args.label_source,
label_map=label_map,
)

dry_run = args.dry_run or not os.environ.get("OPENAI_API_KEY")
model = SDOHICD9LLM(
target_codes=target_codes,
dry_run=dry_run,
max_notes=max_notes,
)

results = []
predicted_codes_all: List[Set[str]] = []
manual_codes_all: List[Set[str]] = []
true_codes_all: List[Set[str]] = []

for sample in sample_dataset:
predicted_codes, note_results = model.predict_admission_with_notes(
sample["notes"],
sample.get("note_categories"),
sample.get("chartdates"),
)
predicted_codes_all.append(predicted_codes)
visit_id = str(sample.get("visit_id", ""))
label_entry = label_map.get(visit_id, {"manual": set(), "true": set()})
manual_codes = set(label_entry["manual"])
true_codes = set(label_entry["true"])
manual_codes_all.append(manual_codes)
true_codes_all.append(true_codes)

results.append(
{
"visit_id": sample.get("visit_id"),
"patient_id": sample.get("patient_id"),
"num_notes": sample.get("num_notes"),
"text_length": sample.get("text_length"),
"is_gap_case": sample.get("is_gap_case"),
"manual_codes": ",".join(sorted(manual_codes)),
"true_codes": ",".join(sorted(true_codes)),
"predicted_codes": ",".join(sorted(predicted_codes)),
"note_results": json.dumps(note_results),
}
)

y_pred = np.stack(
[codes_to_multihot(codes, target_codes).numpy() for codes in predicted_codes_all],
axis=0,
)
y_manual = np.stack(
[codes_to_multihot(codes, target_codes).numpy() for codes in manual_codes_all],
axis=0,
)
y_true = np.stack(
[codes_to_multihot(codes, target_codes).numpy() for codes in true_codes_all],
axis=0,
)

metrics_list = [
"accuracy",
"hamming_loss",
"f1_micro",
"f1_macro",
"precision_micro",
"recall_micro",
]
metrics_manual = multilabel_metrics_fn(y_manual, y_pred, metrics=metrics_list)
metrics_true = multilabel_metrics_fn(y_true, y_pred, metrics=metrics_list)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs(args.output_dir, exist_ok=True)

results_path = os.path.join(
args.output_dir, f"admission_level_results_per_note_{timestamp}.json"
)
with open(results_path, "w") as f:
json.dump(results, f, indent=2)

metrics_path = os.path.join(
args.output_dir, f"admission_level_metrics_per_note_{timestamp}.json"
)
with open(metrics_path, "w") as f:
json.dump(
{
"evaluation_timestamp": timestamp,
"processing_method": "per_note",
"total_admissions": len(results),
"dry_run": dry_run,
"manual_labels_metrics": metrics_manual,
"true_codes_metrics": metrics_true,
},
f,
indent=2,
)

print("Saved results to:", results_path)
print("Saved metrics to:", metrics_path)
print("Manual labels micro F1:", metrics_manual.get("f1_micro"))


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, *args, **kwargs):
from .eicu import eICUDataset
from .isruc import ISRUCDataset
from .medical_transcriptions import MedicalTranscriptionsDataset
from .mimic3 import MIMIC3Dataset
from .mimic3 import MIMIC3Dataset, MIMIC3NoteDataset
from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
from .mimicextract import MIMICExtractDataset
from .omop import OMOPDataset
Expand All @@ -66,6 +66,7 @@ def __init__(self, *args, **kwargs):
from .sleepedf import SleepEDFDataset
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .sdoh_icd9 import SDOHICD9NotesDataset
from .tcga_prad import TCGAPRADDataset
from .splitter import (
split_by_patient,
Expand Down
159 changes: 158 additions & 1 deletion pyhealth/datasets/mimic3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
import warnings
from pathlib import Path
from typing import List, Optional
from typing import Dict, Iterable, List, Optional, Sequence, Set

import narwhals as pl
import pandas as pd

from .base_dataset import BaseDataset
from .sample_dataset import SampleDataset, create_sample_dataset
from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask
from ..tasks.sdoh_utils import TARGET_CODES

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,3 +89,156 @@ def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame:
.alias("charttime")
)
return df


class MIMIC3NoteDataset:
"""Note-only loader for MIMIC-III NOTEEVENTS.

This loader streams NOTEEVENTS in chunks and optionally filters by HADM_IDs
and note categories. Use set_task() to convert admissions into samples.
"""

def __init__(
self,
noteevents_path: Optional[str] = None,
root: Optional[str] = None,
target_codes: Optional[Sequence[str]] = None,
hadm_ids: Optional[Iterable[str]] = None,
include_categories: Optional[Sequence[str]] = None,
chunksize: int = 200_000,
dataset_name: Optional[str] = None,
) -> None:
"""Initialize the note-only loader.

Args:
noteevents_path: Path to NOTEEVENTS CSV/CSV.GZ.
root: MIMIC-III root directory (used if noteevents_path is None).
target_codes: Target ICD-9 codes for label vector construction.
hadm_ids: Optional admission IDs to include.
include_categories: Optional NOTE_CATEGORY values to include.
chunksize: Number of rows per chunk read.
dataset_name: Optional dataset name for the SampleDataset.
"""
if noteevents_path is None:
if root is None:
raise ValueError("root is required when noteevents_path is not set.")
noteevents_path = str(Path(root) / "NOTEEVENTS.csv.gz")
self.noteevents_path = noteevents_path
self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES)
self.hadm_ids = {str(x) for x in hadm_ids} if hadm_ids is not None else None
self.include_categories = (
{cat.strip().upper() for cat in include_categories}
if include_categories
else None
)
self.chunksize = chunksize
self.dataset_name = dataset_name or "mimic3_note"

def _load_notes(self) -> Dict[str, List[Dict]]:
"""Load and group note events by admission."""
keep_cols = {
"SUBJECT_ID",
"HADM_ID",
"CHARTDATE",
"CHARTTIME",
"CATEGORY",
"TEXT",
}

notes_by_hadm: Dict[str, List[Dict]] = {}
for chunk in pd.read_csv(
self.noteevents_path,
chunksize=self.chunksize,
usecols=lambda c: c.upper() in keep_cols,
dtype={"SUBJECT_ID": "string", "HADM_ID": "string"},
low_memory=False,
):
chunk.columns = [c.upper() for c in chunk.columns]
if self.hadm_ids is not None:
chunk = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)]
if chunk.empty:
continue
if self.include_categories is not None:
chunk = chunk[
chunk["CATEGORY"].astype("string")
.str.upper()
.isin(self.include_categories)
]
if chunk.empty:
continue

charttime = pd.to_datetime(chunk["CHARTTIME"], errors="coerce")
chartdate = pd.to_datetime(chunk["CHARTDATE"], errors="coerce")
timestamp = charttime.fillna(chartdate)

for row, ts in zip(chunk.itertuples(index=False), timestamp):
hadm_id = str(row.HADM_ID)
entry = {
"patient_id": str(row.SUBJECT_ID) if pd.notna(row.SUBJECT_ID) else "",
"text": row.TEXT if pd.notna(row.TEXT) else "",
"category": row.CATEGORY if pd.notna(row.CATEGORY) else "",
"timestamp": ts,
}
notes_by_hadm.setdefault(hadm_id, []).append(entry)

return notes_by_hadm

def _build_admissions(self) -> List[Dict]:
"""Build admission-level note bundles with timestamps and categories."""
notes_by_hadm = self._load_notes()
admissions: List[Dict] = []
for hadm_id, notes in notes_by_hadm.items():
notes.sort(
key=lambda x: x["timestamp"]
if pd.notna(x["timestamp"])
else pd.Timestamp.min
)
note_texts = [str(n["text"]) for n in notes]
note_categories = [str(n["category"]) for n in notes]
chartdates = [
n["timestamp"].strftime("%Y-%m-%d") if pd.notna(n["timestamp"]) else "Unknown"
for n in notes
]

admissions.append(
{
"visit_id": hadm_id,
"patient_id": notes[0]["patient_id"],
"notes": note_texts,
"note_categories": note_categories,
"chartdates": chartdates,
"num_notes": len(note_texts),
"text_length": int(sum(len(note) for note in note_texts)),
}
)

logger.info("Loaded %d admissions from NOTEEVENTS", len(admissions))
return admissions

def set_task(
self,
task: Optional[SDOHICD9AdmissionTask] = None,
label_source: str = "manual",
label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None,
in_memory: bool = True,
) -> SampleDataset:
"""Apply a task to admissions and return a SampleDataset."""
if task is None:
task = SDOHICD9AdmissionTask(
target_codes=self.target_codes,
label_source=label_source,
label_map=label_map,
)

samples: List[Dict] = []
for admission in self._build_admissions():
samples.extend(task(admission))

return create_sample_dataset(
samples=samples,
input_schema=task.input_schema,
output_schema=task.output_schema,
dataset_name=self.dataset_name,
task_name=task.task_name,
in_memory=in_memory,
)
Loading