Safemotion Lib
Loading...
Searching...
No Matches
hooks.py
Go to the documentation of this file.
1# -*- coding: utf-8 -*-
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3
4import datetime
5import itertools
6import logging
7import os
8import tempfile
9import time
10from collections import Counter
11
12import torch
13from torch import nn
14from torch.nn.parallel import DistributedDataParallel
15
16from fastreid.evaluation.testing import flatten_results_dict
17from fastreid.solver import optim
18from fastreid.utils import comm
19from fastreid.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
20from fastreid.utils.events import EventStorage, EventWriter
21from fastreid.utils.file_io import PathManager
22from fastreid.utils.precision_bn import update_bn_stats, get_bn_modules
23from fastreid.utils.timer import Timer
24from .train_loop import HookBase
25
26__all__ = [
27 "CallbackHook",
28 "IterationTimer",
29 "PeriodicWriter",
30 "PeriodicCheckpointer",
31 "LRScheduler",
32 "AutogradProfiler",
33 "EvalHook",
34 "PreciseBN",
35 "FreezeLayer",
36]
37
38"""
39Implement some common hooks.
40"""
41
42
44 """
45 Create a hook using callback functions provided by the user.
46 """
47
48 def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
49 """
50 Each argument is a function that takes one argument: the trainer.
51 """
52 self._before_train = before_train
53 self._before_step = before_step
54 self._after_step = after_step
55 self._after_train = after_train
56
57 def before_train(self):
58 if self._before_train:
60
61 def after_train(self):
62 if self._after_train:
63 self._after_train(self.trainer)
64 # The functions may be closures that hold reference to the trainer
65 # Therefore, delete them to avoid circular reference.
66 del self._before_train, self._after_train
67 del self._before_step, self._after_step
68
69 def before_step(self):
70 if self._before_step:
71 self._before_step(self.trainer)
72
73 def after_step(self):
74 if self._after_step:
75 self._after_step(self.trainer)
76
77
79 """
80 Track the time spent for each iteration (each run_step call in the trainer).
81 Print a summary in the end of training.
82 This hook uses the time between the call to its :meth:`before_step`
83 and :meth:`after_step` methods.
84 Under the convention that :meth:`before_step` of all hooks should only
85 take negligible amount of time, the :class:`IterationTimer` hook should be
86 placed at the beginning of the list of hooks to obtain accurate timing.
87 """
88
89 def __init__(self, warmup_iter=3):
90 """
91 Args:
92 warmup_iter (int): the number of iterations at the beginning to exclude
93 from timing.
94 """
95 self._warmup_iter = warmup_iter
97
98 def before_train(self):
99 self._start_time = time.perf_counter()
101 self._total_timer.pause()
102
103 def after_train(self):
104 logger = logging.getLogger(__name__)
105 total_time = time.perf_counter() - self._start_time
106 total_time_minus_hooks = self._total_timer.seconds()
107 hook_time = total_time - total_time_minus_hooks
108
109 num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
110
111 if num_iter > 0 and total_time_minus_hooks > 0:
112 # Speed is meaningful only after warmup
113 # NOTE this format is parsed by grep in some scripts
114 logger.info(
115 "Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
116 num_iter,
117 str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
118 total_time_minus_hooks / num_iter,
119 )
120 )
121
122 logger.info(
123 "Total training time: {} ({} on hooks)".format(
124 str(datetime.timedelta(seconds=int(total_time))),
125 str(datetime.timedelta(seconds=int(hook_time))),
126 )
127 )
128
129 def before_step(self):
130 self._step_timer.reset()
131 self._total_timer.resume()
132
133 def after_step(self):
134 # +1 because we're in after_step
135 iter_done = self.trainer.iter - self.trainer.start_iter + 1
136 if iter_done >= self._warmup_iter:
137 sec = self._step_timer.seconds()
138 self.trainer.storage.put_scalars(time=sec)
139 else:
140 self._start_time = time.perf_counter()
141 self._total_timer.reset()
142
143 self._total_timer.pause()
144
145
147 """
148 Write events to EventStorage periodically.
149 It is executed every ``period`` iterations and after the last iteration.
150 """
151
152 def __init__(self, writers, period=20):
153 """
154 Args:
155 writers (list[EventWriter]): a list of EventWriter objects
156 period (int):
157 """
158 self._writers = writers
159 for w in writers:
160 assert isinstance(w, EventWriter), w
161 self._period = period
162
163 def after_step(self):
164 if (self.trainer.iter + 1) % self._period == 0 or (
165 self.trainer.iter == self.trainer.max_iter - 1
166 ):
167 for writer in self._writers:
168 writer.write()
169
170 def after_train(self):
171 for writer in self._writers:
172 writer.close()
173
174
175class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
176 """
177 Same as :class:`fastreid.utils.checkpoint.PeriodicCheckpointer`, but as a hook.
178 Note that when used as a hook,
179 it is unable to save additional data other than what's defined
180 by the given `checkpointer`.
181 It is executed every ``period`` iterations and after the last iteration.
182 """
183
184 def before_train(self):
185 self.max_iter = self.trainer.max_iter
186
187 def after_step(self):
188 # No way to use **kwargs
189 self.step(self.trainer.iter)
190
191
193 """
194 A hook which executes a torch builtin LR scheduler and summarizes the LR.
195 It is executed after every iteration.
196 """
197
198 def __init__(self, optimizer, scheduler):
199 """
200 Args:
201 optimizer (torch.optim.Optimizer):
202 scheduler (torch.optim._LRScheduler)
203 """
204 self._optimizer = optimizer
205 self._scheduler = scheduler
206
207 # NOTE: some heuristics on what LR to summarize
208 # summarize the param group with most parameters
209 largest_group = max(len(g["params"]) for g in optimizer.param_groups)
210
211 if largest_group == 1:
212 # If all groups have one parameter,
213 # then find the most common initial LR, and use it for summary
214 lr_count = Counter([g["lr"] for g in optimizer.param_groups])
215 lr = lr_count.most_common()[0][0]
216 for i, g in enumerate(optimizer.param_groups):
217 if g["lr"] == lr:
219 break
220 else:
221 for i, g in enumerate(optimizer.param_groups):
222 if len(g["params"]) == largest_group:
223 self._best_param_group_id = i
224 break
225
226 def after_step(self):
227 lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
228 self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
229 self._scheduler.step()
230
231
233 """
234 A hook which runs `torch.autograd.profiler.profile`.
235 Examples:
236 .. code-block:: python
237 hooks.AutogradProfiler(
238 lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
239 )
240 The above example will run the profiler for iteration 10~20 and dump
241 results to ``OUTPUT_DIR``. We did not profile the first few iterations
242 because they are typically slower than the rest.
243 The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
244 Note:
245 When used together with NCCL on older version of GPUs,
246 autograd profiler may cause deadlock because it unnecessarily allocates
247 memory on every device it sees. The memory management calls, if
248 interleaved with NCCL calls, lead to deadlock on GPUs that do not
249 support `cudaLaunchCooperativeKernelMultiDevice`.
250 """
251
252 def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
253 """
254 Args:
255 enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
256 and returns whether to enable the profiler.
257 It will be called once every step, and can be used to select which steps to profile.
258 output_dir (str): the output directory to dump tracing files.
259 use_cuda (bool): same as in `torch.autograd.profiler.profile`.
260 """
261 self._enable_predicate = enable_predicate
262 self._use_cuda = use_cuda
263 self._output_dir = output_dir
264
265 def before_step(self):
267 self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
268 self._profiler.__enter__()
269 else:
270 self._profiler = None
271
272 def after_step(self):
273 if self._profiler is None:
274 return
275 self._profiler.__exit__(None, None, None)
276 out_file = os.path.join(
277 self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
278 )
279 if "://" not in out_file:
280 self._profiler.export_chrome_trace(out_file)
281 else:
282 # Support non-posix filesystems
283 with tempfile.TemporaryDirectory(prefix="fastreid_profiler") as d:
284 tmp_file = os.path.join(d, "tmp.json")
285 self._profiler.export_chrome_trace(tmp_file)
286 with open(tmp_file) as f:
287 content = f.read()
288 with PathManager.open(out_file, "w") as f:
289 f.write(content)
290
291
293 """
294 Run an evaluation function periodically, and at the end of training.
295 It is executed every ``eval_period`` iterations and after the last iteration.
296 """
297
298 def __init__(self, eval_period, eval_function):
299 """
300 Args:
301 eval_period (int): the period to run `eval_function`.
302 eval_function (callable): a function which takes no arguments, and
303 returns a nested dict of evaluation metrics.
304 Note:
305 This hook must be enabled in all or none workers.
306 If you would like only certain workers to perform evaluation,
307 give other workers a no-op function (`eval_function=lambda: None`).
308 """
309 self._period = eval_period
310 self._func = eval_function
311
312 def _do_eval(self):
313 results = self._func()
314
315 if results:
316 assert isinstance(
317 results, dict
318 ), "Eval function must return a dict. Got {} instead.".format(results)
319
320 flattened_results = flatten_results_dict(results)
321 for k, v in flattened_results.items():
322 try:
323 v = float(v)
324 except Exception:
325 raise ValueError(
326 "[EvalHook] eval_function should return a nested dict of float. "
327 "Got '{}: {}' instead.".format(k, v)
328 )
329 self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
330
331 # Remove extra memory cache of main process due to evaluation
332 torch.cuda.empty_cache()
333
334 def after_step(self):
335 next_iter = self.trainer.iter + 1
336 is_final = next_iter == self.trainer.max_iter
337 if is_final or (self._period > 0 and next_iter % self._period == 0):
338 self._do_eval()
339 # Evaluation may take different time among workers.
340 # A barrier make them start the next iteration together.
341 comm.synchronize()
342
343 def after_train(self):
344 # func is likely a closure that holds reference to the trainer
345 # therefore we clean it to avoid circular reference in the end
346 del self._func
347
348
350 """
351 The standard implementation of BatchNorm uses EMA in inference, which is
352 sometimes suboptimal.
353 This class computes the true average of statistics rather than the moving average,
354 and put true averages to every BN layer in the given model.
355 It is executed after the last iteration.
356 """
357
358 def __init__(self, model, data_loader, num_iter):
359 """
360 Args:
361 model (nn.Module): a module whose all BN layers in training mode will be
362 updated by precise BN.
363 Note that user is responsible for ensuring the BN layers to be
364 updated are in training mode when this hook is triggered.
365 data_loader (iterable): it will produce data to be run by `model(data)`.
366 num_iter (int): number of iterations used to compute the precise
367 statistics.
368 """
369 self._logger = logging.getLogger(__name__)
370 if len(get_bn_modules(model)) == 0:
371 self._logger.info(
372 "PreciseBN is disabled because model does not contain BN layers in training mode."
373 )
374 self._disabled = True
375 return
376
377 self._model = model
378 self._data_loader = data_loader
379 self._num_iter = num_iter
380 self._disabled = False
381
382 self._data_iter = None
383
384 def after_step(self):
385 next_iter = self.trainer.iter + 1
386 is_final = next_iter == self.trainer.max_iter
387 if is_final:
388 self.update_stats()
389
390 def update_stats(self):
391 """
392 Update the model with precise statistics. Users can manually call this method.
393 """
394 if self._disabled:
395 return
396
397 if self._data_iter is None:
398 self._data_iter = iter(self._data_loader)
399
400 def data_loader():
401 for num_iter in itertools.count(1):
402 if num_iter % 100 == 0:
403 self._logger.info(
404 "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
405 )
406 # This way we can reuse the same iterator
407 yield next(self._data_iter)
408
409 with EventStorage(): # capture events in a new storage to discard them
410 self._logger.info(
411 "Running precise-BN for {} iterations... ".format(self._num_iter)
412 + "Note that this could produce different statistics every time."
413 )
414 update_bn_stats(self._model, data_loader(), self._num_iter)
415
416
418 def __init__(self, model, optimizer, freeze_layers, freeze_iters):
419 self._logger = logging.getLogger(__name__)
420
421 if isinstance(model, DistributedDataParallel):
422 model = model.module
423 self.model = model
424 self.optimizer = optimizer
425
426 self.freeze_layers = freeze_layers
427 self.freeze_iters = freeze_iters
428
429 # Previous parameters freeze status
430 param_freeze = {}
431 for param_group in self.optimizer.param_groups:
432 param_name = param_group['name']
433 param_freeze[param_name] = param_group['freeze']
434 self.param_freeze = param_freeze
435
436 self.is_frozen = False
437
438 def before_step(self):
439 # Freeze specific layers
440 if self.trainer.iter <= self.freeze_iters and not self.is_frozen:
442
443 # Recover original layers status
444 if self.trainer.iter > self.freeze_iters and self.is_frozen:
445 self.open_all_layer()
446
448 for layer in self.freeze_layers:
449 if not hasattr(self.model, layer):
450 self._logger.info(f'{layer} is not an attribute of the model, will skip this layer')
451
452 for param_group in self.optimizer.param_groups:
453 param_name = param_group['name']
454 if param_name.split('.')[0] in self.freeze_layers:
455 param_group['freeze'] = True
456
457 # Change BN in freeze layers to eval mode
458 for name, module in self.model.named_children():
459 if name in self.freeze_layers: module.eval()
460
461 self.is_frozen = True
462
463 def open_all_layer(self):
464 self.model.train()
465 for param_group in self.optimizer.param_groups:
466 param_name = param_group['name']
467 param_group['freeze'] = self.param_freeze[param_name]
468
469 self.is_frozen = False
470
471
473 def __init__(self, swa_start: int, swa_freq: int, swa_lr_factor: float, eta_min: float, lr_sched=False, ):
474 self.swa_start = swa_start
475 self.swa_freq = swa_freq
476 self.swa_lr_factor = swa_lr_factor
477 self.eta_min = eta_min
478 self.lr_sched = lr_sched
479
480 def before_step(self):
481 is_swa = self.trainer.iter == self.swa_start
482 if is_swa:
483 # Wrapper optimizer with SWA
484 self.trainer.optimizer = optim.SWA(self.trainer.optimizer, self.swa_freq, self.swa_lr_factor)
485 self.trainer.optimizer.reset_lr_to_swa()
486
487 if self.lr_sched:
488 self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
489 optimizer=self.trainer.optimizer,
490 T_0=self.swa_freq,
491 eta_min=self.eta_min,
492 )
493
494 def after_step(self):
495 next_iter = self.trainer.iter + 1
496
497 # Use Cyclic learning rate scheduler
498 if next_iter > self.swa_start and self.lr_sched:
499 self.scheduler.step()
500
501 is_final = next_iter == self.trainer.max_iter
502 if is_final:
503 self.trainer.optimizer.swap_swa_param()
__init__(self, enable_predicate, output_dir, *use_cuda=True)
Definition hooks.py:252
__init__(self, *before_train=None, after_train=None, before_step=None, after_step=None)
Definition hooks.py:48
__init__(self, eval_period, eval_function)
Definition hooks.py:298
__init__(self, model, optimizer, freeze_layers, freeze_iters)
Definition hooks.py:418
__init__(self, warmup_iter=3)
Definition hooks.py:89
__init__(self, optimizer, scheduler)
Definition hooks.py:198
__init__(self, writers, period=20)
Definition hooks.py:152
__init__(self, model, data_loader, num_iter)
Definition hooks.py:358
__init__(self, int swa_start, int swa_freq, float swa_lr_factor, float eta_min, lr_sched=False)
Definition hooks.py:473