8from collections
import defaultdict
14from termcolor
import colored
15from torch.nn.parallel
import DataParallel, DistributedDataParallel
22 A checkpointer that can save/load model as well as extra checkpointable
31 save_to_disk: bool =
True,
32 **checkpointables: object,
36 model (nn.Module): model.
37 save_dir (str): a directory to save and find checkpoints.
38 save_to_disk (bool): if True, save checkpoint to disk, otherwise
39 disable saving for this checkpointer.
40 checkpointables (object): any checkpointable objects, i.e., objects
41 that have the `state_dict()` and `load_state_dict()` method. For
42 example, it can be used like
43 `Checkpointer(model, "dir", optimizer=optimizer)`.
45 if isinstance(model, (DistributedDataParallel, DataParallel)):
49 self.
logger = logging.getLogger(__name__)
53 def save(self, name: str, **kwargs: dict):
55 Dump model and checkpointables to a file.
57 name (str): name of the file.
58 kwargs (dict): extra arbitrary data to save.
64 data[
"model"] = self.
model.state_dict()
66 data[key] = obj.state_dict()
69 basename =
"{}.pth".format(name)
70 save_file = os.path.join(self.
save_dir, basename)
71 assert os.path.basename(save_file) == basename, basename
72 self.
logger.info(
"Saving checkpoint to {}".format(save_file))
73 with PathManager.open(save_file,
"wb")
as f:
79 Load from the given checkpoint. When path points to network file, this
80 function has to be called on all ranks.
82 path (str): path or url to the checkpoint. If empty, will not load
86 extra data loaded from the checkpoint that has not been
87 processed. For example, those saved with
88 :meth:`.save(**extra_data)`.
90 print(f
'path = {path}')
94 "No checkpoint found. Training model from scratch"
97 self.
logger.info(
"Loading checkpoint from {}".format(path))
98 if not os.path.isfile(path):
99 path = PathManager.get_local_path(path)
100 assert os.path.isfile(path),
"Checkpoint {} not found!".format(path)
105 if key
in checkpoint:
106 self.
logger.info(
"Loading {} from {}".format(key, path))
107 obj.load_state_dict(checkpoint.pop(key))
115 bool: whether a checkpoint exists in the target directory.
117 save_file = os.path.join(self.
save_dir,
"last_checkpoint")
118 return PathManager.exists(save_file)
123 str: The latest checkpoint file in target directory.
125 save_file = os.path.join(self.
save_dir,
"last_checkpoint")
127 with PathManager.open(save_file,
"r")
as f:
128 last_saved = f.read().strip()
133 return os.path.join(self.
save_dir, last_saved)
138 list: All available checkpoint files (.pth files) in target
141 all_model_checkpoints = [
143 for file
in PathManager.ls(self.
save_dir)
144 if PathManager.isfile(os.path.join(self.
save_dir, file))
145 and file.endswith(
".pth")
147 return all_model_checkpoints
151 If `resume` is True, this method attempts to resume from the last
152 checkpoint, if exists. Otherwise, load checkpoint from the given path.
153 This is useful when restarting an interrupted training job.
155 path (str): path to the checkpoint.
156 resume (bool): if True, resume from the last checkpoint if it exists.
158 same as :meth:`load`.
162 return self.
load(path)
166 Tag the last checkpoint.
168 last_filename_basename (str): the basename of the last filename.
170 save_file = os.path.join(self.
save_dir,
"last_checkpoint")
171 with PathManager.open(save_file,
"w")
as f:
172 f.write(last_filename_basename)
176 Load a checkpoint file. Can be overwritten by subclasses to support
179 f (str): a locally mounted file path.
181 dict: with keys "model" and optionally others that are saved by
182 the checkpointer dict["model"] must be a dict which maps strings
183 to torch.Tensor or numpy arrays.
185 return torch.load(f, map_location=torch.device(
"cpu"))
189 Load weights from a checkpoint.
191 checkpoint (Any): checkpoint contains the weights.
193 checkpoint_state_dict = checkpoint.pop(
"model")
202 model_state_dict = self.
model.state_dict()
203 for k
in list(checkpoint_state_dict.keys()):
204 if k
in model_state_dict:
205 shape_model = tuple(model_state_dict[k].shape)
206 shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
207 if shape_model != shape_checkpoint:
209 "'{}' has shape {} in the checkpoint but {} in the "
210 "model! Skipped.".format(
211 k, shape_checkpoint, shape_model
214 checkpoint_state_dict.pop(k)
216 incompatible = self.
model.load_state_dict(
217 checkpoint_state_dict, strict=
False
219 if incompatible.missing_keys:
221 get_missing_parameters_message(incompatible.missing_keys)
223 if incompatible.unexpected_keys:
225 get_unexpected_parameters_message(incompatible.unexpected_keys)
230 In-place convert all numpy arrays in the state_dict to torch tensor.
232 state_dict (dict): a state-dict to be loaded to the model.
237 for k
in list(state_dict.keys()):
239 if not isinstance(v, np.ndarray)
and not isinstance(
243 "Unsupported type found in checkpoint! {}: {}".format(
247 if not isinstance(v, torch.Tensor):
248 state_dict[k] = torch.from_numpy(v)
253 Save checkpoints periodically. When `.step(iteration)` is called, it will
254 execute `checkpointer.save` on the given checkpointer, if iteration is a
255 multiple of period or if `max_iter` is reached.
258 def __init__(self, checkpointer: Any, period: int, max_iter: int =
None):
261 checkpointer (Any): the checkpointer object used to save
263 period (int): the period to save checkpoint.
264 max_iter (int): maximum number of iterations. When it is reached,
265 a checkpoint named "model_final" will be saved.
271 def step(self, iteration: int, **kwargs: Any):
273 Perform the appropriate action at the given iteration.
275 iteration (int): the current iteration, ranged in [0, max_iter-1].
276 kwargs (Any): extra data to save, same as in
277 :meth:`Checkpointer.save`.
279 iteration = int(iteration)
280 additional_state = {
"iteration": iteration}
281 additional_state.update(kwargs)
282 if self.
period > 0
and (iteration + 1) % self.
period == 0:
284 "model_{:07d}".format(iteration), **additional_state
289 def save(self, name: str, **kwargs: Any):
291 Same argument as :meth:`Checkpointer.save`.
292 Use this method to manually save checkpoints outside the schedule.
294 name (str): file name.
295 kwargs (Any): extra data to save, same as in
296 :meth:`Checkpointer.save`.
303 Get a logging-friendly message to report parameter names (keys) that are in
304 the model but not found in a checkpoint.
306 keys (list[str]): List of keys that were not found in the checkpoint.
311 msg =
"Some model parameters are not in the checkpoint:\n"
313 " " + colored(k +
_group_to_str(v),
"blue")
for k, v
in groups.items()
320 Get a logging-friendly message to report parameter names (keys) that are in
321 the checkpoint but not found in the model.
323 keys (list[str]): List of keys that were not found in the model.
328 msg =
"The checkpoint contains parameters not used by the model:\n"
331 for k, v
in groups.items()
338 Strip the prefix in metadata, if any.
340 state_dict (OrderedDict): a state-dict to be loaded to the model.
341 prefix (str): prefix.
343 keys = sorted(state_dict.keys())
344 if not all(len(key) == 0
or key.startswith(prefix)
for key
in keys):
348 newkey = key[len(prefix):]
349 state_dict[newkey] = state_dict.pop(key)
353 metadata = state_dict._metadata
354 except AttributeError:
357 for key
in list(metadata.keys()):
365 newkey = key[len(prefix):]
366 metadata[newkey] = metadata.pop(key)
371 Group keys based on common prefixes. A prefix is the string up to the final
374 keys (list[str]): list of parameter names, i.e. keys in the model
377 dict[list]: keys with common prefixes are grouped into lists.
379 groups = defaultdict(list)
383 head, tail = key[:pos], [key[pos + 1:]]
386 groups[head].extend(tail)
392 Format a group of parameter name suffixes into a loggable string.
394 group (list[str]): list of parameter name suffixes.
396 str: formated string.
402 return "." + group[0]
404 return ".{" +
", ".join(group) +
"}"
save(self, str name, **dict kwargs)
get_checkpoint_file(self)
_convert_ndarray_to_tensor(self, dict state_dict)
tag_last_checkpoint(self, str last_filename_basename)
resume_or_load(self, str path, *bool resume=True)
get_all_checkpoint_files(self)
_load_model(self, Any checkpoint)
__init__(self, nn.Module model, str save_dir="", *bool save_to_disk=True, **object checkpointables)
step(self, int iteration, **Any kwargs)
__init__(self, Any checkpointer, int period, int max_iter=None)
save(self, str name, **Any kwargs)
_strip_prefix_if_present(collections.OrderedDict state_dict, str prefix)
_group_checkpoint_keys(list keys)
_group_to_str(list group)
get_unexpected_parameters_message(list keys)
get_missing_parameters_message(list keys)