Safemotion Lib
Loading...
Searching...
No Matches
build.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3@author: liaoxingyu
4@contact: sherlockliao01@gmail.com
5"""
6
7from . import lr_scheduler
8from . import optim
9from torch import nn
10
11def build_optimizer(cfg, model):
12 params = []
13 for key, value in model.named_parameters():
14 if not value.requires_grad: continue
15
16 lr = cfg.SOLVER.BASE_LR
17 weight_decay = cfg.SOLVER.WEIGHT_DECAY
18 if "heads" in key: # for projection head, they may have larger LR
19 lr *= cfg.SOLVER.HEADS_LR_FACTOR
20 else: # for backbone, they may have different for convs and bns
21 if 'bn' in key or 'downsample.1' in key:
22 lr *= cfg.SOLVER.BACKBONE_BN_LR_FACTOR
23 elif 'backbone.1' in key and isinstance(model.backbone[1], nn.BatchNorm2d):
24 lr *= cfg.SOLVER.BACKBONE_BN_LR_FACTOR
25 if "bias" in key:
26 lr *= cfg.SOLVER.BIAS_LR_FACTOR
27 weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
28 params += [{"name": key, "params": [value], "lr": lr,
29 "weight_decay": weight_decay, "freeze": False}]
30
31 solver_opt = cfg.SOLVER.OPT
32 # fmt: off
33 if solver_opt == "SGD": opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
34 else: opt_fns = getattr(optim, solver_opt)(params)
35 # fmt: on
36 return opt_fns
37
38
39def build_lr_scheduler(cfg, optimizer):
40 scheduler_args = {
41 "optimizer": optimizer,
42
43 # warmup options
44 "warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
45 "warmup_iters": cfg.SOLVER.WARMUP_ITERS,
46 "warmup_method": cfg.SOLVER.WARMUP_METHOD,
47
48 # multi-step lr scheduler options
49 "milestones": cfg.SOLVER.STEPS,
50 "gamma": cfg.SOLVER.GAMMA,
51
52 # cosine annealing lr scheduler options
53 "max_iters": cfg.SOLVER.MAX_ITER,
54 "delay_iters": cfg.SOLVER.DELAY_ITERS,
55 "eta_min_lr": cfg.SOLVER.ETA_MIN_LR,
56
57 }
58 return getattr(lr_scheduler, cfg.SOLVER.SCHED)(**scheduler_args)
build_lr_scheduler(cfg, optimizer)
Definition build.py:39
build_optimizer(cfg, model)
Definition build.py:11