Safemotion Lib
Loading...
Searching...
No Matches
posec3d_cat_fl.py
Go to the documentation of this file.
1model = dict(
2 type='ActionRecognitionRunner',
3 backbone=dict(
4 action_feat = dict(
5 type='ResNet3d',
6 in_channels = 17,
7 base_channels = 64,
8 stage_blocks = (4, 6, 3),
9 out_indices = (2, ),
10 spatial_strides = (2, 2, 2),
11 temporal_strides = (1, 1, 2),
12 dilations = (1, 1, 1),
13 conv1_kernel = (1, 7, 7),
14 conv1_stride_s = 1,
15 conv1_stride_t = 1,
16 pool1_stride_s = 1,
17 pool1_stride_t = 1,
18 inflate = (0, 1, 1),
19 inflate_style = '3x1x1',
20 input_key = 'pose_heatmap_for_action'
21 ),
22 pose_feat = dict(
23 type='ResNet3d',
24 in_channels = 17,
25 base_channels = 32,
26 stage_blocks = (4, 6, 3),
27 out_indices = (2, ),
28 spatial_strides = (2, 2, 2),
29 temporal_strides = (1, 1, 2),
30 dilations = (1, 1, 1),
31 conv1_kernel = (1, 7, 7),
32 conv1_stride_s = 1,
33 conv1_stride_t = 1,
34 pool1_stride_s = 1,
35 pool1_stride_t = 1,
36 inflate = (0, 1, 1),
37 inflate_style = '3x1x1',
38 input_key = 'pose_heatmap_for_pose'
39 ),
40 ),
41 fusion=dict(
42 type='I3DFusion',
43 in_channels = 1024+512,
44 out_channels = 0,
45 dropout_ratio=0.5,
46 input_key = ['action_feat', 'pose_feat']
47 ),
48 head=dict(
49 action_upper = dict(
50 type='LinearHead',
51 in_channels=1024+512,
52 num_classes=6,
53 dropout_ratio=0.5,
54 input_key = 'fusion',
55 ),
56 action_lower = dict(
57 type='LinearHead',
58 in_channels=1024+512,
59 num_classes=12,
60 dropout_ratio=0.5,
61 input_key = 'fusion',
62 ),
63 pose = dict(
64 type='LinearHead',
65 in_channels=1024+512,
66 num_classes=8,
67 dropout_ratio=0.5,
68 input_key = 'fusion',
69 ),
70 hand = dict(
71 type='LinearHead',
72 in_channels=1024+512,
73 num_classes=4,
74 dropout_ratio=0.5,
75 input_key = 'fusion',
76 ),
77 foot = dict(
78 type='LinearHead',
79 in_channels=1024+512,
80 num_classes=2,
81 dropout_ratio=0.5,
82 input_key = 'fusion',
83 ),
84 aux_action_upper = dict(
85 type='I3DHead',
86 in_channels=1024,
87 num_classes=6,
88 dropout_ratio=0.5,
89 input_key = 'action_feat',
90 ),
91 aux_action_lower = dict(
92 type='I3DHead',
93 in_channels=1024,
94 num_classes=12,
95 dropout_ratio=0.5,
96 input_key = 'action_feat',
97 ),
98 aux_pose = dict(
99 type='I3DHead',
100 in_channels=512,
101 num_classes=8,
102 dropout_ratio=0.5,
103 input_key = 'pose_feat',
104 ),
105 aux_hand = dict(
106 type='I3DHead',
107 in_channels=512,
108 num_classes=4,
109 dropout_ratio=0.5,
110 input_key = 'pose_feat',
111 ),
112 aux_foot = dict(
113 type='I3DHead',
114 in_channels=512,
115 num_classes=2,
116 dropout_ratio=0.5,
117 input_key = 'pose_feat',
118 ),
119 ),
120 predict_keys = dict(
121 #예측한 라벨의 키 = 헤드의 키(스코어의 키)
122 pred_action_upper = 'action_upper',
123 pred_action_lower = 'action_lower',
124 pred_pose = 'pose',
125 pred_hand = 'hand',
126 pred_foot = 'foot',
127 )
128)
129
130score_keys = ['action_upper', 'action_lower', 'pose', 'hand', 'foot', 'aux_action_upper', 'aux_action_lower', 'aux_pose', 'aux_hand', 'aux_foot']
131pred_keys = ['pred_action_upper', 'pred_action_lower', 'pred_pose', 'pred_hand', 'pred_foot']
132gt_keys = ['gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot', 'gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot']
133target_tasks = ['action_upper', 'action_lower', 'pose', 'hand', 'foot', 'action_upper', 'action_lower', 'pose', 'hand', 'foot']
134train_tasks = ['action_upper', 'action_lower', 'pose', 'hand', 'foot']
135
136
137loss = dict(
138 MutiTaskSigmoidFocalLoss = dict(
139 weight = 1.0,
140 task_key = 'category',
141 pred_keys = score_keys,
142 gt_keys = gt_keys,
143 target_tasks = target_tasks,
144 weights = [1.0, 1.0, 1.0, 1.0, 1.0, 0.25, 0.25, 0.25, 0.25, 0.25],
145 # data_num = dict(action = [100, 100, 100, 24, 100, 34, 3, 100, 100], pose = [98, 517, 11, 35])
146 )
147)
148
149metric_args = dict(pred_key=pred_keys,
150 gt_key=gt_keys[:5],
151 target_tasks=target_tasks[:5],
152 task_key='category')
153
154collect_keys = ['pose_heatmap_for_action', 'pose_heatmap_for_pose', 'gt_action_upper', 'gt_action_lower', 'gt_pose', 'gt_hand', 'gt_foot']
155
156data_loader = dict(
157 type = 'ActionDatasetLoader_mtml',
158 data_folder = '/media/safemotion/HDD5/pjm_test/action_train_dataset_2023/action_mtml_1st_split',
159 category_info = dict(action_upper = 6,
160 action_lower = 12,
161 pose = 8,
162 hand = 4,
163 foot = 2),
164 clip_len_action = 20,
165 clip_len_pose = 6,
166)
167ep_mul = 20
168train = dict(
169 num_workers = 8,
170 init_lr = 0.1,
171 batch_size = 32,
172 epochs = 100*ep_mul,
173 optimizer = 'SGD',
174 optimizer_args = dict(momentum=0.9, nesterov=True, weight_decay=0.0001),
175 scheduler = 'CosineAnnealingLR',
176 # scheduler = 'StepLR',
177 scheduler_args = dict(T_max=30*ep_mul, eta_min=0),
178 adjust_lr_epoch = [10*ep_mul, 50*ep_mul, 100*ep_mul, 130*ep_mul],
179 adjust_lr_rate = [0.5, 0.1, 0.1, 0.1],
180 val_interval = 1,
181
182
183 update_loss_weight = False,
184 update_loss_weight_interval = 20,
185 base_weight = 0.5,
186
187 pretrained = None,#'/media/safemotion/HDD5/pjm_test/action_train_test/9.pth',
188 save_root = '/media/safemotion/HDD5/pjm_test/action_train_result/action_cat_fl',
189
190)
191
192test = dict(
193 model_path = None,
194 save_root = '',
195)