diff --git a/docs/tutorials/.pages b/docs/tutorials/.pages index 875214fc..4f1079b3 100644 --- a/docs/tutorials/.pages +++ b/docs/tutorials/.pages @@ -3,5 +3,7 @@ nav: - basic_renewal_model.md - custom_randomvariables.md - hospital_admissions_model.md + - observation_processes_counts.md + - observation_processes_measurements.md - day_of_the_week.md - periodic_effects.md diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd new file mode 100644 index 00000000..6715279a --- /dev/null +++ b/docs/tutorials/observation_processes_counts.qmd @@ -0,0 +1,621 @@ +--- +title: "Observation processes for count data" +format: + gfm: + fig-width: 16 + fig-height: 10 + html: + toc: true + embed-resources: true + self-contained-math: true + code-fold: true + code-tools: true +engine: jupyter +--- + +This tutorial demonstrates how to use the `Counts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. + +```{python} +# | label: setup +# | output: false +import jax.numpy as jnp +import numpy as np +import numpyro +import plotnine as p9 +import pandas as pd +from pathlib import Path +import sys +import warnings +from plotnine.exceptions import PlotnineWarning + +warnings.filterwarnings("ignore", category=PlotnineWarning) + +import matplotlib.pyplot as plt + +from pyrenew.observation import Counts, NegativeBinomialNoise, PoissonNoise +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +from pyrenew import datasets +``` + +## Overview + +Count observation processes model the lag between infections and an observed outcome such as hospital admissions, emergency department visits, confirmed cases, or deaths. +Observed data can be aggregated or available as subpopulation-level counts, which are modeled by classes `Counts` and `CountsBySubpop`, respectively. + +Count observation processes transform infections into predicted counts by applying an event probability and/or ascertainment rate and convolving with a delay distribution. + +The predicted observations on day $t$ are: + +$$\lambda_t = \alpha \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ + +where: + +- $I_{t-d}$ is the number of incident (new) infections on day $t-d$ (i.e., $d$ days before day $t$) +- $\alpha$ is the rate of ascertained counts per infection (e.g., infection-to-hospital admission rate). This can model a mix of biological effects (e.g. some percentage of infections lead to hospital admissions, but not all) and reporting effects (e.g. some percentage of admissions that occur are reported, but not all). +- $p_d$ is the delay distribution from infection to observation, conditional on an infection leading to an observation +- $D$ is the maximum delay + +Discrete observations are generated by sampling from a noise distribution—e.g. Poisson or negative binomial—to model reporting variability. +Poisson assumes variance equals the mean; negative binomial accommodates the overdispersion common in surveillance data. + +**Note on terminology:** In real-world inference, incident infections are typically a *latent* (unobserved) quantity and must be estimated from observed data like hospital admissions. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce hospital admissions through convolution and sampling. + +## Hospital admissions example + +For hospital admissions data, we construct a `Counts` observation process. +The delay is the key mechanism: infections from $d$ days ago ($I_{t-d}$) contribute to today's hospital admissions ($\lambda_t$) weighted by the probability ($p_d$) that an infection leads to hospitalization after exactly $d$ days. The convolution sums these contributions across all past days. + +The process generates hospital admissions by sampling from a negative binomial distribution: + +$$Y_t \sim \text{NegativeBinomial}(\mu = \lambda_t, \text{concentration} = \phi)$$ + +The concentration parameter $\phi$ (sometimes called $k$ or the dispersion parameter) controls overdispersion: as $\phi \to \infty$, the distribution approaches Poisson; smaller values allow greater overdispersion. + +We use the negative binomial distribution because real-world hospital admission counts exhibit overdispersion—the variance exceeds the mean. +The Poisson distribution assumes variance equals the mean, which is too restrictive. The negative binomial adds an overdispersion term: + +$$\text{Var}[Y_t] = \mu + \frac{\mu^2}{\phi}$$ + +In this example, we use fixed parameter values for illustration; in practice, these parameters would be estimated from data using weakly informative priors. + +## Infection-to-hospitalization delay distribution + +The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. +For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. + +We load a delay distribution from PyRenew's datasets which peaks around day 8-9 post-infection, compute summary statistics, and plot it. + +```{python} +# | label: delay-distribution +inf_hosp_int = datasets.load_infection_admission_interval() +hosp_delay_pmf = jnp.array(inf_hosp_int["probability_mass"].to_numpy()) +delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) + +# Summary statistics +days = np.arange(len(hosp_delay_pmf)) +mean_delay = float(np.sum(days * hosp_delay_pmf)) +mode_delay = int(np.argmax(hosp_delay_pmf)) +sd = float(np.sqrt(np.sum(days**2 * hosp_delay_pmf) - mean_delay**2)) +print( + f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}" +) +``` + +```{python} +# | label: plot-delay-distribution +delay_df = pd.DataFrame( + {"days": days, "probability": np.array(hosp_delay_pmf)} +) + +plot_delay = ( + p9.ggplot(delay_df, p9.aes(x="days", y="probability")) + + p9.geom_col(fill="steelblue", alpha=0.7, color="black") + + p9.geom_vline( + xintercept=mode_delay, color="purple", linetype="solid", size=1 + ) + + p9.geom_vline( + xintercept=mean_delay, color="red", linetype="dashed", size=1 + ) + + p9.labs( + x="Days from infection to hospitalization", + y="Probability", + title="Infection-to-Hospitalization Delay Distribution", + ) + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=14, weight="bold")) + + p9.annotate( + "text", + x=mode_delay + 8, + y=max(delay_df["probability"]) * 0.95, + label=f"Mode: {mode_delay} days", + color="purple", + size=10, + ) + + p9.annotate( + "text", + x=mean_delay + 8, + y=max(delay_df["probability"]) * 0.8, + label=f"Mean: {mean_delay:.1f} days", + color="red", + size=10, + ) +) +plot_delay +``` + +## Creating a Counts observation process + +A `Counts` object takes the following arguments: + +- **`ascertainment_rate_rv`**: the probability an infection results in an observation (e.g., IHR) +- **`delay_distribution_rv`**: delay distribution from infection to observation (PMF) +- **`noise`**: noise model (`PoissonNoise()` or `NegativeBinomialNoise(concentration_rv)`) + +For hospital admissions, the ascertainment rate is specifically called the infection-hospitalization rate (IHR). +In this example, the percentage of infections which lead to hospitalization is treated as a fixed value, +which will allow us to see how different values affect the model. +The concentration parameter for the negative binomial noise model is also fixed. +In practice, both of these parameters would be given a somewhat informative prior and then inferred. + +```{python} +# | label: create-counts-process +# Infection-hospitalization ratio (1% of infections lead to hospitalization) +ihr_rv = DeterministicVariable("ihr", 0.01) + +# Overdispersion parameter for negative binomial +concentration_rv = DeterministicVariable("concentration", 10.0) + +# Create the observation process +hosp_process = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(concentration_rv), +) +``` + +### Timeline alignment and lookback period + +The observation process convolves infections with a delay distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. + +Hospital admissions depend on infections from prior days (the length of our delay distribution minus one). The method `lookback_days()` returns this length; the first valid observation day is at index `lookback - 1`. Earlier days are marked invalid. + +```{python} +# | label: helper-function +print(f"Required lookback: {hosp_process.lookback_days()} days") + + +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history for convolution.""" + return obs_process.lookback_days() - 1 +``` + +## Simulating observed hospital admissions given a single day's worth of infections + +To demonstrate how a `Counts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. + + +```{python} +# | label: simulate-spike +n_days = 100 +day_one = first_valid_observation_day(hosp_process) + +# Create infections with a spike +infection_spike_day = day_one + 10 +infections = jnp.zeros(n_days) +infections = infections.at[infection_spike_day].set(2000) +``` + +We plot the infections starting from day_one (the first valid observation day, after the lookback period). +```{python} +# | label: plot-infections +# Plot relative to first valid observation day +n_plot_days = n_days - day_one +rel_spike_day = infection_spike_day - day_one + +infections_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "count": np.array(infections[day_one:]), + } +) + +max_infection_count = float(jnp.max(infections[day_one:])) + +plot_infections = ( + p9.ggplot(infections_df, p9.aes(x="day", y="count")) + + p9.geom_line(color="darkblue", size=1) + + p9.geom_point(color="darkblue", size=2) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.labs(x="Day", y="Daily Infections", title="Infections (Input)") + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=13, weight="bold")) + + p9.annotate( + "text", + x=rel_spike_day + 2, + y=max_infection_count * 0.9, + label=f"Infection spike (day {rel_spike_day})", + color="darkred", + size=10, + ) +) +plot_infections +``` + +Because all infections occur on a single day, this allows us to see how one day's worth of infections result in hospital admissions spread over subsequent days according to the delay distribution. + +## Predicted admissions without observation noise. + +First, we compute the predicted admissions from the convolution alone, without observation noise. This is the mean of the distribution from which samples are drawn. + +```{python} +# | label: predicted-no-noise +# Compute predicted admissions (convolution only, no observation noise) +from pyrenew.convolve import compute_delay_ascertained_incidence + +# Scale infections by IHR (ascertainment rate) +infections_scaled = infections * float(ihr_rv.sample()) +predicted_admissions, offset = compute_delay_ascertained_incidence( + p_observed_given_incident=1.0, + latent_incidence=infections_scaled, + delay_incidence_to_observation_pmf=hosp_delay_pmf, + pad=True, +) +``` + +```{python} +# | label: plot-predicted-no-noise +# Relative peak day for plotting +peak_day = rel_spike_day + mode_delay + +# Plot predicted admissions (x-axis: day_one = first valid observation day) +predicted_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "admissions": np.array(predicted_admissions[day_one:]), + } +) + +max_predicted = float(predicted_df["admissions"].max()) +plot_predicted = ( + p9.ggplot(predicted_df, p9.aes(x="day", y="admissions")) + + p9.geom_line(color="purple", size=1) + + p9.geom_point(color="purple", size=1.5) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.geom_vline( + xintercept=peak_day, + color="purple", + linetype="dashed", + alpha=0.5, + ) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Predicted Hospital Admissions (Deterministic)", + ) + + p9.theme_grey() + + p9.annotate( + "text", + x=rel_spike_day, + y=max_predicted * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkred", + size=9, + ha="center", + ) + + p9.annotate( + "text", + x=peak_day, + y=max_predicted * 1.05, + label=f"Peak\n(day {peak_day})", + color="purple", + size=9, + ha="center", + ) +) +plot_predicted +``` + +The predicted admissions mirror the delay distribution, shifted by the infection spike day and scaled by the IHR. + + +## Observation Noise (Negative Binomial) + +The negative binomial distribution adds stochastic variation. Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +# Sample 50 realizations of hospital admissions from the same infection spike +n_samples = 50 +samples_list = [] + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + hosp_sample = hosp_process.sample(infections=infections, obs=None) + + for i, val in enumerate(hosp_sample.observed[day_one:]): + samples_list.append( + { + "day": i, + "admissions": float(val), + "sample": seed, + "type": "sampled", + } + ) + +# Add predicted values +for i, val in enumerate(predicted_admissions[day_one:]): + samples_list.append( + { + "day": i, + "admissions": float(val), + "sample": -1, + "type": "predicted", + } + ) +``` + +```{python} +# | label: plot-realizations +samples_df = pd.DataFrame(samples_list) +sampled_df = samples_df[samples_df["type"] == "sampled"] +predicted_noise_df = samples_df[samples_df["type"] == "predicted"] + +# Separate one sample to highlight +highlight_sample = 0 +other_samples_df = sampled_df[sampled_df["sample"] != highlight_sample] +highlight_df = sampled_df[sampled_df["sample"] == highlight_sample] + +plot_50_samples = ( + p9.ggplot() + + p9.geom_line( + p9.aes(x="day", y="admissions", group="sample"), + data=other_samples_df, + color="orange", + alpha=0.15, + size=0.5, + ) + + p9.geom_line( + p9.aes(x="day", y="admissions"), + data=highlight_df, + color="steelblue", + size=1, + ) + + p9.geom_line( + p9.aes(x="day", y="admissions"), + data=predicted_noise_df, + color="darkred", + size=1.2, + ) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkblue", + linetype="dashed", + alpha=0.5, + ) + + p9.labs( + x="Day", + y="Hospital Admissions", + title=f"Observation Noise: {n_samples} Samples from Same Infections", + subtitle="Blue: one realization | Orange: other samples | Dark red: predicted", + ) + + p9.theme_grey() +) +plot_50_samples +``` + +```{python} +# | label: timeline-stats +# Print timeline statistics +print("Timeline Analysis:") +print( + f" Infection spike on day {rel_spike_day}: {infections[infection_spike_day]:.0f} people" +) +print(f" Mode delay from infection to hospitalization: {mode_delay} days") +print( + f" Predicted hospitalization peak: day {rel_spike_day + mode_delay} (= {rel_spike_day} + {mode_delay})" +) +``` + +## Effect of the ascertainment rate + +The ascertainment rate (here, the infection-hospitalization rate or IHR) directly scales the number of predicted hospital admissions. +We compare two contrasting IHR values: **0.5%** and **2.5%**. + +```{python} +# | label: compare-ihr +# Two contrasting IHR values +ihr_values = [0.005, 0.025] +peak_value = 3000 # Peak infections +infections_decay = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + +# Compute predicted hospital admissions (no noise) for each IHR +results_list = [] +for ihr_val in ihr_values: + infections_scaled = infections_decay * ihr_val + predicted_hosp, _ = compute_delay_ascertained_incidence( + p_observed_given_incident=1.0, + latent_incidence=infections_scaled, + delay_incidence_to_observation_pmf=hosp_delay_pmf, + pad=True, + ) + + for i, admit in enumerate(predicted_hosp[day_one:]): + results_list.append( + { + "day": i, + "admissions": float(admit), + "IHR": f"IHR = {ihr_val:.1%}", + } + ) +``` + + +```{python} +# | label: plot-ihr-comparisons +results_df = pd.DataFrame(results_list) + +plot_ihr = ( + p9.ggplot(results_df, p9.aes(x="day", y="admissions", color="IHR")) + + p9.geom_line(size=1) + + p9.scale_color_manual(values=["steelblue", "darkred"]) + + p9.labs( + x="Day", + y="Predicted Hospital Admissions", + title="Effect of IHR on Predicted Hospital Admissions", + color="Infection-Hospitalization\nrate", + ) + + p9.theme_grey() +) +plot_ihr +``` + +## Negative binomial concentration parameter + +The concentration parameter $\phi$ controls overdispersion: + +- Higher $\phi$ → less overdispersion (approaches Poisson) +- Lower $\phi$ → more overdispersion (noisier data) + +We compare three concentration values spanning two orders of magnitude: + +- **φ = 1**: high overdispersion (noisy) +- **φ = 10**: moderate overdispersion +- **φ = 100**: nearly Poisson (minimal noise) + +```{python} +# | label: concentration-comparisons +# Use constant infections +peak_value = 2000 +infections_constant = peak_value * jnp.ones(n_days) + +# Concentration values spanning two orders of magnitude +concentration_values = [1.0, 10.0, 100.0] +n_replicates = 10 + +# Collect results +conc_results = [] +for conc_val in concentration_values: + conc_rv_temp = DeterministicVariable("conc", conc_val) + process_temp = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=NegativeBinomialNoise(conc_rv_temp), + ) + + for seed in range(n_replicates): + with numpyro.handlers.seed(rng_seed=seed): + hosp_temp = process_temp.sample( + infections=infections_constant, + obs=None, + ) + + # Use relative days + for i, admit in enumerate(hosp_temp.observed[day_one:]): + conc_results.append( + { + "day": i, + "admissions": float(admit), + "concentration": f"φ = {int(conc_val)}", + "replicate": seed, + } + ) +``` + +```{python} +# | label: plot-concentration-comparisons +conc_df = pd.DataFrame(conc_results) + +# Convert to ordered categorical +conc_df["concentration"] = pd.Categorical( + conc_df["concentration"], + categories=["φ = 1", "φ = 10", "φ = 100"], + ordered=True, +) + +plot_concentration = ( + p9.ggplot(conc_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.facet_wrap("~ concentration", ncol=3) + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Effect of Concentration Parameter on Variability", + ) + + p9.theme_grey() +) +plot_concentration +``` + +## Swapping noise models + +To use Poisson noise instead of negative binomial, change the noise model: + +```{python} +# | label: poisson-noise +hosp_process_poisson = Counts( + ascertainment_rate_rv=ihr_rv, + delay_distribution_rv=delay_rv, + noise=PoissonNoise(), +) + +with numpyro.handlers.seed(rng_seed=42): + poisson_result = hosp_process_poisson.sample( + infections=infections, + obs=None, + ) + +print( + f"Sampled {len(poisson_result.observed)} days of hospital admissions with Poisson noise" +) +``` + +We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. + +To see the reduction in noise, it is necessary to keep the y-axis on the same scale as in the previous plot. + +```{python} +# | label: poisson-realizations +# Sample multiple realizations with Poisson noise +n_replicates_poisson = 10 + +poisson_results = [] +for seed in range(n_replicates_poisson): + with numpyro.handlers.seed(rng_seed=seed): + poisson_temp = hosp_process_poisson.sample( + infections=infections_constant, + obs=None, + ) + + for i, admit in enumerate(poisson_temp.observed[day_one:]): + poisson_results.append( + { + "day": i, + "admissions": float(admit), + "replicate": seed, + } + ) +poisson_df = pd.DataFrame(poisson_results) + +plot_poisson = ( + p9.ggplot(poisson_df, p9.aes(x="day", y="admissions", group="replicate")) + + p9.geom_line(alpha=0.5, size=0.8, color="steelblue") + + p9.labs( + x="Day", + y="Hospital Admissions", + title="Poisson Noise Model (Variance = Mean)", + ) + + p9.theme_grey() + + p9.ylim(0, 105) +) +plot_poisson +``` diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd new file mode 100644 index 00000000..ddf89a8d --- /dev/null +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -0,0 +1,726 @@ +--- +title: "Observation processes for continuous measurements" +format: + gfm: + fig-width: 16 + fig-height: 10 + html: + toc: true + embed-resources: true + self-contained-math: true + code-fold: true + code-tools: true +engine: jupyter +--- + +This tutorial demonstrates how to use the `Measurements` observation process to model continuous measurement data. We first explain the general framework, then illustrate with a wastewater viral concentration example. + +```{python} +# | label: setup +# | output: false +import jax +import jax.numpy as jnp +import numpy as np +import numpyro +import matplotlib.pyplot as plt +import pandas as pd +import plotnine as p9 + +import numpyro.distributions as dist + +from pyrenew.observation import Measurements, HierarchicalNormalNoise +from pyrenew.randomvariable import DistributionalVariable +from pyrenew.deterministic import DeterministicVariable, DeterministicPMF +``` + +## The Measurements Class + +The `Measurements` class models continuous signals derived from infections. Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). + +**Examples of measurement data:** + +- Wastewater viral concentrations +- Air quality pathogen levels +- Serological assay results +- Environmental sensor readings + +### The general pattern + +All measurement observation processes follow the same pattern: + +$$\text{observed} \sim \text{Noise}\bigl(\text{predicted}(\text{infections})\bigr)$$ + +where: + +1. **`_predicted_obs(infections)`**: Transforms infections into predicted measurement values (you implement this) +2. **Noise model**: Adds stochastic variation around predictions (provided by PyRenew) + +The `Measurements` base class provides: + +- Convolution utilities for temporal delays +- Timeline alignment between infections and observations +- Integration with hierarchical noise models +- Support for multiple sensors and subpopulations + +### Comparison with count observations + +The core convolution structure is shared with count observations, but key aspects differ: + +| Aspect | Counts | Measurements | +|--------|--------|--------------| +| Output type | Discrete counts | Continuous values | +| Output space | Linear (expected counts) | Often log-transformed | +| Noise model | Poisson or Negative Binomial | Normal (often on log scale) | +| Scaling | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific | +| Subpop structure | Optional (`CountsBySubpop`) | Inherent (sensor/site effects) | + +### The noise model + +Measurement data typically exhibits **sensor-level variability**: different instruments, labs, or sampling locations have systematic biases and different precision levels. + +`HierarchicalNormalNoise` models this with two per-sensor parameters: + +- **Sensor mode**: Systematic bias (additive shift) +- **Sensor SD**: Measurement precision (noise level) + +``` +observed ~ Normal(predicted + sensor_mode[sensor], sensor_sd[sensor]) +``` + +The noise model samples sensor-level parameters within a plate, so any `RandomVariable` can be used as a prior: + +```{python} +# | label: noise-model-general +# Sensor modes: zero-centered, allowing positive or negative bias +sensor_mode_rv = DistributionalVariable("sensor_mode", dist.Normal(0, 0.5)) + +# Sensor SDs: must be positive, truncated normal is a common choice +sensor_sd_rv = DistributionalVariable( + "sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.05) +) + +# Create noise model +noise = HierarchicalNormalNoise( + sensor_mode_rv=sensor_mode_rv, + sensor_sd_rv=sensor_sd_rv, +) +``` + +### The indexing system + +Measurement observations use three index arrays to map observations to their context: + +| Index array | Purpose | +|-------------|---------| +| `times` | Day index for each observation | +| `subpop_indices` | Which infection trajectory (subpopulation) generated each observation | +| `sensor_indices` | Which sensor made each observation (determines noise parameters) | + +This flexible indexing supports: + +- **Irregular sampling**: Observations don't need to be daily +- **Multiple sensors per subpopulation**: Different labs analyzing the same source +- **Multiple subpopulations per sensor**: One sensor serving multiple areas (less common) + +### Subclassing Measurements + +To create a measurement process for your domain, subclass `Measurements` and implement: + +1. **`_predicted_obs(infections)`**: Transform infections to predicted values +2. **`validate()`**: Check parameter validity +3. **`lookback_days()`**: Return the temporal PMF length + +```python +class MyMeasurement(Measurements): + def __init__(self, temporal_pmf_rv, noise, my_scaling_param): + super().__init__(temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.my_scaling_param = my_scaling_param + + def _predicted_obs(self, infections): + # Your domain-specific transformation here + pmf = self.temporal_pmf_rv() + # ... convolve, scale, transform ... + return predicted_values + + def validate(self): + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def lookback_days(self): + return len(self.temporal_pmf_rv()) +``` + + +## Measurement Example: Wastewater + +To illustrate the framework, we specify a wastewater viral concentration observation process, +based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-hew) family of models. + +**The wastewater signal** + +Wastewater treatment plants measure viral RNA concentrations in sewage. +The predicted concentration depends on: + +- **Infections**: People shed virus into wastewater +- **Shedding kinetics**: Viral shedding peaks a few days after infection +- **Scaling factors**: Genome copies per infection, wastewater volume + +The predicted log-concentration on day $t$ is: + +$$\log(\lambda_t) = \log\left(\frac{G}{V} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d\right)$$ + +where: + +- $I_{t-d}$ is infections on day $t-d$ +- $p_d$ is the shedding kinetics PMF (fraction shed on day $d$ post-infection) +- $G$ is genome copies shed per infection +- $V$ is wastewater volume per person per day + +Observations are log-concentrations with normal noise: + +$$y_t \sim \text{Normal}(\log(\lambda_t) + \text{sensor\_mode}, \text{sensor\_sd})$$ + +### Implementing the Wastewater class + +```{python} +# | label: wastewater-class +from jax.typing import ArrayLike +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.noise import MeasurementNoise + + +class Wastewater(Measurements): + """ + Wastewater viral concentration observation process. + + Transforms site-level infections into predicted log-concentrations + via shedding kinetics convolution and genome/volume scaling. + """ + + def __init__( + self, + shedding_kinetics_rv: RandomVariable, + log10_genome_per_infection_rv: RandomVariable, + ml_per_person_per_day: float, + noise: MeasurementNoise, + ) -> None: + """ + Initialize wastewater observation process. + + Parameters + ---------- + shedding_kinetics_rv : RandomVariable + Viral shedding PMF (fraction shed each day post-infection). + log10_genome_per_infection_rv : RandomVariable + Log10 genome copies shed per infection. + ml_per_person_per_day : float + Wastewater volume per person per day (mL). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise). + """ + super().__init__(temporal_pmf_rv=shedding_kinetics_rv, noise=noise) + self.log10_genome_per_infection_rv = log10_genome_per_infection_rv + self.ml_per_person_per_day = ml_per_person_per_day + + def validate(self) -> None: + """Validate parameters.""" + shedding_pmf = self.temporal_pmf_rv() + self._validate_pmf(shedding_pmf, "shedding_kinetics_rv") + self.noise.validate() + + def lookback_days(self) -> int: + """Return shedding PMF length.""" + return len(self.temporal_pmf_rv()) + + def _predicted_obs(self, infections: ArrayLike) -> ArrayLike: + """ + Compute predicted log-concentration from infections. + + Applies shedding kinetics convolution, then scales by + genome copies and volume to get concentration. + """ + shedding_pmf = self.temporal_pmf_rv() + log10_genome = self.log10_genome_per_infection_rv() + + # Convolve each site's infections with shedding kinetics + def convolve_site(site_infections): + convolved, _ = self._convolve_with_alignment( + site_infections, shedding_pmf, p_observed=1.0 + ) + return convolved + + # Apply to all subpops (infections shape: n_days x n_subpops) + shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( + infections + ) + + # Convert to concentration: genomes per mL + genome_copies = 10**log10_genome + concentration = ( + shedding_signal * genome_copies / self.ml_per_person_per_day + ) + + # Return log-concentration (what we model) + return jnp.log(concentration) +``` + +#### Configuring wastewater-specific parameters + +**Viral shedding kinetics** + +The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: + +```{python} +# | label: shedding-pmf +# Peak shedding ~3 days after infection, continues for ~10 days +shedding_pmf = jnp.array( + [0.0, 0.05, 0.15, 0.25, 0.20, 0.15, 0.10, 0.05, 0.03, 0.02] +) +print(f"PMF sums to: {shedding_pmf.sum():.2f}") + +shedding_rv = DeterministicPMF("viral_shedding", shedding_pmf) + +# Summary statistics +days = np.arange(len(shedding_pmf)) +mean_shedding_day = float(np.sum(days * shedding_pmf)) +mode_shedding_day = int(np.argmax(shedding_pmf)) +print(f"Mode: {mode_shedding_day} days, Mean: {mean_shedding_day:.1f} days") +``` + +```{python} +# | label: plot-shedding +# Visualize the shedding distribution +shedding_df = pd.DataFrame( + {"days": days, "probability": np.array(shedding_pmf)} +) + +( + p9.ggplot(shedding_df, p9.aes(x="days", y="probability")) + + p9.geom_col(fill="steelblue", alpha=0.7, color="black") + + p9.geom_vline( + xintercept=mode_shedding_day, color="purple", linetype="solid", size=1 + ) + + p9.geom_vline( + xintercept=mean_shedding_day, color="red", linetype="dashed", size=1 + ) + + p9.labs( + x="Days after infection", + y="Fraction of total shedding", + title="Viral Shedding Kinetics", + ) + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=14, weight="bold")) + + p9.annotate( + "text", + x=mode_shedding_day + 2, + y=max(shedding_df["probability"]) * 0.95, + label=f"Mode: {mode_shedding_day} days", + color="purple", + size=10, + ) + + p9.annotate( + "text", + x=mean_shedding_day + 2, + y=max(shedding_df["probability"]) * 0.8, + label=f"Mean: {mean_shedding_day:.1f} days", + color="red", + size=10, + ) +) +``` + +**Genome copies and wastewater volume** + +```{python} +# | label: scaling-params +# Log10 genome copies shed per infection (typical range: 8-10) +log10_genome_rv = DeterministicVariable("log10_genome", 9.0) + +# Wastewater volume per person per day (mL) +ml_per_person_per_day = 1000.0 +``` + +### Sensor-level noise + +For wastewater, a "sensor" is a WWTP/lab pair—the combination of treatment plant and laboratory that determines measurement characteristics: + +```{python} +# | label: ww-noise-model +# Sensor-level mode: systematic differences between WWTP/lab pairs +ww_sensor_mode_rv = DistributionalVariable( + "ww_sensor_mode", dist.Normal(0, 0.5) +) + +# Sensor-level SD: measurement variability within each WWTP/lab pair +ww_sensor_sd_rv = DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.10) +) + +ww_noise = HierarchicalNormalNoise( + sensor_mode_rv=ww_sensor_mode_rv, + sensor_sd_rv=ww_sensor_sd_rv, +) +``` + +### Creating the wastewater observation process + +```{python} +# | label: create-process +ww_process = Wastewater( + shedding_kinetics_rv=shedding_rv, + log10_genome_per_infection_rv=log10_genome_rv, + ml_per_person_per_day=ml_per_person_per_day, + noise=ww_noise, +) + +print(f"Required lookback: {ww_process.lookback_days()} days") +``` + +## Simulations + +### Timeline alignment + +The observation process maintains alignment: day $t$ in output corresponds to day $t$ in input. The first `lookback_days() - 1` days have incomplete history and are marked invalid. + +```{python} +# | label: helper-function +def first_valid_observation_day(obs_process) -> int: + """Return the first day index with complete infection history.""" + return obs_process.lookback_days() - 1 +``` + +### Simulating from observations from a single-day infection spike + +To see how infections spread into concentrations via shedding kinetics, we simulate from a single-day spike: + +```{python} +# | label: simulate-spike +n_days = 50 +day_one = first_valid_observation_day(ww_process) + +# Create infections with a spike (shape: n_days x n_subpops) +infection_spike_day = day_one + 10 +infections = jnp.zeros((n_days, 1)) # 1 subpopulation +infections = infections.at[infection_spike_day, 0].set(2000.0) + +# For plotting +rel_spike_day = infection_spike_day - day_one +n_plot_days = n_days - day_one + +# Observation times and indices +observation_days = jnp.arange(day_one, 40, dtype=jnp.int32) +n_obs = len(observation_days) + +with numpyro.handlers.seed(rng_seed=42): + ww_obs = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + obs=None, + n_sensors=1, + ) +``` +We plot the resulting observations starting from the first valid observation day. + +```{python} +# | label: plot-spike-infections +infections_df = pd.DataFrame( + { + "day": np.arange(n_plot_days), + "infections": np.array(infections[day_one:, 0]), + } +) + +max_infection_count = float(jnp.max(infections[day_one:])) + +plot_infections = ( + p9.ggplot(infections_df, p9.aes(x="day", y="infections")) + + p9.geom_line(color="darkblue", size=1) + + p9.geom_point(color="darkblue", size=2) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkred", + linetype="dashed", + alpha=0.5, + ) + + p9.labs(x="Day", y="Daily Infections", title="Infections (Input)") + + p9.theme_grey() + + p9.theme(plot_title=p9.element_text(size=13, weight="bold")) + + p9.annotate( + "text", + x=rel_spike_day, + y=max_infection_count * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkred", + size=10, + ha="center", + ) +) +plot_infections +``` + +### Observation noise + +Sampling multiple times from the same infections shows the range of possible observations: + +```{python} +# | label: sample-realizations +n_samples = 50 +ww_samples_list = [] + +for seed in range(n_samples): + with numpyro.handlers.seed(rng_seed=seed): + ww_result = ww_process.sample( + infections=infections, + subpop_indices=jnp.zeros(n_obs, dtype=jnp.int32), + sensor_indices=jnp.zeros(n_obs, dtype=jnp.int32), + times=observation_days, + obs=None, + n_sensors=1, + ) + for day_idx, conc in zip(observation_days, ww_result.observed): + ww_samples_list.append( + { + "day": int(day_idx) - day_one, + "log_concentration": float(conc), + "sample": seed, + } + ) + +ww_samples_df = pd.DataFrame(ww_samples_list) +``` + +```{python} +# | label: plot-sampled-concentrations +# Compute mean across samples for each day +mean_by_day = ( + ww_samples_df.groupby("day")["log_concentration"].mean().reset_index() +) +mean_by_day["sample"] = -1 + +# Relative peak day for plotting (using mode, not mean, since distribution is skewed) +peak_day = rel_spike_day + mode_shedding_day + +# Separate one sample to highlight +highlight_sample = 0 +other_samples_df = ww_samples_df[ww_samples_df["sample"] != highlight_sample] +highlight_df = ww_samples_df[ww_samples_df["sample"] == highlight_sample] + +# For annotation positioning +max_conc = ww_samples_df["log_concentration"].max() + +( + p9.ggplot() + + p9.geom_line( + p9.aes(x="day", y="log_concentration", group="sample"), + data=other_samples_df, + color="orange", + alpha=0.15, + size=0.5, + ) + + p9.geom_line( + p9.aes(x="day", y="log_concentration"), + data=highlight_df, + color="steelblue", + size=1, + ) + + p9.geom_line( + p9.aes(x="day", y="log_concentration"), + data=mean_by_day, + color="darkred", + size=1.2, + ) + + p9.geom_vline( + xintercept=rel_spike_day, + color="darkblue", + linetype="dashed", + alpha=0.5, + ) + + p9.geom_vline( + xintercept=peak_day, + color="darkred", + linetype="dotted", + alpha=0.7, + ) + + p9.annotate( + "text", + x=rel_spike_day, + y=max_conc * 1.05, + label=f"Infection spike\n(day {rel_spike_day})", + color="darkblue", + size=9, + ha="center", + ) + + p9.annotate( + "text", + x=peak_day, + y=max_conc * 0.98, + label=f"Expected peak\n(day {peak_day})", + color="darkred", + size=9, + ha="center", + ) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title=f"Observation Noise: {n_samples} Samples from Same Infections", + subtitle="Blue: one realization | Orange: other samples | Dark red: sample mean", + ) + + p9.theme_grey() +) +``` + +### Sensor-level variability + +The previous plot showed variability from repeatedly sampling the entire observation process (resampling sensor parameters and noise each time). In practice, we have multiple physical sensors, each with fixed but unknown characteristics. + +This plot shows four sensors observing the **same infection spike**. Each sensor has: + +- A **sensor mode** (systematic bias): shifts all observations up or down +- A **sensor SD** (measurement precision): determines noise level around predictions + +These parameters are sampled once per sensor, then held fixed across all observations from that sensor. + +```{python} +# | label: multi-sensor +num_sensors = 4 + +# Use the same observation times and infections as the sampled-concentrations plot +sensor_obs_times = jnp.tile(observation_days, num_sensors) +sensor_ids = jnp.repeat( + jnp.arange(num_sensors, dtype=jnp.int32), len(observation_days) +) +subpop_ids = jnp.zeros(num_sensors * len(observation_days), dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_sensor = ww_process.sample( + infections=infections, # Same spike as before + subpop_indices=subpop_ids, + sensor_indices=sensor_ids, + times=sensor_obs_times, + obs=None, + n_sensors=num_sensors, + ) + +# Create DataFrame for plotting (using relative days) +multi_sensor_df = pd.DataFrame( + { + "day": np.array(sensor_obs_times) - day_one, + "log_concentration": np.array(ww_multi_sensor.observed), + "sensor": [f"Sensor {i}" for i in np.array(sensor_ids)], + } +) +``` + +```{python} +# | label: plot-multi-sensor +# Use same y-axis range as sampled-concentrations plot for comparison +y_min = ww_samples_df["log_concentration"].min() +y_max = ww_samples_df["log_concentration"].max() + +( + p9.ggplot( + multi_sensor_df, p9.aes(x="day", y="log_concentration", color="sensor") + ) + + p9.geom_line(size=1) + + p9.geom_point(size=2) + + p9.ylim(y_min, y_max * 1.05) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title="Four Sensors Observing the Same Infection Spike", + color="Sensor", + ) + + p9.theme_grey() +) +``` + +Compare this to the previous plot: here, each colored line represents a distinct physical sensor with its own systematic bias. The vertical spread between sensors reflects differences in sensor modes, while the noise within each line reflects each sensor's measurement precision. During inference, these sensor-specific effects are learned from data. + +### Multiple subpopulations + +In regional surveillance, each wastewater treatment plant serves a distinct **catchment area** (subpopulation) with its own infection dynamics. The `subpop_indices` array maps each observation to the appropriate infection trajectory. + +This example shows two subpopulations with different epidemic curves: + +- **Subpopulation 0**: Slow decay (e.g., large urban area with sustained transmission) +- **Subpopulation 1**: Fast decay (e.g., smaller community with rapid burnout) + +Each subpopulation is observed by its own sensor. The observed concentrations reflect both the underlying infection differences AND the sensor-specific measurement characteristics. + +```{python} +# | label: multi-subpop +# Two subpopulations with different infection patterns +n_days_mp = 40 +infections_subpop1 = 1000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 20.0 +) # Slow decay +infections_subpop2 = 2000.0 * jnp.exp( + -jnp.arange(n_days_mp) / 10.0 +) # Fast decay +infections_multi = jnp.stack([infections_subpop1, infections_subpop2], axis=1) + +# Two sensors, each observing a different subpopulation +obs_days_mp = jnp.tile(jnp.arange(10, 30, 2, dtype=jnp.int32), 2) +subpop_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) +sensor_ids_mp = jnp.array([0] * 10 + [1] * 10, dtype=jnp.int32) + +with numpyro.handlers.seed(rng_seed=42): + ww_multi_subpop = ww_process.sample( + infections=infections_multi, + subpop_indices=subpop_ids_mp, + sensor_indices=sensor_ids_mp, + times=obs_days_mp, + obs=None, + n_sensors=2, + ) + +# Create DataFrame for plotting +multi_subpop_df = pd.DataFrame( + { + "day": np.array(obs_days_mp), + "log_concentration": np.array(ww_multi_subpop.observed), + "subpopulation": [f"Subpop {i}" for i in np.array(subpop_ids_mp)], + } +) +``` + +```{python} +# | label: plot-multi-subpop +( + p9.ggplot( + multi_subpop_df, + p9.aes(x="day", y="log_concentration", color="subpopulation"), + ) + + p9.geom_line(size=1) + + p9.geom_point(size=2) + + p9.labs( + x="Day", + y="Log Viral Concentration", + title="Two Subpopulations with Different Infection Dynamics", + color="Subpopulation", + ) + + p9.theme_grey() +) +``` + +The diverging trajectories reflect the different underlying infection curves. Subpopulation 1 starts higher but decays faster, while Subpopulation 0 maintains more sustained levels. In a full model, you would jointly infer the infection trajectories for each subpopulation while accounting for sensor-specific biases. + +--- + +## Summary + +The `Measurements` class provides: + +1. **A consistent interface** for continuous observation processes +2. **Hierarchical noise models** that capture sensor-level variability +3. **Flexible indexing** for irregular, multi-sensor, multi-subpopulation data +4. **Convolution utilities** with proper timeline alignment + +To use it for your domain: + +1. Subclass `Measurements` +2. Implement `_predicted_obs()` with your signal transformation +3. Configure appropriate priors for sensor-level effects +4. Use the indexing system to map observations to their context diff --git a/mkdocs.yml b/mkdocs.yml index a4da9382..1c2745eb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,9 +64,9 @@ markdown_extensions: - callouts extra_javascript: - - javascripts/katex.js - https://unpkg.com/katex@0/dist/katex.min.js - https://unpkg.com/katex@0/dist/contrib/auto-render.min.js + - javascripts/katex.js extra_css: - https://unpkg.com/katex@0/dist/katex.min.css diff --git a/pyproject.toml b/pyproject.toml index 01d7800a..514e6d3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dev = [ "nbconvert>=7.16.6", "pytest>=8.4.2", "pytest-cov>=6.3.0", + "plotnine>=0.14.0", "pytest-mpl>=0.17.0", "scipy>=1.16.1", ] diff --git a/pyrenew/observation/__init__.py b/pyrenew/observation/__init__.py index b0e04e69..8a0cdeab 100644 --- a/pyrenew/observation/__init__.py +++ b/pyrenew/observation/__init__.py @@ -1,9 +1,54 @@ # numpydoc ignore=GL08 +""" +Observation processes for connecting infections to observed data. +Architecture +------------ +``BaseObservationProcess`` is the abstract base. Concrete subclasses: + +- ``Counts``: Aggregate counts (admissions, deaths) +- ``CountsBySubpop``: Subpopulation-level counts +- ``Measurements``: Continuous subpopulation-level signals (e.g., wastewater) + +All observation processes implement: + +- ``sample()``: Sample observations given infections +- ``infection_resolution()``: returns ``"aggregate"`` or ``"subpop"`` +- ``lookback_days()``: returns required infection history length + +Noise models (``CountNoise``, ``MeasurementNoise``) are composable—pass them +to observation constructors to control the output distribution. +""" + +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.count_observations import Counts, CountsBySubpop +from pyrenew.observation.measurements import Measurements from pyrenew.observation.negativebinomial import NegativeBinomialObservation +from pyrenew.observation.noise import ( + CountNoise, + HierarchicalNormalNoise, + MeasurementNoise, + NegativeBinomialNoise, + PoissonNoise, +) from pyrenew.observation.poisson import PoissonObservation +from pyrenew.observation.types import ObservationSample __all__ = [ + # Existing (kept for backward compatibility) "NegativeBinomialObservation", "PoissonObservation", + # Base classes and types + "BaseObservationProcess", + "ObservationSample", + # Noise models + "CountNoise", + "PoissonNoise", + "NegativeBinomialNoise", + "MeasurementNoise", + "HierarchicalNormalNoise", + # Observation processes + "Counts", + "CountsBySubpop", + "Measurements", ] diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py new file mode 100644 index 00000000..f415671a --- /dev/null +++ b/pyrenew/observation/base.py @@ -0,0 +1,337 @@ +# numpydoc ignore=GL08 +""" +Abstract base class for observation processes. + +Provides common functionality for observation processes that use convolution +with temporal distributions to connect infections to observed data. +""" + +from __future__ import annotations + +from abc import abstractmethod + +import jax.numpy as jnp +import numpyro +from jax.typing import ArrayLike + +from pyrenew.convolve import compute_delay_ascertained_incidence +from pyrenew.metaclass import RandomVariable + + +class BaseObservationProcess(RandomVariable): + """ + Abstract base class for observation processes that use convolution + with temporal distributions. + + This class provides common functionality for connecting infections + to observed data (e.g., hospital admissions, wastewater concentrations) + through temporal convolution operations. + + Key features provided: + + - PMF validation (sum to 1, non-negative) + - Minimum observation day calculation + - Convolution wrapper with timeline alignment + - Deterministic quantity tracking + + Subclasses must implement: + + - ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs) + - ``lookback_days()``: Return PMF length for initialization + - ``infection_resolution()``: Return ``"aggregate"`` or ``"subpop"`` + - ``_predicted_obs()``: Transform infections to predicted values + - ``sample()``: Apply noise model to predicted observations + + Notes + ----- + Computing predicted observations on day t requires infection history + from previous days (determined by the temporal PMF length). + The first ``len(pmf) - 1`` days have insufficient history and return NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying convolution function + pyrenew.metaclass.RandomVariable : + Base class for all random variables + """ + + def __init__(self, temporal_pmf_rv: RandomVariable) -> None: + """ + Initialize base observation process. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + The temporal distribution PMF (e.g., delay or shedding distribution). + Must sample to a 1D array that sums to ~1.0 with non-negative values. + Subclasses may have additional parameters. + + Notes + ----- + Subclasses should call ``super().__init__(temporal_pmf_rv)`` + in their constructors and may add additional parameters. + """ + self.temporal_pmf_rv = temporal_pmf_rv + + @abstractmethod + def validate(self) -> None: + """ + Validate observation process parameters. + + Subclasses must implement this method to validate all parameters. + Typically this involves calling ``_validate_pmf()`` for the PMF + and adding any additional parameter-specific validation. + + Raises + ------ + ValueError + If any parameters fail validation. + """ + pass # pragma: no cover + + @abstractmethod + def lookback_days(self) -> int: + """ + Return the number of days this observation process needs to look back. + + This determines the minimum n_initialization_points required by the + latent process when this observation is included in a multi-signal model. + + Returns + ------- + int + Number of days of infection history required. + Typically the length of the delay or shedding PMF. + + Notes + ----- + This is used by model builders to automatically compute + n_initialization_points as: + ``max(gen_int_length, max(all lookbacks)) - 1`` + """ + pass # pragma: no cover + + @abstractmethod + def infection_resolution(self) -> str: + """ + Return whether this observation uses aggregate or subpop infections. + + Returns one of: + + - ``"aggregate"``: Uses a single aggregated infection trajectory. + Shape: ``(n_days,)`` + - ``"subpop"``: Uses subpopulation-level infection trajectories. + Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``. + + Returns + ------- + str + Either ``"aggregate"`` or ``"subpop"`` + + Examples + -------- + >>> # Aggregated count observations + >>> hosp_obs.infection_resolution() # Returns "aggregate" + >>> + >>> # Subpopulation-level observations (wastewater, subpop-specific counts) + >>> ww_obs.infection_resolution() # Returns "subpop" + + Notes + ----- + This is used by multi-signal models to route the correct infection + output to each observation process. + """ + pass # pragma: no cover + + def _validate_pmf( + self, + pmf: ArrayLike, + param_name: str, + atol: float = 1e-6, + ) -> None: + """ + Validate that an array is a valid probability mass function. + + Checks: + + - Non-empty array + - Sums to 1.0 (within tolerance) + - All non-negative values + + Parameters + ---------- + pmf : ArrayLike + The PMF array to validate + param_name : str + Name of the parameter (for error messages) + atol : float, default 1e-6 + Absolute tolerance for sum-to-one check + + Raises + ------ + ValueError + If PMF is empty, doesn't sum to 1.0 (within tolerance), + or contains negative values. + """ + if pmf.size == 0: + raise ValueError(f"{param_name} must return non-empty array") + + pmf_sum = jnp.sum(pmf) + if not jnp.isclose(pmf_sum, 1.0, atol=atol): + raise ValueError( + f"{param_name} must sum to 1.0 (±{atol}), got {float(pmf_sum):.6f}" + ) + + if jnp.any(pmf < 0): + raise ValueError(f"{param_name} must have non-negative values") + + def get_minimum_observation_day(self) -> int: + """ + Get the first day with valid (non-NaN) convolution results. + + Due to the convolution operation requiring a history window, + the first ``len(pmf) - 1`` days will have NaN values in the + output. This method returns the index of the first valid day. + + Returns + ------- + int + Day index (0-based) of first valid observation. + Equal to ``len(pmf) - 1``. + """ + pmf = self.temporal_pmf_rv() + return int(len(pmf) - 1) + + def _convolve_with_alignment( + self, + latent_incidence: ArrayLike, + pmf: ArrayLike, + p_observed: float = 1.0, + ) -> tuple[ArrayLike, int]: + """ + Convolve latent incidence with PMF while maintaining timeline alignment. + + This is a wrapper around ``compute_delay_ascertained_incidence`` that + always uses ``pad=True`` to ensure day t in the output corresponds to + day t in the input. The first ``len(pmf) - 1`` days will be NaN. + + Parameters + ---------- + latent_incidence : ArrayLike + Latent incidence time series (infections, prevalence, etc.). + Shape: (n_days,) + pmf : ArrayLike + Delay or shedding PMF. Shape: (n_pmf,) + p_observed : float, default 1.0 + Observation probability multiplier. Scales the convolution result. + + Returns + ------- + tuple[ArrayLike, int] + - convolved_array : ArrayLike + Convolved time series with same length as input. + First ``len(pmf) - 1`` days are NaN. + Shape: (n_days,) + - offset : int + Always 0 when pad=True (maintained for API compatibility) + + Notes + ----- + For t < len(pmf)-1, there is insufficient history, so output[t] = NaN. + + See Also + -------- + pyrenew.convolve.compute_delay_ascertained_incidence : + Underlying function + """ + return compute_delay_ascertained_incidence( + latent_incidence=latent_incidence, + delay_incidence_to_observation_pmf=pmf, + p_observed_given_incident=p_observed, + pad=True, # Maintains timeline alignment + ) + + def _deterministic(self, name: str, value: ArrayLike) -> None: + """ + Track a deterministic quantity in the numpyro execution trace. + + This is a convenience wrapper around ``numpyro.deterministic`` for + tracking intermediate quantities (e.g., latent admissions, predicted + concentrations) that are useful for diagnostics and model checking. + These quantities are stored in MCMC samples and can be used for + model diagnostics and posterior predictive checks. + + Parameters + ---------- + name : str + Name for the tracked quantity. Will appear in MCMC samples. + value : ArrayLike + Value to track. Can be any shape. + """ + numpyro.deterministic(name, value) + + @abstractmethod + def _predicted_obs( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Transform infections to predicted observation values. + + This is the core transformation that each observation process must + implement. It converts infections (from the infection process) + to predicted values for the observation model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for aggregate observations + Shape: (n_days, n_subpops) for subpop-level observations + + Returns + ------- + ArrayLike + Predicted observation values (counts, log-concentrations, etc.). + Same shape as input, with first len(pmf)-1 days as NaN. + + Notes + ----- + The transformation is observation-specific: + + - Count observations: ascertainment x delay convolution -> predicted counts + - Wastewater: shedding convolution -> genome scaling -> dilution -> log + + See Also + -------- + sample : Uses this method then applies noise model + """ + pass # pragma: no cover + + @abstractmethod + def sample( + self, + obs: ArrayLike | None = None, + **kwargs, + ) -> ArrayLike: + """ + Sample from the observation process. + + Subclasses must implement this method to define the specific + observation model. Typically calls ``_predicted_obs`` first, + then applies the noise model. + + Parameters + ---------- + obs : ArrayLike | None + Observed data for conditioning, or None for prior predictive sampling. + **kwargs + Subclass-specific parameters (e.g., infections from the infection process). + + Returns + ------- + ArrayLike + Observed or sampled values from the observation process. + """ + pass # pragma: no cover diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py new file mode 100644 index 00000000..48a8d574 --- /dev/null +++ b/pyrenew/observation/count_observations.py @@ -0,0 +1,348 @@ +# numpydoc ignore=GL08 +""" +Count observations with composable noise models. + +Ascertainment x delay convolution with pluggable noise (Poisson, Negative Binomial, etc.). +""" + +from __future__ import annotations + +import jax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import CountNoise +from pyrenew.observation.types import ObservationSample + + +class _CountBase(BaseObservationProcess): + """ + Internal base for count observation processes. + + Implements ascertainment x delay convolution with pluggable noise model. + """ + + def __init__( + self, + ascertainment_rate_rv: RandomVariable, + delay_distribution_rv: RandomVariable, + noise: CountNoise, + ) -> None: + """ + Initialize count observation base. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model for count observations (Poisson, NegBin, etc.). + """ + super().__init__(temporal_pmf_rv=delay_distribution_rv) + self.ascertainment_rate_rv = ascertainment_rate_rv + self.noise = noise + + def validate(self) -> None: + """ + Validate observation parameters. + + Raises + ------ + ValueError + If delay PMF invalid, ascertainment rate outside [0,1], + or noise params invalid. + """ + delay_pmf = self.temporal_pmf_rv() + self._validate_pmf(delay_pmf, "delay_distribution_rv") + + ascertainment_rate = self.ascertainment_rate_rv() + if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1): + raise ValueError( + "ascertainment_rate_rv must be in [0, 1], " + "got value(s) outside this range" + ) + + self.noise.validate() + + def lookback_days(self) -> int: + """ + Return delay PMF length. + + Returns + ------- + int + Length of delay distribution PMF. + """ + return len(self.temporal_pmf_rv()) + + def infection_resolution(self) -> str: + """ + Return required infection resolution. + + Returns + ------- + str + "aggregate" or "subpop". + """ + raise NotImplementedError("Subclasses must implement infection_resolution()") + + def _predicted_obs( + self, + infections: ArrayLike, + ) -> ArrayLike: + """ + Compute predicted counts via ascertainment x delay convolution. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days,) for aggregate + Shape: (n_days, n_subpops) for subpop-level + + Returns + ------- + ArrayLike + Predicted counts with timeline alignment. + Same shape as input. + First len(delay_pmf)-1 days are NaN. + """ + delay_pmf = self.temporal_pmf_rv() + ascertainment_rate = self.ascertainment_rate_rv() + + is_1d = infections.ndim == 1 + if is_1d: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, delay_pmf, ascertainment_rate)[0] + + predicted_counts = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + return predicted_counts[:, 0] if is_1d else predicted_counts + + +class Counts(_CountBase): + """ + Aggregated count observation. + + Maps aggregate infections to counts through ascertainment x delay + convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1] (e.g., IHR, IER). + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days return + -1 or ~0 (depending on noise model) due to NaN padding. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import Counts, NegativeBinomialNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.2, 0.5, 0.3]) + >>> counts_obs = Counts( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones(30) * 1000 + ... sampled_counts = counts_obs.sample(infections=infections, obs=None) + """ + + def infection_resolution(self) -> str: + """ + Return "aggregate" for aggregated observations. + + Returns + ------- + str + The string "aggregate". + """ + return "aggregate" + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"Counts(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def sample( + self, + infections: ArrayLike, + obs: ArrayLike | None = None, + times: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample aggregated counts with dense or sparse observations. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Aggregate infections from the infection process. + Shape: (n_days,) + obs : ArrayLike | None + Observed counts. Dense: (n_days,), Sparse: (n_obs,), None: prior. + times : ArrayLike | None + Day indices for sparse observations. None for dense observations. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `predicted` (predicted counts before noise). + """ + predicted_counts = self._predicted_obs(infections) + self._deterministic("predicted_counts", predicted_counts) + predicted_counts_safe = jnp.nan_to_num(predicted_counts, nan=0.0) + + # Only use sparse indexing when conditioning on observations + if times is not None and obs is not None: + predicted_obs = predicted_counts_safe[times] + else: + predicted_obs = predicted_counts_safe + + observed = self.noise.sample( + name="counts", + predicted=predicted_obs, + obs=obs, + ) + + return ObservationSample(observed=observed, predicted=predicted_counts) + + +class CountsBySubpop(_CountBase): + """ + Subpopulation-level count observation. + + Maps subpopulation-level infections to counts through + ascertainment x delay convolution with composable noise model. + + Parameters + ---------- + ascertainment_rate_rv : RandomVariable + Ascertainment rate in [0, 1]. + delay_distribution_rv : RandomVariable + Delay distribution PMF (must sum to ~1.0). + noise : CountNoise + Noise model (PoissonNoise, NegativeBinomialNoise, etc.). + + Notes + ----- + Output preserves input timeline. First len(delay_pmf)-1 days are NaN. + + Examples + -------- + >>> from pyrenew.deterministic import DeterministicVariable, DeterministicPMF + >>> from pyrenew.observation import CountsBySubpop, PoissonNoise + >>> import jax.numpy as jnp + >>> import numpyro + >>> + >>> delay_pmf = jnp.array([0.3, 0.4, 0.3]) + >>> counts_obs = CountsBySubpop( + ... ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + ... delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + ... noise=PoissonNoise(), + ... ) + >>> + >>> with numpyro.handlers.seed(rng_seed=42): + ... infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops + ... times = jnp.array([10, 15, 10, 15]) + ... subpop_indices = jnp.array([0, 0, 1, 1]) + ... sampled = counts_obs.sample( + ... infections=infections, + ... subpop_indices=subpop_indices, + ... times=times, + ... obs=None, + ... ) + """ + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"CountsBySubpop(ascertainment_rate_rv={self.ascertainment_rate_rv!r}, " + f"delay_distribution_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def infection_resolution(self) -> str: + """ + Return "subpop" for subpopulation-level observations. + + Returns + ------- + str + The string "subpop". + """ + return "subpop" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + times: ArrayLike, + obs: ArrayLike | None = None, + ) -> ObservationSample: + """ + Sample subpopulation-level counts with flexible indexing. + + Validation is performed before JAX tracing at runtime, + prior to calling this method. + + Parameters + ---------- + infections : ArrayLike + Subpopulation-level infections from the infection process. + Shape: (n_days, n_subpops) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + obs : ArrayLike | None + Observed counts (n_obs,), or None for prior sampling. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned counts) and + `predicted` (predicted counts before noise, shape: n_days x n_subpops). + """ + # Compute predicted counts for all subpops + predicted_counts_all = self._predicted_obs(infections) + + self._deterministic("predicted_counts_by_subpop", predicted_counts_all) + + # Replace NaN padding with 0 for distribution creation + predicted_counts_safe = jnp.nan_to_num(predicted_counts_all, nan=0.0) + predicted_obs = predicted_counts_safe[times, subpop_indices] + + observed = self.noise.sample( + name="counts_by_subpop", + predicted=predicted_obs, + obs=obs, + ) + + return ObservationSample(observed=observed, predicted=predicted_counts_all) diff --git a/pyrenew/observation/measurements.py b/pyrenew/observation/measurements.py new file mode 100644 index 00000000..82043e65 --- /dev/null +++ b/pyrenew/observation/measurements.py @@ -0,0 +1,144 @@ +# numpydoc ignore=GL08 +""" +Continuous measurement observation processes. + +Abstract base for any population-level continuous measurements (wastewater, +air quality, serology, etc.) with signal-specific processing. +""" + +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.observation.noise import MeasurementNoise +from pyrenew.observation.types import ObservationSample + + +class Measurements(BaseObservationProcess): + """ + Abstract base for continuous measurement observations. + + Subclasses implement signal-specific transformations from infections + to predicted measurement values, then add measurement noise. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics for wastewater). + noise : MeasurementNoise + Noise model for continuous measurements + (e.g., HierarchicalNormalNoise). + + Notes + ----- + Subclasses must implement ``_predicted_obs()`` according to their + specific signal processing (e.g., wastewater shedding kinetics, + dilution factors, etc.). + + See Also + -------- + pyrenew.observation.noise.HierarchicalNormalNoise : + Suitable noise model for sensor-level measurements + pyrenew.observation.base.BaseObservationProcess : + Parent class with common observation utilities + """ + + def __init__( + self, + temporal_pmf_rv: RandomVariable, + noise: MeasurementNoise, + ) -> None: + """ + Initialize measurement observation base. + + Parameters + ---------- + temporal_pmf_rv : RandomVariable + Temporal distribution PMF (e.g., shedding kinetics). + noise : MeasurementNoise + Noise model (e.g., HierarchicalNormalNoise with sensor effects). + """ + super().__init__(temporal_pmf_rv=temporal_pmf_rv) + self.noise = noise + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"{self.__class__.__name__}(" + f"temporal_pmf_rv={self.temporal_pmf_rv!r}, " + f"noise={self.noise!r})" + ) + + def infection_resolution(self) -> str: + """ + Return "subpop" for measurement observations. + + Measurement observations require subpopulation-level infections + because each measurement corresponds to a specific catchment area. + + Returns + ------- + str + ``"subpop"`` + """ + return "subpop" + + def sample( + self, + infections: ArrayLike, + subpop_indices: ArrayLike, + sensor_indices: ArrayLike, + times: ArrayLike, + obs: ArrayLike | None, + n_sensors: int, + ) -> ObservationSample: + """ + Sample measurements from observed sensors. + + This method does not perform runtime validation of index values + (times, subpop_indices, sensor_indices). Validate observation data + before sampling. + + Transforms infections to predicted values via signal-specific processing + (``_predicted_obs``), then applies noise model. + + Parameters + ---------- + infections : ArrayLike + Infections from the infection process. + Shape: (n_days, n_subpops) + subpop_indices : ArrayLike + Subpopulation index for each observation (0-indexed). + Shape: (n_obs,) + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). + Shape: (n_obs,) + times : ArrayLike + Day index for each observation (0-indexed). + Shape: (n_obs,) + obs : ArrayLike | None + Observed measurements (n_obs,), or None for prior sampling. + n_sensors : int + Total number of measurement sensors. + + Returns + ------- + ObservationSample + Named tuple with `observed` (sampled/conditioned measurements) and + `predicted` (predicted values before noise, shape: n_days x n_subpops). + """ + predicted_values = self._predicted_obs(infections) + + self._deterministic("predicted_log_conc", predicted_values) + + predicted_obs = predicted_values[times, subpop_indices] + + observed = self.noise.sample( + name="concentrations", + predicted=predicted_obs, + obs=obs, + sensor_indices=sensor_indices, + n_sensors=n_sensors, + ) + + return ObservationSample(observed=observed, predicted=predicted_values) diff --git a/pyrenew/observation/noise.py b/pyrenew/observation/noise.py new file mode 100644 index 00000000..9ce282e7 --- /dev/null +++ b/pyrenew/observation/noise.py @@ -0,0 +1,381 @@ +# numpydoc ignore=GL08 +""" +Noise models for observation processes. + +Provides composable noise strategies for count and measurement observations, +separating the noise distribution from the observation structure. + +Count Noise +----------- +- ``PoissonNoise``: Equidispersed counts (variance = mean). No parameters. +- ``NegativeBinomialNoise``: Overdispersed counts relative to Poisson (variance > mean). + Takes ``concentration_rv`` (higher concentration = less overdispersed, more Poisson-like). + +Measurement Noise +----------------- +- ``HierarchicalNormalNoise``: Normal noise with hierarchical sensor effects. + Takes ``sensor_mode_prior_rv`` and ``sensor_sd_prior_rv`` for sensor-level + bias and variability. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike + +from pyrenew.metaclass import RandomVariable + +_EPSILON = 1e-10 + + +class CountNoise(ABC): + """ + Abstract base for count observation noise models. + + Defines how discrete count observations are distributed around predicted values. + """ + + @abstractmethod + def sample( + self, + name: str, + predicted: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample count observations given predicted counts. + + Parameters + ---------- + name : str + Numpyro sample site name. + predicted : ArrayLike + Predicted count values (non-negative). + obs : ArrayLike | None + Observed counts for conditioning, or None for prior sampling. + + Returns + ------- + ArrayLike + Sampled or conditioned counts, same shape as predicted. + """ + pass # pragma: no cover + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass # pragma: no cover + + +class PoissonNoise(CountNoise): + """ + Poisson noise for equidispersed counts (variance = mean). + """ + + def __init__(self) -> None: + """Initialize Poisson noise (no parameters).""" + pass + + def __repr__(self) -> str: + """Return string representation.""" + return "PoissonNoise()" + + def validate(self) -> None: + """Validate Poisson noise (always valid).""" + pass + + def sample( + self, + name: str, + predicted: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Poisson distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + predicted : ArrayLike + Predicted count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Poisson-distributed counts. + """ + return numpyro.sample( + name, + dist.Poisson(rate=predicted + _EPSILON), + obs=obs, + ) + + +class NegativeBinomialNoise(CountNoise): + """ + Negative Binomial noise for overdispersed counts (variance > mean). + + Uses NB2 parameterization. Higher concentration reduces overdispersion. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + + Notes + ----- + The NB2 parameterization has variance = mean + mean^2 / concentration. + As concentration -> infinity, this approaches Poisson. + """ + + def __init__(self, concentration_rv: RandomVariable) -> None: + """ + Initialize Negative Binomial noise. + + Parameters + ---------- + concentration_rv : RandomVariable + Concentration parameter (must be > 0). + Higher values reduce overdispersion. + """ + self.concentration_rv = concentration_rv + + def __repr__(self) -> str: + """Return string representation.""" + return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})" + + def validate(self) -> None: + """ + Validate concentration is positive. + + Raises + ------ + ValueError + If concentration <= 0. + """ + concentration = self.concentration_rv() + if jnp.any(concentration <= 0): + raise ValueError( + f"NegativeBinomialNoise: concentration must be positive, " + f"got {float(concentration)}" + ) + + def sample( + self, + name: str, + predicted: ArrayLike, + obs: ArrayLike | None = None, + ) -> ArrayLike: + """ + Sample from Negative Binomial distribution. + + Parameters + ---------- + name : str + Numpyro sample site name. + predicted : ArrayLike + Predicted count values. + obs : ArrayLike | None + Observed counts for conditioning. + + Returns + ------- + ArrayLike + Negative Binomial-distributed counts. + """ + concentration = self.concentration_rv() + return numpyro.sample( + name, + dist.NegativeBinomial2( + mean=predicted + _EPSILON, + concentration=concentration, + ), + obs=obs, + ) + + +class MeasurementNoise(ABC): + """ + Abstract base for continuous measurement noise models. + + Defines how continuous observations are distributed around predicted values. + """ + + @abstractmethod + def sample( + self, + name: str, + predicted: ArrayLike, + obs: ArrayLike | None = None, + **kwargs, + ) -> ArrayLike: + """ + Sample continuous observations given predicted values. + + Parameters + ---------- + name : str + Numpyro sample site name. + predicted : ArrayLike + Predicted measurement values. + obs : ArrayLike | None + Observed measurements for conditioning, or None for prior sampling. + **kwargs + Additional context (e.g., sensor indices). + + Returns + ------- + ArrayLike + Sampled or conditioned measurements, same shape as predicted. + """ + pass # pragma: no cover + + @abstractmethod + def validate(self) -> None: + """ + Validate noise model parameters. + + Raises + ------ + ValueError + If parameters are invalid. + """ + pass # pragma: no cover + + +class HierarchicalNormalNoise(MeasurementNoise): + """ + Normal noise with hierarchical sensor-level effects. + + Observation model: ``obs ~ Normal(predicted + sensor_mode, sensor_sd)`` + where sensor_mode and sensor_sd are sampled per-sensor within a plate. + + Parameters + ---------- + sensor_mode_rv : RandomVariable + Prior for sensor-level modes (log-scale biases). + Sampled once per sensor within a plate context. + Example: ``DistributionalVariable("mode", dist.Normal(0, 0.5))`` + sensor_sd_rv : RandomVariable + Prior for sensor-level SDs (should be > 0). + Sampled once per sensor within a plate context. + Example: ``DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05))`` + + Notes + ----- + Expects data already on log scale for wastewater applications. + + The sensor-level parameters are sampled within a numpyro plate context, + so any standard RandomVariable can be used (no special interface required). + + Examples + -------- + >>> from pyrenew.randomvariable import DistributionalVariable + >>> import numpyro.distributions as dist + >>> + >>> noise = HierarchicalNormalNoise( + ... sensor_mode_rv=DistributionalVariable("mode", dist.Normal(0, 0.5)), + ... sensor_sd_rv=DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.05)), + ... ) + """ + + def __init__( + self, + sensor_mode_rv: RandomVariable, + sensor_sd_rv: RandomVariable, + ) -> None: + """ + Initialize hierarchical Normal noise. + + Parameters + ---------- + sensor_mode_rv : RandomVariable + Prior for sensor-level modes (log-scale biases). + Sampled once per sensor within a plate context. + sensor_sd_rv : RandomVariable + Prior for sensor-level SDs (should be > 0). + Sampled once per sensor within a plate context. + """ + self.sensor_mode_rv = sensor_mode_rv + self.sensor_sd_rv = sensor_sd_rv + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"HierarchicalNormalNoise(" + f"sensor_mode_rv={self.sensor_mode_rv!r}, " + f"sensor_sd_rv={self.sensor_sd_rv!r})" + ) + + def validate(self) -> None: + """ + Validate noise parameters. + + Notes + ----- + Full validation requires n_groups, which is only available during sample(). + """ + pass + + def sample( + self, + name: str, + predicted: ArrayLike, + obs: ArrayLike | None = None, + *, + sensor_indices: ArrayLike, + n_sensors: int, + ) -> ArrayLike: + """ + Sample from Normal distribution with sensor-level hierarchical effects. + + Parameters + ---------- + name : str + Numpyro sample site name. + predicted : ArrayLike + Predicted log-scale measurement values. + Shape: (n_obs,) + obs : ArrayLike | None + Observed log-scale measurements for conditioning. + Shape: (n_obs,) + sensor_indices : ArrayLike + Sensor index for each observation (0-indexed). + Shape: (n_obs,) + n_sensors : int + Total number of sensors. + + Returns + ------- + ArrayLike + Normal distributed measurements with hierarchical sensor effects. + Shape: (n_obs,) + + Raises + ------ + ValueError + If sensor_sd samples non-positive values. + """ + with numpyro.plate("sensor", n_sensors): + sensor_mode = self.sensor_mode_rv() + sensor_sd = self.sensor_sd_rv() + + loc = predicted + sensor_mode[sensor_indices] + scale = sensor_sd[sensor_indices] + + return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs) diff --git a/pyrenew/observation/types.py b/pyrenew/observation/types.py new file mode 100644 index 00000000..f8a1163c --- /dev/null +++ b/pyrenew/observation/types.py @@ -0,0 +1,28 @@ +# numpydoc ignore=GL08 +""" +Return types for observation processes. + +Named tuples providing structured access to observation process outputs. +""" + +from typing import NamedTuple + +from jax.typing import ArrayLike + + +class ObservationSample(NamedTuple): + """ + Return type for observation process sample() methods. + + Attributes + ---------- + observed : ArrayLike + Sampled or conditioned observations. Shape depends on the + observation process and indexing. + predicted : ArrayLike + Predicted values before noise is applied. Useful for + diagnostics and posterior predictive checks. + """ + + observed: ArrayLike + predicted: ArrayLike diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 00000000..bc95fd10 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,397 @@ +""" +Shared pytest fixtures for PyRenew tests. + +This module provides reusable fixtures for creating observation processes, +test data, and common configurations used across multiple test files. +""" + +import jax.numpy as jnp +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import Counts, NegativeBinomialNoise +from pyrenew.randomvariable import DistributionalVariable + +# ============================================================================= +# PMF Fixtures +# ============================================================================= + + +@pytest.fixture +def simple_delay_pmf(): + """ + Simple 1-day delay PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_delay_pmf(): + """ + Short 2-day delay PMF. + + Returns + ------- + jnp.ndarray + A 2-element PMF array. + """ + return jnp.array([0.5, 0.5]) + + +@pytest.fixture +def medium_delay_pmf(): + """ + Medium 4-day delay PMF. + + Returns + ------- + jnp.ndarray + A 4-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.4, 0.2]) + + +@pytest.fixture +def realistic_delay_pmf(): + """ + Realistic 10-day delay PMF (shifted gamma-like). + + Returns + ------- + jnp.ndarray + A 10-element PMF array with gamma-like shape. + """ + return jnp.array([0.01, 0.05, 0.10, 0.15, 0.20, 0.20, 0.15, 0.08, 0.04, 0.02]) + + +@pytest.fixture +def long_delay_pmf(): + """ + Long 10-day delay PMF for edge case testing. + + Returns + ------- + jnp.ndarray + A 10-element PMF array. + """ + return jnp.array([0.05, 0.1, 0.15, 0.2, 0.2, 0.15, 0.1, 0.03, 0.01, 0.01]) + + +@pytest.fixture +def simple_shedding_pmf(): + """ + Simple 1-day shedding PMF (no delay). + + Returns + ------- + jnp.ndarray + A single-element PMF array representing no shedding delay. + """ + return jnp.array([1.0]) + + +@pytest.fixture +def short_shedding_pmf(): + """ + Short 3-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 3-element PMF array. + """ + return jnp.array([0.3, 0.4, 0.3]) + + +@pytest.fixture +def medium_shedding_pmf(): + """ + Medium 5-day shedding PMF. + + Returns + ------- + jnp.ndarray + A 5-element PMF array. + """ + return jnp.array([0.1, 0.3, 0.3, 0.2, 0.1]) + + +# ============================================================================= +# Sensor Prior Fixtures +# ============================================================================= + + +@pytest.fixture +def sensor_mode_rv(): + """ + Standard normal prior for sensor modes. + + Returns + ------- + DistributionalVariable + A normal prior with standard deviation 0.5. + """ + return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)) + + +@pytest.fixture +def sensor_mode_rv_tight(): + """ + Tight normal prior for deterministic-like behavior. + + Returns + ------- + DistributionalVariable + A normal prior with small standard deviation 0.01. + """ + return DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)) + + +@pytest.fixture +def sensor_sd_rv(): + """ + Standard truncated normal prior for sensor standard deviations. + + Returns + ------- + DistributionalVariable + A truncated normal prior for sensor standard deviations. + """ + return DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10) + ) + + +@pytest.fixture +def sensor_sd_rv_tight(): + """ + Tight truncated normal prior for deterministic-like behavior. + + Returns + ------- + DistributionalVariable + A truncated normal prior with small scale for tight behavior. + """ + return DistributionalVariable( + "ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.005) + ) + + +# ============================================================================= +# Counts Process Fixtures +# ============================================================================= + + +@pytest.fixture +def counts_process(simple_delay_pmf): + """ + Standard Counts observation process with simple delay. + + Returns + ------- + Counts + A Counts observation process with no delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + + +@pytest.fixture +def counts_process_medium_delay(medium_delay_pmf): + """ + Counts observation process with medium delay. + + Returns + ------- + Counts + A Counts observation process with 4-day delay. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", medium_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 50.0)), + ) + + +@pytest.fixture +def counts_process_realistic(realistic_delay_pmf): + """ + Counts observation process with realistic delay and ascertainment. + + Returns + ------- + Counts + A Counts observation process with realistic parameters. + """ + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.005), + delay_distribution_rv=DeterministicPMF("delay", realistic_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 100.0)), + ) + + +class CountsProcessFactory: + """Factory for creating Counts processes with custom parameters.""" + + @staticmethod + def create( + delay_pmf=None, + ascertainment_rate=0.01, + concentration=10.0, + ): + """ + Create a Counts process with specified parameters. + + Returns + ------- + Counts + A Counts observation process with the specified parameters. + """ + if delay_pmf is None: + delay_pmf = jnp.array([1.0]) + return Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", ascertainment_rate), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", concentration)), + ) + + +@pytest.fixture +def counts_factory(): + """ + Factory fixture for creating custom Counts processes. + + Returns + ------- + CountsProcessFactory + A factory for creating Counts processes. + """ + return CountsProcessFactory() + + +# ============================================================================= +# Infection Fixtures +# ============================================================================= + + +@pytest.fixture +def constant_infections(): + """ + Constant infections array (30 days, 100 infections/day). + + Returns + ------- + jnp.ndarray + A 1D array of shape (30,) with constant value 100. + """ + return jnp.ones(30) * 100 + + +@pytest.fixture +def constant_infections_2d(): + """ + Constant infections array for 2 subpopulations. + + Returns + ------- + jnp.ndarray + A 2D array of shape (30, 2) with constant value 100. + """ + return jnp.ones((30, 2)) * 100 + + +def make_infections(n_days, n_subpops=None, value=100.0): + """ + Create infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + n_subpops : int, optional + Number of subpopulations (None for 1D array) + value : float + Constant infection value + + Returns + ------- + jnp.ndarray + Infections array + """ + if n_subpops is None: + return jnp.ones(n_days) * value + return jnp.ones((n_days, n_subpops)) * value + + +def make_spike_infections(n_days, spike_day, spike_value=1000.0, n_subpops=None): + """ + Create spike infection arrays for testing. + + Parameters + ---------- + n_days : int + Number of days + spike_day : int + Day of the spike + spike_value : float + Value at spike + n_subpops : int, optional + Number of subpopulations + + Returns + ------- + jnp.ndarray + Infections array with spike + """ + if n_subpops is None: + infections = jnp.zeros(n_days) + return infections.at[spike_day].set(spike_value) + infections = jnp.zeros((n_days, n_subpops)) + return infections.at[spike_day, :].set(spike_value) + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py new file mode 100644 index 00000000..33f4eade --- /dev/null +++ b/test/test_observation_counts.py @@ -0,0 +1,560 @@ +""" +Unit tests for Counts (aggregated count observations). + +These tests validate the count observation process implementation. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF, DeterministicVariable +from pyrenew.observation import ( + Counts, + CountsBySubpop, + NegativeBinomialNoise, + PoissonNoise, +) +from pyrenew.observation.count_observations import _CountBase +from pyrenew.randomvariable import DistributionalVariable + + +def create_mock_infections( + n_days: int, + peak_day: int = 10, + peak_value: float = 1000.0, + shape: str = "spike", +) -> jnp.ndarray: + """ + Create mock infection time series for testing. + + Parameters + ---------- + n_days : int + Number of days + peak_day : int + Day of peak infections + peak_value : float + Peak infection value + shape : str + Shape of the curve: "spike", "constant", or "decay" + + Returns + ------- + jnp.ndarray + Array of infections of shape (n_days,) + """ + if shape == "spike": + infections = jnp.zeros(n_days) + infections = infections.at[peak_day].set(peak_value) + elif shape == "constant": + infections = jnp.ones(n_days) * peak_value + elif shape == "decay": + infections = peak_value * jnp.exp(-jnp.arange(n_days) / 20.0) + else: + raise ValueError(f"Unknown shape: {shape}") + + return infections + + +class TestCountsBasics: + """Test basic functionality of aggregated count observation process.""" + + def test_sample_returns_correct_shape(self, counts_process): + """Test that sample returns correct shape.""" + infections = jnp.ones(30) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] > 0 + assert result.observed.ndim == 1 + assert result.predicted.shape == infections.shape + + def test_delay_convolution(self, counts_factory, short_delay_pmf): + """Test that delay is properly applied.""" + process = counts_factory.create(delay_pmf=short_delay_pmf) + + infections = jnp.zeros(30) + infections = infections.at[10].set(1000) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + + # Timeline alignment: output length equals input length + assert result.observed.shape[0] == len(infections) + # First len(delay_pmf)-1 days are NaN (appear as -1 after NegativeBinomial sampling) + assert jnp.all(result.observed[1:] >= 0) + assert jnp.sum(result.observed[result.observed >= 0]) > 0 + + def test_ascertainment_scaling(self, counts_factory, simple_delay_pmf): + """Test that ascertainment rate properly scales counts.""" + infections = jnp.ones(20) * 100 + + results = [] + for rate_value in [0.01, 0.02, 0.05]: + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + ascertainment_rate=rate_value, + ) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + results.append(jnp.mean(result.observed)) + + # Higher ascertainment rate should lead to more counts + assert results[1] > results[0] + assert results[2] > results[1] + + def test_negative_binomial_observation(self, counts_factory, simple_delay_pmf): + """Test that negative binomial observation is used.""" + process = counts_factory.create( + delay_pmf=simple_delay_pmf, + concentration=5.0, + ) + + infections = jnp.ones(20) * 100 + + samples = [] + for seed in range(5): + with numpyro.handlers.seed(rng_seed=seed): + result = process.sample( + infections=infections, + obs=None, + ) + samples.append(jnp.sum(result.observed)) + + # Should have some variability due to negative binomial sampling + assert jnp.std(jnp.array(samples)) > 0 + + +class TestCountsWithPriors: + """Test aggregated count observation with uncertain parameters.""" + + def test_with_stochastic_ascertainment(self, short_shedding_pmf): + """Test with uncertain ascertainment rate parameter.""" + delay = DeterministicPMF("delay", jnp.array([0.2, 0.5, 0.3])) + ascertainment = DistributionalVariable("ihr", dist.Beta(2, 100)) + concentration = DeterministicVariable("conc", 10.0) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] > 0 + # Skip NaN padding + valid_counts = result.observed[2:] + assert jnp.all(valid_counts >= 0) + + def test_with_stochastic_concentration(self, simple_delay_pmf): + """Test with uncertain concentration parameter.""" + delay = DeterministicPMF("delay", simple_delay_pmf) + ascertainment = DeterministicVariable("ihr", 0.01) + concentration = DistributionalVariable("conc", dist.HalfNormal(10.0)) + + process = Counts( + ascertainment_rate_rv=ascertainment, + delay_distribution_rv=delay, + noise=NegativeBinomialNoise(concentration), + ) + + infections = jnp.ones(20) * 100 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) + + +class TestCountsEdgeCases: + """Test edge cases and error handling.""" + + def test_zero_infections(self, counts_process): + """Test with zero infections.""" + infections = jnp.zeros(20) + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] > 0 + + def test_small_infections(self, counts_process): + """Test with small infection values.""" + infections = jnp.ones(20) * 10 + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] > 0 + assert jnp.all(result.observed >= 0) + + def test_long_delay_distribution(self, counts_factory, long_delay_pmf): + """Test with longer delay distribution.""" + process = counts_factory.create(delay_pmf=long_delay_pmf) + + infections = create_mock_infections(40, peak_day=20, shape="spike") + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + + # Timeline alignment maintained + assert result.observed.shape[0] == infections.shape[0] + # Skip NaN padding: 10-day delay -> first 9 days are NaN + valid_counts = result.observed[9:] + assert jnp.sum(valid_counts) > 0 + + +class TestCountsSparseObservations: + """Test sparse observation support.""" + + def test_sparse_observations(self, counts_process): + """Test with sparse (irregular) observations.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Sparse observations: only days 5, 10, 15, 20 + times = jnp.array([5, 10, 15, 20]) + counts_data = jnp.array([10, 12, 8, 15]) + + with numpyro.handlers.seed(rng_seed=42): + result = counts_process.sample( + infections=infections, + obs=counts_data, + times=times, + ) + + assert result.observed.shape == times.shape + assert jnp.allclose(result.observed, counts_data) + + def test_sparse_vs_dense_sampling(self, counts_process): + """Test that sparse sampling gives different output shape than dense.""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + + # Dense: prior sampling (obs=None, no times) + with numpyro.handlers.seed(rng_seed=42): + dense_result = counts_process.sample( + infections=infections, + obs=None, + ) + + # Sparse with observed data: only some days + times = jnp.array([5, 10, 15, 20]) + sparse_obs_data = jnp.array([10, 12, 8, 15]) + with numpyro.handlers.seed(rng_seed=42): + sparse_result = counts_process.sample( + infections=infections, + obs=sparse_obs_data, + times=times, + ) + + # Dense prior produces full length output + assert dense_result.observed.shape == (n_days,) + + # Sparse observations produce output matching times shape + assert sparse_result.observed.shape == times.shape + assert jnp.allclose(sparse_result.observed, sparse_obs_data) + + def test_prior_sampling_ignores_times(self, counts_process): + """Test that times parameter is ignored when obs=None (prior sampling).""" + n_days = 30 + infections = jnp.ones(n_days) * 100 + times = jnp.array([5, 10, 15, 20]) + + # When obs=None, times is ignored - output is dense + with numpyro.handlers.seed(rng_seed=42): + result_with_times = counts_process.sample( + infections=infections, + obs=None, + times=times, + ) + + with numpyro.handlers.seed(rng_seed=42): + result_without_times = counts_process.sample( + infections=infections, + obs=None, + ) + + # Both should produce dense output of shape (n_days,) + assert result_with_times.observed.shape == (n_days,) + assert result_without_times.observed.shape == (n_days,) + # With same seed, outputs should be identical + assert jnp.allclose(result_with_times.observed, result_without_times.observed) + + +class TestCountsBySubpop: + """Test CountsBySubpop for subpopulation-level observations.""" + + def test_sample_returns_correct_shape(self): + """Test that CountsBySubpop sample returns correct shape.""" + delay_pmf = jnp.array([0.3, 0.4, 0.3]) + process = CountsBySubpop( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones((30, 3)) * 500 # 30 days, 3 subpops + times = jnp.array([10, 15, 10, 15]) + subpop_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + subpop_indices=subpop_indices, + times=times, + obs=None, + ) + + assert result.observed.shape == times.shape + assert result.predicted.shape == infections.shape + + def test_infection_resolution(self): + """Test that CountsBySubpop returns 'subpop' resolution.""" + delay_pmf = jnp.array([1.0]) + process = CountsBySubpop( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=PoissonNoise(), + ) + + assert process.infection_resolution() == "subpop" + + +class TestPoissonNoise: + """Test PoissonNoise model.""" + + def test_poisson_counts(self, simple_delay_pmf): + """Test Counts with Poisson noise.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + + infections = jnp.ones(20) * 1000 + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + obs=None, + ) + + assert result.observed.shape[0] == 20 + assert jnp.all(result.observed >= 0) + + +class TestCountBaseInternalMethods: + """Test internal _CountBase methods for coverage.""" + + def test_count_base_infection_resolution_raises(self, simple_delay_pmf): + """Test that _CountBase.infection_resolution() raises NotImplementedError.""" + + # Create a subclass that doesn't override infection_resolution + class IncompleteCountProcess(_CountBase): + """Incomplete count process for testing.""" + + def sample(self, **kwargs): + """Sample method stub.""" + pass + + process = IncompleteCountProcess( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises( + NotImplementedError, match="Subclasses must implement infection_resolution" + ): + process.infection_resolution() + + +class TestValidationMethods: + """Test validation methods for coverage.""" + + def test_validate_calls_all_validations(self, simple_delay_pmf): + """Test that validate() calls all necessary validations.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # Should not raise + process.validate() + + def test_validate_invalid_ascertainment_rate_negative(self, simple_delay_pmf): + """Test that validate raises for negative ascertainment rate.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", -0.1), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_validate_invalid_ascertainment_rate_greater_than_one( + self, simple_delay_pmf + ): + """Test that validate raises for ascertainment rate > 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 1.5), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + with pytest.raises(ValueError, match="ascertainment_rate_rv must be in"): + process.validate() + + def test_lookback_days(self, simple_delay_pmf, long_delay_pmf): + """Test lookback_days returns PMF length.""" + process_short = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_short.lookback_days() == 1 + + process_long = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", long_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process_long.lookback_days() == 10 + + def test_infection_resolution_counts(self, simple_delay_pmf): + """Test that Counts returns 'aggregate' resolution.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + assert process.infection_resolution() == "aggregate" + + +class TestNoiseRepr: + """Test noise model __repr__ methods.""" + + def test_poisson_noise_repr(self): + """Test PoissonNoise __repr__ method.""" + noise = PoissonNoise() + assert repr(noise) == "PoissonNoise()" + + def test_negative_binomial_noise_repr(self): + """Test NegativeBinomialNoise __repr__ method.""" + conc_rv = DeterministicVariable("conc", 10.0) + noise = NegativeBinomialNoise(conc_rv) + repr_str = repr(noise) + assert "NegativeBinomialNoise" in repr_str + assert "concentration_rv" in repr_str + + +class TestNoiseValidation: + """Test noise model validation methods.""" + + def test_poisson_noise_validate(self): + """Test PoissonNoise validate method.""" + noise = PoissonNoise() + # Should not raise - Poisson has no parameters to validate + noise.validate() + + def test_negative_binomial_noise_validate_success(self): + """Test NegativeBinomialNoise validate with valid concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 10.0)) + # Should not raise + noise.validate() + + def test_negative_binomial_noise_validate_zero_concentration(self): + """Test NegativeBinomialNoise validate with zero concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", 0.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + def test_negative_binomial_noise_validate_negative_concentration(self): + """Test NegativeBinomialNoise validate with negative concentration.""" + noise = NegativeBinomialNoise(DeterministicVariable("conc", -1.0)) + with pytest.raises(ValueError, match="concentration must be positive"): + noise.validate() + + +class TestBaseObservationProcessValidation: + """Test base observation process PMF validation.""" + + def test_validate_pmf_empty_array(self, simple_delay_pmf): + """Test that _validate_pmf raises for empty array.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + empty_pmf = jnp.array([]) + with pytest.raises(ValueError, match="must return non-empty array"): + process._validate_pmf(empty_pmf, "test_pmf") + + def test_validate_pmf_sum_not_one(self, simple_delay_pmf): + """Test that _validate_pmf raises for PMF not summing to 1.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([0.3, 0.3, 0.3]) # sums to 0.9 + with pytest.raises(ValueError, match="must sum to 1.0"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_validate_pmf_negative_values(self, simple_delay_pmf): + """Test that _validate_pmf raises for negative values.""" + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + bad_pmf = jnp.array([1.5, -0.5]) # sums to 1.0 but has negative + with pytest.raises(ValueError, match="must have non-negative values"): + process._validate_pmf(bad_pmf, "test_pmf") + + def test_get_minimum_observation_day(self): + """Test get_minimum_observation_day returns correct value.""" + delay_pmf = jnp.array([0.2, 0.5, 0.3]) # length 3 + process = Counts( + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", delay_pmf), + noise=NegativeBinomialNoise(DeterministicVariable("conc", 10.0)), + ) + # First valid day should be len(pmf) - 1 = 2 + assert process.get_minimum_observation_day() == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_observation_measurements.py b/test/test_observation_measurements.py new file mode 100644 index 00000000..b4ed3488 --- /dev/null +++ b/test/test_observation_measurements.py @@ -0,0 +1,256 @@ +""" +Unit tests for Measurements (continuous measurement observations). + +These tests validate the measurement observation process base class implementation. +""" + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.deterministic import DeterministicPMF +from pyrenew.observation import HierarchicalNormalNoise, Measurements +from pyrenew.observation.base import BaseObservationProcess +from pyrenew.randomvariable import DistributionalVariable + + +class ConcreteMeasurements(Measurements): + """Concrete implementation of Measurements for testing.""" + + def __init__(self, temporal_pmf_rv, noise, log10_scale=9.0): + """Initialize the concrete measurements for testing.""" + super().__init__(temporal_pmf_rv=temporal_pmf_rv, noise=noise) + self.log10_scale = log10_scale + + def validate(self) -> None: + """Validate parameters.""" + pmf = self.temporal_pmf_rv() + self._validate_pmf(pmf, "temporal_pmf_rv") + + def lookback_days(self) -> int: + """ + Return temporal PMF length. + + Returns + ------- + int + Length of the temporal PMF. + """ + return len(self.temporal_pmf_rv()) + + def _predicted_obs(self, infections): + """ + Simple predicted signal: log(convolution * scale). + + Returns + ------- + jnp.ndarray + Log-transformed predicted signal. + """ + pmf = self.temporal_pmf_rv() + + # Handle 2D infections (n_days, n_subpops) + if infections.ndim == 1: + infections = infections[:, jnp.newaxis] + + def convolve_col(col): # numpydoc ignore=GL08 + return self._convolve_with_alignment(col, pmf, 1.0)[0] + + import jax + + predicted = jax.vmap(convolve_col, in_axes=1, out_axes=1)(infections) + + # Apply log10 scaling (simplified from wastewater model) + log_predicted = jnp.log(predicted + 1e-10) + self.log10_scale * jnp.log(10) + + return log_predicted + + +class TestMeasurementsBase: + """Test Measurements abstract base class.""" + + def test_is_base_observation_process(self): + """Test that Measurements inherits from BaseObservationProcess.""" + assert issubclass(Measurements, BaseObservationProcess) + + def test_infection_resolution_is_subpop(self): + """Test that Measurements returns 'subpop' resolution.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + assert process.infection_resolution() == "subpop" + + +class TestHierarchicalNormalNoise: + """Test HierarchicalNormalNoise model.""" + + def test_repr(self): + """Test HierarchicalNormalNoise __repr__ method.""" + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + repr_str = repr(noise) + assert "HierarchicalNormalNoise" in repr_str + assert "sensor_mode_rv" in repr_str + assert "sensor_sd_rv" in repr_str + + def test_validate(self): + """Test HierarchicalNormalNoise validate method.""" + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + # Should not raise - validation is deferred to sample time + noise.validate() + + def test_sample_shape(self): + """Test that HierarchicalNormalNoise produces correct shape.""" + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) + sensor_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + predicted=predicted, + obs=None, + sensor_indices=sensor_indices, + n_sensors=2, + ) + + assert samples.shape == predicted.shape + + def test_sample_with_observations(self): + """Test that HierarchicalNormalNoise conditions on observations.""" + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + predicted = jnp.array([1.0, 2.0, 3.0, 4.0]) + obs = jnp.array([1.1, 2.1, 3.1, 4.1]) + sensor_indices = jnp.array([0, 0, 1, 1]) + + with numpyro.handlers.seed(rng_seed=42): + samples = noise.sample( + name="test", + predicted=predicted, + obs=obs, + sensor_indices=sensor_indices, + n_sensors=2, + ) + + # When obs is provided, samples should equal obs + assert jnp.allclose(samples, obs) + + +class TestConcreteMeasurements: + """Test concrete Measurements implementation.""" + + def test_repr(self): + """Test Measurements __repr__ method.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + repr_str = repr(process) + assert "ConcreteMeasurements" in repr_str + assert "temporal_pmf_rv" in repr_str + assert "noise" in repr_str + + def test_sample_shape(self): + """Test that sample returns correct shape.""" + shedding_pmf = jnp.array([0.3, 0.4, 0.3]) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.5)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.3, 0.15, low=0.05) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + # 30 days, 2 subpops + infections = jnp.ones((30, 2)) * 1000 + subpop_indices = jnp.array([0, 0, 1, 1]) + sensor_indices = jnp.array([0, 0, 1, 1]) + times = jnp.array([10, 15, 10, 15]) + + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=infections, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + times=times, + obs=None, + n_sensors=2, + ) + + assert result.observed.shape == times.shape + assert result.predicted.shape == infections.shape + + def test_predicted_obs_stored(self): + """Test that predicted_log_conc is stored as deterministic.""" + shedding_pmf = jnp.array([0.5, 0.5]) + sensor_mode_rv = DistributionalVariable("mode", dist.Normal(0, 0.01)) + sensor_sd_rv = DistributionalVariable( + "sd", dist.TruncatedNormal(0.01, 0.005, low=0.001) + ) + noise = HierarchicalNormalNoise(sensor_mode_rv, sensor_sd_rv) + + process = ConcreteMeasurements( + temporal_pmf_rv=DeterministicPMF("shedding", shedding_pmf), + noise=noise, + ) + + infections = jnp.ones((20, 2)) * 1000 + subpop_indices = jnp.array([0, 1]) + sensor_indices = jnp.array([0, 1]) + times = jnp.array([10, 10]) + + with numpyro.handlers.seed(rng_seed=42): + trace = numpyro.handlers.trace( + lambda: process.sample( + infections=infections, + subpop_indices=subpop_indices, + sensor_indices=sensor_indices, + times=times, + obs=None, + n_sensors=2, + ) + ).get_trace() + + assert "predicted_log_conc" in trace + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/test/test_observation_poisson.py b/test/test_observation_poisson.py index b9d975be..8b9c0716 100644 --- a/test/test_observation_poisson.py +++ b/test/test_observation_poisson.py @@ -20,3 +20,10 @@ def test_poisson_obs(): sim_pois = pois(mu=rates) testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) + + +def test_poisson_validate(): + """ + Check that PoissonObservation.validate() runs without error. + """ + PoissonObservation.validate() diff --git a/uv.lock b/uv.lock index f2477f56..6a0bfd17 100644 --- a/uv.lock +++ b/uv.lock @@ -949,6 +949,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7a/f0/8282d9641415e9e33df173516226b404d367a0fc55e1a60424a152913abc/mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d", size = 53481, upload-time = "2025-08-29T07:20:42.218Z" }, ] +[[package]] +name = "mizani" +version = "0.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "tzdata", marker = "sys_platform == 'emscripten' or sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/19/98f2bd61e5441b687e0a5d3b36041981cc032451f2d11472021b040d27fd/mizani-0.14.3.tar.gz", hash = "sha256:c2fb886b3c9e8109be5b8fd21e1130fba1f0a20230a987146240221209fc0ddd", size = 772470, upload-time = "2025-10-30T20:16:53.268Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/d2/4ffcaa27c8a4b4f9ad456da4821c76dfbdfada23e8210cd4d80e1eb3236a/mizani-0.14.3-py3-none-any.whl", hash = "sha256:6d2ca9b1b8366ff85668f0cc1b6095f1e702e26e66f132c4f02a949efa32a688", size = 133433, upload-time = "2025-10-30T20:16:51.218Z" }, +] + [[package]] name = "mkdocs" version = "1.6.1" @@ -1348,6 +1363,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, ] +[[package]] +name = "patsy" +version = "1.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/44/ed13eccdd0519eff265f44b670d46fbb0ec813e2274932dc1c0e48520f7d/patsy-1.0.2.tar.gz", hash = "sha256:cdc995455f6233e90e22de72c37fcadb344e7586fb83f06696f54d92f8ce74c0", size = 399942, upload-time = "2025-10-20T16:17:37.535Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/70/ba4b949bdc0490ab78d545459acd7702b211dfccf7eb89bbc1060f52818d/patsy-1.0.2-py2.py3-none-any.whl", hash = "sha256:37bfddbc58fcf0362febb5f54f10743f8b21dd2aa73dec7e7ef59d1b02ae668a", size = 233301, upload-time = "2025-10-20T16:17:36.563Z" }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -1427,6 +1454,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, ] +[[package]] +name = "plotnine" +version = "0.15.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, + { name = "mizani" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "statsmodels" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/14/3adedabe6b8710caee34e4ac9f4edc48218a381594ee1980c323b8866577/plotnine-0.15.2.tar.gz", hash = "sha256:ec2e4cdf2d022eb0dab63ef4aa0017ce0d84c60bd99d55093e72637fddf757e6", size = 6787690, upload-time = "2025-12-12T10:41:37.249Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/27/4e6ffe2f095fbfd6285343aa6114903a4cf011564b4f1f2bb706341472df/plotnine-0.15.2-py3-none-any.whl", hash = "sha256:7dc508bc51625b9b9f945e274d8ee4463cf30b280749190a5b707e6828003fa6", size = 1332822, upload-time = "2025-12-12T10:41:34.203Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -1582,6 +1626,7 @@ dev = [ { name = "mkdocstrings" }, { name = "mkdocstrings-python" }, { name = "nbconvert" }, + { name = "plotnine" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "pytest-mpl" }, @@ -1609,6 +1654,7 @@ dev = [ { name = "mkdocstrings", specifier = ">=0.30.0" }, { name = "mkdocstrings-python", specifier = ">=1.18.2" }, { name = "nbconvert", specifier = ">=7.16.6" }, + { name = "plotnine", specifier = ">=0.14.0" }, { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-cov", specifier = ">=6.3.0" }, { name = "pytest-mpl", specifier = ">=0.17.0" }, @@ -1971,6 +2017,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "statsmodels" +version = "0.14.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "patsy" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/81/e8d74b34f85285f7335d30c5e3c2d7c0346997af9f3debf9a0a9a63de184/statsmodels-0.14.6.tar.gz", hash = "sha256:4d17873d3e607d398b85126cd4ed7aad89e4e9d89fc744cdab1af3189a996c2a", size = 20689085, upload-time = "2025-12-05T23:08:39.522Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/59/a5aad5b0cc266f5be013db8cde563ac5d2a025e7efc0c328d83b50c72992/statsmodels-0.14.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47ee7af083623d2091954fa71c7549b8443168f41b7c5dce66510274c50fd73e", size = 10072009, upload-time = "2025-12-05T23:11:14.021Z" }, + { url = "https://files.pythonhosted.org/packages/53/dd/d8cfa7922fc6dc3c56fa6c59b348ea7de829a94cd73208c6f8202dd33f17/statsmodels-0.14.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa60d82e29fcd0a736e86feb63a11d2380322d77a9369a54be8b0965a3985f71", size = 9980018, upload-time = "2025-12-05T23:11:30.907Z" }, + { url = "https://files.pythonhosted.org/packages/ee/77/0ec96803eba444efd75dba32f2ef88765ae3e8f567d276805391ec2c98c6/statsmodels-0.14.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89ee7d595f5939cc20bf946faedcb5137d975f03ae080f300ebb4398f16a5bd4", size = 10060269, upload-time = "2025-12-05T23:11:46.338Z" }, + { url = "https://files.pythonhosted.org/packages/10/b9/fd41f1f6af13a1a1212a06bb377b17762feaa6d656947bf666f76300fc05/statsmodels-0.14.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:730f3297b26749b216a06e4327fe0be59b8d05f7d594fb6caff4287b69654589", size = 10324155, upload-time = "2025-12-05T23:12:01.805Z" }, + { url = "https://files.pythonhosted.org/packages/ee/0f/a6900e220abd2c69cd0a07e3ad26c71984be6061415a60e0f17b152ecf08/statsmodels-0.14.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f1c08befa85e93acc992b72a390ddb7bd876190f1360e61d10cf43833463bc9c", size = 10349765, upload-time = "2025-12-05T23:12:18.018Z" }, + { url = "https://files.pythonhosted.org/packages/98/08/b79f0c614f38e566eebbdcff90c0bcacf3c6ba7a5bbb12183c09c29ca400/statsmodels-0.14.6-cp313-cp313-win_amd64.whl", hash = "sha256:8021271a79f35b842c02a1794465a651a9d06ec2080f76ebc3b7adce77d08233", size = 9540043, upload-time = "2025-12-05T23:12:33.887Z" }, + { url = "https://files.pythonhosted.org/packages/71/de/09540e870318e0c7b58316561d417be45eff731263b4234fdd2eee3511a8/statsmodels-0.14.6-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:00781869991f8f02ad3610da6627fd26ebe262210287beb59761982a8fa88cae", size = 10069403, upload-time = "2025-12-05T23:12:48.424Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f0/63c1bfda75dc53cee858006e1f46bd6d6f883853bea1b97949d0087766ca/statsmodels-0.14.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:73f305fbf31607b35ce919fae636ab8b80d175328ed38fdc6f354e813b86ee37", size = 9989253, upload-time = "2025-12-05T23:13:05.274Z" }, + { url = "https://files.pythonhosted.org/packages/c1/98/b0dfb4f542b2033a3341aa5f1bdd97024230a4ad3670c5b0839d54e3dcab/statsmodels-0.14.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e443e7077a6e2d3faeea72f5a92c9f12c63722686eb80bb40a0f04e4a7e267ad", size = 10090802, upload-time = "2025-12-05T23:13:20.653Z" }, + { url = "https://files.pythonhosted.org/packages/34/0e/2408735aca9e764643196212f9069912100151414dd617d39ffc72d77eee/statsmodels-0.14.6-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3414e40c073d725007a6603a18247ab7af3467e1af4a5e5a24e4c27bc26673b4", size = 10337587, upload-time = "2025-12-05T23:13:37.597Z" }, + { url = "https://files.pythonhosted.org/packages/0f/36/4d44f7035ab3c0b2b6a4c4ebb98dedf36246ccbc1b3e2f51ebcd7ac83abb/statsmodels-0.14.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a518d3f9889ef920116f9fa56d0338069e110f823926356946dae83bc9e33e19", size = 10363350, upload-time = "2025-12-05T23:13:53.08Z" }, + { url = "https://files.pythonhosted.org/packages/26/33/f1652d0c59fa51de18492ee2345b65372550501ad061daa38f950be390b6/statsmodels-0.14.6-cp314-cp314-win_amd64.whl", hash = "sha256:151b73e29f01fe619dbce7f66d61a356e9d1fe5e906529b78807df9189c37721", size = 9588010, upload-time = "2025-12-05T23:14:07.28Z" }, +] + [[package]] name = "tinycss2" version = "1.4.0"