diff --git a/examples/sdoh_icd9_llm_eval.py b/examples/sdoh_icd9_llm_eval.py new file mode 100644 index 000000000..5db4f1cde --- /dev/null +++ b/examples/sdoh_icd9_llm_eval.py @@ -0,0 +1,189 @@ +import argparse +import json +import os +from datetime import datetime +from typing import List, Set + +import numpy as np +from pyhealth.datasets import MIMIC3NoteDataset +from pyhealth.metrics import multilabel_metrics_fn +from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM +from pyhealth.tasks.sdoh_icd9_detection import TARGET_CODES +from pyhealth.tasks.sdoh_utils import codes_to_multihot, load_sdoh_icd9_labels + + +def parse_args(): + """Parse CLI arguments for SDOH ICD-9 note evaluation.""" + parser = argparse.ArgumentParser( + description="Admission-level SDOH ICD-9 evaluation with per-note LLM calls." + ) + parser.add_argument("--mimic-root", required=True, help="Root folder for MIMIC-III CSVs") + parser.add_argument("--label-csv-path", required=True, help="Path to sdoh_icd9_dataset.csv") + parser.add_argument( + "--label-source", + default="manual", + choices=["manual", "true"], + help="Which labels to use as primary ground truth.", + ) + parser.add_argument( + "--max-notes", + default="all", + help="Limit notes per admission (e.g., 1, 2, 5, or 'all').", + ) + parser.add_argument( + "--max-admissions", + default="all", + help="Limit admissions to process (e.g., 5 or 'all').", + ) + parser.add_argument( + "--note-categories", + help="Comma-separated NOTE_CATEGORY values to include (optional).", + ) + parser.add_argument("--output-dir", default=".", help="Directory to save outputs.") + parser.add_argument("--dry-run", action="store_true") + return parser.parse_args() + + +def main(): + """Run admission-level evaluation with per-note LLM calls.""" + args = parse_args() + target_codes = list(TARGET_CODES) + label_map = load_sdoh_icd9_labels(args.label_csv_path, target_codes) + + include_categories = ( + [cat.strip() for cat in args.note_categories.split(",")] + if args.note_categories + else None + ) + if str(args.max_notes).lower() == "all": + max_notes = None + else: + try: + max_notes = int(args.max_notes) + except ValueError as exc: + raise ValueError("--max-notes must be an integer or 'all'") from exc + if max_notes <= 0: + raise ValueError("--max-notes must be a positive integer or 'all'") + if str(args.max_admissions).lower() == "all": + max_admissions = None + else: + try: + max_admissions = int(args.max_admissions) + except ValueError as exc: + raise ValueError("--max-admissions must be an integer or 'all'") from exc + if max_admissions <= 0: + raise ValueError("--max-admissions must be a positive integer or 'all'") + + hadm_ids = list(label_map.keys()) + if max_admissions is not None: + hadm_ids = hadm_ids[:max_admissions] + label_map = {hadm_id: label_map[hadm_id] for hadm_id in hadm_ids} + + note_dataset = MIMIC3NoteDataset( + root=args.mimic_root, + target_codes=target_codes, + hadm_ids=hadm_ids, + include_categories=include_categories, + ) + sample_dataset = note_dataset.set_task( + label_source=args.label_source, + label_map=label_map, + ) + + dry_run = args.dry_run or not os.environ.get("OPENAI_API_KEY") + model = SDOHICD9LLM( + target_codes=target_codes, + dry_run=dry_run, + max_notes=max_notes, + ) + + results = [] + predicted_codes_all: List[Set[str]] = [] + manual_codes_all: List[Set[str]] = [] + true_codes_all: List[Set[str]] = [] + + for sample in sample_dataset: + predicted_codes, note_results = model.predict_admission_with_notes( + sample["notes"], + sample.get("note_categories"), + sample.get("chartdates"), + ) + predicted_codes_all.append(predicted_codes) + visit_id = str(sample.get("visit_id", "")) + label_entry = label_map.get(visit_id, {"manual": set(), "true": set()}) + manual_codes = set(label_entry["manual"]) + true_codes = set(label_entry["true"]) + manual_codes_all.append(manual_codes) + true_codes_all.append(true_codes) + + results.append( + { + "visit_id": sample.get("visit_id"), + "patient_id": sample.get("patient_id"), + "num_notes": sample.get("num_notes"), + "text_length": sample.get("text_length"), + "is_gap_case": sample.get("is_gap_case"), + "manual_codes": ",".join(sorted(manual_codes)), + "true_codes": ",".join(sorted(true_codes)), + "predicted_codes": ",".join(sorted(predicted_codes)), + "note_results": json.dumps(note_results), + } + ) + + y_pred = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in predicted_codes_all], + axis=0, + ) + y_manual = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in manual_codes_all], + axis=0, + ) + y_true = np.stack( + [codes_to_multihot(codes, target_codes).numpy() for codes in true_codes_all], + axis=0, + ) + + metrics_list = [ + "accuracy", + "hamming_loss", + "f1_micro", + "f1_macro", + "precision_micro", + "recall_micro", + ] + metrics_manual = multilabel_metrics_fn(y_manual, y_pred, metrics=metrics_list) + metrics_true = multilabel_metrics_fn(y_true, y_pred, metrics=metrics_list) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + os.makedirs(args.output_dir, exist_ok=True) + + results_path = os.path.join( + args.output_dir, f"admission_level_results_per_note_{timestamp}.json" + ) + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + + metrics_path = os.path.join( + args.output_dir, f"admission_level_metrics_per_note_{timestamp}.json" + ) + with open(metrics_path, "w") as f: + json.dump( + { + "evaluation_timestamp": timestamp, + "processing_method": "per_note", + "total_admissions": len(results), + "dry_run": dry_run, + "manual_labels_metrics": metrics_manual, + "true_codes_metrics": metrics_true, + }, + f, + indent=2, + ) + + print("Saved results to:", results_path) + print("Saved metrics to:", metrics_path) + print("Manual labels micro F1:", metrics_manual.get("f1_micro")) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 5176fdb42..0eb6babca 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -57,7 +57,7 @@ def __init__(self, *args, **kwargs): from .eicu import eICUDataset from .isruc import ISRUCDataset from .medical_transcriptions import MedicalTranscriptionsDataset -from .mimic3 import MIMIC3Dataset +from .mimic3 import MIMIC3Dataset, MIMIC3NoteDataset from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset from .mimicextract import MIMICExtractDataset from .omop import OMOPDataset @@ -66,6 +66,7 @@ def __init__(self, *args, **kwargs): from .sleepedf import SleepEDFDataset from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset +from .sdoh_icd9 import SDOHICD9NotesDataset from .tcga_prad import TCGAPRADDataset from .splitter import ( split_by_patient, diff --git a/pyhealth/datasets/mimic3.py b/pyhealth/datasets/mimic3.py index 22ca79d5c..d8c2f117c 100644 --- a/pyhealth/datasets/mimic3.py +++ b/pyhealth/datasets/mimic3.py @@ -1,11 +1,15 @@ import logging import warnings from pathlib import Path -from typing import List, Optional +from typing import Dict, Iterable, List, Optional, Sequence, Set import narwhals as pl +import pandas as pd from .base_dataset import BaseDataset +from .sample_dataset import SampleDataset, create_sample_dataset +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask +from ..tasks.sdoh_utils import TARGET_CODES logger = logging.getLogger(__name__) @@ -85,3 +89,156 @@ def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame: .alias("charttime") ) return df + + +class MIMIC3NoteDataset: + """Note-only loader for MIMIC-III NOTEEVENTS. + + This loader streams NOTEEVENTS in chunks and optionally filters by HADM_IDs + and note categories. Use set_task() to convert admissions into samples. + """ + + def __init__( + self, + noteevents_path: Optional[str] = None, + root: Optional[str] = None, + target_codes: Optional[Sequence[str]] = None, + hadm_ids: Optional[Iterable[str]] = None, + include_categories: Optional[Sequence[str]] = None, + chunksize: int = 200_000, + dataset_name: Optional[str] = None, + ) -> None: + """Initialize the note-only loader. + + Args: + noteevents_path: Path to NOTEEVENTS CSV/CSV.GZ. + root: MIMIC-III root directory (used if noteevents_path is None). + target_codes: Target ICD-9 codes for label vector construction. + hadm_ids: Optional admission IDs to include. + include_categories: Optional NOTE_CATEGORY values to include. + chunksize: Number of rows per chunk read. + dataset_name: Optional dataset name for the SampleDataset. + """ + if noteevents_path is None: + if root is None: + raise ValueError("root is required when noteevents_path is not set.") + noteevents_path = str(Path(root) / "NOTEEVENTS.csv.gz") + self.noteevents_path = noteevents_path + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.hadm_ids = {str(x) for x in hadm_ids} if hadm_ids is not None else None + self.include_categories = ( + {cat.strip().upper() for cat in include_categories} + if include_categories + else None + ) + self.chunksize = chunksize + self.dataset_name = dataset_name or "mimic3_note" + + def _load_notes(self) -> Dict[str, List[Dict]]: + """Load and group note events by admission.""" + keep_cols = { + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + } + + notes_by_hadm: Dict[str, List[Dict]] = {} + for chunk in pd.read_csv( + self.noteevents_path, + chunksize=self.chunksize, + usecols=lambda c: c.upper() in keep_cols, + dtype={"SUBJECT_ID": "string", "HADM_ID": "string"}, + low_memory=False, + ): + chunk.columns = [c.upper() for c in chunk.columns] + if self.hadm_ids is not None: + chunk = chunk[chunk["HADM_ID"].astype("string").isin(self.hadm_ids)] + if chunk.empty: + continue + if self.include_categories is not None: + chunk = chunk[ + chunk["CATEGORY"].astype("string") + .str.upper() + .isin(self.include_categories) + ] + if chunk.empty: + continue + + charttime = pd.to_datetime(chunk["CHARTTIME"], errors="coerce") + chartdate = pd.to_datetime(chunk["CHARTDATE"], errors="coerce") + timestamp = charttime.fillna(chartdate) + + for row, ts in zip(chunk.itertuples(index=False), timestamp): + hadm_id = str(row.HADM_ID) + entry = { + "patient_id": str(row.SUBJECT_ID) if pd.notna(row.SUBJECT_ID) else "", + "text": row.TEXT if pd.notna(row.TEXT) else "", + "category": row.CATEGORY if pd.notna(row.CATEGORY) else "", + "timestamp": ts, + } + notes_by_hadm.setdefault(hadm_id, []).append(entry) + + return notes_by_hadm + + def _build_admissions(self) -> List[Dict]: + """Build admission-level note bundles with timestamps and categories.""" + notes_by_hadm = self._load_notes() + admissions: List[Dict] = [] + for hadm_id, notes in notes_by_hadm.items(): + notes.sort( + key=lambda x: x["timestamp"] + if pd.notna(x["timestamp"]) + else pd.Timestamp.min + ) + note_texts = [str(n["text"]) for n in notes] + note_categories = [str(n["category"]) for n in notes] + chartdates = [ + n["timestamp"].strftime("%Y-%m-%d") if pd.notna(n["timestamp"]) else "Unknown" + for n in notes + ] + + admissions.append( + { + "visit_id": hadm_id, + "patient_id": notes[0]["patient_id"], + "notes": note_texts, + "note_categories": note_categories, + "chartdates": chartdates, + "num_notes": len(note_texts), + "text_length": int(sum(len(note) for note in note_texts)), + } + ) + + logger.info("Loaded %d admissions from NOTEEVENTS", len(admissions)) + return admissions + + def set_task( + self, + task: Optional[SDOHICD9AdmissionTask] = None, + label_source: str = "manual", + label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, + in_memory: bool = True, + ) -> SampleDataset: + """Apply a task to admissions and return a SampleDataset.""" + if task is None: + task = SDOHICD9AdmissionTask( + target_codes=self.target_codes, + label_source=label_source, + label_map=label_map, + ) + + samples: List[Dict] = [] + for admission in self._build_admissions(): + samples.extend(task(admission)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=in_memory, + ) diff --git a/pyhealth/datasets/sdoh_icd9.py b/pyhealth/datasets/sdoh_icd9.py new file mode 100644 index 000000000..e12a695a2 --- /dev/null +++ b/pyhealth/datasets/sdoh_icd9.py @@ -0,0 +1,111 @@ +import logging +from typing import Dict, List, Optional, Sequence + +import pandas as pd + +from .sample_dataset import SampleDataset, create_sample_dataset +from ..tasks.sdoh_icd9_detection import SDOHICD9AdmissionTask +from ..tasks.sdoh_utils import TARGET_CODES, parse_codes + +logger = logging.getLogger(__name__) + + +REQUIRED_COLUMNS = { + "HADM_ID", + "SUBJECT_ID", + "NOTE_CATEGORY", + "CHARTDATE", + "FULL_TEXT", + "ADMISSION_TRUE_CODES", + "ADMISSION_MANUAL_LABELS", +} + + +class SDOHICD9NotesDataset: + """CSV-backed dataset for SDOH ICD-9 V-code detection from notes.""" + + def __init__( + self, + csv_path: str, + dataset_name: Optional[str] = None, + target_codes: Optional[Sequence[str]] = None, + include_categories: Optional[Sequence[str]] = None, + ) -> None: + self.csv_path = csv_path + self.dataset_name = dataset_name or "sdoh_icd9_notes" + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.include_categories = ( + {cat.strip().upper() for cat in include_categories} + if include_categories + else None + ) + self._admissions = self._load_admissions() + + def _load_admissions(self) -> List[Dict]: + df = pd.read_csv(self.csv_path) + missing = REQUIRED_COLUMNS - set(df.columns) + if missing: + raise ValueError(f"Missing required columns: {sorted(missing)}") + + admissions: List[Dict] = [] + df["CHARTDATE"] = pd.to_datetime(df["CHARTDATE"], errors="coerce") + if self.include_categories is not None: + df = df[ + df["NOTE_CATEGORY"].astype("string") + .str.upper() + .isin(self.include_categories) + ] + for hadm_id, group in df.groupby("HADM_ID"): + group = group.sort_values("CHARTDATE") + first = group.iloc[0] + notes = [str(text).strip() for text in group["FULL_TEXT"].fillna("")] + chartdates = [ + dt.strftime("%Y-%m-%d") if pd.notna(dt) else "Unknown" + for dt in group["CHARTDATE"] + ] + admission = { + "visit_id": str(hadm_id), + "patient_id": str(first["SUBJECT_ID"]), + "is_gap_case": first.get("IS_GAP_CASE"), + "note_categories": [ + str(cat).strip() for cat in group["NOTE_CATEGORY"].fillna("") + ], + "chartdates": chartdates, + "notes": notes, + "num_notes": len(notes), + "text_length": int(sum(len(note) for note in notes)), + "manual_codes": parse_codes( + first["ADMISSION_MANUAL_LABELS"], self.target_codes + ), + "true_codes": parse_codes( + first["ADMISSION_TRUE_CODES"], self.target_codes + ), + } + admissions.append(admission) + logger.info("Loaded %d admissions from %s", len(admissions), self.csv_path) + return admissions + + def set_task( + self, + task: Optional[SDOHICD9AdmissionTask] = None, + label_source: str = "manual", + in_memory: bool = True, + ) -> SampleDataset: + if task is None: + task = SDOHICD9AdmissionTask( + target_codes=self.target_codes, + label_source=label_source, + ) + + samples: List[Dict] = [] + for admission in self._admissions: + samples.extend(task(admission)) + + return create_sample_dataset( + samples=samples, + input_schema=task.input_schema, + output_schema=task.output_schema, + dataset_name=self.dataset_name, + task_name=task.task_name, + in_memory=in_memory, + ) diff --git a/pyhealth/metrics/__init__.py b/pyhealth/metrics/__init__.py index da8da0f5b..fa9578ed8 100644 --- a/pyhealth/metrics/__init__.py +++ b/pyhealth/metrics/__init__.py @@ -1,12 +1,19 @@ from .binary import binary_metrics_fn from .drug_recommendation import ddi_rate_score -from .interpretability import ( - ComprehensivenessMetric, - Evaluator, - RemovalBasedMetric, - SufficiencyMetric, - evaluate_attribution, -) +try: + from .interpretability import ( + ComprehensivenessMetric, + Evaluator, + RemovalBasedMetric, + SufficiencyMetric, + evaluate_attribution, + ) +except Exception: # pragma: no cover - optional dependencies + ComprehensivenessMetric = None + Evaluator = None + RemovalBasedMetric = None + SufficiencyMetric = None + evaluate_attribution = None from .multiclass import multiclass_metrics_fn from .multilabel import multilabel_metrics_fn @@ -17,11 +24,11 @@ __all__ = [ "binary_metrics_fn", "ddi_rate_score", - "ComprehensivenessMetric", - "SufficiencyMetric", - "RemovalBasedMetric", - "Evaluator", - "evaluate_attribution", + "ComprehensivenessMetric", + "SufficiencyMetric", + "RemovalBasedMetric", + "Evaluator", + "evaluate_attribution", "multiclass_metrics_fn", "multilabel_metrics_fn", "ranking_metrics_fn", diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 3c0b5384d..b34100b33 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -37,3 +37,4 @@ from .vae import VAE from .sdoh import SdohClassifier from .medlink import MedLink +from .sdoh_icd9_llm import SDOHICD9LLM diff --git a/pyhealth/models/sdoh_icd9_llm.py b/pyhealth/models/sdoh_icd9_llm.py new file mode 100644 index 000000000..ebbd60947 --- /dev/null +++ b/pyhealth/models/sdoh_icd9_llm.py @@ -0,0 +1,382 @@ +import hashlib +import os +import re +import time +from typing import Iterable, List, Optional, Sequence, Set, Tuple + +from pyhealth.tasks.sdoh_utils import TARGET_CODES + + +PROMPT_TEMPLATE = """\ +# SDOH Medical Coding Task + +You are an expert medical coder. Your job is to find Social Determinants of Health (SDOH) codes in clinical notes. + +## CRITICAL PREDICTION RULES: +1. **IF YOU FIND EVIDENCE FOR A CODE → YOU MUST PREDICT THAT CODE** +2. **BE CONSERVATIVE**: Only predict codes with STRONG evidence +3. **Evidence detection = Automatic code prediction** + +## CODE HIERARCHY - ABSOLUTELY CRITICAL: +**HOMELESS (V600) COMPLETELY EXCLUDES V602:** +- ❌ If you see homelessness → ONLY predict V600, NEVER V602 +- ❌ "Homeless and can't afford X" → STILL only V600 +- ✅ V602 is EXCLUSIVELY for housed people with explicit financial hardship + +**V602 ULTRA-STRICT REQUIREMENTS:** +- ✅ Must have exact quotes: "cannot afford medications", "no money for food" +- ❌ NEVER infer from: homelessness, unemployment, substance use, mental health +- ❌ NEVER predict from: "poor", "disadvantaged", social circumstances + + +## Available ICD-9 Codes: + +**Housing & Resources:** +- V600: Homeless (living on streets, shelters, hotels, motels, temporary housing) +- V602: Cannot afford basic needs. **ULTRA-STRICT RULES**: + - ✅ ONLY if EXPLICIT financial statements: "cannot afford medications", "unable to pay for treatment", "no money for food", "financial hardship preventing care" + - ❌ NEVER predict from social circumstances: substance abuse, unemployment, mental health, divorce + - ❌ NEVER predict from discharge inventory: "no money/wallet returned" + - ❌ NEVER predict for insurance/benefits mentions: "on disability", "has Medicaid" +- V604: No family/caregiver available. **CRITICAL RULES**: + - ❌ NOT just "lives alone" or "elderly" + - ✅ ONLY if: "no family contact" AND "no support" AND "no one to help" + - ❌ NOT for: "lives alone but daughter visits", "independent" + +**Employment & Legal:** +- V620: Unemployed (current employment status) — **EXPLICIT ONLY** + - ✅ **Predict only if one of these exact employment-status phrases appears (word boundaries, case-insensitive):** + - unemploy(ed|ment) + - jobless | no job + - out of work | without work (employment context only; see exclusions) + - not working (employment context only; see exclusions) + - between jobs + - on unemployment | receiving unemployment + - Recent loss: laid off | fired | terminated | lost (my|his|her|their)? job + - Label–value fields: Employment|Employment status|Work: → unemployed|none|not working ⇒ V620 + - ❌ Do NOT infer from former-only phrases: used to work | formerly | previously employed | last worked in ... UNLESS illness/disability is given as the reason for stopping work + - ❌ **Contradictions (block V620 if present nearby, ±1 sentence)**: employed | works at/as | working | return(ed) to work | full-time/part-time | self-employed | contractor | on leave/LOA | work excuse | at work today + - ❌ **Non-employment uses of "work" (never trigger)**: work of breathing | workup/work-up | social work | line not working | device not working | therapy not working | meds not working + - ❌ **Context rule for "not/without work"**: Only treat as unemployment when it clearly refers to employment status (e.g., in Social History/Employment section or followed by "due to X" re job). Otherwise, do not code + - ❌ **Exclusions**: retired | student | stay-at-home parent | on medical leave — unless an explicit unemployment phrase above is present + - ✅ Include cases where patient stopped working or is unable to work due to illness/disability (including SSDI/SSI). +- V625: Legal problems. **ACTIVE LEGAL PROCESSES ONLY**: + - ✅ Criminal: Arrest, jail, prison, parole, probation, active charges, bail, or court case pending + - ✅ Civil: Court-ordered custody, restraining order, supervised release, or active litigation + - ✅ Guardianship: Legal capacity hearing, court-appointed guardian, or power of attorney via court + - ✅ Child welfare: DCF/CPS removed children, placed them in care, or court-ordered custody change + - ❌ Do NOT predict V625 for: + - History of crime, substance use, or homelessness without active legal process + - Drug use history alone (marijuana, cocaine, etc.) without legal consequences + - Psych history, social issues, or missed appointments without legal involvement + - DCF "awareness" without custody change or legal action + +**Health History:** +- V1541: Physical/sexual abuse (violence BY another person). **CRITICAL**: + - ✅ Physical violence, sexual abuse, assault, domestic violence, rape + - ❌ NOT accidents, falls, fights where patient was aggressor +- V1542: Pure emotional/psychological abuse. **MUTUALLY EXCLUSIVE WITH V1541**: + - ✅ **ONLY if NO physical/sexual abuse mentioned AND explicit emotional abuse:** + - ✅ Witnessed violence: "witnessed violence as a child", "saw domestic violence" + - ✅ Verbal abuse: "verbal abuse", "emotionally abusive", "psychological abuse" + - ✅ Emotional manipulation: "jealous and controlling", "isolation", "intimidation" + - ❌ **NEVER predict V1542 if ANY physical/sexual abuse mentioned (use V1541 instead)** + - ❌ **NEVER predict both V1541 AND V1542 for same patient** + - ❌ Depression, anxiety, suicidal ideation alone without explicit abuse = NO CODE + - ❌ "History of abuse" without specifying type = NO CODE + - ❌ Psychiatric history alone = NO CODE +- V1584: Past asbestos exposure + +**Family History:** +- V6141: Family alcoholism (family member drinks) — PREDICT if any kinship term + alcohol term appear together. + - ✅ Kinship terms: father, mother, dad, mom, brother, sister, son, daughter, uncle, aunt, cousin, grand*; Headers: FH, FHx, Family History + - ✅ Alcohol terms (case-insensitive, synonyms OK): ETOH/EtOH/etoh, alcohol, alcoholism, AUD, alcohol use disorder, EtOH abuse, EtOH dependence, "alcohol problem(s)", "drinks heavily", "alcoholic" + - ✅ Mixed substance OK: If text says "drug and alcohol problems," still predict V6141 + - ✅ Outside headers OK: If kinship + alcohol appear in same clause/sentence or within ±1 line (e.g., "Pt's father ... has hx of etoh"), predict V6141 + - ✅ Examples to capture: "FH: Father – ETOH", "Mother has h/o alcoholism", "Father with depression and alcoholism", "Multiple family members with ETOH abuse ... (cousin, sister, uncle, aunt, father)", "Both brothers have drug and alcohol problems" + - ❌ Negations: Do not predict if explicitly denied (e.g., "denies family history of alcoholism") + - ❌ NEVER for PATIENT'S own history: "history of alcohol abuse", "with a history of alcohol use", "past medical history significant for heavy EtOH abuse", "patient alcoholic", "ETOH abuse" + +## ENHANCED NEGATIVE EXAMPLES: + +**V602 FALSE POSITIVES TO AVOID:** +❌ "Homeless patient" → Predict V600 ONLY, NEVER V602 +❌ "Lives in shelter, gets food stamps" → V600 ONLY, NEVER V602 +❌ "Homeless, on disability" → V600 ONLY, NEVER V602 +❌ "No permanent address, has Medicaid" → V600 ONLY, NEVER V602 +❌ "Homeless and can't afford medications" → V600 ONLY, NEVER V602 +❌ "Unemployed alcoholic" → V620 (unemployment is explicit), NEVER V602 +❌ "Lives in poverty" → NEVER V602 (too vague) +❌ "Financial strain from divorce" → NEVER V602 (circumstantial) + +**V604 FALSE POSITIVES TO AVOID:** +❌ "82 year old lives alone" → NO CODE unless no support mentioned +❌ "Lives by herself" → NO CODE unless isolation confirmed +❌ "Widowed, lives alone, son calls daily" → NO CODE (has support) + +**V1542 FALSE POSITIVES TO AVOID:** +❌ "History of physical and sexual abuse" → V1541 ONLY (physical trumps emotional) +❌ "PTSD from rape at age 7" → V1541 ONLY (sexual abuse) +❌ "Childhood sexual abuse by uncle" → V1541 ONLY (sexual abuse) +❌ "History of domestic abuse" → V1541 ONLY (physical abuse) +❌ "Depression and anxiety" → NO CODE (psychiatric symptoms alone) +❌ "Suicide attempts" → NO CODE (mental health history alone) +❌ "History of abuse" → NO CODE (unspecified type) +❌ "Recent argument with partner" → NO CODE (relationship conflict) + +**V1542 TRUE POSITIVES TO CAPTURE:** +✅ "Witnessed violence as a child" → V1542 (pure emotional trauma, no physical) +✅ "Emotionally abusive relationship for 14 years" → V1542 (explicit emotional abuse) +✅ "Verbal abuse from controlling partner" → V1542 (explicit emotional abuse) +✅ "Jealous and controlling behavior" → V1542 (emotional manipulation) + +## CONFIDENCE RULES: + +**HIGH CONFIDENCE (Predict):** +- Direct statement of condition +- Multiple supporting evidence pieces +- Explicit language matching code definition + +**LOW CONFIDENCE (Don't Predict):** +- Ambiguous language +- Single weak indicator +- Contradictory evidence + +## Key Rules: + +1. **Precision over Recall**: Better to miss a code than falsely predict +2. **Evidence-Driven**: Strong evidence required for prediction +3. **Multiple codes allowed**: But each needs independent evidence +4. **Conservative approach**: When in doubt, don't predict + +## Output Format: +Return applicable codes separated by commas, or "None" if no codes apply. + +Example: +``` +V600, V625 +``` + +or if no codes apply: +``` +None +``` + +--- + +**Clinical Note to Analyze:** +{note} +""" + + +def _load_prompt_template() -> str: + return PROMPT_TEMPLATE + + +class SDOHICD9LLM: + """Admission-level SDOH ICD-9 V-code detector using an LLM. + + This model sends each note for an admission to an LLM, parses predicted + ICD-9 V-codes, and aggregates the codes across notes (set union). + + Notes: + - Use ``dry_run=True`` to skip LLM calls while exercising the pipeline. + - Predictions are derived entirely from the LLM response parsing logic. + + Examples: + >>> from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM + >>> notes = [ + ... "Pt is homeless and has no family support.", + ... "Social work consulted for housing resources.", + ... ] + >>> model = SDOHICD9LLM(dry_run=True, max_notes=2) + >>> codes, note_results = model.predict_admission_with_notes(notes) + >>> codes + set() + + >>> model = SDOHICD9LLM(model_name="gpt-4o-mini", max_notes=1) + >>> codes, note_results = model.predict_admission_with_notes(notes) + """ + + def __init__( + self, + target_codes: Optional[Sequence[str]] = None, + model_name: str = "gpt-4o-mini", + prompt_template: Optional[str] = None, + api_key: Optional[str] = None, + max_tokens: int = 100, + max_chars: int = 100000, + temperature: float = 0.0, + sleep_s: float = 0.2, + max_notes: Optional[int] = None, + dry_run: bool = False, + ) -> None: + """Initialize the LLM wrapper. + + Args: + target_codes: Target ICD-9 codes to retain after parsing. + model_name: OpenAI model name. + prompt_template: Optional prompt template override. Uses built-in + SDOH template if not provided. + api_key: OpenAI API key. Defaults to ``OPENAI_API_KEY`` env var. + max_tokens: Max tokens for LLM response. + max_chars: Max chars from each note to send. + temperature: LLM temperature. + sleep_s: Delay between per-note requests (seconds). + max_notes: Optional limit on notes per admission. + dry_run: If True, skips API calls and returns "None" responses. + """ + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + self.model_name = model_name + self.prompt_template = prompt_template or _load_prompt_template() + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") + self.max_tokens = max_tokens + self.max_chars = max_chars + self.temperature = temperature + self.sleep_s = sleep_s + self.max_notes = max_notes + self.dry_run = dry_run + self._client = None + + if not self.api_key and not self.dry_run: + raise EnvironmentError( + "OPENAI_API_KEY is required unless dry_run=True." + ) + + def _get_client(self): + """Initialize and cache the OpenAI client.""" + if self._client is None: + from openai import OpenAI + + self._client = OpenAI(api_key=self.api_key) + return self._client + + def _call_openai_api(self, text: str) -> str: + """Send a single note to the LLM and return the raw response. + + Args: + text: Note text to send. + + Returns: + Raw string response from the LLM. + """ + self._write_prompt_preview(text) + + if self.dry_run: + return "```None```" + + if len(text) > self.max_chars: + text = text[: self.max_chars] + "\n\n[Note truncated due to length...]" + + client = self._get_client() + response = client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "system", "content": self.prompt_template.format(note=text)}, + ], + max_tokens=self.max_tokens, + temperature=self.temperature, + ) + return response.choices[0].message.content.strip() + + def _write_prompt_preview(self, text: str) -> None: + """Write the fully rendered prompt (with note) to a local file.""" + prompt = self.prompt_template.format(note=text) + digest = hashlib.sha1(prompt.encode("utf-8")).hexdigest()[:10] + filename = f"sdoh_prompt_{digest}.txt" + with open(filename, "w", encoding="utf-8") as f: + f.write(prompt) + + def _parse_llm_response(self, response: str) -> Set[str]: + """Parse the LLM response into a set of valid target codes. + + Returns: + A set of ICD-9 codes intersected with ``target_codes``. + """ + if not response: + return set() + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + if matches: + response = matches[0].strip() + else: + response = response.strip() + + if response.lower().strip() == "none": + return set() + + response = response.replace("Answer:", "").replace("Codes:", "").strip() + for delimiter in [",", ";", " ", "\n"]: + if delimiter in response: + parts = [c.strip() for c in response.split(delimiter)] + break + else: + parts = [response.strip()] + + valid = {code.upper() for code in parts if code.strip()} + target_set = {code.upper() for code in self.target_codes} + return {code for code in valid if code in target_set} + + def _predict_admission( + self, + notes: Iterable[str], + note_categories: Optional[Iterable[str]] = None, + chartdates: Optional[Iterable[str]] = None, + ) -> Tuple[Set[str], List[dict]]: + """Run per-note predictions and aggregate codes for one admission. + + Args: + notes: Iterable of note texts. + note_categories: Optional note categories aligned to ``notes``. + chartdates: Optional chart dates aligned to ``notes``. + + Returns: + A tuple of (aggregated_codes, per_note_results). + """ + aggregated: Set[str] = set() + note_results: List[dict] = [] + categories = list(note_categories) if note_categories is not None else [] + dates = list(chartdates) if chartdates is not None else [] + notes_list = list(notes) + if self.max_notes and self.max_notes > 0: + notes_list = notes_list[: self.max_notes] + categories = categories[: self.max_notes] + dates = dates[: self.max_notes] + + for idx, note in enumerate(notes_list): + category = categories[idx] if idx < len(categories) else "Unknown" + date = dates[idx] if idx < len(dates) else "Unknown" + response = self._call_openai_api(note) + predicted = self._parse_llm_response(response) + aggregated.update(predicted) + + note_results.append( + { + "category": category, + "date": date, + "predicted_codes": sorted(predicted), + "llm_response": response, + } + ) + if self.sleep_s > 0 and not self.dry_run: + time.sleep(self.sleep_s) + + return aggregated, note_results + + def predict_admission_with_notes( + self, + notes: Iterable[str], + note_categories: Optional[Iterable[str]] = None, + chartdates: Optional[Iterable[str]] = None, + ) -> Tuple[Set[str], List[dict]]: + """Predict codes for one admission using per-note LLM calls. + + Args: + notes: Iterable of note texts. + note_categories: Optional note categories aligned to ``notes``. + chartdates: Optional chart dates aligned to ``notes``. + + Returns: + A tuple of (aggregated_codes, per_note_results). + """ + return self._predict_admission(notes, note_categories, chartdates) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 28c71f142..31cb01c0d 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -69,3 +69,4 @@ VariantClassificationClinVar, ) from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task +from .sdoh_icd9_detection import SDOHICD9AdmissionTask diff --git a/pyhealth/tasks/sdoh_icd9_detection.py b/pyhealth/tasks/sdoh_icd9_detection.py new file mode 100644 index 000000000..16cdace7b --- /dev/null +++ b/pyhealth/tasks/sdoh_icd9_detection.py @@ -0,0 +1,63 @@ +from typing import Dict, List, Optional, Sequence, Set + +import torch + +from .base_task import BaseTask +from .sdoh_utils import TARGET_CODES, codes_to_multihot + + +class SDOHICD9AdmissionTask(BaseTask): + """Builds admission-level samples for SDOH ICD-9 V-code detection. + + The task attaches a multi-hot label vector and the corresponding label + codes based on the provided admission dictionary (and optional label_map). + """ + + task_name: str = "SDOHICD9Admission" + input_schema: Dict[str, str] = { + "notes": "raw", + "note_categories": "raw", + "chartdates": "raw", + "patient_id": "raw", + "visit_id": "raw", + } + output_schema: Dict[str, object] = { + "label": ("tensor", {"dtype": torch.float32}), + } + + def __init__( + self, + target_codes: Optional[Sequence[str]] = None, + label_source: str = "manual", + label_map: Optional[Dict[str, Dict[str, Set[str]]]] = None, + ) -> None: + """Initialize the admission-level task. + + Args: + target_codes: Target ICD-9 codes for label vector construction. + label_source: Which label set to use ("manual" or "true"). + label_map: Optional mapping of HADM_ID to label codes. + """ + self.target_codes = list(target_codes) if target_codes else list(TARGET_CODES) + if label_source not in {"manual", "true"}: + raise ValueError("label_source must be 'manual' or 'true'") + self.label_source = label_source + self.label_map = label_map or {} + + def __call__(self, admission: Dict) -> List[Dict]: + """Build a single labeled sample from an admission dictionary.""" + admission_id = str(admission.get("visit_id", "")) + if admission_id and admission_id in self.label_map: + admission = dict(admission) + admission.setdefault("manual_codes", self.label_map[admission_id]["manual"]) + admission.setdefault("true_codes", self.label_map[admission_id]["true"]) + + if self.label_source == "manual": + label_codes: Set[str] = admission.get("manual_codes", set()) + else: + label_codes = admission.get("true_codes", set()) + + sample = dict(admission) + sample["label_codes"] = sorted(label_codes) + sample["label"] = codes_to_multihot(label_codes, self.target_codes) + return [sample] diff --git a/pyhealth/tasks/sdoh_utils.py b/pyhealth/tasks/sdoh_utils.py new file mode 100644 index 000000000..78e81bd57 --- /dev/null +++ b/pyhealth/tasks/sdoh_utils.py @@ -0,0 +1,90 @@ +"""Utilities for SDOH ICD-9 V-code detection tasks.""" + +from typing import Dict, Iterable, Sequence, Set + +import pandas as pd +import torch + + +# Standard SDOH ICD-9 V-codes +TARGET_CODES = [ + "V600", + "V602", + "V604", + "V620", + "V625", + "V1541", + "V1542", + "V1584", + "V6141", +] + + +def parse_codes(codes_str: object, target_codes: Sequence[str]) -> Set[str]: + """Parse ICD-9 codes from various string formats. + + Args: + codes_str: String representation of codes (comma/semicolon separated) + target_codes: Valid target codes to filter for + + Returns: + Set of valid codes found in the string + """ + if pd.isna(codes_str) or str(codes_str).strip() == "": + return set() + + # Clean string + codes = ( + str(codes_str) + .replace("[", "") + .replace("]", "") + .replace('"', "") + .replace("'", "") + ) + + # Split by delimiter + if "," in codes: + values = [c.strip() for c in codes.split(",")] + elif ";" in codes: + values = [c.strip() for c in codes.split(";")] + else: + values = [codes.strip()] + + # Filter to valid target codes + target_set = {code.upper() for code in target_codes} + parsed = {value.upper() for value in values if value.strip()} + return {code for code in parsed if code in target_set} + + +def codes_to_multihot(codes: Iterable[str], target_codes: Sequence[str]) -> torch.Tensor: + """Convert code set to multi-hot encoding. + + Args: + codes: Iterable of code strings + target_codes: Ordered list of target codes + + Returns: + Multi-hot tensor aligned with target_codes + """ + code_set = {code.upper() for code in codes} + return torch.tensor( + [1.0 if code in code_set else 0.0 for code in target_codes], + dtype=torch.float32, + ) + + +def load_sdoh_icd9_labels( + csv_path: str, target_codes: Sequence[str] +) -> Dict[str, Dict[str, Set[str]]]: + df = pd.read_csv(csv_path) + if "HADM_ID" not in df.columns: + raise ValueError("CSV must include HADM_ID column.") + + labels: Dict[str, Dict[str, Set[str]]] = {} + for hadm_id, group in df.groupby("HADM_ID"): + first = group.iloc[0] + labels[str(hadm_id)] = { + "manual": parse_codes(first.get("ADMISSION_MANUAL_LABELS"), target_codes), + "true": parse_codes(first.get("ADMISSION_TRUE_CODES"), target_codes), + } + return labels diff --git a/tests/core/test_sdoh_llm.py b/tests/core/test_sdoh_llm.py new file mode 100644 index 000000000..2bc7e99e9 --- /dev/null +++ b/tests/core/test_sdoh_llm.py @@ -0,0 +1,43 @@ +from unittest import mock + +from base import BaseTestCase +from pyhealth.models.sdoh_icd9_llm import SDOHICD9LLM + + +class TestSdohLLM(BaseTestCase): + def setUp(self): + self.set_random_seed() + + def test_llm_aggregation(self): + model = SDOHICD9LLM(dry_run=True) + notes = ["note one", "note two", "note three"] + + responses = iter( + [ + "```V600```", + "None", + "```V620, V625```", + ] + ) + + with mock.patch.object(model, "_call_openai_api", side_effect=lambda _: next(responses)): + with mock.patch.object(model, "_write_prompt_preview", return_value=None): + aggregated, note_results = model.predict_admission_with_notes(notes) + + self.assertEqual({"V600", "V620", "V625"}, aggregated) + self.assertEqual(3, len(note_results)) + self.assertEqual({"V600"}, set(note_results[0]["predicted_codes"])) + self.assertEqual(set(), set(note_results[1]["predicted_codes"])) + self.assertEqual({"V620", "V625"}, set(note_results[2]["predicted_codes"])) + + def test_llm_max_notes(self): + model = SDOHICD9LLM(dry_run=True, max_notes=1) + notes = ["note one", "note two"] + + with mock.patch.object(model, "_call_openai_api", return_value="V600") as mocked: + with mock.patch.object(model, "_write_prompt_preview", return_value=None): + aggregated, note_results = model.predict_admission_with_notes(notes) + + self.assertEqual({"V600"}, aggregated) + self.assertEqual(1, len(note_results)) + mocked.assert_called_once() diff --git a/tests/core/test_sdoh_mimic3_notes.py b/tests/core/test_sdoh_mimic3_notes.py new file mode 100644 index 000000000..0cf63c293 --- /dev/null +++ b/tests/core/test_sdoh_mimic3_notes.py @@ -0,0 +1,133 @@ +import csv +import os +import tempfile + +from base import BaseTestCase +from pyhealth.datasets import MIMIC3NoteDataset +from pyhealth.tasks.sdoh_utils import TARGET_CODES, codes_to_multihot + + +class TestSdohMimic3Notes(BaseTestCase): + def setUp(self): + self.set_random_seed() + + def test_mimic3_note_dataset(self): + """Test MIMIC3NoteDataset with label filtering and categories.""" + with tempfile.TemporaryDirectory() as tmpdir: + noteevents_path = os.path.join(tmpdir, "NOTEEVENTS.csv") + with open(noteevents_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 12:00:00", + "Physician", + "Pt is homeless", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 13:00:00", + "Radiology", + "XR chest normal", + ] + ) + writer.writerow( + [ + "2", + "20", + "2020-02-01", + "2020-02-01 09:00:00", + "Physician", + "No issues reported", + ] + ) + + label_map = { + "10": {"manual": {"V600"}, "true": set()}, + "20": {"manual": set(), "true": set()}, + } + + dataset = MIMIC3NoteDataset( + noteevents_path=noteevents_path, + hadm_ids=["10"], + include_categories=["Physician"], + ) + sample_dataset = dataset.set_task( + label_source="manual", + label_map=label_map, + ) + + self.assertEqual(1, len(sample_dataset)) + sample = next(iter(sample_dataset)) + self.assertEqual("10", sample["visit_id"]) + self.assertEqual(["Pt is homeless"], sample["notes"]) + self.assertEqual(["Physician"], sample["note_categories"]) + + expected_label = codes_to_multihot({"V600"}, TARGET_CODES).tolist() + self.assertEqual(expected_label, sample["label"].tolist()) + + def test_note_sorting_and_chartdate_fallback(self): + """Test ordering with missing charttime and chartdate fallback.""" + with tempfile.TemporaryDirectory() as tmpdir: + noteevents_path = os.path.join(tmpdir, "NOTEEVENTS.csv") + with open(noteevents_path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "SUBJECT_ID", + "HADM_ID", + "CHARTDATE", + "CHARTTIME", + "CATEGORY", + "TEXT", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-02", + "", + "Physician", + "Later note", + ] + ) + writer.writerow( + [ + "1", + "10", + "2020-01-01", + "2020-01-01 08:00:00", + "Physician", + "Earlier note", + ] + ) + + label_map = {"10": {"manual": {"V600"}, "true": set()}} + dataset = MIMIC3NoteDataset( + noteevents_path=noteevents_path, + include_categories=["Physician"], + ) + sample_dataset = dataset.set_task( + label_source="manual", + label_map=label_map, + ) + + sample = next(iter(sample_dataset)) + self.assertEqual(["Earlier note", "Later note"], sample["notes"])