18 def __init__(self, in_channels, num_classes, layer_channels, dropout_ratio, input_key, input_type='linear'):
24 layer_num = len(layer_channels)
28 for i
in range(layer_num):
29 out_ch = layer_channels[i]
32 layer_list.append(nn.Linear(in_ch, out_ch))
33 layer_list.append(nn.BatchNorm1d(out_ch))
34 layer_list.append(nn.ReLU(inplace=
True))
40 layer_list.append(nn.Linear(in_ch, num_classes))
42 self.
mlp = nn.Sequential(*layer_list)
44 if input_type ==
'3d':
45 self.
avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))