docling-ibm-models 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/tableformer/common.py +0 -94
- docling_ibm_models/tableformer/data_management/functional.py +1 -478
- docling_ibm_models/tableformer/data_management/tf_predictor.py +5 -71
- docling_ibm_models/tableformer/data_management/transforms.py +0 -305
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +1 -1
- {docling_ibm_models-1.3.0.dist-info → docling_ibm_models-1.3.2.dist-info}/METADATA +2 -2
- {docling_ibm_models-1.3.0.dist-info → docling_ibm_models-1.3.2.dist-info}/RECORD +9 -13
- docling_ibm_models/tableformer/data_management/data_transformer.py +0 -504
- docling_ibm_models/tableformer/data_management/tf_dataset.py +0 -1233
- docling_ibm_models/tableformer/test_dataset_cache.py +0 -37
- docling_ibm_models/tableformer/test_prepare_image.py +0 -99
- {docling_ibm_models-1.3.0.dist-info → docling_ibm_models-1.3.2.dist-info}/LICENSE +0 -0
- {docling_ibm_models-1.3.0.dist-info → docling_ibm_models-1.3.2.dist-info}/WHEEL +0 -0
@@ -1,504 +0,0 @@
|
|
1
|
-
#
|
2
|
-
# Copyright IBM Corp. 2024 - 2024
|
3
|
-
# SPDX-License-Identifier: MIT
|
4
|
-
#
|
5
|
-
import copy
|
6
|
-
import logging
|
7
|
-
import os
|
8
|
-
import random
|
9
|
-
|
10
|
-
import cv2
|
11
|
-
import numpy as np
|
12
|
-
import torch
|
13
|
-
import torchvision.transforms as transforms
|
14
|
-
from PIL import Image, ImageDraw
|
15
|
-
|
16
|
-
import docling_ibm_models.tableformer.data_management.transforms as T
|
17
|
-
import docling_ibm_models.tableformer.settings as s
|
18
|
-
|
19
|
-
LOG_LEVEL = logging.INFO
|
20
|
-
# LOG_LEVEL = logging.DEBUG
|
21
|
-
|
22
|
-
|
23
|
-
class DataTransformer:
|
24
|
-
r"""
|
25
|
-
Data transformations for the images and bboxes
|
26
|
-
|
27
|
-
Check the "help" fields inside the config file for an explanation of each parameter
|
28
|
-
"""
|
29
|
-
|
30
|
-
def __init__(self, config):
|
31
|
-
self._config = config
|
32
|
-
|
33
|
-
print("DataTransformer Init!")
|
34
|
-
|
35
|
-
def _log(self):
|
36
|
-
# Setup a custom logger
|
37
|
-
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
38
|
-
|
39
|
-
def append_id(self, filename):
|
40
|
-
name, ext = os.path.splitext(filename)
|
41
|
-
return "{name}_{uid}{ext}".format(name=name, uid="resized", ext=ext)
|
42
|
-
|
43
|
-
def load_image(self, img_fn):
|
44
|
-
r"""
|
45
|
-
Load an image from the disk
|
46
|
-
|
47
|
-
Parameters
|
48
|
-
----------
|
49
|
-
img_fn: The filename of the image
|
50
|
-
|
51
|
-
Returns
|
52
|
-
-------
|
53
|
-
PIL image object
|
54
|
-
"""
|
55
|
-
|
56
|
-
img = Image.open(img_fn)
|
57
|
-
return img
|
58
|
-
|
59
|
-
def load_image_cv2(self, img_fn):
|
60
|
-
r"""Load an image from the disk
|
61
|
-
|
62
|
-
Parameters
|
63
|
-
----------
|
64
|
-
img_fn: The filename of the image
|
65
|
-
|
66
|
-
Returns
|
67
|
-
-------
|
68
|
-
CV2 image object
|
69
|
-
"""
|
70
|
-
img = cv2.imread(img_fn)
|
71
|
-
return img
|
72
|
-
|
73
|
-
def save_image(self, img, img_fn):
|
74
|
-
img.save(self.append_id(img_fn))
|
75
|
-
|
76
|
-
def renderbboxes(self, img, bboxes):
|
77
|
-
draw_img = ImageDraw.Draw(img)
|
78
|
-
for i in range(len(bboxes)):
|
79
|
-
draw_img.rectangle(bboxes[i], fill=None, outline=(255, 0, 0))
|
80
|
-
return img
|
81
|
-
|
82
|
-
def get_dataset_settings(self):
|
83
|
-
dataset = {}
|
84
|
-
debug = {"save_debug_images": False}
|
85
|
-
|
86
|
-
if "dataset" in self._config:
|
87
|
-
dataset = self._config["dataset"]
|
88
|
-
if "debug" in self._config:
|
89
|
-
debug = self._config["debug"]
|
90
|
-
|
91
|
-
return dataset, debug
|
92
|
-
|
93
|
-
def _prepare_image_from_file(self, image_fn, bboxes, convert_box=True):
|
94
|
-
r"""
|
95
|
-
Load the image from file and prepare it
|
96
|
-
|
97
|
-
Parameters
|
98
|
-
----------
|
99
|
-
image_fn : string
|
100
|
-
Filename to load the image
|
101
|
-
bboxes : dict
|
102
|
-
Bounding boxes of the image
|
103
|
-
convert_box : bool
|
104
|
-
If true the bboxes are converted to xcycwh format
|
105
|
-
|
106
|
-
Returns
|
107
|
-
-------
|
108
|
-
PIL image
|
109
|
-
A PIL image object with the image prepared according to the settings in the config file
|
110
|
-
bboxes : dict
|
111
|
-
Converted bboxes of the image
|
112
|
-
"""
|
113
|
-
im = self.load_image(image_fn)
|
114
|
-
return self._prepare_image(im, bboxes, convert_box, image_fn)
|
115
|
-
|
116
|
-
def _prepare_image(self, im, bboxes, convert_box=True, image_fn=None):
|
117
|
-
r"""
|
118
|
-
Parameters
|
119
|
-
----------
|
120
|
-
im : PIL image object
|
121
|
-
bboxes : dict
|
122
|
-
Bounding boxes of the image
|
123
|
-
convert_box : bool
|
124
|
-
If true the bboxes are converted to xcycwh format
|
125
|
-
image_fn : string
|
126
|
-
Filename of the original image or None. It is used to save augmented image for debugging
|
127
|
-
|
128
|
-
Returns
|
129
|
-
-------
|
130
|
-
PIL image
|
131
|
-
A PIL image object with the image prepared according to the settings in the config file
|
132
|
-
bboxes : dict
|
133
|
-
Converted bboxes of the image
|
134
|
-
"""
|
135
|
-
debug_settings = False
|
136
|
-
settings, debug_settings = self.get_dataset_settings()
|
137
|
-
|
138
|
-
desired_size = settings["resized_image"]
|
139
|
-
old_size = im.size
|
140
|
-
|
141
|
-
# Calculate Aspect Ratio if needed
|
142
|
-
if settings["keep_AR"]:
|
143
|
-
ratio = float(desired_size) / max(old_size)
|
144
|
-
else:
|
145
|
-
ratio = 1 # Square image
|
146
|
-
|
147
|
-
new_size = old_size
|
148
|
-
if max(old_size) < desired_size:
|
149
|
-
# Image is smaller than desired
|
150
|
-
# Upscale?
|
151
|
-
if settings["up_scaling_enabled"]:
|
152
|
-
# Calculate new image size, taking into account aspect ratio
|
153
|
-
new_size = tuple([int(x * ratio) for x in old_size])
|
154
|
-
else:
|
155
|
-
new_size = old_size
|
156
|
-
else:
|
157
|
-
if settings["keep_AR"]:
|
158
|
-
new_size = tuple([int(x * ratio) for x in old_size])
|
159
|
-
else:
|
160
|
-
new_size = [desired_size, desired_size]
|
161
|
-
|
162
|
-
if not settings["keep_AR"]:
|
163
|
-
if settings["up_scaling_enabled"]:
|
164
|
-
new_size = [desired_size, desired_size]
|
165
|
-
|
166
|
-
######################################################################################
|
167
|
-
# Use OpenCV to resize the image
|
168
|
-
#
|
169
|
-
# im = im.resize(new_size, Image.ANTIALIAS)
|
170
|
-
|
171
|
-
import cv2
|
172
|
-
|
173
|
-
np_im = np.array(im)
|
174
|
-
np_resized = cv2.resize(np_im, new_size, interpolation=cv2.INTER_LANCZOS4)
|
175
|
-
im = Image.fromarray(np_resized)
|
176
|
-
######################################################################################
|
177
|
-
|
178
|
-
new_bboxes = copy.deepcopy(bboxes)
|
179
|
-
|
180
|
-
# Resize bboxes (in pixels)
|
181
|
-
x_scale = new_size[0] / old_size[0]
|
182
|
-
y_scale = new_size[1] / old_size[1]
|
183
|
-
# loop over bboxes
|
184
|
-
for i in range(len(new_bboxes)):
|
185
|
-
new_bboxes[i][0] = x_scale * bboxes[i][0]
|
186
|
-
new_bboxes[i][1] = y_scale * bboxes[i][1]
|
187
|
-
new_bboxes[i][2] = x_scale * bboxes[i][2]
|
188
|
-
new_bboxes[i][3] = y_scale * bboxes[i][3]
|
189
|
-
|
190
|
-
# Set background color for padding
|
191
|
-
br = settings["padding_color"][0]
|
192
|
-
bg = settings["padding_color"][1]
|
193
|
-
bb = settings["padding_color"][2]
|
194
|
-
bcolor = (br, bg, bb)
|
195
|
-
# Create empty canvas of background color and desired size
|
196
|
-
new_im = Image.new(mode="RGB", size=(desired_size, desired_size), color=bcolor)
|
197
|
-
|
198
|
-
if "grayscale" in settings:
|
199
|
-
if settings["grayscale"]:
|
200
|
-
im = im.convert("LA")
|
201
|
-
|
202
|
-
if settings["padding_mode"] == "frame":
|
203
|
-
# If paddinds are around image, paste resized image in the center
|
204
|
-
x_pad = (desired_size - new_size[0]) // 2
|
205
|
-
y_pad = (desired_size - new_size[1]) // 2
|
206
|
-
# Paste rescaled image
|
207
|
-
new_im.paste(im, (x_pad, y_pad))
|
208
|
-
# Offset (pad) bboxes
|
209
|
-
# loop over bboxes
|
210
|
-
for i in range(len(new_bboxes)):
|
211
|
-
new_bboxes[i][0] += x_pad
|
212
|
-
new_bboxes[i][1] += y_pad
|
213
|
-
new_bboxes[i][2] += x_pad
|
214
|
-
new_bboxes[i][3] += y_pad
|
215
|
-
else:
|
216
|
-
# Otherwise paste in the 0,0 coordinates
|
217
|
-
new_im.paste(im, (0, 0))
|
218
|
-
|
219
|
-
if debug_settings:
|
220
|
-
if debug_settings["save_debug_images"]:
|
221
|
-
aug_im = self.renderbboxes(new_im, bboxes)
|
222
|
-
if "grayscale" in settings:
|
223
|
-
if settings["grayscale"]:
|
224
|
-
aug_im = aug_im.convert("LA")
|
225
|
-
self.save_image(aug_im, image_fn)
|
226
|
-
if convert_box:
|
227
|
-
bboxes = self.xyxy_to_xcycwh(new_bboxes, desired_size)
|
228
|
-
return new_im, bboxes
|
229
|
-
|
230
|
-
# convert bboxes from [x1, y1, x2, y2] format to [xc, yc, w, h] format
|
231
|
-
def xyxy_to_xcycwh(self, bboxes, size):
|
232
|
-
# Use the "dataset.bbox_format" parameter to decide which bbox format to use
|
233
|
-
bbox_format = self._config["dataset"].get("bbox_format", "4plet")
|
234
|
-
|
235
|
-
conv_bboxes = []
|
236
|
-
for i in range(len(bboxes)):
|
237
|
-
x1 = bboxes[i][0] / size # X1
|
238
|
-
y1 = bboxes[i][1] / size # Y1
|
239
|
-
x2 = bboxes[i][2] / size # X2
|
240
|
-
y2 = bboxes[i][3] / size # Y2
|
241
|
-
xc = (x1 + x2) / 2
|
242
|
-
yc = (y1 + y2) / 2
|
243
|
-
bw = abs(x2 - x1)
|
244
|
-
bh = abs(y2 - y1)
|
245
|
-
|
246
|
-
if bbox_format == "5plet":
|
247
|
-
cls = bboxes[i][4]
|
248
|
-
conv_bboxes.append([xc, yc, bw, bh, cls])
|
249
|
-
else:
|
250
|
-
conv_bboxes.append([xc, yc, bw, bh])
|
251
|
-
|
252
|
-
# conv_bboxes = bboxes
|
253
|
-
return conv_bboxes
|
254
|
-
|
255
|
-
def rescale_in_memory(self, image, normalization):
|
256
|
-
r"""
|
257
|
-
Receive image and escale it in memory
|
258
|
-
|
259
|
-
Parameters
|
260
|
-
----------
|
261
|
-
image : PIL image
|
262
|
-
The image data to rescale
|
263
|
-
normalization : dictionary
|
264
|
-
The normalization information with the format:
|
265
|
-
"state": "true or false if image normalization is to be enabled",
|
266
|
-
"mean": "mean values to use if state is true",
|
267
|
-
"std": "std values to use if state is true"
|
268
|
-
Returns
|
269
|
-
-------
|
270
|
-
npimgc : FloatTensor
|
271
|
-
The loaded and properly transformed image data
|
272
|
-
"""
|
273
|
-
settings, debug_settings = self.get_dataset_settings()
|
274
|
-
new_image, _ = self._prepare_image(image, {}, convert_box=False, image_fn=None)
|
275
|
-
# Convert to nparray
|
276
|
-
npimg = np.asarray(new_image) # (width, height, channels)
|
277
|
-
|
278
|
-
# Convert to float?
|
279
|
-
npimgc = npimg.copy()
|
280
|
-
|
281
|
-
# Transpose numpy array (image)
|
282
|
-
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
|
283
|
-
npimgc = torch.FloatTensor(npimgc / 255.0)
|
284
|
-
|
285
|
-
if normalization:
|
286
|
-
transform = transforms.Compose(
|
287
|
-
[
|
288
|
-
transforms.Normalize(
|
289
|
-
mean=self._config["dataset"]["image_normalization"]["mean"],
|
290
|
-
std=self._config["dataset"]["image_normalization"]["std"],
|
291
|
-
)
|
292
|
-
]
|
293
|
-
)
|
294
|
-
npimgc = transform(npimgc)
|
295
|
-
|
296
|
-
return npimgc
|
297
|
-
|
298
|
-
def _rescale(self, image_fn, bboxes, normalization):
|
299
|
-
r"""
|
300
|
-
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
301
|
-
from the config
|
302
|
-
|
303
|
-
Parameters
|
304
|
-
----------
|
305
|
-
image_fn: full image file name
|
306
|
-
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
307
|
-
bottom-left points
|
308
|
-
statistics: Dictionary with statistics over the whole image dataset.
|
309
|
-
The keys are: "mean", "variance", "std" and each value is a list with the
|
310
|
-
coresponding statistical value for each channel. Normally there are 3
|
311
|
-
channels.
|
312
|
-
|
313
|
-
Returns
|
314
|
-
-------
|
315
|
-
npimgc : FloatTensor
|
316
|
-
The loaded and properly transformed image data
|
317
|
-
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
|
318
|
-
center, w, h the width and height of the bbox and all are normalized to the
|
319
|
-
scaled size of the image
|
320
|
-
|
321
|
-
Raises
|
322
|
-
------
|
323
|
-
ValueError
|
324
|
-
In case the configuration and the image dimensions make it impossible to rescale the
|
325
|
-
image throw a ValueError exception
|
326
|
-
"""
|
327
|
-
settings, debug_settings = self.get_dataset_settings()
|
328
|
-
# new_image is a PIL object
|
329
|
-
new_image, new_bboxes = self._prepare_image_from_file(image_fn, bboxes)
|
330
|
-
# Convert to nparray
|
331
|
-
npimg = np.asarray(new_image) # (width, height, channels)
|
332
|
-
# Convert to float?
|
333
|
-
npimgc = npimg.copy()
|
334
|
-
# Transpose numpy array (image)
|
335
|
-
npimgc = npimgc.transpose(2, 0, 1) # (channels, width, height)
|
336
|
-
npimgc = torch.FloatTensor(npimgc / 255.0)
|
337
|
-
|
338
|
-
if normalization:
|
339
|
-
transform = transforms.Compose(
|
340
|
-
[
|
341
|
-
transforms.Normalize(
|
342
|
-
mean=self._config["dataset"]["image_normalization"]["mean"],
|
343
|
-
std=self._config["dataset"]["image_normalization"]["std"],
|
344
|
-
)
|
345
|
-
]
|
346
|
-
)
|
347
|
-
npimgc = transform(npimgc)
|
348
|
-
return npimgc, new_bboxes
|
349
|
-
|
350
|
-
def rescale_old(self, image, bboxes, statistics=None):
|
351
|
-
r"""
|
352
|
-
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
353
|
-
from the config
|
354
|
-
|
355
|
-
Input:
|
356
|
-
image: np array (channels, width, height)
|
357
|
-
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
358
|
-
bottom-left points
|
359
|
-
statistics: Dictionary with statistics over the whole image dataset.
|
360
|
-
The keys are: "mean", "variance", "std" and each value is a list with the
|
361
|
-
coresponding statistical value for each channel. Normally there are 3
|
362
|
-
channels.
|
363
|
-
|
364
|
-
Output:
|
365
|
-
image: np array (channels, resized_image, resized_image)
|
366
|
-
bboxes: List with bboxes in the format (xc, yc, w, h) where xc, yc are the coords of the
|
367
|
-
center, w, h the width and height of the bbox and all are normalized to the
|
368
|
-
scaled size of the image
|
369
|
-
|
370
|
-
Exceptions:
|
371
|
-
In case the configuration and the image dimensions make it impossible to rescale the
|
372
|
-
image throw a ValueError exception
|
373
|
-
"""
|
374
|
-
image_size = 448
|
375
|
-
# Convert the image to (width, height, channels)
|
376
|
-
image = image.transpose(1, 2, 0)
|
377
|
-
|
378
|
-
# Convert to PIL format and resize
|
379
|
-
image = Image.fromarray(image)
|
380
|
-
image = image.resize((image_size, image_size), Image.ANTIALIAS)
|
381
|
-
return image, bboxes
|
382
|
-
|
383
|
-
def rescale_batch(self, images, bboxes, statistics=None):
|
384
|
-
r"""
|
385
|
-
Rescale, resize, pad the given batch of images and its associated bboxes according to the
|
386
|
-
settings from the config.
|
387
|
-
|
388
|
-
Input:
|
389
|
-
images: np array (batch_size, channels, width, height)
|
390
|
-
bboxes:
|
391
|
-
statistics: Dictionary with statistics over the whole image dataset.
|
392
|
-
The keys are: "mean", "variance", "std" and each value is a list with the
|
393
|
-
coresponding statistical value for each channel. Normally there are 3
|
394
|
-
channels.
|
395
|
-
|
396
|
-
Output:
|
397
|
-
image batch: np array (batch_size, channels, resized_image, resized_image)
|
398
|
-
bboxes:
|
399
|
-
|
400
|
-
Exceptions:
|
401
|
-
In case the configuration and the image dimensions make it impossible to rescale the
|
402
|
-
image throw a ValueError exception
|
403
|
-
"""
|
404
|
-
pass
|
405
|
-
|
406
|
-
def sample_preprocessor(self, image_fn, bboxes, purpose, table_bboxes=None):
|
407
|
-
r"""
|
408
|
-
Rescale, resize, pad the given image and its associated bboxes according to the settings
|
409
|
-
from the config
|
410
|
-
|
411
|
-
Parameters
|
412
|
-
----------
|
413
|
-
image_fn: full image file name
|
414
|
-
bboxes: List with bboxes in the format [x1, y1, x2, y2] with box's top-right,
|
415
|
-
bottom-left points
|
416
|
-
statistics: Dictionary with statistics over the whole image dataset.
|
417
|
-
The keys are: "mean", "variance", "std" and each value is a list with the
|
418
|
-
coresponding statistical value for each channel. Normally there are 3
|
419
|
-
channels.
|
420
|
-
|
421
|
-
Returns
|
422
|
-
-------
|
423
|
-
npimgc : FloatTensor
|
424
|
-
The loaded and properly transformed image data
|
425
|
-
bboxes: List with bboxes in the format [xc, yc, w, h] where xc, yc are the coords of the
|
426
|
-
center, w, h the width and height of the bbox and all are normalized to the
|
427
|
-
scaled size of the image
|
428
|
-
|
429
|
-
Raises
|
430
|
-
------
|
431
|
-
ValueError
|
432
|
-
In case the configuration and the image dimensions make it impossible to rescale the
|
433
|
-
image throw a ValueError exception
|
434
|
-
"""
|
435
|
-
settings, debug_settings = self.get_dataset_settings()
|
436
|
-
img = self.load_image_cv2(image_fn)
|
437
|
-
img = np.ascontiguousarray(img)
|
438
|
-
|
439
|
-
target = {
|
440
|
-
"size": [img.shape[1], img.shape[2]],
|
441
|
-
"boxes": (
|
442
|
-
torch.from_numpy(np.array(bboxes)[:, :4])
|
443
|
-
if purpose != s.PREDICT_PURPOSE
|
444
|
-
else None
|
445
|
-
),
|
446
|
-
"classes": (
|
447
|
-
np.array(bboxes)[:, -1] if purpose != s.PREDICT_PURPOSE else None
|
448
|
-
),
|
449
|
-
"area": img.shape[1] * img.shape[2],
|
450
|
-
}
|
451
|
-
|
452
|
-
optional_transforms = [T.NoTransformation()]
|
453
|
-
|
454
|
-
# Necessary preprocessing ends here, experimental options begin below.
|
455
|
-
# DETR format, might be necessary to keep this structure to share other functions used by
|
456
|
-
# the community
|
457
|
-
|
458
|
-
if purpose == s.TRAIN_PURPOSE:
|
459
|
-
if self._config["dataset"]["color_jitter"]:
|
460
|
-
jitter = T.ColorJitter(
|
461
|
-
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
|
462
|
-
)
|
463
|
-
optional_transforms.append(jitter)
|
464
|
-
|
465
|
-
if self._config["dataset"]["rand_pad"]:
|
466
|
-
pad_val = random.randint(0, 50)
|
467
|
-
rand_pad = T.RandomPad(pad_val)
|
468
|
-
optional_transforms.append(rand_pad)
|
469
|
-
|
470
|
-
if table_bboxes is not None:
|
471
|
-
if self._config["dataset"]["rand_crop"]:
|
472
|
-
w_, h_, _ = img.shape[0], img.shape[1], img.shape[2]
|
473
|
-
w_c, h_c = table_bboxes[0], table_bboxes[1]
|
474
|
-
f_w, f_h = random.randint(0, w_c), random.randint(0, h_c)
|
475
|
-
rand_crop = T.RandomCrop((w_, h_), (f_w, f_h))
|
476
|
-
optional_transforms.append(rand_crop)
|
477
|
-
|
478
|
-
# transform_opt = random.choice(optional_transforms)
|
479
|
-
normalize = T.Normalize(
|
480
|
-
mean=self._config["dataset"]["image_normalization"]["mean"],
|
481
|
-
std=self._config["dataset"]["image_normalization"]["std"],
|
482
|
-
)
|
483
|
-
resized_size = self._config["dataset"]["resized_image"]
|
484
|
-
resize = T.Resize([resized_size, resized_size])
|
485
|
-
|
486
|
-
transformations = T.RandomChoice(optional_transforms)
|
487
|
-
|
488
|
-
img, target = transformations(img, target)
|
489
|
-
img, target = normalize(img, target)
|
490
|
-
img, target = resize(img, target)
|
491
|
-
|
492
|
-
img = img.transpose(2, 1, 0) # (channels, width, height)
|
493
|
-
img = torch.FloatTensor(img / 255.0)
|
494
|
-
bboxes_ = target["boxes"]
|
495
|
-
classes_ = target["classes"]
|
496
|
-
desired_size = img.shape[1]
|
497
|
-
|
498
|
-
if purpose != s.PREDICT_PURPOSE:
|
499
|
-
bboxes_ = np.concatenate(
|
500
|
-
(bboxes_, np.expand_dims(classes_, axis=1)), axis=1
|
501
|
-
)
|
502
|
-
bboxes_ = self.xyxy_to_xcycwh(bboxes_, desired_size)
|
503
|
-
|
504
|
-
return img, bboxes_
|