Skip to content

scDORI training from the "grn" step OOM issue #13

@DmitriiSeverinov

Description

@DmitriiSeverinov

Report

Hi,

Thank you very much for developing such a nice and powerful tool!

I was training scDORI on the GPU cluster. I wrongly estimated the amount of time it will take to train the model, therefore it got terminated, but at least the Phase 1 of training was completed.
To continue training scDORI model I simply commented the lines for Phase 1 training in my script:

import logging
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from pathlib import Path
from sklearn.preprocessing import OneHotEncoder

from scdori import (
    trainConfig,
    load_scdori_inputs,
    save_model_weights,
    set_seed,
    scDoRI,
    train_scdori_phases,
    train_model_grn,
    initialize_scdori_parameters,
    load_best_model,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=trainConfig.logging_level)

# ----------------------------------------------------------------
# Loading and preparing data for training and model initialisation
# ----------------------------------------------------------------

logger.info("Starting scDoRI pipeline with integrated GRN.")
set_seed(trainConfig.random_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# ---------
# Load data
# ---------

rna_metacell, atac_metacell, gene_peak_dist, insilico_act, insilico_rep = (
    load_scdori_inputs(trainConfig)
)
gene_peak_fixed = gene_peak_dist.clone()
gene_peak_fixed[gene_peak_fixed > 0] = 1  # mask for peak-gene links based on distance

# ------------------------------------------------------------------------------------------------------------------
# Computing indices of genes which are TFs and setting number of cells per metacell ( set to 1 for single cell data)
# ------------------------------------------------------------------------------------------------------------------

rna_metacell.obs["num_cells"] = 1
rna_metacell.var["index_int"] = range(rna_metacell.shape[1])
tf_indices = rna_metacell.var[rna_metacell.var.gene_type == "TF"].index_int.values
num_cells = rna_metacell.obs.num_cells.values.reshape((-1, 1))

# ---------------------------------------------------
# Onehot encoding the batch column for entire dataset
# ---------------------------------------------------

batch_col = trainConfig.batch_col
rna_metacell.obs["batch"] = rna_metacell.obs[batch_col].values
atac_metacell.obs["batch"] = atac_metacell.obs[batch_col].values
# obtaining onehot encoding for technical batch,

enc = OneHotEncoder(handle_unknown="ignore")
enc.fit(rna_metacell.obs["batch"].values.reshape(-1, 1))

onehot_batch = enc.transform(rna_metacell.obs["batch"].values.reshape(-1, 1)).toarray()
enc.categories_

# ------------------------------------
# Making train and evaluation datasets
# ------------------------------------

# 2) Make small train/test sets
n_cells = rna_metacell.n_obs
indices = np.arange(n_cells)
train_idx, eval_idx = train_test_split(indices, test_size=0.2, random_state=42)
train_dataset = TensorDataset(torch.from_numpy(train_idx))
train_loader = DataLoader(
    train_dataset, batch_size=trainConfig.batch_size_cell, shuffle=True
)

eval_dataset = TensorDataset(torch.from_numpy(eval_idx))
eval_loader = DataLoader(
    eval_dataset, batch_size=trainConfig.batch_size_cell, shuffle=False
)

# ----------------------------------------------------
# Build scDoRI model using parameters from config file
# ----------------------------------------------------

num_genes = rna_metacell.n_vars
num_peaks = atac_metacell.n_vars

num_tfs = insilico_act.shape[1]

num_batches = onehot_batch.shape[1]
model = scDoRI(
    device=device,
    num_genes=num_genes,
    num_peaks=num_peaks,
    num_tfs=num_tfs,
    num_topics=trainConfig.num_topics,
    num_batches=num_batches,
    dim_encoder1=trainConfig.dim_encoder1,
    dim_encoder2=trainConfig.dim_encoder2,
).to(device)

# # -------------------------------------------------------------------------
# # Initialising scDoRI model with precomputed matrices and setting gradients
# # -------------------------------------------------------------------------

# initialize_scdori_parameters(
#     model,
#     gene_peak_dist.to(device),
#     gene_peak_fixed.to(device),
#     insilico_act=insilico_act.to(device),
#     insilico_rep=insilico_rep.to(device),
#     phase="warmup",
# )

# # -----------------------------
# # Train Phase 1 of scDoRI model
# # -----------------------------

# model = train_scdori_phases(
#     model,
#     device,
#     train_loader,
#     eval_loader,
#     rna_metacell,
#     atac_metacell,
#     num_cells,
#     tf_indices,
#     onehot_batch,
#     trainConfig,
# )

# # saving the model weight correspoinding to final epoch where model stopped training
# save_model_weights(model, Path(trainConfig.weights_folder_scdori), "scdori_final")

# -----------------------------
# Train Phase 2 of scDoRI model
# -----------------------------

# loading the best checkpoint from Phase 1
model = load_best_model(
    model, Path(trainConfig.weights_folder_scdori) / "best_scdori_best_eval.pth", device
)

# ----------------------------------
# Set gradients for Phase 2 training
# ----------------------------------

initialize_scdori_parameters(
    model,
    gene_peak_dist,
    gene_peak_fixed,
    insilico_act=insilico_act,
    insilico_rep=insilico_rep,
    phase="grn",
)

# -----------------------------------------
# Phase 2 training and saving model weights
# -----------------------------------------

# train Phase 2 of scDoRI model, TF-gene links are learnt in this phase and used to reconstruct gene-expression profiles
model = train_model_grn(
    model,
    device,
    train_loader,
    eval_loader,
    rna_metacell,
    atac_metacell,
    num_cells,
    tf_indices,
    onehot_batch,
    trainConfig,
)

# saving the model weight correspoinding to final epoch where model stopped training
save_model_weights(model, Path(trainConfig.weights_folder_grn), "scdori_final")

However, when I submitted the job on the cluster I got the following error:

Extracting latent topics: 100%|██████████| 1105/1105 [15:28<00:00,  1.19it/s]
/data/horse/ws/dmse952c-scDoRI_env/scDoRI/src/scdori/_core/train_grn.py:143: RuntimeWarning: invalid value encountered in divide
  rna_tf_vals = median_cell * (rna_tf_vals / rna_tf_vals.sum(axis=1, keepdims=True))
INFO:scdori._core.train_grn:Starting GRN training
GRN Epoch 0:   0%|          | 0/1105 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/projects/p_scads_spinal_cord/scDORI/10062025/scripts/train.py", line 172, in <module>
    model = train_model_grn(
            ^^^^^^^^^^^^^^^^
  File "/data/horse/ws/dmse952c-scDoRI_env/scDoRI/src/scdori/_core/train_grn.py", line 465, in train_model_grn
    out = model(
          ^^^^^^
  File "/home/dmse952c/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dmse952c/.local/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/horse/ws/dmse952c-scDoRI_env/scDoRI/src/scdori/_core/models.py", line 323, in forward
    topic_gene_peak = (1 / (topic_peak_denoised1[topic] + 1e-20))[:, None] * gene_peak
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.23 GiB. GPU 0 has a total capacity of 93.12 GiB of which 249.88 MiB is free. Including non-PyTorch memory, this process has 92.86 GiB memory in use. Of the allocated memory 92.01 GiB is allocated by PyTorch, and 197.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

The GPUs that we have on the cluster are H100, and they have a lot of memory. When I checked the GPU memory usage during Phase 2 training I saw a sharp increase from 14.2 Gb up to 100Gb. For Phase 1 training, that got terminated, 20 Gb of GPU memory was enough.
It was also strange that when I tried to retrain the model from the Phase 1, i.e., Phase 1 and Phase 2 together in one run (and this time I asked for more time on the node), Phase 1 took 20 Gb of memory and at the Phase 2 it increased to 25 Gb. And the model got successfully trained.

So far I have no explanation as to what could have caused this problem.

Best,
Dmitrii

Session info:

| Package      | Version             |
| ------------ | ------------------- |
| torch        | 2.7.0 (2.7.0+cu126) |
| numpy        | 1.26.4              |
| scikit-learn | 1.6.1               |

| Dependency        | Version     |
| ----------------- | ----------- |
| umap-learn        | 0.5.7       |
| pytz              | 2025.2      |
| tqdm              | 4.67.1      |
| h5py              | 3.13.0      |
| pyfaidx           | 0.8.1.4     |
| igraph            | 0.11.8      |
| joblib            | 1.5.1       |
| matplotlib        | 3.10.3      |
| traitlets         | 5.14.3      |
| cffi              | 1.17.1      |
| gtfparse          | 2.5.0       |
| prompt_toolkit    | 3.0.51      |
| wcwidth           | 0.2.13      |
| defusedxml        | 0.7.1       |
| typing_extensions | 4.14.0      |
| comm              | 0.2.2       |
| executing         | 2.2.0       |
| pyparsing         | 3.2.3       |
| pillow            | 11.2.1      |
| session-info2     | 0.1.2       |
| Pygments          | 2.19.1      |
| pure_eval         | 0.2.3       |
| python-dateutil   | 2.9.0.post0 |
| statsmodels       | 0.14.4      |
| scipy             | 1.15.3      |
| texttable         | 1.7.0       |
| pandas            | 2.3.0       |
| kiwisolver        | 1.4.8       |
| polars            | 0.20.31     |
| cycler            | 0.12.1      |
| pyBigWig          | 0.3.24      |
| setuptools        | 80.9.0      |
| ipywidgets        | 8.1.7       |
| decorator         | 5.2.1       |
| pynndescent       | 0.5.13      |
| anndata           | 0.11.4      |
| stack-data        | 0.6.3       |
| seaborn           | 0.13.2      |
| pyarrow           | 14.0.2      |
| psutil            | 7.0.0       |
| ipython           | 9.3.0       |
| PyYAML            | 6.0.2       |
| parso             | 0.8.4       |
| tangermeme        | 0.4.4       |
| threadpoolctl     | 3.6.0       |
| numba             | 0.61.2      |
| patsy             | 1.0.1       |
| llvmlite          | 0.44.0      |
| asttokens         | 3.0.0       |
| six               | 1.17.0      |
| scanpy            | 1.11.2      |
| legacy-api-wrap   | 1.4.1       |
| jedi              | 0.19.2      |
| pycparser         | 2.22        |
| natsort           | 8.4.0       |
| packaging         | 25.0        |
| leidenalg         | 0.10.2      |
| fsspec            | 2025.5.1    |

| Component | Info                                                                           |
| --------- | ------------------------------------------------------------------------------ |
| Python    | 3.12.11 | packaged by conda-forge | (main, Jun  4 2025, 14:45:31) [GCC 13.3.0] |
| OS        | Linux-4.18.0-513.24.1.el8_9.x86_64-x86_64-with-glibc2.28                       |
| Updated   | 2025-06-16 09:16                                                               |

Versions

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions