Safemotion Lib
Loading...
Searching...
No Matches
Public Member Functions | Public Attributes | Protected Member Functions | List of all members
fastreid.utils.checkpoint.Checkpointer Class Reference
Inheritance diagram for fastreid.utils.checkpoint.Checkpointer:

Public Member Functions

 __init__ (self, nn.Module model, str save_dir="", *bool save_to_disk=True, **object checkpointables)
 
 save (self, str name, **dict kwargs)
 
 load (self, str path)
 
 has_checkpoint (self)
 
 get_checkpoint_file (self)
 
 get_all_checkpoint_files (self)
 
 resume_or_load (self, str path, *bool resume=True)
 
 tag_last_checkpoint (self, str last_filename_basename)
 

Public Attributes

 model
 
 checkpointables
 
 logger
 
 save_dir
 
 save_to_disk
 

Protected Member Functions

 _load_file (self, str f)
 
 _load_model (self, Any checkpoint)
 
 _convert_ndarray_to_tensor (self, dict state_dict)
 

Detailed Description

A checkpointer that can save/load model as well as extra checkpointable
objects.

Definition at line 20 of file checkpoint.py.

Constructor & Destructor Documentation

◆ __init__()

fastreid.utils.checkpoint.Checkpointer.__init__ ( self,
nn.Module model,
str save_dir = "",
*bool save_to_disk = True,
**object checkpointables )
Args:
    model (nn.Module): model.
    save_dir (str): a directory to save and find checkpoints.
    save_to_disk (bool): if True, save checkpoint to disk, otherwise
        disable saving for this checkpointer.
    checkpointables (object): any checkpointable objects, i.e., objects
        that have the `state_dict()` and `load_state_dict()` method. For
        example, it can be used like
        `Checkpointer(model, "dir", optimizer=optimizer)`.

Definition at line 26 of file checkpoint.py.

33 ):
34 """
35 Args:
36 model (nn.Module): model.
37 save_dir (str): a directory to save and find checkpoints.
38 save_to_disk (bool): if True, save checkpoint to disk, otherwise
39 disable saving for this checkpointer.
40 checkpointables (object): any checkpointable objects, i.e., objects
41 that have the `state_dict()` and `load_state_dict()` method. For
42 example, it can be used like
43 `Checkpointer(model, "dir", optimizer=optimizer)`.
44 """
45 if isinstance(model, (DistributedDataParallel, DataParallel)):
46 model = model.module
47 self.model = model
48 self.checkpointables = copy.copy(checkpointables)
49 self.logger = logging.getLogger(__name__)
50 self.save_dir = save_dir
51 self.save_to_disk = save_to_disk
52

Member Function Documentation

◆ _convert_ndarray_to_tensor()

fastreid.utils.checkpoint.Checkpointer._convert_ndarray_to_tensor ( self,
dict state_dict )
protected
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
    state_dict (dict): a state-dict to be loaded to the model.

Definition at line 228 of file checkpoint.py.

228 def _convert_ndarray_to_tensor(self, state_dict: dict):
229 """
230 In-place convert all numpy arrays in the state_dict to torch tensor.
231 Args:
232 state_dict (dict): a state-dict to be loaded to the model.
233 """
234 # model could be an OrderedDict with _metadata attribute
235 # (as returned by Pytorch's state_dict()). We should preserve these
236 # properties.
237 for k in list(state_dict.keys()):
238 v = state_dict[k]
239 if not isinstance(v, np.ndarray) and not isinstance(
240 v, torch.Tensor
241 ):
242 raise ValueError(
243 "Unsupported type found in checkpoint! {}: {}".format(
244 k, type(v)
245 )
246 )
247 if not isinstance(v, torch.Tensor):
248 state_dict[k] = torch.from_numpy(v)
249
250

◆ _load_file()

fastreid.utils.checkpoint.Checkpointer._load_file ( self,
str f )
protected
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
    f (str): a locally mounted file path.
Returns:
    dict: with keys "model" and optionally others that are saved by
        the checkpointer dict["model"] must be a dict which maps strings
        to torch.Tensor or numpy arrays.

Definition at line 174 of file checkpoint.py.

174 def _load_file(self, f: str):
175 """
176 Load a checkpoint file. Can be overwritten by subclasses to support
177 different formats.
178 Args:
179 f (str): a locally mounted file path.
180 Returns:
181 dict: with keys "model" and optionally others that are saved by
182 the checkpointer dict["model"] must be a dict which maps strings
183 to torch.Tensor or numpy arrays.
184 """
185 return torch.load(f, map_location=torch.device("cpu"))
186

◆ _load_model()

fastreid.utils.checkpoint.Checkpointer._load_model ( self,
Any checkpoint )
protected
Load weights from a checkpoint.
Args:
    checkpoint (Any): checkpoint contains the weights.

Definition at line 187 of file checkpoint.py.

187 def _load_model(self, checkpoint: Any):
188 """
189 Load weights from a checkpoint.
190 Args:
191 checkpoint (Any): checkpoint contains the weights.
192 """
193 checkpoint_state_dict = checkpoint.pop("model")
194 self._convert_ndarray_to_tensor(checkpoint_state_dict)
195
196 # if the state_dict comes from a model that was wrapped in a
197 # DataParallel or DistributedDataParallel during serialization,
198 # remove the "module" prefix before performing the matching.
199 _strip_prefix_if_present(checkpoint_state_dict, "module.")
200
201 # work around https://github.com/pytorch/pytorch/issues/24139
202 model_state_dict = self.model.state_dict()
203 for k in list(checkpoint_state_dict.keys()):
204 if k in model_state_dict:
205 shape_model = tuple(model_state_dict[k].shape)
206 shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
207 if shape_model != shape_checkpoint:
208 self.logger.warning(
209 "'{}' has shape {} in the checkpoint but {} in the "
210 "model! Skipped.".format(
211 k, shape_checkpoint, shape_model
212 )
213 )
214 checkpoint_state_dict.pop(k)
215
216 incompatible = self.model.load_state_dict(
217 checkpoint_state_dict, strict=False
218 )
219 if incompatible.missing_keys:
220 self.logger.info(
221 get_missing_parameters_message(incompatible.missing_keys)
222 )
223 if incompatible.unexpected_keys:
224 self.logger.info(
225 get_unexpected_parameters_message(incompatible.unexpected_keys)
226 )
227

◆ get_all_checkpoint_files()

fastreid.utils.checkpoint.Checkpointer.get_all_checkpoint_files ( self)
Returns:
    list: All available checkpoint files (.pth files) in target
        directory.

Definition at line 135 of file checkpoint.py.

135 def get_all_checkpoint_files(self):
136 """
137 Returns:
138 list: All available checkpoint files (.pth files) in target
139 directory.
140 """
141 all_model_checkpoints = [
142 os.path.join(self.save_dir, file)
143 for file in PathManager.ls(self.save_dir)
144 if PathManager.isfile(os.path.join(self.save_dir, file))
145 and file.endswith(".pth")
146 ]
147 return all_model_checkpoints
148

◆ get_checkpoint_file()

fastreid.utils.checkpoint.Checkpointer.get_checkpoint_file ( self)
Returns:
    str: The latest checkpoint file in target directory.

Definition at line 120 of file checkpoint.py.

120 def get_checkpoint_file(self):
121 """
122 Returns:
123 str: The latest checkpoint file in target directory.
124 """
125 save_file = os.path.join(self.save_dir, "last_checkpoint")
126 try:
127 with PathManager.open(save_file, "r") as f:
128 last_saved = f.read().strip()
129 except IOError:
130 # if file doesn't exist, maybe because it has just been
131 # deleted by a separate process
132 return ""
133 return os.path.join(self.save_dir, last_saved)
134

◆ has_checkpoint()

fastreid.utils.checkpoint.Checkpointer.has_checkpoint ( self)
Returns:
    bool: whether a checkpoint exists in the target directory.

Definition at line 112 of file checkpoint.py.

112 def has_checkpoint(self):
113 """
114 Returns:
115 bool: whether a checkpoint exists in the target directory.
116 """
117 save_file = os.path.join(self.save_dir, "last_checkpoint")
118 return PathManager.exists(save_file)
119

◆ load()

fastreid.utils.checkpoint.Checkpointer.load ( self,
str path )
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
    path (str): path or url to the checkpoint. If empty, will not load
        anything.
Returns:
    dict:
        extra data loaded from the checkpoint that has not been
        processed. For example, those saved with
        :meth:`.save(**extra_data)`.

Definition at line 77 of file checkpoint.py.

77 def load(self, path: str):
78 """
79 Load from the given checkpoint. When path points to network file, this
80 function has to be called on all ranks.
81 Args:
82 path (str): path or url to the checkpoint. If empty, will not load
83 anything.
84 Returns:
85 dict:
86 extra data loaded from the checkpoint that has not been
87 processed. For example, those saved with
88 :meth:`.save(**extra_data)`.
89 """
90 print(f'path = {path}')
91 if not path:
92 # no checkpoint provided
93 self.logger.info(
94 "No checkpoint found. Training model from scratch"
95 )
96 return {}
97 self.logger.info("Loading checkpoint from {}".format(path))
98 if not os.path.isfile(path):
99 path = PathManager.get_local_path(path)
100 assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
101
102 checkpoint = self._load_file(path)
103 self._load_model(checkpoint)
104 for key, obj in self.checkpointables.items():
105 if key in checkpoint:
106 self.logger.info("Loading {} from {}".format(key, path))
107 obj.load_state_dict(checkpoint.pop(key))
108
109 # return any further checkpoint data
110 return checkpoint
111

◆ resume_or_load()

fastreid.utils.checkpoint.Checkpointer.resume_or_load ( self,
str path,
*bool resume = True )
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
    path (str): path to the checkpoint.
    resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
    same as :meth:`load`.

Definition at line 149 of file checkpoint.py.

149 def resume_or_load(self, path: str, *, resume: bool = True):
150 """
151 If `resume` is True, this method attempts to resume from the last
152 checkpoint, if exists. Otherwise, load checkpoint from the given path.
153 This is useful when restarting an interrupted training job.
154 Args:
155 path (str): path to the checkpoint.
156 resume (bool): if True, resume from the last checkpoint if it exists.
157 Returns:
158 same as :meth:`load`.
159 """
160 if resume and self.has_checkpoint():
161 path = self.get_checkpoint_file()
162 return self.load(path)
163

◆ save()

fastreid.utils.checkpoint.Checkpointer.save ( self,
str name,
**dict kwargs )
Dump model and checkpointables to a file.
Args:
    name (str): name of the file.
    kwargs (dict): extra arbitrary data to save.

Definition at line 53 of file checkpoint.py.

53 def save(self, name: str, **kwargs: dict):
54 """
55 Dump model and checkpointables to a file.
56 Args:
57 name (str): name of the file.
58 kwargs (dict): extra arbitrary data to save.
59 """
60 if not self.save_dir or not self.save_to_disk:
61 return
62
63 data = {}
64 data["model"] = self.model.state_dict()
65 for key, obj in self.checkpointables.items():
66 data[key] = obj.state_dict()
67 data.update(kwargs)
68
69 basename = "{}.pth".format(name)
70 save_file = os.path.join(self.save_dir, basename)
71 assert os.path.basename(save_file) == basename, basename
72 self.logger.info("Saving checkpoint to {}".format(save_file))
73 with PathManager.open(save_file, "wb") as f:
74 torch.save(data, f)
75 self.tag_last_checkpoint(basename)
76

◆ tag_last_checkpoint()

fastreid.utils.checkpoint.Checkpointer.tag_last_checkpoint ( self,
str last_filename_basename )
Tag the last checkpoint.
Args:
    last_filename_basename (str): the basename of the last filename.

Definition at line 164 of file checkpoint.py.

164 def tag_last_checkpoint(self, last_filename_basename: str):
165 """
166 Tag the last checkpoint.
167 Args:
168 last_filename_basename (str): the basename of the last filename.
169 """
170 save_file = os.path.join(self.save_dir, "last_checkpoint")
171 with PathManager.open(save_file, "w") as f:
172 f.write(last_filename_basename)
173

Member Data Documentation

◆ checkpointables

fastreid.utils.checkpoint.Checkpointer.checkpointables

Definition at line 48 of file checkpoint.py.

◆ logger

fastreid.utils.checkpoint.Checkpointer.logger

Definition at line 49 of file checkpoint.py.

◆ model

fastreid.utils.checkpoint.Checkpointer.model

Definition at line 47 of file checkpoint.py.

◆ save_dir

fastreid.utils.checkpoint.Checkpointer.save_dir

Definition at line 50 of file checkpoint.py.

◆ save_to_disk

fastreid.utils.checkpoint.Checkpointer.save_to_disk

Definition at line 51 of file checkpoint.py.


The documentation for this class was generated from the following file: