Safemotion Lib
Loading...
Searching...
No Matches
checkpoint.py
Go to the documentation of this file.
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
4import collections
5import copy
6import logging
7import os
8from collections import defaultdict
9from typing import Any
10
11import numpy as np
12import torch
13import torch.nn as nn
14from termcolor import colored
15from torch.nn.parallel import DataParallel, DistributedDataParallel
16
17from fastreid.utils.file_io import PathManager
18
19
20class Checkpointer(object):
21 """
22 A checkpointer that can save/load model as well as extra checkpointable
23 objects.
24 """
25
27 self,
28 model: nn.Module,
29 save_dir: str = "",
30 *,
31 save_to_disk: bool = True,
32 **checkpointables: object,
33 ):
34 """
35 Args:
36 model (nn.Module): model.
37 save_dir (str): a directory to save and find checkpoints.
38 save_to_disk (bool): if True, save checkpoint to disk, otherwise
39 disable saving for this checkpointer.
40 checkpointables (object): any checkpointable objects, i.e., objects
41 that have the `state_dict()` and `load_state_dict()` method. For
42 example, it can be used like
43 `Checkpointer(model, "dir", optimizer=optimizer)`.
44 """
45 if isinstance(model, (DistributedDataParallel, DataParallel)):
46 model = model.module
47 self.model = model
48 self.checkpointables = copy.copy(checkpointables)
49 self.logger = logging.getLogger(__name__)
50 self.save_dir = save_dir
51 self.save_to_disk = save_to_disk
52
53 def save(self, name: str, **kwargs: dict):
54 """
55 Dump model and checkpointables to a file.
56 Args:
57 name (str): name of the file.
58 kwargs (dict): extra arbitrary data to save.
59 """
60 if not self.save_dir or not self.save_to_disk:
61 return
62
63 data = {}
64 data["model"] = self.model.state_dict()
65 for key, obj in self.checkpointables.items():
66 data[key] = obj.state_dict()
67 data.update(kwargs)
68
69 basename = "{}.pth".format(name)
70 save_file = os.path.join(self.save_dir, basename)
71 assert os.path.basename(save_file) == basename, basename
72 self.logger.info("Saving checkpoint to {}".format(save_file))
73 with PathManager.open(save_file, "wb") as f:
74 torch.save(data, f)
75 self.tag_last_checkpoint(basename)
76
77 def load(self, path: str):
78 """
79 Load from the given checkpoint. When path points to network file, this
80 function has to be called on all ranks.
81 Args:
82 path (str): path or url to the checkpoint. If empty, will not load
83 anything.
84 Returns:
85 dict:
86 extra data loaded from the checkpoint that has not been
87 processed. For example, those saved with
88 :meth:`.save(**extra_data)`.
89 """
90 print(f'path = {path}')
91 if not path:
92 # no checkpoint provided
93 self.logger.info(
94 "No checkpoint found. Training model from scratch"
95 )
96 return {}
97 self.logger.info("Loading checkpoint from {}".format(path))
98 if not os.path.isfile(path):
99 path = PathManager.get_local_path(path)
100 assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
101
102 checkpoint = self._load_file(path)
103 self._load_model(checkpoint)
104 for key, obj in self.checkpointables.items():
105 if key in checkpoint:
106 self.logger.info("Loading {} from {}".format(key, path))
107 obj.load_state_dict(checkpoint.pop(key))
108
109 # return any further checkpoint data
110 return checkpoint
111
112 def has_checkpoint(self):
113 """
114 Returns:
115 bool: whether a checkpoint exists in the target directory.
116 """
117 save_file = os.path.join(self.save_dir, "last_checkpoint")
118 return PathManager.exists(save_file)
119
121 """
122 Returns:
123 str: The latest checkpoint file in target directory.
124 """
125 save_file = os.path.join(self.save_dir, "last_checkpoint")
126 try:
127 with PathManager.open(save_file, "r") as f:
128 last_saved = f.read().strip()
129 except IOError:
130 # if file doesn't exist, maybe because it has just been
131 # deleted by a separate process
132 return ""
133 return os.path.join(self.save_dir, last_saved)
134
136 """
137 Returns:
138 list: All available checkpoint files (.pth files) in target
139 directory.
140 """
141 all_model_checkpoints = [
142 os.path.join(self.save_dir, file)
143 for file in PathManager.ls(self.save_dir)
144 if PathManager.isfile(os.path.join(self.save_dir, file))
145 and file.endswith(".pth")
146 ]
147 return all_model_checkpoints
148
149 def resume_or_load(self, path: str, *, resume: bool = True):
150 """
151 If `resume` is True, this method attempts to resume from the last
152 checkpoint, if exists. Otherwise, load checkpoint from the given path.
153 This is useful when restarting an interrupted training job.
154 Args:
155 path (str): path to the checkpoint.
156 resume (bool): if True, resume from the last checkpoint if it exists.
157 Returns:
158 same as :meth:`load`.
159 """
160 if resume and self.has_checkpoint():
161 path = self.get_checkpoint_file()
162 return self.load(path)
163
164 def tag_last_checkpoint(self, last_filename_basename: str):
165 """
166 Tag the last checkpoint.
167 Args:
168 last_filename_basename (str): the basename of the last filename.
169 """
170 save_file = os.path.join(self.save_dir, "last_checkpoint")
171 with PathManager.open(save_file, "w") as f:
172 f.write(last_filename_basename)
173
174 def _load_file(self, f: str):
175 """
176 Load a checkpoint file. Can be overwritten by subclasses to support
177 different formats.
178 Args:
179 f (str): a locally mounted file path.
180 Returns:
181 dict: with keys "model" and optionally others that are saved by
182 the checkpointer dict["model"] must be a dict which maps strings
183 to torch.Tensor or numpy arrays.
184 """
185 return torch.load(f, map_location=torch.device("cpu"))
186
187 def _load_model(self, checkpoint: Any):
188 """
189 Load weights from a checkpoint.
190 Args:
191 checkpoint (Any): checkpoint contains the weights.
192 """
193 checkpoint_state_dict = checkpoint.pop("model")
194 self._convert_ndarray_to_tensor(checkpoint_state_dict)
195
196 # if the state_dict comes from a model that was wrapped in a
197 # DataParallel or DistributedDataParallel during serialization,
198 # remove the "module" prefix before performing the matching.
199 _strip_prefix_if_present(checkpoint_state_dict, "module.")
200
201 # work around https://github.com/pytorch/pytorch/issues/24139
202 model_state_dict = self.model.state_dict()
203 for k in list(checkpoint_state_dict.keys()):
204 if k in model_state_dict:
205 shape_model = tuple(model_state_dict[k].shape)
206 shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
207 if shape_model != shape_checkpoint:
208 self.logger.warning(
209 "'{}' has shape {} in the checkpoint but {} in the "
210 "model! Skipped.".format(
211 k, shape_checkpoint, shape_model
212 )
213 )
214 checkpoint_state_dict.pop(k)
215
216 incompatible = self.model.load_state_dict(
217 checkpoint_state_dict, strict=False
218 )
219 if incompatible.missing_keys:
220 self.logger.info(
221 get_missing_parameters_message(incompatible.missing_keys)
222 )
223 if incompatible.unexpected_keys:
224 self.logger.info(
225 get_unexpected_parameters_message(incompatible.unexpected_keys)
226 )
227
228 def _convert_ndarray_to_tensor(self, state_dict: dict):
229 """
230 In-place convert all numpy arrays in the state_dict to torch tensor.
231 Args:
232 state_dict (dict): a state-dict to be loaded to the model.
233 """
234 # model could be an OrderedDict with _metadata attribute
235 # (as returned by Pytorch's state_dict()). We should preserve these
236 # properties.
237 for k in list(state_dict.keys()):
238 v = state_dict[k]
239 if not isinstance(v, np.ndarray) and not isinstance(
240 v, torch.Tensor
241 ):
242 raise ValueError(
243 "Unsupported type found in checkpoint! {}: {}".format(
244 k, type(v)
245 )
246 )
247 if not isinstance(v, torch.Tensor):
248 state_dict[k] = torch.from_numpy(v)
249
250
252 """
253 Save checkpoints periodically. When `.step(iteration)` is called, it will
254 execute `checkpointer.save` on the given checkpointer, if iteration is a
255 multiple of period or if `max_iter` is reached.
256 """
257
258 def __init__(self, checkpointer: Any, period: int, max_iter: int = None):
259 """
260 Args:
261 checkpointer (Any): the checkpointer object used to save
262 checkpoints.
263 period (int): the period to save checkpoint.
264 max_iter (int): maximum number of iterations. When it is reached,
265 a checkpoint named "model_final" will be saved.
266 """
267 self.checkpointer = checkpointer
268 self.period = int(period)
269 self.max_iter = max_iter
270
271 def step(self, iteration: int, **kwargs: Any):
272 """
273 Perform the appropriate action at the given iteration.
274 Args:
275 iteration (int): the current iteration, ranged in [0, max_iter-1].
276 kwargs (Any): extra data to save, same as in
277 :meth:`Checkpointer.save`.
278 """
279 iteration = int(iteration)
280 additional_state = {"iteration": iteration}
281 additional_state.update(kwargs)
282 if self.period > 0 and (iteration + 1) % self.period == 0:
283 self.checkpointer.save(
284 "model_{:07d}".format(iteration), **additional_state
285 )
286 if self.period > 0 and iteration >= self.max_iter - 1:
287 self.checkpointer.save("model_final", **additional_state)
288
289 def save(self, name: str, **kwargs: Any):
290 """
291 Same argument as :meth:`Checkpointer.save`.
292 Use this method to manually save checkpoints outside the schedule.
293 Args:
294 name (str): file name.
295 kwargs (Any): extra data to save, same as in
296 :meth:`Checkpointer.save`.
297 """
298 self.checkpointer.save(name, **kwargs)
299
300
302 """
303 Get a logging-friendly message to report parameter names (keys) that are in
304 the model but not found in a checkpoint.
305 Args:
306 keys (list[str]): List of keys that were not found in the checkpoint.
307 Returns:
308 str: message.
309 """
310 groups = _group_checkpoint_keys(keys)
311 msg = "Some model parameters are not in the checkpoint:\n"
312 msg += "\n".join(
313 " " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
314 )
315 return msg
316
317
319 """
320 Get a logging-friendly message to report parameter names (keys) that are in
321 the checkpoint but not found in the model.
322 Args:
323 keys (list[str]): List of keys that were not found in the model.
324 Returns:
325 str: message.
326 """
327 groups = _group_checkpoint_keys(keys)
328 msg = "The checkpoint contains parameters not used by the model:\n"
329 msg += "\n".join(
330 " " + colored(k + _group_to_str(v), "magenta")
331 for k, v in groups.items()
332 )
333 return msg
334
335
336def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
337 """
338 Strip the prefix in metadata, if any.
339 Args:
340 state_dict (OrderedDict): a state-dict to be loaded to the model.
341 prefix (str): prefix.
342 """
343 keys = sorted(state_dict.keys())
344 if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
345 return
346
347 for key in keys:
348 newkey = key[len(prefix):]
349 state_dict[newkey] = state_dict.pop(key)
350
351 # also strip the prefix in metadata, if any..
352 try:
353 metadata = state_dict._metadata
354 except AttributeError:
355 pass
356 else:
357 for key in list(metadata.keys()):
358 # for the metadata dict, the key can be:
359 # '': for the DDP module, which we want to remove.
360 # 'module': for the actual model.
361 # 'module.xx.xx': for the rest.
362
363 if len(key) == 0:
364 continue
365 newkey = key[len(prefix):]
366 metadata[newkey] = metadata.pop(key)
367
368
370 """
371 Group keys based on common prefixes. A prefix is the string up to the final
372 "." in each key.
373 Args:
374 keys (list[str]): list of parameter names, i.e. keys in the model
375 checkpoint dict.
376 Returns:
377 dict[list]: keys with common prefixes are grouped into lists.
378 """
379 groups = defaultdict(list)
380 for key in keys:
381 pos = key.rfind(".")
382 if pos >= 0:
383 head, tail = key[:pos], [key[pos + 1:]]
384 else:
385 head, tail = key, []
386 groups[head].extend(tail)
387 return groups
388
389
390def _group_to_str(group: list):
391 """
392 Format a group of parameter name suffixes into a loggable string.
393 Args:
394 group (list[str]): list of parameter name suffixes.
395 Returns:
396 str: formated string.
397 """
398 if len(group) == 0:
399 return ""
400
401 if len(group) == 1:
402 return "." + group[0]
403
404 return ".{" + ", ".join(group) + "}"
save(self, str name, **dict kwargs)
Definition checkpoint.py:53
_convert_ndarray_to_tensor(self, dict state_dict)
tag_last_checkpoint(self, str last_filename_basename)
resume_or_load(self, str path, *bool resume=True)
__init__(self, nn.Module model, str save_dir="", *bool save_to_disk=True, **object checkpointables)
Definition checkpoint.py:33
step(self, int iteration, **Any kwargs)
__init__(self, Any checkpointer, int period, int max_iter=None)
_strip_prefix_if_present(collections.OrderedDict state_dict, str prefix)
_group_checkpoint_keys(list keys)
get_unexpected_parameters_message(list keys)
get_missing_parameters_message(list keys)