From 7e4e1c70f67dc4593017cc24c84c00da980765a3 Mon Sep 17 00:00:00 2001 From: mmg10 <65535131+mmg10@users.noreply.github.com> Date: Tue, 21 Mar 2023 15:10:07 +0500 Subject: [PATCH] fixes number of batches len(batch) = 4 since it has four key,value pairs in the Dataset --- FSDP_Workshop/main_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/FSDP_Workshop/main_training.py b/FSDP_Workshop/main_training.py index da149ee..05a8241 100644 --- a/FSDP_Workshop/main_training.py +++ b/FSDP_Workshop/main_training.py @@ -224,7 +224,7 @@ def train( optimizer.step() ddp_loss[0] += loss.item() - ddp_loss[1] += len(batch) + ddp_loss[1] += 1 if rank == 0: inner_pbar.update(1) @@ -262,7 +262,7 @@ def validation(model, local_rank, rank, world_size, test_loader): labels=batch["target_ids"], ) ddp_loss[0] += output["loss"].item() # sum up batch loss - ddp_loss[1] += len(batch) + ddp_loss[1] += 1 if rank == 0: inner_pbar.update(1) @@ -633,4 +633,4 @@ def fsdp_main(args): # torch run start fsdp_main(args) - \ No newline at end of file +