docling-ibm-models 1.3.0__py3-none-any.whl → 1.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,1233 +0,0 @@
1
- #
2
- # Copyright IBM Corp. 2024 - 2024
3
- # SPDX-License-Identifier: MIT
4
- #
5
- import json
6
- import logging
7
- import os
8
- from glob import glob
9
- from html import escape
10
-
11
- import jsonlines
12
- import numpy as np
13
- import torch
14
- import torch.utils.data
15
- from lxml import html
16
- from PIL import Image
17
- from torch.nn.utils.rnn import pad_sequence
18
- from torch.utils.data import Dataset
19
- from tqdm import tqdm
20
-
21
- import docling_ibm_models.tableformer.common as c
22
- import docling_ibm_models.tableformer.settings as s
23
- import docling_ibm_models.tableformer.utils.utils as u
24
- from docling_ibm_models.tableformer.data_management.data_transformer import (
25
- DataTransformer,
26
- )
27
-
28
- LOG_LEVEL = logging.INFO
29
- # LOG_LEVEL = logging.DEBUG
30
-
31
-
32
- class TFDataset(Dataset):
33
- def __init__(self, config, purpose, fixed_padding=False):
34
- r"""
35
- Parameters
36
- ----------
37
- config : Dictionary
38
- The input configuration file
39
- purpose : string
40
- One of s.TRAIN_PURPOSE, s.VAL_PURPOSE, s.PREDICT_PURPOSE
41
- fixed_padding : bool
42
- If False (default), the produced tag sequences will be truncated to the maximum
43
- actual length among the tag sequences of the batch
44
- If True the produced tag and cell sequences will have fixed length equal to
45
- max_tag_len and max_cell_len respectively.
46
- """
47
- self.cml_task = {}
48
- self.cml_logger = {}
49
- self._config = config
50
- self._fixed_padding = fixed_padding
51
- self._index = 0 # Index to the current image file
52
- self._max_tag_len = c.safe_get_parameter(
53
- config, ["preparation", "max_tag_len"], required=True
54
- )
55
- self._max_cell_len = c.safe_get_parameter(
56
- config, ["preparation", "max_cell_len"], required=True
57
- )
58
- self._resized_image = c.safe_get_parameter(
59
- config, ["dataset", "resized_image"], required=True
60
- )
61
- self.annotation = c.safe_get_parameter(
62
- config, ["preparation", "annotation"], required=True
63
- )
64
- self._image_normalization = config["dataset"]["image_normalization"]
65
-
66
- self._load_cells = c.safe_get_parameter(
67
- config, ["dataset", "load_cells"], required=False
68
- )
69
- self._predict_dir = c.safe_get_parameter(
70
- config, ["predict", "predict_dir"], required=False
71
- )
72
- self._train_bbox = c.safe_get_parameter(
73
- config, ["train", "bbox"], required=False
74
- )
75
- self._predict_bbox = c.safe_get_parameter(
76
- config, ["predict", "bbox"], required=False
77
- )
78
-
79
- self._log().debug("purpose: {}".format(purpose))
80
- self._log().debug("resized_image: {}".format(self._resized_image))
81
- self._log().debug("image_normalization: {}".format(self._image_normalization))
82
-
83
- # Check the type of the dataset
84
- dataset_type = c.safe_get_parameter(config, ["dataset", "type"])
85
- if dataset_type not in s.supported_datasets:
86
- msg = "Unsupported dataset type: " + dataset_type
87
- self._log().error(msg)
88
- raise NotImplementedError(msg)
89
- self._dataset_type = dataset_type
90
-
91
- # Check the purpose of the object
92
- if purpose not in [
93
- s.TRAIN_PURPOSE,
94
- s.VAL_PURPOSE,
95
- s.TEST_PURPOSE,
96
- s.PREDICT_PURPOSE,
97
- ]:
98
- msg = "Unsupported purpose: " + purpose
99
- self._log().error(msg)
100
- raise Exception(msg)
101
- self._purpose = purpose
102
-
103
- # The batch_size is grounded to 1 in case of VAL, PREDICT
104
- if purpose == s.TRAIN_PURPOSE:
105
- self._batch_size = c.safe_get_parameter(
106
- config, ["train", "batch_size"], required=True
107
- )
108
- else:
109
- self._batch_size = 1
110
- self._log().debug("batch_size: {}".format(self._batch_size))
111
-
112
- self._transformer = DataTransformer(config)
113
- if purpose == s.PREDICT_PURPOSE:
114
- self._build_predict_cache()
115
- else:
116
- self._build_cache()
117
-
118
- self._index = 0 # Index to the current image file
119
- self._ind = np.array(range(self._dataset_size))
120
-
121
- def set_device(self, device):
122
- r"""
123
- Set the device to be used to place the tensors when looping over the data
124
-
125
- Parameters
126
- ----------
127
- device : torch.device or int
128
- The device to do the training
129
- """
130
- self._device = device
131
-
132
- def _log(self):
133
- # Setup a custom logger
134
- return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
135
-
136
- def __len__(self):
137
- return int(self._dataset_size)
138
-
139
- def get_batch_num(self):
140
- return int(np.ceil(self._dataset_size / self._batch_size))
141
-
142
- def _pad_tensors_from_batch(self, batch, field):
143
- r"""
144
- Retrieves data about tensor - "field" from the raw DataLoader batch list
145
- And pads it to maximum batch length, to be further collated in the
146
- custom collator function
147
-
148
- Parameters
149
- ----------
150
- batch : list
151
- The list of samples, obtained by DataLoader from the sampler
152
- a non-collated batch
153
- field: string
154
- Name of the tensor data-point from the sample, that has to be collated,
155
- and because of that has to be padded to max length of the tensor in the batch
156
-
157
- Returns
158
- -------
159
- batchoftensors: list
160
- The list of padded tensors
161
- """
162
- # Make list of selected items by field name
163
- list_of_items = [x[field] for x in batch]
164
- tagseqlens = []
165
- # Identify lengths of tensor "lists"
166
- for i, ten in enumerate(list_of_items):
167
- tagnum = ten[0].size()[0]
168
- tagseqlens.append(tagnum)
169
- # Get the biggest length
170
- maxbatchtaglen = max(tagseqlens)
171
-
172
- # Prepare new list with padded tensors
173
- batchoftensors = []
174
- for i, ten in enumerate(list_of_items):
175
- tagtensor = ten
176
- newtagtensor = torch.zeros(
177
- 1, maxbatchtaglen, dtype=torch.long, device=self._device
178
- )
179
- newtagtensor[:, : tagtensor.size()[1]] = tagtensor
180
- batchoftensors.append(newtagtensor)
181
-
182
- return batchoftensors
183
-
184
- def bcol(self, batch):
185
- r"""
186
- Custom collate fucntion for Pytorch DataLoader, to collate items prepared
187
- by TFDataset into batches
188
-
189
- Parameters
190
- ----------
191
- batch : list
192
- The list of samples, obtained by DataLoader from the sampler
193
- a non-collated batch
194
-
195
- Returns
196
- -------
197
- batchoftensors: tuple
198
- Tuple of lists of items, properly collated into batches
199
- """
200
- collated = {}
201
- test_gt = {}
202
- if bool(batch[0]["test_gt"]):
203
- for x in batch:
204
- test_gt.update(x["test_gt"])
205
-
206
- if len(batch) > 1:
207
- # In case batch length is more than 1 we want to collate all elements in batch
208
- # Every element has it's own rule how to collate
209
-
210
- if self._load_cells:
211
- cells = [
212
- item for sublist in [x["cells"] for x in batch] for item in sublist
213
- ]
214
- else:
215
- cells = []
216
- cell_lens = [
217
- item for sublist in [x["cell_lens"] for x in batch] for item in sublist
218
- ]
219
- cell_bboxes = []
220
-
221
- if "cell_bboxes" in batch[0] and batch[0]["cell_bboxes"]:
222
- for x in batch:
223
- cell_bboxes.append(
224
- {
225
- "boxes": x["cell_bboxes"][0][0],
226
- "labels": x["cell_bboxes"][0][1],
227
- }
228
- )
229
-
230
- samples = [
231
- item for sublist in [x["samples"] for x in batch] for item in sublist
232
- ]
233
-
234
- # Sequences of tags have to be padded and then collated:
235
- batchoftags = self._pad_tensors_from_batch(batch, "tags")
236
- tags = pad_sequence(batchoftags, batch_first=True)
237
- tags = torch.squeeze(tags, 1)
238
-
239
- tag_lens = pad_sequence([x["tag_lens"] for x in batch], batch_first=True)
240
- tag_lens = torch.squeeze(tag_lens, 1)
241
- num_cells = pad_sequence([x["num_cells"] for x in batch], batch_first=True)
242
- num_cells = torch.squeeze(num_cells, 1)
243
-
244
- collated = (
245
- torch.cat([x["imgs"] for x in batch], dim=0),
246
- tags,
247
- tag_lens,
248
- num_cells,
249
- cells,
250
- cell_lens,
251
- cell_bboxes,
252
- samples,
253
- test_gt,
254
- )
255
- else:
256
- # In case batch length is 1 we just formulate the expected output
257
- cell_bboxes = None
258
- if "cell_bboxes" in batch[0] and batch[0]["cell_bboxes"]:
259
- cell_bboxes = {
260
- "boxes": batch[0]["cell_bboxes"][0][0],
261
- "labels": batch[0]["cell_bboxes"][0][1],
262
- }
263
-
264
- collated = (
265
- batch[0]["imgs"],
266
- batch[0]["tags"],
267
- batch[0]["tag_lens"],
268
- batch[0]["num_cells"],
269
- batch[0]["cells"],
270
- batch[0]["cell_lens"],
271
- cell_bboxes,
272
- batch[0]["samples"],
273
- test_gt,
274
- )
275
- return collated
276
-
277
- def __getitem__(self, idx):
278
- r"""
279
- Retrieve data by the specific index from the cache
280
- Required for Pytorch DataSampler, and to be used together with DataLoader
281
- Depending on the "purpose" different returned objects can be None.
282
-
283
- Returns
284
- (All data points presented in a dictionary, each wrapped into a list,
285
- for easier batching and collating later on)
286
- -------
287
- imgs : tensor (batch_size, image_channels, resized_image, resized_image)
288
- Batch with the rescaled images
289
- tags : The tags of the images. The object is one of:
290
- None : If purpose is not "train"
291
- (batch_size, max_tag_len + 2) : If purpose is "train" and "fixed_padding" is true.
292
- The +2 over the max_tag_len is for the <start> <stop>
293
- (batch_size, batch_max_tag_len) : If purpose is "train" and "fixed_padding" is false,
294
- where "batch_max_tag_len" is the max length of the
295
- tags in batch
296
- tag_lens : The real length of the tags per image in the batch. The object is one of:
297
- None : If purpose is not "train"
298
- (batch_size, 1) : If purpose is "train"
299
- num_cells : The number of cells per image in the batch. The object is one of:
300
- None : If purpose is not "train"
301
- (batch_size, 1) : If purpose is "train"
302
- cells : The cell tags for the images in the batch. The object is one of:
303
- None : If purpose is not "train"
304
- list with LongTensor: If purpose is "train"
305
- cell_lens : The length of the cell tags per image in the batch. The object is one of:
306
- None : If purpose is not "train"
307
- list with LongTensor: If purpose is "train"
308
- cell_bboxes : The transformed (rescaled, padded, etc) bboxes of the cells for all images in
309
- batch. Each list is a bbox in the format [xc, yc, w, h] where xc, yc are the
310
- coords of the center, w, h the width and height of the bbox and all are
311
- normalized to the scaled size of the image. The object is one of:
312
- None : If purpose is not "train"
313
- list of lists: If purpose is "train"
314
- samples : list of string
315
- The filenames in the batch
316
- val_gt : The ground turth raw attributes for the validation split. The object is one of:
317
- None : If purpose is not "val"
318
- dictionary : If purpose is "val"
319
- """
320
-
321
- # Check if index out of bounds
322
- if idx >= self._dataset_size:
323
- return None
324
-
325
- # Move current _index to requested idx
326
- # self._index = idx
327
-
328
- # Images, __getitem__ provides only 1 image for specified index
329
- # (batch_size, image_channels, resized_image, resized_image)
330
- imgs = torch.zeros(
331
- 1, self._img_ch, self._resized_image, self._resized_image, dtype=torch.float
332
- ).to(self._device)
333
-
334
- # Initialize all output objects to None
335
- # Depending on the "purpose" some of them will be populated
336
- tags = None
337
- tag_lens = None
338
- num_cells = None
339
- cells = None
340
- cell_lens = None
341
- cell_bboxes = None
342
-
343
- val_tags = None
344
- val_tag_lens = None
345
- val_num_cells = None
346
- val_cells = None
347
- val_cell_lens = None
348
- val_cell_bboxes = None
349
-
350
- test_gt = {}
351
- test_tags = None
352
- test_tag_lens = None
353
-
354
- # Train specific output
355
- if self._purpose == s.TRAIN_PURPOSE:
356
- tag_len = self._taglens[idx]
357
- fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
358
- if self._fixed_padding:
359
- tags = torch.zeros(
360
- 1, fixed_tag_len, dtype=torch.long, device=self._device
361
- )
362
- else:
363
- tags = torch.zeros(1, tag_len, dtype=torch.long, device=self._device)
364
-
365
- tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
366
- num_cells = torch.zeros(1, 1, dtype=torch.long).to(self._device)
367
- cells = []
368
- cell_lens = []
369
- cell_bboxes = []
370
- # val specific output
371
- elif self._purpose == s.VAL_PURPOSE:
372
- val_tag_len = self._val_taglens[idx]
373
- val_fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
374
- if self._fixed_padding:
375
- val_tags = torch.zeros(
376
- 1, val_fixed_tag_len, dtype=torch.long, device=self._device
377
- )
378
- else:
379
- val_tags = torch.zeros(
380
- 1, val_tag_len, dtype=torch.long, device=self._device
381
- )
382
-
383
- val_tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
384
- val_num_cells = torch.zeros(1, 1, dtype=torch.long).to(self._device)
385
- val_cells = []
386
- val_cell_lens = []
387
- val_cell_bboxes = []
388
-
389
- elif self._purpose == s.TEST_PURPOSE:
390
- if len(self._test_taglens) > 0:
391
- # Dictionary with the raw attributes for the groundtruth. Keys are the filenames
392
- test_gt = {}
393
- tag_len = self._test_taglens[idx]
394
- fixed_tag_len = self._max_tag_len + 2 # <start>...<end>
395
- if self._fixed_padding:
396
- test_tags = torch.zeros(
397
- 1, fixed_tag_len, dtype=torch.long, device=self._device
398
- )
399
- else:
400
- test_tags = torch.zeros(
401
- 1, tag_len, dtype=torch.long, device=self._device
402
- )
403
- test_tag_lens = torch.zeros(1, 1, dtype=torch.long).to(self._device)
404
- cells = []
405
- cell_lens = []
406
- cell_bboxes = []
407
-
408
- sample = self._image_fns[idx]
409
- # Rescale/convert the image and bboxes
410
- bboxes = self._bboxes[sample]
411
- sample_fn = self._get_image_path(sample)
412
-
413
- if not self._table_bboxes:
414
- table_bbox = None
415
- else:
416
- if sample in self._table_bboxes:
417
- table_bbox = self._table_bboxes[sample]
418
- else:
419
- table_bbox = None
420
- scaled_img, scaled_bboxes = self._transformer.sample_preprocessor(
421
- sample_fn, bboxes, self._purpose, table_bbox
422
- )
423
-
424
- imgs[0] = scaled_img.to(self._device)
425
-
426
- # Train specific output
427
- if self._purpose == s.TRAIN_PURPOSE:
428
- # Remove the padding from tags and cells
429
- if self._fixed_padding:
430
- tags[0] = torch.LongTensor(self._tags[idx]).to(self._device)
431
- else:
432
- tags[0] = torch.LongTensor(self._tags[idx][:tag_len]).to(self._device)
433
-
434
- tag_lens[0] = torch.LongTensor([self._taglens[idx]]).to(self._device)
435
- num_cells[0] = len(self._cell_lens[idx])
436
-
437
- if len(self._cell_lens[idx]) > 0:
438
- sample_max_cell_len = max(self._cell_lens[idx])
439
- else:
440
- sample_max_cell_len = 0
441
-
442
- if self._load_cells:
443
- image_trimmed_cells = [
444
- self._cells[idx][x][0:sample_max_cell_len]
445
- for x in range(0, len(self._cells[idx]))
446
- ]
447
- cells.append(torch.LongTensor(image_trimmed_cells).to(self._device))
448
-
449
- cell_lens.append(torch.LongTensor(self._cell_lens[idx]).to(self._device))
450
- if self._train_bbox:
451
-
452
- cell_bboxes.append(
453
- [
454
- torch.from_numpy(
455
- np.array(scaled_bboxes, dtype=np.float32)[:, :4]
456
- ).to(self._device),
457
- torch.from_numpy(
458
- np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
459
- ).to(self._device),
460
- ]
461
- )
462
-
463
- elif self._purpose == s.VAL_PURPOSE:
464
- # Remove the padding from tags and cells
465
- if self._fixed_padding:
466
- val_tags[0] = torch.LongTensor(self._val_tags[idx]).to(self._device)
467
- else:
468
- val_tags[0] = torch.LongTensor(self._val_tags[idx][:val_tag_len]).to(
469
- self._device
470
- )
471
-
472
- val_tag_lens[0] = torch.LongTensor([self._val_taglens[idx]]).to(
473
- self._device
474
- )
475
- val_num_cells[0] = len(self._val_cell_lens[idx])
476
-
477
- if len(self._val_cell_lens[idx]) > 0:
478
- sample_max_cell_len = max(self._val_cell_lens[idx])
479
- else:
480
- sample_max_cell_len = 0
481
-
482
- if self._load_cells:
483
- val_image_trimmed_cells = [
484
- self._val_cells[idx][x][0:sample_max_cell_len]
485
- for x in range(0, len(self._cells[idx]))
486
- ]
487
- val_cells.append(
488
- torch.LongTensor(val_image_trimmed_cells).to(self._device)
489
- )
490
-
491
- val_cell_lens.append(
492
- torch.LongTensor(self._val_cell_lens[idx]).to(self._device)
493
- )
494
- if self._train_bbox:
495
- val_cell_bboxes.append(
496
- [
497
- torch.from_numpy(
498
- np.array(scaled_bboxes, dtype=np.float32)[:, :4]
499
- ).to(self._device),
500
- torch.from_numpy(
501
- np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
502
- ).to(self._device),
503
- ]
504
- )
505
- # val specific output
506
- elif self._purpose == s.TEST_PURPOSE:
507
- if test_gt is not None:
508
- test_gt[sample] = self._test[sample]
509
-
510
- # Remove the padding from tags and cells
511
- if self._fixed_padding:
512
- test_tags[0] = torch.LongTensor(self._test_tags[idx]).to(
513
- self._device
514
- )
515
- else:
516
- test_tags[0] = torch.LongTensor(self._test_tags[idx][:tag_len]).to(
517
- self._device
518
- )
519
- test_tag_lens[0] = torch.LongTensor([self._test_taglens[idx]]).to(
520
- self._device
521
- )
522
- if self._predict_bbox:
523
- cell_bboxes.append(
524
- [
525
- torch.from_numpy(
526
- np.array(scaled_bboxes, dtype=np.float32)[:, :4]
527
- ).to(self._device),
528
- torch.from_numpy(
529
- np.array(scaled_bboxes, dtype=np.compat.long)[:, -1]
530
- ).to(self._device),
531
- ]
532
- )
533
-
534
- output = {}
535
-
536
- # Samples is a list with the given image filename
537
- samples = [self._image_fns[idx]]
538
- # All data points presented in a dictionary, each wrapped into a list,
539
- # for easier batching and collating later on
540
- if self._purpose == s.TRAIN_PURPOSE:
541
- output["imgs"] = imgs
542
- output["tags"] = tags
543
- output["tag_lens"] = tag_lens
544
- output["num_cells"] = num_cells
545
- output["cells"] = cells
546
- output["cell_lens"] = cell_lens
547
- output["cell_bboxes"] = cell_bboxes
548
- output["samples"] = samples
549
- output["test_gt"] = test_gt
550
- elif self._purpose == s.VAL_PURPOSE:
551
- output["imgs"] = imgs
552
- output["tags"] = val_tags
553
- output["tag_lens"] = val_tag_lens
554
- output["num_cells"] = val_num_cells
555
- output["cells"] = val_cells
556
- output["cell_lens"] = val_cell_lens
557
- output["cell_bboxes"] = val_cell_bboxes
558
- output["samples"] = samples
559
- output["test_gt"] = test_gt
560
- elif self._purpose == s.TEST_PURPOSE:
561
- output["imgs"] = imgs
562
- output["tags"] = test_tags
563
- output["tag_lens"] = test_tag_lens
564
- output["num_cells"] = num_cells
565
- output["cells"] = cells
566
- output["cell_lens"] = cell_lens
567
- output["cell_bboxes"] = cell_bboxes
568
- output["samples"] = samples
569
- output["test_gt"] = test_gt
570
- else:
571
- output["imgs"] = imgs
572
- output["tags"] = tags
573
- output["tag_lens"] = tag_lens
574
- output["num_cells"] = num_cells
575
- output["cells"] = cells
576
- output["cell_lens"] = cell_lens
577
- output["cell_bboxes"] = cell_bboxes
578
- output["samples"] = samples
579
- output["test_gt"] = test_gt
580
- return output
581
-
582
- def get_batch_size(self):
583
- r"""
584
- Return the actual batch_size
585
- """
586
- return self._batch_size
587
-
588
- def reset(self):
589
- self._index = 0
590
-
591
- def __iter__(self):
592
- return self
593
-
594
- def is_valid(self, img, config):
595
- max_tag_len = config["preparation"]["max_tag_len"]
596
- max_cell_len = config["preparation"]["max_cell_len"]
597
- check_limits = True
598
- if "check_limits" in config["preparation"]:
599
- check_limits = config["preparation"]["check_limits"]
600
- if check_limits:
601
- if len(img["html"]["structure"]["tokens"]) > max_tag_len:
602
- self._log().debug(
603
- "========================================= TAG LEN REJECTED"
604
- )
605
- self._log().debug("File name: {}".format(img["filename"]))
606
- tokens_len = len(img["html"]["structure"]["tokens"])
607
- self._log().debug("Structure token len: {}".format(tokens_len))
608
- self._log().debug("More than max_tag_len: {}".format(max_tag_len))
609
- self._log().debug(
610
- "=========================================================="
611
- )
612
- return False
613
- for cell in img["html"]["cells"]:
614
- if len(cell["tokens"]) > max_cell_len:
615
- self._log().debug(
616
- "======================================== CELL LEN REJECTED"
617
- )
618
- self._log().debug("File name: {}".format(img["filename"]))
619
- self._log().debug("Cell len: {}".format(len(cell["tokens"])))
620
- self._log().debug("More than max_cell_len: {}".format(max_cell_len))
621
- self._log().debug(
622
- "=========================================================="
623
- )
624
- return False
625
- self.raw_data_dir = config["preparation"]["raw_data_dir"]
626
- with Image.open(
627
- os.path.join(self.raw_data_dir, img["split"], img["filename"])
628
- ) as im:
629
- max_image_size = config["preparation"]["max_image_size"]
630
- if im.width > max_image_size or im.height > max_image_size:
631
- # IMG SIZE REJECTED
632
- return False
633
- return True
634
-
635
- def __next__(self):
636
- r"""
637
- Get the next batch or raise the StopIteration
638
-
639
- In order to have the batch size fixed also in the last iteration, we wrap over the dataset
640
- and repeat some of the first elements.
641
-
642
- Depending on the "purpose" different returned objects can be None.
643
-
644
- Returns
645
- -------
646
- imgs : tensor (batch_size, image_channels, resized_image, resized_image)
647
- Batch with the rescaled images
648
- tags : The tags of the images. The object is one of:
649
- None : If purpose is not "train"
650
- (batch_size, max_tag_len + 2) : If purpose is "train" and "fixed_padding" is true.
651
- The +2 over the max_tag_len is for the <start> <stop>
652
- (batch_size, batch_max_tag_len) : If purpose is "train" and "fixed_padding" is false,
653
- where "batch_max_tag_len" is the max length of the
654
- tags in batch
655
- tag_lens : The real length of the tags per image in the batch. The object is one of:
656
- None : If purpose is not "train"
657
- (batch_size, 1) : If purpose is "train"
658
- num_cells : The number of cells per image in the batch. The object is one of:
659
- None : If purpose is not "train"
660
- (batch_size, 1) : If purpose is "train"
661
- cells : The cell tags for the images in the batch. The object is one of:
662
- None : If purpose is not "train"
663
- list with LongTensor: If purpose is "train"
664
- cell_lens : The length of the cell tags per image in the batch. The object is one of:
665
- None : If purpose is not "train"
666
- list with LongTensor: If purpose is "train"
667
- cell_bboxes : The transformed (rescaled, padded, etc) bboxes of the cells for all images in
668
- batch. Each list is a bbox in the format [xc, yc, w, h] where xc, yc are the
669
- coords of the center, w, h the width and height of the bbox and all are
670
- normalized to the scaled size of the image. The object is one of:
671
- None : If purpose is not "train"
672
- list of lists: If purpose is "train"
673
- samples : list of string
674
- The filenames in the batch
675
- val_gt : The ground turth raw attributes for the validation split. The object is one of:
676
- None : If purpose is not "val"
677
- dictionary : If purpose is "val"
678
- """
679
-
680
- if self._index >= self._dataset_size:
681
- raise StopIteration()
682
-
683
- # Compute the next sample
684
- if (
685
- self._dataset_size - self._index >= self._batch_size
686
- ): # Full batch_size sample
687
- step = self._batch_size
688
- sample_indices = self._ind[self._index : self._index + step]
689
- else:
690
- # skip last batch
691
- raise StopIteration()
692
-
693
- batch = []
694
- # Loop over the batch indices and collect items for the batch
695
- for i, idx in enumerate(sample_indices):
696
- item = self.__getitem__(idx)
697
- batch.append(item)
698
- # Collate batch
699
- output = self.bcol(batch)
700
- self._index += step
701
-
702
- return output
703
-
704
- def shuffle(self):
705
- r"""
706
- Shuffle the training images
707
- This takes place only in case of training, otherwise it just returns
708
-
709
- Output: True in case a shuffling took place, False otherwise
710
- """
711
- if self._purpose != s.TRAIN_PURPOSE:
712
- return False
713
-
714
- # image_fns_np = np.asarray(self._image_fns)
715
- # To get a deterministic random shuffle, we need to seed our random
716
- # with a deterministic seed (int)
717
- np.random.seed(42)
718
- # Then shuffle after seeding
719
- self._ind = np.random.permutation(self._dataset_size)
720
- self._index = 0
721
- return True
722
-
723
- def get_init_data(self):
724
- r"""
725
- Create a dictionary with all kind of initialization data necessary for all models.
726
- This data should not be served by the __next__ method.
727
- """
728
- init_data = {"word_map": self._word_map, "statistics": self._statistics}
729
- return init_data
730
-
731
- def _get_image_path(self, img_fn):
732
- r"""
733
- Get the full image path out of the image file name
734
- """
735
- if self._dataset_type == "TF_prepared":
736
- if self._purpose == s.TRAIN_PURPOSE:
737
- full_fn = os.path.join(self._raw_data_dir, "train", img_fn)
738
- elif self._purpose == s.VAL_PURPOSE:
739
- full_fn = os.path.join(self._raw_data_dir, "val", img_fn)
740
- elif self._purpose == s.TEST_PURPOSE:
741
- full_fn = os.path.join(self._raw_data_dir, "test", img_fn)
742
- else:
743
- full_fn = os.path.join(self._raw_data_dir, img_fn)
744
-
745
- if full_fn is None or not os.path.isfile(full_fn):
746
- self._log().error("File not found: " + full_fn)
747
- return None
748
-
749
- return full_fn
750
-
751
- def format_html(self, img):
752
- r"""
753
- Formats HTML code from tokenized annotation of img
754
- """
755
- tag_len = len(img["html"]["structure"]["tokens"])
756
- if self._load_cells:
757
- cell_len_max = max([len(c["tokens"]) for c in img["html"]["cells"]])
758
- else:
759
- cell_len_max = 0
760
-
761
- HTML = img["html"]["structure"]["tokens"].copy()
762
- to_insert = [i for i, tag in enumerate(HTML) if tag in ("<td>", ">")]
763
-
764
- if self._load_cells:
765
- for i, cell in zip(to_insert[::-1], img["html"]["cells"][::-1]):
766
- if cell:
767
- cell = "".join(
768
- [
769
- escape(token) if len(token) == 1 else token
770
- for token in cell["tokens"]
771
- ]
772
- )
773
- HTML.insert(i + 1, cell)
774
-
775
- HTML = "<html><body><table>%s</table></body></html>" % "".join(HTML)
776
- root = html.fromstring(HTML)
777
- if self._predict_bbox:
778
- for td, cell in zip(root.iter("td"), img["html"]["cells"]):
779
- if "bbox" in cell:
780
- bbox = cell["bbox"]
781
- td.attrib["x"] = str(bbox[0])
782
- td.attrib["y"] = str(bbox[1])
783
- td.attrib["width"] = str(bbox[2] - bbox[0])
784
- td.attrib["height"] = str(bbox[3] - bbox[1])
785
- HTML = html.tostring(root, encoding="utf-8").decode()
786
- return HTML, tag_len, cell_len_max
787
-
788
- def _build_predict_cache(self):
789
- r"""
790
- populate cache with image file names that need to be predicted
791
- """
792
- self._prepared_data_dir = c.safe_get_parameter(
793
- self._config, ["dataset", "prepared_data_dir"], required=False
794
- )
795
- self._data_name = c.safe_get_parameter(
796
- self._config, ["dataset", "name"], required=True
797
- )
798
-
799
- if self._prepared_data_dir is None:
800
-
801
- self._statistics = c.safe_get_parameter(
802
- self._config, ["dataset", "image_normalization"], required=True
803
- )
804
-
805
- self._word_map = c.safe_get_parameter(
806
- self._config, ["dataset_wordmap"], required=True
807
- )
808
-
809
- else:
810
- # Load statistics
811
- statistics_fn = c.get_prepared_data_filename("STATISTICS", self._data_name)
812
- with open(os.path.join(self._prepared_data_dir, statistics_fn), "r") as f:
813
- self._log().debug("Load statistics from: {}".format(statistics_fn))
814
- self._statistics = json.load(f)
815
-
816
- # Load word_map
817
- word_map_fn = c.get_prepared_data_filename("WORDMAP", self._data_name)
818
- with open(os.path.join(self._prepared_data_dir, word_map_fn), "r") as f:
819
- self._log().debug("Load WORDMAP from: {}".format(word_map_fn))
820
- self._word_map = json.load(f)
821
-
822
- # Get Image File Names for Prediction
823
- self._image_fns = []
824
- self._bboxes = {}
825
- self._bboxes_table = {}
826
- self._raw_data_dir = self._predict_dir
827
- self._table_bboxes = {}
828
-
829
- if self._predict_dir[-1] != "/":
830
- self._predict_dir += "/"
831
- for path in list(glob(self._predict_dir + "**/*.png", recursive=True)):
832
- filename = os.path.basename(path)
833
- self._image_fns.append(filename)
834
- self._log().info("Image found: {}".format(filename))
835
- self._bboxes[filename] = []
836
-
837
- # Get size of a dataset to predict
838
- self._dataset_size = len(self._image_fns)
839
-
840
- # Get the number of image channels
841
- self._log().info(
842
- "To test load... {}".format(self._predict_dir + self._image_fns[0])
843
- )
844
- img = u.load_image(self._predict_dir + self._image_fns[0])
845
- if img is None:
846
- msg = "Cannot load image"
847
- self._log().error(msg)
848
- raise Exception(msg)
849
- self._img_ch = img.shape[0]
850
-
851
- def _build_cache(self):
852
- r"""
853
- Cache with small data
854
- """
855
- all_bboxes = {} # Keep original bboxes for all images
856
- table_bboxes = {}
857
- self._log().info("Building the cache...")
858
- self._raw_data_dir = c.safe_get_parameter(
859
- self._config, ["dataset", "raw_data_dir"], required=True
860
- )
861
- self._prepared_data_dir = c.safe_get_parameter(
862
- self._config, ["dataset", "prepared_data_dir"], required=False
863
- )
864
-
865
- self._data_name = c.safe_get_parameter(
866
- self._config, ["dataset", "name"], required=True
867
- )
868
-
869
- if self._prepared_data_dir is None:
870
-
871
- self._statistics = c.safe_get_parameter(
872
- self._config, ["dataset", "image_normalization"], required=True
873
- )
874
-
875
- self._word_map = c.safe_get_parameter(
876
- self._config, ["dataset_wordmap"], required=True
877
- )
878
-
879
- else:
880
- # Load statistics
881
- statistics_fn = c.get_prepared_data_filename("STATISTICS", self._data_name)
882
- with open(os.path.join(self._prepared_data_dir, statistics_fn), "r") as f:
883
- self._log().debug("Load statistics from: {}".format(statistics_fn))
884
- self._statistics = json.load(f)
885
-
886
- # Load word_map
887
- word_map_fn = c.get_prepared_data_filename("WORDMAP", self._data_name)
888
- with open(os.path.join(self._prepared_data_dir, word_map_fn), "r") as f:
889
- self._log().debug("Load WORDMAP from: {}".format(word_map_fn))
890
- self._word_map = json.load(f)
891
-
892
- word_map_cell = self._word_map["word_map_cell"]
893
- word_map_tag = self._word_map["word_map_tag"]
894
- # Read image paths and captions for each image
895
- train_image_paths = []
896
- train_images = []
897
- val_image_paths = []
898
- val_images = []
899
- test_image_paths = []
900
- test_images = []
901
- predict_images = []
902
-
903
- train_image_tags = (
904
- []
905
- ) # List of list of structure tokens for each image in the train set
906
- train_image_cells = []
907
- train_image_cells_len = []
908
-
909
- val_image_tags = (
910
- []
911
- ) # List of list of structure tokens for each image in the train set
912
- val_image_cells = []
913
- val_image_cells_len = []
914
-
915
- test_image_tags = []
916
- test_gt = dict()
917
-
918
- invalid_files = 0
919
- total_files = 0
920
-
921
- self._log().debug(
922
- "Create lists with image filenames per split, train tags/cells and GT"
923
- )
924
- with jsonlines.open(self.annotation, "r") as reader:
925
- for img in tqdm(reader):
926
- total_files += 1
927
- img_filename = img["filename"]
928
- path = os.path.join(self._raw_data_dir, img["split"], img_filename)
929
-
930
- # Keep bboxes for all images
931
- all_cell_bboxes = []
932
- for cell in img["html"]["cells"]:
933
- if "bbox" not in cell:
934
- continue
935
- all_cell_bboxes.append(cell["bbox"])
936
- all_bboxes[img_filename] = all_cell_bboxes
937
-
938
- # if dataset does include bbox for the table itself
939
- if "table_bbox" in img:
940
- table_bboxes[img_filename] = img["table_bbox"]
941
- if img["split"] == "train":
942
- if self._purpose == s.TRAIN_PURPOSE:
943
- # Skip invalid images
944
- if not self.is_valid(img, self._config):
945
- invalid_files += 1
946
- continue
947
- tags = []
948
- cells = []
949
- cell_lens = []
950
- tags.append(img["html"]["structure"]["tokens"])
951
-
952
- if self._load_cells:
953
- for cell in img["html"]["cells"]:
954
- cells.append(cell["tokens"])
955
- cell_lens.append(len(cell["tokens"]) + 2)
956
- else:
957
- for cell in img["html"]["cells"]:
958
- cell_lens.append(len(cell["tokens"]) + 2)
959
-
960
- train_images.append(img_filename)
961
- train_image_paths.append(path)
962
- train_image_tags.append(tags)
963
- train_image_cells.append(cells)
964
- train_image_cells_len.append(cell_lens)
965
- if img["split"] == "val":
966
- if self._purpose == s.VAL_PURPOSE:
967
- # Skip invalid images
968
- if not self.is_valid(img, self._config):
969
- invalid_files += 1
970
- continue
971
-
972
- val_tags = []
973
- val_cells = []
974
- val_cell_lens = []
975
- val_tags.append(img["html"]["structure"]["tokens"])
976
-
977
- if self._load_cells:
978
- for cell in img["html"]["cells"]:
979
- val_cells.append(cell["tokens"])
980
- val_cell_lens.append(len(cell["tokens"]) + 2)
981
- else:
982
- for cell in img["html"]["cells"]:
983
- val_cell_lens.append(len(cell["tokens"]) + 2)
984
-
985
- with Image.open(path) as im:
986
- HTML, tag_len, cell_len_max = self.format_html(img)
987
- lt1 = [">", "lcel", "ucel", "xcel"]
988
- lt2 = img["html"]["structure"]["tokens"]
989
- tcheck = any(item in lt1 for item in lt2)
990
- if tcheck:
991
- gtt = "complex"
992
- else:
993
- gtt = "simple"
994
- test_gt[img_filename] = {
995
- "html": HTML,
996
- "tag_len": tag_len,
997
- "cell_len_max": cell_len_max,
998
- "width": im.width,
999
- "height": im.height,
1000
- "type": gtt,
1001
- "html_tags": img["html"]["structure"]["tokens"],
1002
- }
1003
-
1004
- val_images.append(img_filename)
1005
- val_image_paths.append(path)
1006
- val_image_tags.append(val_tags)
1007
- val_image_cells.append(val_cells)
1008
- val_image_cells_len.append(val_cell_lens)
1009
-
1010
- elif img["split"] == "test":
1011
- if self._purpose == s.TEST_PURPOSE:
1012
- # Skip invalid images
1013
- if not self.is_valid(img, self._config):
1014
- invalid_files += 1
1015
- continue
1016
-
1017
- with Image.open(path) as im:
1018
- HTML, tag_len, cell_len_max = self.format_html(img)
1019
- lt1 = [">", "lcel", "ucel", "xcel"]
1020
- lt2 = img["html"]["structure"]["tokens"]
1021
- tcheck = any(item in lt1 for item in lt2)
1022
- if tcheck:
1023
- gtt = "complex"
1024
- else:
1025
- gtt = "simple"
1026
- test_gt[img_filename] = {
1027
- "html": HTML,
1028
- "tag_len": tag_len,
1029
- "cell_len_max": cell_len_max,
1030
- "width": im.width,
1031
- "height": im.height,
1032
- "type": gtt,
1033
- "html_tags": img["html"]["structure"]["tokens"],
1034
- }
1035
- test_images.append(img_filename)
1036
-
1037
- test_tags = []
1038
- test_tags.append(img["html"]["structure"]["tokens"])
1039
- test_image_paths.append(path)
1040
- test_image_tags.append(test_tags)
1041
- else:
1042
- if self._purpose == s.PREDICT_PURPOSE:
1043
- predict_images.append(img_filename)
1044
-
1045
- image_fns = {
1046
- s.TRAIN_SPLIT: train_images,
1047
- s.VAL_SPLIT: val_images,
1048
- s.TEST_SPLIT: test_images,
1049
- }
1050
-
1051
- self._log().debug("Keep the split data pointed by the purpose")
1052
- # Images
1053
- # Filter out the images for the particual split
1054
- self._image_fns = image_fns[self._purpose]
1055
- self._dataset_size = len(self._image_fns)
1056
- assert len(self._image_fns) > 0, "Empty image split: " + self._purpose
1057
-
1058
- # Get the number of image channels
1059
- img = u.load_image(self._get_image_path(self._image_fns[0]))
1060
- if img is None:
1061
- msg = "Cannot load image"
1062
- self._log().error(msg)
1063
- raise Exception(msg)
1064
- self._img_ch = img.shape[0]
1065
-
1066
- # img_name -> list of bboxes, each bbox is a list with x1y1x2y2
1067
- split_bboxes = {}
1068
- img_names = set(self._image_fns) # Set will speed up search
1069
- for img_name, bbox in all_bboxes.items():
1070
- if img_name not in img_names:
1071
- continue
1072
- if img_name not in split_bboxes:
1073
- split_bboxes[img_name] = []
1074
- # we should use extend not append, otherwise we get list within list
1075
- split_bboxes[img_name].extend(bbox)
1076
- self._bboxes = split_bboxes
1077
- self._table_bboxes = table_bboxes
1078
- # -------------------------------------------------------------------------------
1079
- # Train specific
1080
- # -------------------------------------------------------------------------------
1081
- # Compute encoded tags and cells
1082
- enc_tags = []
1083
- tag_lens = []
1084
- enc_cells = []
1085
- cell_lens = []
1086
-
1087
- val_enc_tags = []
1088
- val_tag_lens = []
1089
- val_enc_cells = []
1090
- val_cell_lens = []
1091
-
1092
- test_enc_tags = []
1093
- test_tag_lens = []
1094
-
1095
- # Based on the "purpose"
1096
- if self._purpose == s.TRAIN_PURPOSE:
1097
- self._log().debug("Convert train tags and cell tags to indices")
1098
- for i, path in enumerate(tqdm(train_image_paths)):
1099
- for tag in train_image_tags[i]:
1100
- # Encode tags
1101
- # Notice that at this point we don't have images longer than max_tag_length
1102
- # The same happens with the cell tokens
1103
- enc_tag = (
1104
- [word_map_tag["<start>"]]
1105
- + [
1106
- word_map_tag.get(word, word_map_tag["<unk>"])
1107
- for word in tag
1108
- ]
1109
- + [word_map_tag["<end>"]]
1110
- + [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
1111
- )
1112
- # Find caption lengths
1113
- tag_len = len(tag) + 2
1114
-
1115
- enc_tags.append(enc_tag)
1116
- tag_lens.append(tag_len)
1117
-
1118
- enc_cell_seq = []
1119
- cell_seq_len = []
1120
-
1121
- if self._load_cells:
1122
- for cell in train_image_cells[i]:
1123
- # Encode captions
1124
- enc_cell = (
1125
- [word_map_cell["<start>"]]
1126
- + [
1127
- word_map_cell.get(word, word_map_cell["<unk>"])
1128
- for word in cell
1129
- ]
1130
- + [word_map_cell["<end>"]]
1131
- + [word_map_cell["<pad>"]]
1132
- * (self._max_cell_len - len(cell))
1133
- )
1134
- enc_cell_seq.append(enc_cell)
1135
- # Find caption lengths
1136
- cell_len = len(cell) + 2
1137
- cell_seq_len.append(cell_len)
1138
- else:
1139
- for cell in train_image_cells_len[i]:
1140
- cell_seq_len.append(cell)
1141
- enc_cells.append(enc_cell_seq)
1142
- cell_lens.append(cell_seq_len)
1143
-
1144
- if self._purpose == s.VAL_PURPOSE:
1145
- self._log().debug("Convert train tags and cell tags to indices")
1146
- for i, path in enumerate(tqdm(val_image_paths)):
1147
- for tag in val_image_tags[i]:
1148
- # Encode tags
1149
- # Notice that at this point we don't have images longer than max_tag_length
1150
- # The same happens with the cell tokens
1151
- val_enc_tag = (
1152
- [word_map_tag["<start>"]]
1153
- + [
1154
- word_map_tag.get(word, word_map_tag["<unk>"])
1155
- for word in tag
1156
- ]
1157
- + [word_map_tag["<end>"]]
1158
- + [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
1159
- )
1160
- # Find caption lengths
1161
- val_tag_len = len(tag) + 2
1162
-
1163
- val_enc_tags.append(val_enc_tag)
1164
- val_tag_lens.append(val_tag_len)
1165
-
1166
- val_enc_cell_seq = []
1167
- val_cell_seq_len = []
1168
-
1169
- if self._load_cells:
1170
- for cell in val_image_cells[i]:
1171
- # Encode captions
1172
- val_enc_cell = (
1173
- [word_map_cell["<start>"]]
1174
- + [
1175
- word_map_cell.get(word, word_map_cell["<unk>"])
1176
- for word in cell
1177
- ]
1178
- + [word_map_cell["<end>"]]
1179
- + [word_map_cell["<pad>"]]
1180
- * (self._max_cell_len - len(cell))
1181
- )
1182
- val_enc_cell_seq.append(val_enc_cell)
1183
-
1184
- # Find caption lengths
1185
- cell_len = len(cell) + 2
1186
- val_cell_seq_len.append(cell_len)
1187
- else:
1188
- for cell in val_image_cells_len[i]:
1189
- val_cell_seq_len.append(cell)
1190
- val_enc_cells.append(val_enc_cell_seq)
1191
- val_cell_lens.append(val_cell_seq_len)
1192
-
1193
- if self._purpose == s.TEST_PURPOSE:
1194
- self._log().debug("Convert val tags to indices")
1195
- for i, path in enumerate(tqdm(test_image_paths)):
1196
- for tag in test_image_tags[i]:
1197
- # Encode tags
1198
- # Notice that at this point we don't have images longer than max_tag_length
1199
- # The same happens with the cell tokens
1200
- test_enc_tag = (
1201
- [word_map_tag["<start>"]]
1202
- + [
1203
- word_map_tag.get(word, word_map_tag["<unk>"])
1204
- for word in tag
1205
- ]
1206
- + [word_map_tag["<end>"]]
1207
- + [word_map_tag["<pad>"]] * (self._max_tag_len - len(tag))
1208
- )
1209
-
1210
- # Find caption lengths
1211
- test_tag_len = len(tag) + 2
1212
-
1213
- test_enc_tags.append(test_enc_tag)
1214
- test_tag_lens.append(test_tag_len)
1215
-
1216
- self._tags = enc_tags
1217
- self._taglens = tag_lens
1218
- self._cells = enc_cells
1219
- self._cell_lens = cell_lens
1220
-
1221
- # -------------------------------------------------------------------------------
1222
- # val specific
1223
- # -------------------------------------------------------------------------------
1224
- self._val_tags = val_enc_tags
1225
- self._val_taglens = val_tag_lens
1226
- self._val_cells = val_enc_cells
1227
- self._val_cell_lens = val_cell_lens
1228
- # -------------------------------------------------------------------------------
1229
- # test / evaluation specific
1230
- # -------------------------------------------------------------------------------
1231
- self._test = test_gt
1232
- self._test_tags = test_enc_tags
1233
- self._test_taglens = test_tag_lens