18 def __init__(self, in_channels, num_classes, layer_channels, dropout_ratio, input_key, input_type='linear'):
19 super().__init__()
20
21 self.dropout_ratio = dropout_ratio
22 self.input_key = input_key
23
24 layer_num = len(layer_channels)
25 layer_list = []
26
27 in_ch = in_channels
28 for i in range(layer_num):
29 out_ch = layer_channels[i]
30 if self.dropout_ratio != 0:
31 layer_list.append(nn.Dropout(p=self.dropout_ratio))
32 layer_list.append(nn.Linear(in_ch, out_ch))
33 layer_list.append(nn.BatchNorm1d(out_ch))
34 layer_list.append(nn.ReLU(inplace=True))
35 in_ch = out_ch
36
37 if self.dropout_ratio != 0:
38 layer_list.append(nn.Dropout(p=self.dropout_ratio))
39
40 layer_list.append(nn.Linear(in_ch, num_classes))
41
42 self.mlp = nn.Sequential(*layer_list)
43
44 if input_type == '3d':
45 self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
46 else:
47 self.avg_pool = None
48
49