Skip to content

Python library to work in single cell gene expression prediction. Work in Progress.

Notifications You must be signed in to change notification settings

GabrielCabas/SCPred

Repository files navigation

ObesitySCPred

Python pipeline for single-cell gene expression prediction in the obesity challenge setting: given a control cell type (pre-adipocytes) and the perturbed gene (condition), the model predicts the expression profile after perturbation.

Description

The project trains a Conditional Denoising Autoencoder (CDAE) that uses:

  • Control expression (unperturbed pre-adipocyte cells) as input.
  • Gene Ontology (GO) embeddings of the perturbed gene as condition.
  • Output: prediction of the expression change relative to control.

Evaluation uses the Pearson delta metric between predictions and ground truth on a local test set.

Project structure

ObesitySCPred/
├── data/                    # Data (h5ad, gene lists, embeddings)
│   ├── default/             # Main dataset (train/test, genes_to_predict, etc.)
│   ├── go_terms.csv         # GO terms per human gene (generated)
│   ├── go_terms_unique.csv  # Go terms unique (one GO term per row)
│   ├── pca_embeddings.h5    # GO embeddings reduced with PCA (generated)
│   ├── genes_list.csv       # Gene list with embeddings (generated)
│   └── deepseek-r1_embeddings.h5   # Raw GO embeddings (generated)
├── scpred/                  # Main module
│   ├── preprocess.py        # Preprocess (train/test split), PreprocessGO (PCA)
│   ├── autoencoder.py       # PerturbDataset, ConditionalDAE, ModelTrainer
│   ├── go_terms.py          # GO download from Ensembl
│   └── metrics.py           # Pearson delta for evaluation
├── train_cdae.py            # CDAE training + test + metric
├── embed_pca.py             # PCA on GO embeddings → pca_embeddings.h5
├── get_go_terms.py          # Download go_terms.csv / go_term
├── baselines/               # Baseline notebooks, downloaded from Crunch-Lab
└── notebooks/               # EDA, GO exploration.

Requirements

  • Python 3.10+
  • Main dependencies: anndata, scanpy, numpy, pandas, torch, h5py, scipy, scikit-learn, tqdm, requests

Suggested install (e.g. with pip):

pip install anndata scanpy numpy pandas torch h5py scipy scikit-learn tqdm requests

Data

Expected input in data/default/

File Description
obesity_challenge_1.h5ad Training data (cells × genes, obs["gene"] = perturbed gene, obs["pre_adipo"] for control)
obesity_challenge_1_local_gtruth.h5ad Local test set with ground truth
genes_to_predict.txt Genes to predict
predict_perturbations.txt Perturbations in test set

Files produced by the pipeline

  1. GO terms (get_go_terms.py): data/go_terms.csv, data/go_terms_unique.csv (Ensembl BioMart).
  2. GO embeddings: assumes data/deepseek-r1_embeddings.h5 (embedding per GO term).
  3. PCA (embed_pca.py): data/pca_embeddings.h5, data/genes_list.csv (same order as embeddings).

Usage pipeline

1. Fetch GO terms (optional if you already have the CSVs)

python get_go_terms.py

Produces data/go_terms.csv and data/go_terms_unique.csv.

python embed.py

Produces data/deepseek-r1_embeddings.h5.

2. Reduce GO embeddings with PCA

Requires data/deepseek-r1_embeddings.h5, go_terms.csv, go_terms_unique.csv. Produces pca_embeddings.h5 and genes_list.csv.

python embed_pca.py

3. Train the CDAE and evaluate

Reads data/default/ and data/pca_embeddings.h5, data/genes_list.csv. Trains, runs prediction on test, and computes Pearson delta.

python train_cdae.py

Expected console output: a line like Pearson delta score: 0.xxxx. The script also builds an AnnData of predictions (perturbed gene in obs, real expression in X, predictions in layers["predicted"]).

Methodology

Data and preprocessing

  • Train/validation split: By perturbed gene (e.g. 80% train, 20% val), not by cells. The control condition "NC" is always in the training set. Genes to keep are the union of genes to predict, test perturbations, and genes observed in training; cells are filtered so that each cell’s obs["gene"] is either one of those genes or "NC".
  • Control representation:
    • Training: For each sample, a random control pre-adipocyte cell is used as input (denoising setup).
    • Validation and test: The centroid of control pre-adipocytes (mean expression) is used as the single control input for all samples.
  • Target: The model predicts the delta of expression with respect to a reference:
    delta = x_perturbed - perturbed_centroid,
    where perturbed_centroid is the mean expression over all perturbed (non‑NC) cells in the training set. At inference, predicted expression is recovered as control + predicted delta (if needed).

GO conditioning

  • Gene Ontology (GO) terms are fetched from Ensembl BioMart (human genes) and stored per gene (go_terms.csv / go_terms_unique.csv).
  • Embeddings: Each GO term has an embedding (e.g. from embed.pydeepseek-r1_embeddings.h5). For each gene, embeddings of its GO terms are averaged to obtain one vector per gene.
  • Dimensionality reduction: PCA is applied (e.g. 95% variance in embed_pca.py) to get a fixed-size conditioning vector per gene (pca_embeddings.h5, genes_list.csv). For the control gene "NC", the conditioning vector is set to zero.

Model architecture

  • ConditionalDAE (in scpred.autoencoder):
    • Encoder: Linear(n_genes → 512) → LayerNorm → ReLU → Linear(512 → latent_dim). Default latent_dim = 128.
    • FiLM (Feature-wise Linear Modulation): The GO embedding (dimension go_emb_dim, e.g. PCA dimension) is mapped to scale and shift:
      out = gamma * h + beta,
      with gamma, beta of dimension latent_dim. This conditions the latent representation on the perturbed gene.
    • Decoder: Linear(latent_dim → 512) → LayerNorm → ReLU → Linear(512 → n_genes).
  • Input: Control expression vector (one cell or centroid). Output: Predicted delta (same dimension as n_genes), not raw expression.

Training

  • Loss (per batch):
    loss = loss_recon + 0.2 * loss_corr + 0.5 * loss_delta
    • loss_recon: Smooth L1 between predicted delta and target delta.
    • loss_corr: 1 - mean(Pearson_correlation(pred_delta, target_delta)) over samples (batch).
    • loss_delta: Smooth L1 between (pred - x_ctrl) and (target_delta - x_ctrl) to align predicted change w.r.t. control with the true change.
  • Denoising: The control input is corrupted during training: a random fraction of entries is replaced by Gaussian noise. The fraction and noise std increase with epoch (e.g. epochs 0–4: 10% mask, std 0.05; 5–9: 15%, 0.1; 10+: 20%, 0.2). Validation uses uncorrupted control.
  • Optimizer: AdamW (lr=1e-3, weight_decay=1e-5). Gradient clipping (max norm 1.0). ReduceLROnPlateau on validation correlation (max, factor=0.5, patience=10).
  • Early stopping: Training stops if validation correlation does not improve for 10 consecutive epochs. The best checkpoint (by validation correlation) is restored.

Evaluation (Pearson delta)

  • Per perturbation: For each perturbed gene in the test set, cells are averaged in ground truth and in the prediction matrix, giving two vectors of length n_genes.
  • Centroid correction: A reference vector perturbed_centroid_train (stored in gtruth.uns) is subtracted from both vectors. This centers the comparison on the training perturbed centroid.
  • Score: Pearson correlation between these two centered vectors is computed. The Pearson delta metric is the mean of these correlations over all test perturbations.
  • Gene subset: The metric can be computed on a subset of genes (e.g. 1000 genes sampled with a fixed seed for reproducibility). Implementation: scpred.metrics.compute_metric_pearson_delta.

License

Adjust according to challenge usage (CrunchDAO / challenge organizers).

About

Python library to work in single cell gene expression prediction. Work in Progress.

Resources

Stars

Watchers

Forks

Packages

No packages published