-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Hi, firstly thank you for sharing your work on CogACT, we are interested in finetuning and deploying to our own setup.
We collected data using UR5e robots and converted to the custom dataset format as instructed. However, when trying to run finetuning we are running into a problem of slow speed despite the small dataset. For reference we have about 25 episodes and the dataset size is roughly <6000 examples.
The current command we are using to run finetuning on 4 A100s is torchrun --standalone --nnodes 1 --nproc-per-node 4 scripts/train.py --pretrained_checkpoint $HOME/CogACT/base_models/CogACT-Base/checkpoints/CogACT-Base.pt --vla.type prism-dinosiglip-224px+oxe+diffusion --vla.data_mix custom_finetuning --vla.expected_world_size 4 --vla.global_batch_size 128 --vla.per_device_batch_size 32 --vla.learning_rate 2e-5 --vla.epochs 50 --vla.weight_decay 0.01 --vla.lr_scheduler_type linear-warmup+cosine-decay --vla.warmup_ratio 0.1 --vla.max_grad_norm 1.0 --vla.freeze_vision_backbone True --vla.freeze_llm_backbone True --vla.unfreeze_last_llm_layer True --data_root_dir $SCRATCH/cogact_data --run_root_dir $SCRATCH/cogact_finetune_logs --image_aug True --save_interval 50 --repeated_diffusion_steps 8 --future_action_window_size 15 --action_model_type DiT-B --is_resume False --hf_token HF_TOKEN --load_all_data_for_training True
currently, we are dealing with a speed of ~2.3 seconds per iteration, or 5.5 hours per epoch. Is this the expected speed for finetuing?
The script also produced output of
######################################################################################
Loading the following 1 datasets (incl. sampling weight):
custom_finetuning: =======================================================1.000000
######################################################################################
INFO | >> [*] Threads per Dataset: [1] dataset.py:537
INFO | >> [*] Reads per Dataset: [1] dataset.py:538
INFO | >> [*] Constructing datasets... dataset.py:541
INFO | >> Load dataset info from $SCRATCH/cogact_data/custom_finetuning/1.0.0 dataset_info.py:599
INFO | >> Constructing tf.data.Dataset custom_finetuning for split train, from $SCRATCH/cogact_data/custom_finetuning/1.0.0 logging_logger.py:49
2025-08-08 17:50:19.526656: I tensorflow/core/grappler/optimizers/data/replicate_on_split.cc:32] Running replicate on split optimization
i was wondering why is the threads/reads set to 1? and does this impact training speed?
Thank you :)