198 def __init__(self, optimizer, scheduler):
199 """
200 Args:
201 optimizer (torch.optim.Optimizer):
202 scheduler (torch.optim._LRScheduler)
203 """
204 self._optimizer = optimizer
205 self._scheduler = scheduler
206
207
208
209 largest_group = max(len(g["params"]) for g in optimizer.param_groups)
210
211 if largest_group == 1:
212
213
214 lr_count = Counter([g["lr"] for g in optimizer.param_groups])
215 lr = lr_count.most_common()[0][0]
216 for i, g in enumerate(optimizer.param_groups):
217 if g["lr"] == lr:
218 self._best_param_group_id = i
219 break
220 else:
221 for i, g in enumerate(optimizer.param_groups):
222 if len(g["params"]) == largest_group:
223 self._best_param_group_id = i
224 break
225