418 def __init__(self, model, optimizer, freeze_layers, freeze_iters):
419 self._logger = logging.getLogger(__name__)
420
421 if isinstance(model, DistributedDataParallel):
422 model = model.module
423 self.model = model
424 self.optimizer = optimizer
425
426 self.freeze_layers = freeze_layers
427 self.freeze_iters = freeze_iters
428
429
430 param_freeze = {}
431 for param_group in self.optimizer.param_groups:
432 param_name = param_group['name']
433 param_freeze[param_name] = param_group['freeze']
434 self.param_freeze = param_freeze
435
436 self.is_frozen = False
437