Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Protected Member Functions | Protected Attributes | List of all members
fastreid.engine.train_loop.SimpleTrainer Class Reference
Inheritance diagram for fastreid.engine.train_loop.SimpleTrainer:
fastreid.engine.train_loop.TrainerBase fastreid.engine.defaults.DefaultTrainer

Public Member Functions

 __init__ (self, model, data_loader, optimizer, amp_enabled)
 
 run_step (self)
 
- Public Member Functions inherited from fastreid.engine.train_loop.TrainerBase
 register_hooks (self, hooks)
 
 train (self, int start_iter, int max_iter)
 
 before_train (self)
 
 after_train (self)
 
 before_step (self)
 
 after_step (self)
 

Public Attributes

 model
 
 data_loader
 
 optimizer
 
 amp_enabled
 
 scaler
 
 iter
 
- Public Attributes inherited from fastreid.engine.train_loop.TrainerBase
 iter
 
 max_iter
 

Protected Member Functions

 _detect_anomaly (self, losses, loss_dict)
 
 _write_metrics (self, dict metrics_dict)
 

Protected Attributes

 _data_loader_iter
 
- Protected Attributes inherited from fastreid.engine.train_loop.TrainerBase
 _hooks
 

Detailed Description

A simple trainer for the most common type of task:
single-cost single-optimizer single-data-source iterative optimization.
It assumes that every step, you:
1. Compute the loss with a data from the data_loader.
2. Compute the gradients with the above loss.
3. Update the model with the optimizer.
If you want to do anything fancier than this,
either subclass TrainerBase and implement your own `run_step`,
or write your own training loop.

Definition at line 154 of file train_loop.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.engine.train_loop.SimpleTrainer.__init__ ( self,
model,
data_loader,
optimizer,
amp_enabled )
Args:
    model: a torch Module. Takes a data from data_loader and returns a
        dict of heads.
    data_loader: an iterable. Contains data to be used to call model.
    optimizer: a torch optimizer.

Reimplemented from fastreid.engine.train_loop.TrainerBase.

Reimplemented in fastreid.engine.defaults.DefaultTrainer.

Definition at line 167 of file train_loop.py.

167 def __init__(self, model, data_loader, optimizer, amp_enabled):
168 """
169 Args:
170 model: a torch Module. Takes a data from data_loader and returns a
171 dict of heads.
172 data_loader: an iterable. Contains data to be used to call model.
173 optimizer: a torch optimizer.
174 """
175 super().__init__()
176
177 """
178 We set the model to training mode in the trainer.
179 However it's valid to train a model that's in eval mode.
180 If you want your model (or a submodule of it) to behave
181 like evaluation during training, you can overwrite its train() method.
182 """
183 model.train()
184
185 self.model = model
186 self.data_loader = data_loader
187 self._data_loader_iter = iter(data_loader)
188 self.optimizer = optimizer
189 self.amp_enabled = amp_enabled
190
191 if amp_enabled:
192 # Creates a GradScaler once at the beginning of training.
193 self.scaler = amp.GradScaler()
194

Member Function Documentation

◆ _detect_anomaly()

fastreid.engine.train_loop.SimpleTrainer._detect_anomaly ( self,
losses,
loss_dict )
protected

Definition at line 246 of file train_loop.py.

246 def _detect_anomaly(self, losses, loss_dict):
247 if not torch.isfinite(losses).all():
248 raise FloatingPointError(
249 "Loss became infinite or NaN at iteration={}!\nloss_dict = {}".format(
250 self.iter, loss_dict
251 )
252 )
253

◆ _write_metrics()

fastreid.engine.train_loop.SimpleTrainer._write_metrics ( self,
dict metrics_dict )
protected
Args:
    metrics_dict (dict): dict of scalar metrics

Definition at line 254 of file train_loop.py.

254 def _write_metrics(self, metrics_dict: dict):
255 """
256 Args:
257 metrics_dict (dict): dict of scalar metrics
258 """
259 metrics_dict = {
260 k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
261 for k, v in metrics_dict.items()
262 }
263 # gather metrics among all workers for logging
264 # This assumes we do DDP-style training, which is currently the only
265 # supported method in fastreid.
266 all_metrics_dict = comm.gather(metrics_dict)
267
268 if comm.is_main_process():
269 if "data_time" in all_metrics_dict[0]:
270 # data_time among workers can have high variance. The actual latency
271 # caused by data_time is the maximum among workers.
272 data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
273 self.storage.put_scalar("data_time", data_time)
274
275 # average the rest metrics
276 metrics_dict = {
277 k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
278 }
279 total_losses_reduced = sum(loss for loss in metrics_dict.values())
280
281 self.storage.put_scalar("total_loss", total_losses_reduced)
282 if len(metrics_dict) > 1:
283 self.storage.put_scalars(**metrics_dict)

◆ run_step()

fastreid.engine.train_loop.SimpleTrainer.run_step ( self)
Implement the standard training logic described above.

Reimplemented from fastreid.engine.train_loop.TrainerBase.

Definition at line 195 of file train_loop.py.

195 def run_step(self):
196 """
197 Implement the standard training logic described above.
198 """
199 assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
200 start = time.perf_counter()
201 """
202 If your want to do something with the data, you can wrap the dataloader.
203 """
204 data = next(self._data_loader_iter)
205 data_time = time.perf_counter() - start
206
207 """
208 If your want to do something with the heads, you can wrap the model.
209 """
210
211 with amp.autocast(enabled=self.amp_enabled):
212 outs = self.model(data)
213
214 # Compute loss
215 if isinstance(self.model, DistributedDataParallel):
216 loss_dict = self.model.module.losses(outs)
217 else:
218 loss_dict = self.model.losses(outs)
219
220 losses = sum(loss_dict.values())
221
222 with torch.cuda.stream(torch.cuda.Stream()):
223 metrics_dict = loss_dict
224 metrics_dict["data_time"] = data_time
225 self._write_metrics(metrics_dict)
226 self._detect_anomaly(losses, loss_dict)
227
228 """
229 If you need accumulate gradients or something similar, you can
230 wrap the optimizer with your custom `zero_grad()` method.
231 """
232 self.optimizer.zero_grad()
233
234 if self.amp_enabled:
235 self.scaler.scale(losses).backward()
236 self.scaler.step(self.optimizer)
237 self.scaler.update()
238 else:
239 losses.backward()
240 """
241 If you need gradient clipping/scaling or other processing, you can
242 wrap the optimizer with your custom `step()` method.
243 """
244 self.optimizer.step()
245

Member Data Documentation

◆ _data_loader_iter

fastreid.engine.train_loop.SimpleTrainer._data_loader_iter
protected

Definition at line 187 of file train_loop.py.

◆ amp_enabled

fastreid.engine.train_loop.SimpleTrainer.amp_enabled

Definition at line 189 of file train_loop.py.

◆ data_loader

fastreid.engine.train_loop.SimpleTrainer.data_loader

Definition at line 186 of file train_loop.py.

◆ iter

fastreid.engine.train_loop.SimpleTrainer.iter

Definition at line 250 of file train_loop.py.

◆ model

fastreid.engine.train_loop.SimpleTrainer.model

Definition at line 185 of file train_loop.py.

◆ optimizer

fastreid.engine.train_loop.SimpleTrainer.optimizer

Definition at line 188 of file train_loop.py.

◆ scaler

fastreid.engine.train_loop.SimpleTrainer.scaler

Definition at line 193 of file train_loop.py.


The documentation for this class was generated from the following file: