Safemotion Lib
Loading...
Searching...
No Matches
posec3d_cat_mlp2_fl.py
Go to the documentation of this file.
1model = dict(
2 type='ActionRecognitionRunner',
3 backbone=dict(
4 action_feat = dict(
5 type='ResNet3d',
6 in_channels = 17,
7 base_channels = 64,
8 stage_blocks = (4, 6, 3),
9 out_indices = (2, ),
10 spatial_strides = (2, 2, 2),
11 temporal_strides = (1, 1, 2),
12 dilations = (1, 1, 1),
13 conv1_kernel = (1, 7, 7),
14 conv1_stride_s = 1,
15 conv1_stride_t = 1,
16 pool1_stride_s = 1,
17 pool1_stride_t = 1,
18 inflate = (0, 1, 1),
19 inflate_style = '3x1x1',
20 input_key = 'pose_heatmap_for_action'
21 ),
22 pose_feat = dict(
23 type='ResNet3d',
24 in_channels = 17,
25 base_channels = 32,
26 stage_blocks = (4, 6, 3),
27 out_indices = (2, ),
28 spatial_strides = (2, 2, 2),
29 temporal_strides = (1, 1, 2),
30 dilations = (1, 1, 1),
31 conv1_kernel = (1, 7, 7),
32 conv1_stride_s = 1,
33 conv1_stride_t = 1,
34 pool1_stride_s = 1,
35 pool1_stride_t = 1,
36 inflate = (0, 1, 1),
37 inflate_style = '3x1x1',
38 input_key = 'pose_heatmap_for_pose'
39 ),
40 ),
41 fusion=dict(
42 type='I3DFusion',
43 in_channels = 1024+512,
44 out_channels = 0,
45 dropout_ratio=0.5,
46 input_key = ['action_feat', 'pose_feat']
47 ),
48 head=dict(
49 action_upper = dict(
50 type='LinearHead',
51 in_channels=1024+512,
52 num_classes=6,
53 dropout_ratio=0.5,
54 input_key = 'fusion',
55 ),
56 action_lower = dict(
57 type='LinearHead',
58 in_channels=1024+512,
59 num_classes=12,
60 dropout_ratio=0.5,
61 input_key = 'fusion',
62 ),
63 pose = dict(
64 type='MLPHead',
65 in_channels=1024+512,
66 num_classes=8,
67 layer_channels=[2048],
68 dropout_ratio=0.5,
69 input_key = 'fusion',
70 ),
71 hand = dict(
72 type='MLPHead',
73 in_channels=1024+512,
74 num_classes=4,
75 layer_channels=[2048],
76 dropout_ratio=0.5,
77 input_key = 'fusion',
78 ),
79 foot = dict(
80 type='MLPHead',
81 in_channels=1024+512,
82 num_classes=2,
83 layer_channels=[2048],
84 dropout_ratio=0.5,
85 input_key = 'fusion',
86 ),
87 aux_action_upper = dict(
88 type='I3DHead',
89 in_channels=1024,
90 num_classes=6,
91 dropout_ratio=0.5,
92 input_key = 'action_feat',
93 ),
94 aux_action_lower = dict(
95 type='I3DHead',
96 in_channels=1024,
97 num_classes=12,
98 dropout_ratio=0.5,
99 input_key = 'action_feat',
100 ),
101 aux_pose = dict(
102 type='MLPHead',
103 in_channels=512,
104 num_classes=8,
105 layer_channels=[1024],
106 dropout_ratio=0.5,
107 input_key = 'pose_feat',
108 input_type = '3d',
109 ),
110 aux_hand = dict(
111 type='MLPHead',
112 in_channels=512,
113 num_classes=4,
114 layer_channels=[1024],
115 dropout_ratio=0.5,
116 input_key = 'pose_feat',
117 input_type = '3d',
118 ),
119 aux_foot = dict(
120 type='MLPHead',
121 in_channels=512,
122 num_classes=2,
123 layer_channels=[1024],
124 dropout_ratio=0.5,
125 input_key = 'pose_feat',
126 input_type = '3d',
127 ),
128 ),
129 predict_keys = dict(
130 #예측한 라벨의 키 = 헤드의 키(스코어의 키)
131 pred_action_upper = 'action_upper',
132 pred_action_lower = 'action_lower',
133 pred_pose = 'pose',
134 pred_hand = 'hand',
135 pred_foot = 'foot',
136 )
137)
138
139score_keys = ['action_upper', 'action_lower', 'pose', 'hand', 'foot', 'aux_action_upper', 'aux_action_lower', 'aux_pose', 'aux_hand', 'aux_foot']
140pred_keys = ['pred_action_upper', 'pred_action_lower', 'pred_pose', 'pred_hand', 'pred_foot']
141gt_keys = ['gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot', 'gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot']
142target_tasks = ['action_upper', 'action_lower', 'pose', 'hand', 'foot', 'action_upper', 'action_lower', 'pose', 'hand', 'foot']
143train_tasks = ['action_upper', 'action_lower', 'pose', 'hand', 'foot']
144
145
146loss = dict(
147 MutiTaskSigmoidFocalLoss = dict(
148 weight = 1.0,
149 task_key = 'category',
150 pred_keys = score_keys,
151 gt_keys = gt_keys,
152 target_tasks = target_tasks,
153 weights = [1.0, 1.0, 1.0, 1.0, 1.0, 0.25, 0.25, 0.25, 0.25, 0.25],
154 # data_num = dict(action = [100, 100, 100, 24, 100, 34, 3, 100, 100], pose = [98, 517, 11, 35])
155 )
156)
157
158metric_args = dict(pred_key=pred_keys,
159 gt_key=gt_keys[:5],
160 target_tasks=target_tasks[:5],
161 task_key='category')
162
163collect_keys = ['pose_heatmap_for_action', 'pose_heatmap_for_pose', 'gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot']
164
165data_loader = dict(
166 type = 'ActionDatasetLoader_mtml',
167 data_folder = '/media/safemotion/HDD5/pjm_test/action_train_dataset_2023/action_mtml_1st_split',
168 category_info = dict(action_upper = 6,
169 action_lower = 12,
170 pose = 8,
171 hand = 4,
172 foot = 2),
173 clip_len_action = 20,
174 clip_len_pose = 6,
175)
176ep_mul = 20
177train = dict(
178 num_workers = 8,
179 init_lr = 0.1,
180 batch_size = 32,
181 epochs = 100*ep_mul,
182 optimizer = 'SGD',
183 optimizer_args = dict(momentum=0.9, nesterov=True, weight_decay=0.0001),
184 scheduler = 'CosineAnnealingLR',
185 # scheduler = 'StepLR',
186 scheduler_args = dict(T_max=30*ep_mul, eta_min=0),
187 adjust_lr_epoch = [10*ep_mul, 50*ep_mul, 100*ep_mul, 130*ep_mul],
188 adjust_lr_rate = [0.5, 0.1, 0.1, 0.1],
189 val_interval = 1,
190
191
192 update_loss_weight = False,
193 update_loss_weight_interval = 20,
194 base_weight = 0.5,
195
196 pretrained = None,#'/media/safemotion/HDD5/pjm_test/action_train_test/9.pth',
197 save_root = '/media/safemotion/HDD5/pjm_test/action_train_result/action_cat_mlp2_fl',
198
199)
200
201test = dict(
202 model_path = '/media/safemotion/HDD5/pjm_test/action_train_result/action_cat_mlp2_fl/weights/1824.pth',
203 save_root = '',
204)