From 525b97c7effe2410c400e1592eebe7f281863f8a Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sat, 10 Jan 2026 21:30:35 -0500 Subject: [PATCH 1/7] Add SDOH ICD-9 note pipeline --- examples/sdoh_icd9_llm_eval.py | 131 ++++++++++++++++ pyhealth/datasets/__init__.py | 2 + pyhealth/datasets/mimic3_notes.py | 129 ++++++++++++++++ pyhealth/datasets/sdoh_icd9.py | 99 ++++++++++++ pyhealth/metrics/__init__.py | 31 ++-- pyhealth/models/__init__.py | 1 + pyhealth/models/sdoh_icd9_llm.py | 210 ++++++++++++++++++++++++++ pyhealth/tasks/__init__.py | 1 + pyhealth/tasks/sdoh_icd9_detection.py | 148 ++++++++++++++++++ pyhealth/tasks/sdoh_utils.py | 72 +++++++++ tools/mimic3_sdoh_subset.py | 79 ++++++++++ 11 files changed, 891 insertions(+), 12 deletions(-) create mode 100644 examples/sdoh_icd9_llm_eval.py create mode 100644 pyhealth/datasets/mimic3_notes.py create mode 100644 pyhealth/datasets/sdoh_icd9.py create mode 100644 pyhealth/models/sdoh_icd9_llm.py create mode 100644 pyhealth/tasks/sdoh_icd9_detection.py create mode 100644 pyhealth/tasks/sdoh_utils.py create mode 100644 tools/mimic3_sdoh_subset.py diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py new file mode 100644 index 000000000..c9f390768 --- /dev/null +++ b/examples/sdoh_icd9_llm_eval.py @@ -0,0 +1,131 @@ +import argparse +import json +import os +from datetime import datetime +from typing import Iterable, List, Sequence, Set + +import numpy as np +from pyhealth.datasets import MIMIC3NotesDataset +from pyhealth.metrics import multilabel_metrics_fn +from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM +from pyhealth.tasks.sdoh_icd9_detection import TARGET_CODES + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Admission-level SDOH ICD-9 evaluation with per-note LLM calls." + ) + parser.add_argument("--mimic-root", required=True, help="Root folder for MIMIC-III CSVs") + parser.add_argument("--label-csv-path", required=True, help="Path to sdoh_icd9_dataset.csv") + parser.add_argument( + "--label-source", + default="manual", + choices=["manual", "true"], + help="Which labels to use as primary ground truth.", + ) + parser.add_argument("--output-dir", default=".", help="Directory to save outputs.") + parser.add_argument("--dry-run", action="store_true") + return parser.parse_args() + + +def codes_to_multihot(codes_list: Iterable[Set[str]], target_codes: Sequence[str]) -> np.ndarray: + target_set = [code.upper() for code in target_codes] + rows: List[List[int]] = [] + for codes in codes_list: + code_set = {code.upper() for code in codes} + rows.append([1 if code in code_set else 0 for code in target_set]) + return np.array(rows, dtype=np.float32) + + +def main(): + args = parse_args() + target_codes = list(TARGET_CODES) + + noteevents_path = f"{args.mimic_root}/NOTEEVENTS.csv.gz" + note_dataset = MIMIC3NotesDataset( + noteevents_path=noteevents_path, + label_csv_path=args.label_csv_path, + target_codes=target_codes, + ) + sample_dataset = note_dataset.set_task(label_source=args.label_source) + + dry_run = args.dry_run or not os.environ.get("OPENAI_API_KEY") + model = SDOHICD9LLM(target_codes=target_codes, dry_run=dry_run) + + results = [] + predicted_codes_all: List[Set[str]] = [] + manual_codes_all: List[Set[str]] = [] + true_codes_all: List[Set[str]] = [] + + for sample in sample_dataset: + predicted_codes, note_results = model.predict_admission_with_notes( + sample["notes"], + sample.get("note_categories"), + sample.get("chartdates"), + ) + predicted_codes_all.append(predicted_codes) + manual_codes_all.append(set(sample.get("manual_codes", []))) + true_codes_all.append(set(sample.get("true_codes", []))) + + results.append( + { + "visit_id": sample.get("visit_id"), + "patient_id": sample.get("patient_id"), + "num_notes": sample.get("num_notes"), + "text_length": sample.get("text_length"), + "is_gap_case": sample.get("is_gap_case"), + "manual_codes": ",".join(sorted(sample.get("manual_codes", []))), + "true_codes": ",".join(sorted(sample.get("true_codes", []))), + "predicted_codes": ",".join(sorted(predicted_codes)), + "note_results": json.dumps(note_results), + } + ) + + y_pred = codes_to_multihot(predicted_codes_all, target_codes) + y_manual = codes_to_multihot(manual_codes_all, target_codes) + y_true = codes_to_multihot(true_codes_all, target_codes) + + metrics_list = [ + "accuracy", + "hamming_loss", + "f1_micro", + "f1_macro", + "precision_micro", + "recall_micro", + ] + metrics_manual = multilabel_metrics_fn(y_manual, y_pred, metrics=metrics_list) + metrics_true = multilabel_metrics_fn(y_true, y_pred, metrics=metrics_list) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(args.output_dir, exist_ok=True) + + results_path = os.path.join( + args.output_dir, f"admission_level_results_per_note_{timestamp}.json" + ) + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + + metrics_path = os.path.join( + args.output_dir, f"admission_level_metrics_per_note_{timestamp}.json" + ) + with open(metrics_path, "w") as f: + json.dump( + { + "evaluation_timestamp": timestamp, + "processing_method": "per_note", + "total_admissions": len(results), + "dry_run": dry_run, + "manual_labels_metrics": metrics_manual, + "true_codes_metrics": metrics_true, + }, + f, + indent=2, + ) + + print("Saved results to:", results_path) + print("Saved metrics to:", metrics_path) + print("Manual labels micro F1:", metrics_manual.get("f1_micro")) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 5176fdb42..dc493a3db 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -58,6 +58,7 @@ def __init__(self, *args, **kwargs): from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset from .mimic3 import MIMIC3Dataset +from .mimic3_notes import MIMIC3NotesDataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset @@ -66,6 +67,7 @@ def __init__(self, *args, **kwargs): from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset +from .sdoh_icd9 import SDOHICD9NotesDataset from .tcga_prad import TCGAPRADDataset from .splitter import ( split_by_patient, diff --git a/pyhealth/datasets/mimic3_notes.py b/pyhealth/datasets/mimic3_notes.py new file mode 100644 index 000000000..1e1584cdc --- /dev/null +++ b/pyhealth/datasets/mimic3_notes.py @@ -0,0 +1,129 @@ +import logging +from typing import Dict, Iterable, List, Optional, Sequence + +import pandas as pd + +from .sample_dataset import SampleDataset, create_sample_dataset +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask, load_sdoh_icd9_labels +from ..tasks.sdoh_utils import TARGET_CODES + +logger = logging.getLogger(__name__) + + +class MIMIC3NotesDataset: + """Note-only loader for MIMIC-III NOTEEVENTS with label CSV filtering.""" + + def __init__( + self, + noteevents_path: str, + label_csv_path: str, + target_codes: Optional[Sequence[str]] = None, + hadm_ids: Optional[Iterable[str]] = None, + chunksize: int = 200_000, + dataset_name: Optional[str] = None, + ) -> None: + self.noteevents_path = noteevents_path + self.label_csv_path = label_csv_path + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.label_map = load_sdoh_icd9_labels(label_csv_path, self.target_codes) + if hadm_ids is None: + hadm_ids = self.label_map.keys() + self.hadm_ids = {str(x) for x in hadm_ids} + self.chunksize = chunksize + self.dataset_name = dataset_name or "mimic3_notes" + + def _load_notes(self) -> Dict[str, List[Dict]]: + keep_cols = { + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + } + + notes_by_hadm: Dict[str, List[Dict]] = {} + for chunk in pd.read_csv( + self.noteevents_path, + chunksize=self.chunksize, + usecols=lambda c: c.upper() in keep_cols, + dtype={"SUBJECT_ID": "string", "HADM_ID": "string"}, + low_memory=False, + ): + chunk.columns = [c.upper() for c in chunk.columns] + filtered = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)] + if filtered.empty: + continue + + charttime = pd.to_datetime(filtered["CHARTTIME"], errors="coerce") + chartdate = pd.to_datetime(filtered["CHARTDATE"], errors="coerce") + timestamp = charttime.fillna(chartdate) + + for row, ts in zip(filtered.itertuples(index=False), timestamp): + hadm_id = str(row.HADM_ID) + entry = { + "patient_id": str(row.SUBJECT_ID) if row.SUBJECT_ID is not pd.NA else "", + "text": row.TEXT if row.TEXT is not pd.NA else "", + "category": row.CATEGORY if row.CATEGORY is not pd.NA else "", + "timestamp": ts, + } + notes_by_hadm.setdefault(hadm_id, []).append(entry) + + return notes_by_hadm + + def _build_admissions(self) -> List[Dict]: + notes_by_hadm = self._load_notes() + admissions: List[Dict] = [] + for hadm_id, notes in notes_by_hadm.items(): + if hadm_id not in self.label_map: + continue + notes.sort(key=lambda x: x["timestamp"] if x["timestamp"] is not pd.NaT else pd.Timestamp.min) + note_texts = [str(n["text"]) for n in notes] + note_categories = [str(n["category"]) for n in notes] + chartdates = [ + n["timestamp"].strftime("%Y-%m-%d") if pd.notna(n["timestamp"]) else "Unknown" + for n in notes + ] + + labels = self.label_map[hadm_id] + admissions.append( + { + "visit_id": hadm_id, + "patient_id": notes[0]["patient_id"], + "notes": note_texts, + "note_categories": note_categories, + "chartdates": chartdates, + "num_notes": len(note_texts), + "text_length": int(sum(len(note) for note in note_texts)), + "manual_codes": labels["manual"], + "true_codes": labels["true"], + } + ) + + logger.info("Loaded %d admissions from NOTEEVENTS", len(admissions)) + return admissions + + def set_task( + self, + task: Optional[SDOHICD9AdmissionTask] = None, + label_source: str = "manual", + in_memory: bool = True, + ) -> SampleDataset: + if task is None: + task = SDOHICD9AdmissionTask( + target_codes=self.target_codes, + label_source=label_source, + ) + + samples: List[Dict] = [] + for admission in self._build_admissions(): + samples.extend(task(admission)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=in_memory, + ) diff --git a/pyhealth/datasets/sdoh_icd9.py b/pyhealth/datasets/sdoh_icd9.py new file mode 100644 index 000000000..49def92d4 --- /dev/null +++ b/pyhealth/datasets/sdoh_icd9.py @@ -0,0 +1,99 @@ +import logging +from typing import Dict, List, Optional, Sequence, Set + +import pandas as pd + +from .sample_dataset import SampleDataset, create_sample_dataset +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask +from ..tasks.sdoh_utils import TARGET_CODES, parse_codes + +logger = logging.getLogger(__name__) + + +REQUIRED_COLUMNS = { + "HADM_ID", + "SUBJECT_ID", + "NOTE_CATEGORY", + "CHARTDATE", + "FULL_TEXT", + "ADMISSION_TRUE_CODES", + "ADMISSION_MANUAL_LABELS", +} + + +class SDOHICD9NotesDataset: + """CSV-backed dataset for SDOH ICD-9 V-code detection from notes.""" + + def __init__( + self, + csv_path: str, + dataset_name: Optional[str] = None, + target_codes: Optional[Sequence[str]] = None, + ) -> None: + self.csv_path = csv_path + self.dataset_name = dataset_name or "sdoh_icd9_notes" + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self._admissions = self._load_admissions() + + def _load_admissions(self) -> List[Dict]: + df = pd.read_csv(self.csv_path) + missing = REQUIRED_COLUMNS - set(df.columns) + if missing: + raise ValueError(f"Missing required columns: {sorted(missing)}") + + admissions: List[Dict] = [] + df["CHARTDATE"] = pd.to_datetime(df["CHARTDATE"], errors="coerce") + for hadm_id, group in df.groupby("HADM_ID"): + group = group.sort_values("CHARTDATE") + first = group.iloc[0] + notes = [str(text).strip() for text in group["FULL_TEXT"].fillna("")] + chartdates = [ + dt.strftime("%Y-%m-%d") if pd.notna(dt) else "Unknown" + for dt in group["CHARTDATE"] + ] + admission = { + "visit_id": str(hadm_id), + "patient_id": str(first["SUBJECT_ID"]), + "is_gap_case": first.get("IS_GAP_CASE"), + "note_categories": [ + str(cat).strip() for cat in group["NOTE_CATEGORY"].fillna("") + ], + "chartdates": chartdates, + "notes": notes, + "num_notes": len(notes), + "text_length": int(sum(len(note) for note in notes)), + "manual_codes": parse_codes( + first["ADMISSION_MANUAL_LABELS"], self.target_codes + ), + "true_codes": parse_codes( + first["ADMISSION_TRUE_CODES"], self.target_codes + ), + } + admissions.append(admission) + logger.info("Loaded %d admissions from %s", len(admissions), self.csv_path) + return admissions + + def set_task( + self, + task: Optional[SDOHICD9AdmissionTask] = None, + label_source: str = "manual", + in_memory: bool = True, + ) -> SampleDataset: + if task is None: + task = SDOHICD9AdmissionTask( + target_codes=self.target_codes, + label_source=label_source, + ) + + samples: List[Dict] = [] + for admission in self._admissions: + samples.extend(task(admission)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=in_memory, + ) diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..fa9578ed8 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -1,12 +1,19 @@ from .binary import binary_metrics_fn from .drug_recommendation import ddi_rate_score -from .interpretability import ( - ComprehensivenessMetric, - Evaluator, - RemovalBasedMetric, - SufficiencyMetric, - evaluate_attribution, -) +try: + from .interpretability import ( + ComprehensivenessMetric, + Evaluator, + RemovalBasedMetric, + SufficiencyMetric, + evaluate_attribution, + ) +except Exception: # pragma: no cover - optional dependencies + ComprehensivenessMetric = None + Evaluator = None + RemovalBasedMetric = None + SufficiencyMetric = None + evaluate_attribution = None from .multiclass import multiclass_metrics_fn from .multilabel import multilabel_metrics_fn @@ -17,11 +24,11 @@ __all__ = [ "binary_metrics_fn", "ddi_rate_score", - "ComprehensivenessMetric", - "SufficiencyMetric", - "RemovalBasedMetric", - "Evaluator", - "evaluate_attribution", + "ComprehensivenessMetric", + "SufficiencyMetric", + "RemovalBasedMetric", + "Evaluator", + "evaluate_attribution", "multiclass_metrics_fn", "multilabel_metrics_fn", "ranking_metrics_fn", diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 3c0b5384d..b34100b33 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -37,3 +37,4 @@ from .vae import VAE from .sdoh import SdohClassifier from .medlink import MedLink +from .sdoh_icd9_llm import SDOHICD9LLM diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py new file mode 100644 index 000000000..f22ce29f0 --- /dev/null +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -0,0 +1,210 @@ +import logging +import os +import re +import time +from typing import Iterable, List, Optional, Sequence, Set, Tuple + +import torch +from torch import nn + +from pyhealth.tasks.sdoh_utils import TARGET_CODES, codes_to_multihot + +logger = logging.getLogger(__name__) + + +PROMPT_TEMPLATE = """\ +You are an assistant that extracts SDOH ICD-9 V-codes from clinical notes. +Return only the codes, comma-separated, inside triple backticks. +If no target codes are present, return None inside triple backticks. +Target codes: {codes} +""" + + +class SDOHICD9LLM(nn.Module): + """Admission-level SDOH ICD-9 V-code detector using an LLM.""" + + mode = "multilabel" + + def __init__( + self, + target_codes: Optional[Sequence[str]] = None, + model_name: str = "gpt-4o-mini", + prompt_template: Optional[str] = None, + api_key: Optional[str] = None, + max_tokens: int = 100, + max_chars: int = 100000, + temperature: float = 0.0, + sleep_s: float = 0.2, + dry_run: bool = False, + ) -> None: + super().__init__() + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.model_name = model_name + self.prompt_template = prompt_template or PROMPT_TEMPLATE + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") + self.max_tokens = max_tokens + self.max_chars = max_chars + self.temperature = temperature + self.sleep_s = sleep_s + self.dry_run = dry_run + self._client = None + + if not self.api_key and not self.dry_run: + raise EnvironmentError( + "OPENAI_API_KEY is required unless dry_run=True." + ) + + mode = "dry-run" if dry_run else "live" + logger.info( + "Initialized SDOHICD9LLM (mode=%s, model=%s, codes=%d)", + mode, model_name, len(self.target_codes) + ) + + def _get_client(self): + if self._client is None: + from openai import OpenAI + + self._client = OpenAI(api_key=self.api_key) + return self._client + + def _call_openai_api(self, text: str) -> str: + if self.dry_run: + return "```None```" + + if len(text) > self.max_chars: + text = text[: self.max_chars] + "\n\n[Note truncated due to length...]" + + client = self._get_client() + response = client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "system", + "content": self.prompt_template.format( + codes=", ".join(self.target_codes) + ), + }, + { + "role": "user", + "content": ( + "Analyze this clinical note and identify SDOH codes:\n\n" + f"{text}" + ), + }, + ], + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + return response.choices[0].message.content.strip() + + def _parse_llm_response(self, response: str) -> Set[str]: + if not response: + return set() + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + if matches: + response = matches[0].strip() + else: + response = response.strip() + + if response.lower().strip() == "none": + return set() + + response = response.replace("Answer:", "").replace("Codes:", "").strip() + for delimiter in [",", ";", " ", "\n"]: + if delimiter in response: + parts = [c.strip() for c in response.split(delimiter)] + break + else: + parts = [response.strip()] + + valid = {code.upper() for code in parts if code.strip()} + target_set = {code.upper() for code in self.target_codes} + return {code for code in valid if code in target_set} + + def _predict_admission( + self, + notes: Iterable[str], + note_categories: Optional[Iterable[str]] = None, + chartdates: Optional[Iterable[str]] = None, + ) -> Tuple[Set[str], List[dict]]: + aggregated: Set[str] = set() + note_results: List[dict] = [] + + categories = list(note_categories) if note_categories is not None else [] + dates = list(chartdates) if chartdates is not None else [] + notes_list = list(notes) + + logger.debug("Processing admission with %d notes", len(notes_list)) + + for idx, note in enumerate(notes_list): + category = categories[idx] if idx < len(categories) else "Unknown" + date = dates[idx] if idx < len(dates) else "Unknown" + + response = self._call_openai_api(note) + predicted = self._parse_llm_response(response) + aggregated.update(predicted) + + logger.debug( + "Note %d/%d (%s, %s): predicted %s", + idx + 1, len(notes_list), category, date, sorted(predicted) or "none" + ) + + note_results.append( + { + "category": category, + "date": date, + "predicted_codes": sorted(predicted), + "llm_response": response, + } + ) + + if self.sleep_s > 0 and not self.dry_run: + time.sleep(self.sleep_s) + + logger.debug("Admission complete: aggregated codes %s", sorted(aggregated)) + return aggregated, note_results + + def forward( + self, + notes, + note_categories=None, + chartdates=None, + label=None, + **kwargs, + ): + if notes and isinstance(notes[0], str): + notes_batch = [notes] + categories_batch = [note_categories] if note_categories is not None else [None] + dates_batch = [chartdates] if chartdates is not None else [None] + else: + notes_batch = notes + categories_batch = note_categories or [None] * len(notes_batch) + dates_batch = chartdates or [None] * len(notes_batch) + + batch_probs: List[List[int]] = [] + for note_list, cats, dates in zip( + notes_batch, categories_batch, dates_batch + ): + predicted, _ = self._predict_admission(note_list, cats, dates) + batch_probs.append( + codes_to_multihot(predicted, self.target_codes) + ) + + y_prob = torch.tensor(batch_probs, dtype=torch.float32) + if label is not None and isinstance(label, torch.Tensor): + y_prob = y_prob.to(label.device) + y_true = label + else: + y_true = label + + loss = torch.zeros(1, device=y_prob.device).sum() + return {"loss": loss, "y_prob": y_prob, "y_true": y_true} + + def predict_admission_with_notes( + self, + notes: Iterable[str], + note_categories: Optional[Iterable[str]] = None, + chartdates: Optional[Iterable[str]] = None, + ) -> Tuple[Set[str], List[dict]]: + return self._predict_admission(notes, note_categories, chartdates) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 28c71f142..31cb01c0d 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -69,3 +69,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .sdoh_icd9_detection import SDOHICD9AdmissionTask diff --git a/pyhealth/tasks/sdoh_icd9_detection.py b/pyhealth/tasks/sdoh_icd9_detection.py new file mode 100644 index 000000000..42ce722ea --- /dev/null +++ b/pyhealth/tasks/sdoh_icd9_detection.py @@ -0,0 +1,148 @@ +import logging +from typing import Dict, List, Optional, Sequence, Set + +import pandas as pd +import torch + +from ..data import Event, Patient +from .base_task import BaseTask +from .sdoh_utils import TARGET_CODES, codes_to_multihot, parse_codes + +logger = logging.getLogger(__name__) + + +class SDOHICD9AdmissionTask(BaseTask): + """Builds admission-level samples for SDOH ICD-9 V-code detection.""" + + task_name: str = "SDOHICD9Admission" + input_schema: Dict[str, str] = { + "notes": "raw", + "note_categories": "raw", + "chartdates": "raw", + "patient_id": "raw", + "visit_id": "raw", + } + output_schema: Dict[str, object] = { + "label": ("tensor", {"dtype": torch.float32}), + } + + def __init__( + self, + target_codes: Optional[Sequence[str]] = None, + label_source: str = "manual", + ) -> None: + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + if label_source not in {"manual", "true"}: + raise ValueError("label_source must be 'manual' or 'true'") + self.label_source = label_source + + def __call__(self, admission: Dict) -> List[Dict]: + if self.label_source == "manual": + label_codes: Set[str] = admission.get("manual_codes", set()) + else: + label_codes = admission.get("true_codes", set()) + + sample = { + "visit_id": admission["visit_id"], + "patient_id": admission["patient_id"], + "notes": admission["notes"], + "note_categories": admission["note_categories"], + "chartdates": admission["chartdates"], + "num_notes": admission.get("num_notes", len(admission["notes"])), + "text_length": admission.get("text_length", 0), + "is_gap_case": admission.get("is_gap_case"), + "manual_codes": admission.get("manual_codes", set()), + "true_codes": admission.get("true_codes", set()), + "label_codes": sorted(label_codes), + "label": codes_to_multihot(label_codes, self.target_codes), + } + return [sample] + + +def load_sdoh_icd9_labels( + csv_path: str, target_codes: Sequence[str] +) -> Dict[str, Dict[str, Set[str]]]: + df = pd.read_csv(csv_path) + if "HADM_ID" not in df.columns: + raise ValueError("CSV must include HADM_ID column.") + + labels: Dict[str, Dict[str, Set[str]]] = {} + for hadm_id, group in df.groupby("HADM_ID"): + first = group.iloc[0] + labels[str(hadm_id)] = { + "manual": parse_codes( + first.get("ADMISSION_MANUAL_LABELS"), target_codes + ), + "true": parse_codes( + first.get("ADMISSION_TRUE_CODES"), target_codes + ), + } + return labels + + +class SDOHICD9MIMIC3NoteTask(BaseTask): + """Builds admission-level samples from MIMIC-III noteevents with CSV labels.""" + + task_name: str = "SDOHICD9MIMIC3Notes" + input_schema: Dict[str, str] = { + "notes": "raw", + "note_categories": "raw", + "chartdates": "raw", + "patient_id": "raw", + "visit_id": "raw", + } + output_schema: Dict[str, object] = { + "label": ("tensor", {"dtype": torch.float32}), + } + + def __init__( + self, + label_csv_path: str, + target_codes: Optional[Sequence[str]] = None, + label_source: str = "manual", + ) -> None: + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + if label_source not in {"manual", "true"}: + raise ValueError("label_source must be 'manual' or 'true'") + self.label_source = label_source + self.label_map = load_sdoh_icd9_labels(label_csv_path, self.target_codes) + + def __call__(self, patient: Patient) -> List[Dict]: + notes: List[Event] = patient.get_events(event_type="noteevents") + if not notes: + return [] + + by_hadm: Dict[str, List[Event]] = {} + for event in notes: + hadm_id = str(event.hadm_id) + if hadm_id not in self.label_map: + continue + by_hadm.setdefault(hadm_id, []).append(event) + + samples: List[Dict] = [] + for hadm_id, events in by_hadm.items(): + events.sort(key=lambda e: e.timestamp or "") + note_texts = [str(e.text) if e.text is not None else "" for e in events] + note_categories = [str(e.category) if e.category is not None else "" for e in events] + chartdates = [ + e.timestamp.strftime("%Y-%m-%d") if e.timestamp is not None else "Unknown" + for e in events + ] + + label_codes = self.label_map[hadm_id][self.label_source] + sample = { + "visit_id": hadm_id, + "patient_id": patient.patient_id, + "notes": note_texts, + "note_categories": note_categories, + "chartdates": chartdates, + "num_notes": len(note_texts), + "text_length": int(sum(len(note) for note in note_texts)), + "manual_codes": self.label_map[hadm_id]["manual"], + "true_codes": self.label_map[hadm_id]["true"], + "label_codes": sorted(label_codes), + "label": codes_to_multihot(label_codes, self.target_codes), + } + samples.append(sample) + + return samples diff --git a/pyhealth/tasks/sdoh_utils.py b/pyhealth/tasks/sdoh_utils.py new file mode 100644 index 000000000..a79721d4d --- /dev/null +++ b/pyhealth/tasks/sdoh_utils.py @@ -0,0 +1,72 @@ +"""Utilities for SDOH ICD-9 V-code detection tasks.""" + +import logging +from typing import Iterable, List, Sequence, Set + +import pandas as pd + +logger = logging.getLogger(__name__) + + +# Standard SDOH ICD-9 V-codes +TARGET_CODES = [ + "V600", + "V602", + "V604", + "V620", + "V625", + "V1541", + "V1542", + "V1584", + "V6141", +] + + +def parse_codes(codes_str: object, target_codes: Sequence[str]) -> Set[str]: + """Parse ICD-9 codes from various string formats. + + Args: + codes_str: String representation of codes (comma/semicolon separated) + target_codes: Valid target codes to filter for + + Returns: + Set of valid codes found in the string + """ + if pd.isna(codes_str) or str(codes_str).strip() == "": + return set() + + # Clean string + codes = ( + str(codes_str) + .replace("[", "") + .replace("]", "") + .replace('"', "") + .replace("'", "") + ) + + # Split by delimiter + if "," in codes: + values = [c.strip() for c in codes.split(",")] + elif ";" in codes: + values = [c.strip() for c in codes.split(";")] + else: + values = [codes.strip()] + + # Filter to valid target codes + target_set = {code.upper() for code in target_codes} + parsed = {value.upper() for value in values if value.strip()} + return {code for code in parsed if code in target_set} + + +def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> List[int]: + """Convert code set to multi-hot encoding. + + Args: + codes: Iterable of code strings + target_codes: Ordered list of target codes + + Returns: + Multi-hot binary list aligned with target_codes + """ + code_set = {code.upper() for code in codes} + return [1 if code in code_set else 0 for code in target_codes] diff --git a/tools/mimic3_sdoh_subset.py b/tools/mimic3_sdoh_subset.py new file mode 100644 index 000000000..560169475 --- /dev/null +++ b/tools/mimic3_sdoh_subset.py @@ -0,0 +1,79 @@ +import argparse +import gzip +from pathlib import Path + +import pandas as pd + + +def write_gzip_csv(df: pd.DataFrame, path: Path, header: bool) -> None: + mode = "wt" if header else "at" + with gzip.open(path, mode) as f: + df.to_csv(f, index=False, header=header) + + +def filter_csv( + source: Path, + dest: Path, + column: str, + allowed: set, + chunksize: int = 200_000, +) -> None: + header_written = False + for chunk in pd.read_csv(source, chunksize=chunksize): + filtered = chunk[chunk[column].astype(str).isin(allowed)] + if filtered.empty: + continue + write_gzip_csv(filtered, dest, header=not header_written) + header_written = True + if not header_written: + write_gzip_csv(pd.DataFrame(columns=pd.read_csv(source, nrows=0).columns), dest, header=True) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build a small MIMIC-III subset for SDOH evaluation.") + parser.add_argument("--mimic-root", required=True, help="Path to full MIMIC-III root.") + parser.add_argument("--label-csv-path", required=True, help="Path to sdoh_icd9_dataset.csv.") + parser.add_argument("--output-root", required=True, help="Output folder for the subset.") + parser.add_argument("--note-chunksize", type=int, default=200_000) + parser.add_argument("--max-hadm", type=int, default=0, help="Limit to first N HADM_IDs.") + args = parser.parse_args() + + mimic_root = Path(args.mimic_root) + output_root = Path(args.output_root) + output_root.mkdir(parents=True, exist_ok=True) + + labels = pd.read_csv(args.label_csv_path) + hadm_list = labels["HADM_ID"].astype(str).unique().tolist() + if args.max_hadm and args.max_hadm > 0: + hadm_list = hadm_list[: args.max_hadm] + hadm_ids = set(hadm_list) + subject_ids = set(labels["SUBJECT_ID"].astype(str).unique().tolist()) + + admissions_src = mimic_root / "ADMISSIONS.csv.gz" + icustays_src = mimic_root / "ICUSTAYS.csv.gz" + patients_src = mimic_root / "PATIENTS.csv.gz" + noteevents_src = mimic_root / "NOTEEVENTS.csv.gz" + + print("Filtering ADMISSIONS...") + filter_csv(admissions_src, output_root / "ADMISSIONS.csv.gz", "HADM_ID", hadm_ids) + + print("Filtering ICUSTAYS...") + filter_csv(icustays_src, output_root / "ICUSTAYS.csv.gz", "HADM_ID", hadm_ids) + + print("Filtering PATIENTS...") + filter_csv(patients_src, output_root / "PATIENTS.csv.gz", "SUBJECT_ID", subject_ids) + + print("Filtering NOTEEVENTS...") + filter_csv( + noteevents_src, + output_root / "NOTEEVENTS.csv.gz", + "HADM_ID", + hadm_ids, + chunksize=args.note_chunksize, + ) + + print("Subset written to:", output_root) + + +if __name__ == "__main__": + main() From 3a56cfd5102e5b601657f40c72c3b90538cb1d84 Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sat, 10 Jan 2026 22:23:58 -0500 Subject: [PATCH 2/7] Refine SDOH eval and labels --- examples/sdoh_icd9_llm_eval.py | 27 ++++++++++++++------------- pyhealth/models/sdoh_icd9_llm.py | 8 +++----- pyhealth/tasks/sdoh_utils.py | 12 ++++++++---- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py index c9f390768..81b84bd4d 100644 --- a/examples/sdoh_icd9_llm_eval.py +++ b/examples/sdoh_icd9_llm_eval.py @@ -2,13 +2,14 @@ import json import os from datetime import datetime -from typing import Iterable, List, Sequence, Set +from typing import List, Set import numpy as np from pyhealth.datasets import MIMIC3NotesDataset from pyhealth.metrics import multilabel_metrics_fn from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM from pyhealth.tasks.sdoh_icd9_detection import TARGET_CODES +from pyhealth.tasks.sdoh_utils import codes_to_multihot def parse_args(): @@ -28,15 +29,6 @@ def parse_args(): return parser.parse_args() -def codes_to_multihot(codes_list: Iterable[Set[str]], target_codes: Sequence[str]) -> np.ndarray: - target_set = [code.upper() for code in target_codes] - rows: List[List[int]] = [] - for codes in codes_list: - code_set = {code.upper() for code in codes} - rows.append([1 if code in code_set else 0 for code in target_set]) - return np.array(rows, dtype=np.float32) - - def main(): args = parse_args() target_codes = list(TARGET_CODES) @@ -81,9 +73,18 @@ def main(): } ) - y_pred = codes_to_multihot(predicted_codes_all, target_codes) - y_manual = codes_to_multihot(manual_codes_all, target_codes) - y_true = codes_to_multihot(true_codes_all, target_codes) + y_pred = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in predicted_codes_all], + axis=0, + ) + y_manual = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in manual_codes_all], + axis=0, + ) + y_true = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in true_codes_all], + axis=0, + ) metrics_list = [ "accuracy", diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py index f22ce29f0..a9f956083 100644 --- a/pyhealth/models/sdoh_icd9_llm.py +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -182,16 +182,14 @@ def forward( categories_batch = note_categories or [None] * len(notes_batch) dates_batch = chartdates or [None] * len(notes_batch) - batch_probs: List[List[int]] = [] + batch_probs: List[torch.Tensor] = [] for note_list, cats, dates in zip( notes_batch, categories_batch, dates_batch ): predicted, _ = self._predict_admission(note_list, cats, dates) - batch_probs.append( - codes_to_multihot(predicted, self.target_codes) - ) + batch_probs.append(codes_to_multihot(predicted, self.target_codes)) - y_prob = torch.tensor(batch_probs, dtype=torch.float32) + y_prob = torch.stack(batch_probs, dim=0) if label is not None and isinstance(label, torch.Tensor): y_prob = y_prob.to(label.device) y_true = label diff --git a/pyhealth/tasks/sdoh_utils.py b/pyhealth/tasks/sdoh_utils.py index a79721d4d..d57773d3f 100644 --- a/pyhealth/tasks/sdoh_utils.py +++ b/pyhealth/tasks/sdoh_utils.py @@ -1,9 +1,10 @@ """Utilities for SDOH ICD-9 V-code detection tasks.""" import logging -from typing import Iterable, List, Sequence, Set +from typing import Iterable, Sequence, Set import pandas as pd +import torch logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ def parse_codes(codes_str: object, target_codes: Sequence[str]) -> Set[str]: return {code for code in parsed if code in target_set} -def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> List[int]: +def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> torch.Tensor: """Convert code set to multi-hot encoding. Args: @@ -66,7 +67,10 @@ def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> List target_codes: Ordered list of target codes Returns: - Multi-hot binary list aligned with target_codes + Multi-hot tensor aligned with target_codes """ code_set = {code.upper() for code in codes} - return [1 if code in code_set else 0 for code in target_codes] + return torch.tensor( + [1.0 if code in code_set else 0.0 for code in target_codes], + dtype=torch.float32, + ) From a56220a80b114c2e407c4f530cebaf53e49eecd0 Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sat, 17 Jan 2026 22:27:48 -0500 Subject: [PATCH 3/7] Add SDOH ICD-9 note evaluation pipeline --- examples/sdoh_icd9_llm_eval.py | 47 +++++++- pyhealth/datasets/mimic3_notes.py | 31 +++-- pyhealth/datasets/sdoh_icd9.py | 14 ++- pyhealth/models/sdoh_icd9_llm.py | 104 ++++------------- pyhealth/models/sdoh_icd9_task.txt | 158 ++++++++++++++++++++++++++ pyhealth/tasks/sdoh_icd9_detection.py | 113 +----------------- pyhealth/tasks/sdoh_utils.py | 22 +++- tools/mimic3_sdoh_subset.py | 79 ------------- 8 files changed, 287 insertions(+), 281 deletions(-) create mode 100644 pyhealth/models/sdoh_icd9_task.txt delete mode 100644 tools/mimic3_sdoh_subset.py diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py index 81b84bd4d..8162b1be6 100644 --- a/examples/sdoh_icd9_llm_eval.py +++ b/examples/sdoh_icd9_llm_eval.py @@ -24,6 +24,20 @@ def parse_args(): choices=["manual", "true"], help="Which labels to use as primary ground truth.", ) + parser.add_argument( + "--max-notes", + default="all", + help="Limit notes per admission (e.g., 1, 2, 5, or 'all').", + ) + parser.add_argument( + "--max-admissions", + default="all", + help="Limit admissions to process (e.g., 5 or 'all').", + ) + parser.add_argument( + "--note-categories", + help="Comma-separated NOTE_CATEGORY values to include (optional).", + ) parser.add_argument("--output-dir", default=".", help="Directory to save outputs.") parser.add_argument("--dry-run", action="store_true") return parser.parse_args() @@ -33,16 +47,47 @@ def main(): args = parse_args() target_codes = list(TARGET_CODES) + include_categories = ( + [cat.strip() for cat in args.note_categories.split(",")] + if args.note_categories + else None + ) + if str(args.max_notes).lower() == "all": + max_notes = None + else: + try: + max_notes = int(args.max_notes) + except ValueError as exc: + raise ValueError("--max-notes must be an integer or 'all'") from exc + if max_notes <= 0: + raise ValueError("--max-notes must be a positive integer or 'all'") + if str(args.max_admissions).lower() == "all": + max_admissions = None + else: + try: + max_admissions = int(args.max_admissions) + except ValueError as exc: + raise ValueError("--max-admissions must be an integer or 'all'") from exc + if max_admissions <= 0: + raise ValueError("--max-admissions must be a positive integer or 'all'") + noteevents_path = f"{args.mimic_root}/NOTEEVENTS.csv.gz" note_dataset = MIMIC3NotesDataset( noteevents_path=noteevents_path, label_csv_path=args.label_csv_path, target_codes=target_codes, + include_categories=include_categories, ) sample_dataset = note_dataset.set_task(label_source=args.label_source) + if max_admissions is not None: + sample_dataset = sample_dataset.subset(slice(0, max_admissions)) dry_run = args.dry_run or not os.environ.get("OPENAI_API_KEY") - model = SDOHICD9LLM(target_codes=target_codes, dry_run=dry_run) + model = SDOHICD9LLM( + target_codes=target_codes, + dry_run=dry_run, + max_notes=max_notes, + ) results = [] predicted_codes_all: List[Set[str]] = [] diff --git a/pyhealth/datasets/mimic3_notes.py b/pyhealth/datasets/mimic3_notes.py index 1e1584cdc..c0c223486 100644 --- a/pyhealth/datasets/mimic3_notes.py +++ b/pyhealth/datasets/mimic3_notes.py @@ -4,8 +4,8 @@ import pandas as pd from .sample_dataset import SampleDataset, create_sample_dataset -from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask, load_sdoh_icd9_labels -from ..tasks.sdoh_utils import TARGET_CODES +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask +from ..tasks.sdoh_utils import TARGET_CODES, load_sdoh_icd9_labels logger = logging.getLogger(__name__) @@ -19,16 +19,21 @@ def __init__( label_csv_path: str, target_codes: Optional[Sequence[str]] = None, hadm_ids: Optional[Iterable[str]] = None, + include_categories: Optional[Sequence[str]] = None, chunksize: int = 200_000, dataset_name: Optional[str] = None, ) -> None: self.noteevents_path = noteevents_path - self.label_csv_path = label_csv_path self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) self.label_map = load_sdoh_icd9_labels(label_csv_path, self.target_codes) if hadm_ids is None: hadm_ids = self.label_map.keys() self.hadm_ids = {str(x) for x in hadm_ids} + self.include_categories = ( + {cat.strip().upper() for cat in include_categories} + if include_categories + else None + ) self.chunksize = chunksize self.dataset_name = dataset_name or "mimic3_notes" @@ -54,6 +59,14 @@ def _load_notes(self) -> Dict[str, List[Dict]]: filtered = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)] if filtered.empty: continue + if self.include_categories is not None: + filtered = filtered[ + filtered["CATEGORY"].astype("string") + .str.upper() + .isin(self.include_categories) + ] + if filtered.empty: + continue charttime = pd.to_datetime(filtered["CHARTTIME"], errors="coerce") chartdate = pd.to_datetime(filtered["CHARTDATE"], errors="coerce") @@ -62,9 +75,9 @@ def _load_notes(self) -> Dict[str, List[Dict]]: for row, ts in zip(filtered.itertuples(index=False), timestamp): hadm_id = str(row.HADM_ID) entry = { - "patient_id": str(row.SUBJECT_ID) if row.SUBJECT_ID is not pd.NA else "", - "text": row.TEXT if row.TEXT is not pd.NA else "", - "category": row.CATEGORY if row.CATEGORY is not pd.NA else "", + "patient_id": str(row.SUBJECT_ID) if pd.notna(row.SUBJECT_ID) else "", + "text": row.TEXT if pd.notna(row.TEXT) else "", + "category": row.CATEGORY if pd.notna(row.CATEGORY) else "", "timestamp": ts, } notes_by_hadm.setdefault(hadm_id, []).append(entry) @@ -77,7 +90,11 @@ def _build_admissions(self) -> List[Dict]: for hadm_id, notes in notes_by_hadm.items(): if hadm_id not in self.label_map: continue - notes.sort(key=lambda x: x["timestamp"] if x["timestamp"] is not pd.NaT else pd.Timestamp.min) + notes.sort( + key=lambda x: x["timestamp"] + if pd.notna(x["timestamp"]) + else pd.Timestamp.min + ) note_texts = [str(n["text"]) for n in notes] note_categories = [str(n["category"]) for n in notes] chartdates = [ diff --git a/pyhealth/datasets/sdoh_icd9.py b/pyhealth/datasets/sdoh_icd9.py index 49def92d4..e12a695a2 100644 --- a/pyhealth/datasets/sdoh_icd9.py +++ b/pyhealth/datasets/sdoh_icd9.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Sequence, Set +from typing import Dict, List, Optional, Sequence import pandas as pd @@ -29,10 +29,16 @@ def __init__( csv_path: str, dataset_name: Optional[str] = None, target_codes: Optional[Sequence[str]] = None, + include_categories: Optional[Sequence[str]] = None, ) -> None: self.csv_path = csv_path self.dataset_name = dataset_name or "sdoh_icd9_notes" self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.include_categories = ( + {cat.strip().upper() for cat in include_categories} + if include_categories + else None + ) self._admissions = self._load_admissions() def _load_admissions(self) -> List[Dict]: @@ -43,6 +49,12 @@ def _load_admissions(self) -> List[Dict]: admissions: List[Dict] = [] df["CHARTDATE"] = pd.to_datetime(df["CHARTDATE"], errors="coerce") + if self.include_categories is not None: + df = df[ + df["NOTE_CATEGORY"].astype("string") + .str.upper() + .isin(self.include_categories) + ] for hadm_id, group in df.groupby("HADM_ID"): group = group.sort_values("CHARTDATE") first = group.iloc[0] diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py index a9f956083..71a6bda6d 100644 --- a/pyhealth/models/sdoh_icd9_llm.py +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -1,30 +1,23 @@ -import logging import os +import hashlib import re import time from typing import Iterable, List, Optional, Sequence, Set, Tuple -import torch -from torch import nn +from pyhealth.tasks.sdoh_utils import TARGET_CODES -from pyhealth.tasks.sdoh_utils import TARGET_CODES, codes_to_multihot -logger = logging.getLogger(__name__) +PROMPT_PATH = os.path.join(os.path.dirname(__file__), "sdoh_icd9_task.txt") -PROMPT_TEMPLATE = """\ -You are an assistant that extracts SDOH ICD-9 V-codes from clinical notes. -Return only the codes, comma-separated, inside triple backticks. -If no target codes are present, return None inside triple backticks. -Target codes: {codes} -""" +def _load_prompt_template() -> str: + with open(PROMPT_PATH, "r", encoding="utf-8") as f: + return f.read() -class SDOHICD9LLM(nn.Module): +class SDOHICD9LLM: """Admission-level SDOH ICD-9 V-code detector using an LLM.""" - mode = "multilabel" - def __init__( self, target_codes: Optional[Sequence[str]] = None, @@ -35,17 +28,18 @@ def __init__( max_chars: int = 100000, temperature: float = 0.0, sleep_s: float = 0.2, + max_notes: Optional[int] = None, dry_run: bool = False, ) -> None: - super().__init__() self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) self.model_name = model_name - self.prompt_template = prompt_template or PROMPT_TEMPLATE + self.prompt_template = prompt_template or _load_prompt_template() self.api_key = api_key or os.environ.get("OPENAI_API_KEY") self.max_tokens = max_tokens self.max_chars = max_chars self.temperature = temperature self.sleep_s = sleep_s + self.max_notes = max_notes self.dry_run = dry_run self._client = None @@ -54,12 +48,6 @@ def __init__( "OPENAI_API_KEY is required unless dry_run=True." ) - mode = "dry-run" if dry_run else "live" - logger.info( - "Initialized SDOHICD9LLM (mode=%s, model=%s, codes=%d)", - mode, model_name, len(self.target_codes) - ) - def _get_client(self): if self._client is None: from openai import OpenAI @@ -68,6 +56,8 @@ def _get_client(self): return self._client def _call_openai_api(self, text: str) -> str: + self._write_prompt_preview(text) + if self.dry_run: return "```None```" @@ -78,25 +68,20 @@ def _call_openai_api(self, text: str) -> str: response = client.chat.completions.create( model=self.model_name, messages=[ - { - "role": "system", - "content": self.prompt_template.format( - codes=", ".join(self.target_codes) - ), - }, - { - "role": "user", - "content": ( - "Analyze this clinical note and identify SDOH codes:\n\n" - f"{text}" - ), - }, + {"role": "system", "content": self.prompt_template.format(note=text)}, ], max_tokens=self.max_tokens, temperature=self.temperature, ) return response.choices[0].message.content.strip() + def _write_prompt_preview(self, text: str) -> None: + prompt = self.prompt_template.format(note=text) + digest = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10] + filename = f"sdoh_prompt_{digest}.txt" + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt) + def _parse_llm_response(self, response: str) -> Set[str]: if not response: return set() @@ -130,26 +115,21 @@ def _predict_admission( ) -> Tuple[Set[str], List[dict]]: aggregated: Set[str] = set() note_results: List[dict] = [] - categories = list(note_categories) if note_categories is not None else [] dates = list(chartdates) if chartdates is not None else [] notes_list = list(notes) - - logger.debug("Processing admission with %d notes", len(notes_list)) + if self.max_notes and self.max_notes > 0: + notes_list = notes_list[: self.max_notes] + categories = categories[: self.max_notes] + dates = dates[: self.max_notes] for idx, note in enumerate(notes_list): category = categories[idx] if idx < len(categories) else "Unknown" date = dates[idx] if idx < len(dates) else "Unknown" - response = self._call_openai_api(note) predicted = self._parse_llm_response(response) aggregated.update(predicted) - logger.debug( - "Note %d/%d (%s, %s): predicted %s", - idx + 1, len(notes_list), category, date, sorted(predicted) or "none" - ) - note_results.append( { "category": category, @@ -158,47 +138,11 @@ def _predict_admission( "llm_response": response, } ) - if self.sleep_s > 0 and not self.dry_run: time.sleep(self.sleep_s) - logger.debug("Admission complete: aggregated codes %s", sorted(aggregated)) return aggregated, note_results - def forward( - self, - notes, - note_categories=None, - chartdates=None, - label=None, - **kwargs, - ): - if notes and isinstance(notes[0], str): - notes_batch = [notes] - categories_batch = [note_categories] if note_categories is not None else [None] - dates_batch = [chartdates] if chartdates is not None else [None] - else: - notes_batch = notes - categories_batch = note_categories or [None] * len(notes_batch) - dates_batch = chartdates or [None] * len(notes_batch) - - batch_probs: List[torch.Tensor] = [] - for note_list, cats, dates in zip( - notes_batch, categories_batch, dates_batch - ): - predicted, _ = self._predict_admission(note_list, cats, dates) - batch_probs.append(codes_to_multihot(predicted, self.target_codes)) - - y_prob = torch.stack(batch_probs, dim=0) - if label is not None and isinstance(label, torch.Tensor): - y_prob = y_prob.to(label.device) - y_true = label - else: - y_true = label - - loss = torch.zeros(1, device=y_prob.device).sum() - return {"loss": loss, "y_prob": y_prob, "y_true": y_true} - def predict_admission_with_notes( self, notes: Iterable[str], diff --git a/pyhealth/models/sdoh_icd9_task.txt b/pyhealth/models/sdoh_icd9_task.txt new file mode 100644 index 000000000..ed2f847ef --- /dev/null +++ b/pyhealth/models/sdoh_icd9_task.txt @@ -0,0 +1,158 @@ +# SDOH Medical Coding Task + +You are an expert medical coder. Your job is to find Social Determinants of Health (SDOH) codes in clinical notes. + +## CRITICAL PREDICTION RULES: +1. **IF YOU FIND EVIDENCE FOR A CODE → YOU MUST PREDICT THAT CODE** +2. **BE CONSERVATIVE**: Only predict codes with STRONG evidence +3. **Evidence detection = Automatic code prediction** + +## CODE HIERARCHY - ABSOLUTELY CRITICAL: +**HOMELESS (V600) COMPLETELY EXCLUDES V602:** +- ❌ If you see homelessness → ONLY predict V600, NEVER V602 +- ❌ "Homeless and can't afford X" → STILL only V600 +- ✅ V602 is EXCLUSIVELY for housed people with explicit financial hardship + +**V602 ULTRA-STRICT REQUIREMENTS:** +- ✅ Must have exact quotes: "cannot afford medications", "no money for food" +- ❌ NEVER infer from: homelessness, unemployment, substance use, mental health +- ❌ NEVER predict from: "poor", "disadvantaged", social circumstances + + +## Available ICD-9 Codes: + +**Housing & Resources:** +- V600: Homeless (living on streets, shelters, hotels, motels, temporary housing) +- V602: Cannot afford basic needs. **ULTRA-STRICT RULES**: + - ✅ ONLY if EXPLICIT financial statements: "cannot afford medications", "unable to pay for treatment", "no money for food", "financial hardship preventing care" + - ❌ NEVER predict from social circumstances: substance abuse, unemployment, mental health, divorce + - ❌ NEVER predict from discharge inventory: "no money/wallet returned" + - ❌ NEVER predict for insurance/benefits mentions: "on disability", "has Medicaid" +- V604: No family/caregiver available. **CRITICAL RULES**: + - ❌ NOT just "lives alone" or "elderly" + - ✅ ONLY if: "no family contact" AND "no support" AND "no one to help" + - ❌ NOT for: "lives alone but daughter visits", "independent" + +**Employment & Legal:** +- V620: Unemployed (current employment status) — **EXPLICIT ONLY** + - ✅ **Predict only if one of these exact employment-status phrases appears (word boundaries, case-insensitive):** + - unemploy(ed|ment) + - jobless | no job + - out of work | without work (employment context only; see exclusions) + - not working (employment context only; see exclusions) + - between jobs + - on unemployment | receiving unemployment + - Recent loss: laid off | fired | terminated | lost (my|his|her|their)? job + - Label–value fields: Employment|Employment status|Work: → unemployed|none|not working ⇒ V620 + - ❌ Do NOT infer from former-only phrases: used to work | formerly | previously employed | last worked in ... UNLESS illness/disability is given as the reason for stopping work + - ❌ **Contradictions (block V620 if present nearby, ±1 sentence)**: employed | works at/as | working | return(ed) to work | full-time/part-time | self-employed | contractor | on leave/LOA | work excuse | at work today + - ❌ **Non-employment uses of "work" (never trigger)**: work of breathing | workup/work-up | social work | line not working | device not working | therapy not working | meds not working + - ❌ **Context rule for "not/without work"**: Only treat as unemployment when it clearly refers to employment status (e.g., in Social History/Employment section or followed by "due to X" re job). Otherwise, do not code + - ❌ **Exclusions**: retired | student | stay-at-home parent | on medical leave — unless an explicit unemployment phrase above is present + - ✅ Include cases where patient stopped working or is unable to work due to illness/disability (including SSDI/SSI). +- V625: Legal problems. **ACTIVE LEGAL PROCESSES ONLY**: + - ✅ Criminal: Arrest, jail, prison, parole, probation, active charges, bail, or court case pending + - ✅ Civil: Court-ordered custody, restraining order, supervised release, or active litigation + - ✅ Guardianship: Legal capacity hearing, court-appointed guardian, or power of attorney via court + - ✅ Child welfare: DCF/CPS removed children, placed them in care, or court-ordered custody change + - ❌ Do NOT predict V625 for: + - History of crime, substance use, or homelessness without active legal process + - Drug use history alone (marijuana, cocaine, etc.) without legal consequences + - Psych history, social issues, or missed appointments without legal involvement + - DCF "awareness" without custody change or legal action + +**Health History:** +- V1541: Physical/sexual abuse (violence BY another person). **CRITICAL**: + - ✅ Physical violence, sexual abuse, assault, domestic violence, rape + - ❌ NOT accidents, falls, fights where patient was aggressor +- V1542: Pure emotional/psychological abuse. **MUTUALLY EXCLUSIVE WITH V1541**: + - ✅ **ONLY if NO physical/sexual abuse mentioned AND explicit emotional abuse:** + - ✅ Witnessed violence: "witnessed violence as a child", "saw domestic violence" + - ✅ Verbal abuse: "verbal abuse", "emotionally abusive", "psychological abuse" + - ✅ Emotional manipulation: "jealous and controlling", "isolation", "intimidation" + - ❌ **NEVER predict V1542 if ANY physical/sexual abuse mentioned (use V1541 instead)** + - ❌ **NEVER predict both V1541 AND V1542 for same patient** + - ❌ Depression, anxiety, suicidal ideation alone without explicit abuse = NO CODE + - ❌ "History of abuse" without specifying type = NO CODE + - ❌ Psychiatric history alone = NO CODE +- V1584: Past asbestos exposure + +**Family History:** +- V6141: Family alcoholism (family member drinks) — PREDICT if any kinship term + alcohol term appear together. + - ✅ Kinship terms: father, mother, dad, mom, brother, sister, son, daughter, uncle, aunt, cousin, grand*; Headers: FH, FHx, Family History + - ✅ Alcohol terms (case-insensitive, synonyms OK): ETOH/EtOH/etoh, alcohol, alcoholism, AUD, alcohol use disorder, EtOH abuse, EtOH dependence, "alcohol problem(s)", "drinks heavily", "alcoholic" + - ✅ Mixed substance OK: If text says "drug and alcohol problems," still predict V6141 + - ✅ Outside headers OK: If kinship + alcohol appear in same clause/sentence or within ±1 line (e.g., "Pt's father ... has hx of etoh"), predict V6141 + - ✅ Examples to capture: "FH: Father – ETOH", "Mother has h/o alcoholism", "Father with depression and alcoholism", "Multiple family members with ETOH abuse ... (cousin, sister, uncle, aunt, father)", "Both brothers have drug and alcohol problems" + - ❌ Negations: Do not predict if explicitly denied (e.g., "denies family history of alcoholism") + - ❌ NEVER for PATIENT'S own history: "history of alcohol abuse", "with a history of alcohol use", "past medical history significant for heavy EtOH abuse", "patient alcoholic", "ETOH abuse" + +## ENHANCED NEGATIVE EXAMPLES: + +**V602 FALSE POSITIVES TO AVOID:** +❌ "Homeless patient" → Predict V600 ONLY, NEVER V602 +❌ "Lives in shelter, gets food stamps" → V600 ONLY, NEVER V602 +❌ "Homeless, on disability" → V600 ONLY, NEVER V602 +❌ "No permanent address, has Medicaid" → V600 ONLY, NEVER V602 +❌ "Homeless and can't afford medications" → V600 ONLY, NEVER V602 +❌ "Unemployed alcoholic" → V620 (unemployment is explicit), NEVER V602 +❌ "Lives in poverty" → NEVER V602 (too vague) +❌ "Financial strain from divorce" → NEVER V602 (circumstantial) + +**V604 FALSE POSITIVES TO AVOID:** +❌ "82 year old lives alone" → NO CODE unless no support mentioned +❌ "Lives by herself" → NO CODE unless isolation confirmed +❌ "Widowed, lives alone, son calls daily" → NO CODE (has support) + +**V1542 FALSE POSITIVES TO AVOID:** +❌ "History of physical and sexual abuse" → V1541 ONLY (physical trumps emotional) +❌ "PTSD from rape at age 7" → V1541 ONLY (sexual abuse) +❌ "Childhood sexual abuse by uncle" → V1541 ONLY (sexual abuse) +❌ "History of domestic abuse" → V1541 ONLY (physical abuse) +❌ "Depression and anxiety" → NO CODE (psychiatric symptoms alone) +❌ "Suicide attempts" → NO CODE (mental health history alone) +❌ "History of abuse" → NO CODE (unspecified type) +❌ "Recent argument with partner" → NO CODE (relationship conflict) + +**V1542 TRUE POSITIVES TO CAPTURE:** +✅ "Witnessed violence as a child" → V1542 (pure emotional trauma, no physical) +✅ "Emotionally abusive relationship for 14 years" → V1542 (explicit emotional abuse) +✅ "Verbal abuse from controlling partner" → V1542 (explicit emotional abuse) +✅ "Jealous and controlling behavior" → V1542 (emotional manipulation) + +## CONFIDENCE RULES: + +**HIGH CONFIDENCE (Predict):** +- Direct statement of condition +- Multiple supporting evidence pieces +- Explicit language matching code definition + +**LOW CONFIDENCE (Don't Predict):** +- Ambiguous language +- Single weak indicator +- Contradictory evidence + +## Key Rules: + +1. **Precision over Recall**: Better to miss a code than falsely predict +2. **Evidence-Driven**: Strong evidence required for prediction +3. **Multiple codes allowed**: But each needs independent evidence +4. **Conservative approach**: When in doubt, don't predict + +## Output Format: +Return applicable codes separated by commas, or "None" if no codes apply. + +Example: +``` +V600, V625 +``` + +or if no codes apply: +``` +None +``` + +--- + +**Clinical Note to Analyze:** +{note} \ No newline at end of file diff --git a/pyhealth/tasks/sdoh_icd9_detection.py b/pyhealth/tasks/sdoh_icd9_detection.py index 42ce722ea..978307d02 100644 --- a/pyhealth/tasks/sdoh_icd9_detection.py +++ b/pyhealth/tasks/sdoh_icd9_detection.py @@ -1,14 +1,9 @@ -import logging from typing import Dict, List, Optional, Sequence, Set -import pandas as pd import torch -from ..data import Event, Patient from .base_task import BaseTask -from .sdoh_utils import TARGET_CODES, codes_to_multihot, parse_codes - -logger = logging.getLogger(__name__) +from .sdoh_utils import TARGET_CODES, codes_to_multihot class SDOHICD9AdmissionTask(BaseTask): @@ -42,107 +37,7 @@ def __call__(self, admission: Dict) -> List[Dict]: else: label_codes = admission.get("true_codes", set()) - sample = { - "visit_id": admission["visit_id"], - "patient_id": admission["patient_id"], - "notes": admission["notes"], - "note_categories": admission["note_categories"], - "chartdates": admission["chartdates"], - "num_notes": admission.get("num_notes", len(admission["notes"])), - "text_length": admission.get("text_length", 0), - "is_gap_case": admission.get("is_gap_case"), - "manual_codes": admission.get("manual_codes", set()), - "true_codes": admission.get("true_codes", set()), - "label_codes": sorted(label_codes), - "label": codes_to_multihot(label_codes, self.target_codes), - } + sample = dict(admission) + sample["label_codes"] = sorted(label_codes) + sample["label"] = codes_to_multihot(label_codes, self.target_codes) return [sample] - - -def load_sdoh_icd9_labels( - csv_path: str, target_codes: Sequence[str] -) -> Dict[str, Dict[str, Set[str]]]: - df = pd.read_csv(csv_path) - if "HADM_ID" not in df.columns: - raise ValueError("CSV must include HADM_ID column.") - - labels: Dict[str, Dict[str, Set[str]]] = {} - for hadm_id, group in df.groupby("HADM_ID"): - first = group.iloc[0] - labels[str(hadm_id)] = { - "manual": parse_codes( - first.get("ADMISSION_MANUAL_LABELS"), target_codes - ), - "true": parse_codes( - first.get("ADMISSION_TRUE_CODES"), target_codes - ), - } - return labels - - -class SDOHICD9MIMIC3NoteTask(BaseTask): - """Builds admission-level samples from MIMIC-III noteevents with CSV labels.""" - - task_name: str = "SDOHICD9MIMIC3Notes" - input_schema: Dict[str, str] = { - "notes": "raw", - "note_categories": "raw", - "chartdates": "raw", - "patient_id": "raw", - "visit_id": "raw", - } - output_schema: Dict[str, object] = { - "label": ("tensor", {"dtype": torch.float32}), - } - - def __init__( - self, - label_csv_path: str, - target_codes: Optional[Sequence[str]] = None, - label_source: str = "manual", - ) -> None: - self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) - if label_source not in {"manual", "true"}: - raise ValueError("label_source must be 'manual' or 'true'") - self.label_source = label_source - self.label_map = load_sdoh_icd9_labels(label_csv_path, self.target_codes) - - def __call__(self, patient: Patient) -> List[Dict]: - notes: List[Event] = patient.get_events(event_type="noteevents") - if not notes: - return [] - - by_hadm: Dict[str, List[Event]] = {} - for event in notes: - hadm_id = str(event.hadm_id) - if hadm_id not in self.label_map: - continue - by_hadm.setdefault(hadm_id, []).append(event) - - samples: List[Dict] = [] - for hadm_id, events in by_hadm.items(): - events.sort(key=lambda e: e.timestamp or "") - note_texts = [str(e.text) if e.text is not None else "" for e in events] - note_categories = [str(e.category) if e.category is not None else "" for e in events] - chartdates = [ - e.timestamp.strftime("%Y-%m-%d") if e.timestamp is not None else "Unknown" - for e in events - ] - - label_codes = self.label_map[hadm_id][self.label_source] - sample = { - "visit_id": hadm_id, - "patient_id": patient.patient_id, - "notes": note_texts, - "note_categories": note_categories, - "chartdates": chartdates, - "num_notes": len(note_texts), - "text_length": int(sum(len(note) for note in note_texts)), - "manual_codes": self.label_map[hadm_id]["manual"], - "true_codes": self.label_map[hadm_id]["true"], - "label_codes": sorted(label_codes), - "label": codes_to_multihot(label_codes, self.target_codes), - } - samples.append(sample) - - return samples diff --git a/pyhealth/tasks/sdoh_utils.py b/pyhealth/tasks/sdoh_utils.py index d57773d3f..78e81bd57 100644 --- a/pyhealth/tasks/sdoh_utils.py +++ b/pyhealth/tasks/sdoh_utils.py @@ -1,13 +1,10 @@ """Utilities for SDOH ICD-9 V-code detection tasks.""" -import logging -from typing import Iterable, Sequence, Set +from typing import Dict, Iterable, Sequence, Set import pandas as pd import torch -logger = logging.getLogger(__name__) - # Standard SDOH ICD-9 V-codes TARGET_CODES = [ @@ -74,3 +71,20 @@ def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> torc [1.0 if code in code_set else 0.0 for code in target_codes], dtype=torch.float32, ) + + +def load_sdoh_icd9_labels( + csv_path: str, target_codes: Sequence[str] +) -> Dict[str, Dict[str, Set[str]]]: + df = pd.read_csv(csv_path) + if "HADM_ID" not in df.columns: + raise ValueError("CSV must include HADM_ID column.") + + labels: Dict[str, Dict[str, Set[str]]] = {} + for hadm_id, group in df.groupby("HADM_ID"): + first = group.iloc[0] + labels[str(hadm_id)] = { + "manual": parse_codes(first.get("ADMISSION_MANUAL_LABELS"), target_codes), + "true": parse_codes(first.get("ADMISSION_TRUE_CODES"), target_codes), + } + return labels diff --git a/tools/mimic3_sdoh_subset.py b/tools/mimic3_sdoh_subset.py deleted file mode 100644 index 560169475..000000000 --- a/tools/mimic3_sdoh_subset.py +++ /dev/null @@ -1,79 +0,0 @@ -import argparse -import gzip -from pathlib import Path - -import pandas as pd - - -def write_gzip_csv(df: pd.DataFrame, path: Path, header: bool) -> None: - mode = "wt" if header else "at" - with gzip.open(path, mode) as f: - df.to_csv(f, index=False, header=header) - - -def filter_csv( - source: Path, - dest: Path, - column: str, - allowed: set, - chunksize: int = 200_000, -) -> None: - header_written = False - for chunk in pd.read_csv(source, chunksize=chunksize): - filtered = chunk[chunk[column].astype(str).isin(allowed)] - if filtered.empty: - continue - write_gzip_csv(filtered, dest, header=not header_written) - header_written = True - if not header_written: - write_gzip_csv(pd.DataFrame(columns=pd.read_csv(source, nrows=0).columns), dest, header=True) - - -def main() -> None: - parser = argparse.ArgumentParser(description="Build a small MIMIC-III subset for SDOH evaluation.") - parser.add_argument("--mimic-root", required=True, help="Path to full MIMIC-III root.") - parser.add_argument("--label-csv-path", required=True, help="Path to sdoh_icd9_dataset.csv.") - parser.add_argument("--output-root", required=True, help="Output folder for the subset.") - parser.add_argument("--note-chunksize", type=int, default=200_000) - parser.add_argument("--max-hadm", type=int, default=0, help="Limit to first N HADM_IDs.") - args = parser.parse_args() - - mimic_root = Path(args.mimic_root) - output_root = Path(args.output_root) - output_root.mkdir(parents=True, exist_ok=True) - - labels = pd.read_csv(args.label_csv_path) - hadm_list = labels["HADM_ID"].astype(str).unique().tolist() - if args.max_hadm and args.max_hadm > 0: - hadm_list = hadm_list[: args.max_hadm] - hadm_ids = set(hadm_list) - subject_ids = set(labels["SUBJECT_ID"].astype(str).unique().tolist()) - - admissions_src = mimic_root / "ADMISSIONS.csv.gz" - icustays_src = mimic_root / "ICUSTAYS.csv.gz" - patients_src = mimic_root / "PATIENTS.csv.gz" - noteevents_src = mimic_root / "NOTEEVENTS.csv.gz" - - print("Filtering ADMISSIONS...") - filter_csv(admissions_src, output_root / "ADMISSIONS.csv.gz", "HADM_ID", hadm_ids) - - print("Filtering ICUSTAYS...") - filter_csv(icustays_src, output_root / "ICUSTAYS.csv.gz", "HADM_ID", hadm_ids) - - print("Filtering PATIENTS...") - filter_csv(patients_src, output_root / "PATIENTS.csv.gz", "SUBJECT_ID", subject_ids) - - print("Filtering NOTEEVENTS...") - filter_csv( - noteevents_src, - output_root / "NOTEEVENTS.csv.gz", - "HADM_ID", - hadm_ids, - chunksize=args.note_chunksize, - ) - - print("Subset written to:", output_root) - - -if __name__ == "__main__": - main() From 3f21b5c090813607d3289fc70b5e5cb8284335f3 Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sun, 18 Jan 2026 14:11:38 -0500 Subject: [PATCH 4/7] Refactor MIMIC3 note loader to chunked reader --- examples/sdoh_icd9_llm_eval.py | 40 ++++--- pyhealth/datasets/__init__.py | 3 +- pyhealth/datasets/mimic3.py | 141 ++++++++++++++++++++++++- pyhealth/datasets/mimic3_notes.py | 146 -------------------------- pyhealth/tasks/sdoh_icd9_detection.py | 8 ++ 5 files changed, 174 insertions(+), 164 deletions(-) delete mode 100644 pyhealth/datasets/mimic3_notes.py diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py index 8162b1be6..aa896de3a 100644 --- a/examples/sdoh_icd9_llm_eval.py +++ b/examples/sdoh_icd9_llm_eval.py @@ -5,11 +5,11 @@ from typing import List, Set import numpy as np -from pyhealth.datasets import MIMIC3NotesDataset +from pyhealth.datasets import MIMIC3NoteDataset from pyhealth.metrics import multilabel_metrics_fn from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM from pyhealth.tasks.sdoh_icd9_detection import TARGET_CODES -from pyhealth.tasks.sdoh_utils import codes_to_multihot +from pyhealth.tasks.sdoh_utils import codes_to_multihot, load_sdoh_icd9_labels def parse_args(): @@ -46,6 +46,7 @@ def parse_args(): def main(): args = parse_args() target_codes = list(TARGET_CODES) + label_map = load_sdoh_icd9_labels(args.label_csv_path, target_codes) include_categories = ( [cat.strip() for cat in args.note_categories.split(",")] @@ -68,19 +69,24 @@ def main(): max_admissions = int(args.max_admissions) except ValueError as exc: raise ValueError("--max-admissions must be an integer or 'all'") from exc - if max_admissions <= 0: - raise ValueError("--max-admissions must be a positive integer or 'all'") + if max_admissions <= 0: + raise ValueError("--max-admissions must be a positive integer or 'all'") - noteevents_path = f"{args.mimic_root}/NOTEEVENTS.csv.gz" - note_dataset = MIMIC3NotesDataset( - noteevents_path=noteevents_path, - label_csv_path=args.label_csv_path, + hadm_ids = list(label_map.keys()) + if max_admissions is not None: + hadm_ids = hadm_ids[:max_admissions] + label_map = {hadm_id: label_map[hadm_id] for hadm_id in hadm_ids} + + note_dataset = MIMIC3NoteDataset( + root=args.mimic_root, target_codes=target_codes, + hadm_ids=hadm_ids, include_categories=include_categories, ) - sample_dataset = note_dataset.set_task(label_source=args.label_source) - if max_admissions is not None: - sample_dataset = sample_dataset.subset(slice(0, max_admissions)) + sample_dataset = note_dataset.set_task( + label_source=args.label_source, + label_map=label_map, + ) dry_run = args.dry_run or not os.environ.get("OPENAI_API_KEY") model = SDOHICD9LLM( @@ -101,8 +107,12 @@ def main(): sample.get("chartdates"), ) predicted_codes_all.append(predicted_codes) - manual_codes_all.append(set(sample.get("manual_codes", []))) - true_codes_all.append(set(sample.get("true_codes", []))) + visit_id = str(sample.get("visit_id", "")) + label_entry = label_map.get(visit_id, {"manual": set(), "true": set()}) + manual_codes = set(label_entry["manual"]) + true_codes = set(label_entry["true"]) + manual_codes_all.append(manual_codes) + true_codes_all.append(true_codes) results.append( { @@ -111,8 +121,8 @@ def main(): "num_notes": sample.get("num_notes"), "text_length": sample.get("text_length"), "is_gap_case": sample.get("is_gap_case"), - "manual_codes": ",".join(sorted(sample.get("manual_codes", []))), - "true_codes": ",".join(sorted(sample.get("true_codes", []))), + "manual_codes": ",".join(sorted(manual_codes)), + "true_codes": ",".join(sorted(true_codes)), "predicted_codes": ",".join(sorted(predicted_codes)), "note_results": json.dumps(note_results), } diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index dc493a3db..0eb6babca 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -57,8 +57,7 @@ def __init__(self, *args, **kwargs): from .eicu import eICUDataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset -from .mimic3 import MIMIC3Dataset -from .mimic3_notes import MIMIC3NotesDataset +from .mimic3 import MIMIC3Dataset, MIMIC3NoteDataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 22ca79d5c..783e0e828 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -1,11 +1,15 @@ import logging import warnings from pathlib import Path -from typing import List, Optional +from typing import Dict, Iterable, List, Optional, Sequence, Set import narwhals as pl +import pandas as pd from .base_dataset import BaseDataset +from .sample_dataset import SampleDataset, create_sample_dataset +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask +from ..tasks.sdoh_utils import TARGET_CODES logger = logging.getLogger(__name__) @@ -85,3 +89,138 @@ def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: .alias("charttime") ) return df + + +class MIMIC3NoteDataset: + """Note-only loader for MIMIC-III NOTEEVENTS.""" + + def __init__( + self, + noteevents_path: Optional[str] = None, + root: Optional[str] = None, + target_codes: Optional[Sequence[str]] = None, + hadm_ids: Optional[Iterable[str]] = None, + include_categories: Optional[Sequence[str]] = None, + chunksize: int = 200_000, + dataset_name: Optional[str] = None, + ) -> None: + if noteevents_path is None: + if root is None: + raise ValueError("root is required when noteevents_path is not set.") + noteevents_path = str(Path(root) / "NOTEEVENTS.csv.gz") + self.noteevents_path = noteevents_path + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.hadm_ids = {str(x) for x in hadm_ids} if hadm_ids is not None else None + self.include_categories = ( + {cat.strip().upper() for cat in include_categories} + if include_categories + else None + ) + self.chunksize = chunksize + self.dataset_name = dataset_name or "mimic3_note" + + def _load_notes(self) -> Dict[str, List[Dict]]: + keep_cols = { + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + } + + notes_by_hadm: Dict[str, List[Dict]] = {} + for chunk in pd.read_csv( + self.noteevents_path, + chunksize=self.chunksize, + usecols=lambda c: c.upper() in keep_cols, + dtype={"SUBJECT_ID": "string", "HADM_ID": "string"}, + low_memory=False, + ): + chunk.columns = [c.upper() for c in chunk.columns] + if self.hadm_ids is not None: + chunk = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)] + if chunk.empty: + continue + if self.include_categories is not None: + chunk = chunk[ + chunk["CATEGORY"].astype("string") + .str.upper() + .isin(self.include_categories) + ] + if chunk.empty: + continue + + charttime = pd.to_datetime(chunk["CHARTTIME"], errors="coerce") + chartdate = pd.to_datetime(chunk["CHARTDATE"], errors="coerce") + timestamp = charttime.fillna(chartdate) + + for row, ts in zip(chunk.itertuples(index=False), timestamp): + hadm_id = str(row.HADM_ID) + entry = { + "patient_id": str(row.SUBJECT_ID) if pd.notna(row.SUBJECT_ID) else "", + "text": row.TEXT if pd.notna(row.TEXT) else "", + "category": row.CATEGORY if pd.notna(row.CATEGORY) else "", + "timestamp": ts, + } + notes_by_hadm.setdefault(hadm_id, []).append(entry) + + return notes_by_hadm + + def _build_admissions(self) -> List[Dict]: + notes_by_hadm = self._load_notes() + admissions: List[Dict] = [] + for hadm_id, notes in notes_by_hadm.items(): + notes.sort( + key=lambda x: x["timestamp"] + if pd.notna(x["timestamp"]) + else pd.Timestamp.min + ) + note_texts = [str(n["text"]) for n in notes] + note_categories = [str(n["category"]) for n in notes] + chartdates = [ + n["timestamp"].strftime("%Y-%m-%d") if pd.notna(n["timestamp"]) else "Unknown" + for n in notes + ] + + admissions.append( + { + "visit_id": hadm_id, + "patient_id": notes[0]["patient_id"], + "notes": note_texts, + "note_categories": note_categories, + "chartdates": chartdates, + "num_notes": len(note_texts), + "text_length": int(sum(len(note) for note in note_texts)), + } + ) + + logger.info("Loaded %d admissions from NOTEEVENTS", len(admissions)) + return admissions + + def set_task( + self, + task: Optional[SDOHICD9AdmissionTask] = None, + label_source: str = "manual", + label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, + in_memory: bool = True, + ) -> SampleDataset: + if task is None: + task = SDOHICD9AdmissionTask( + target_codes=self.target_codes, + label_source=label_source, + label_map=label_map, + ) + + samples: List[Dict] = [] + for admission in self._build_admissions(): + samples.extend(task(admission)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=in_memory, + ) diff --git a/pyhealth/datasets/mimic3_notes.py b/pyhealth/datasets/mimic3_notes.py deleted file mode 100644 index c0c223486..000000000 --- a/pyhealth/datasets/mimic3_notes.py +++ /dev/null @@ -1,146 +0,0 @@ -import logging -from typing import Dict, Iterable, List, Optional, Sequence - -import pandas as pd - -from .sample_dataset import SampleDataset, create_sample_dataset -from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask -from ..tasks.sdoh_utils import TARGET_CODES, load_sdoh_icd9_labels - -logger = logging.getLogger(__name__) - - -class MIMIC3NotesDataset: - """Note-only loader for MIMIC-III NOTEEVENTS with label CSV filtering.""" - - def __init__( - self, - noteevents_path: str, - label_csv_path: str, - target_codes: Optional[Sequence[str]] = None, - hadm_ids: Optional[Iterable[str]] = None, - include_categories: Optional[Sequence[str]] = None, - chunksize: int = 200_000, - dataset_name: Optional[str] = None, - ) -> None: - self.noteevents_path = noteevents_path - self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) - self.label_map = load_sdoh_icd9_labels(label_csv_path, self.target_codes) - if hadm_ids is None: - hadm_ids = self.label_map.keys() - self.hadm_ids = {str(x) for x in hadm_ids} - self.include_categories = ( - {cat.strip().upper() for cat in include_categories} - if include_categories - else None - ) - self.chunksize = chunksize - self.dataset_name = dataset_name or "mimic3_notes" - - def _load_notes(self) -> Dict[str, List[Dict]]: - keep_cols = { - "SUBJECT_ID", - "HADM_ID", - "CHARTDATE", - "CHARTTIME", - "CATEGORY", - "TEXT", - } - - notes_by_hadm: Dict[str, List[Dict]] = {} - for chunk in pd.read_csv( - self.noteevents_path, - chunksize=self.chunksize, - usecols=lambda c: c.upper() in keep_cols, - dtype={"SUBJECT_ID": "string", "HADM_ID": "string"}, - low_memory=False, - ): - chunk.columns = [c.upper() for c in chunk.columns] - filtered = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)] - if filtered.empty: - continue - if self.include_categories is not None: - filtered = filtered[ - filtered["CATEGORY"].astype("string") - .str.upper() - .isin(self.include_categories) - ] - if filtered.empty: - continue - - charttime = pd.to_datetime(filtered["CHARTTIME"], errors="coerce") - chartdate = pd.to_datetime(filtered["CHARTDATE"], errors="coerce") - timestamp = charttime.fillna(chartdate) - - for row, ts in zip(filtered.itertuples(index=False), timestamp): - hadm_id = str(row.HADM_ID) - entry = { - "patient_id": str(row.SUBJECT_ID) if pd.notna(row.SUBJECT_ID) else "", - "text": row.TEXT if pd.notna(row.TEXT) else "", - "category": row.CATEGORY if pd.notna(row.CATEGORY) else "", - "timestamp": ts, - } - notes_by_hadm.setdefault(hadm_id, []).append(entry) - - return notes_by_hadm - - def _build_admissions(self) -> List[Dict]: - notes_by_hadm = self._load_notes() - admissions: List[Dict] = [] - for hadm_id, notes in notes_by_hadm.items(): - if hadm_id not in self.label_map: - continue - notes.sort( - key=lambda x: x["timestamp"] - if pd.notna(x["timestamp"]) - else pd.Timestamp.min - ) - note_texts = [str(n["text"]) for n in notes] - note_categories = [str(n["category"]) for n in notes] - chartdates = [ - n["timestamp"].strftime("%Y-%m-%d") if pd.notna(n["timestamp"]) else "Unknown" - for n in notes - ] - - labels = self.label_map[hadm_id] - admissions.append( - { - "visit_id": hadm_id, - "patient_id": notes[0]["patient_id"], - "notes": note_texts, - "note_categories": note_categories, - "chartdates": chartdates, - "num_notes": len(note_texts), - "text_length": int(sum(len(note) for note in note_texts)), - "manual_codes": labels["manual"], - "true_codes": labels["true"], - } - ) - - logger.info("Loaded %d admissions from NOTEEVENTS", len(admissions)) - return admissions - - def set_task( - self, - task: Optional[SDOHICD9AdmissionTask] = None, - label_source: str = "manual", - in_memory: bool = True, - ) -> SampleDataset: - if task is None: - task = SDOHICD9AdmissionTask( - target_codes=self.target_codes, - label_source=label_source, - ) - - samples: List[Dict] = [] - for admission in self._build_admissions(): - samples.extend(task(admission)) - - return create_sample_dataset( - samples=samples, - input_schema=task.input_schema, - output_schema=task.output_schema, - dataset_name=self.dataset_name, - task_name=task.task_name, - in_memory=in_memory, - ) diff --git a/pyhealth/tasks/sdoh_icd9_detection.py b/pyhealth/tasks/sdoh_icd9_detection.py index 978307d02..5263f35df 100644 --- a/pyhealth/tasks/sdoh_icd9_detection.py +++ b/pyhealth/tasks/sdoh_icd9_detection.py @@ -25,13 +25,21 @@ def __init__( self, target_codes: Optional[Sequence[str]] = None, label_source: str = "manual", + label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, ) -> None: self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) if label_source not in {"manual", "true"}: raise ValueError("label_source must be 'manual' or 'true'") self.label_source = label_source + self.label_map = label_map or {} def __call__(self, admission: Dict) -> List[Dict]: + admission_id = str(admission.get("visit_id", "")) + if admission_id and admission_id in self.label_map: + admission = dict(admission) + admission.setdefault("manual_codes", self.label_map[admission_id]["manual"]) + admission.setdefault("true_codes", self.label_map[admission_id]["true"]) + if self.label_source == "manual": label_codes: Set[str] = admission.get("manual_codes", set()) else: From a481c941b0ac8725b3653587c915fabb0c06d20a Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sun, 25 Jan 2026 13:07:26 -0500 Subject: [PATCH 5/7] Embed SDOH prompt and add LLM tests --- pyhealth/models/sdoh_icd9_llm.py | 186 ++++++++++++++++++++++++++- pyhealth/models/sdoh_icd9_task.txt | 158 ----------------------- tests/core/test_sdoh_llm.py | 43 +++++++ tests/core/test_sdoh_mimic3_notes.py | 133 +++++++++++++++++++ 4 files changed, 357 insertions(+), 163 deletions(-) delete mode 100644 pyhealth/models/sdoh_icd9_task.txt create mode 100644 tests/core/test_sdoh_llm.py create mode 100644 tests/core/test_sdoh_mimic3_notes.py diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py index 71a6bda6d..55d0a59e0 100644 --- a/pyhealth/models/sdoh_icd9_llm.py +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -1,5 +1,5 @@ -import os import hashlib +import os import re import time from typing import Iterable, List, Optional, Sequence, Set, Tuple @@ -7,16 +7,192 @@ from pyhealth.tasks.sdoh_utils import TARGET_CODES -PROMPT_PATH = os.path.join(os.path.dirname(__file__), "sdoh_icd9_task.txt") +PROMPT_TEMPLATE = """\ +# SDOH Medical Coding Task + +You are an expert medical coder. Your job is to find Social Determinants of Health (SDOH) codes in clinical notes. + +## CRITICAL PREDICTION RULES: +1. **IF YOU FIND EVIDENCE FOR A CODE → YOU MUST PREDICT THAT CODE** +2. **BE CONSERVATIVE**: Only predict codes with STRONG evidence +3. **Evidence detection = Automatic code prediction** + +## CODE HIERARCHY - ABSOLUTELY CRITICAL: +**HOMELESS (V600) COMPLETELY EXCLUDES V602:** +- ❌ If you see homelessness → ONLY predict V600, NEVER V602 +- ❌ "Homeless and can't afford X" → STILL only V600 +- ✅ V602 is EXCLUSIVELY for housed people with explicit financial hardship + +**V602 ULTRA-STRICT REQUIREMENTS:** +- ✅ Must have exact quotes: "cannot afford medications", "no money for food" +- ❌ NEVER infer from: homelessness, unemployment, substance use, mental health +- ❌ NEVER predict from: "poor", "disadvantaged", social circumstances + + +## Available ICD-9 Codes: + +**Housing & Resources:** +- V600: Homeless (living on streets, shelters, hotels, motels, temporary housing) +- V602: Cannot afford basic needs. **ULTRA-STRICT RULES**: + - ✅ ONLY if EXPLICIT financial statements: "cannot afford medications", "unable to pay for treatment", "no money for food", "financial hardship preventing care" + - ❌ NEVER predict from social circumstances: substance abuse, unemployment, mental health, divorce + - ❌ NEVER predict from discharge inventory: "no money/wallet returned" + - ❌ NEVER predict for insurance/benefits mentions: "on disability", "has Medicaid" +- V604: No family/caregiver available. **CRITICAL RULES**: + - ❌ NOT just "lives alone" or "elderly" + - ✅ ONLY if: "no family contact" AND "no support" AND "no one to help" + - ❌ NOT for: "lives alone but daughter visits", "independent" + +**Employment & Legal:** +- V620: Unemployed (current employment status) — **EXPLICIT ONLY** + - ✅ **Predict only if one of these exact employment-status phrases appears (word boundaries, case-insensitive):** + - unemploy(ed|ment) + - jobless | no job + - out of work | without work (employment context only; see exclusions) + - not working (employment context only; see exclusions) + - between jobs + - on unemployment | receiving unemployment + - Recent loss: laid off | fired | terminated | lost (my|his|her|their)? job + - Label–value fields: Employment|Employment status|Work: → unemployed|none|not working ⇒ V620 + - ❌ Do NOT infer from former-only phrases: used to work | formerly | previously employed | last worked in ... UNLESS illness/disability is given as the reason for stopping work + - ❌ **Contradictions (block V620 if present nearby, ±1 sentence)**: employed | works at/as | working | return(ed) to work | full-time/part-time | self-employed | contractor | on leave/LOA | work excuse | at work today + - ❌ **Non-employment uses of "work" (never trigger)**: work of breathing | workup/work-up | social work | line not working | device not working | therapy not working | meds not working + - ❌ **Context rule for "not/without work"**: Only treat as unemployment when it clearly refers to employment status (e.g., in Social History/Employment section or followed by "due to X" re job). Otherwise, do not code + - ❌ **Exclusions**: retired | student | stay-at-home parent | on medical leave — unless an explicit unemployment phrase above is present + - ✅ Include cases where patient stopped working or is unable to work due to illness/disability (including SSDI/SSI). +- V625: Legal problems. **ACTIVE LEGAL PROCESSES ONLY**: + - ✅ Criminal: Arrest, jail, prison, parole, probation, active charges, bail, or court case pending + - ✅ Civil: Court-ordered custody, restraining order, supervised release, or active litigation + - ✅ Guardianship: Legal capacity hearing, court-appointed guardian, or power of attorney via court + - ✅ Child welfare: DCF/CPS removed children, placed them in care, or court-ordered custody change + - ❌ Do NOT predict V625 for: + - History of crime, substance use, or homelessness without active legal process + - Drug use history alone (marijuana, cocaine, etc.) without legal consequences + - Psych history, social issues, or missed appointments without legal involvement + - DCF "awareness" without custody change or legal action + +**Health History:** +- V1541: Physical/sexual abuse (violence BY another person). **CRITICAL**: + - ✅ Physical violence, sexual abuse, assault, domestic violence, rape + - ❌ NOT accidents, falls, fights where patient was aggressor +- V1542: Pure emotional/psychological abuse. **MUTUALLY EXCLUSIVE WITH V1541**: + - ✅ **ONLY if NO physical/sexual abuse mentioned AND explicit emotional abuse:** + - ✅ Witnessed violence: "witnessed violence as a child", "saw domestic violence" + - ✅ Verbal abuse: "verbal abuse", "emotionally abusive", "psychological abuse" + - ✅ Emotional manipulation: "jealous and controlling", "isolation", "intimidation" + - ❌ **NEVER predict V1542 if ANY physical/sexual abuse mentioned (use V1541 instead)** + - ❌ **NEVER predict both V1541 AND V1542 for same patient** + - ❌ Depression, anxiety, suicidal ideation alone without explicit abuse = NO CODE + - ❌ "History of abuse" without specifying type = NO CODE + - ❌ Psychiatric history alone = NO CODE +- V1584: Past asbestos exposure + +**Family History:** +- V6141: Family alcoholism (family member drinks) — PREDICT if any kinship term + alcohol term appear together. + - ✅ Kinship terms: father, mother, dad, mom, brother, sister, son, daughter, uncle, aunt, cousin, grand*; Headers: FH, FHx, Family History + - ✅ Alcohol terms (case-insensitive, synonyms OK): ETOH/EtOH/etoh, alcohol, alcoholism, AUD, alcohol use disorder, EtOH abuse, EtOH dependence, "alcohol problem(s)", "drinks heavily", "alcoholic" + - ✅ Mixed substance OK: If text says "drug and alcohol problems," still predict V6141 + - ✅ Outside headers OK: If kinship + alcohol appear in same clause/sentence or within ±1 line (e.g., "Pt's father ... has hx of etoh"), predict V6141 + - ✅ Examples to capture: "FH: Father – ETOH", "Mother has h/o alcoholism", "Father with depression and alcoholism", "Multiple family members with ETOH abuse ... (cousin, sister, uncle, aunt, father)", "Both brothers have drug and alcohol problems" + - ❌ Negations: Do not predict if explicitly denied (e.g., "denies family history of alcoholism") + - ❌ NEVER for PATIENT'S own history: "history of alcohol abuse", "with a history of alcohol use", "past medical history significant for heavy EtOH abuse", "patient alcoholic", "ETOH abuse" + +## ENHANCED NEGATIVE EXAMPLES: + +**V602 FALSE POSITIVES TO AVOID:** +❌ "Homeless patient" → Predict V600 ONLY, NEVER V602 +❌ "Lives in shelter, gets food stamps" → V600 ONLY, NEVER V602 +❌ "Homeless, on disability" → V600 ONLY, NEVER V602 +❌ "No permanent address, has Medicaid" → V600 ONLY, NEVER V602 +❌ "Homeless and can't afford medications" → V600 ONLY, NEVER V602 +❌ "Unemployed alcoholic" → V620 (unemployment is explicit), NEVER V602 +❌ "Lives in poverty" → NEVER V602 (too vague) +❌ "Financial strain from divorce" → NEVER V602 (circumstantial) + +**V604 FALSE POSITIVES TO AVOID:** +❌ "82 year old lives alone" → NO CODE unless no support mentioned +❌ "Lives by herself" → NO CODE unless isolation confirmed +❌ "Widowed, lives alone, son calls daily" → NO CODE (has support) + +**V1542 FALSE POSITIVES TO AVOID:** +❌ "History of physical and sexual abuse" → V1541 ONLY (physical trumps emotional) +❌ "PTSD from rape at age 7" → V1541 ONLY (sexual abuse) +❌ "Childhood sexual abuse by uncle" → V1541 ONLY (sexual abuse) +❌ "History of domestic abuse" → V1541 ONLY (physical abuse) +❌ "Depression and anxiety" → NO CODE (psychiatric symptoms alone) +❌ "Suicide attempts" → NO CODE (mental health history alone) +❌ "History of abuse" → NO CODE (unspecified type) +❌ "Recent argument with partner" → NO CODE (relationship conflict) + +**V1542 TRUE POSITIVES TO CAPTURE:** +✅ "Witnessed violence as a child" → V1542 (pure emotional trauma, no physical) +✅ "Emotionally abusive relationship for 14 years" → V1542 (explicit emotional abuse) +✅ "Verbal abuse from controlling partner" → V1542 (explicit emotional abuse) +✅ "Jealous and controlling behavior" → V1542 (emotional manipulation) + +## CONFIDENCE RULES: + +**HIGH CONFIDENCE (Predict):** +- Direct statement of condition +- Multiple supporting evidence pieces +- Explicit language matching code definition + +**LOW CONFIDENCE (Don't Predict):** +- Ambiguous language +- Single weak indicator +- Contradictory evidence + +## Key Rules: + +1. **Precision over Recall**: Better to miss a code than falsely predict +2. **Evidence-Driven**: Strong evidence required for prediction +3. **Multiple codes allowed**: But each needs independent evidence +4. **Conservative approach**: When in doubt, don't predict + +## Output Format: +Return applicable codes separated by commas, or "None" if no codes apply. + +Example: +``` +V600, V625 +``` + +or if no codes apply: +``` +None +``` + +--- + +**Clinical Note to Analyze:** +{note} +""" def _load_prompt_template() -> str: - with open(PROMPT_PATH, "r", encoding="utf-8") as f: - return f.read() + return PROMPT_TEMPLATE class SDOHICD9LLM: - """Admission-level SDOH ICD-9 V-code detector using an LLM.""" + """Admission-level SDOH ICD-9 V-code detector using an LLM. + + This model runs an LLM on each note for an admission, parses the predicted + ICD-9 V-codes, and aggregates predictions across notes (union). + + Examples: + >>> from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM + >>> notes = [ + ... "Pt is homeless and has no family support.", + ... "Social work consulted for housing resources.", + ... ] + >>> model = SDOHICD9LLM(dry_run=True, max_notes=2) + >>> codes, note_results = model.predict_admission_with_notes(notes) + >>> codes + set() + + >>> model = SDOHICD9LLM(model_name="gpt-4o-mini", max_notes=1) + >>> codes, note_results = model.predict_admission_with_notes(notes) + """ def __init__( self, diff --git a/pyhealth/models/sdoh_icd9_task.txt b/pyhealth/models/sdoh_icd9_task.txt deleted file mode 100644 index ed2f847ef..000000000 --- a/pyhealth/models/sdoh_icd9_task.txt +++ /dev/null @@ -1,158 +0,0 @@ -# SDOH Medical Coding Task - -You are an expert medical coder. Your job is to find Social Determinants of Health (SDOH) codes in clinical notes. - -## CRITICAL PREDICTION RULES: -1. **IF YOU FIND EVIDENCE FOR A CODE → YOU MUST PREDICT THAT CODE** -2. **BE CONSERVATIVE**: Only predict codes with STRONG evidence -3. **Evidence detection = Automatic code prediction** - -## CODE HIERARCHY - ABSOLUTELY CRITICAL: -**HOMELESS (V600) COMPLETELY EXCLUDES V602:** -- ❌ If you see homelessness → ONLY predict V600, NEVER V602 -- ❌ "Homeless and can't afford X" → STILL only V600 -- ✅ V602 is EXCLUSIVELY for housed people with explicit financial hardship - -**V602 ULTRA-STRICT REQUIREMENTS:** -- ✅ Must have exact quotes: "cannot afford medications", "no money for food" -- ❌ NEVER infer from: homelessness, unemployment, substance use, mental health -- ❌ NEVER predict from: "poor", "disadvantaged", social circumstances - - -## Available ICD-9 Codes: - -**Housing & Resources:** -- V600: Homeless (living on streets, shelters, hotels, motels, temporary housing) -- V602: Cannot afford basic needs. **ULTRA-STRICT RULES**: - - ✅ ONLY if EXPLICIT financial statements: "cannot afford medications", "unable to pay for treatment", "no money for food", "financial hardship preventing care" - - ❌ NEVER predict from social circumstances: substance abuse, unemployment, mental health, divorce - - ❌ NEVER predict from discharge inventory: "no money/wallet returned" - - ❌ NEVER predict for insurance/benefits mentions: "on disability", "has Medicaid" -- V604: No family/caregiver available. **CRITICAL RULES**: - - ❌ NOT just "lives alone" or "elderly" - - ✅ ONLY if: "no family contact" AND "no support" AND "no one to help" - - ❌ NOT for: "lives alone but daughter visits", "independent" - -**Employment & Legal:** -- V620: Unemployed (current employment status) — **EXPLICIT ONLY** - - ✅ **Predict only if one of these exact employment-status phrases appears (word boundaries, case-insensitive):** - - unemploy(ed|ment) - - jobless | no job - - out of work | without work (employment context only; see exclusions) - - not working (employment context only; see exclusions) - - between jobs - - on unemployment | receiving unemployment - - Recent loss: laid off | fired | terminated | lost (my|his|her|their)? job - - Label–value fields: Employment|Employment status|Work: → unemployed|none|not working ⇒ V620 - - ❌ Do NOT infer from former-only phrases: used to work | formerly | previously employed | last worked in ... UNLESS illness/disability is given as the reason for stopping work - - ❌ **Contradictions (block V620 if present nearby, ±1 sentence)**: employed | works at/as | working | return(ed) to work | full-time/part-time | self-employed | contractor | on leave/LOA | work excuse | at work today - - ❌ **Non-employment uses of "work" (never trigger)**: work of breathing | workup/work-up | social work | line not working | device not working | therapy not working | meds not working - - ❌ **Context rule for "not/without work"**: Only treat as unemployment when it clearly refers to employment status (e.g., in Social History/Employment section or followed by "due to X" re job). Otherwise, do not code - - ❌ **Exclusions**: retired | student | stay-at-home parent | on medical leave — unless an explicit unemployment phrase above is present - - ✅ Include cases where patient stopped working or is unable to work due to illness/disability (including SSDI/SSI). -- V625: Legal problems. **ACTIVE LEGAL PROCESSES ONLY**: - - ✅ Criminal: Arrest, jail, prison, parole, probation, active charges, bail, or court case pending - - ✅ Civil: Court-ordered custody, restraining order, supervised release, or active litigation - - ✅ Guardianship: Legal capacity hearing, court-appointed guardian, or power of attorney via court - - ✅ Child welfare: DCF/CPS removed children, placed them in care, or court-ordered custody change - - ❌ Do NOT predict V625 for: - - History of crime, substance use, or homelessness without active legal process - - Drug use history alone (marijuana, cocaine, etc.) without legal consequences - - Psych history, social issues, or missed appointments without legal involvement - - DCF "awareness" without custody change or legal action - -**Health History:** -- V1541: Physical/sexual abuse (violence BY another person). **CRITICAL**: - - ✅ Physical violence, sexual abuse, assault, domestic violence, rape - - ❌ NOT accidents, falls, fights where patient was aggressor -- V1542: Pure emotional/psychological abuse. **MUTUALLY EXCLUSIVE WITH V1541**: - - ✅ **ONLY if NO physical/sexual abuse mentioned AND explicit emotional abuse:** - - ✅ Witnessed violence: "witnessed violence as a child", "saw domestic violence" - - ✅ Verbal abuse: "verbal abuse", "emotionally abusive", "psychological abuse" - - ✅ Emotional manipulation: "jealous and controlling", "isolation", "intimidation" - - ❌ **NEVER predict V1542 if ANY physical/sexual abuse mentioned (use V1541 instead)** - - ❌ **NEVER predict both V1541 AND V1542 for same patient** - - ❌ Depression, anxiety, suicidal ideation alone without explicit abuse = NO CODE - - ❌ "History of abuse" without specifying type = NO CODE - - ❌ Psychiatric history alone = NO CODE -- V1584: Past asbestos exposure - -**Family History:** -- V6141: Family alcoholism (family member drinks) — PREDICT if any kinship term + alcohol term appear together. - - ✅ Kinship terms: father, mother, dad, mom, brother, sister, son, daughter, uncle, aunt, cousin, grand*; Headers: FH, FHx, Family History - - ✅ Alcohol terms (case-insensitive, synonyms OK): ETOH/EtOH/etoh, alcohol, alcoholism, AUD, alcohol use disorder, EtOH abuse, EtOH dependence, "alcohol problem(s)", "drinks heavily", "alcoholic" - - ✅ Mixed substance OK: If text says "drug and alcohol problems," still predict V6141 - - ✅ Outside headers OK: If kinship + alcohol appear in same clause/sentence or within ±1 line (e.g., "Pt's father ... has hx of etoh"), predict V6141 - - ✅ Examples to capture: "FH: Father – ETOH", "Mother has h/o alcoholism", "Father with depression and alcoholism", "Multiple family members with ETOH abuse ... (cousin, sister, uncle, aunt, father)", "Both brothers have drug and alcohol problems" - - ❌ Negations: Do not predict if explicitly denied (e.g., "denies family history of alcoholism") - - ❌ NEVER for PATIENT'S own history: "history of alcohol abuse", "with a history of alcohol use", "past medical history significant for heavy EtOH abuse", "patient alcoholic", "ETOH abuse" - -## ENHANCED NEGATIVE EXAMPLES: - -**V602 FALSE POSITIVES TO AVOID:** -❌ "Homeless patient" → Predict V600 ONLY, NEVER V602 -❌ "Lives in shelter, gets food stamps" → V600 ONLY, NEVER V602 -❌ "Homeless, on disability" → V600 ONLY, NEVER V602 -❌ "No permanent address, has Medicaid" → V600 ONLY, NEVER V602 -❌ "Homeless and can't afford medications" → V600 ONLY, NEVER V602 -❌ "Unemployed alcoholic" → V620 (unemployment is explicit), NEVER V602 -❌ "Lives in poverty" → NEVER V602 (too vague) -❌ "Financial strain from divorce" → NEVER V602 (circumstantial) - -**V604 FALSE POSITIVES TO AVOID:** -❌ "82 year old lives alone" → NO CODE unless no support mentioned -❌ "Lives by herself" → NO CODE unless isolation confirmed -❌ "Widowed, lives alone, son calls daily" → NO CODE (has support) - -**V1542 FALSE POSITIVES TO AVOID:** -❌ "History of physical and sexual abuse" → V1541 ONLY (physical trumps emotional) -❌ "PTSD from rape at age 7" → V1541 ONLY (sexual abuse) -❌ "Childhood sexual abuse by uncle" → V1541 ONLY (sexual abuse) -❌ "History of domestic abuse" → V1541 ONLY (physical abuse) -❌ "Depression and anxiety" → NO CODE (psychiatric symptoms alone) -❌ "Suicide attempts" → NO CODE (mental health history alone) -❌ "History of abuse" → NO CODE (unspecified type) -❌ "Recent argument with partner" → NO CODE (relationship conflict) - -**V1542 TRUE POSITIVES TO CAPTURE:** -✅ "Witnessed violence as a child" → V1542 (pure emotional trauma, no physical) -✅ "Emotionally abusive relationship for 14 years" → V1542 (explicit emotional abuse) -✅ "Verbal abuse from controlling partner" → V1542 (explicit emotional abuse) -✅ "Jealous and controlling behavior" → V1542 (emotional manipulation) - -## CONFIDENCE RULES: - -**HIGH CONFIDENCE (Predict):** -- Direct statement of condition -- Multiple supporting evidence pieces -- Explicit language matching code definition - -**LOW CONFIDENCE (Don't Predict):** -- Ambiguous language -- Single weak indicator -- Contradictory evidence - -## Key Rules: - -1. **Precision over Recall**: Better to miss a code than falsely predict -2. **Evidence-Driven**: Strong evidence required for prediction -3. **Multiple codes allowed**: But each needs independent evidence -4. **Conservative approach**: When in doubt, don't predict - -## Output Format: -Return applicable codes separated by commas, or "None" if no codes apply. - -Example: -``` -V600, V625 -``` - -or if no codes apply: -``` -None -``` - ---- - -**Clinical Note to Analyze:** -{note} \ No newline at end of file diff --git a/tests/core/test_sdoh_llm.py b/tests/core/test_sdoh_llm.py new file mode 100644 index 000000000..2bc7e99e9 --- /dev/null +++ b/tests/core/test_sdoh_llm.py @@ -0,0 +1,43 @@ +from unittest import mock + +from base import BaseTestCase +from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM + + +class TestSdohLLM(BaseTestCase): + def setUp(self): + self.set_random_seed() + + def test_llm_aggregation(self): + model = SDOHICD9LLM(dry_run=True) + notes = ["note one", "note two", "note three"] + + responses = iter( + [ + "```V600```", + "None", + "```V620, V625```", + ] + ) + + with mock.patch.object(model, "_call_openai_api", side_effect=lambda _: next(responses)): + with mock.patch.object(model, "_write_prompt_preview", return_value=None): + aggregated, note_results = model.predict_admission_with_notes(notes) + + self.assertEqual({"V600", "V620", "V625"}, aggregated) + self.assertEqual(3, len(note_results)) + self.assertEqual({"V600"}, set(note_results[0]["predicted_codes"])) + self.assertEqual(set(), set(note_results[1]["predicted_codes"])) + self.assertEqual({"V620", "V625"}, set(note_results[2]["predicted_codes"])) + + def test_llm_max_notes(self): + model = SDOHICD9LLM(dry_run=True, max_notes=1) + notes = ["note one", "note two"] + + with mock.patch.object(model, "_call_openai_api", return_value="V600") as mocked: + with mock.patch.object(model, "_write_prompt_preview", return_value=None): + aggregated, note_results = model.predict_admission_with_notes(notes) + + self.assertEqual({"V600"}, aggregated) + self.assertEqual(1, len(note_results)) + mocked.assert_called_once() diff --git a/tests/core/test_sdoh_mimic3_notes.py b/tests/core/test_sdoh_mimic3_notes.py new file mode 100644 index 000000000..0cf63c293 --- /dev/null +++ b/tests/core/test_sdoh_mimic3_notes.py @@ -0,0 +1,133 @@ +import csv +import os +import tempfile + +from base import BaseTestCase +from pyhealth.datasets import MIMIC3NoteDataset +from pyhealth.tasks.sdoh_utils import TARGET_CODES, codes_to_multihot + + +class TestSdohMimic3Notes(BaseTestCase): + def setUp(self): + self.set_random_seed() + + def test_mimic3_note_dataset(self): + """Test MIMIC3NoteDataset with label filtering and categories.""" + with tempfile.TemporaryDirectory() as tmpdir: + noteevents_path = os.path.join(tmpdir, "NOTEEVENTS.csv") + with open(noteevents_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 12:00:00", + "Physician", + "Pt is homeless", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 13:00:00", + "Radiology", + "XR chest normal", + ] + ) + writer.writerow( + [ + "2", + "20", + "2020-02-01", + "2020-02-01 09:00:00", + "Physician", + "No issues reported", + ] + ) + + label_map = { + "10": {"manual": {"V600"}, "true": set()}, + "20": {"manual": set(), "true": set()}, + } + + dataset = MIMIC3NoteDataset( + noteevents_path=noteevents_path, + hadm_ids=["10"], + include_categories=["Physician"], + ) + sample_dataset = dataset.set_task( + label_source="manual", + label_map=label_map, + ) + + self.assertEqual(1, len(sample_dataset)) + sample = next(iter(sample_dataset)) + self.assertEqual("10", sample["visit_id"]) + self.assertEqual(["Pt is homeless"], sample["notes"]) + self.assertEqual(["Physician"], sample["note_categories"]) + + expected_label = codes_to_multihot({"V600"}, TARGET_CODES).tolist() + self.assertEqual(expected_label, sample["label"].tolist()) + + def test_note_sorting_and_chartdate_fallback(self): + """Test ordering with missing charttime and chartdate fallback.""" + with tempfile.TemporaryDirectory() as tmpdir: + noteevents_path = os.path.join(tmpdir, "NOTEEVENTS.csv") + with open(noteevents_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-02", + "", + "Physician", + "Later note", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 08:00:00", + "Physician", + "Earlier note", + ] + ) + + label_map = {"10": {"manual": {"V600"}, "true": set()}} + dataset = MIMIC3NoteDataset( + noteevents_path=noteevents_path, + include_categories=["Physician"], + ) + sample_dataset = dataset.set_task( + label_source="manual", + label_map=label_map, + ) + + sample = next(iter(sample_dataset)) + self.assertEqual(["Earlier note", "Later note"], sample["notes"]) From 18a272217f6a0538ecce1797af821badd6aa1290 Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sun, 25 Jan 2026 13:44:06 -0500 Subject: [PATCH 6/7] Add docstrings for SDOH loaders and LLM --- examples/sdoh_icd9_llm_eval.py | 2 ++ pyhealth/datasets/mimic3.py | 20 +++++++++++++++++++- pyhealth/models/sdoh_icd9_llm.py | 6 ++++++ pyhealth/tasks/sdoh_icd9_detection.py | 14 +++++++++++++- 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py index aa896de3a..5db4f1cde 100644 --- a/examples/sdoh_icd9_llm_eval.py +++ b/examples/sdoh_icd9_llm_eval.py @@ -13,6 +13,7 @@ def parse_args(): + """Parse CLI arguments for SDOH ICD-9 note evaluation.""" parser = argparse.ArgumentParser( description="Admission-level SDOH ICD-9 evaluation with per-note LLM calls." ) @@ -44,6 +45,7 @@ def parse_args(): def main(): + """Run admission-level evaluation with per-note LLM calls.""" args = parse_args() target_codes = list(TARGET_CODES) label_map = load_sdoh_icd9_labels(args.label_csv_path, target_codes) diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 783e0e828..d8c2f117c 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -92,7 +92,11 @@ def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: class MIMIC3NoteDataset: - """Note-only loader for MIMIC-III NOTEEVENTS.""" + """Note-only loader for MIMIC-III NOTEEVENTS. + + This loader streams NOTEEVENTS in chunks and optionally filters by HADM_IDs + and note categories. Use set_task() to convert admissions into samples. + """ def __init__( self, @@ -104,6 +108,17 @@ def __init__( chunksize: int = 200_000, dataset_name: Optional[str] = None, ) -> None: + """Initialize the note-only loader. + + Args: + noteevents_path: Path to NOTEEVENTS CSV/CSV.GZ. + root: MIMIC-III root directory (used if noteevents_path is None). + target_codes: Target ICD-9 codes for label vector construction. + hadm_ids: Optional admission IDs to include. + include_categories: Optional NOTE_CATEGORY values to include. + chunksize: Number of rows per chunk read. + dataset_name: Optional dataset name for the SampleDataset. + """ if noteevents_path is None: if root is None: raise ValueError("root is required when noteevents_path is not set.") @@ -120,6 +135,7 @@ def __init__( self.dataset_name = dataset_name or "mimic3_note" def _load_notes(self) -> Dict[str, List[Dict]]: + """Load and group note events by admission.""" keep_cols = { "SUBJECT_ID", "HADM_ID", @@ -168,6 +184,7 @@ def _load_notes(self) -> Dict[str, List[Dict]]: return notes_by_hadm def _build_admissions(self) -> List[Dict]: + """Build admission-level note bundles with timestamps and categories.""" notes_by_hadm = self._load_notes() admissions: List[Dict] = [] for hadm_id, notes in notes_by_hadm.items(): @@ -205,6 +222,7 @@ def set_task( label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, in_memory: bool = True, ) -> SampleDataset: + """Apply a task to admissions and return a SampleDataset.""" if task is None: task = SDOHICD9AdmissionTask( target_codes=self.target_codes, diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py index 55d0a59e0..4c680ea4c 100644 --- a/pyhealth/models/sdoh_icd9_llm.py +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -225,6 +225,7 @@ def __init__( ) def _get_client(self): + """Initialize and cache the OpenAI client.""" if self._client is None: from openai import OpenAI @@ -232,6 +233,7 @@ def _get_client(self): return self._client def _call_openai_api(self, text: str) -> str: + """Send a single note to the LLM and return the raw response.""" self._write_prompt_preview(text) if self.dry_run: @@ -252,6 +254,7 @@ def _call_openai_api(self, text: str) -> str: return response.choices[0].message.content.strip() def _write_prompt_preview(self, text: str) -> None: + """Write the fully rendered prompt (with note) to a local file.""" prompt = self.prompt_template.format(note=text) digest = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10] filename = f"sdoh_prompt_{digest}.txt" @@ -259,6 +262,7 @@ def _write_prompt_preview(self, text: str) -> None: f.write(prompt) def _parse_llm_response(self, response: str) -> Set[str]: + """Parse the LLM response into a set of valid target codes.""" if not response: return set() @@ -289,6 +293,7 @@ def _predict_admission( note_categories: Optional[Iterable[str]] = None, chartdates: Optional[Iterable[str]] = None, ) -> Tuple[Set[str], List[dict]]: + """Run per-note predictions and aggregate codes for one admission.""" aggregated: Set[str] = set() note_results: List[dict] = [] categories = list(note_categories) if note_categories is not None else [] @@ -325,4 +330,5 @@ def predict_admission_with_notes( note_categories: Optional[Iterable[str]] = None, chartdates: Optional[Iterable[str]] = None, ) -> Tuple[Set[str], List[dict]]: + """Public helper to predict and return codes for one admission.""" return self._predict_admission(notes, note_categories, chartdates) diff --git a/pyhealth/tasks/sdoh_icd9_detection.py b/pyhealth/tasks/sdoh_icd9_detection.py index 5263f35df..16cdace7b 100644 --- a/pyhealth/tasks/sdoh_icd9_detection.py +++ b/pyhealth/tasks/sdoh_icd9_detection.py @@ -7,7 +7,11 @@ class SDOHICD9AdmissionTask(BaseTask): - """Builds admission-level samples for SDOH ICD-9 V-code detection.""" + """Builds admission-level samples for SDOH ICD-9 V-code detection. + + The task attaches a multi-hot label vector and the corresponding label + codes based on the provided admission dictionary (and optional label_map). + """ task_name: str = "SDOHICD9Admission" input_schema: Dict[str, str] = { @@ -27,6 +31,13 @@ def __init__( label_source: str = "manual", label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, ) -> None: + """Initialize the admission-level task. + + Args: + target_codes: Target ICD-9 codes for label vector construction. + label_source: Which label set to use ("manual" or "true"). + label_map: Optional mapping of HADM_ID to label codes. + """ self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) if label_source not in {"manual", "true"}: raise ValueError("label_source must be 'manual' or 'true'") @@ -34,6 +45,7 @@ def __init__( self.label_map = label_map or {} def __call__(self, admission: Dict) -> List[Dict]: + """Build a single labeled sample from an admission dictionary.""" admission_id = str(admission.get("visit_id", "")) if admission_id and admission_id in self.label_map: admission = dict(admission) From b3515216771bef5f120b568066851f055a4ce784 Mon Sep 17 00:00:00 2001 From: khancepts101 Date: Sun, 25 Jan 2026 13:54:20 -0500 Subject: [PATCH 7/7] Expand SDOH LLM docstrings --- pyhealth/models/sdoh_icd9_llm.py | 60 ++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py index 4c680ea4c..ebbd60947 100644 --- a/pyhealth/models/sdoh_icd9_llm.py +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -176,8 +176,12 @@ def _load_prompt_template() -> str: class SDOHICD9LLM: """Admission-level SDOH ICD-9 V-code detector using an LLM. - This model runs an LLM on each note for an admission, parses the predicted - ICD-9 V-codes, and aggregates predictions across notes (union). + This model sends each note for an admission to an LLM, parses predicted + ICD-9 V-codes, and aggregates the codes across notes (set union). + + Notes: + - Use ``dry_run=True`` to skip LLM calls while exercising the pipeline. + - Predictions are derived entirely from the LLM response parsing logic. Examples: >>> from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM @@ -207,6 +211,21 @@ def __init__( max_notes: Optional[int] = None, dry_run: bool = False, ) -> None: + """Initialize the LLM wrapper. + + Args: + target_codes: Target ICD-9 codes to retain after parsing. + model_name: OpenAI model name. + prompt_template: Optional prompt template override. Uses built-in + SDOH template if not provided. + api_key: OpenAI API key. Defaults to ``OPENAI_API_KEY`` env var. + max_tokens: Max tokens for LLM response. + max_chars: Max chars from each note to send. + temperature: LLM temperature. + sleep_s: Delay between per-note requests (seconds). + max_notes: Optional limit on notes per admission. + dry_run: If True, skips API calls and returns "None" responses. + """ self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) self.model_name = model_name self.prompt_template = prompt_template or _load_prompt_template() @@ -233,7 +252,14 @@ def _get_client(self): return self._client def _call_openai_api(self, text: str) -> str: - """Send a single note to the LLM and return the raw response.""" + """Send a single note to the LLM and return the raw response. + + Args: + text: Note text to send. + + Returns: + Raw string response from the LLM. + """ self._write_prompt_preview(text) if self.dry_run: @@ -262,7 +288,11 @@ def _write_prompt_preview(self, text: str) -> None: f.write(prompt) def _parse_llm_response(self, response: str) -> Set[str]: - """Parse the LLM response into a set of valid target codes.""" + """Parse the LLM response into a set of valid target codes. + + Returns: + A set of ICD-9 codes intersected with ``target_codes``. + """ if not response: return set() @@ -293,7 +323,16 @@ def _predict_admission( note_categories: Optional[Iterable[str]] = None, chartdates: Optional[Iterable[str]] = None, ) -> Tuple[Set[str], List[dict]]: - """Run per-note predictions and aggregate codes for one admission.""" + """Run per-note predictions and aggregate codes for one admission. + + Args: + notes: Iterable of note texts. + note_categories: Optional note categories aligned to ``notes``. + chartdates: Optional chart dates aligned to ``notes``. + + Returns: + A tuple of (aggregated_codes, per_note_results). + """ aggregated: Set[str] = set() note_results: List[dict] = [] categories = list(note_categories) if note_categories is not None else [] @@ -330,5 +369,14 @@ def predict_admission_with_notes( note_categories: Optional[Iterable[str]] = None, chartdates: Optional[Iterable[str]] = None, ) -> Tuple[Set[str], List[dict]]: - """Public helper to predict and return codes for one admission.""" + """Predict codes for one admission using per-note LLM calls. + + Args: + notes: Iterable of note texts. + note_categories: Optional note categories aligned to ``notes``. + chartdates: Optional chart dates aligned to ``notes``. + + Returns: + A tuple of (aggregated_codes, per_note_results). + """ return self._predict_admission(notes, note_categories, chartdates)