docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,396 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ from __future__ import division
6
+
7
+ import collections
8
+ import numbers
9
+ import random
10
+
11
+ import torch
12
+
13
+ from docling_ibm_models.tableformer.data_management import functional as F
14
+
15
+
16
+ def box_cxcywh_to_xyxy(x):
17
+ x_c, y_c, w, h = x.unbind(-1)
18
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
19
+ return torch.stack(b, dim=-1)
20
+
21
+
22
+ def box_xyxy_to_cxcywh(x):
23
+ x0, y0, x1, y1 = x.unbind(-1)
24
+ b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
25
+ return torch.stack(b, dim=-1)
26
+
27
+
28
+ class Lambda(object):
29
+ """Apply a user-defined lambda as a transform.
30
+ Attention: The multiprocessing used in dataloader of pytorch
31
+ is not friendly with lambda function in Windows
32
+ Args:
33
+ lambd (function): Lambda/function to be used for transform.
34
+ """
35
+
36
+ def __init__(self, lambd):
37
+ # assert isinstance(lambd, types.LambdaType)
38
+ self.lambd = lambd
39
+ # if 'Windows' in platform.system():
40
+ # raise RuntimeError("Can't pickle lambda funciton in windows system")
41
+
42
+ def __call__(self, img):
43
+ return self.lambd(img)
44
+
45
+ def __repr__(self):
46
+ return self.__class__.__name__ + "()"
47
+
48
+
49
+ class RandomTransforms(object):
50
+ """Base class for a list of transformations with randomness
51
+ Args:
52
+ transforms (list or tuple): list of transformations
53
+ """
54
+
55
+ def __init__(self, transforms):
56
+ assert isinstance(transforms, (list, tuple))
57
+ self.transforms = transforms
58
+
59
+ def __call__(self, *args, **kwargs):
60
+ raise NotImplementedError()
61
+
62
+ def __repr__(self):
63
+ format_string = self.__class__.__name__ + "("
64
+ for t in self.transforms:
65
+ format_string += "\n"
66
+ format_string += " {0}".format(t)
67
+ format_string += "\n)"
68
+ return format_string
69
+
70
+
71
+ class RandomChoice(RandomTransforms):
72
+ """Apply single transformation randomly picked from a list"""
73
+
74
+ def __call__(self, img, target):
75
+ t = random.choice(self.transforms)
76
+ return t(img, target)
77
+
78
+
79
+ class RandomCrop(object):
80
+ def __init__(self, size, margin_crop):
81
+ self.size = list(size)
82
+ self.margin_crop = list(margin_crop)
83
+ # margin_crop: w, h
84
+
85
+ def __call__(self, img, target):
86
+ # img (w,h,ch)
87
+ image_height, image_width = img.shape[0], img.shape[1]
88
+ """
89
+ img (np.ndarray): Image to be cropped.
90
+ x: Upper pixel coordinate.
91
+ y: Left pixel coordinate.
92
+ h: Height of the cropped image.
93
+ w: Width of the cropped image.
94
+ """
95
+ if image_width > 0 and image_height > 0:
96
+ cropped_image = F.crop(
97
+ img,
98
+ self.margin_crop[1],
99
+ self.margin_crop[0],
100
+ image_height - (self.margin_crop[1] * 2),
101
+ image_width - (self.margin_crop[0] * 2),
102
+ )
103
+
104
+ target_ = target.copy()
105
+ target_["boxes"][:, 0] = target_["boxes"][:, 0] - self.margin_crop[0]
106
+ target_["boxes"][:, 1] = target_["boxes"][:, 1] - self.margin_crop[1]
107
+ target_["boxes"][:, 2] = target_["boxes"][:, 2] - self.margin_crop[0]
108
+ target_["boxes"][:, 3] = target_["boxes"][:, 3] - self.margin_crop[1]
109
+ else:
110
+ cropped_image = img
111
+ return cropped_image, target_
112
+
113
+
114
+ class RandomPad(object):
115
+ def __init__(self, max_pad):
116
+ self.max_pad = max_pad
117
+
118
+ def __call__(self, img, target):
119
+ pad_x = random.randint(0, self.max_pad)
120
+ pad_y = random.randint(0, self.max_pad)
121
+ pad_x1 = random.randint(0, self.max_pad)
122
+ pad_y1 = random.randint(0, self.max_pad)
123
+ img = img.copy()
124
+ padded_image = F.pad(img, (pad_x, pad_y, pad_x1, pad_y1), fill=(255, 255, 255))
125
+ target_ = target.copy()
126
+ if target["boxes"] is not None:
127
+ target_["boxes"][:, 0] = target_["boxes"][:, 0] + pad_x
128
+ target_["boxes"][:, 1] = target_["boxes"][:, 1] + pad_y
129
+ target_["boxes"][:, 2] = target_["boxes"][:, 2] + pad_x
130
+ target_["boxes"][:, 3] = target_["boxes"][:, 3] + pad_y
131
+ return padded_image, target_
132
+
133
+
134
+ class ColorJitter(object):
135
+ """Randomly change the brightness, contrast and saturation of an image.
136
+ Args:
137
+ brightness (float): How much to jitter brightness. brightness_factor
138
+ is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
139
+ contrast (float): How much to jitter contrast. contrast_factor
140
+ is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
141
+ saturation (float): How much to jitter saturation. saturation_factor
142
+ is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
143
+ hue(float): How much to jitter hue. hue_factor is chosen uniformly from
144
+ [-hue, hue]. Should be >=0 and <= 0.5.
145
+ """
146
+
147
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
148
+
149
+ assert isinstance(brightness, float) or (
150
+ isinstance(brightness, collections.Iterable) and len(brightness) == 2
151
+ )
152
+ assert isinstance(contrast, float) or (
153
+ isinstance(contrast, collections.Iterable) and len(contrast) == 2
154
+ )
155
+ assert isinstance(saturation, float) or (
156
+ isinstance(saturation, collections.Iterable) and len(saturation) == 2
157
+ )
158
+ assert isinstance(hue, float) or (
159
+ isinstance(hue, collections.Iterable) and len(hue) == 2
160
+ )
161
+
162
+ self.brightness = brightness
163
+ self.contrast = contrast
164
+ self.saturation = saturation
165
+ self.hue = hue
166
+
167
+ @staticmethod
168
+ def get_params(brightness, contrast, saturation, hue):
169
+ """Get a randomized transform to be applied on image.
170
+ Arguments are same as that of __init__.
171
+ Returns:
172
+ Transform which randomly adjusts brightness, contrast and
173
+ saturation in a random order.
174
+ """
175
+ transforms = []
176
+
177
+ if isinstance(brightness, numbers.Number):
178
+
179
+ if brightness > 0:
180
+ brightness_factor = random.uniform(
181
+ max(0, 1 - brightness), 1 + brightness
182
+ )
183
+ transforms.append(
184
+ Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
185
+ )
186
+
187
+ if contrast > 0:
188
+ contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
189
+ transforms.append(
190
+ Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
191
+ )
192
+
193
+ if saturation > 0:
194
+ saturation_factor = random.uniform(
195
+ max(0, 1 - saturation), 1 + saturation
196
+ )
197
+ transforms.append(
198
+ Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
199
+ )
200
+
201
+ if hue > 0:
202
+ hue_factor = random.uniform(-hue, hue)
203
+ transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
204
+
205
+ else:
206
+
207
+ if brightness[0] > 0 and brightness[1] > 0:
208
+
209
+ brightness_factor = random.uniform(brightness[0], brightness[1])
210
+ transforms.append(
211
+ Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
212
+ )
213
+
214
+ if contrast[0] > 0 and contrast[1] > 0:
215
+
216
+ contrast_factor = random.uniform(contrast[0], contrast[1])
217
+ transforms.append(
218
+ Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
219
+ )
220
+
221
+ if saturation[0] > 0 and saturation[1] > 0:
222
+
223
+ saturation_factor = random.uniform(saturation[0], saturation[1])
224
+ transforms.append(
225
+ Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
226
+ )
227
+
228
+ if hue[0] > 0 and hue[1] > 0:
229
+ hue_factor = random.uniform(hue[0], hue[1])
230
+ transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
231
+
232
+ random.shuffle(transforms)
233
+ transform = ComposeSingle(transforms)
234
+
235
+ return transform
236
+
237
+ def __call__(self, img, target):
238
+ """
239
+ Args:
240
+ img (np.ndarray): Input image.
241
+ Returns:
242
+ np.ndarray: Color jittered image.
243
+ """
244
+ transform = self.get_params(
245
+ self.brightness, self.contrast, self.saturation, self.hue
246
+ )
247
+ return transform(img), target
248
+
249
+ def __repr__(self):
250
+ format_string = self.__class__.__name__ + "("
251
+ format_string += "brightness={0}".format(self.brightness)
252
+ format_string += ", contrast={0}".format(self.contrast)
253
+ format_string += ", saturation={0}".format(self.saturation)
254
+ format_string += ", hue={0})".format(self.hue)
255
+ return format_string
256
+
257
+
258
+ class Normalize(object):
259
+ """Normalize a tensor image with mean and standard deviation.
260
+ Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
261
+ will normalize each channel of the input ``torch.*Tensor`` i.e.
262
+ ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
263
+ Args:
264
+ mean (sequence): Sequence of means for each channel.
265
+ std (sequence): Sequence of standard deviations for each channel.
266
+ """
267
+
268
+ def __init__(self, mean, std):
269
+ self.mean = mean
270
+ self.std = std
271
+
272
+ def __call__(self, tensor, target=None):
273
+ """
274
+ Args:
275
+ tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
276
+ Returns:
277
+ Tensor: Normalized Tensor image.
278
+ """
279
+ return F.normalize(tensor, self.mean, self.std), target
280
+
281
+ def __repr__(self):
282
+ return self.__class__.__name__ + "(mean={0}, std={1})".format(
283
+ self.mean, self.std
284
+ )
285
+
286
+
287
+ class NoTransformation(object):
288
+ """Do Nothing"""
289
+
290
+ def __call__(self, img, target):
291
+ return img, target
292
+
293
+
294
+ class Compose(object):
295
+ """Composes several transforms together.
296
+ Args:
297
+ transforms (list of ``Transform`` objects): list of transforms to compose.
298
+ Example:
299
+ >>> transforms.Compose([
300
+ >>> transforms.CenterCrop(10),
301
+ >>> transforms.ToTensor(),
302
+ >>> ])
303
+ """
304
+
305
+ def __init__(self, transforms):
306
+ self.transforms = transforms
307
+
308
+ def __call__(self, img, target):
309
+ for t in self.transforms:
310
+ img, target = t(img, target)
311
+ return img, target
312
+
313
+ def __repr__(self):
314
+ format_string = self.__class__.__name__ + "("
315
+ for t in self.transforms:
316
+ format_string += "\n"
317
+ format_string += " {0}".format(t)
318
+ format_string += "\n)"
319
+ return format_string
320
+
321
+
322
+ class ComposeSingle(object):
323
+ """Composes several transforms together.
324
+ Args:
325
+ transforms (list of ``Transform`` objects): list of transforms to compose.
326
+ Example:
327
+ >>> transforms.Compose([
328
+ >>> transforms.CenterCrop(10),
329
+ >>> transforms.ToTensor(),
330
+ >>> ])
331
+ """
332
+
333
+ def __init__(self, transforms):
334
+ self.transforms = transforms
335
+
336
+ def __call__(self, img):
337
+ for t in self.transforms:
338
+ img = t(img)
339
+ return img
340
+
341
+ def __repr__(self):
342
+ format_string = self.__class__.__name__ + "("
343
+ for t in self.transforms:
344
+ format_string += "\n"
345
+ format_string += " {0}".format(t)
346
+ format_string += "\n)"
347
+ return format_string
348
+
349
+
350
+ class Resize(object):
351
+ """Resize the input PIL Image to the given size.
352
+ Args:
353
+ size (sequence or int): Desired output size. If size is a sequence like
354
+ (h, w), output size will be matched to this. If size is an int,
355
+ smaller edge of the image will be matched to this number.
356
+ i.e, if height > width, then image will be rescaled to
357
+ (size * height / width, size)
358
+ interpolation (int, optional): Desired interpolation. Default is
359
+ ``BILINEAR``
360
+ """
361
+
362
+ def __init__(self, size, interpolation="BILINEAR"):
363
+ self.size = size
364
+ self.interpolation = interpolation
365
+
366
+ def __call__(self, img, target=None):
367
+ """
368
+ Args:
369
+ img (np.ndarray): Image to be scaled.
370
+ Returns:
371
+ np.ndarray: Rescaled image.
372
+ """
373
+ # Resize bboxes (in pixels)
374
+ x_scale = 0
375
+ y_scale = 0
376
+
377
+ if img.shape[1] > 0:
378
+ x_scale = self.size[0] / img.shape[1]
379
+ if img.shape[0] > 0:
380
+ y_scale = self.size[1] / img.shape[0]
381
+
382
+ # loop over bboxes
383
+ if target is not None:
384
+ if target["boxes"] is not None:
385
+ target_ = target.copy()
386
+ target_["boxes"][:, 0] = x_scale * target_["boxes"][:, 0]
387
+ target_["boxes"][:, 1] = y_scale * target_["boxes"][:, 1]
388
+ target_["boxes"][:, 2] = x_scale * target_["boxes"][:, 2]
389
+ target_["boxes"][:, 3] = y_scale * target_["boxes"][:, 3]
390
+ return F.resize(img, self.size, self.interpolation), target
391
+
392
+ def __repr__(self):
393
+ interpolate_str = self.interpolation
394
+ return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
395
+ self.size, interpolate_str
396
+ )
File without changes
@@ -0,0 +1,279 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import glob
6
+ import logging
7
+ import os
8
+ import time
9
+ from abc import ABC, abstractmethod
10
+ from pathlib import Path
11
+
12
+ import torch
13
+
14
+ import docling_ibm_models.tableformer.settings as s
15
+
16
+ LOG_LEVEL = logging.INFO
17
+ # LOG_LEVEL = logging.DEBUG
18
+
19
+
20
+ class BaseModel(ABC):
21
+ r"""
22
+ BaseModel provides some common functionality for all models:
23
+ - Saves checkpoint files for each epoch
24
+ - Loads the model from the best available checkpoint
25
+ - Save repository branch and commit
26
+ """
27
+
28
+ def __init__(self, config, init_data, device):
29
+ r"""
30
+ Inputs:
31
+ config: The configuration file
32
+ init_data: Dictionary with initialization data. This dictionary can be used to pass any
33
+ kind of initialization data for the models
34
+ device: The device used to move the tensors of the model
35
+ """
36
+ super(BaseModel, self).__init__()
37
+
38
+ # Set config and device
39
+ self._config = config
40
+ self._init_data = init_data
41
+
42
+ self._device = device
43
+
44
+ self._save_dir = config["model"]["save_dir"]
45
+ self._load_checkpoint = None
46
+ if "load_checkpoint" in config["model"]:
47
+ self._load_checkpoint = config["model"]["load_checkpoint"]
48
+
49
+ self._branch_name = "dev/next"
50
+ self._commit_sha = "1"
51
+
52
+ # Keep a dictionary with the starting times per epoch.
53
+ # NOTICE: Epochs start from 0
54
+ self._epoch_start_ts = {0: time.time()}
55
+
56
+ def _log(self):
57
+ # Setup a custom logger
58
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
59
+
60
+ @abstractmethod
61
+ def predict(self, img, max_steps, beam_size, return_attention):
62
+ pass
63
+
64
+ def count_parameters(self):
65
+ r"""Counts the number of trainable parameters of this model
66
+
67
+ Output:
68
+ num_parameters: number of trainable parameters
69
+ """
70
+ num_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
71
+
72
+ return num_parameters
73
+
74
+ def get_code_version(self):
75
+ r"""Gets the source control version of this model code
76
+
77
+ Returns
78
+ -------
79
+ branch_name : str
80
+ The name of the Git branch of this model code
81
+ commit_sha : str
82
+ The unique identifier of the Git commit of this model code
83
+ """
84
+
85
+ return self._branch_name, self._commit_sha
86
+
87
+ def get_save_directory(self):
88
+ r"""
89
+ Return the save directory
90
+ """
91
+ return self._save_dir
92
+
93
+ def is_saved(self):
94
+ r"""
95
+ This method returns True if both conditions are met:
96
+ 1. There is a checkpoint file for the model.
97
+ 2. The checkpoint file corresponds to the last training epoch set in the configuration file.
98
+ """
99
+ # Get the saved_model
100
+ saved_model, _ = self._load_best_checkpoint()
101
+
102
+ if saved_model is None:
103
+ return False
104
+
105
+ epochs = self._config["train"]["epochs"]
106
+ self._log().debug(
107
+ "Best epoch in saved model: {}; Number of epochs in config: {}".format(
108
+ saved_model["epoch"], epochs
109
+ )
110
+ )
111
+ if epochs == saved_model["epoch"] + 1:
112
+ return True
113
+
114
+ return False
115
+
116
+ def save(self, epoch=None, optimizers=None, losses=None, model_parameters=None):
117
+ r"""
118
+ Save the model data to the disk as a pickle file.
119
+
120
+ Parameters
121
+ ----------
122
+ epoch: Training epoch
123
+ optimizers: Dictionary with the optimizers. The key specifies what the optimizer is
124
+ used for. The 'state_dict' of each optimizer will be saved in the
125
+ checkpoint file.
126
+ losses: Dictionary with the losses. The key specifies what the loss is used for. Each
127
+ value is a list
128
+ model_parameters: Dictionary with model specific parameters that we need to save in the
129
+ checkpoint file.
130
+ Returns
131
+ -------
132
+ True if success, False otherwise
133
+ """
134
+ # Get the checkpoint_filename
135
+ c_filename = self._build_checkpoint_filename(epoch)
136
+ self._log().debug("Trying to save checkpoint file: {}".format(c_filename))
137
+
138
+ # Prepare a dictionary with all data we want to save
139
+ optimizers_state_dict = None
140
+ if optimizers is not None:
141
+ optimizers_state_dict = {k: v.state_dict() for k, v in optimizers.items()}
142
+
143
+ model_data = {
144
+ "model_state_dict": self.state_dict(),
145
+ "epoch": epoch,
146
+ "optimizers": optimizers_state_dict,
147
+ "losses": losses,
148
+ "model_parameters": model_parameters,
149
+ }
150
+
151
+ # Add the processing time per epoch
152
+ now = time.time()
153
+ self._epoch_start_ts[epoch + 1] = now
154
+ if epoch in self._epoch_start_ts:
155
+ dt = now - self._epoch_start_ts[epoch]
156
+ model_data["epoch_start_ts"] = self._epoch_start_ts[epoch]
157
+ model_data["epoch_dt"] = dt
158
+
159
+ # Create the save directory
160
+ Path(self._save_dir).mkdir(parents=True, exist_ok=True)
161
+
162
+ # Save the model
163
+ torch.save(model_data, c_filename)
164
+
165
+ # Return true if file is present, otherwise false
166
+ if not os.path.isfile(c_filename):
167
+ self._log().error("Cannot find the file to save: " + c_filename)
168
+ return False
169
+
170
+ # store code branch name and commit
171
+ version_file = os.path.join(self._save_dir, "_version")
172
+ with open(version_file, "w") as text_file:
173
+ print("Model is using code [commit:branch]", file=text_file)
174
+ print("{}:{}".format(self._commit_sha, self._branch_name), file=text_file)
175
+
176
+ return True
177
+
178
+ def load(self, optimizers=None):
179
+ r"""
180
+ Load the model data from the disk.
181
+ The method will iterate over all *.check files and try to load the one from the highest
182
+ epoch.
183
+
184
+ Input:
185
+ -optimizers: Dictionary with optimizers. If it is not null the keys will be used to
186
+ associate the corresponding state_dicts from the checkpoint file and update
187
+ the internal states of the provided optimizers.
188
+
189
+ Output:
190
+ - Success: True/ False
191
+ - epoch: Loaded epoch or -1 if there are no checkpoint files
192
+ - optimizers: Dictionary with loaded optimizers or empty dictionary of there is no
193
+ checkpoint file
194
+ - losses: Dictionary with loaded losses or empty dictionary of there is no checkpoint
195
+ file
196
+ - model_parameters: Dictionary with the model parameters or empty dictionary if there
197
+ are no checkpoint files
198
+ """
199
+ # Get the saved_model
200
+ saved_model, _ = self._load_best_checkpoint()
201
+
202
+ # Restore the model
203
+ if saved_model is None:
204
+ self._log().debug("No saved model checkpoint found")
205
+ return False, -1, optimizers, {}, {}
206
+
207
+ self._log().debug("Loading model from checkpoint file")
208
+ self.load_state_dict(saved_model["model_state_dict"])
209
+
210
+ epoch = 0
211
+ if "epoch" in saved_model:
212
+ epoch = saved_model["epoch"]
213
+ losses = {}
214
+ if "losses" in saved_model:
215
+ losses = saved_model["losses"]
216
+ model_parameters = saved_model["model_parameters"]
217
+
218
+ if optimizers is not None:
219
+ for key, optimizer_state_dict in saved_model["optimizers"].items():
220
+ optimizers[key].load_state_dict(optimizer_state_dict)
221
+
222
+ # Reset the start_ts of the next epoch
223
+ self._epoch_start_ts[epoch + 1] = time.time()
224
+
225
+ return True, epoch, optimizers, losses, model_parameters
226
+
227
+ def _load_best_checkpoint(self):
228
+ r"""
229
+ If a "load_checkpoint" file has been provided, load this one.
230
+ Otherwise use the "save_dir" and load the one with the most advanced epoch
231
+
232
+ Returns
233
+ -------
234
+ saved_model : dictionary
235
+ Checkpoint file contents generated by torch.load, or None
236
+ checkpoint_file : string
237
+ Filename of the loaded checkpoint, or None
238
+ """
239
+ checkpoint_files = []
240
+ # If a "load_checkpoint" file is provided, try to load it
241
+ if self._load_checkpoint is not None:
242
+ if not os.path.exists(self._load_checkpoint):
243
+ self._log().error(
244
+ "Cannot load the checkpoint: {}".format(self._load_checkpoint)
245
+ )
246
+ return None, None
247
+ checkpoint_files.append(self._load_checkpoint)
248
+ else:
249
+ # Iterate over all check files from the directory by reverse alphabetical order
250
+ # This will get the biggest epoch first
251
+ checkpoint_files = glob.glob(os.path.join(self._save_dir, "*.check"))
252
+ checkpoint_files.sort(reverse=True)
253
+
254
+ for checkpoint_file in checkpoint_files:
255
+ try:
256
+ # Try to load the file
257
+ self._log().info(
258
+ "Loading model checkpoint file: {}".format(checkpoint_file)
259
+ )
260
+ saved_model = torch.load(checkpoint_file, map_location=self._device)
261
+ return saved_model, checkpoint_file
262
+ except RuntimeError:
263
+ self._log().error("Cannot load file: {}".format(checkpoint_file))
264
+
265
+ return None, None
266
+
267
+ def _build_checkpoint_filename(self, epoch):
268
+ r"""
269
+ Construct the full path for the filename of this checkpoint
270
+ """
271
+ dataset_name = self._config["dataset"]["name"]
272
+ model_type = self._config["model"]["type"]
273
+ model_name = self._config["model"]["name"]
274
+ filename = "{}_{}_{}_{:03}.check".format(
275
+ model_type, model_name, dataset_name, epoch
276
+ )
277
+ c_filename = os.path.join(self._save_dir, filename)
278
+
279
+ return c_filename