Safemotion Lib
Loading...
Searching...
No Matches
train_loop.py
Go to the documentation of this file.
1# encoding: utf-8
2"""
3credit:
4https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/train_loop.py
5"""
6
7import logging
8import time
9import weakref
10
11import numpy as np
12import torch
13from torch.cuda import amp
14from torch.nn.parallel import DistributedDataParallel
15
16import fastreid.utils.comm as comm
17from fastreid.utils.events import EventStorage
18
19__all__ = ["HookBase", "TrainerBase", "SimpleTrainer"]
20
21
23 """
24 Base class for hooks that can be registered with :class:`TrainerBase`.
25 Each hook can implement 4 methods. The way they are called is demonstrated
26 in the following snippet:
27 .. code-block:: python
28 hook.before_train()
29 for iter in range(start_iter, max_iter):
30 hook.before_step()
31 trainer.run_step()
32 hook.after_step()
33 hook.after_train()
34 Notes:
35 1. In the hook method, users can access `self.trainer` to access more
36 properties about the context (e.g., current iteration).
37 2. A hook that does something in :meth:`before_step` can often be
38 implemented equivalently in :meth:`after_step`.
39 If the hook takes non-trivial time, it is strongly recommended to
40 implement the hook in :meth:`after_step` instead of :meth:`before_step`.
41 The convention is that :meth:`before_step` should only take negligible time.
42 Following this convention will allow hooks that do care about the difference
43 between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
44 function properly.
45 Attributes:
46 trainer: A weak reference to the trainer object. Set by the trainer when the hook is
47 registered.
48 """
49
50 def before_train(self):
51 """
52 Called before the first iteration.
53 """
54 pass
55
56 def after_train(self):
57 """
58 Called after the last iteration.
59 """
60 pass
61
62 def before_step(self):
63 """
64 Called before each iteration.
65 """
66 pass
67
68 def after_step(self):
69 """
70 Called after each iteration.
71 """
72 pass
73
74
76 """
77 Base class for iterative trainer with hooks.
78 The only assumption we made here is: the training runs in a loop.
79 A subclass can implement what the loop is.
80 We made no assumptions about the existence of dataloader, optimizer, model, etc.
81 Attributes:
82 iter(int): the current iteration.
83 start_iter(int): The iteration to start with.
84 By convention the minimum possible value is 0.
85 max_iter(int): The iteration to end training.
86 storage(EventStorage): An EventStorage that's opened during the course of training.
87 """
88
89 def __init__(self):
90 self._hooks = []
91
92 def register_hooks(self, hooks):
93 """
94 Register hooks to the trainer. The hooks are executed in the order
95 they are registered.
96 Args:
97 hooks (list[Optional[HookBase]]): list of hooks
98 """
99 hooks = [h for h in hooks if h is not None]
100 for h in hooks:
101 assert isinstance(h, HookBase)
102 # To avoid circular reference, hooks and trainer cannot own each other.
103 # This normally does not matter, but will cause memory leak if the
104 # involved objects contain __del__:
105 # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
106 h.trainer = weakref.proxy(self)
107 self._hooks.extend(hooks)
108
109 def train(self, start_iter: int, max_iter: int):
110 """
111 Args:
112 start_iter, max_iter (int): See docs above
113 """
114 logger = logging.getLogger(__name__)
115 logger.info("Starting training from iteration {}".format(start_iter))
116
117 self.iter = self.start_iter = start_iter
118 self.max_iter = max_iter
119
120 with EventStorage(start_iter) as self.storage:
121 try:
122 self.before_train()
123 for self.iter in range(start_iter, max_iter):
124 self.before_step()
125 self.run_step()
126 self.after_step()
127 except Exception:
128 logger.exception("Exception during training:")
129 finally:
130 self.after_train()
131
132 def before_train(self):
133 for h in self._hooks:
134 h.before_train()
135
136 def after_train(self):
137 for h in self._hooks:
138 h.after_train()
139
140 def before_step(self):
141 for h in self._hooks:
142 h.before_step()
143
144 def after_step(self):
145 for h in self._hooks:
146 h.after_step()
147 # this guarantees, that in each hook's after_step, storage.iter == trainer.iter
148 self.storage.step()
149
150 def run_step(self):
151 raise NotImplementedError
152
153
155 """
156 A simple trainer for the most common type of task:
157 single-cost single-optimizer single-data-source iterative optimization.
158 It assumes that every step, you:
159 1. Compute the loss with a data from the data_loader.
160 2. Compute the gradients with the above loss.
161 3. Update the model with the optimizer.
162 If you want to do anything fancier than this,
163 either subclass TrainerBase and implement your own `run_step`,
164 or write your own training loop.
165 """
166
167 def __init__(self, model, data_loader, optimizer, amp_enabled):
168 """
169 Args:
170 model: a torch Module. Takes a data from data_loader and returns a
171 dict of heads.
172 data_loader: an iterable. Contains data to be used to call model.
173 optimizer: a torch optimizer.
174 """
175 super().__init__()
176
177 """
178 We set the model to training mode in the trainer.
179 However it's valid to train a model that's in eval mode.
180 If you want your model (or a submodule of it) to behave
181 like evaluation during training, you can overwrite its train() method.
182 """
183 model.train()
184
185 self.model = model
186 self.data_loader = data_loader
187 self._data_loader_iter = iter(data_loader)
188 self.optimizer = optimizer
189 self.amp_enabled = amp_enabled
190
191 if amp_enabled:
192 # Creates a GradScaler once at the beginning of training.
193 self.scaler = amp.GradScaler()
194
195 def run_step(self):
196 """
197 Implement the standard training logic described above.
198 """
199 assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
200 start = time.perf_counter()
201 """
202 If your want to do something with the data, you can wrap the dataloader.
203 """
204 data = next(self._data_loader_iter)
205 data_time = time.perf_counter() - start
206
207 """
208 If your want to do something with the heads, you can wrap the model.
209 """
210
211 with amp.autocast(enabled=self.amp_enabled):
212 outs = self.model(data)
213
214 # Compute loss
215 if isinstance(self.model, DistributedDataParallel):
216 loss_dict = self.model.module.losses(outs)
217 else:
218 loss_dict = self.model.losses(outs)
219
220 losses = sum(loss_dict.values())
221
222 with torch.cuda.stream(torch.cuda.Stream()):
223 metrics_dict = loss_dict
224 metrics_dict["data_time"] = data_time
225 self._write_metrics(metrics_dict)
226 self._detect_anomaly(losses, loss_dict)
227
228 """
229 If you need accumulate gradients or something similar, you can
230 wrap the optimizer with your custom `zero_grad()` method.
231 """
232 self.optimizer.zero_grad()
233
234 if self.amp_enabled:
235 self.scaler.scale(losses).backward()
236 self.scaler.step(self.optimizer)
237 self.scaler.update()
238 else:
239 losses.backward()
240 """
241 If you need gradient clipping/scaling or other processing, you can
242 wrap the optimizer with your custom `step()` method.
243 """
244 self.optimizer.step()
245
246 def _detect_anomaly(self, losses, loss_dict):
247 if not torch.isfinite(losses).all():
248 raise FloatingPointError(
249 "Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
250 self.iteriter, loss_dict
251 )
252 )
253
254 def _write_metrics(self, metrics_dict: dict):
255 """
256 Args:
257 metrics_dict (dict): dict of scalar metrics
258 """
259 metrics_dict = {
260 k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
261 for k, v in metrics_dict.items()
262 }
263 # gather metrics among all workers for logging
264 # This assumes we do DDP-style training, which is currently the only
265 # supported method in fastreid.
266 all_metrics_dict = comm.gather(metrics_dict)
267
268 if comm.is_main_process():
269 if "data_time" in all_metrics_dict[0]:
270 # data_time among workers can have high variance. The actual latency
271 # caused by data_time is the maximum among workers.
272 data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
273 self.storage.put_scalar("data_time", data_time)
274
275 # average the rest metrics
276 metrics_dict = {
277 k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
278 }
279 total_losses_reduced = sum(loss for loss in metrics_dict.values())
280
281 self.storage.put_scalar("total_loss", total_losses_reduced)
282 if len(metrics_dict) > 1:
283 self.storage.put_scalars(**metrics_dict)
_detect_anomaly(self, losses, loss_dict)
_write_metrics(self, dict metrics_dict)
__init__(self, model, data_loader, optimizer, amp_enabled)
train(self, int start_iter, int max_iter)