These are my args:
- run_cfg@_global_=llama2_7b_drope_qk_norm.yaml
- train_batch_size=512
- per_device_train_batch_size=4
dataset - PrimeIntellect/fineweb-edu
The only thing I changed was to reduce the per_device_train_batch_size as , I am using 40 x 8 GPUs.