Python pipeline for single-cell gene expression prediction in the obesity challenge setting: given a control cell type (pre-adipocytes) and the perturbed gene (condition), the model predicts the expression profile after perturbation.
The project trains a Conditional Denoising Autoencoder (CDAE) that uses:
- Control expression (unperturbed pre-adipocyte cells) as input.
- Gene Ontology (GO) embeddings of the perturbed gene as condition.
- Output: prediction of the expression change relative to control.
Evaluation uses the Pearson delta metric between predictions and ground truth on a local test set.
ObesitySCPred/
├── data/ # Data (h5ad, gene lists, embeddings)
│ ├── default/ # Main dataset (train/test, genes_to_predict, etc.)
│ ├── go_terms.csv # GO terms per human gene (generated)
│ ├── go_terms_unique.csv # Go terms unique (one GO term per row)
│ ├── pca_embeddings.h5 # GO embeddings reduced with PCA (generated)
│ ├── genes_list.csv # Gene list with embeddings (generated)
│ └── deepseek-r1_embeddings.h5 # Raw GO embeddings (generated)
├── scpred/ # Main module
│ ├── preprocess.py # Preprocess (train/test split), PreprocessGO (PCA)
│ ├── autoencoder.py # PerturbDataset, ConditionalDAE, ModelTrainer
│ ├── go_terms.py # GO download from Ensembl
│ └── metrics.py # Pearson delta for evaluation
├── train_cdae.py # CDAE training + test + metric
├── embed_pca.py # PCA on GO embeddings → pca_embeddings.h5
├── get_go_terms.py # Download go_terms.csv / go_term
├── baselines/ # Baseline notebooks, downloaded from Crunch-Lab
└── notebooks/ # EDA, GO exploration.
- Python 3.10+
- Main dependencies:
anndata,scanpy,numpy,pandas,torch,h5py,scipy,scikit-learn,tqdm,requests
Suggested install (e.g. with pip):
pip install anndata scanpy numpy pandas torch h5py scipy scikit-learn tqdm requests| File | Description |
|---|---|
obesity_challenge_1.h5ad |
Training data (cells × genes, obs["gene"] = perturbed gene, obs["pre_adipo"] for control) |
obesity_challenge_1_local_gtruth.h5ad |
Local test set with ground truth |
genes_to_predict.txt |
Genes to predict |
predict_perturbations.txt |
Perturbations in test set |
- GO terms (
get_go_terms.py):data/go_terms.csv,data/go_terms_unique.csv(Ensembl BioMart). - GO embeddings: assumes
data/deepseek-r1_embeddings.h5(embedding per GO term). - PCA (
embed_pca.py):data/pca_embeddings.h5,data/genes_list.csv(same order as embeddings).
python get_go_terms.pyProduces data/go_terms.csv and data/go_terms_unique.csv.
python embed.pyProduces data/deepseek-r1_embeddings.h5.
Requires data/deepseek-r1_embeddings.h5, go_terms.csv, go_terms_unique.csv. Produces pca_embeddings.h5 and genes_list.csv.
python embed_pca.pyReads data/default/ and data/pca_embeddings.h5, data/genes_list.csv. Trains, runs prediction on test, and computes Pearson delta.
python train_cdae.pyExpected console output: a line like Pearson delta score: 0.xxxx. The script also builds an AnnData of predictions (perturbed gene in obs, real expression in X, predictions in layers["predicted"]).
- Train/validation split: By perturbed gene (e.g. 80% train, 20% val), not by cells. The control condition
"NC"is always in the training set. Genes to keep are the union of genes to predict, test perturbations, and genes observed in training; cells are filtered so that each cell’sobs["gene"]is either one of those genes or"NC". - Control representation:
- Training: For each sample, a random control pre-adipocyte cell is used as input (denoising setup).
- Validation and test: The centroid of control pre-adipocytes (mean expression) is used as the single control input for all samples.
- Target: The model predicts the delta of expression with respect to a reference:
delta = x_perturbed - perturbed_centroid,
whereperturbed_centroidis the mean expression over all perturbed (non‑NC) cells in the training set. At inference, predicted expression is recovered as control + predicted delta (if needed).
- Gene Ontology (GO) terms are fetched from Ensembl BioMart (human genes) and stored per gene (
go_terms.csv/go_terms_unique.csv). - Embeddings: Each GO term has an embedding (e.g. from
embed.py→deepseek-r1_embeddings.h5). For each gene, embeddings of its GO terms are averaged to obtain one vector per gene. - Dimensionality reduction: PCA is applied (e.g. 95% variance in
embed_pca.py) to get a fixed-size conditioning vector per gene (pca_embeddings.h5,genes_list.csv). For the control gene"NC", the conditioning vector is set to zero.
- ConditionalDAE (in
scpred.autoencoder):- Encoder: Linear(n_genes → 512) → LayerNorm → ReLU → Linear(512 →
latent_dim). Defaultlatent_dim = 128. - FiLM (Feature-wise Linear Modulation): The GO embedding (dimension
go_emb_dim, e.g. PCA dimension) is mapped to scale and shift:
out = gamma * h + beta,
withgamma, betaof dimensionlatent_dim. This conditions the latent representation on the perturbed gene. - Decoder: Linear(
latent_dim→ 512) → LayerNorm → ReLU → Linear(512 → n_genes).
- Encoder: Linear(n_genes → 512) → LayerNorm → ReLU → Linear(512 →
- Input: Control expression vector (one cell or centroid). Output: Predicted delta (same dimension as n_genes), not raw expression.
- Loss (per batch):
loss = loss_recon + 0.2 * loss_corr + 0.5 * loss_delta- loss_recon: Smooth L1 between predicted delta and target delta.
- loss_corr:
1 - mean(Pearson_correlation(pred_delta, target_delta))over samples (batch). - loss_delta: Smooth L1 between
(pred - x_ctrl)and(target_delta - x_ctrl)to align predicted change w.r.t. control with the true change.
- Denoising: The control input is corrupted during training: a random fraction of entries is replaced by Gaussian noise. The fraction and noise std increase with epoch (e.g. epochs 0–4: 10% mask, std 0.05; 5–9: 15%, 0.1; 10+: 20%, 0.2). Validation uses uncorrupted control.
- Optimizer: AdamW (lr=1e-3, weight_decay=1e-5). Gradient clipping (max norm 1.0). ReduceLROnPlateau on validation correlation (max, factor=0.5, patience=10).
- Early stopping: Training stops if validation correlation does not improve for 10 consecutive epochs. The best checkpoint (by validation correlation) is restored.
- Per perturbation: For each perturbed gene in the test set, cells are averaged in ground truth and in the prediction matrix, giving two vectors of length n_genes.
- Centroid correction: A reference vector
perturbed_centroid_train(stored ingtruth.uns) is subtracted from both vectors. This centers the comparison on the training perturbed centroid. - Score: Pearson correlation between these two centered vectors is computed. The Pearson delta metric is the mean of these correlations over all test perturbations.
- Gene subset: The metric can be computed on a subset of genes (e.g. 1000 genes sampled with a fixed seed for reproducibility). Implementation:
scpred.metrics.compute_metric_pearson_delta.
Adjust according to challenge usage (CrunchDAO / challenge organizers).