From bb59b1c9b1204b4a948b226e026a3e9045bfbd19 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 9 Aug 2019 14:48:51 -0400 Subject: [PATCH 1/2] add running_mean and running_var params for BN For batch norm layers, count the running_var and running_mean parameters in batch_norm layers --- torchsummary/torchsummary.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index cbe18e3..6cf3cd7 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -32,6 +32,10 @@ def hook(module, input, output): summary[m_key]["trainable"] = module.weight.requires_grad if hasattr(module, "bias") and hasattr(module.bias, "size"): params += torch.prod(torch.LongTensor(list(module.bias.size()))) + if hasattr(module, "running_mean") and hasattr(module.running_mean, "size"): + params += torch.prod(torch.LongTensor(list(module.running_mean.size()))) + if hasattr(module, "running_var") and hasattr(module.running_var, "size"): + params += torch.prod(torch.LongTensor(list(module.running_var.size()))) summary[m_key]["nb_params"] = params if ( From 9a9a86684400bd8c2ce7d54d064ec87f2ae44cb7 Mon Sep 17 00:00:00 2001 From: Mostafa Elhoushi Date: Fri, 9 Aug 2019 16:30:25 -0400 Subject: [PATCH 2/2] Check if running_mean/var are tracked parameters According to https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html, we need to check the attribute track_running_stats of batchnorm layer to see if running_mean and running_vars are stored as parameters --- torchsummary/torchsummary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 6cf3cd7..34216b1 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -32,9 +32,9 @@ def hook(module, input, output): summary[m_key]["trainable"] = module.weight.requires_grad if hasattr(module, "bias") and hasattr(module.bias, "size"): params += torch.prod(torch.LongTensor(list(module.bias.size()))) - if hasattr(module, "running_mean") and hasattr(module.running_mean, "size"): + if hasattr(module, "running_mean") and hasattr(module.running_mean, "size") and hasattr(module, "track_running_stats") and module.track_running_stats: params += torch.prod(torch.LongTensor(list(module.running_mean.size()))) - if hasattr(module, "running_var") and hasattr(module.running_var, "size"): + if hasattr(module, "running_var") and hasattr(module.running_var, "size") and hasattr(module, "track_running_stats") and module.track_running_stats: params += torch.prod(torch.LongTensor(list(module.running_var.size()))) summary[m_key]["nb_params"] = params