docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,541 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import copy
|
6
|
+
import logging
|
7
|
+
from itertools import groupby
|
8
|
+
|
9
|
+
import docling_ibm_models.tableformer.settings as s
|
10
|
+
|
11
|
+
LOG_LEVEL = logging.INFO
|
12
|
+
# LOG_LEVEL = logging.DEBUG
|
13
|
+
logger = s.get_custom_logger("consolidate", LOG_LEVEL)
|
14
|
+
png_files = {} # Evaluation files
|
15
|
+
total_pics = 0
|
16
|
+
|
17
|
+
|
18
|
+
class bcolors:
|
19
|
+
HEADER = "\033[95m"
|
20
|
+
OKBLUE = "\033[94m"
|
21
|
+
OKCYAN = "\033[96m"
|
22
|
+
OKGREEN = "\033[92m"
|
23
|
+
WARNING = "\033[93m"
|
24
|
+
FAIL = "\033[91m"
|
25
|
+
ENDC = "\033[0m"
|
26
|
+
BOLD = "\033[1m"
|
27
|
+
UNDERLINE = "\033[4m"
|
28
|
+
|
29
|
+
|
30
|
+
def otsl_clean(rs_list):
|
31
|
+
new_rs_list = []
|
32
|
+
stop_list = ["<pad>", "<unk>", "<start>", "<end>"]
|
33
|
+
for tag in rs_list:
|
34
|
+
if tag not in stop_list:
|
35
|
+
new_rs_list.append(tag)
|
36
|
+
return new_rs_list
|
37
|
+
|
38
|
+
|
39
|
+
def otsl_sqr_chk(rs_list, name, logdebug):
|
40
|
+
rs_list_split = [
|
41
|
+
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
42
|
+
]
|
43
|
+
isSquare = True
|
44
|
+
if len(rs_list_split) > 0:
|
45
|
+
init_tag_len = len(rs_list_split[0]) + 1
|
46
|
+
for ind, ln in enumerate(rs_list_split):
|
47
|
+
ln.append("nl")
|
48
|
+
if len(ln) != init_tag_len:
|
49
|
+
isSquare = False
|
50
|
+
if isSquare:
|
51
|
+
if logdebug:
|
52
|
+
print(
|
53
|
+
"{}*OK* Table is square! *OK*{}".format(
|
54
|
+
bcolors.OKGREEN, bcolors.ENDC
|
55
|
+
)
|
56
|
+
)
|
57
|
+
else:
|
58
|
+
err_name = "{}*ERR* " + name + " *ERR*{}"
|
59
|
+
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
|
60
|
+
print(
|
61
|
+
"{}*ERR* Table is not square! *ERR*{}".format(
|
62
|
+
bcolors.FAIL, bcolors.ENDC
|
63
|
+
)
|
64
|
+
)
|
65
|
+
return isSquare
|
66
|
+
|
67
|
+
|
68
|
+
def otsl_pad_to_sqr(rs_list, pad_tag):
|
69
|
+
new_list = []
|
70
|
+
rs_list_split = [
|
71
|
+
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
72
|
+
]
|
73
|
+
max_row_len = 0
|
74
|
+
for ind, ln in enumerate(rs_list_split):
|
75
|
+
if len(ln) > max_row_len:
|
76
|
+
max_row_len = len(ln)
|
77
|
+
for ind, ln in enumerate(rs_list_split):
|
78
|
+
ln += [pad_tag] * (max_row_len - len(ln))
|
79
|
+
ln.append("nl")
|
80
|
+
new_list.extend(ln)
|
81
|
+
return new_list
|
82
|
+
|
83
|
+
|
84
|
+
def otsl_tags_cells_sync_chk(rs_list, cells, name, logdebug):
|
85
|
+
countCellTags = 0
|
86
|
+
isGood = True
|
87
|
+
for rsTag in rs_list:
|
88
|
+
if rsTag in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
89
|
+
countCellTags += 1
|
90
|
+
if countCellTags != len(cells):
|
91
|
+
err_name = "{}*!ERR* " + name + " *ERR!*{}"
|
92
|
+
print(err_name.format(bcolors.FAIL, bcolors.ENDC))
|
93
|
+
err_msg = "{}*!ERR* Tags are not in sync with cells! *ERR!*{}"
|
94
|
+
print(err_msg.format(bcolors.FAIL, bcolors.ENDC))
|
95
|
+
isGood = False
|
96
|
+
return isGood
|
97
|
+
|
98
|
+
|
99
|
+
def otsl_check_down(rs_split, x, y):
|
100
|
+
distance = 1
|
101
|
+
elem = "ucel"
|
102
|
+
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "lcel", "nl"]
|
103
|
+
while elem not in goodlist and y < len(rs_split) - 1:
|
104
|
+
y += 1
|
105
|
+
distance += 1
|
106
|
+
elem = rs_split[y][x]
|
107
|
+
if elem in goodlist:
|
108
|
+
distance -= 1
|
109
|
+
return distance
|
110
|
+
|
111
|
+
|
112
|
+
def otsl_check_right(rs_split, x, y):
|
113
|
+
distance = 1
|
114
|
+
elem = "lcel"
|
115
|
+
goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "ucel", "nl"]
|
116
|
+
while elem not in goodlist and x < (len(rs_split[y]) - 1):
|
117
|
+
x += 1
|
118
|
+
distance += 1
|
119
|
+
elem = rs_split[y][x]
|
120
|
+
if elem in goodlist:
|
121
|
+
distance -= 1
|
122
|
+
return distance
|
123
|
+
|
124
|
+
|
125
|
+
def otsl_to_html(rs_list, logdebug):
|
126
|
+
if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
127
|
+
# Most likely already HTML...
|
128
|
+
return rs_list
|
129
|
+
html_table = []
|
130
|
+
if logdebug:
|
131
|
+
print("{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC))
|
132
|
+
|
133
|
+
if not otsl_sqr_chk(rs_list, "---", logdebug):
|
134
|
+
# PAD TABLE TO SQUARE
|
135
|
+
print("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
|
136
|
+
rs_list = otsl_pad_to_sqr(rs_list, "lcel")
|
137
|
+
|
138
|
+
# 2D structure, line by line:
|
139
|
+
rs_list_split = [
|
140
|
+
list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
|
141
|
+
]
|
142
|
+
|
143
|
+
if logdebug:
|
144
|
+
print("")
|
145
|
+
|
146
|
+
# Sequentially store indexes of 2D spans that were registered to avoid re-registering them
|
147
|
+
registry_2d_span = []
|
148
|
+
|
149
|
+
# Iterate all elements in the rs line, and look right / down to detect spans
|
150
|
+
# If span detected - run function to find size of the span
|
151
|
+
# repeat with all cells
|
152
|
+
thead_present = False
|
153
|
+
|
154
|
+
for rs_row_ind, rs_row in enumerate(rs_list_split):
|
155
|
+
html_list = []
|
156
|
+
|
157
|
+
if not thead_present:
|
158
|
+
if "ched" in rs_list_split[rs_row_ind]:
|
159
|
+
html_list.append("<thead>")
|
160
|
+
thead_present = True
|
161
|
+
|
162
|
+
if thead_present:
|
163
|
+
if "ched" not in rs_list_split[rs_row_ind]:
|
164
|
+
html_list.append("</thead>")
|
165
|
+
thead_present = False
|
166
|
+
|
167
|
+
html_list.append("<tr>")
|
168
|
+
for rs_cell_ind, rs_cell in enumerate(rs_list_split[rs_row_ind]):
|
169
|
+
if rs_cell in ["fcel", "ched", "rhed", "srow", "ecel"]:
|
170
|
+
rdist = 0
|
171
|
+
ddist = 0
|
172
|
+
xrdist = 0
|
173
|
+
xddist = 0
|
174
|
+
span = False
|
175
|
+
# Check if it has horizontal span:
|
176
|
+
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
|
177
|
+
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "lcel":
|
178
|
+
rdist = otsl_check_right(rs_list_split, rs_cell_ind, rs_row_ind)
|
179
|
+
span = True
|
180
|
+
# Check if it has vertical span:
|
181
|
+
if rs_row_ind + 1 < len(rs_list_split):
|
182
|
+
# print(">>>")
|
183
|
+
# print(rs_list_split[rs_row_ind + 1])
|
184
|
+
# print(">>> rs_cell_ind = {}".format(rs_cell_ind))
|
185
|
+
if rs_list_split[rs_row_ind + 1][rs_cell_ind] == "ucel":
|
186
|
+
ddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
|
187
|
+
span = True
|
188
|
+
# Check if it has 2D span:
|
189
|
+
if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
|
190
|
+
if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "xcel":
|
191
|
+
xrdist = otsl_check_right(
|
192
|
+
rs_list_split, rs_cell_ind, rs_row_ind
|
193
|
+
)
|
194
|
+
xddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
|
195
|
+
span = True
|
196
|
+
# Check if this 2D span was already registered,
|
197
|
+
# If not - register, if yes - cancel span
|
198
|
+
# print("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
|
199
|
+
# print("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
|
200
|
+
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
|
201
|
+
for y in range(rs_row_ind, xddist + rs_row_ind):
|
202
|
+
reg2dind = str(x) + "_" + str(y)
|
203
|
+
# print(reg2dind)
|
204
|
+
if reg2dind in registry_2d_span:
|
205
|
+
# Cell of the span is already in, cancel current span
|
206
|
+
span = False
|
207
|
+
if span:
|
208
|
+
# None of the span cells were previously registered
|
209
|
+
# Register an entire span
|
210
|
+
for x in range(rs_cell_ind, xrdist + rs_cell_ind):
|
211
|
+
for y in range(rs_row_ind, xddist + rs_row_ind):
|
212
|
+
reg2dind = str(x) + "_" + str(y)
|
213
|
+
registry_2d_span.append(reg2dind)
|
214
|
+
if span:
|
215
|
+
html_list.append("<td")
|
216
|
+
if rdist > 1:
|
217
|
+
html_list.append(' colspan="' + str(rdist) + '"')
|
218
|
+
if ddist > 1:
|
219
|
+
html_list.append(' rowspan="' + str(ddist) + '"')
|
220
|
+
if xrdist > 1:
|
221
|
+
html_list.append(' rowspan="' + str(xddist) + '"')
|
222
|
+
html_list.append(' colspan="' + str(xrdist) + '"')
|
223
|
+
html_list.append(">")
|
224
|
+
html_list.append("</td>")
|
225
|
+
else:
|
226
|
+
html_list.append("<td>")
|
227
|
+
html_list.append("</td>")
|
228
|
+
html_list.append("</tr>")
|
229
|
+
html_table.extend(html_list)
|
230
|
+
|
231
|
+
if logdebug:
|
232
|
+
print("*********************** registry_2d_span ***************************")
|
233
|
+
print(registry_2d_span)
|
234
|
+
print("********************************************************************")
|
235
|
+
|
236
|
+
return html_table
|
237
|
+
|
238
|
+
|
239
|
+
def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer):
|
240
|
+
r"""
|
241
|
+
Converts table structure from HTML to RS
|
242
|
+
|
243
|
+
Parameters
|
244
|
+
----------
|
245
|
+
table : json
|
246
|
+
line from jsonl
|
247
|
+
writer : writer
|
248
|
+
Writes lines into output jsonl
|
249
|
+
"""
|
250
|
+
|
251
|
+
table_html_structure = copy.deepcopy(table["html"]["structure"])
|
252
|
+
out_line = table
|
253
|
+
if include_html:
|
254
|
+
out_line["html"]["html_structure"] = table_html_structure
|
255
|
+
out_line["html"]["html_restored_structure"] = {"tokens": []}
|
256
|
+
|
257
|
+
out_line["html"]["structure"] = {"tokens": []}
|
258
|
+
# possible colspans
|
259
|
+
pos_colspans = {
|
260
|
+
' colspan="20"': 20,
|
261
|
+
' colspan="19"': 19,
|
262
|
+
' colspan="18"': 18,
|
263
|
+
' colspan="17"': 17,
|
264
|
+
' colspan="16"': 16,
|
265
|
+
' colspan="15"': 15,
|
266
|
+
' colspan="14"': 14,
|
267
|
+
' colspan="13"': 13,
|
268
|
+
' colspan="12"': 12,
|
269
|
+
' colspan="11"': 11,
|
270
|
+
' colspan="10"': 10,
|
271
|
+
' colspan="2"': 2,
|
272
|
+
' colspan="3"': 3,
|
273
|
+
' colspan="4"': 4,
|
274
|
+
' colspan="5"': 5,
|
275
|
+
' colspan="6"': 6,
|
276
|
+
' colspan="7"': 7,
|
277
|
+
' colspan="8"': 8,
|
278
|
+
' colspan="9"': 9,
|
279
|
+
}
|
280
|
+
# possible rowspans
|
281
|
+
pos_rowspans = {
|
282
|
+
' rowspan="20"': 20,
|
283
|
+
' rowspan="19"': 19,
|
284
|
+
' rowspan="18"': 18,
|
285
|
+
' rowspan="17"': 17,
|
286
|
+
' rowspan="16"': 16,
|
287
|
+
' rowspan="15"': 15,
|
288
|
+
' rowspan="14"': 14,
|
289
|
+
' rowspan="13"': 13,
|
290
|
+
' rowspan="12"': 12,
|
291
|
+
' rowspan="11"': 11,
|
292
|
+
' rowspan="10"': 10,
|
293
|
+
' rowspan="2"': 2,
|
294
|
+
' rowspan="3"': 3,
|
295
|
+
' rowspan="4"': 4,
|
296
|
+
' rowspan="5"': 5,
|
297
|
+
' rowspan="6"': 6,
|
298
|
+
' rowspan="7"': 7,
|
299
|
+
' rowspan="8"': 8,
|
300
|
+
' rowspan="9"': 9,
|
301
|
+
}
|
302
|
+
|
303
|
+
t_cells = [] # 2D structure
|
304
|
+
tl_cells = [] # 1D structure
|
305
|
+
t_expands = [] # 2D structure
|
306
|
+
tl_spans = {} # MAP, POPULATE WITH ACTUAL SPANS VALUES, IN SYNC WITH tl_cells
|
307
|
+
|
308
|
+
current_line = 0
|
309
|
+
current_column = 0
|
310
|
+
current_html_cell_ind = 0
|
311
|
+
|
312
|
+
current_line_tags = []
|
313
|
+
current_line_expands = []
|
314
|
+
|
315
|
+
if logdebug:
|
316
|
+
print("")
|
317
|
+
print("*** {}: {} ***".format(table["split"], table["filename"]))
|
318
|
+
|
319
|
+
colnum = 0
|
320
|
+
|
321
|
+
if extra_debug:
|
322
|
+
print("========================== Input HTML ============================")
|
323
|
+
print(table_html_structure["tokens"])
|
324
|
+
print("==================================================================")
|
325
|
+
|
326
|
+
if logdebug:
|
327
|
+
print("********")
|
328
|
+
print("* OTSL *")
|
329
|
+
print("********")
|
330
|
+
|
331
|
+
for i in range(len(table_html_structure["tokens"])):
|
332
|
+
html_tag = table_html_structure["tokens"][i]
|
333
|
+
prev_html_tag = ""
|
334
|
+
next_html_tag = ""
|
335
|
+
if i > 0:
|
336
|
+
prev_html_tag = table_html_structure["tokens"][i - 1]
|
337
|
+
if i < len(table_html_structure["tokens"]) - 1:
|
338
|
+
next_html_tag = table_html_structure["tokens"][i + 1]
|
339
|
+
|
340
|
+
if html_tag not in ["<thead>", "<tbody>"]:
|
341
|
+
# Then check the next tag...
|
342
|
+
# rules of conversion
|
343
|
+
# Check up-cell in t_expands, in case row-spans have to be inserted
|
344
|
+
if html_tag in ["<td>", "<td", "</tr>"]:
|
345
|
+
if current_line > 0:
|
346
|
+
if current_column >= len(t_expands[current_line - 1]):
|
347
|
+
# !!!
|
348
|
+
return False, {}
|
349
|
+
up_expand = t_expands[current_line - 1][current_column]
|
350
|
+
|
351
|
+
while up_expand[1] > 0:
|
352
|
+
if up_expand[0] == 0:
|
353
|
+
# ucel
|
354
|
+
current_line_tags.append("ucel")
|
355
|
+
current_line_expands.append([0, up_expand[1] - 1])
|
356
|
+
current_column += 1
|
357
|
+
else:
|
358
|
+
# xcel
|
359
|
+
for ci in range(up_expand[0]):
|
360
|
+
current_line_tags.append("xcel")
|
361
|
+
current_line_expands.append(
|
362
|
+
[up_expand[0] - ci, up_expand[1] - 1]
|
363
|
+
)
|
364
|
+
current_column += 1
|
365
|
+
up_expand = t_expands[current_line - 1][current_column]
|
366
|
+
# ======================================================================================
|
367
|
+
# Fix for trailing "ucel" in a row
|
368
|
+
if html_tag in ["</tr>"]:
|
369
|
+
if current_line > 0:
|
370
|
+
cur_line_len = len(current_line_expands)
|
371
|
+
pre_line_len = len(t_expands[current_line - 1])
|
372
|
+
|
373
|
+
if cur_line_len < pre_line_len:
|
374
|
+
extra_columns = pre_line_len - cur_line_len - 1
|
375
|
+
if extra_columns > 0:
|
376
|
+
if extra_debug:
|
377
|
+
print(
|
378
|
+
"Extra columns needed in row: {}".format(
|
379
|
+
extra_columns
|
380
|
+
)
|
381
|
+
)
|
382
|
+
|
383
|
+
for clm in range(extra_columns):
|
384
|
+
up_expand = t_expands[current_line - 1][
|
385
|
+
cur_line_len + clm
|
386
|
+
]
|
387
|
+
if up_expand[0] == 0:
|
388
|
+
# ucel
|
389
|
+
current_line_tags.append("ucel")
|
390
|
+
current_line_expands.append([0, up_expand[1] - 1])
|
391
|
+
else:
|
392
|
+
# xcel
|
393
|
+
current_line_tags.append("xcel")
|
394
|
+
current_line_expands.append(
|
395
|
+
[up_expand[0], up_expand[1] - 1]
|
396
|
+
)
|
397
|
+
# ======================================================================================
|
398
|
+
|
399
|
+
# 1. Opening cell tags
|
400
|
+
if html_tag in ["<td>", "<td"]:
|
401
|
+
# check if cell is empty...
|
402
|
+
cell_is_empty = True
|
403
|
+
if "cells" in table["html"]:
|
404
|
+
cell_tokens = table["html"]["cells"][current_html_cell_ind][
|
405
|
+
"tokens"
|
406
|
+
]
|
407
|
+
else:
|
408
|
+
cell_tokens = "f"
|
409
|
+
|
410
|
+
# Clean cell_tokens from trash:
|
411
|
+
cell_tokens = list(filter(lambda a: a != "<i>", cell_tokens))
|
412
|
+
cell_tokens = list(filter(lambda a: a != "<I>", cell_tokens))
|
413
|
+
cell_tokens = list(filter(lambda a: a != "<b>", cell_tokens))
|
414
|
+
cell_tokens = list(filter(lambda a: a != "<B>", cell_tokens))
|
415
|
+
cell_tokens = list(filter(lambda a: a != " ", cell_tokens))
|
416
|
+
cell_tokens = list(filter(lambda a: a != "</b>", cell_tokens))
|
417
|
+
cell_tokens = list(filter(lambda a: a != "</B>", cell_tokens))
|
418
|
+
cell_tokens = list(filter(lambda a: a != "</i>", cell_tokens))
|
419
|
+
cell_tokens = list(filter(lambda a: a != "</I>", cell_tokens))
|
420
|
+
|
421
|
+
# Check if empty
|
422
|
+
if len(cell_tokens) > 0:
|
423
|
+
cell_is_empty = False
|
424
|
+
if cell_is_empty:
|
425
|
+
out_line["html"]["cells"][current_html_cell_ind]["tokens"] = []
|
426
|
+
current_line_tags.append("ecel")
|
427
|
+
current_line_expands.append([0, 0])
|
428
|
+
else:
|
429
|
+
current_line_tags.append("fcel")
|
430
|
+
current_line_expands.append([0, 0])
|
431
|
+
current_html_cell_ind += 1
|
432
|
+
current_column += 1
|
433
|
+
|
434
|
+
# 2. Closing row tags
|
435
|
+
if html_tag == "</tr>":
|
436
|
+
if len(current_line_tags) > colnum:
|
437
|
+
colnum = len(current_line_tags)
|
438
|
+
# Save everything we read about the line to t_cells
|
439
|
+
current_line_tags.append("nl")
|
440
|
+
t_cells.append(copy.deepcopy(current_line_tags))
|
441
|
+
tl_cells.extend(copy.deepcopy(current_line_tags))
|
442
|
+
if logdebug:
|
443
|
+
print(current_line_tags)
|
444
|
+
current_line_tags = []
|
445
|
+
|
446
|
+
# Deal with expands
|
447
|
+
current_line_expands.append([-1, -1])
|
448
|
+
# Output spans metadata
|
449
|
+
t_expands.append(copy.deepcopy(current_line_expands))
|
450
|
+
current_line_expands = []
|
451
|
+
|
452
|
+
current_column = 0
|
453
|
+
current_line += 1
|
454
|
+
# 3. Colspans only
|
455
|
+
if html_tag in pos_colspans:
|
456
|
+
if prev_html_tag not in pos_rowspans:
|
457
|
+
if next_html_tag not in pos_rowspans:
|
458
|
+
colspan_len = pos_colspans[html_tag]
|
459
|
+
tl_spans[current_html_cell_ind - 1] = [colspan_len, 1]
|
460
|
+
current_line_expands[len(current_line_expands) - 1] = [
|
461
|
+
colspan_len,
|
462
|
+
0,
|
463
|
+
]
|
464
|
+
for ci in range(colspan_len - 1):
|
465
|
+
current_line_tags.append("lcel")
|
466
|
+
current_line_expands.append([colspan_len - ci - 1, 0])
|
467
|
+
current_column += 1
|
468
|
+
|
469
|
+
# 4. Rowspans only
|
470
|
+
if html_tag in pos_rowspans:
|
471
|
+
if prev_html_tag not in pos_colspans:
|
472
|
+
if next_html_tag not in pos_colspans:
|
473
|
+
rowspan_len = pos_rowspans[html_tag]
|
474
|
+
tl_spans[current_html_cell_ind - 1] = [1, rowspan_len]
|
475
|
+
current_line_expands[len(current_line_expands) - 1] = [
|
476
|
+
0,
|
477
|
+
rowspan_len - 1,
|
478
|
+
]
|
479
|
+
|
480
|
+
# 5. 2D spans
|
481
|
+
if html_tag in pos_rowspans:
|
482
|
+
rowspan_len = pos_rowspans[html_tag]
|
483
|
+
if prev_html_tag in pos_colspans:
|
484
|
+
colspan_len = pos_colspans[prev_html_tag]
|
485
|
+
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
|
486
|
+
newexp = [colspan_len, rowspan_len - 1]
|
487
|
+
current_line_expands[len(current_line_expands) - 1] = newexp
|
488
|
+
for ci in range(colspan_len - 1):
|
489
|
+
current_line_tags.append("xcel")
|
490
|
+
current_line_expands.append(
|
491
|
+
[colspan_len - ci - 1, rowspan_len - 1]
|
492
|
+
)
|
493
|
+
if next_html_tag in pos_colspans:
|
494
|
+
colspan_len = pos_colspans[next_html_tag]
|
495
|
+
tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
|
496
|
+
newexp = [colspan_len, rowspan_len - 1]
|
497
|
+
current_line_expands[len(current_line_expands) - 1] = newexp
|
498
|
+
for ci in range(colspan_len - 1):
|
499
|
+
current_line_tags.append("xcel")
|
500
|
+
current_line_expands.append(
|
501
|
+
[colspan_len - ci - 1, rowspan_len - 1]
|
502
|
+
)
|
503
|
+
|
504
|
+
t_name = "*** {}: {} ***".format(table["split"], table["filename"])
|
505
|
+
# check if square
|
506
|
+
isSquare = otsl_sqr_chk(tl_cells, t_name, logdebug)
|
507
|
+
# TODO: pad if not square?
|
508
|
+
if not isSquare:
|
509
|
+
tl_cells = otsl_pad_to_sqr(tl_cells, "fcel")
|
510
|
+
# check if cells (bboxes) in sync:
|
511
|
+
if "cells" in out_line["html"]:
|
512
|
+
isGood = otsl_tags_cells_sync_chk(
|
513
|
+
tl_cells, out_line["html"]["cells"], t_name, logdebug
|
514
|
+
)
|
515
|
+
# convert back to HTML
|
516
|
+
rHTML = []
|
517
|
+
if isSquare:
|
518
|
+
rHTML = otsl_to_html(tl_cells, logdebug)
|
519
|
+
out_line["html"]["html_restored_structure"]["tokens"] = rHTML
|
520
|
+
|
521
|
+
out_line["html"]["structure"]["tokens"] = tl_cells
|
522
|
+
out_line["otsl_spans"] = tl_spans
|
523
|
+
out_line["cols"] = colnum
|
524
|
+
out_line["rows"] = len(t_cells)
|
525
|
+
out_line["html_len"] = len(table_html_structure["tokens"])
|
526
|
+
out_line["rs_len"] = len(tl_cells)
|
527
|
+
# save converted line
|
528
|
+
if use_writer:
|
529
|
+
if isSquare:
|
530
|
+
if isGood:
|
531
|
+
writer.write(out_line)
|
532
|
+
|
533
|
+
if logdebug:
|
534
|
+
print("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
|
535
|
+
print(rHTML)
|
536
|
+
# original HTML
|
537
|
+
oHTML = out_line["html"]["html_structure"]
|
538
|
+
print("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
|
539
|
+
print(oHTML)
|
540
|
+
|
541
|
+
return True, out_line
|
@@ -0,0 +1,90 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
import sys
|
7
|
+
|
8
|
+
|
9
|
+
def get_custom_logger(logger_name, level, stream=sys.stdout):
|
10
|
+
r"""
|
11
|
+
Create a custom logger with a standard formatting
|
12
|
+
|
13
|
+
Inputs:
|
14
|
+
- logger_name: Name of the logger. You can get the class name as self.__class__.__name__
|
15
|
+
- level: logging level (e.g. logging.INFO, logging.DEBUG, etc.)
|
16
|
+
- stream: One of sys.stdout or sys.stderr
|
17
|
+
|
18
|
+
Outputs:
|
19
|
+
logger
|
20
|
+
"""
|
21
|
+
logger = logging.getLogger(logger_name)
|
22
|
+
logger.setLevel(level)
|
23
|
+
|
24
|
+
# Set the handler
|
25
|
+
if not logger.hasHandlers():
|
26
|
+
handler = logging.StreamHandler(stream)
|
27
|
+
formatter = logging.Formatter(
|
28
|
+
"%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
|
29
|
+
)
|
30
|
+
handler.setFormatter(formatter)
|
31
|
+
logger.addHandler(handler)
|
32
|
+
|
33
|
+
return logger
|
34
|
+
|
35
|
+
|
36
|
+
###################################################################################
|
37
|
+
# System constants
|
38
|
+
#
|
39
|
+
|
40
|
+
r"""
|
41
|
+
This is a "generic" logger available to all scripts.
|
42
|
+
It is encouraged that each class has it's own custom logger with the name of the class.
|
43
|
+
You can use the "get_custom_logger" function to build a custom logger with a standard format.
|
44
|
+
"""
|
45
|
+
LOGGER = get_custom_logger("docling-pm", logging.INFO)
|
46
|
+
|
47
|
+
# Supported dataset types
|
48
|
+
supported_datasets = ["TF_prepared"] # TF prepared dataset
|
49
|
+
|
50
|
+
# Split names
|
51
|
+
TRAIN_SPLIT = "train"
|
52
|
+
VAL_SPLIT = "val"
|
53
|
+
TEST_SPLIT = "test"
|
54
|
+
|
55
|
+
# Prepared data parts and filename templates
|
56
|
+
PREPARED_DATA_PARTS = {
|
57
|
+
# Array with the bboxes (x1y1x2y2) for all cells of the images across all splits.
|
58
|
+
# The bboxes are indexed with the filename.
|
59
|
+
# Notices:
|
60
|
+
# - The bboxes are NOT transformed.
|
61
|
+
# - If the image filenames are the same across splits, there will be one one entry in the file
|
62
|
+
"BBOXES": "BBOXES.json",
|
63
|
+
# Image filenames used for train and val
|
64
|
+
"IMAGES": "IMAGES.json",
|
65
|
+
# Mean, std, variance as arrays of 3 (for each color)
|
66
|
+
"STATISTICS": "STATISTICS_<POSTFIX>.json", # PRECOMPUTED
|
67
|
+
# Bboxes of the cells in the form [1, x1, x2, y1, y2] or [0, 0, 0, 0, 0] in case of no box.
|
68
|
+
"TRAIN_CELLBBOXES": "TRAIN_CELLBBOXES_<POSTFIX>.json", # NOT USED.
|
69
|
+
# Array with arrays of the length + 2 of the original cells per image.
|
70
|
+
"TRAIN_CELLLENS": "TRAIN_CELLLENS_<POSTFIX>.json",
|
71
|
+
# Indices of the cells between <start> <end> and <pad> at the end.
|
72
|
+
"TRAIN_CELLS": "TRAIN_CELLS_<POSTFIX>.json",
|
73
|
+
# Array with the length + 2 of the original tags per image.
|
74
|
+
"TRAIN_TAGLENS": "TRAIN_TAGLENS_<POSTFIX>.json",
|
75
|
+
# Indices of the tags between <start> <end> and <pad> at the end.
|
76
|
+
"TRAIN_TAGS": "TRAIN_TAGS_<POSTFIX>.json",
|
77
|
+
# Ground truth for the evaluation dataset per eval image.
|
78
|
+
"VAL": "VAL.json",
|
79
|
+
# Vocabulary: Indices of the word_map_cells and word_map_tags
|
80
|
+
"WORDMAP": "WORDMAP_<POSTFIX>.json", # PRECOMPUTED
|
81
|
+
}
|
82
|
+
|
83
|
+
# Purposes
|
84
|
+
TRAIN_PURPOSE = "train"
|
85
|
+
VAL_PURPOSE = "val"
|
86
|
+
TEST_PURPOSE = "test"
|
87
|
+
PREDICT_PURPOSE = "predict"
|
88
|
+
|
89
|
+
# The DDP world size when we train in CPU with DDP enabled
|
90
|
+
DDP_CPU_WORLD_SIZE = 2
|
@@ -0,0 +1,37 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import docling_ibm_models.tableformer.common as c
|
8
|
+
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
|
9
|
+
|
10
|
+
LOG_LEVEL = logging.INFO
|
11
|
+
# LOG_LEVEL = logging.DEBUG
|
12
|
+
|
13
|
+
|
14
|
+
def dataset_test(config):
|
15
|
+
r"""
|
16
|
+
Parameters
|
17
|
+
----------
|
18
|
+
config : dictionary
|
19
|
+
The configuration settings
|
20
|
+
"""
|
21
|
+
|
22
|
+
# model_type = config["model"]["type"]
|
23
|
+
# Create the device and the Dataset
|
24
|
+
device = "cpu"
|
25
|
+
dataset = TFDataset(config, "train")
|
26
|
+
dataset.set_device(device)
|
27
|
+
|
28
|
+
# Loop over the data
|
29
|
+
dataset.reset()
|
30
|
+
dataset.shuffle()
|
31
|
+
for i, batch in enumerate(dataset):
|
32
|
+
print("Loading batch: {}".format(i))
|
33
|
+
|
34
|
+
|
35
|
+
if __name__ == "__main__":
|
36
|
+
config = c.parse_arguments()
|
37
|
+
dataset_test(config)
|