15def summary(model, input_size, batch_size=-1, device="cuda"):
16 def register_hook(module):
17
18 def hook(module, input, output):
19 class_name = str(module.__class__).split(".")[-1].split("'")[0]
20 module_idx = len(summary)
21
22 m_key = "%s-%i" % (class_name, module_idx + 1)
23 summary[m_key] = OrderedDict()
24 summary[m_key]["input_shape"] = list(input[0].size())
25 summary[m_key]["input_shape"][0] = batch_size
26 if isinstance(output, (list, tuple)):
27 summary[m_key]["output_shape"] = [
28 [-1] + list(o.size())[1:] for o in output
29 ]
30 else:
31 summary[m_key]["output_shape"] = list(output.size())
32 summary[m_key]["output_shape"][0] = batch_size
33
34 params = 0
35 if hasattr(module, "weight") and hasattr(module.weight, "size"):
36 params += torch.prod(torch.LongTensor(list(module.weight.size())))
37 summary[m_key]["trainable"] = module.weight.requires_grad
38 if hasattr(module, "bias") and hasattr(module.bias, "size"):
39 params += torch.prod(torch.LongTensor(list(module.bias.size())))
40 summary[m_key]["nb_params"] = params
41
42 if (
43 not isinstance(module, nn.Sequential)
44 and not isinstance(module, nn.ModuleList)
45 and not (module == model)
46 ):
47 hooks.append(module.register_forward_hook(hook))
48
49 device = device.lower()
50 assert device in [
51 "cuda",
52 "cpu",
53 ], "Input device is not valid, please specify 'cuda' or 'cpu'"
54
55 if device == "cuda" and torch.cuda.is_available():
56 dtype = torch.cuda.FloatTensor
57 else:
58 dtype = torch.FloatTensor
59
60
61 if isinstance(input_size, tuple):
62 input_size = [input_size]
63
64
65 x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
66
67
68
69 summary = OrderedDict()
70 hooks = []
71
72
73 model.apply(register_hook)
74
75
76
77 model(*x)
78
79
80 for h in hooks:
81 h.remove()
82
83 print("----------------------------------------------------------------")
84 line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
85 print(line_new)
86 print("================================================================")
87 total_params = 0
88 total_output = 0
89 trainable_params = 0
90 for layer in summary:
91
92 line_new = "{:>20} {:>25} {:>15}".format(
93 layer,
94 str(summary[layer]["output_shape"]),
95 "{0:,}".format(summary[layer]["nb_params"]),
96 )
97 total_params += summary[layer]["nb_params"]
98 total_output += np.prod(summary[layer]["output_shape"])
99 if "trainable" in summary[layer]:
100 if summary[layer]["trainable"] == True:
101 trainable_params += summary[layer]["nb_params"]
102 print(line_new)
103
104
105 total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
106 total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))
107 total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
108 total_size = total_params_size + total_output_size + total_input_size
109
110 print("================================================================")
111 print("Total params: {0:,}".format(total_params))
112 print("Trainable params: {0:,}".format(trainable_params))
113 print("Non-trainable params: {0:,}".format(total_params - trainable_params))
114 print("----------------------------------------------------------------")
115 print("Input size (MB): %0.2f" % total_input_size)
116 print("Forward/backward pass size (MB): %0.2f" % total_output_size)
117 print("Params size (MB): %0.2f" % total_params_size)
118 print("Estimated Total Size (MB): %0.2f" % total_size)
119 print("----------------------------------------------------------------")
120