From d223a4a165097b9cbe275dbd4c7a7e3c0a029236 Mon Sep 17 00:00:00 2001 From: 201870262-Liu Yi <201870262@smail.nju.edu.cn> Date: Sun, 28 Jan 2024 16:34:36 +0800 Subject: [PATCH 1/2] feat: allow nested list/tuple input/output and None type --- torchsummary/torchsummary.py | 41 +++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..231a849 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,6 +6,31 @@ import numpy as np +def parse_tuple_list(output, batch_size): + if isinstance(output, (list, tuple)): + ret = [] + for o in output: + ret.append(parse_tuple_list(o, batch_size)) + elif output is None: + ret = [0, ] + else: + ret = list(output.size()) + ret[0] = batch_size + return ret + + +def prod_tuple_list(output_shape): + if isinstance(output_shape, (list, tuple)): + ret = 0 + for o in output_shape: + ret += prod_tuple_list(o) + elif output_shape is None: + ret = 0 + else: + ret = np.prod(output_shape) + return ret + + def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): result, params_info = summary_string( model, input_size, batch_size, device, dtypes) @@ -27,15 +52,8 @@ def hook(module, input, output): m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() - summary[m_key]["input_shape"] = list(input[0].size()) - summary[m_key]["input_shape"][0] = batch_size - if isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output - ] - else: - summary[m_key]["output_shape"] = list(output.size()) - summary[m_key]["output_shape"][0] = batch_size + summary[m_key]["input_shape"] = parse_tuple_list(input, batch_size) + summary[m_key]["output_shape"] = parse_tuple_list(output, batch_size) params = 0 if hasattr(module, "weight") and hasattr(module.weight, "size"): @@ -91,15 +109,14 @@ def hook(module, input, output): ) total_params += summary[layer]["nb_params"] - total_output += np.prod(summary[layer]["output_shape"]) + total_output += prod_tuple_list(summary[layer]["output_shape"]) if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.prod(sum(input_size, ())) - * batch_size * 4. / (1024 ** 2.)) + total_input_size = abs(prod_tuple_list(input_size) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024 ** 2.)) From f28eda3002a92ba67101d4920b90e9096d6e2e5f Mon Sep 17 00:00:00 2001 From: cijinsama Date: Sat, 2 Mar 2024 23:10:15 +0800 Subject: [PATCH 2/2] Add docstrings and change public functions to private ones --- torchsummary/torchsummary.py | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 231a849..2e88038 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -6,11 +6,22 @@ import numpy as np -def parse_tuple_list(output, batch_size): +def _parse_tuple_list(output, batch_size): + """Recursively parse all element in tuple/list in the arg 'output' + + If the output is list/tuple, iterate over all element recursively to obtain the tensor shape. + Else if the output is None, use 0 as the tensor shape. + Else (the output is supposed to be a tensor) return cooresponding shape. + + Args: + output (Union[list, tuple, Tensor]): output to be parsed + batch_size (int): specific batch_size + """ + if isinstance(output, (list, tuple)): ret = [] for o in output: - ret.append(parse_tuple_list(o, batch_size)) + ret.append(_parse_tuple_list(o, batch_size)) elif output is None: ret = [0, ] else: @@ -19,11 +30,20 @@ def parse_tuple_list(output, batch_size): return ret -def prod_tuple_list(output_shape): +def _prod_tuple_list(output_shape): + """Recursively calculate output size + + If the output is list/tuple, iterate over all element recursively to obtain the tensor shape and accumulate them. + Else if the output is None, use 0 as the tensor size. + Else (the output is supposed to be a tensor) return it's size. + + Args: + output_shape (Union[list, tuple, Tensor]): output_shape to be parsed + """ if isinstance(output_shape, (list, tuple)): ret = 0 for o in output_shape: - ret += prod_tuple_list(o) + ret += _prod_tuple_list(o) elif output_shape is None: ret = 0 else: @@ -52,8 +72,8 @@ def hook(module, input, output): m_key = "%s-%i" % (class_name, module_idx + 1) summary[m_key] = OrderedDict() - summary[m_key]["input_shape"] = parse_tuple_list(input, batch_size) - summary[m_key]["output_shape"] = parse_tuple_list(output, batch_size) + summary[m_key]["input_shape"] = _parse_tuple_list(input, batch_size) + summary[m_key]["output_shape"] = _parse_tuple_list(output, batch_size) params = 0 if hasattr(module, "weight") and hasattr(module.weight, "size"): @@ -109,14 +129,14 @@ def hook(module, input, output): ) total_params += summary[layer]["nb_params"] - total_output += prod_tuple_list(summary[layer]["output_shape"]) + total_output += _prod_tuple_list(summary[layer]["output_shape"]) if "trainable" in summary[layer]: if summary[layer]["trainable"] == True: trainable_params += summary[layer]["nb_params"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). - total_input_size = abs(prod_tuple_list(input_size) * batch_size * 4. / (1024 ** 2.)) + total_input_size = abs(_prod_tuple_list(input_size) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024 ** 2.))