Safemotion Lib
Loading...
Searching...
No Matches
stgcn_mc2.py
Go to the documentation of this file.
1model = dict(
2 type='ActionRecognitionRunner',
3 backbone=dict(
4 stgcn = dict(
5 type='STGCN',
6 in_channels = 3,
7 graph_args = dict(layout = 'coco', strategy='spatial'),
8 edge_importance_weighting = True,
9 input_key = 'keypoints',
10 dropout=0.5,
11 ),
12 ),
13 head=dict(
14 action = dict(
15 type='GCNHead',
16 in_channels=256,
17 num_class=9,
18 input_key = 'stgcn',
19 ),
20 pose = dict(
21 type='GCNHead',
22 in_channels=256,
23 num_class=4,
24 input_key = 'stgcn',
25 ),
26 ),
27 predict_keys = dict(
28 pred_action = 'action',
29 pred_pose = 'pose',
30 )
31)
32collect_keys = ['keypoints', 'label_action', 'label_pose']
33# collect_keys = ['keypoints', 'label_action']
34
35# loss = dict(
36# CrossEntropyLoss = dict(
37# weight = 1.0,
38# weights = [1.0],
39# pred_keys = ['action'],
40# gt_keys = ['label_action'],
41# data_num = dict(action = [100, 100, 100, 24, 100, 34, 3, 100, 100])
42# )
43# )
44loss = dict(
45 CrossEntropyLoss = dict(
46 weight = 1.0,
47 weights = [1.0, 0.5],
48 pred_keys = ['action', 'pose'],
49 gt_keys = ['label_action', 'label_pose'],
50 data_num = dict(action = [100, 100, 100, 24, 100, 34, 3, 100, 100], pose = [98, 517, 11, 35])
51 )
52)
53
54data_loader = dict(
55 type = 'ActionDatasetLoader',
56 train_data_folder = '/media/safemotion/HDD5/pjm_test/action_mc_train/1st/train',
57 test_data_folder = '/media/safemotion/HDD5/pjm_test/action_mc_train/1st',
58 use_normalize = True,
59)
60
61train = dict(
62 num_workers = 8,
63 init_lr = 0.1,
64 batch_size = 32,
65 epochs = 3000,
66 optimizer = 'SGD',
67 optimizer_args = dict(momentum=0.9, nesterov=True, weight_decay=0.0001),
68 adjust_lr_epoch = [500, 1000, 1500, 2000],
69 adjust_lr_rate = [0.1]*4,
70 val_interval = 1,
71
72 pretrained = None,#'/media/safemotion/HDD5/pjm_test/action_train_test2/9.pth',
73 save_root = '/media/safemotion/HDD5/pjm_test/action_train_mc_test2',
74)
75
76test = dict(
77 model_path = '/media/safemotion/HDD5/pjm_test/action_train_mc_test2/weights/1368.pth',
78 save_root = '',
79)