docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,396 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
from __future__ import division
|
6
|
+
|
7
|
+
import collections
|
8
|
+
import numbers
|
9
|
+
import random
|
10
|
+
|
11
|
+
import torch
|
12
|
+
|
13
|
+
from docling_ibm_models.tableformer.data_management import functional as F
|
14
|
+
|
15
|
+
|
16
|
+
def box_cxcywh_to_xyxy(x):
|
17
|
+
x_c, y_c, w, h = x.unbind(-1)
|
18
|
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
19
|
+
return torch.stack(b, dim=-1)
|
20
|
+
|
21
|
+
|
22
|
+
def box_xyxy_to_cxcywh(x):
|
23
|
+
x0, y0, x1, y1 = x.unbind(-1)
|
24
|
+
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
|
25
|
+
return torch.stack(b, dim=-1)
|
26
|
+
|
27
|
+
|
28
|
+
class Lambda(object):
|
29
|
+
"""Apply a user-defined lambda as a transform.
|
30
|
+
Attention: The multiprocessing used in dataloader of pytorch
|
31
|
+
is not friendly with lambda function in Windows
|
32
|
+
Args:
|
33
|
+
lambd (function): Lambda/function to be used for transform.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, lambd):
|
37
|
+
# assert isinstance(lambd, types.LambdaType)
|
38
|
+
self.lambd = lambd
|
39
|
+
# if 'Windows' in platform.system():
|
40
|
+
# raise RuntimeError("Can't pickle lambda funciton in windows system")
|
41
|
+
|
42
|
+
def __call__(self, img):
|
43
|
+
return self.lambd(img)
|
44
|
+
|
45
|
+
def __repr__(self):
|
46
|
+
return self.__class__.__name__ + "()"
|
47
|
+
|
48
|
+
|
49
|
+
class RandomTransforms(object):
|
50
|
+
"""Base class for a list of transformations with randomness
|
51
|
+
Args:
|
52
|
+
transforms (list or tuple): list of transformations
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self, transforms):
|
56
|
+
assert isinstance(transforms, (list, tuple))
|
57
|
+
self.transforms = transforms
|
58
|
+
|
59
|
+
def __call__(self, *args, **kwargs):
|
60
|
+
raise NotImplementedError()
|
61
|
+
|
62
|
+
def __repr__(self):
|
63
|
+
format_string = self.__class__.__name__ + "("
|
64
|
+
for t in self.transforms:
|
65
|
+
format_string += "\n"
|
66
|
+
format_string += " {0}".format(t)
|
67
|
+
format_string += "\n)"
|
68
|
+
return format_string
|
69
|
+
|
70
|
+
|
71
|
+
class RandomChoice(RandomTransforms):
|
72
|
+
"""Apply single transformation randomly picked from a list"""
|
73
|
+
|
74
|
+
def __call__(self, img, target):
|
75
|
+
t = random.choice(self.transforms)
|
76
|
+
return t(img, target)
|
77
|
+
|
78
|
+
|
79
|
+
class RandomCrop(object):
|
80
|
+
def __init__(self, size, margin_crop):
|
81
|
+
self.size = list(size)
|
82
|
+
self.margin_crop = list(margin_crop)
|
83
|
+
# margin_crop: w, h
|
84
|
+
|
85
|
+
def __call__(self, img, target):
|
86
|
+
# img (w,h,ch)
|
87
|
+
image_height, image_width = img.shape[0], img.shape[1]
|
88
|
+
"""
|
89
|
+
img (np.ndarray): Image to be cropped.
|
90
|
+
x: Upper pixel coordinate.
|
91
|
+
y: Left pixel coordinate.
|
92
|
+
h: Height of the cropped image.
|
93
|
+
w: Width of the cropped image.
|
94
|
+
"""
|
95
|
+
if image_width > 0 and image_height > 0:
|
96
|
+
cropped_image = F.crop(
|
97
|
+
img,
|
98
|
+
self.margin_crop[1],
|
99
|
+
self.margin_crop[0],
|
100
|
+
image_height - (self.margin_crop[1] * 2),
|
101
|
+
image_width - (self.margin_crop[0] * 2),
|
102
|
+
)
|
103
|
+
|
104
|
+
target_ = target.copy()
|
105
|
+
target_["boxes"][:, 0] = target_["boxes"][:, 0] - self.margin_crop[0]
|
106
|
+
target_["boxes"][:, 1] = target_["boxes"][:, 1] - self.margin_crop[1]
|
107
|
+
target_["boxes"][:, 2] = target_["boxes"][:, 2] - self.margin_crop[0]
|
108
|
+
target_["boxes"][:, 3] = target_["boxes"][:, 3] - self.margin_crop[1]
|
109
|
+
else:
|
110
|
+
cropped_image = img
|
111
|
+
return cropped_image, target_
|
112
|
+
|
113
|
+
|
114
|
+
class RandomPad(object):
|
115
|
+
def __init__(self, max_pad):
|
116
|
+
self.max_pad = max_pad
|
117
|
+
|
118
|
+
def __call__(self, img, target):
|
119
|
+
pad_x = random.randint(0, self.max_pad)
|
120
|
+
pad_y = random.randint(0, self.max_pad)
|
121
|
+
pad_x1 = random.randint(0, self.max_pad)
|
122
|
+
pad_y1 = random.randint(0, self.max_pad)
|
123
|
+
img = img.copy()
|
124
|
+
padded_image = F.pad(img, (pad_x, pad_y, pad_x1, pad_y1), fill=(255, 255, 255))
|
125
|
+
target_ = target.copy()
|
126
|
+
if target["boxes"] is not None:
|
127
|
+
target_["boxes"][:, 0] = target_["boxes"][:, 0] + pad_x
|
128
|
+
target_["boxes"][:, 1] = target_["boxes"][:, 1] + pad_y
|
129
|
+
target_["boxes"][:, 2] = target_["boxes"][:, 2] + pad_x
|
130
|
+
target_["boxes"][:, 3] = target_["boxes"][:, 3] + pad_y
|
131
|
+
return padded_image, target_
|
132
|
+
|
133
|
+
|
134
|
+
class ColorJitter(object):
|
135
|
+
"""Randomly change the brightness, contrast and saturation of an image.
|
136
|
+
Args:
|
137
|
+
brightness (float): How much to jitter brightness. brightness_factor
|
138
|
+
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
|
139
|
+
contrast (float): How much to jitter contrast. contrast_factor
|
140
|
+
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
|
141
|
+
saturation (float): How much to jitter saturation. saturation_factor
|
142
|
+
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
|
143
|
+
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
|
144
|
+
[-hue, hue]. Should be >=0 and <= 0.5.
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
|
148
|
+
|
149
|
+
assert isinstance(brightness, float) or (
|
150
|
+
isinstance(brightness, collections.Iterable) and len(brightness) == 2
|
151
|
+
)
|
152
|
+
assert isinstance(contrast, float) or (
|
153
|
+
isinstance(contrast, collections.Iterable) and len(contrast) == 2
|
154
|
+
)
|
155
|
+
assert isinstance(saturation, float) or (
|
156
|
+
isinstance(saturation, collections.Iterable) and len(saturation) == 2
|
157
|
+
)
|
158
|
+
assert isinstance(hue, float) or (
|
159
|
+
isinstance(hue, collections.Iterable) and len(hue) == 2
|
160
|
+
)
|
161
|
+
|
162
|
+
self.brightness = brightness
|
163
|
+
self.contrast = contrast
|
164
|
+
self.saturation = saturation
|
165
|
+
self.hue = hue
|
166
|
+
|
167
|
+
@staticmethod
|
168
|
+
def get_params(brightness, contrast, saturation, hue):
|
169
|
+
"""Get a randomized transform to be applied on image.
|
170
|
+
Arguments are same as that of __init__.
|
171
|
+
Returns:
|
172
|
+
Transform which randomly adjusts brightness, contrast and
|
173
|
+
saturation in a random order.
|
174
|
+
"""
|
175
|
+
transforms = []
|
176
|
+
|
177
|
+
if isinstance(brightness, numbers.Number):
|
178
|
+
|
179
|
+
if brightness > 0:
|
180
|
+
brightness_factor = random.uniform(
|
181
|
+
max(0, 1 - brightness), 1 + brightness
|
182
|
+
)
|
183
|
+
transforms.append(
|
184
|
+
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
|
185
|
+
)
|
186
|
+
|
187
|
+
if contrast > 0:
|
188
|
+
contrast_factor = random.uniform(max(0, 1 - contrast), 1 + contrast)
|
189
|
+
transforms.append(
|
190
|
+
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
|
191
|
+
)
|
192
|
+
|
193
|
+
if saturation > 0:
|
194
|
+
saturation_factor = random.uniform(
|
195
|
+
max(0, 1 - saturation), 1 + saturation
|
196
|
+
)
|
197
|
+
transforms.append(
|
198
|
+
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
|
199
|
+
)
|
200
|
+
|
201
|
+
if hue > 0:
|
202
|
+
hue_factor = random.uniform(-hue, hue)
|
203
|
+
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
204
|
+
|
205
|
+
else:
|
206
|
+
|
207
|
+
if brightness[0] > 0 and brightness[1] > 0:
|
208
|
+
|
209
|
+
brightness_factor = random.uniform(brightness[0], brightness[1])
|
210
|
+
transforms.append(
|
211
|
+
Lambda(lambda img: F.adjust_brightness(img, brightness_factor))
|
212
|
+
)
|
213
|
+
|
214
|
+
if contrast[0] > 0 and contrast[1] > 0:
|
215
|
+
|
216
|
+
contrast_factor = random.uniform(contrast[0], contrast[1])
|
217
|
+
transforms.append(
|
218
|
+
Lambda(lambda img: F.adjust_contrast(img, contrast_factor))
|
219
|
+
)
|
220
|
+
|
221
|
+
if saturation[0] > 0 and saturation[1] > 0:
|
222
|
+
|
223
|
+
saturation_factor = random.uniform(saturation[0], saturation[1])
|
224
|
+
transforms.append(
|
225
|
+
Lambda(lambda img: F.adjust_saturation(img, saturation_factor))
|
226
|
+
)
|
227
|
+
|
228
|
+
if hue[0] > 0 and hue[1] > 0:
|
229
|
+
hue_factor = random.uniform(hue[0], hue[1])
|
230
|
+
transforms.append(Lambda(lambda img: F.adjust_hue(img, hue_factor)))
|
231
|
+
|
232
|
+
random.shuffle(transforms)
|
233
|
+
transform = ComposeSingle(transforms)
|
234
|
+
|
235
|
+
return transform
|
236
|
+
|
237
|
+
def __call__(self, img, target):
|
238
|
+
"""
|
239
|
+
Args:
|
240
|
+
img (np.ndarray): Input image.
|
241
|
+
Returns:
|
242
|
+
np.ndarray: Color jittered image.
|
243
|
+
"""
|
244
|
+
transform = self.get_params(
|
245
|
+
self.brightness, self.contrast, self.saturation, self.hue
|
246
|
+
)
|
247
|
+
return transform(img), target
|
248
|
+
|
249
|
+
def __repr__(self):
|
250
|
+
format_string = self.__class__.__name__ + "("
|
251
|
+
format_string += "brightness={0}".format(self.brightness)
|
252
|
+
format_string += ", contrast={0}".format(self.contrast)
|
253
|
+
format_string += ", saturation={0}".format(self.saturation)
|
254
|
+
format_string += ", hue={0})".format(self.hue)
|
255
|
+
return format_string
|
256
|
+
|
257
|
+
|
258
|
+
class Normalize(object):
|
259
|
+
"""Normalize a tensor image with mean and standard deviation.
|
260
|
+
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
|
261
|
+
will normalize each channel of the input ``torch.*Tensor`` i.e.
|
262
|
+
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
|
263
|
+
Args:
|
264
|
+
mean (sequence): Sequence of means for each channel.
|
265
|
+
std (sequence): Sequence of standard deviations for each channel.
|
266
|
+
"""
|
267
|
+
|
268
|
+
def __init__(self, mean, std):
|
269
|
+
self.mean = mean
|
270
|
+
self.std = std
|
271
|
+
|
272
|
+
def __call__(self, tensor, target=None):
|
273
|
+
"""
|
274
|
+
Args:
|
275
|
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
276
|
+
Returns:
|
277
|
+
Tensor: Normalized Tensor image.
|
278
|
+
"""
|
279
|
+
return F.normalize(tensor, self.mean, self.std), target
|
280
|
+
|
281
|
+
def __repr__(self):
|
282
|
+
return self.__class__.__name__ + "(mean={0}, std={1})".format(
|
283
|
+
self.mean, self.std
|
284
|
+
)
|
285
|
+
|
286
|
+
|
287
|
+
class NoTransformation(object):
|
288
|
+
"""Do Nothing"""
|
289
|
+
|
290
|
+
def __call__(self, img, target):
|
291
|
+
return img, target
|
292
|
+
|
293
|
+
|
294
|
+
class Compose(object):
|
295
|
+
"""Composes several transforms together.
|
296
|
+
Args:
|
297
|
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
298
|
+
Example:
|
299
|
+
>>> transforms.Compose([
|
300
|
+
>>> transforms.CenterCrop(10),
|
301
|
+
>>> transforms.ToTensor(),
|
302
|
+
>>> ])
|
303
|
+
"""
|
304
|
+
|
305
|
+
def __init__(self, transforms):
|
306
|
+
self.transforms = transforms
|
307
|
+
|
308
|
+
def __call__(self, img, target):
|
309
|
+
for t in self.transforms:
|
310
|
+
img, target = t(img, target)
|
311
|
+
return img, target
|
312
|
+
|
313
|
+
def __repr__(self):
|
314
|
+
format_string = self.__class__.__name__ + "("
|
315
|
+
for t in self.transforms:
|
316
|
+
format_string += "\n"
|
317
|
+
format_string += " {0}".format(t)
|
318
|
+
format_string += "\n)"
|
319
|
+
return format_string
|
320
|
+
|
321
|
+
|
322
|
+
class ComposeSingle(object):
|
323
|
+
"""Composes several transforms together.
|
324
|
+
Args:
|
325
|
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
326
|
+
Example:
|
327
|
+
>>> transforms.Compose([
|
328
|
+
>>> transforms.CenterCrop(10),
|
329
|
+
>>> transforms.ToTensor(),
|
330
|
+
>>> ])
|
331
|
+
"""
|
332
|
+
|
333
|
+
def __init__(self, transforms):
|
334
|
+
self.transforms = transforms
|
335
|
+
|
336
|
+
def __call__(self, img):
|
337
|
+
for t in self.transforms:
|
338
|
+
img = t(img)
|
339
|
+
return img
|
340
|
+
|
341
|
+
def __repr__(self):
|
342
|
+
format_string = self.__class__.__name__ + "("
|
343
|
+
for t in self.transforms:
|
344
|
+
format_string += "\n"
|
345
|
+
format_string += " {0}".format(t)
|
346
|
+
format_string += "\n)"
|
347
|
+
return format_string
|
348
|
+
|
349
|
+
|
350
|
+
class Resize(object):
|
351
|
+
"""Resize the input PIL Image to the given size.
|
352
|
+
Args:
|
353
|
+
size (sequence or int): Desired output size. If size is a sequence like
|
354
|
+
(h, w), output size will be matched to this. If size is an int,
|
355
|
+
smaller edge of the image will be matched to this number.
|
356
|
+
i.e, if height > width, then image will be rescaled to
|
357
|
+
(size * height / width, size)
|
358
|
+
interpolation (int, optional): Desired interpolation. Default is
|
359
|
+
``BILINEAR``
|
360
|
+
"""
|
361
|
+
|
362
|
+
def __init__(self, size, interpolation="BILINEAR"):
|
363
|
+
self.size = size
|
364
|
+
self.interpolation = interpolation
|
365
|
+
|
366
|
+
def __call__(self, img, target=None):
|
367
|
+
"""
|
368
|
+
Args:
|
369
|
+
img (np.ndarray): Image to be scaled.
|
370
|
+
Returns:
|
371
|
+
np.ndarray: Rescaled image.
|
372
|
+
"""
|
373
|
+
# Resize bboxes (in pixels)
|
374
|
+
x_scale = 0
|
375
|
+
y_scale = 0
|
376
|
+
|
377
|
+
if img.shape[1] > 0:
|
378
|
+
x_scale = self.size[0] / img.shape[1]
|
379
|
+
if img.shape[0] > 0:
|
380
|
+
y_scale = self.size[1] / img.shape[0]
|
381
|
+
|
382
|
+
# loop over bboxes
|
383
|
+
if target is not None:
|
384
|
+
if target["boxes"] is not None:
|
385
|
+
target_ = target.copy()
|
386
|
+
target_["boxes"][:, 0] = x_scale * target_["boxes"][:, 0]
|
387
|
+
target_["boxes"][:, 1] = y_scale * target_["boxes"][:, 1]
|
388
|
+
target_["boxes"][:, 2] = x_scale * target_["boxes"][:, 2]
|
389
|
+
target_["boxes"][:, 3] = y_scale * target_["boxes"][:, 3]
|
390
|
+
return F.resize(img, self.size, self.interpolation), target
|
391
|
+
|
392
|
+
def __repr__(self):
|
393
|
+
interpolate_str = self.interpolation
|
394
|
+
return self.__class__.__name__ + "(size={0}, interpolation={1})".format(
|
395
|
+
self.size, interpolate_str
|
396
|
+
)
|
File without changes
|
File without changes
|
@@ -0,0 +1,279 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import glob
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import time
|
9
|
+
from abc import ABC, abstractmethod
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
import torch
|
13
|
+
|
14
|
+
import docling_ibm_models.tableformer.settings as s
|
15
|
+
|
16
|
+
LOG_LEVEL = logging.INFO
|
17
|
+
# LOG_LEVEL = logging.DEBUG
|
18
|
+
|
19
|
+
|
20
|
+
class BaseModel(ABC):
|
21
|
+
r"""
|
22
|
+
BaseModel provides some common functionality for all models:
|
23
|
+
- Saves checkpoint files for each epoch
|
24
|
+
- Loads the model from the best available checkpoint
|
25
|
+
- Save repository branch and commit
|
26
|
+
"""
|
27
|
+
|
28
|
+
def __init__(self, config, init_data, device):
|
29
|
+
r"""
|
30
|
+
Inputs:
|
31
|
+
config: The configuration file
|
32
|
+
init_data: Dictionary with initialization data. This dictionary can be used to pass any
|
33
|
+
kind of initialization data for the models
|
34
|
+
device: The device used to move the tensors of the model
|
35
|
+
"""
|
36
|
+
super(BaseModel, self).__init__()
|
37
|
+
|
38
|
+
# Set config and device
|
39
|
+
self._config = config
|
40
|
+
self._init_data = init_data
|
41
|
+
|
42
|
+
self._device = device
|
43
|
+
|
44
|
+
self._save_dir = config["model"]["save_dir"]
|
45
|
+
self._load_checkpoint = None
|
46
|
+
if "load_checkpoint" in config["model"]:
|
47
|
+
self._load_checkpoint = config["model"]["load_checkpoint"]
|
48
|
+
|
49
|
+
self._branch_name = "dev/next"
|
50
|
+
self._commit_sha = "1"
|
51
|
+
|
52
|
+
# Keep a dictionary with the starting times per epoch.
|
53
|
+
# NOTICE: Epochs start from 0
|
54
|
+
self._epoch_start_ts = {0: time.time()}
|
55
|
+
|
56
|
+
def _log(self):
|
57
|
+
# Setup a custom logger
|
58
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
59
|
+
|
60
|
+
@abstractmethod
|
61
|
+
def predict(self, img, max_steps, beam_size, return_attention):
|
62
|
+
pass
|
63
|
+
|
64
|
+
def count_parameters(self):
|
65
|
+
r"""Counts the number of trainable parameters of this model
|
66
|
+
|
67
|
+
Output:
|
68
|
+
num_parameters: number of trainable parameters
|
69
|
+
"""
|
70
|
+
num_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
71
|
+
|
72
|
+
return num_parameters
|
73
|
+
|
74
|
+
def get_code_version(self):
|
75
|
+
r"""Gets the source control version of this model code
|
76
|
+
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
branch_name : str
|
80
|
+
The name of the Git branch of this model code
|
81
|
+
commit_sha : str
|
82
|
+
The unique identifier of the Git commit of this model code
|
83
|
+
"""
|
84
|
+
|
85
|
+
return self._branch_name, self._commit_sha
|
86
|
+
|
87
|
+
def get_save_directory(self):
|
88
|
+
r"""
|
89
|
+
Return the save directory
|
90
|
+
"""
|
91
|
+
return self._save_dir
|
92
|
+
|
93
|
+
def is_saved(self):
|
94
|
+
r"""
|
95
|
+
This method returns True if both conditions are met:
|
96
|
+
1. There is a checkpoint file for the model.
|
97
|
+
2. The checkpoint file corresponds to the last training epoch set in the configuration file.
|
98
|
+
"""
|
99
|
+
# Get the saved_model
|
100
|
+
saved_model, _ = self._load_best_checkpoint()
|
101
|
+
|
102
|
+
if saved_model is None:
|
103
|
+
return False
|
104
|
+
|
105
|
+
epochs = self._config["train"]["epochs"]
|
106
|
+
self._log().debug(
|
107
|
+
"Best epoch in saved model: {}; Number of epochs in config: {}".format(
|
108
|
+
saved_model["epoch"], epochs
|
109
|
+
)
|
110
|
+
)
|
111
|
+
if epochs == saved_model["epoch"] + 1:
|
112
|
+
return True
|
113
|
+
|
114
|
+
return False
|
115
|
+
|
116
|
+
def save(self, epoch=None, optimizers=None, losses=None, model_parameters=None):
|
117
|
+
r"""
|
118
|
+
Save the model data to the disk as a pickle file.
|
119
|
+
|
120
|
+
Parameters
|
121
|
+
----------
|
122
|
+
epoch: Training epoch
|
123
|
+
optimizers: Dictionary with the optimizers. The key specifies what the optimizer is
|
124
|
+
used for. The 'state_dict' of each optimizer will be saved in the
|
125
|
+
checkpoint file.
|
126
|
+
losses: Dictionary with the losses. The key specifies what the loss is used for. Each
|
127
|
+
value is a list
|
128
|
+
model_parameters: Dictionary with model specific parameters that we need to save in the
|
129
|
+
checkpoint file.
|
130
|
+
Returns
|
131
|
+
-------
|
132
|
+
True if success, False otherwise
|
133
|
+
"""
|
134
|
+
# Get the checkpoint_filename
|
135
|
+
c_filename = self._build_checkpoint_filename(epoch)
|
136
|
+
self._log().debug("Trying to save checkpoint file: {}".format(c_filename))
|
137
|
+
|
138
|
+
# Prepare a dictionary with all data we want to save
|
139
|
+
optimizers_state_dict = None
|
140
|
+
if optimizers is not None:
|
141
|
+
optimizers_state_dict = {k: v.state_dict() for k, v in optimizers.items()}
|
142
|
+
|
143
|
+
model_data = {
|
144
|
+
"model_state_dict": self.state_dict(),
|
145
|
+
"epoch": epoch,
|
146
|
+
"optimizers": optimizers_state_dict,
|
147
|
+
"losses": losses,
|
148
|
+
"model_parameters": model_parameters,
|
149
|
+
}
|
150
|
+
|
151
|
+
# Add the processing time per epoch
|
152
|
+
now = time.time()
|
153
|
+
self._epoch_start_ts[epoch + 1] = now
|
154
|
+
if epoch in self._epoch_start_ts:
|
155
|
+
dt = now - self._epoch_start_ts[epoch]
|
156
|
+
model_data["epoch_start_ts"] = self._epoch_start_ts[epoch]
|
157
|
+
model_data["epoch_dt"] = dt
|
158
|
+
|
159
|
+
# Create the save directory
|
160
|
+
Path(self._save_dir).mkdir(parents=True, exist_ok=True)
|
161
|
+
|
162
|
+
# Save the model
|
163
|
+
torch.save(model_data, c_filename)
|
164
|
+
|
165
|
+
# Return true if file is present, otherwise false
|
166
|
+
if not os.path.isfile(c_filename):
|
167
|
+
self._log().error("Cannot find the file to save: " + c_filename)
|
168
|
+
return False
|
169
|
+
|
170
|
+
# store code branch name and commit
|
171
|
+
version_file = os.path.join(self._save_dir, "_version")
|
172
|
+
with open(version_file, "w") as text_file:
|
173
|
+
print("Model is using code [commit:branch]", file=text_file)
|
174
|
+
print("{}:{}".format(self._commit_sha, self._branch_name), file=text_file)
|
175
|
+
|
176
|
+
return True
|
177
|
+
|
178
|
+
def load(self, optimizers=None):
|
179
|
+
r"""
|
180
|
+
Load the model data from the disk.
|
181
|
+
The method will iterate over all *.check files and try to load the one from the highest
|
182
|
+
epoch.
|
183
|
+
|
184
|
+
Input:
|
185
|
+
-optimizers: Dictionary with optimizers. If it is not null the keys will be used to
|
186
|
+
associate the corresponding state_dicts from the checkpoint file and update
|
187
|
+
the internal states of the provided optimizers.
|
188
|
+
|
189
|
+
Output:
|
190
|
+
- Success: True/ False
|
191
|
+
- epoch: Loaded epoch or -1 if there are no checkpoint files
|
192
|
+
- optimizers: Dictionary with loaded optimizers or empty dictionary of there is no
|
193
|
+
checkpoint file
|
194
|
+
- losses: Dictionary with loaded losses or empty dictionary of there is no checkpoint
|
195
|
+
file
|
196
|
+
- model_parameters: Dictionary with the model parameters or empty dictionary if there
|
197
|
+
are no checkpoint files
|
198
|
+
"""
|
199
|
+
# Get the saved_model
|
200
|
+
saved_model, _ = self._load_best_checkpoint()
|
201
|
+
|
202
|
+
# Restore the model
|
203
|
+
if saved_model is None:
|
204
|
+
self._log().debug("No saved model checkpoint found")
|
205
|
+
return False, -1, optimizers, {}, {}
|
206
|
+
|
207
|
+
self._log().debug("Loading model from checkpoint file")
|
208
|
+
self.load_state_dict(saved_model["model_state_dict"])
|
209
|
+
|
210
|
+
epoch = 0
|
211
|
+
if "epoch" in saved_model:
|
212
|
+
epoch = saved_model["epoch"]
|
213
|
+
losses = {}
|
214
|
+
if "losses" in saved_model:
|
215
|
+
losses = saved_model["losses"]
|
216
|
+
model_parameters = saved_model["model_parameters"]
|
217
|
+
|
218
|
+
if optimizers is not None:
|
219
|
+
for key, optimizer_state_dict in saved_model["optimizers"].items():
|
220
|
+
optimizers[key].load_state_dict(optimizer_state_dict)
|
221
|
+
|
222
|
+
# Reset the start_ts of the next epoch
|
223
|
+
self._epoch_start_ts[epoch + 1] = time.time()
|
224
|
+
|
225
|
+
return True, epoch, optimizers, losses, model_parameters
|
226
|
+
|
227
|
+
def _load_best_checkpoint(self):
|
228
|
+
r"""
|
229
|
+
If a "load_checkpoint" file has been provided, load this one.
|
230
|
+
Otherwise use the "save_dir" and load the one with the most advanced epoch
|
231
|
+
|
232
|
+
Returns
|
233
|
+
-------
|
234
|
+
saved_model : dictionary
|
235
|
+
Checkpoint file contents generated by torch.load, or None
|
236
|
+
checkpoint_file : string
|
237
|
+
Filename of the loaded checkpoint, or None
|
238
|
+
"""
|
239
|
+
checkpoint_files = []
|
240
|
+
# If a "load_checkpoint" file is provided, try to load it
|
241
|
+
if self._load_checkpoint is not None:
|
242
|
+
if not os.path.exists(self._load_checkpoint):
|
243
|
+
self._log().error(
|
244
|
+
"Cannot load the checkpoint: {}".format(self._load_checkpoint)
|
245
|
+
)
|
246
|
+
return None, None
|
247
|
+
checkpoint_files.append(self._load_checkpoint)
|
248
|
+
else:
|
249
|
+
# Iterate over all check files from the directory by reverse alphabetical order
|
250
|
+
# This will get the biggest epoch first
|
251
|
+
checkpoint_files = glob.glob(os.path.join(self._save_dir, "*.check"))
|
252
|
+
checkpoint_files.sort(reverse=True)
|
253
|
+
|
254
|
+
for checkpoint_file in checkpoint_files:
|
255
|
+
try:
|
256
|
+
# Try to load the file
|
257
|
+
self._log().info(
|
258
|
+
"Loading model checkpoint file: {}".format(checkpoint_file)
|
259
|
+
)
|
260
|
+
saved_model = torch.load(checkpoint_file, map_location=self._device)
|
261
|
+
return saved_model, checkpoint_file
|
262
|
+
except RuntimeError:
|
263
|
+
self._log().error("Cannot load file: {}".format(checkpoint_file))
|
264
|
+
|
265
|
+
return None, None
|
266
|
+
|
267
|
+
def _build_checkpoint_filename(self, epoch):
|
268
|
+
r"""
|
269
|
+
Construct the full path for the filename of this checkpoint
|
270
|
+
"""
|
271
|
+
dataset_name = self._config["dataset"]["name"]
|
272
|
+
model_type = self._config["model"]["type"]
|
273
|
+
model_name = self._config["model"]["name"]
|
274
|
+
filename = "{}_{}_{}_{:03}.check".format(
|
275
|
+
model_type, model_name, dataset_name, epoch
|
276
|
+
)
|
277
|
+
c_filename = os.path.join(self._save_dir, filename)
|
278
|
+
|
279
|
+
return c_filename
|
File without changes
|