81def inference_on_dataset(model, data_loader, evaluator):
82 """
83 Run model on the data_loader and evaluate the metrics with evaluator.
84 The model will be used in eval mode.
85 Args:
86 model (nn.Module): a module which accepts an object from
87 `data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
88 If you wish to evaluate a model in `training` mode instead, you can
89 wrap the given model and override its behavior of `.eval()` and `.train()`.
90 data_loader: an iterable object with a length.
91 The elements it generates will be the inputs to the model.
92 evaluator (DatasetEvaluator): the evaluator to run. Use
93 :class:`DatasetEvaluators([])` if you only want to benchmark, but
94 don't want to do any evaluation.
95 Returns:
96 The return value of `evaluator.evaluate()`
97 """
98 logger = logging.getLogger(__name__)
99 logger.info("Start inference on {} images".format(len(data_loader.dataset)))
100
101 total = len(data_loader)
102 evaluator.reset()
103
104 num_warmup = min(5, total - 1)
105 start_time = time.perf_counter()
106 total_compute_time = 0
107 with inference_context(model), torch.no_grad():
108 for idx, inputs in enumerate(data_loader):
109 if idx == num_warmup:
110 start_time = time.perf_counter()
111 total_compute_time = 0
112
113 start_compute_time = time.perf_counter()
114 outputs = model(inputs)
115 total_compute_time += time.perf_counter() - start_compute_time
116 evaluator.process(inputs, outputs)
117
118 idx += 1
119 iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
120 seconds_per_batch = total_compute_time / iters_after_start
121 if idx >= num_warmup * 2 or seconds_per_batch > 30:
122 total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
123 eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
124 log_every_n_seconds(
125 logging.INFO,
126 "Inference done {}/{}. {:.4f} s / batch. ETA={}".format(
127 idx + 1, total, seconds_per_batch, str(eta)
128 ),
129 n=30,
130 )
131
132
133 total_time = time.perf_counter() - start_time
134 total_time_str = str(datetime.timedelta(seconds=total_time))
135
136 logger.info(
137 "Total inference time: {} ({:.6f} s / batch per device)".format(
138 total_time_str, total_time / (total - num_warmup)
139 )
140 )
141 total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
142 logger.info(
143 "Total inference pure compute time: {} ({:.6f} s / batch per device)".format(
144 total_compute_time_str, total_compute_time / (total - num_warmup)
145 )
146 )
147 results = evaluator.evaluate()
148
149
150 if results is None:
151 results = {}
152 return results
153
154
155@contextmanager