docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,1020 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import json
6
+ import logging
7
+ import os
8
+ from itertools import groupby
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+
14
+ import docling_ibm_models.tableformer.common as c
15
+ import docling_ibm_models.tableformer.data_management.functional as F
16
+ import docling_ibm_models.tableformer.data_management.transforms as T
17
+ import docling_ibm_models.tableformer.settings as s
18
+ import docling_ibm_models.tableformer.utils.utils as u
19
+ from docling_ibm_models.tableformer.data_management.matching_post_processor import (
20
+ MatchingPostProcessor,
21
+ )
22
+ from docling_ibm_models.tableformer.data_management.tf_cell_matcher import CellMatcher
23
+ from docling_ibm_models.tableformer.models.common.base_model import BaseModel
24
+ from docling_ibm_models.tableformer.otsl import otsl_to_html
25
+ from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
26
+
27
+ # LOG_LEVEL = logging.INFO
28
+ # LOG_LEVEL = logging.DEBUG
29
+ LOG_LEVEL = logging.WARN
30
+
31
+
32
+ class bcolors:
33
+ HEADER = "\033[95m"
34
+ OKBLUE = "\033[94m"
35
+ OKCYAN = "\033[96m"
36
+ OKGREEN = "\033[92m"
37
+ WARNING = "\033[93m"
38
+ FAIL = "\033[91m"
39
+ ENDC = "\033[0m"
40
+ BOLD = "\033[1m"
41
+ UNDERLINE = "\033[4m"
42
+
43
+
44
+ def otsl_sqr_chk(rs_list, logdebug):
45
+ rs_list_split = [
46
+ list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
47
+ ]
48
+ isSquare = True
49
+ if len(rs_list_split) > 0:
50
+ init_tag_len = len(rs_list_split[0]) + 1
51
+
52
+ totcelnum = rs_list.count("fcel") + rs_list.count("ecel")
53
+ if logdebug:
54
+ print("Total number of cells = {}".format(totcelnum))
55
+
56
+ for ind, ln in enumerate(rs_list_split):
57
+ ln.append("nl")
58
+ if logdebug:
59
+ print("{}".format(ln))
60
+ if len(ln) != init_tag_len:
61
+ isSquare = False
62
+ if isSquare:
63
+ if logdebug:
64
+ print(
65
+ "{}*OK* Table is square! *OK*{}".format(
66
+ bcolors.OKGREEN, bcolors.ENDC
67
+ )
68
+ )
69
+ else:
70
+ if logdebug:
71
+ err_name = "{}***** ERR ******{}"
72
+ print(err_name.format(bcolors.FAIL, bcolors.ENDC))
73
+ print(
74
+ "{}*ERR* Table is not square! *ERR*{}".format(
75
+ bcolors.FAIL, bcolors.ENDC
76
+ )
77
+ )
78
+ return isSquare
79
+
80
+
81
+ def decide_device(config: dict) -> str:
82
+ r"""
83
+ Decide the inference device based on the "predict.device_mode" parameter
84
+ """
85
+ device_mode = config["predict"].get("device_mode", "cpu")
86
+ num_gpus = torch.cuda.device_count()
87
+
88
+ if device_mode == "auto":
89
+ device = "cuda:0" if num_gpus > 0 else "cpu"
90
+ elif device_mode in ["gpu", "cuda"]:
91
+ device = "cuda:0"
92
+ else:
93
+ device = "cpu"
94
+ return device
95
+
96
+
97
+ class TFPredictor:
98
+ r"""
99
+ Table predictions for the in-memory Docling API
100
+ """
101
+
102
+ def __init__(self, config):
103
+ r"""
104
+ Parameters
105
+ ----------
106
+ config : dict
107
+ Parameters configuration
108
+ Raises
109
+ ------
110
+ ValueError
111
+ When the model cannot be found
112
+ """
113
+ self._device = decide_device(config)
114
+ self._log().info("Running on device: {}".format(self._device))
115
+
116
+ self._config = config
117
+ self.enable_post_process = True
118
+
119
+ self._padding = config["predict"].get("padding", False)
120
+ self._padding_size = config["predict"].get("padding_size", 10)
121
+
122
+ self._cell_matcher = CellMatcher(config)
123
+ self._post_processor = MatchingPostProcessor(config)
124
+
125
+ self._init_word_map()
126
+ # Load the model
127
+ self._model = self._load_model()
128
+ self._model.eval()
129
+ self._prof = config["predict"].get("profiling", False)
130
+ self._profiling_agg_window = config["predict"].get("profiling_agg_window", None)
131
+ if self._profiling_agg_window is not None:
132
+ AggProfiler(self._profiling_agg_window)
133
+ else:
134
+ AggProfiler()
135
+
136
+ def _init_word_map(self):
137
+ self._prepared_data_dir = c.safe_get_parameter(
138
+ self._config, ["dataset", "prepared_data_dir"], required=False
139
+ )
140
+
141
+ if self._prepared_data_dir is None:
142
+ self._word_map = c.safe_get_parameter(
143
+ self._config, ["dataset_wordmap"], required=True
144
+ )
145
+ else:
146
+ data_name = c.safe_get_parameter(
147
+ self._config, ["dataset", "name"], required=True
148
+ )
149
+ word_map_fn = c.get_prepared_data_filename("WORDMAP", data_name)
150
+
151
+ # Load word_map
152
+ with open(os.path.join(self._prepared_data_dir, word_map_fn), "r") as f:
153
+ self._log().debug("Load WORDMAP from: {}".format(word_map_fn))
154
+ self._word_map = json.load(f)
155
+
156
+ self._init_data = {"word_map": self._word_map}
157
+ # Prepare a reversed index for the word map
158
+ self._rev_word_map = {v: k for k, v in self._word_map["word_map_tag"].items()}
159
+
160
+ def get_init_data(self):
161
+ r"""
162
+ Return the initialization data
163
+ """
164
+ return self._init_data
165
+
166
+ def get_model(self):
167
+ r"""
168
+ Return the loaded model
169
+ """
170
+ return self._model
171
+
172
+ def _load_model(self):
173
+ r"""
174
+ Load the proper model
175
+ """
176
+
177
+ self._model_type = self._config["model"]["type"]
178
+ # Added import here to avoid loading turbotransformer library unnecessarily
179
+ if self._model_type == "TableModel04_rs":
180
+ from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
181
+ TableModel04_rs,
182
+ )
183
+ for candidate in BaseModel.__subclasses__():
184
+ if candidate.__name__ == self._model_type:
185
+ model = candidate(
186
+ self._config, self._init_data, s.PREDICT_PURPOSE, self._device
187
+ )
188
+
189
+ if model is None:
190
+ err_msg = "Not able to initiate a model for {}".format(self._model_type)
191
+ self._log().error(err_msg)
192
+ raise ValueError(err_msg)
193
+
194
+ self._remove_padding = False
195
+ if self._model_type == "TableModel02":
196
+ self._remove_padding = True
197
+
198
+ # Load model from checkpoint
199
+ success, _, _, _, _ = model.load()
200
+ if not success:
201
+ err_msg = "Cannot load the model"
202
+ self._log().error(err_msg)
203
+ raise ValueError(err_msg)
204
+
205
+ return model
206
+
207
+ def get_device(self):
208
+ return self._device
209
+
210
+ def get_model_type(self):
211
+ return self._model_type
212
+
213
+ def _log(self):
214
+ # Setup a custom logger
215
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
216
+
217
+ def _deletebbox(self, listofbboxes, index):
218
+ newlist = []
219
+ for i in range(len(listofbboxes)):
220
+ bbox = listofbboxes[i]
221
+ if i not in index:
222
+ newlist.append(bbox)
223
+ return newlist
224
+
225
+ def _remove_bbox_span_desync(self, prediction):
226
+ # Delete 1 extra bbox after span tag
227
+ index_to_delete_from = 0
228
+ indexes_to_delete = []
229
+ newbboxes = []
230
+ for html_elem in prediction["html_seq"]:
231
+ if html_elem == "<td>":
232
+ index_to_delete_from += 1
233
+ if html_elem == ">":
234
+ index_to_delete_from += 1
235
+ # remove element from bboxes
236
+ self._log().debug(
237
+ "========= DELETE BBOX INDEX: {}".format(index_to_delete_from)
238
+ )
239
+ indexes_to_delete.append(index_to_delete_from)
240
+
241
+ newbboxes = self._deletebbox(prediction["bboxes"], indexes_to_delete)
242
+ return newbboxes
243
+
244
+ def _check_bbox_sync(self, prediction):
245
+ bboxes = []
246
+ match = False
247
+ # count bboxes
248
+ count_bbox = len(prediction["bboxes"])
249
+ # count td tags
250
+ count_td = 0
251
+ for html_elem in prediction["html_seq"]:
252
+ if html_elem == "<td>" or html_elem == ">":
253
+ count_td += 1
254
+ if html_elem in ["fcel", "ecel", "ched", "rhed", "srow"]:
255
+ count_td += 1
256
+ self._log().debug(
257
+ "======================= PREDICTED BBOXES: {}".format(count_bbox)
258
+ )
259
+ self._log().debug(
260
+ "======================= PREDICTED CELLS: {}".format(count_td)
261
+ )
262
+ if count_bbox != count_td:
263
+ bboxes = self._remove_bbox_span_desync(prediction)
264
+ else:
265
+ bboxes = prediction["bboxes"]
266
+ match = True
267
+ return match, bboxes
268
+
269
+ def page_coords_to_table_coords(self, bbox, table_bbox, im_width, im_height):
270
+ r"""
271
+ Transforms given bbox from page coordinate system into table image coordinate system
272
+
273
+ Parameters
274
+ ----------
275
+ bbox : list
276
+ bbox to transform in page coordinates
277
+ table_bbox : list
278
+ table bbox, in page coordinates
279
+ im_width : integer
280
+ width of an image with rendered table (in pixels)
281
+ im_height : integer
282
+ height of an image height rendered table (in pixels)
283
+
284
+ Returns
285
+ -------
286
+ bbox: list
287
+ bbox with transformed coordinates
288
+ """
289
+ # Coordinates of given bbox
290
+ x1 = bbox[0]
291
+ y1 = bbox[1]
292
+ x2 = bbox[2]
293
+ y2 = bbox[3]
294
+
295
+ # Coordinates of table bbox
296
+ t_x1 = table_bbox[0]
297
+ t_y1 = table_bbox[1]
298
+ t_x2 = table_bbox[2]
299
+ t_y2 = table_bbox[3]
300
+
301
+ # Table width / height
302
+ tw = t_x2 - t_x1
303
+ th = t_y2 - t_y1
304
+ new_bbox = [0, 0, 0, 0]
305
+ # Flip corners, substract table coordinates and rescale to new image size
306
+ new_bbox[0] = im_width * (x1 - t_x1) / tw
307
+ new_bbox[1] = im_height * (t_y2 - y2) / th
308
+ new_bbox[2] = im_width * (x2 - t_x1) / tw
309
+ new_bbox[3] = im_height * (t_y2 - y1) / th
310
+
311
+ return new_bbox
312
+
313
+ def _depad_bboxes(self, bboxes, new_image_ratio):
314
+ r"""
315
+ Removes padding from predicted bboxes for previously padded image
316
+
317
+ Parameters
318
+ ----------
319
+ bboxes : list of lists
320
+ list of bboxes that have to be recalculated to remove implied padding
321
+ new_image_ratio : float
322
+ Ratio of padded image size to the original image size
323
+
324
+ Returns
325
+ -------
326
+ new_bboxes: list
327
+ bboxes with transformed coordinates
328
+ """
329
+ new_bboxes = []
330
+ c_x = 0.5
331
+ c_y = 0.5
332
+
333
+ self._log().debug("PREDICTED BBOXES: {}".format(bboxes))
334
+ self._log().debug("new_image_ratio: {}".format(new_image_ratio))
335
+
336
+ for bbox in bboxes:
337
+ # 1. corner coords -> center coords
338
+ cb_x1 = bbox[0] - c_x
339
+ cb_y1 = bbox[1] - c_y
340
+ cb_x2 = bbox[2] - c_x
341
+ cb_y2 = bbox[3] - c_y
342
+
343
+ # 2. center coords * new_image_ratio
344
+ r_cb_x1 = cb_x1 * new_image_ratio
345
+ r_cb_y1 = cb_y1 * new_image_ratio
346
+ r_cb_x2 = cb_x2 * new_image_ratio
347
+ r_cb_y2 = cb_y2 * new_image_ratio
348
+
349
+ # 3. center coords -> corner coords
350
+ x1 = r_cb_x1 + c_x
351
+ y1 = r_cb_y1 + c_y
352
+ x2 = r_cb_x2 + c_x
353
+ y2 = r_cb_y2 + c_y
354
+
355
+ x1 = np.clip(x1, 0.0, 1.0)
356
+ y1 = np.clip(y1, 0.0, 1.0)
357
+ x2 = np.clip(x2, 0.0, 1.0)
358
+ y2 = np.clip(y2, 0.0, 1.0)
359
+
360
+ new_bbox = [x1, y1, x2, y2]
361
+ new_bboxes.append(new_bbox)
362
+
363
+ self._log().debug("DEPAD BBOXES: {}".format(new_bboxes))
364
+
365
+ return new_bboxes
366
+
367
+ def _pad_image(self, iocr_page):
368
+ r"""
369
+ Adds padding to the image
370
+
371
+ Parameters
372
+ ----------
373
+ iocr_page : dict
374
+ Docling provided table data
375
+
376
+ Returns
377
+ -------
378
+ new_im: PIL image
379
+ new, padded image
380
+ new_image_ratio : float
381
+ Ratio of padded image size to the original image size
382
+ """
383
+ _, old_iw, old_ih = iocr_page["image"].shape
384
+
385
+ margin_i = self._padding_size # pixels
386
+
387
+ desired_iw = old_iw + (margin_i * 2)
388
+ desired_ih = old_ih + (margin_i * 2)
389
+
390
+ # Ratio of new image size to the original image size
391
+ new_image_ratio = desired_iw / old_iw
392
+
393
+ bcolor = (255, 255, 255)
394
+ # Create empty canvas of background color and desired size
395
+ padded_image = F.pad(
396
+ iocr_page["image"],
397
+ (desired_iw, desired_ih, desired_iw, desired_ih),
398
+ fill=bcolor,
399
+ )
400
+ return padded_image, new_image_ratio
401
+
402
+ def _pre_process_image(self, iocr_page):
403
+ r"""
404
+ Pre-process table image in memory, before doing prediction
405
+ Currently just removes from the image separate PDF cells that only contain "$" sign
406
+ This is done to remove model confusion when dealing with financial reports
407
+
408
+ Parameters
409
+ ----------
410
+ iocr_page : dict
411
+ Docling provided table data
412
+
413
+ Returns
414
+ -------
415
+ iocr_page["image"] : PIL image
416
+ updated table image with "$" repainted
417
+ new_image_ratio : float
418
+ Ratio of padded image size to the original image size
419
+ """
420
+
421
+ new_image_ratio = 1.0
422
+
423
+ ic, iw, ih = iocr_page["image"].shape
424
+
425
+ return iocr_page["image"], new_image_ratio
426
+
427
+ def _merge_tf_output(self, docling_output, pdf_cells):
428
+ tf_output = []
429
+ tf_cells_map = {}
430
+ max_row_idx = 0
431
+
432
+ for docling_item in docling_output:
433
+ r_idx = str(docling_item["start_row_offset_idx"])
434
+ c_idx = str(docling_item["start_col_offset_idx"])
435
+ cell_key = c_idx + "_" + r_idx
436
+ if cell_key in tf_cells_map:
437
+ for pdf_cell in pdf_cells:
438
+ if pdf_cell["id"] == docling_item["cell_id"]:
439
+ text_cell_bbox = {
440
+ "b": pdf_cell["bbox"][1],
441
+ "l": pdf_cell["bbox"][0],
442
+ "r": pdf_cell["bbox"][2],
443
+ "t": pdf_cell["bbox"][3],
444
+ "token": pdf_cell["text"],
445
+ }
446
+ tf_cells_map[cell_key]["text_cell_bboxes"].append(
447
+ text_cell_bbox
448
+ )
449
+ else:
450
+ tf_cells_map[cell_key] = {
451
+ "bbox": docling_item["bbox"],
452
+ "row_span": docling_item["row_span"],
453
+ "col_span": docling_item["col_span"],
454
+ "start_row_offset_idx": docling_item["start_row_offset_idx"],
455
+ "end_row_offset_idx": docling_item["end_row_offset_idx"],
456
+ "start_col_offset_idx": docling_item["start_col_offset_idx"],
457
+ "end_col_offset_idx": docling_item["end_col_offset_idx"],
458
+ "indentation_level": docling_item["indentation_level"],
459
+ "text_cell_bboxes": [],
460
+ "column_header": docling_item["column_header"],
461
+ "row_header": docling_item["row_header"],
462
+ "row_section": docling_item["row_section"],
463
+ }
464
+
465
+ if docling_item["start_row_offset_idx"] > max_row_idx:
466
+ max_row_idx = docling_item["start_row_offset_idx"]
467
+
468
+ for pdf_cell in pdf_cells:
469
+ if pdf_cell["id"] == docling_item["cell_id"]:
470
+ text_cell_bbox = {
471
+ "b": pdf_cell["bbox"][1],
472
+ "l": pdf_cell["bbox"][0],
473
+ "r": pdf_cell["bbox"][2],
474
+ "t": pdf_cell["bbox"][3],
475
+ "token": pdf_cell["text"],
476
+ }
477
+ tf_cells_map[cell_key]["text_cell_bboxes"].append(
478
+ text_cell_bbox
479
+ )
480
+
481
+ for k in tf_cells_map:
482
+ tf_output.append(tf_cells_map[k])
483
+ return tf_output
484
+
485
+ def resize_img(self, image, width=None, height=None, inter=cv2.INTER_AREA):
486
+ # initialize the dimensions of the image to be resized and
487
+ # grab the image size
488
+ dim = None
489
+ (h, w) = image.shape[:2]
490
+ sf = 1.0
491
+ # if both the width and height are None, then return the
492
+ # original image
493
+ if width is None and height is None:
494
+ return image, sf
495
+ # check to see if the width is None
496
+ if width is None:
497
+ # calculate the ratio of the height and construct the
498
+ # dimensions
499
+ r = height / float(h)
500
+ sf = r
501
+ dim = (int(w * r), height)
502
+ # otherwise, the height is None
503
+ else:
504
+ # calculate the ratio of the width and construct the
505
+ # dimensions
506
+ r = width / float(w)
507
+ sf = r
508
+ dim = (width, int(h * r))
509
+ # resize the image
510
+ resized = cv2.resize(image, dim, interpolation=inter)
511
+ # return the resized image
512
+ return resized, sf
513
+
514
+ def multi_table_predict(self, iocr_page, table_bboxes, do_matching=True):
515
+ # def multi_table_predict(self, iocr_page, page_image, table_bboxes):
516
+ multi_tf_output = []
517
+ page_image = iocr_page["image"]
518
+
519
+ # Prevent large image submission, by resizing input
520
+ page_image_resized, scale_factor = self.resize_img(page_image, height=1024)
521
+
522
+ for table_bbox in table_bboxes:
523
+ # Downscale table bounding box to the size of new image
524
+ table_bbox[0] = table_bbox[0] * scale_factor
525
+ table_bbox[1] = table_bbox[1] * scale_factor
526
+ table_bbox[2] = table_bbox[2] * scale_factor
527
+ table_bbox[3] = table_bbox[3] * scale_factor
528
+
529
+ table_image = page_image_resized[
530
+ round(table_bbox[1]) : round(table_bbox[3]),
531
+ round(table_bbox[0]) : round(table_bbox[2]),
532
+ ]
533
+ # table_image = page_image
534
+ # Predict
535
+ if do_matching:
536
+ tf_responses, predict_details = self.predict(
537
+ iocr_page, table_bbox, table_image, scale_factor, None
538
+ )
539
+ else:
540
+ tf_responses, predict_details = self.predict_dummy(
541
+ iocr_page, table_bbox, table_image, scale_factor, None
542
+ )
543
+
544
+ # ======================================================================================
545
+ # PROCESS PREDICTED RESULTS, TO TURN PREDICTED COL/ROW IDs into Indexes
546
+ # Indexes should be in increasing order, without gaps
547
+
548
+ # Fix col/row indexes
549
+ # Arranges all col/row indexes sequentially without gaps using input IDs
550
+
551
+ indexing_start_cols = [] # Index of original start col IDs (not indexes)
552
+ indexing_end_cols = [] # Index of original end col IDs (not indexes)
553
+ indexing_start_rows = [] # Index of original start row IDs (not indexes)
554
+ indexing_end_rows = [] # Index of original end row IDs (not indexes)
555
+
556
+ # First, collect all possible predicted IDs, to be used as indexes
557
+ # ID's returned by Tableformer are sequential, but might contain gaps
558
+ for tf_response_cell in tf_responses:
559
+ start_col_offset_idx = tf_response_cell["start_col_offset_idx"]
560
+ end_col_offset_idx = tf_response_cell["end_col_offset_idx"]
561
+ start_row_offset_idx = tf_response_cell["start_row_offset_idx"]
562
+ end_row_offset_idx = tf_response_cell["end_row_offset_idx"]
563
+
564
+ # Collect all possible col/row IDs:
565
+ if start_col_offset_idx not in indexing_start_cols:
566
+ indexing_start_cols.append(start_col_offset_idx)
567
+ if end_col_offset_idx not in indexing_end_cols:
568
+ indexing_end_cols.append(end_col_offset_idx)
569
+ if start_row_offset_idx not in indexing_start_rows:
570
+ indexing_start_rows.append(start_row_offset_idx)
571
+ if end_row_offset_idx not in indexing_end_rows:
572
+ indexing_end_rows.append(end_row_offset_idx)
573
+
574
+ indexing_start_cols.sort()
575
+ indexing_end_cols.sort()
576
+ indexing_start_rows.sort()
577
+ indexing_end_rows.sort()
578
+
579
+ # After this - put actual indexes of IDs back into predicted structure...
580
+ for tf_response_cell in tf_responses:
581
+ tf_response_cell["start_col_offset_idx"] = indexing_start_cols.index(
582
+ tf_response_cell["start_col_offset_idx"]
583
+ )
584
+ tf_response_cell["end_col_offset_idx"] = (
585
+ tf_response_cell["start_col_offset_idx"]
586
+ + tf_response_cell["col_span"]
587
+ )
588
+ tf_response_cell["start_row_offset_idx"] = indexing_start_rows.index(
589
+ tf_response_cell["start_row_offset_idx"]
590
+ )
591
+ tf_response_cell["end_row_offset_idx"] = (
592
+ tf_response_cell["start_row_offset_idx"]
593
+ + tf_response_cell["row_span"]
594
+ )
595
+ # Counting matched cols/rows from actual indexes (and not ids)
596
+ predict_details["num_cols"] = len(indexing_end_cols)
597
+ predict_details["num_rows"] = len(indexing_end_rows)
598
+ # Put results into multi_tf_output
599
+ multi_tf_output.append(
600
+ {"tf_responses": tf_responses, "predict_details": predict_details}
601
+ )
602
+ # Upscale table bounding box back, for visualization purposes
603
+ table_bbox[0] = table_bbox[0] / scale_factor
604
+ table_bbox[1] = table_bbox[1] / scale_factor
605
+ table_bbox[2] = table_bbox[2] / scale_factor
606
+ table_bbox[3] = table_bbox[3] / scale_factor
607
+ # Return grouped results of predictions
608
+ return multi_tf_output
609
+
610
+ def predict_dummy(
611
+ self, iocr_page, table_bbox, table_image, scale_factor, eval_res_preds=None
612
+ ):
613
+ r"""
614
+ Predict the table out of an image in memory
615
+
616
+ Parameters
617
+ ----------
618
+ iocr_page : dict
619
+ Docling provided table data
620
+ eval_res_preds : dict
621
+ Ready predictions provided by the evaluation results
622
+
623
+ Returns
624
+ -------
625
+ docling_output : string
626
+ json response formatted according to Docling api expectations
627
+
628
+ matching_details : string
629
+ json with details about the matching between the pdf cells and the table cells
630
+ """
631
+ AggProfiler().start_agg(self._prof)
632
+
633
+ max_steps = self._config["predict"]["max_steps"]
634
+ beam_size = self._config["predict"]["beam_size"]
635
+ image_batch = self._prepare_image(table_image)
636
+ # Make predictions
637
+ prediction = {}
638
+
639
+ with torch.no_grad():
640
+ # Compute predictions
641
+ if (
642
+ eval_res_preds is not None
643
+ ): # Don't run the model, use the provided predictions
644
+ prediction["bboxes"] = eval_res_preds["bboxes"]
645
+ pred_tag_seq = eval_res_preds["tag_seq"]
646
+ elif self._config["predict"]["bbox"]:
647
+ pred_tag_seq, outputs_class, outputs_coord = self._model.predict(
648
+ image_batch, max_steps, beam_size
649
+ )
650
+
651
+ if outputs_coord is not None:
652
+ bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
653
+ prediction["bboxes"] = bbox_pred.tolist()
654
+ else:
655
+ prediction["bboxes"] = []
656
+ if outputs_class is not None:
657
+ result_class = torch.argmax(outputs_class, dim=1)
658
+ prediction["classes"] = result_class.tolist()
659
+ else:
660
+ prediction["classes"] = []
661
+ if self._remove_padding:
662
+ pred_tag_seq, _ = u.remove_padding(pred_tag_seq)
663
+ else:
664
+ pred_tag_seq, _, _ = self._model.predict(
665
+ image_batch, max_steps, beam_size
666
+ )
667
+ # Check if padding should be removed
668
+ if self._remove_padding:
669
+ pred_tag_seq, _ = u.remove_padding(pred_tag_seq)
670
+
671
+ prediction["tag_seq"] = pred_tag_seq
672
+ prediction["rs_seq"] = self._get_html_tags(pred_tag_seq)
673
+ prediction["html_seq"] = otsl_to_html(prediction["rs_seq"], False)
674
+ # Remove implied padding from bbox predictions,
675
+ # that we added on image pre-processing stage
676
+ self._log().debug("----- rs_seq -----")
677
+ self._log().debug(prediction["rs_seq"])
678
+ self._log().debug(len(prediction["rs_seq"]))
679
+ otsl_sqr_chk(prediction["rs_seq"], False)
680
+
681
+ # Check that bboxes are in sync with predicted tags
682
+ sync, corrected_bboxes = self._check_bbox_sync(prediction)
683
+ if not sync:
684
+ prediction["bboxes"] = corrected_bboxes
685
+
686
+ # Match the cells
687
+ matching_details = {"table_cells": [], "matches": {}}
688
+
689
+ # Table bbox upscaling will scale predicted bboxes too within cell matcher
690
+ scaled_table_bbox = [
691
+ table_bbox[0] / scale_factor,
692
+ table_bbox[1] / scale_factor,
693
+ table_bbox[2] / scale_factor,
694
+ table_bbox[3] / scale_factor,
695
+ ]
696
+
697
+ if len(prediction["bboxes"]) > 0:
698
+ matching_details = self._cell_matcher.match_cells_dummy(
699
+ iocr_page, scaled_table_bbox, prediction
700
+ )
701
+ # Generate the expected Docling responses
702
+ AggProfiler().begin("generate_docling_response", self._prof)
703
+ docling_output = self._generate_tf_response_dummy(
704
+ matching_details["table_cells"]
705
+ )
706
+
707
+ AggProfiler().end("generate_docling_response", self._prof)
708
+ # Add the docling_output sorted by cell_id into the matching_details
709
+ docling_output.sort(key=lambda item: item["cell_id"])
710
+ matching_details["docling_responses"] = docling_output
711
+ # Merge docling_output and pdf_cells into one TF output,
712
+ # with deduplicated table cells
713
+ # tf_output = self._merge_tf_output_dummy(docling_output)
714
+ tf_output = docling_output
715
+
716
+ return tf_output, matching_details
717
+
718
+ def predict(
719
+ self, iocr_page, table_bbox, table_image, scale_factor, eval_res_preds=None
720
+ ):
721
+ r"""
722
+ Predict the table out of an image in memory
723
+
724
+ Parameters
725
+ ----------
726
+ iocr_page : dict
727
+ Docling provided table data
728
+ eval_res_preds : dict
729
+ Ready predictions provided by the evaluation results
730
+
731
+ Returns
732
+ -------
733
+ docling_output : string
734
+ json response formatted according to Docling api expectations
735
+
736
+ matching_details : string
737
+ json with details about the matching between the pdf cells and the table cells
738
+ """
739
+ AggProfiler().start_agg(self._prof)
740
+
741
+ max_steps = self._config["predict"]["max_steps"]
742
+ beam_size = self._config["predict"]["beam_size"]
743
+ image_batch = self._prepare_image(table_image)
744
+ # Make predictions
745
+ prediction = {}
746
+
747
+ with torch.no_grad():
748
+ # Compute predictions
749
+ if (
750
+ eval_res_preds is not None
751
+ ): # Don't run the model, use the provided predictions
752
+ prediction["bboxes"] = eval_res_preds["bboxes"]
753
+ pred_tag_seq = eval_res_preds["tag_seq"]
754
+ elif self._config["predict"]["bbox"]:
755
+ pred_tag_seq, outputs_class, outputs_coord = self._model.predict(
756
+ image_batch, max_steps, beam_size
757
+ )
758
+
759
+ if outputs_coord is not None:
760
+ bbox_pred = u.box_cxcywh_to_xyxy(outputs_coord)
761
+ prediction["bboxes"] = bbox_pred.tolist()
762
+ else:
763
+ prediction["bboxes"] = []
764
+ if outputs_class is not None:
765
+ result_class = torch.argmax(outputs_class, dim=1)
766
+ prediction["classes"] = result_class.tolist()
767
+ else:
768
+ prediction["classes"] = []
769
+ if self._remove_padding:
770
+ pred_tag_seq, _ = u.remove_padding(pred_tag_seq)
771
+ else:
772
+ pred_tag_seq, _, _ = self._model.predict(
773
+ image_batch, max_steps, beam_size
774
+ )
775
+ # Check if padding should be removed
776
+ if self._remove_padding:
777
+ pred_tag_seq, _ = u.remove_padding(pred_tag_seq)
778
+
779
+ prediction["tag_seq"] = pred_tag_seq
780
+ prediction["rs_seq"] = self._get_html_tags(pred_tag_seq)
781
+ prediction["html_seq"] = otsl_to_html(prediction["rs_seq"], False)
782
+ # Remove implied padding from bbox predictions,
783
+ # that we added on image pre-processing stage
784
+ self._log().debug("----- rs_seq -----")
785
+ self._log().debug(prediction["rs_seq"])
786
+ self._log().debug(len(prediction["rs_seq"]))
787
+ otsl_sqr_chk(prediction["rs_seq"], False)
788
+
789
+ sync, corrected_bboxes = self._check_bbox_sync(prediction)
790
+ if not sync:
791
+ prediction["bboxes"] = corrected_bboxes
792
+
793
+ # Match the cells
794
+ matching_details = {"table_cells": [], "matches": {}}
795
+
796
+ # Table bbox upscaling will scale predicted bboxes too within cell matcher
797
+ scaled_table_bbox = [
798
+ table_bbox[0] / scale_factor,
799
+ table_bbox[1] / scale_factor,
800
+ table_bbox[2] / scale_factor,
801
+ table_bbox[3] / scale_factor,
802
+ ]
803
+
804
+ if len(prediction["bboxes"]) > 0:
805
+ matching_details = self._cell_matcher.match_cells(
806
+ iocr_page, scaled_table_bbox, prediction
807
+ )
808
+ # Post-processing
809
+ if len(prediction["bboxes"]) > 0:
810
+ if self.enable_post_process:
811
+ AggProfiler().begin("post_process", self._prof)
812
+ matching_details = self._post_processor.process(matching_details)
813
+ AggProfiler().end("post_process", self._prof)
814
+
815
+ # Generate the expected Docling responses
816
+ AggProfiler().begin("generate_docling_response", self._prof)
817
+ docling_output = self._generate_tf_response(
818
+ matching_details["table_cells"],
819
+ matching_details["matches"],
820
+ )
821
+
822
+ AggProfiler().end("generate_docling_response", self._prof)
823
+ # Add the docling_output sorted by cell_id into the matching_details
824
+ docling_output.sort(key=lambda item: item["cell_id"])
825
+ matching_details["docling_responses"] = docling_output
826
+
827
+ # Merge docling_output and pdf_cells into one TF output,
828
+ # with deduplicated table cells
829
+ tf_output = self._merge_tf_output(docling_output, matching_details["pdf_cells"])
830
+
831
+ return tf_output, matching_details
832
+
833
+ def _generate_tf_response_dummy(self, table_cells):
834
+ tf_cell_list = []
835
+
836
+ for table_cell in table_cells:
837
+ colspan_val = 1
838
+ if "colspan_val" in table_cell:
839
+ colspan_val = table_cell["colspan_val"]
840
+ rowspan_val = 1
841
+ if "rowspan_val" in table_cell:
842
+ rowspan_val = table_cell["rowspan_val"]
843
+
844
+ column_header = False
845
+ if table_cell["label"] == "ched":
846
+ column_header = True
847
+
848
+ row_header = False
849
+ if table_cell["label"] == "rhed":
850
+ row_header = True
851
+
852
+ row_section = False
853
+ if table_cell["label"] == "srow":
854
+ row_section = True
855
+
856
+ row_id = table_cell["row_id"]
857
+ column_id = table_cell["column_id"]
858
+
859
+ cell_bbox = {
860
+ "b": table_cell["bbox"][3],
861
+ "l": table_cell["bbox"][0],
862
+ "r": table_cell["bbox"][2],
863
+ "t": table_cell["bbox"][1],
864
+ "token": "",
865
+ }
866
+
867
+ tf_cell = {
868
+ "cell_id": table_cell["cell_id"],
869
+ "bbox": cell_bbox, # b,l,r,t,token
870
+ "row_span": rowspan_val,
871
+ "col_span": colspan_val,
872
+ "start_row_offset_idx": row_id,
873
+ "end_row_offset_idx": row_id + rowspan_val,
874
+ "start_col_offset_idx": column_id,
875
+ "end_col_offset_idx": column_id + colspan_val,
876
+ "indentation_level": 0,
877
+ # No text cell bboxes, because no matching was done
878
+ "text_cell_bboxes": [],
879
+ "column_header": column_header,
880
+ "row_header": row_header,
881
+ "row_section": row_section,
882
+ }
883
+ tf_cell_list.append(tf_cell)
884
+ return tf_cell_list
885
+
886
+ def _generate_tf_response(self, table_cells, matches):
887
+ r"""
888
+ Convert the matching details to the expected output for Docling
889
+
890
+ Parameters
891
+ ----------
892
+ table_cells : list of dict
893
+ Each value is a dictionary with keys: "cell_id", "row_id", "column_id",
894
+ "bbox", "label", "class"
895
+ matches : dictionary of lists of table_cells
896
+ A dictionary which is indexed by the pdf_cell_id as key and the value is a list
897
+ of the table_cells that fall inside that pdf cell
898
+
899
+ Returns
900
+ -------
901
+ docling_output : string
902
+ json response formatted according to Docling api expectations
903
+ """
904
+
905
+ # format output to look similar to tests/examples/tf_gte_output_2.json
906
+ tf_cell_list = []
907
+ for pdf_cell_id, pdf_cell_matches in matches.items():
908
+ tf_cell = {
909
+ "bbox": {}, # b,l,r,t,token
910
+ "row_span": 1,
911
+ "col_span": 1,
912
+ "start_row_offset_idx": -1,
913
+ "end_row_offset_idx": -1,
914
+ "start_col_offset_idx": -1,
915
+ "end_col_offset_idx": -1,
916
+ "indentation_level": 0,
917
+ # return text cell bboxes additionally to the matched index
918
+ "text_cell_bboxes": [{}], # b,l,r,t,token
919
+ "column_header": False,
920
+ "row_header": False,
921
+ "row_section": False,
922
+ }
923
+ tf_cell["cell_id"] = int(pdf_cell_id)
924
+
925
+ row_ids = set()
926
+ column_ids = set()
927
+ labels = set()
928
+
929
+ for match in pdf_cell_matches:
930
+ tm = match["table_cell_id"]
931
+ tcl = list(
932
+ filter(lambda table_cell: table_cell["cell_id"] == tm, table_cells)
933
+ )
934
+ if len(tcl) > 0:
935
+ table_cell = tcl[0]
936
+ row_ids.add(table_cell["row_id"])
937
+ column_ids.add(table_cell["column_id"])
938
+ labels.add(table_cell["label"])
939
+
940
+ if table_cell["label"] is not None:
941
+ if table_cell["label"] in ["ched"]:
942
+ tf_cell["column_header"] = True
943
+ if table_cell["label"] in ["rhed"]:
944
+ tf_cell["row_header"] = True
945
+ if table_cell["label"] in ["srow"]:
946
+ tf_cell["row_section"] = True
947
+
948
+ tf_cell["start_col_offset_idx"] = table_cell["column_id"]
949
+ tf_cell["end_col_offset_idx"] = table_cell["column_id"] + 1
950
+ tf_cell["start_row_offset_idx"] = table_cell["row_id"]
951
+ tf_cell["end_row_offset_idx"] = table_cell["row_id"] + 1
952
+
953
+ if "colspan_val" in table_cell:
954
+ tf_cell["col_span"] = table_cell["colspan_val"]
955
+ tf_cell["start_col_offset_idx"] = table_cell["column_id"]
956
+ off_idx = table_cell["column_id"] + tf_cell["col_span"]
957
+ tf_cell["end_col_offset_idx"] = off_idx
958
+ if "rowspan_val" in table_cell:
959
+ tf_cell["row_span"] = table_cell["rowspan_val"]
960
+ tf_cell["start_row_offset_idx"] = table_cell["row_id"]
961
+ tf_cell["end_row_offset_idx"] = (
962
+ table_cell["row_id"] + tf_cell["row_span"]
963
+ )
964
+ if "bbox" in table_cell:
965
+ table_match_bbox = table_cell["bbox"]
966
+ tf_bbox = {
967
+ "b": table_match_bbox[3],
968
+ "l": table_match_bbox[0],
969
+ "r": table_match_bbox[2],
970
+ "t": table_match_bbox[1],
971
+ }
972
+ tf_cell["bbox"] = tf_bbox
973
+
974
+ tf_cell["row_ids"] = list(row_ids)
975
+ tf_cell["column_ids"] = list(column_ids)
976
+ tf_cell["label"] = "None"
977
+ l_labels = list(labels)
978
+ if len(l_labels) > 0:
979
+ tf_cell["label"] = l_labels[0]
980
+ tf_cell_list.append(tf_cell)
981
+ return tf_cell_list
982
+
983
+ def _prepare_image(self, mat_image):
984
+ r"""
985
+ Rescale the image and prepare a batch of 1 with the image as as tensor
986
+
987
+ Parameters
988
+ ----------
989
+ mat_image: cv2.Mat
990
+ The image as an openCV Mat object
991
+
992
+ Returns
993
+ -------
994
+ tensor (batch_size, image_channels, resized_image, resized_image)
995
+ """
996
+ normalize = T.Normalize(
997
+ mean=self._config["dataset"]["image_normalization"]["mean"],
998
+ std=self._config["dataset"]["image_normalization"]["std"],
999
+ )
1000
+ resized_size = self._config["dataset"]["resized_image"]
1001
+ resize = T.Resize([resized_size, resized_size])
1002
+
1003
+ img, _ = normalize(mat_image, None)
1004
+ img, _ = resize(img, None)
1005
+
1006
+ img = img.transpose(2, 1, 0) # (channels, width, height)
1007
+ img = torch.FloatTensor(img / 255.0)
1008
+ image_batch = img.unsqueeze(dim=0)
1009
+ image_batch = image_batch.to(device=self._device)
1010
+ return image_batch
1011
+
1012
+ def _get_html_tags(self, seq):
1013
+ r"""
1014
+ Convert indices to actual html tags
1015
+
1016
+ """
1017
+ # Map the tag indices back to actual tags (without start, end)
1018
+ html_tags = [self._rev_word_map[ind] for ind in seq[1:-1]]
1019
+
1020
+ return html_tags