15def summary(model, input_size, batch_size=-1, device="cuda"):
16 def register_hook(module):
18 def hook(module, input, output):
19 class_name = str(module.__class__).split(
".")[-1].split(
"'")[0]
20 module_idx = len(summary)
22 m_key =
"%s-%i" % (class_name, module_idx + 1)
23 summary[m_key] = OrderedDict()
24 summary[m_key][
"input_shape"] = list(input[0].size())
25 summary[m_key][
"input_shape"][0] = batch_size
26 if isinstance(output, (list, tuple)):
27 summary[m_key][
"output_shape"] = [
28 [-1] + list(o.size())[1:]
for o
in output
31 summary[m_key][
"output_shape"] = list(output.size())
32 summary[m_key][
"output_shape"][0] = batch_size
35 if hasattr(module,
"weight")
and hasattr(module.weight,
"size"):
36 params += torch.prod(torch.LongTensor(list(module.weight.size())))
37 summary[m_key][
"trainable"] = module.weight.requires_grad
38 if hasattr(module,
"bias")
and hasattr(module.bias,
"size"):
39 params += torch.prod(torch.LongTensor(list(module.bias.size())))
40 summary[m_key][
"nb_params"] = params
43 not isinstance(module, nn.Sequential)
44 and not isinstance(module, nn.ModuleList)
45 and not (module == model)
47 hooks.append(module.register_forward_hook(hook))
49 device = device.lower()
53 ],
"Input device is not valid, please specify 'cuda' or 'cpu'"
55 if device ==
"cuda" and torch.cuda.is_available():
56 dtype = torch.cuda.FloatTensor
58 dtype = torch.FloatTensor
61 if isinstance(input_size, tuple):
62 input_size = [input_size]
65 x = [torch.rand(2, *in_size).type(dtype)
for in_size
in input_size]
69 summary = OrderedDict()
73 model.apply(register_hook)
83 print(
"----------------------------------------------------------------")
84 line_new =
"{:>20} {:>25} {:>15}".format(
"Layer (type)",
"Output Shape",
"Param #")
86 print(
"================================================================")
92 line_new =
"{:>20} {:>25} {:>15}".format(
94 str(summary[layer][
"output_shape"]),
95 "{0:,}".format(summary[layer][
"nb_params"]),
97 total_params += summary[layer][
"nb_params"]
98 total_output += np.prod(summary[layer][
"output_shape"])
99 if "trainable" in summary[layer]:
100 if summary[layer][
"trainable"] ==
True:
101 trainable_params += summary[layer][
"nb_params"]
105 total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
106 total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))
107 total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
108 total_size = total_params_size + total_output_size + total_input_size
110 print(
"================================================================")
111 print(
"Total params: {0:,}".format(total_params))
112 print(
"Trainable params: {0:,}".format(trainable_params))
113 print(
"Non-trainable params: {0:,}".format(total_params - trainable_params))
114 print(
"----------------------------------------------------------------")
115 print(
"Input size (MB): %0.2f" % total_input_size)
116 print(
"Forward/backward pass size (MB): %0.2f" % total_output_size)
117 print(
"Params size (MB): %0.2f" % total_params_size)
118 print(
"Estimated Total Size (MB): %0.2f" % total_size)
119 print(
"----------------------------------------------------------------")