195 def run_step(self):
196 """
197 Implement the standard training logic described above.
198 """
199 assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
200 start = time.perf_counter()
201 """
202 If your want to do something with the data, you can wrap the dataloader.
203 """
204 data = next(self._data_loader_iter)
205 data_time = time.perf_counter() - start
206
207 """
208 If your want to do something with the heads, you can wrap the model.
209 """
210
211 with amp.autocast(enabled=self.amp_enabled):
212 outs = self.model(data)
213
214
215 if isinstance(self.model, DistributedDataParallel):
216 loss_dict = self.model.module.losses(outs)
217 else:
218 loss_dict = self.model.losses(outs)
219
220 losses = sum(loss_dict.values())
221
222 with torch.cuda.stream(torch.cuda.Stream()):
223 metrics_dict = loss_dict
224 metrics_dict["data_time"] = data_time
225 self._write_metrics(metrics_dict)
226 self._detect_anomaly(losses, loss_dict)
227
228 """
229 If you need accumulate gradients or something similar, you can
230 wrap the optimizer with your custom `zero_grad()` method.
231 """
232 self.optimizer.zero_grad()
233
234 if self.amp_enabled:
235 self.scaler.scale(losses).backward()
236 self.scaler.step(self.optimizer)
237 self.scaler.update()
238 else:
239 losses.backward()
240 """
241 If you need gradient clipping/scaling or other processing, you can
242 wrap the optimizer with your custom `step()` method.
243 """
244 self.optimizer.step()
245