docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,596 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import copy
|
6
|
+
import logging
|
7
|
+
import re
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
|
11
|
+
import docling_ibm_models.tableformer.otsl as otsl
|
12
|
+
import docling_ibm_models.tableformer.settings as s
|
13
|
+
|
14
|
+
# LOG_LEVEL = logging.INFO
|
15
|
+
# LOG_LEVEL = logging.DEBUG
|
16
|
+
LOG_LEVEL = logging.WARN
|
17
|
+
|
18
|
+
# Cell labels
|
19
|
+
BODY = "body"
|
20
|
+
COL_HEADER = "col_header"
|
21
|
+
MULTI_COL_HEADER = "multi_col_header"
|
22
|
+
MULTI_ROW_HEADER = "multi_row_header"
|
23
|
+
MULTI_ROW = "multi_row"
|
24
|
+
MULTI_COL = "multi_col"
|
25
|
+
|
26
|
+
|
27
|
+
def validate_bboxes_page(bboxes):
|
28
|
+
r"""
|
29
|
+
Useful function for Debugging
|
30
|
+
|
31
|
+
Validate that the bboxes have a positive area in the page coordinate system
|
32
|
+
|
33
|
+
Parameters
|
34
|
+
----------
|
35
|
+
bboxes : list of 4
|
36
|
+
Each element of the list is expected to be a bbox in the page coordinates system
|
37
|
+
|
38
|
+
Returns
|
39
|
+
-------
|
40
|
+
int
|
41
|
+
The number of invalid bboxes.
|
42
|
+
"""
|
43
|
+
invalid_counter = 0
|
44
|
+
for i, bbox in enumerate(bboxes):
|
45
|
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
46
|
+
|
47
|
+
if area < 0:
|
48
|
+
print("Wrong bbox: {} - {}".format(i, bbox))
|
49
|
+
invalid_counter += 1
|
50
|
+
|
51
|
+
if invalid_counter > 0:
|
52
|
+
print("Invalid bboxes in total: {}".format(invalid_counter))
|
53
|
+
return invalid_counter
|
54
|
+
|
55
|
+
|
56
|
+
def find_intersection(b1, b2):
|
57
|
+
r"""
|
58
|
+
Compute the intersection between 2 bboxes
|
59
|
+
|
60
|
+
Parameters
|
61
|
+
----------
|
62
|
+
b1 : list of 4
|
63
|
+
The page x1y1x2y2 coordinates of the bbox
|
64
|
+
b2 : list of 4
|
65
|
+
The page x1y1x2y2 coordinates of the bbox
|
66
|
+
|
67
|
+
Returns
|
68
|
+
-------
|
69
|
+
The bbox of the intersection or None if there is no intersection
|
70
|
+
"""
|
71
|
+
# Check when the bboxes do NOT intersect
|
72
|
+
if b1[2] < b2[0] or b2[2] < b1[0] or b1[1] > b2[3] or b2[1] > b2[3]:
|
73
|
+
return None
|
74
|
+
|
75
|
+
i_bbox = [
|
76
|
+
max(b1[0], b2[0]),
|
77
|
+
max(b1[1], b2[1]),
|
78
|
+
min(b1[2], b2[2]),
|
79
|
+
min(b1[3], b2[3]),
|
80
|
+
]
|
81
|
+
return i_bbox
|
82
|
+
|
83
|
+
|
84
|
+
class CellMatcher:
|
85
|
+
r"""
|
86
|
+
Match the table cells to the pdf page cells.
|
87
|
+
|
88
|
+
NOTICE: PDF page coordinate system vs table coordinate system.
|
89
|
+
In both systems the bboxes are described in as (x1, y1, x2, y2) with the following meaning:
|
90
|
+
|
91
|
+
Page coordinate system:
|
92
|
+
- Origin (0, 0) at the lower-left corner
|
93
|
+
- (x1, y1) the lower left corner of the box
|
94
|
+
- (x2, y2) the upper right corner of the box
|
95
|
+
|
96
|
+
Table coordinate system:
|
97
|
+
- Origin (0, 0) at the upper-left corner
|
98
|
+
- (x1, y1) the upper left corner of the box
|
99
|
+
- (x2, y2) the lower right corner of the box
|
100
|
+
"""
|
101
|
+
|
102
|
+
def __init__(self, config):
|
103
|
+
self._config = config
|
104
|
+
self._iou_thres = config["predict"]["pdf_cell_iou_thres"]
|
105
|
+
|
106
|
+
def _log(self):
|
107
|
+
# Setup a custom logger
|
108
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
109
|
+
|
110
|
+
def match_cells(self, iocr_page, table_bbox, prediction):
|
111
|
+
r"""
|
112
|
+
Convert the tablemodel prediction into the Docling format
|
113
|
+
|
114
|
+
Parameters
|
115
|
+
----------
|
116
|
+
iocr_page : dict
|
117
|
+
The original Docling provided table data
|
118
|
+
prediction : dict
|
119
|
+
The dictionary has the keys:
|
120
|
+
"tag_seq": The sequence in indices from the WORDMAP
|
121
|
+
"html_seq": The sequence as html tags
|
122
|
+
"bboxes": The bounding boxes
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
matching_details : dict
|
127
|
+
Dictionary with all details about the mathings between the table and pdf cells
|
128
|
+
"""
|
129
|
+
pdf_cells = copy.deepcopy(iocr_page["tokens"])
|
130
|
+
for word in pdf_cells:
|
131
|
+
word["bbox"] = [
|
132
|
+
word["bbox"]["l"],
|
133
|
+
word["bbox"]["t"],
|
134
|
+
word["bbox"]["r"],
|
135
|
+
word["bbox"]["b"],
|
136
|
+
]
|
137
|
+
table_bboxes = prediction["bboxes"]
|
138
|
+
table_classes = prediction["classes"]
|
139
|
+
# BBOXES transformed...
|
140
|
+
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
|
141
|
+
|
142
|
+
# Combine the table tags and bboxes into TableCells
|
143
|
+
html_seq = prediction["html_seq"]
|
144
|
+
otsl_seq = prediction["rs_seq"]
|
145
|
+
table_cells = self._build_table_cells(
|
146
|
+
html_seq, otsl_seq, table_bboxes_page, table_classes
|
147
|
+
)
|
148
|
+
matches, matches_counter = self._intersection_over_pdf_match(
|
149
|
+
table_cells, pdf_cells
|
150
|
+
)
|
151
|
+
|
152
|
+
self._log().debug("matches_counter: {}".format(matches_counter))
|
153
|
+
|
154
|
+
# Build output
|
155
|
+
matching_details = {
|
156
|
+
"iou_threshold": self._iou_thres,
|
157
|
+
"table_bbox": table_bbox,
|
158
|
+
"prediction_bboxes_page": table_bboxes_page, # Make easier the comparison with c++
|
159
|
+
"prediction": prediction,
|
160
|
+
"pdf_cells": pdf_cells,
|
161
|
+
"page_height": iocr_page["height"],
|
162
|
+
"page_width": iocr_page["width"],
|
163
|
+
"table_cells": table_cells,
|
164
|
+
"pdf_cells": pdf_cells,
|
165
|
+
"matches": matches,
|
166
|
+
}
|
167
|
+
return matching_details
|
168
|
+
|
169
|
+
def match_cells_dummy(self, iocr_page, table_bbox, prediction):
|
170
|
+
r"""
|
171
|
+
Convert the tablemodel prediction into the Docling format
|
172
|
+
DUMMY version doesn't do matching with text cells, but propagates predicted bboxes,
|
173
|
+
respecting the rest of the format
|
174
|
+
|
175
|
+
Parameters
|
176
|
+
----------
|
177
|
+
iocr_page : dict
|
178
|
+
The original Docling provided table data
|
179
|
+
prediction : dict
|
180
|
+
The dictionary has the keys:
|
181
|
+
"tag_seq": The sequence in indices from the WORDMAP
|
182
|
+
"html_seq": The sequence as html tags
|
183
|
+
"bboxes": The bounding boxes
|
184
|
+
|
185
|
+
Returns
|
186
|
+
-------
|
187
|
+
matching_details : dict
|
188
|
+
Dictionary with all details about the mathings between the table and pdf cells
|
189
|
+
"""
|
190
|
+
pdf_cells = copy.deepcopy(iocr_page["tokens"])
|
191
|
+
for word in pdf_cells:
|
192
|
+
word["bbox"] = [
|
193
|
+
word["bbox"]["l"],
|
194
|
+
word["bbox"]["t"],
|
195
|
+
word["bbox"]["r"],
|
196
|
+
word["bbox"]["b"],
|
197
|
+
]
|
198
|
+
|
199
|
+
table_bboxes = prediction["bboxes"]
|
200
|
+
table_classes = prediction["classes"]
|
201
|
+
# BBOXES transformed...
|
202
|
+
table_bboxes_page = self._translate_bboxes(table_bbox, table_bboxes)
|
203
|
+
|
204
|
+
# Combine the table tags and bboxes into TableCells
|
205
|
+
html_seq = prediction["html_seq"]
|
206
|
+
otsl_seq = prediction["rs_seq"]
|
207
|
+
|
208
|
+
table_cells = self._build_table_cells(
|
209
|
+
html_seq, otsl_seq, table_bboxes_page, table_classes
|
210
|
+
)
|
211
|
+
|
212
|
+
# Build output
|
213
|
+
matching_details = {
|
214
|
+
"iou_threshold": self._iou_thres,
|
215
|
+
"table_bbox": table_bbox,
|
216
|
+
"prediction_bboxes_page": table_bboxes_page,
|
217
|
+
"prediction": prediction,
|
218
|
+
"pdf_cells": pdf_cells,
|
219
|
+
"page_height": iocr_page["height"],
|
220
|
+
"page_width": iocr_page["width"],
|
221
|
+
"table_cells": table_cells,
|
222
|
+
"pdf_cells": pdf_cells,
|
223
|
+
"matches": {},
|
224
|
+
}
|
225
|
+
return matching_details
|
226
|
+
|
227
|
+
def _build_table_cells(self, html_seq, otsl_seq, bboxes, table_classes):
|
228
|
+
r"""
|
229
|
+
Combine the tags and bboxes of the table into unified TableCell objects.
|
230
|
+
Each TableCell takes a row_id, column_id index based on the html structure provided by
|
231
|
+
html_seq.
|
232
|
+
It is assumed that the bboxes are in sync with the appearence of the closing </td>
|
233
|
+
|
234
|
+
Parameters
|
235
|
+
----------
|
236
|
+
html_seq : list
|
237
|
+
List of html tags
|
238
|
+
bboxes : list of lists of 4
|
239
|
+
Bboxes for the table cells at the page origin
|
240
|
+
|
241
|
+
Returns
|
242
|
+
-------
|
243
|
+
list of dict
|
244
|
+
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
245
|
+
"""
|
246
|
+
table_html_structure = {
|
247
|
+
"html": {"structure": {"tokens": html_seq}},
|
248
|
+
"split": "predict",
|
249
|
+
"filename": "memory",
|
250
|
+
}
|
251
|
+
|
252
|
+
otsl_spans = {}
|
253
|
+
|
254
|
+
# r, o = otsl.html_to_otsl(table, writer, true, extra_debug, include_html)
|
255
|
+
r, o = otsl.html_to_otsl(table_html_structure, None, False, False, True, False)
|
256
|
+
if not r:
|
257
|
+
ermsg = "ERR#: COULD NOT CONVERT TO RS THIS TABLE TO COMPUTE SPANS"
|
258
|
+
print(ermsg)
|
259
|
+
else:
|
260
|
+
otsl_spans = o["otsl_spans"]
|
261
|
+
|
262
|
+
table_cells = []
|
263
|
+
|
264
|
+
# It is assumed that the bboxes appear in sync (at the same order) as the TDs
|
265
|
+
cell_id = 0
|
266
|
+
|
267
|
+
row_id = -1
|
268
|
+
column_id = -1
|
269
|
+
in_header = False
|
270
|
+
in_body = False
|
271
|
+
multicol_tag = ""
|
272
|
+
colspan_val = 0
|
273
|
+
rowspan_val = 0
|
274
|
+
|
275
|
+
mode = "OTSL"
|
276
|
+
if mode == "HTML":
|
277
|
+
for tag in html_seq:
|
278
|
+
label = None
|
279
|
+
if tag == "<thead>":
|
280
|
+
in_header = True
|
281
|
+
multicol_tag = ""
|
282
|
+
colspan_val = 0
|
283
|
+
rowspan_val = 0
|
284
|
+
elif tag == "</thead>":
|
285
|
+
in_header = False
|
286
|
+
multicol_tag = ""
|
287
|
+
colspan_val = 0
|
288
|
+
rowspan_val = 0
|
289
|
+
elif tag == "<tbody>":
|
290
|
+
in_body = True
|
291
|
+
multicol_tag = ""
|
292
|
+
colspan_val = 0
|
293
|
+
rowspan_val = 0
|
294
|
+
elif tag == "</tbody>":
|
295
|
+
in_body = False
|
296
|
+
multicol_tag = ""
|
297
|
+
colspan_val = 0
|
298
|
+
rowspan_val = 0
|
299
|
+
elif tag == "<td>" or tag == "<td":
|
300
|
+
column_id += 1
|
301
|
+
multicol_tag = ""
|
302
|
+
colspan_val = 0
|
303
|
+
rowspan_val = 0
|
304
|
+
if tag == "<td":
|
305
|
+
multicol_tag = tag
|
306
|
+
elif tag == "<tr>":
|
307
|
+
row_id += 1
|
308
|
+
column_id = -1
|
309
|
+
multicol_tag = ""
|
310
|
+
colspan_val = 0
|
311
|
+
rowspan_val = 0
|
312
|
+
elif "colspan" in tag:
|
313
|
+
label = MULTI_COL
|
314
|
+
multicol_tag += tag
|
315
|
+
colspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
|
316
|
+
elif "rowspan" in tag:
|
317
|
+
label = MULTI_ROW
|
318
|
+
multicol_tag += tag
|
319
|
+
rowspan_val = int(re.findall(r'"([^"]*)"', tag)[0])
|
320
|
+
elif tag == "</td>": # Create a TableCell on each closing td
|
321
|
+
if len(multicol_tag) > 0:
|
322
|
+
multicol_tag += tag
|
323
|
+
if in_header:
|
324
|
+
if label is None:
|
325
|
+
label = COL_HEADER
|
326
|
+
elif label == MULTI_COL:
|
327
|
+
label = MULTI_COL_HEADER
|
328
|
+
elif label == MULTI_ROW:
|
329
|
+
label = MULTI_ROW_HEADER
|
330
|
+
if label is None and in_body:
|
331
|
+
label = BODY
|
332
|
+
|
333
|
+
err_mismatch = "Mismatching bboxes with closing TDs {} < {}".format(
|
334
|
+
cell_id, len(bboxes)
|
335
|
+
)
|
336
|
+
assert cell_id < len(bboxes), err_mismatch
|
337
|
+
bbox = bboxes[cell_id]
|
338
|
+
cell_class = table_classes[cell_id]
|
339
|
+
|
340
|
+
table_cell = {}
|
341
|
+
table_cell["cell_id"] = cell_id
|
342
|
+
table_cell["row_id"] = row_id
|
343
|
+
table_cell["column_id"] = column_id
|
344
|
+
table_cell["bbox"] = bbox
|
345
|
+
table_cell["cell_class"] = cell_class
|
346
|
+
table_cell["label"] = label
|
347
|
+
table_cell["multicol_tag"] = multicol_tag
|
348
|
+
if colspan_val > 0:
|
349
|
+
table_cell["colspan_val"] = colspan_val
|
350
|
+
column_id += (
|
351
|
+
colspan_val - 1
|
352
|
+
) # Shift column index to account for span
|
353
|
+
if rowspan_val > 0:
|
354
|
+
table_cell["rowspan_val"] = rowspan_val
|
355
|
+
|
356
|
+
table_cells.append(table_cell)
|
357
|
+
cell_id += 1
|
358
|
+
|
359
|
+
if mode == "OTSL":
|
360
|
+
row_id = 0
|
361
|
+
column_id = 0
|
362
|
+
multicol_tag = ""
|
363
|
+
otsl_line = []
|
364
|
+
cell_id_line = []
|
365
|
+
|
366
|
+
for tag in otsl_seq:
|
367
|
+
otsl_line.append(tag)
|
368
|
+
if tag == "nl":
|
369
|
+
row_id += 1
|
370
|
+
column_id = 0
|
371
|
+
otsl_line = []
|
372
|
+
cell_id_line = []
|
373
|
+
if tag in ["fcel", "ecel", "xcel", "ched", "rhed", "srow"]:
|
374
|
+
cell_id_line.append(cell_id)
|
375
|
+
bbox = [0.0, 0.0, 0.0, 0.0]
|
376
|
+
if cell_id < len(bboxes):
|
377
|
+
bbox = bboxes[cell_id]
|
378
|
+
|
379
|
+
cell_class = 2
|
380
|
+
if cell_id < len(table_classes):
|
381
|
+
cell_class = table_classes[cell_id]
|
382
|
+
label = tag
|
383
|
+
|
384
|
+
table_cell = {}
|
385
|
+
table_cell["cell_id"] = cell_id
|
386
|
+
table_cell["row_id"] = row_id
|
387
|
+
table_cell["column_id"] = column_id
|
388
|
+
table_cell["bbox"] = bbox
|
389
|
+
table_cell["cell_class"] = cell_class
|
390
|
+
table_cell["label"] = label
|
391
|
+
table_cell["multicol_tag"] = multicol_tag
|
392
|
+
|
393
|
+
colspan_val = 0
|
394
|
+
rowspan_val = 0
|
395
|
+
|
396
|
+
if cell_id in otsl_spans:
|
397
|
+
colspan_val = otsl_spans[cell_id][0]
|
398
|
+
rowspan_val = otsl_spans[cell_id][1]
|
399
|
+
if colspan_val > 0:
|
400
|
+
table_cell["colspan_val"] = colspan_val
|
401
|
+
if rowspan_val > 0:
|
402
|
+
table_cell["rowspan_val"] = rowspan_val
|
403
|
+
|
404
|
+
table_cells.append(table_cell)
|
405
|
+
cell_id += 1
|
406
|
+
if tag != "nl":
|
407
|
+
column_id += 1
|
408
|
+
|
409
|
+
return table_cells
|
410
|
+
|
411
|
+
def _translate_bboxes(self, table_bbox, cell_bboxes):
|
412
|
+
r"""
|
413
|
+
Translate table cell bboxes to the lower-left corner of the page.
|
414
|
+
|
415
|
+
The cells of the table are given:
|
416
|
+
- Origin at the top left corner
|
417
|
+
- Point A: Top left corner
|
418
|
+
- Point B: Low right corner
|
419
|
+
- Coordinate values are normalized to the table width/height
|
420
|
+
|
421
|
+
Parameters
|
422
|
+
----------
|
423
|
+
table_bbox : list of 4
|
424
|
+
The whole table bbox page coordinates
|
425
|
+
cell_bboxes : list of lists of 4
|
426
|
+
The bboxes of the table cells
|
427
|
+
|
428
|
+
Returns
|
429
|
+
-------
|
430
|
+
list of 4
|
431
|
+
The translated bboxes of the table cells
|
432
|
+
"""
|
433
|
+
W = table_bbox[2] - table_bbox[0]
|
434
|
+
H = table_bbox[3] - table_bbox[1]
|
435
|
+
b = np.asarray(cell_bboxes)
|
436
|
+
t_mask = np.asarray(
|
437
|
+
[table_bbox[0], table_bbox[3], table_bbox[0], table_bbox[3]]
|
438
|
+
)
|
439
|
+
m = np.asarray([W, -H, W, -H])
|
440
|
+
page_bboxes_y_flipped = t_mask + m * b
|
441
|
+
page_bboxes = page_bboxes_y_flipped[:, [0, 3, 2, 1]] # Flip y1' with y2'
|
442
|
+
page_bboxes_list = page_bboxes.tolist()
|
443
|
+
|
444
|
+
t_height = table_bbox[3]
|
445
|
+
page_bboxes_list1 = []
|
446
|
+
for page_bbox in page_bboxes_list:
|
447
|
+
page_bbox1 = [
|
448
|
+
page_bbox[0],
|
449
|
+
t_height - page_bbox[3] + table_bbox[1],
|
450
|
+
page_bbox[2],
|
451
|
+
t_height - page_bbox[1] + table_bbox[1],
|
452
|
+
]
|
453
|
+
page_bboxes_list1.append(page_bbox1)
|
454
|
+
return page_bboxes_list1
|
455
|
+
|
456
|
+
def _intersection_over_pdf_match(self, table_cells, pdf_cells):
|
457
|
+
r"""
|
458
|
+
Compute Intersection between table cells and pdf cells,
|
459
|
+
match 1 pdf cell with highest intersection with only 1 table cell.
|
460
|
+
|
461
|
+
First compute and cache the areas for all involved bboxes.
|
462
|
+
Then compute the pairwise intersections
|
463
|
+
|
464
|
+
Parameters
|
465
|
+
----------
|
466
|
+
table_cells : list of dict
|
467
|
+
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
468
|
+
|
469
|
+
pdf_cells : list of dict
|
470
|
+
Each element of the list is a dictionary which should have the keys: "id", "bbox"
|
471
|
+
Returns
|
472
|
+
-------
|
473
|
+
dictionary of lists of table_cells
|
474
|
+
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
|
475
|
+
of the table_cells that fall inside that pdf cell
|
476
|
+
int
|
477
|
+
Number of total matches
|
478
|
+
"""
|
479
|
+
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
|
480
|
+
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
|
481
|
+
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
|
482
|
+
)
|
483
|
+
|
484
|
+
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
|
485
|
+
matches = {}
|
486
|
+
matches_counter = 0
|
487
|
+
|
488
|
+
# Compute Intersections and build matches
|
489
|
+
for i, table_cell in enumerate(table_cells):
|
490
|
+
table_cell_id = table_cell["cell_id"]
|
491
|
+
t_bbox = table_cell["bbox"]
|
492
|
+
|
493
|
+
for j, pdf_cell in enumerate(pdf_cells):
|
494
|
+
pdf_cell_id = pdf_cell["id"]
|
495
|
+
p_bbox = pdf_cell["bbox"]
|
496
|
+
|
497
|
+
# Compute intersection
|
498
|
+
i_bbox = find_intersection(t_bbox, p_bbox)
|
499
|
+
if i_bbox is None:
|
500
|
+
continue
|
501
|
+
|
502
|
+
# Compute IOU and filter on threshold
|
503
|
+
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
|
504
|
+
iopdf = 0
|
505
|
+
if float(pdf_bboxes_areas[j]) > 0:
|
506
|
+
iopdf = i_bbox_area / float(pdf_bboxes_areas[j])
|
507
|
+
|
508
|
+
if iopdf > 0:
|
509
|
+
match = {"table_cell_id": table_cell_id, "iopdf": iopdf}
|
510
|
+
if pdf_cell_id not in matches:
|
511
|
+
matches[pdf_cell_id] = [match]
|
512
|
+
matches_counter += 1
|
513
|
+
else:
|
514
|
+
# Check if the same match was not already counted
|
515
|
+
if match not in matches[pdf_cell_id]:
|
516
|
+
matches[pdf_cell_id].append(match)
|
517
|
+
matches_counter += 1
|
518
|
+
return matches, matches_counter
|
519
|
+
|
520
|
+
def _iou_match(self, table_cells, pdf_cells):
|
521
|
+
r"""
|
522
|
+
Use Intersection over Union to decide the matching between table cells and pdf cells
|
523
|
+
|
524
|
+
First compute and cache the areas for all involved bboxes.
|
525
|
+
Then compute the pairwise intersections and IOUs and keep those pairs that exceed the IOU
|
526
|
+
threshold
|
527
|
+
|
528
|
+
Parameters
|
529
|
+
----------
|
530
|
+
table_cells : list of dict
|
531
|
+
Each value is a dictionary with keys: "cell_id", "row_id", "column_id", "bbox", "label"
|
532
|
+
|
533
|
+
pdf_cells : list of dict
|
534
|
+
Each element of the list is a dictionary which should have the keys: "id", "bbox"
|
535
|
+
Returns
|
536
|
+
-------
|
537
|
+
dictionary of lists of table_cells
|
538
|
+
Return a dictionary which is indexed by the pdf_cell_id as key and the value is a list
|
539
|
+
of the table_cells that fall inside that pdf cell
|
540
|
+
int
|
541
|
+
Number of total matches
|
542
|
+
"""
|
543
|
+
table_bboxes = np.asarray([t["bbox"] for t in table_cells])
|
544
|
+
pdf_bboxes = np.asarray([p["bbox"] for p in pdf_cells])
|
545
|
+
|
546
|
+
# Cache the areas for table bboxes and pdf bboxes
|
547
|
+
table_bboxes_areas = (table_bboxes[:, 2] - table_bboxes[:, 0]) * (
|
548
|
+
table_bboxes[:, 3] - table_bboxes[:, 1]
|
549
|
+
)
|
550
|
+
|
551
|
+
pdf_bboxes_areas = (pdf_bboxes[:, 2] - pdf_bboxes[:, 0]) * (
|
552
|
+
pdf_bboxes[:, 3] - pdf_bboxes[:, 1]
|
553
|
+
)
|
554
|
+
|
555
|
+
# key: pdf_cell_id, value: list of TableCell that fall inside that pdf_cell
|
556
|
+
matches = {}
|
557
|
+
matches_counter = 0
|
558
|
+
|
559
|
+
# Compute IOUs and build matches
|
560
|
+
for i, table_cell in enumerate(table_cells):
|
561
|
+
table_cell_id = table_cell["cell_id"]
|
562
|
+
t_bbox = table_cell["bbox"]
|
563
|
+
|
564
|
+
for j, pdf_cell in enumerate(pdf_cells):
|
565
|
+
pdf_cell_id = pdf_cell["id"]
|
566
|
+
pdf_cell_text = pdf_cell["text"]
|
567
|
+
p_bbox = pdf_cell["bbox"]
|
568
|
+
|
569
|
+
# Compute intersection
|
570
|
+
i_bbox = find_intersection(t_bbox, p_bbox)
|
571
|
+
if i_bbox is None:
|
572
|
+
continue
|
573
|
+
|
574
|
+
# Compute IOU and filter on threshold
|
575
|
+
i_bbox_area = (i_bbox[2] - i_bbox[0]) * (i_bbox[3] - i_bbox[1])
|
576
|
+
iou = 0
|
577
|
+
div_area = float(
|
578
|
+
table_bboxes_areas[i] + pdf_bboxes_areas[j] - i_bbox_area
|
579
|
+
)
|
580
|
+
if div_area > 0:
|
581
|
+
iou = i_bbox_area / div_area
|
582
|
+
if iou < self._iou_thres:
|
583
|
+
continue
|
584
|
+
|
585
|
+
if pdf_cell_id not in matches:
|
586
|
+
matches[pdf_cell_id] = []
|
587
|
+
|
588
|
+
match = {
|
589
|
+
"table_cell_id": table_cell_id,
|
590
|
+
"iou": iou,
|
591
|
+
"text": pdf_cell_text,
|
592
|
+
}
|
593
|
+
matches[pdf_cell_id].append(match)
|
594
|
+
matches_counter += 1
|
595
|
+
|
596
|
+
return matches, matches_counter
|