docling-ibm-models 1.3.1__py3-none-any.whl → 1.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/tableformer/common.py +0 -84
- docling_ibm_models/tableformer/data_management/functional.py +1 -478
- docling_ibm_models/tableformer/data_management/tf_predictor.py +5 -71
- docling_ibm_models/tableformer/data_management/transforms.py +0 -305
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +1 -1
- {docling_ibm_models-1.3.1.dist-info → docling_ibm_models-1.3.3.dist-info}/METADATA +1 -1
- {docling_ibm_models-1.3.1.dist-info → docling_ibm_models-1.3.3.dist-info}/RECORD +9 -13
- docling_ibm_models/tableformer/data_management/data_transformer.py +0 -504
- docling_ibm_models/tableformer/data_management/tf_dataset.py +0 -1233
- docling_ibm_models/tableformer/test_dataset_cache.py +0 -37
- docling_ibm_models/tableformer/test_prepare_image.py +0 -99
- {docling_ibm_models-1.3.1.dist-info → docling_ibm_models-1.3.3.dist-info}/LICENSE +0 -0
- {docling_ibm_models-1.3.1.dist-info → docling_ibm_models-1.3.3.dist-info}/WHEEL +0 -0
@@ -1,1233 +0,0 @@
|
|
1
|
-
#
|
2
|
-
# Copyright IBM Corp. 2024 - 2024
|
3
|
-
# SPDX-License-Identifier: MIT
|
4
|
-
#
|
5
|
-
import json
|
6
|
-
import logging
|
7
|
-
import os
|
8
|
-
from glob import glob
|
9
|
-
from html import escape
|
10
|
-
|
11
|
-
import jsonlines
|
12
|
-
import numpy as np
|
13
|
-
import torch
|
14
|
-
import torch.utils.data
|
15
|
-
from lxml import html
|
16
|
-
from PIL import Image
|
17
|
-
from torch.nn.utils.rnn import pad_sequence
|
18
|
-
from torch.utils.data import Dataset
|
19
|
-
from tqdm import tqdm
|
20
|
-
|
21
|
-
import docling_ibm_models.tableformer.common as c
|
22
|
-
import docling_ibm_models.tableformer.settings as s
|
23
|
-
import docling_ibm_models.tableformer.utils.utils as u
|
24
|
-
from docling_ibm_models.tableformer.data_management.data_transformer import (
|
25
|
-
DataTransformer,
|
26
|
-
)
|
27
|
-
|
28
|
-
LOG_LEVEL = logging.INFO
|
29
|
-
# LOG_LEVEL = logging.DEBUG
|
30
|
-
|
31
|
-
|
32
|
-
class TFDataset(Dataset):
|
33
|
-
def __init__(self, config, purpose, fixed_padding=False):
|
34
|
-
r"""
|
35
|
-
Parameters
|
36
|
-
----------
|
37
|
-
config : Dictionary
|
38
|
-
The input configuration file
|
39
|
-
purpose : string
|
40
|
-
One of s.TRAIN_PURPOSE, s.VAL_PURPOSE, s.PREDICT_PURPOSE
|
41
|
-
fixed_padding : bool
|
42
|
-
If False (default), the produced tag sequences will be truncated to the maximum
|
43
|
-
actual length among the tag sequences of the batch
|
44
|
-
If True the produced tag and cell sequences will have fixed length equal to
|
45
|
-
max_tag_len and max_cell_len respectively.
|
46
|
-
"""
|
47
|
-
self.cml_task = {}
|
48
|
-
self.cml_logger = {}
|
49
|
-
self._config = config
|
50
|
-
self._fixed_padding = fixed_padding
|
51
|
-
self._index = 0 # Index to the current image file
|
52
|
-
self._max_tag_len = c.safe_get_parameter(
|
53
|
-
config, ["preparation", "max_tag_len"], required=True
|
54
|
-
)
|
55
|
-
self._max_cell_len = c.safe_get_parameter(
|
56
|
-
config, ["preparation", "max_cell_len"], required=True
|
57
|
-
)
|
58
|
-
self._resized_image = c.safe_get_parameter(
|
59
|
-
config, ["dataset", "resized_image"], required=True
|
60
|
-
)
|
61
|
-
self.annotation = c.safe_get_parameter(
|
62
|
-
config, ["preparation", "annotation"], required=True
|
63
|
-
)
|
64
|
-
self._image_normalization = config["dataset"]["image_normalization"]
|
65
|
-
|
66
|
-
self._load_cells = c.safe_get_parameter(
|
67
|
-
config, ["dataset", "load_cells"], required=False
|
68
|
-
)
|
69
|
-
self._predict_dir = c.safe_get_parameter(
|
70
|
-
config, ["predict", "predict_dir"], required=False
|
71
|
-
)
|
72
|
-
self._train_bbox = c.safe_get_parameter(
|
73
|
-
config, ["train", "bbox"], required=False
|
74
|
-
)
|
75
|
-
self._predict_bbox = c.safe_get_parameter(
|
76
|
-
config, ["predict", "bbox"], required=False
|
77
|
-
)
|
78
|
-
|
79
|
-
self._log().debug("purpose: {}".format(purpose))
|
80
|
-
self._log().debug("resized_image: {}".format(self._resized_image))
|
81
|
-
self._log().debug("image_normalization: {}".format(self._image_normalization))
|
82
|
-
|
83
|
-
# Check the type of the dataset
|
84
|
-
dataset_type = c.safe_get_parameter(config, ["dataset", "type"])
|
85
|
-
if dataset_type not in s.supported_datasets:
|
86
|
-
msg = "Unsupported dataset type: " + dataset_type
|
87
|
-
self._log().error(msg)
|
88
|
-
raise NotImplementedError(msg)
|
89
|
-
self._dataset_type = dataset_type
|
90
|
-
|
91
|
-
# Check the purpose of the object
|
92
|
-
if purpose not in [
|
93
|
-
s.TRAIN_PURPOSE,
|
94
|
-
s.VAL_PURPOSE,
|
95
|
-
s.TEST_PURPOSE,
|
96
|
-
s.PREDICT_PURPOSE,
|
97
|
-
]:
|
98
|
-
msg = "Unsupported purpose: " + purpose
|
99
|
-
self._log().error(msg)
|
100
|
-
raise Exception(msg)
|
101
|
-
self._purpose = purpose
|
102
|
-
|
103
|
-
# The batch_size is grounded to 1 in case of VAL, PREDICT
|
104
|
-
if purpose == s.TRAIN_PURPOSE:
|
105
|
-
self._batch_size = c.safe_get_parameter(
|
106
|
-
config, ["train", "batch_size"], required=True
|
107
|
-
)
|
108
|
-
else:
|
109
|
-
self._batch_size = 1
|
110
|
-
self._log().debug("batch_size: {}".format(self._batch_size))
|
111
|
-
|
112
|
-
self._transformer = DataTransformer(config)
|
113
|
-
if purpose == s.PREDICT_PURPOSE:
|
114
|
-
self._build_predict_cache()
|
115
|
-
else:
|
116
|
-
self._build_cache()
|
117
|
-
|
118
|
-
self._index = 0 # Index to the current image file
|
119
|
-
self._ind = np.array(range(self._dataset_size))
|
120
|
-
|
121
|
-
def set_device(self, device):
|
122
|
-
r"""
|
123
|
-
Set the device to be used to place the tensors when looping over the data
|
124
|
-
|
125
|
-
Parameters
|
126
|
-
----------
|
127
|
-
device : torch.device or int
|
128
|
-
The device to do the training
|
129
|
-
"""
|
130
|
-
self._device = device
|
131
|
-
|
132
|
-
def _log(self):
|
133
|
-
# Setup a custom logger
|
134
|
-
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
135
|
-
|
136
|
-
def __len__(self):
|
137
|
-
return int(self._dataset_size)
|
138
|
-
|
139
|
-
def get_batch_num(self):
|
140
|
-
return int(np.ceil(self._dataset_size / self._batch_size))
|
141
|
-
|
142
|
-
def _pad_tensors_from_batch(self, batch, field):
|
143
|
-
r"""
|
144
|
-
Retrieves data about tensor - "field" from the raw DataLoader batch list
|
145
|
-
And pads it to maximum batch length, to be further collated in the
|
146
|
-
custom collator function
|
147
|
-
|
148
|
-
Parameters
|
149
|
-
----------
|
150
|
-
batch : list
|
151
|
-
The list of samples, obtained by DataLoader from the sampler
|
152
|
-
a non-collated batch
|
153
|
-
field: string
|
154
|
-
Name of the tensor data-point from the sample, that has to be collated,
|
155
|
-
and because of that has to be padded to max length of the tensor in the batch
|
156
|
-
|
157
|
-
Returns
|
158
|
-
-------
|
159
|
-
batchoftensors: list
|
160
|
-
The list of padded tensors
|
161
|
-
"""
|
162
|
-
# Make list of selected items by field name
|
163
|
-
list_of_items = [x[field] for x in batch]
|
164
|
-
tagseqlens = []
|
165
|
-
# Identify lengths of tensor "lists"
|
166
|
-
for i, ten in enumerate(list_of_items):
|
167
|
-
tagnum = ten[0].size()[0]
|
168
|
-
tagseqlens.append(tagnum)
|
169
|
-
# Get the biggest length
|
170
|
-
maxbatchtaglen = max(tagseqlens)
|
171
|
-
|
172
|
-
# Prepare new list with padded tensors
|
173
|
-
batchoftensors = []
|
174
|
-
for i, ten in enumerate(list_of_items):
|
175
|
-
tagtensor = ten
|
176
|
-
newtagtensor = torch.zeros(
|
177
|
-
1, maxbatchtaglen, dtype=torch.long, device=self._device
|
178
|
-
)
|
179
|
-
newtagtensor[:, : tagtensor.size()[1]] = tagtensor
|
180
|
-
batchoftensors.append(newtagtensor)
|
181
|
-
|
182
|
-
return batchoftensors
|
183
|
-
|
184
|
-
def bcol(self, batch):
|
185
|
-
r"""
|
186
|
-
Custom collate fucntion for Pytorch DataLoader, to collate items prepared
|
187
|
-
by TFDataset into batches
|
188
|
-
|
189
|
-
Parameters
|
190
|
-
----------
|
191
|
-
batch : list
|
192
|
-
The list of samples, obtained by DataLoader from the sampler
|
193
|
-
a non-collated batch
|
194
|
-
|
195
|
-
Returns
|
196
|
-
-------
|
197
|
-
batchoftensors: tuple
|
198
|
-
Tuple of lists of items, properly collated into batches
|
199
|
-
"""
|
200
|
-
collated = {}
|
201
|
-
test_gt = {}
|
202
|
-
if bool(batch[0]["test_gt"]):
|
203
|
-
for x in batch:
|
204
|
-
test_gt.update(x["test_gt"])
|
205
|
-
|
206
|
-
if len(batch) > 1:
|
207
|
-
# In case batch length is more than 1 we want to collate all elements in batch
|
208
|
-
# Every element has it's own rule how to collate
|
209
|
-
|
210
|
-
if self._load_cells:
|
211
|
-
cells = [
|
212
|
-
item for sublist in [x["cells"] for x in batch] for item in sublist
|
213
|
-
]
|
214
|
-
else:
|
215
|
-
cells = []
|
216
|
-
cell_lens = [
|
217
|
-
item for sublist in [x["cell_lens"] for x in batch] for item in sublist
|
218
|
-
]
|
219
|
-
cell_bboxes = []
|
220
|
-
|
221
|
-
if "cell_bboxes" in batch[0] and batch[0]["cell_bboxes"]:
|
222
|
-
for x in batch:
|
223
|
-
cell_bboxes.append(
|
224
|
-
{
|
225
|
-
"boxes": x["cell_bboxes"][0][0],
|
226
|
-
"labels": x["cell_bboxes"][0][1],
|
227
|
-
}
|
228
|
-
)
|
229
|
-
|
230
|
-
samples = [
|
231
|
-
item for sublist in [x["samples"] for x in batch] for item in sublist
|
232
|
-
]
|
233
|
-
|
234
|
-
# Sequences of tags have to be padded and then collated:
|
235
|
-
batchoftags = self._pad_tensors_from_batch(batch, "tags")
|
236
|
-
tags = pad_sequence(batchoftags, batch_first=True)
|
237
|
-
tags = torch.squeeze(tags, 1)
|
238
|
-
|
239
|
-
tag_lens = pad_sequence([x["tag_lens"] for x in batch], batch_first=True)
|
240
|
-
tag_lens = torch.squeeze(tag_lens, 1)
|
241
|
-
num_cells = pad_sequence([x["num_cells"] for x in batch], batch_first=True)
|
242
|
-
num_cells = torch.squeeze(num_cells, 1)
|
243
|
-
|
244
|
-
collated = (
|
245
|
-
torch.cat([x["imgs"] for x in batch], dim=0),
|
246
|
-
tags,
|
247
|
-
tag_lens,
|
248
|
-
num_cells,
|
249
|
-
cells,
|
250
|
-
cell_lens,
|
251
|
-
cell_bboxes,
|
252
|
-
samples,
|
253
|
-
test_gt,
|
254
|
-
)
|
255
|
-
else:
|
256
|
-
# In case batch length is 1 we just formulate the expected output
|
257
|
-
cell_bboxes = None
|
258
|
-
if "cell_bboxes" in batch[0] and batch[0]["cell_bboxes"]:
|
259
|
-
cell_bboxes = {
|
260
|
-
"boxes": batch[0]["cell_bboxes"][0][0],
|
261
|
-
"labels": batch[0]["cell_bboxes"][0][1],
|
262
|
-
}
|
263
|
-
|
264
|
-
collated = (
|
265
|
-
batch[0]["imgs"],
|
266
|
-
batch[0]["tags"],
|
267
|
-
batch[0]["tag_lens"],
|
268
|
-
batch[0]["num_cells"],
|
269
|
-
batch[0]["cells"],
|
270
|
-
batch[0]["cell_lens"],
|
271
|
-
cell_bboxes,
|
272
|
-
batch[0]["samples"],
|
273
|
-
test_gt,
|
274
|
-
)
|
275
|
-
return collated
|
276
|
-
|
277
|
-
def __getitem__(self, idx):
|
278
|
-
r"""
|
279
|
-
Retrieve data by the specific index from the cache
|
280
|
-
Required for Pytorch DataSampler, and to be used together with DataLoader
|
281
|
-
Depending on the "purpose" different returned objects can be None.
|
282
|
-
|
283
|
-
Returns
|
284
|
-
(All data points presented in a dictionary, each wrapped into a list,
|
285
|
-
for easier batching and collating later on)
|
286
|
-
-------
|
287
|
-
imgs : tensor (batch_size, image_channels, resized_image, resized_image)
|
288
|
-
Batch with the rescaled images
|
289
|
-
tags : The tags of the images. The object is one of:
|
290
|
-
None : If purpose is not "train"
|
291
|
-
(batch_size, max_tag_len + 2) : If purpose is "train" and "fixed_padding" is true.
|
292
|
-
The +2 over the max_tag_len is for the <start> <stop>
|
293
|
-
(batch_size, batch_max_tag_len) : If purpose is "train" and "fixed_padding" is false,
|
294
|
-
where "batch_max_tag_len" is the max length of the
|
295
|
-
tags in batch
|
296
|
-
tag_lens : The real length of the tags per image in the batch. The object is one of:
|
297
|
-
None : If purpose is not "train"
|
298
|
-
(batch_size, 1) : If purpose is "train"
|
299
|
-
num_cells : The number of cells per image in the batch. The object is one of:
|
300
|
-
None : If purpose is not "train"
|
301
|
-
(batch_size, 1) : If purpose is "train"
|
302
|
-
cells : The cell tags for the images in the batch. The object is one of:
|
303
|
-
None : If purpose is not "train"
|
304
|
-
list with LongTensor: If purpose is "train"
|
305
|
-
cell_lens : The length of the cell tags per image in the batch. The object is one of:
|
306
|
-
None : If purpose is not "train"
|
307
|
-
list with LongTensor: If purpose is "train"
|
308
|
-
cell_bboxes : The transformed (rescaled, padded, etc) bboxes of the cells for all images in
|
309
|
-
batch. Each list is a bbox in the format [xc, yc, w, h] where xc, yc are the
|
310
|
-
coords of the center, w, h the width and height of the bbox and all are
|
311
|
-
normalized to the scaled size of the image. The object is one of:
|
312
|
-
None : If purpose is not "train"
|
313
|
-
list of lists: If purpose is "train"
|
314
|
-
samples : list of string
|
315
|
-
The filenames in the batch
|
316
|
-
val_gt : The ground turth raw attributes for the validation split. The object is one of:
|
317
|
-
None : If purpose is not "val"
|
318
|
-
dictionary : If purpose is "val"
|
319
|
-
"""
|
320
|
-
|
321
|
-
# Check if index out of bounds
|
322
|
-
if idx >= self._dataset_size:
|
323
|
-
return None
|
324
|
-
|
325
|
-
# Move current _index to requested idx
|
326
|
-
# self._index = idx
|
327
|
-
|
328
|
-
# Images, __getitem__ provides only 1 image for specified index
|
329
|
-
# (batch_size, image_channels, resized_image, resized_image)
|
330
|
-
imgs = torch.zeros(
|
331
|
-
1, self._img_ch, self._resized_image, self._resized_image, dtype=torch.float
|
332
|
-
).to(self._device)
|
333
|
-
|
334
|
-
# Initialize all output objects to None
|
335
|
-
# Depending on the "purpose" some of them will be populated
|
336
|
-
tags = None
|
337
|
-
tag_lens = None
|
338
|
-
num_cells = None
|
339
|
-
cells = None
|
340
|
-
cell_lens = None
|
341
|
-
cell_bboxes = None
|
342
|
-
|
343
|
-
val_tags = None
|
344
|
-
val_tag_lens = None
|
345
|
-
val_num_cells = None
|
346
|
-
val_cells = None
|
347
|
-
val_cell_lens = None
|
348
|
-
val_cell_bboxes = None
|
349
|
-
|
350
|
-
test_gt = {}
|
351
|
-
test_tags = None
|
352
|
-
test_tag_lens = None
|
353
|
-
|
354
|
-
# Train specific output
|
355
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
356
|
-
tag_len = self._taglens[idx]
|
357
|
-
fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
|
358
|
-
if self._fixed_padding:
|
359
|
-
tags = torch.zeros(
|
360
|
-
1, fixed_tag_len, dtype=torch.long, device=self._device
|
361
|
-
)
|
362
|
-
else:
|
363
|
-
tags = torch.zeros(1, tag_len, dtype=torch.long, device=self._device)
|
364
|
-
|
365
|
-
tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
|
366
|
-
num_cells = torch.zeros(1, 1, dtype=torch.long).to(self._device)
|
367
|
-
cells = []
|
368
|
-
cell_lens = []
|
369
|
-
cell_bboxes = []
|
370
|
-
# val specific output
|
371
|
-
elif self._purpose == s.VAL_PURPOSE:
|
372
|
-
val_tag_len = self._val_taglens[idx]
|
373
|
-
val_fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
|
374
|
-
if self._fixed_padding:
|
375
|
-
val_tags = torch.zeros(
|
376
|
-
1, val_fixed_tag_len, dtype=torch.long, device=self._device
|
377
|
-
)
|
378
|
-
else:
|
379
|
-
val_tags = torch.zeros(
|
380
|
-
1, val_tag_len, dtype=torch.long, device=self._device
|
381
|
-
)
|
382
|
-
|
383
|
-
val_tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
|
384
|
-
val_num_cells = torch.zeros(1, 1, dtype=torch.long).to(self._device)
|
385
|
-
val_cells = []
|
386
|
-
val_cell_lens = []
|
387
|
-
val_cell_bboxes = []
|
388
|
-
|
389
|
-
elif self._purpose == s.TEST_PURPOSE:
|
390
|
-
if len(self._test_taglens) > 0:
|
391
|
-
# Dictionary with the raw attributes for the groundtruth. Keys are the filenames
|
392
|
-
test_gt = {}
|
393
|
-
tag_len = self._test_taglens[idx]
|
394
|
-
fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
|
395
|
-
if self._fixed_padding:
|
396
|
-
test_tags = torch.zeros(
|
397
|
-
1, fixed_tag_len, dtype=torch.long, device=self._device
|
398
|
-
)
|
399
|
-
else:
|
400
|
-
test_tags = torch.zeros(
|
401
|
-
1, tag_len, dtype=torch.long, device=self._device
|
402
|
-
)
|
403
|
-
test_tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
|
404
|
-
cells = []
|
405
|
-
cell_lens = []
|
406
|
-
cell_bboxes = []
|
407
|
-
|
408
|
-
sample = self._image_fns[idx]
|
409
|
-
# Rescale/convert the image and bboxes
|
410
|
-
bboxes = self._bboxes[sample]
|
411
|
-
sample_fn = self._get_image_path(sample)
|
412
|
-
|
413
|
-
if not self._table_bboxes:
|
414
|
-
table_bbox = None
|
415
|
-
else:
|
416
|
-
if sample in self._table_bboxes:
|
417
|
-
table_bbox = self._table_bboxes[sample]
|
418
|
-
else:
|
419
|
-
table_bbox = None
|
420
|
-
scaled_img, scaled_bboxes = self._transformer.sample_preprocessor(
|
421
|
-
sample_fn, bboxes, self._purpose, table_bbox
|
422
|
-
)
|
423
|
-
|
424
|
-
imgs[0] = scaled_img.to(self._device)
|
425
|
-
|
426
|
-
# Train specific output
|
427
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
428
|
-
# Remove the padding from tags and cells
|
429
|
-
if self._fixed_padding:
|
430
|
-
tags[0] = torch.LongTensor(self._tags[idx]).to(self._device)
|
431
|
-
else:
|
432
|
-
tags[0] = torch.LongTensor(self._tags[idx][:tag_len]).to(self._device)
|
433
|
-
|
434
|
-
tag_lens[0] = torch.LongTensor([self._taglens[idx]]).to(self._device)
|
435
|
-
num_cells[0] = len(self._cell_lens[idx])
|
436
|
-
|
437
|
-
if len(self._cell_lens[idx]) > 0:
|
438
|
-
sample_max_cell_len = max(self._cell_lens[idx])
|
439
|
-
else:
|
440
|
-
sample_max_cell_len = 0
|
441
|
-
|
442
|
-
if self._load_cells:
|
443
|
-
image_trimmed_cells = [
|
444
|
-
self._cells[idx][x][0:sample_max_cell_len]
|
445
|
-
for x in range(0, len(self._cells[idx]))
|
446
|
-
]
|
447
|
-
cells.append(torch.LongTensor(image_trimmed_cells).to(self._device))
|
448
|
-
|
449
|
-
cell_lens.append(torch.LongTensor(self._cell_lens[idx]).to(self._device))
|
450
|
-
if self._train_bbox:
|
451
|
-
|
452
|
-
cell_bboxes.append(
|
453
|
-
[
|
454
|
-
torch.from_numpy(
|
455
|
-
np.array(scaled_bboxes, dtype=np.float32)[:, :4]
|
456
|
-
).to(self._device),
|
457
|
-
torch.from_numpy(
|
458
|
-
np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
|
459
|
-
).to(self._device),
|
460
|
-
]
|
461
|
-
)
|
462
|
-
|
463
|
-
elif self._purpose == s.VAL_PURPOSE:
|
464
|
-
# Remove the padding from tags and cells
|
465
|
-
if self._fixed_padding:
|
466
|
-
val_tags[0] = torch.LongTensor(self._val_tags[idx]).to(self._device)
|
467
|
-
else:
|
468
|
-
val_tags[0] = torch.LongTensor(self._val_tags[idx][:val_tag_len]).to(
|
469
|
-
self._device
|
470
|
-
)
|
471
|
-
|
472
|
-
val_tag_lens[0] = torch.LongTensor([self._val_taglens[idx]]).to(
|
473
|
-
self._device
|
474
|
-
)
|
475
|
-
val_num_cells[0] = len(self._val_cell_lens[idx])
|
476
|
-
|
477
|
-
if len(self._val_cell_lens[idx]) > 0:
|
478
|
-
sample_max_cell_len = max(self._val_cell_lens[idx])
|
479
|
-
else:
|
480
|
-
sample_max_cell_len = 0
|
481
|
-
|
482
|
-
if self._load_cells:
|
483
|
-
val_image_trimmed_cells = [
|
484
|
-
self._val_cells[idx][x][0:sample_max_cell_len]
|
485
|
-
for x in range(0, len(self._cells[idx]))
|
486
|
-
]
|
487
|
-
val_cells.append(
|
488
|
-
torch.LongTensor(val_image_trimmed_cells).to(self._device)
|
489
|
-
)
|
490
|
-
|
491
|
-
val_cell_lens.append(
|
492
|
-
torch.LongTensor(self._val_cell_lens[idx]).to(self._device)
|
493
|
-
)
|
494
|
-
if self._train_bbox:
|
495
|
-
val_cell_bboxes.append(
|
496
|
-
[
|
497
|
-
torch.from_numpy(
|
498
|
-
np.array(scaled_bboxes, dtype=np.float32)[:, :4]
|
499
|
-
).to(self._device),
|
500
|
-
torch.from_numpy(
|
501
|
-
np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
|
502
|
-
).to(self._device),
|
503
|
-
]
|
504
|
-
)
|
505
|
-
# val specific output
|
506
|
-
elif self._purpose == s.TEST_PURPOSE:
|
507
|
-
if test_gt is not None:
|
508
|
-
test_gt[sample] = self._test[sample]
|
509
|
-
|
510
|
-
# Remove the padding from tags and cells
|
511
|
-
if self._fixed_padding:
|
512
|
-
test_tags[0] = torch.LongTensor(self._test_tags[idx]).to(
|
513
|
-
self._device
|
514
|
-
)
|
515
|
-
else:
|
516
|
-
test_tags[0] = torch.LongTensor(self._test_tags[idx][:tag_len]).to(
|
517
|
-
self._device
|
518
|
-
)
|
519
|
-
test_tag_lens[0] = torch.LongTensor([self._test_taglens[idx]]).to(
|
520
|
-
self._device
|
521
|
-
)
|
522
|
-
if self._predict_bbox:
|
523
|
-
cell_bboxes.append(
|
524
|
-
[
|
525
|
-
torch.from_numpy(
|
526
|
-
np.array(scaled_bboxes, dtype=np.float32)[:, :4]
|
527
|
-
).to(self._device),
|
528
|
-
torch.from_numpy(
|
529
|
-
np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
|
530
|
-
).to(self._device),
|
531
|
-
]
|
532
|
-
)
|
533
|
-
|
534
|
-
output = {}
|
535
|
-
|
536
|
-
# Samples is a list with the given image filename
|
537
|
-
samples = [self._image_fns[idx]]
|
538
|
-
# All data points presented in a dictionary, each wrapped into a list,
|
539
|
-
# for easier batching and collating later on
|
540
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
541
|
-
output["imgs"] = imgs
|
542
|
-
output["tags"] = tags
|
543
|
-
output["tag_lens"] = tag_lens
|
544
|
-
output["num_cells"] = num_cells
|
545
|
-
output["cells"] = cells
|
546
|
-
output["cell_lens"] = cell_lens
|
547
|
-
output["cell_bboxes"] = cell_bboxes
|
548
|
-
output["samples"] = samples
|
549
|
-
output["test_gt"] = test_gt
|
550
|
-
elif self._purpose == s.VAL_PURPOSE:
|
551
|
-
output["imgs"] = imgs
|
552
|
-
output["tags"] = val_tags
|
553
|
-
output["tag_lens"] = val_tag_lens
|
554
|
-
output["num_cells"] = val_num_cells
|
555
|
-
output["cells"] = val_cells
|
556
|
-
output["cell_lens"] = val_cell_lens
|
557
|
-
output["cell_bboxes"] = val_cell_bboxes
|
558
|
-
output["samples"] = samples
|
559
|
-
output["test_gt"] = test_gt
|
560
|
-
elif self._purpose == s.TEST_PURPOSE:
|
561
|
-
output["imgs"] = imgs
|
562
|
-
output["tags"] = test_tags
|
563
|
-
output["tag_lens"] = test_tag_lens
|
564
|
-
output["num_cells"] = num_cells
|
565
|
-
output["cells"] = cells
|
566
|
-
output["cell_lens"] = cell_lens
|
567
|
-
output["cell_bboxes"] = cell_bboxes
|
568
|
-
output["samples"] = samples
|
569
|
-
output["test_gt"] = test_gt
|
570
|
-
else:
|
571
|
-
output["imgs"] = imgs
|
572
|
-
output["tags"] = tags
|
573
|
-
output["tag_lens"] = tag_lens
|
574
|
-
output["num_cells"] = num_cells
|
575
|
-
output["cells"] = cells
|
576
|
-
output["cell_lens"] = cell_lens
|
577
|
-
output["cell_bboxes"] = cell_bboxes
|
578
|
-
output["samples"] = samples
|
579
|
-
output["test_gt"] = test_gt
|
580
|
-
return output
|
581
|
-
|
582
|
-
def get_batch_size(self):
|
583
|
-
r"""
|
584
|
-
Return the actual batch_size
|
585
|
-
"""
|
586
|
-
return self._batch_size
|
587
|
-
|
588
|
-
def reset(self):
|
589
|
-
self._index = 0
|
590
|
-
|
591
|
-
def __iter__(self):
|
592
|
-
return self
|
593
|
-
|
594
|
-
def is_valid(self, img, config):
|
595
|
-
max_tag_len = config["preparation"]["max_tag_len"]
|
596
|
-
max_cell_len = config["preparation"]["max_cell_len"]
|
597
|
-
check_limits = True
|
598
|
-
if "check_limits" in config["preparation"]:
|
599
|
-
check_limits = config["preparation"]["check_limits"]
|
600
|
-
if check_limits:
|
601
|
-
if len(img["html"]["structure"]["tokens"]) > max_tag_len:
|
602
|
-
self._log().debug(
|
603
|
-
"========================================= TAG LEN REJECTED"
|
604
|
-
)
|
605
|
-
self._log().debug("File name: {}".format(img["filename"]))
|
606
|
-
tokens_len = len(img["html"]["structure"]["tokens"])
|
607
|
-
self._log().debug("Structure token len: {}".format(tokens_len))
|
608
|
-
self._log().debug("More than max_tag_len: {}".format(max_tag_len))
|
609
|
-
self._log().debug(
|
610
|
-
"=========================================================="
|
611
|
-
)
|
612
|
-
return False
|
613
|
-
for cell in img["html"]["cells"]:
|
614
|
-
if len(cell["tokens"]) > max_cell_len:
|
615
|
-
self._log().debug(
|
616
|
-
"======================================== CELL LEN REJECTED"
|
617
|
-
)
|
618
|
-
self._log().debug("File name: {}".format(img["filename"]))
|
619
|
-
self._log().debug("Cell len: {}".format(len(cell["tokens"])))
|
620
|
-
self._log().debug("More than max_cell_len: {}".format(max_cell_len))
|
621
|
-
self._log().debug(
|
622
|
-
"=========================================================="
|
623
|
-
)
|
624
|
-
return False
|
625
|
-
self.raw_data_dir = config["preparation"]["raw_data_dir"]
|
626
|
-
with Image.open(
|
627
|
-
os.path.join(self.raw_data_dir, img["split"], img["filename"])
|
628
|
-
) as im:
|
629
|
-
max_image_size = config["preparation"]["max_image_size"]
|
630
|
-
if im.width > max_image_size or im.height > max_image_size:
|
631
|
-
# IMG SIZE REJECTED
|
632
|
-
return False
|
633
|
-
return True
|
634
|
-
|
635
|
-
def __next__(self):
|
636
|
-
r"""
|
637
|
-
Get the next batch or raise the StopIteration
|
638
|
-
|
639
|
-
In order to have the batch size fixed also in the last iteration, we wrap over the dataset
|
640
|
-
and repeat some of the first elements.
|
641
|
-
|
642
|
-
Depending on the "purpose" different returned objects can be None.
|
643
|
-
|
644
|
-
Returns
|
645
|
-
-------
|
646
|
-
imgs : tensor (batch_size, image_channels, resized_image, resized_image)
|
647
|
-
Batch with the rescaled images
|
648
|
-
tags : The tags of the images. The object is one of:
|
649
|
-
None : If purpose is not "train"
|
650
|
-
(batch_size, max_tag_len + 2) : If purpose is "train" and "fixed_padding" is true.
|
651
|
-
The +2 over the max_tag_len is for the <start> <stop>
|
652
|
-
(batch_size, batch_max_tag_len) : If purpose is "train" and "fixed_padding" is false,
|
653
|
-
where "batch_max_tag_len" is the max length of the
|
654
|
-
tags in batch
|
655
|
-
tag_lens : The real length of the tags per image in the batch. The object is one of:
|
656
|
-
None : If purpose is not "train"
|
657
|
-
(batch_size, 1) : If purpose is "train"
|
658
|
-
num_cells : The number of cells per image in the batch. The object is one of:
|
659
|
-
None : If purpose is not "train"
|
660
|
-
(batch_size, 1) : If purpose is "train"
|
661
|
-
cells : The cell tags for the images in the batch. The object is one of:
|
662
|
-
None : If purpose is not "train"
|
663
|
-
list with LongTensor: If purpose is "train"
|
664
|
-
cell_lens : The length of the cell tags per image in the batch. The object is one of:
|
665
|
-
None : If purpose is not "train"
|
666
|
-
list with LongTensor: If purpose is "train"
|
667
|
-
cell_bboxes : The transformed (rescaled, padded, etc) bboxes of the cells for all images in
|
668
|
-
batch. Each list is a bbox in the format [xc, yc, w, h] where xc, yc are the
|
669
|
-
coords of the center, w, h the width and height of the bbox and all are
|
670
|
-
normalized to the scaled size of the image. The object is one of:
|
671
|
-
None : If purpose is not "train"
|
672
|
-
list of lists: If purpose is "train"
|
673
|
-
samples : list of string
|
674
|
-
The filenames in the batch
|
675
|
-
val_gt : The ground turth raw attributes for the validation split. The object is one of:
|
676
|
-
None : If purpose is not "val"
|
677
|
-
dictionary : If purpose is "val"
|
678
|
-
"""
|
679
|
-
|
680
|
-
if self._index >= self._dataset_size:
|
681
|
-
raise StopIteration()
|
682
|
-
|
683
|
-
# Compute the next sample
|
684
|
-
if (
|
685
|
-
self._dataset_size - self._index >= self._batch_size
|
686
|
-
): # Full batch_size sample
|
687
|
-
step = self._batch_size
|
688
|
-
sample_indices = self._ind[self._index : self._index + step]
|
689
|
-
else:
|
690
|
-
# skip last batch
|
691
|
-
raise StopIteration()
|
692
|
-
|
693
|
-
batch = []
|
694
|
-
# Loop over the batch indices and collect items for the batch
|
695
|
-
for i, idx in enumerate(sample_indices):
|
696
|
-
item = self.__getitem__(idx)
|
697
|
-
batch.append(item)
|
698
|
-
# Collate batch
|
699
|
-
output = self.bcol(batch)
|
700
|
-
self._index += step
|
701
|
-
|
702
|
-
return output
|
703
|
-
|
704
|
-
def shuffle(self):
|
705
|
-
r"""
|
706
|
-
Shuffle the training images
|
707
|
-
This takes place only in case of training, otherwise it just returns
|
708
|
-
|
709
|
-
Output: True in case a shuffling took place, False otherwise
|
710
|
-
"""
|
711
|
-
if self._purpose != s.TRAIN_PURPOSE:
|
712
|
-
return False
|
713
|
-
|
714
|
-
# image_fns_np = np.asarray(self._image_fns)
|
715
|
-
# To get a deterministic random shuffle, we need to seed our random
|
716
|
-
# with a deterministic seed (int)
|
717
|
-
np.random.seed(42)
|
718
|
-
# Then shuffle after seeding
|
719
|
-
self._ind = np.random.permutation(self._dataset_size)
|
720
|
-
self._index = 0
|
721
|
-
return True
|
722
|
-
|
723
|
-
def get_init_data(self):
|
724
|
-
r"""
|
725
|
-
Create a dictionary with all kind of initialization data necessary for all models.
|
726
|
-
This data should not be served by the __next__ method.
|
727
|
-
"""
|
728
|
-
init_data = {"word_map": self._word_map, "statistics": self._statistics}
|
729
|
-
return init_data
|
730
|
-
|
731
|
-
def _get_image_path(self, img_fn):
|
732
|
-
r"""
|
733
|
-
Get the full image path out of the image file name
|
734
|
-
"""
|
735
|
-
if self._dataset_type == "TF_prepared":
|
736
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
737
|
-
full_fn = os.path.join(self._raw_data_dir, "train", img_fn)
|
738
|
-
elif self._purpose == s.VAL_PURPOSE:
|
739
|
-
full_fn = os.path.join(self._raw_data_dir, "val", img_fn)
|
740
|
-
elif self._purpose == s.TEST_PURPOSE:
|
741
|
-
full_fn = os.path.join(self._raw_data_dir, "test", img_fn)
|
742
|
-
else:
|
743
|
-
full_fn = os.path.join(self._raw_data_dir, img_fn)
|
744
|
-
|
745
|
-
if full_fn is None or not os.path.isfile(full_fn):
|
746
|
-
self._log().error("File not found: " + full_fn)
|
747
|
-
return None
|
748
|
-
|
749
|
-
return full_fn
|
750
|
-
|
751
|
-
def format_html(self, img):
|
752
|
-
r"""
|
753
|
-
Formats HTML code from tokenized annotation of img
|
754
|
-
"""
|
755
|
-
tag_len = len(img["html"]["structure"]["tokens"])
|
756
|
-
if self._load_cells:
|
757
|
-
cell_len_max = max([len(c["tokens"]) for c in img["html"]["cells"]])
|
758
|
-
else:
|
759
|
-
cell_len_max = 0
|
760
|
-
|
761
|
-
HTML = img["html"]["structure"]["tokens"].copy()
|
762
|
-
to_insert = [i for i, tag in enumerate(HTML) if tag in ("<td>", ">")]
|
763
|
-
|
764
|
-
if self._load_cells:
|
765
|
-
for i, cell in zip(to_insert[::-1], img["html"]["cells"][::-1]):
|
766
|
-
if cell:
|
767
|
-
cell = "".join(
|
768
|
-
[
|
769
|
-
escape(token) if len(token) == 1 else token
|
770
|
-
for token in cell["tokens"]
|
771
|
-
]
|
772
|
-
)
|
773
|
-
HTML.insert(i + 1, cell)
|
774
|
-
|
775
|
-
HTML = "<html><body><table>%s</table></body></html>" % "".join(HTML)
|
776
|
-
root = html.fromstring(HTML)
|
777
|
-
if self._predict_bbox:
|
778
|
-
for td, cell in zip(root.iter("td"), img["html"]["cells"]):
|
779
|
-
if "bbox" in cell:
|
780
|
-
bbox = cell["bbox"]
|
781
|
-
td.attrib["x"] = str(bbox[0])
|
782
|
-
td.attrib["y"] = str(bbox[1])
|
783
|
-
td.attrib["width"] = str(bbox[2] - bbox[0])
|
784
|
-
td.attrib["height"] = str(bbox[3] - bbox[1])
|
785
|
-
HTML = html.tostring(root, encoding="utf-8").decode()
|
786
|
-
return HTML, tag_len, cell_len_max
|
787
|
-
|
788
|
-
def _build_predict_cache(self):
|
789
|
-
r"""
|
790
|
-
populate cache with image file names that need to be predicted
|
791
|
-
"""
|
792
|
-
self._prepared_data_dir = c.safe_get_parameter(
|
793
|
-
self._config, ["dataset", "prepared_data_dir"], required=False
|
794
|
-
)
|
795
|
-
self._data_name = c.safe_get_parameter(
|
796
|
-
self._config, ["dataset", "name"], required=True
|
797
|
-
)
|
798
|
-
|
799
|
-
if self._prepared_data_dir is None:
|
800
|
-
|
801
|
-
self._statistics = c.safe_get_parameter(
|
802
|
-
self._config, ["dataset", "image_normalization"], required=True
|
803
|
-
)
|
804
|
-
|
805
|
-
self._word_map = c.safe_get_parameter(
|
806
|
-
self._config, ["dataset_wordmap"], required=True
|
807
|
-
)
|
808
|
-
|
809
|
-
else:
|
810
|
-
# Load statistics
|
811
|
-
statistics_fn = c.get_prepared_data_filename("STATISTICS", self._data_name)
|
812
|
-
with open(os.path.join(self._prepared_data_dir, statistics_fn), "r") as f:
|
813
|
-
self._log().debug("Load statistics from: {}".format(statistics_fn))
|
814
|
-
self._statistics = json.load(f)
|
815
|
-
|
816
|
-
# Load word_map
|
817
|
-
word_map_fn = c.get_prepared_data_filename("WORDMAP", self._data_name)
|
818
|
-
with open(os.path.join(self._prepared_data_dir, word_map_fn), "r") as f:
|
819
|
-
self._log().debug("Load WORDMAP from: {}".format(word_map_fn))
|
820
|
-
self._word_map = json.load(f)
|
821
|
-
|
822
|
-
# Get Image File Names for Prediction
|
823
|
-
self._image_fns = []
|
824
|
-
self._bboxes = {}
|
825
|
-
self._bboxes_table = {}
|
826
|
-
self._raw_data_dir = self._predict_dir
|
827
|
-
self._table_bboxes = {}
|
828
|
-
|
829
|
-
if self._predict_dir[-1] != "/":
|
830
|
-
self._predict_dir += "/"
|
831
|
-
for path in list(glob(self._predict_dir + "**/*.png", recursive=True)):
|
832
|
-
filename = os.path.basename(path)
|
833
|
-
self._image_fns.append(filename)
|
834
|
-
self._log().info("Image found: {}".format(filename))
|
835
|
-
self._bboxes[filename] = []
|
836
|
-
|
837
|
-
# Get size of a dataset to predict
|
838
|
-
self._dataset_size = len(self._image_fns)
|
839
|
-
|
840
|
-
# Get the number of image channels
|
841
|
-
self._log().info(
|
842
|
-
"To test load... {}".format(self._predict_dir + self._image_fns[0])
|
843
|
-
)
|
844
|
-
img = u.load_image(self._predict_dir + self._image_fns[0])
|
845
|
-
if img is None:
|
846
|
-
msg = "Cannot load image"
|
847
|
-
self._log().error(msg)
|
848
|
-
raise Exception(msg)
|
849
|
-
self._img_ch = img.shape[0]
|
850
|
-
|
851
|
-
def _build_cache(self):
|
852
|
-
r"""
|
853
|
-
Cache with small data
|
854
|
-
"""
|
855
|
-
all_bboxes = {} # Keep original bboxes for all images
|
856
|
-
table_bboxes = {}
|
857
|
-
self._log().info("Building the cache...")
|
858
|
-
self._raw_data_dir = c.safe_get_parameter(
|
859
|
-
self._config, ["dataset", "raw_data_dir"], required=True
|
860
|
-
)
|
861
|
-
self._prepared_data_dir = c.safe_get_parameter(
|
862
|
-
self._config, ["dataset", "prepared_data_dir"], required=False
|
863
|
-
)
|
864
|
-
|
865
|
-
self._data_name = c.safe_get_parameter(
|
866
|
-
self._config, ["dataset", "name"], required=True
|
867
|
-
)
|
868
|
-
|
869
|
-
if self._prepared_data_dir is None:
|
870
|
-
|
871
|
-
self._statistics = c.safe_get_parameter(
|
872
|
-
self._config, ["dataset", "image_normalization"], required=True
|
873
|
-
)
|
874
|
-
|
875
|
-
self._word_map = c.safe_get_parameter(
|
876
|
-
self._config, ["dataset_wordmap"], required=True
|
877
|
-
)
|
878
|
-
|
879
|
-
else:
|
880
|
-
# Load statistics
|
881
|
-
statistics_fn = c.get_prepared_data_filename("STATISTICS", self._data_name)
|
882
|
-
with open(os.path.join(self._prepared_data_dir, statistics_fn), "r") as f:
|
883
|
-
self._log().debug("Load statistics from: {}".format(statistics_fn))
|
884
|
-
self._statistics = json.load(f)
|
885
|
-
|
886
|
-
# Load word_map
|
887
|
-
word_map_fn = c.get_prepared_data_filename("WORDMAP", self._data_name)
|
888
|
-
with open(os.path.join(self._prepared_data_dir, word_map_fn), "r") as f:
|
889
|
-
self._log().debug("Load WORDMAP from: {}".format(word_map_fn))
|
890
|
-
self._word_map = json.load(f)
|
891
|
-
|
892
|
-
word_map_cell = self._word_map["word_map_cell"]
|
893
|
-
word_map_tag = self._word_map["word_map_tag"]
|
894
|
-
# Read image paths and captions for each image
|
895
|
-
train_image_paths = []
|
896
|
-
train_images = []
|
897
|
-
val_image_paths = []
|
898
|
-
val_images = []
|
899
|
-
test_image_paths = []
|
900
|
-
test_images = []
|
901
|
-
predict_images = []
|
902
|
-
|
903
|
-
train_image_tags = (
|
904
|
-
[]
|
905
|
-
) # List of list of structure tokens for each image in the train set
|
906
|
-
train_image_cells = []
|
907
|
-
train_image_cells_len = []
|
908
|
-
|
909
|
-
val_image_tags = (
|
910
|
-
[]
|
911
|
-
) # List of list of structure tokens for each image in the train set
|
912
|
-
val_image_cells = []
|
913
|
-
val_image_cells_len = []
|
914
|
-
|
915
|
-
test_image_tags = []
|
916
|
-
test_gt = dict()
|
917
|
-
|
918
|
-
invalid_files = 0
|
919
|
-
total_files = 0
|
920
|
-
|
921
|
-
self._log().debug(
|
922
|
-
"Create lists with image filenames per split, train tags/cells and GT"
|
923
|
-
)
|
924
|
-
with jsonlines.open(self.annotation, "r") as reader:
|
925
|
-
for img in tqdm(reader):
|
926
|
-
total_files += 1
|
927
|
-
img_filename = img["filename"]
|
928
|
-
path = os.path.join(self._raw_data_dir, img["split"], img_filename)
|
929
|
-
|
930
|
-
# Keep bboxes for all images
|
931
|
-
all_cell_bboxes = []
|
932
|
-
for cell in img["html"]["cells"]:
|
933
|
-
if "bbox" not in cell:
|
934
|
-
continue
|
935
|
-
all_cell_bboxes.append(cell["bbox"])
|
936
|
-
all_bboxes[img_filename] = all_cell_bboxes
|
937
|
-
|
938
|
-
# if dataset does include bbox for the table itself
|
939
|
-
if "table_bbox" in img:
|
940
|
-
table_bboxes[img_filename] = img["table_bbox"]
|
941
|
-
if img["split"] == "train":
|
942
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
943
|
-
# Skip invalid images
|
944
|
-
if not self.is_valid(img, self._config):
|
945
|
-
invalid_files += 1
|
946
|
-
continue
|
947
|
-
tags = []
|
948
|
-
cells = []
|
949
|
-
cell_lens = []
|
950
|
-
tags.append(img["html"]["structure"]["tokens"])
|
951
|
-
|
952
|
-
if self._load_cells:
|
953
|
-
for cell in img["html"]["cells"]:
|
954
|
-
cells.append(cell["tokens"])
|
955
|
-
cell_lens.append(len(cell["tokens"]) + 2)
|
956
|
-
else:
|
957
|
-
for cell in img["html"]["cells"]:
|
958
|
-
cell_lens.append(len(cell["tokens"]) + 2)
|
959
|
-
|
960
|
-
train_images.append(img_filename)
|
961
|
-
train_image_paths.append(path)
|
962
|
-
train_image_tags.append(tags)
|
963
|
-
train_image_cells.append(cells)
|
964
|
-
train_image_cells_len.append(cell_lens)
|
965
|
-
if img["split"] == "val":
|
966
|
-
if self._purpose == s.VAL_PURPOSE:
|
967
|
-
# Skip invalid images
|
968
|
-
if not self.is_valid(img, self._config):
|
969
|
-
invalid_files += 1
|
970
|
-
continue
|
971
|
-
|
972
|
-
val_tags = []
|
973
|
-
val_cells = []
|
974
|
-
val_cell_lens = []
|
975
|
-
val_tags.append(img["html"]["structure"]["tokens"])
|
976
|
-
|
977
|
-
if self._load_cells:
|
978
|
-
for cell in img["html"]["cells"]:
|
979
|
-
val_cells.append(cell["tokens"])
|
980
|
-
val_cell_lens.append(len(cell["tokens"]) + 2)
|
981
|
-
else:
|
982
|
-
for cell in img["html"]["cells"]:
|
983
|
-
val_cell_lens.append(len(cell["tokens"]) + 2)
|
984
|
-
|
985
|
-
with Image.open(path) as im:
|
986
|
-
HTML, tag_len, cell_len_max = self.format_html(img)
|
987
|
-
lt1 = [">", "lcel", "ucel", "xcel"]
|
988
|
-
lt2 = img["html"]["structure"]["tokens"]
|
989
|
-
tcheck = any(item in lt1 for item in lt2)
|
990
|
-
if tcheck:
|
991
|
-
gtt = "complex"
|
992
|
-
else:
|
993
|
-
gtt = "simple"
|
994
|
-
test_gt[img_filename] = {
|
995
|
-
"html": HTML,
|
996
|
-
"tag_len": tag_len,
|
997
|
-
"cell_len_max": cell_len_max,
|
998
|
-
"width": im.width,
|
999
|
-
"height": im.height,
|
1000
|
-
"type": gtt,
|
1001
|
-
"html_tags": img["html"]["structure"]["tokens"],
|
1002
|
-
}
|
1003
|
-
|
1004
|
-
val_images.append(img_filename)
|
1005
|
-
val_image_paths.append(path)
|
1006
|
-
val_image_tags.append(val_tags)
|
1007
|
-
val_image_cells.append(val_cells)
|
1008
|
-
val_image_cells_len.append(val_cell_lens)
|
1009
|
-
|
1010
|
-
elif img["split"] == "test":
|
1011
|
-
if self._purpose == s.TEST_PURPOSE:
|
1012
|
-
# Skip invalid images
|
1013
|
-
if not self.is_valid(img, self._config):
|
1014
|
-
invalid_files += 1
|
1015
|
-
continue
|
1016
|
-
|
1017
|
-
with Image.open(path) as im:
|
1018
|
-
HTML, tag_len, cell_len_max = self.format_html(img)
|
1019
|
-
lt1 = [">", "lcel", "ucel", "xcel"]
|
1020
|
-
lt2 = img["html"]["structure"]["tokens"]
|
1021
|
-
tcheck = any(item in lt1 for item in lt2)
|
1022
|
-
if tcheck:
|
1023
|
-
gtt = "complex"
|
1024
|
-
else:
|
1025
|
-
gtt = "simple"
|
1026
|
-
test_gt[img_filename] = {
|
1027
|
-
"html": HTML,
|
1028
|
-
"tag_len": tag_len,
|
1029
|
-
"cell_len_max": cell_len_max,
|
1030
|
-
"width": im.width,
|
1031
|
-
"height": im.height,
|
1032
|
-
"type": gtt,
|
1033
|
-
"html_tags": img["html"]["structure"]["tokens"],
|
1034
|
-
}
|
1035
|
-
test_images.append(img_filename)
|
1036
|
-
|
1037
|
-
test_tags = []
|
1038
|
-
test_tags.append(img["html"]["structure"]["tokens"])
|
1039
|
-
test_image_paths.append(path)
|
1040
|
-
test_image_tags.append(test_tags)
|
1041
|
-
else:
|
1042
|
-
if self._purpose == s.PREDICT_PURPOSE:
|
1043
|
-
predict_images.append(img_filename)
|
1044
|
-
|
1045
|
-
image_fns = {
|
1046
|
-
s.TRAIN_SPLIT: train_images,
|
1047
|
-
s.VAL_SPLIT: val_images,
|
1048
|
-
s.TEST_SPLIT: test_images,
|
1049
|
-
}
|
1050
|
-
|
1051
|
-
self._log().debug("Keep the split data pointed by the purpose")
|
1052
|
-
# Images
|
1053
|
-
# Filter out the images for the particual split
|
1054
|
-
self._image_fns = image_fns[self._purpose]
|
1055
|
-
self._dataset_size = len(self._image_fns)
|
1056
|
-
assert len(self._image_fns) > 0, "Empty image split: " + self._purpose
|
1057
|
-
|
1058
|
-
# Get the number of image channels
|
1059
|
-
img = u.load_image(self._get_image_path(self._image_fns[0]))
|
1060
|
-
if img is None:
|
1061
|
-
msg = "Cannot load image"
|
1062
|
-
self._log().error(msg)
|
1063
|
-
raise Exception(msg)
|
1064
|
-
self._img_ch = img.shape[0]
|
1065
|
-
|
1066
|
-
# img_name -> list of bboxes, each bbox is a list with x1y1x2y2
|
1067
|
-
split_bboxes = {}
|
1068
|
-
img_names = set(self._image_fns) # Set will speed up search
|
1069
|
-
for img_name, bbox in all_bboxes.items():
|
1070
|
-
if img_name not in img_names:
|
1071
|
-
continue
|
1072
|
-
if img_name not in split_bboxes:
|
1073
|
-
split_bboxes[img_name] = []
|
1074
|
-
# we should use extend not append, otherwise we get list within list
|
1075
|
-
split_bboxes[img_name].extend(bbox)
|
1076
|
-
self._bboxes = split_bboxes
|
1077
|
-
self._table_bboxes = table_bboxes
|
1078
|
-
# -------------------------------------------------------------------------------
|
1079
|
-
# Train specific
|
1080
|
-
# -------------------------------------------------------------------------------
|
1081
|
-
# Compute encoded tags and cells
|
1082
|
-
enc_tags = []
|
1083
|
-
tag_lens = []
|
1084
|
-
enc_cells = []
|
1085
|
-
cell_lens = []
|
1086
|
-
|
1087
|
-
val_enc_tags = []
|
1088
|
-
val_tag_lens = []
|
1089
|
-
val_enc_cells = []
|
1090
|
-
val_cell_lens = []
|
1091
|
-
|
1092
|
-
test_enc_tags = []
|
1093
|
-
test_tag_lens = []
|
1094
|
-
|
1095
|
-
# Based on the "purpose"
|
1096
|
-
if self._purpose == s.TRAIN_PURPOSE:
|
1097
|
-
self._log().debug("Convert train tags and cell tags to indices")
|
1098
|
-
for i, path in enumerate(tqdm(train_image_paths)):
|
1099
|
-
for tag in train_image_tags[i]:
|
1100
|
-
# Encode tags
|
1101
|
-
# Notice that at this point we don't have images longer than max_tag_length
|
1102
|
-
# The same happens with the cell tokens
|
1103
|
-
enc_tag = (
|
1104
|
-
[word_map_tag["<start>"]]
|
1105
|
-
+ [
|
1106
|
-
word_map_tag.get(word, word_map_tag["<unk>"])
|
1107
|
-
for word in tag
|
1108
|
-
]
|
1109
|
-
+ [word_map_tag["<end>"]]
|
1110
|
-
+ [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
|
1111
|
-
)
|
1112
|
-
# Find caption lengths
|
1113
|
-
tag_len = len(tag) + 2
|
1114
|
-
|
1115
|
-
enc_tags.append(enc_tag)
|
1116
|
-
tag_lens.append(tag_len)
|
1117
|
-
|
1118
|
-
enc_cell_seq = []
|
1119
|
-
cell_seq_len = []
|
1120
|
-
|
1121
|
-
if self._load_cells:
|
1122
|
-
for cell in train_image_cells[i]:
|
1123
|
-
# Encode captions
|
1124
|
-
enc_cell = (
|
1125
|
-
[word_map_cell["<start>"]]
|
1126
|
-
+ [
|
1127
|
-
word_map_cell.get(word, word_map_cell["<unk>"])
|
1128
|
-
for word in cell
|
1129
|
-
]
|
1130
|
-
+ [word_map_cell["<end>"]]
|
1131
|
-
+ [word_map_cell["<pad>"]]
|
1132
|
-
* (self._max_cell_len - len(cell))
|
1133
|
-
)
|
1134
|
-
enc_cell_seq.append(enc_cell)
|
1135
|
-
# Find caption lengths
|
1136
|
-
cell_len = len(cell) + 2
|
1137
|
-
cell_seq_len.append(cell_len)
|
1138
|
-
else:
|
1139
|
-
for cell in train_image_cells_len[i]:
|
1140
|
-
cell_seq_len.append(cell)
|
1141
|
-
enc_cells.append(enc_cell_seq)
|
1142
|
-
cell_lens.append(cell_seq_len)
|
1143
|
-
|
1144
|
-
if self._purpose == s.VAL_PURPOSE:
|
1145
|
-
self._log().debug("Convert train tags and cell tags to indices")
|
1146
|
-
for i, path in enumerate(tqdm(val_image_paths)):
|
1147
|
-
for tag in val_image_tags[i]:
|
1148
|
-
# Encode tags
|
1149
|
-
# Notice that at this point we don't have images longer than max_tag_length
|
1150
|
-
# The same happens with the cell tokens
|
1151
|
-
val_enc_tag = (
|
1152
|
-
[word_map_tag["<start>"]]
|
1153
|
-
+ [
|
1154
|
-
word_map_tag.get(word, word_map_tag["<unk>"])
|
1155
|
-
for word in tag
|
1156
|
-
]
|
1157
|
-
+ [word_map_tag["<end>"]]
|
1158
|
-
+ [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
|
1159
|
-
)
|
1160
|
-
# Find caption lengths
|
1161
|
-
val_tag_len = len(tag) + 2
|
1162
|
-
|
1163
|
-
val_enc_tags.append(val_enc_tag)
|
1164
|
-
val_tag_lens.append(val_tag_len)
|
1165
|
-
|
1166
|
-
val_enc_cell_seq = []
|
1167
|
-
val_cell_seq_len = []
|
1168
|
-
|
1169
|
-
if self._load_cells:
|
1170
|
-
for cell in val_image_cells[i]:
|
1171
|
-
# Encode captions
|
1172
|
-
val_enc_cell = (
|
1173
|
-
[word_map_cell["<start>"]]
|
1174
|
-
+ [
|
1175
|
-
word_map_cell.get(word, word_map_cell["<unk>"])
|
1176
|
-
for word in cell
|
1177
|
-
]
|
1178
|
-
+ [word_map_cell["<end>"]]
|
1179
|
-
+ [word_map_cell["<pad>"]]
|
1180
|
-
* (self._max_cell_len - len(cell))
|
1181
|
-
)
|
1182
|
-
val_enc_cell_seq.append(val_enc_cell)
|
1183
|
-
|
1184
|
-
# Find caption lengths
|
1185
|
-
cell_len = len(cell) + 2
|
1186
|
-
val_cell_seq_len.append(cell_len)
|
1187
|
-
else:
|
1188
|
-
for cell in val_image_cells_len[i]:
|
1189
|
-
val_cell_seq_len.append(cell)
|
1190
|
-
val_enc_cells.append(val_enc_cell_seq)
|
1191
|
-
val_cell_lens.append(val_cell_seq_len)
|
1192
|
-
|
1193
|
-
if self._purpose == s.TEST_PURPOSE:
|
1194
|
-
self._log().debug("Convert val tags to indices")
|
1195
|
-
for i, path in enumerate(tqdm(test_image_paths)):
|
1196
|
-
for tag in test_image_tags[i]:
|
1197
|
-
# Encode tags
|
1198
|
-
# Notice that at this point we don't have images longer than max_tag_length
|
1199
|
-
# The same happens with the cell tokens
|
1200
|
-
test_enc_tag = (
|
1201
|
-
[word_map_tag["<start>"]]
|
1202
|
-
+ [
|
1203
|
-
word_map_tag.get(word, word_map_tag["<unk>"])
|
1204
|
-
for word in tag
|
1205
|
-
]
|
1206
|
-
+ [word_map_tag["<end>"]]
|
1207
|
-
+ [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
|
1208
|
-
)
|
1209
|
-
|
1210
|
-
# Find caption lengths
|
1211
|
-
test_tag_len = len(tag) + 2
|
1212
|
-
|
1213
|
-
test_enc_tags.append(test_enc_tag)
|
1214
|
-
test_tag_lens.append(test_tag_len)
|
1215
|
-
|
1216
|
-
self._tags = enc_tags
|
1217
|
-
self._taglens = tag_lens
|
1218
|
-
self._cells = enc_cells
|
1219
|
-
self._cell_lens = cell_lens
|
1220
|
-
|
1221
|
-
# -------------------------------------------------------------------------------
|
1222
|
-
# val specific
|
1223
|
-
# -------------------------------------------------------------------------------
|
1224
|
-
self._val_tags = val_enc_tags
|
1225
|
-
self._val_taglens = val_tag_lens
|
1226
|
-
self._val_cells = val_enc_cells
|
1227
|
-
self._val_cell_lens = val_cell_lens
|
1228
|
-
# -------------------------------------------------------------------------------
|
1229
|
-
# test / evaluation specific
|
1230
|
-
# -------------------------------------------------------------------------------
|
1231
|
-
self._test = test_gt
|
1232
|
-
self._test_tags = test_enc_tags
|
1233
|
-
self._test_taglens = test_tag_lens
|