docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,541 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import copy
6
+ import logging
7
+ from itertools import groupby
8
+
9
+ import docling_ibm_models.tableformer.settings as s
10
+
11
+ LOG_LEVEL = logging.INFO
12
+ # LOG_LEVEL = logging.DEBUG
13
+ logger = s.get_custom_logger("consolidate", LOG_LEVEL)
14
+ png_files = {} # Evaluation files
15
+ total_pics = 0
16
+
17
+
18
+ class bcolors:
19
+ HEADER = "\033[95m"
20
+ OKBLUE = "\033[94m"
21
+ OKCYAN = "\033[96m"
22
+ OKGREEN = "\033[92m"
23
+ WARNING = "\033[93m"
24
+ FAIL = "\033[91m"
25
+ ENDC = "\033[0m"
26
+ BOLD = "\033[1m"
27
+ UNDERLINE = "\033[4m"
28
+
29
+
30
+ def otsl_clean(rs_list):
31
+ new_rs_list = []
32
+ stop_list = ["<pad>", "<unk>", "<start>", "<end>"]
33
+ for tag in rs_list:
34
+ if tag not in stop_list:
35
+ new_rs_list.append(tag)
36
+ return new_rs_list
37
+
38
+
39
+ def otsl_sqr_chk(rs_list, name, logdebug):
40
+ rs_list_split = [
41
+ list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
42
+ ]
43
+ isSquare = True
44
+ if len(rs_list_split) > 0:
45
+ init_tag_len = len(rs_list_split[0]) + 1
46
+ for ind, ln in enumerate(rs_list_split):
47
+ ln.append("nl")
48
+ if len(ln) != init_tag_len:
49
+ isSquare = False
50
+ if isSquare:
51
+ if logdebug:
52
+ print(
53
+ "{}*OK* Table is square! *OK*{}".format(
54
+ bcolors.OKGREEN, bcolors.ENDC
55
+ )
56
+ )
57
+ else:
58
+ err_name = "{}*ERR* " + name + " *ERR*{}"
59
+ print(err_name.format(bcolors.FAIL, bcolors.ENDC))
60
+ print(
61
+ "{}*ERR* Table is not square! *ERR*{}".format(
62
+ bcolors.FAIL, bcolors.ENDC
63
+ )
64
+ )
65
+ return isSquare
66
+
67
+
68
+ def otsl_pad_to_sqr(rs_list, pad_tag):
69
+ new_list = []
70
+ rs_list_split = [
71
+ list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
72
+ ]
73
+ max_row_len = 0
74
+ for ind, ln in enumerate(rs_list_split):
75
+ if len(ln) > max_row_len:
76
+ max_row_len = len(ln)
77
+ for ind, ln in enumerate(rs_list_split):
78
+ ln += [pad_tag] * (max_row_len - len(ln))
79
+ ln.append("nl")
80
+ new_list.extend(ln)
81
+ return new_list
82
+
83
+
84
+ def otsl_tags_cells_sync_chk(rs_list, cells, name, logdebug):
85
+ countCellTags = 0
86
+ isGood = True
87
+ for rsTag in rs_list:
88
+ if rsTag in ["fcel", "ched", "rhed", "srow", "ecel"]:
89
+ countCellTags += 1
90
+ if countCellTags != len(cells):
91
+ err_name = "{}*!ERR* " + name + " *ERR!*{}"
92
+ print(err_name.format(bcolors.FAIL, bcolors.ENDC))
93
+ err_msg = "{}*!ERR* Tags are not in sync with cells! *ERR!*{}"
94
+ print(err_msg.format(bcolors.FAIL, bcolors.ENDC))
95
+ isGood = False
96
+ return isGood
97
+
98
+
99
+ def otsl_check_down(rs_split, x, y):
100
+ distance = 1
101
+ elem = "ucel"
102
+ goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "lcel", "nl"]
103
+ while elem not in goodlist and y < len(rs_split) - 1:
104
+ y += 1
105
+ distance += 1
106
+ elem = rs_split[y][x]
107
+ if elem in goodlist:
108
+ distance -= 1
109
+ return distance
110
+
111
+
112
+ def otsl_check_right(rs_split, x, y):
113
+ distance = 1
114
+ elem = "lcel"
115
+ goodlist = ["fcel", "ched", "rhed", "srow", "ecel", "ucel", "nl"]
116
+ while elem not in goodlist and x < (len(rs_split[y]) - 1):
117
+ x += 1
118
+ distance += 1
119
+ elem = rs_split[y][x]
120
+ if elem in goodlist:
121
+ distance -= 1
122
+ return distance
123
+
124
+
125
+ def otsl_to_html(rs_list, logdebug):
126
+ if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]:
127
+ # Most likely already HTML...
128
+ return rs_list
129
+ html_table = []
130
+ if logdebug:
131
+ print("{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC))
132
+
133
+ if not otsl_sqr_chk(rs_list, "---", logdebug):
134
+ # PAD TABLE TO SQUARE
135
+ print("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
136
+ rs_list = otsl_pad_to_sqr(rs_list, "lcel")
137
+
138
+ # 2D structure, line by line:
139
+ rs_list_split = [
140
+ list(group) for k, group in groupby(rs_list, lambda x: x == "nl") if not k
141
+ ]
142
+
143
+ if logdebug:
144
+ print("")
145
+
146
+ # Sequentially store indexes of 2D spans that were registered to avoid re-registering them
147
+ registry_2d_span = []
148
+
149
+ # Iterate all elements in the rs line, and look right / down to detect spans
150
+ # If span detected - run function to find size of the span
151
+ # repeat with all cells
152
+ thead_present = False
153
+
154
+ for rs_row_ind, rs_row in enumerate(rs_list_split):
155
+ html_list = []
156
+
157
+ if not thead_present:
158
+ if "ched" in rs_list_split[rs_row_ind]:
159
+ html_list.append("<thead>")
160
+ thead_present = True
161
+
162
+ if thead_present:
163
+ if "ched" not in rs_list_split[rs_row_ind]:
164
+ html_list.append("</thead>")
165
+ thead_present = False
166
+
167
+ html_list.append("<tr>")
168
+ for rs_cell_ind, rs_cell in enumerate(rs_list_split[rs_row_ind]):
169
+ if rs_cell in ["fcel", "ched", "rhed", "srow", "ecel"]:
170
+ rdist = 0
171
+ ddist = 0
172
+ xrdist = 0
173
+ xddist = 0
174
+ span = False
175
+ # Check if it has horizontal span:
176
+ if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
177
+ if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "lcel":
178
+ rdist = otsl_check_right(rs_list_split, rs_cell_ind, rs_row_ind)
179
+ span = True
180
+ # Check if it has vertical span:
181
+ if rs_row_ind + 1 < len(rs_list_split):
182
+ # print(">>>")
183
+ # print(rs_list_split[rs_row_ind + 1])
184
+ # print(">>> rs_cell_ind = {}".format(rs_cell_ind))
185
+ if rs_list_split[rs_row_ind + 1][rs_cell_ind] == "ucel":
186
+ ddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
187
+ span = True
188
+ # Check if it has 2D span:
189
+ if rs_cell_ind + 1 < len(rs_list_split[rs_row_ind]):
190
+ if rs_list_split[rs_row_ind][rs_cell_ind + 1] == "xcel":
191
+ xrdist = otsl_check_right(
192
+ rs_list_split, rs_cell_ind, rs_row_ind
193
+ )
194
+ xddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
195
+ span = True
196
+ # Check if this 2D span was already registered,
197
+ # If not - register, if yes - cancel span
198
+ # print("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
199
+ # print("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
200
+ for x in range(rs_cell_ind, xrdist + rs_cell_ind):
201
+ for y in range(rs_row_ind, xddist + rs_row_ind):
202
+ reg2dind = str(x) + "_" + str(y)
203
+ # print(reg2dind)
204
+ if reg2dind in registry_2d_span:
205
+ # Cell of the span is already in, cancel current span
206
+ span = False
207
+ if span:
208
+ # None of the span cells were previously registered
209
+ # Register an entire span
210
+ for x in range(rs_cell_ind, xrdist + rs_cell_ind):
211
+ for y in range(rs_row_ind, xddist + rs_row_ind):
212
+ reg2dind = str(x) + "_" + str(y)
213
+ registry_2d_span.append(reg2dind)
214
+ if span:
215
+ html_list.append("<td")
216
+ if rdist > 1:
217
+ html_list.append(' colspan="' + str(rdist) + '"')
218
+ if ddist > 1:
219
+ html_list.append(' rowspan="' + str(ddist) + '"')
220
+ if xrdist > 1:
221
+ html_list.append(' rowspan="' + str(xddist) + '"')
222
+ html_list.append(' colspan="' + str(xrdist) + '"')
223
+ html_list.append(">")
224
+ html_list.append("</td>")
225
+ else:
226
+ html_list.append("<td>")
227
+ html_list.append("</td>")
228
+ html_list.append("</tr>")
229
+ html_table.extend(html_list)
230
+
231
+ if logdebug:
232
+ print("*********************** registry_2d_span ***************************")
233
+ print(registry_2d_span)
234
+ print("********************************************************************")
235
+
236
+ return html_table
237
+
238
+
239
+ def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer):
240
+ r"""
241
+ Converts table structure from HTML to RS
242
+
243
+ Parameters
244
+ ----------
245
+ table : json
246
+ line from jsonl
247
+ writer : writer
248
+ Writes lines into output jsonl
249
+ """
250
+
251
+ table_html_structure = copy.deepcopy(table["html"]["structure"])
252
+ out_line = table
253
+ if include_html:
254
+ out_line["html"]["html_structure"] = table_html_structure
255
+ out_line["html"]["html_restored_structure"] = {"tokens": []}
256
+
257
+ out_line["html"]["structure"] = {"tokens": []}
258
+ # possible colspans
259
+ pos_colspans = {
260
+ ' colspan="20"': 20,
261
+ ' colspan="19"': 19,
262
+ ' colspan="18"': 18,
263
+ ' colspan="17"': 17,
264
+ ' colspan="16"': 16,
265
+ ' colspan="15"': 15,
266
+ ' colspan="14"': 14,
267
+ ' colspan="13"': 13,
268
+ ' colspan="12"': 12,
269
+ ' colspan="11"': 11,
270
+ ' colspan="10"': 10,
271
+ ' colspan="2"': 2,
272
+ ' colspan="3"': 3,
273
+ ' colspan="4"': 4,
274
+ ' colspan="5"': 5,
275
+ ' colspan="6"': 6,
276
+ ' colspan="7"': 7,
277
+ ' colspan="8"': 8,
278
+ ' colspan="9"': 9,
279
+ }
280
+ # possible rowspans
281
+ pos_rowspans = {
282
+ ' rowspan="20"': 20,
283
+ ' rowspan="19"': 19,
284
+ ' rowspan="18"': 18,
285
+ ' rowspan="17"': 17,
286
+ ' rowspan="16"': 16,
287
+ ' rowspan="15"': 15,
288
+ ' rowspan="14"': 14,
289
+ ' rowspan="13"': 13,
290
+ ' rowspan="12"': 12,
291
+ ' rowspan="11"': 11,
292
+ ' rowspan="10"': 10,
293
+ ' rowspan="2"': 2,
294
+ ' rowspan="3"': 3,
295
+ ' rowspan="4"': 4,
296
+ ' rowspan="5"': 5,
297
+ ' rowspan="6"': 6,
298
+ ' rowspan="7"': 7,
299
+ ' rowspan="8"': 8,
300
+ ' rowspan="9"': 9,
301
+ }
302
+
303
+ t_cells = [] # 2D structure
304
+ tl_cells = [] # 1D structure
305
+ t_expands = [] # 2D structure
306
+ tl_spans = {} # MAP, POPULATE WITH ACTUAL SPANS VALUES, IN SYNC WITH tl_cells
307
+
308
+ current_line = 0
309
+ current_column = 0
310
+ current_html_cell_ind = 0
311
+
312
+ current_line_tags = []
313
+ current_line_expands = []
314
+
315
+ if logdebug:
316
+ print("")
317
+ print("*** {}: {} ***".format(table["split"], table["filename"]))
318
+
319
+ colnum = 0
320
+
321
+ if extra_debug:
322
+ print("========================== Input HTML ============================")
323
+ print(table_html_structure["tokens"])
324
+ print("==================================================================")
325
+
326
+ if logdebug:
327
+ print("********")
328
+ print("* OTSL *")
329
+ print("********")
330
+
331
+ for i in range(len(table_html_structure["tokens"])):
332
+ html_tag = table_html_structure["tokens"][i]
333
+ prev_html_tag = ""
334
+ next_html_tag = ""
335
+ if i > 0:
336
+ prev_html_tag = table_html_structure["tokens"][i - 1]
337
+ if i < len(table_html_structure["tokens"]) - 1:
338
+ next_html_tag = table_html_structure["tokens"][i + 1]
339
+
340
+ if html_tag not in ["<thead>", "<tbody>"]:
341
+ # Then check the next tag...
342
+ # rules of conversion
343
+ # Check up-cell in t_expands, in case row-spans have to be inserted
344
+ if html_tag in ["<td>", "<td", "</tr>"]:
345
+ if current_line > 0:
346
+ if current_column >= len(t_expands[current_line - 1]):
347
+ # !!!
348
+ return False, {}
349
+ up_expand = t_expands[current_line - 1][current_column]
350
+
351
+ while up_expand[1] > 0:
352
+ if up_expand[0] == 0:
353
+ # ucel
354
+ current_line_tags.append("ucel")
355
+ current_line_expands.append([0, up_expand[1] - 1])
356
+ current_column += 1
357
+ else:
358
+ # xcel
359
+ for ci in range(up_expand[0]):
360
+ current_line_tags.append("xcel")
361
+ current_line_expands.append(
362
+ [up_expand[0] - ci, up_expand[1] - 1]
363
+ )
364
+ current_column += 1
365
+ up_expand = t_expands[current_line - 1][current_column]
366
+ # ======================================================================================
367
+ # Fix for trailing "ucel" in a row
368
+ if html_tag in ["</tr>"]:
369
+ if current_line > 0:
370
+ cur_line_len = len(current_line_expands)
371
+ pre_line_len = len(t_expands[current_line - 1])
372
+
373
+ if cur_line_len < pre_line_len:
374
+ extra_columns = pre_line_len - cur_line_len - 1
375
+ if extra_columns > 0:
376
+ if extra_debug:
377
+ print(
378
+ "Extra columns needed in row: {}".format(
379
+ extra_columns
380
+ )
381
+ )
382
+
383
+ for clm in range(extra_columns):
384
+ up_expand = t_expands[current_line - 1][
385
+ cur_line_len + clm
386
+ ]
387
+ if up_expand[0] == 0:
388
+ # ucel
389
+ current_line_tags.append("ucel")
390
+ current_line_expands.append([0, up_expand[1] - 1])
391
+ else:
392
+ # xcel
393
+ current_line_tags.append("xcel")
394
+ current_line_expands.append(
395
+ [up_expand[0], up_expand[1] - 1]
396
+ )
397
+ # ======================================================================================
398
+
399
+ # 1. Opening cell tags
400
+ if html_tag in ["<td>", "<td"]:
401
+ # check if cell is empty...
402
+ cell_is_empty = True
403
+ if "cells" in table["html"]:
404
+ cell_tokens = table["html"]["cells"][current_html_cell_ind][
405
+ "tokens"
406
+ ]
407
+ else:
408
+ cell_tokens = "f"
409
+
410
+ # Clean cell_tokens from trash:
411
+ cell_tokens = list(filter(lambda a: a != "<i>", cell_tokens))
412
+ cell_tokens = list(filter(lambda a: a != "<I>", cell_tokens))
413
+ cell_tokens = list(filter(lambda a: a != "<b>", cell_tokens))
414
+ cell_tokens = list(filter(lambda a: a != "<B>", cell_tokens))
415
+ cell_tokens = list(filter(lambda a: a != " ", cell_tokens))
416
+ cell_tokens = list(filter(lambda a: a != "</b>", cell_tokens))
417
+ cell_tokens = list(filter(lambda a: a != "</B>", cell_tokens))
418
+ cell_tokens = list(filter(lambda a: a != "</i>", cell_tokens))
419
+ cell_tokens = list(filter(lambda a: a != "</I>", cell_tokens))
420
+
421
+ # Check if empty
422
+ if len(cell_tokens) > 0:
423
+ cell_is_empty = False
424
+ if cell_is_empty:
425
+ out_line["html"]["cells"][current_html_cell_ind]["tokens"] = []
426
+ current_line_tags.append("ecel")
427
+ current_line_expands.append([0, 0])
428
+ else:
429
+ current_line_tags.append("fcel")
430
+ current_line_expands.append([0, 0])
431
+ current_html_cell_ind += 1
432
+ current_column += 1
433
+
434
+ # 2. Closing row tags
435
+ if html_tag == "</tr>":
436
+ if len(current_line_tags) > colnum:
437
+ colnum = len(current_line_tags)
438
+ # Save everything we read about the line to t_cells
439
+ current_line_tags.append("nl")
440
+ t_cells.append(copy.deepcopy(current_line_tags))
441
+ tl_cells.extend(copy.deepcopy(current_line_tags))
442
+ if logdebug:
443
+ print(current_line_tags)
444
+ current_line_tags = []
445
+
446
+ # Deal with expands
447
+ current_line_expands.append([-1, -1])
448
+ # Output spans metadata
449
+ t_expands.append(copy.deepcopy(current_line_expands))
450
+ current_line_expands = []
451
+
452
+ current_column = 0
453
+ current_line += 1
454
+ # 3. Colspans only
455
+ if html_tag in pos_colspans:
456
+ if prev_html_tag not in pos_rowspans:
457
+ if next_html_tag not in pos_rowspans:
458
+ colspan_len = pos_colspans[html_tag]
459
+ tl_spans[current_html_cell_ind - 1] = [colspan_len, 1]
460
+ current_line_expands[len(current_line_expands) - 1] = [
461
+ colspan_len,
462
+ 0,
463
+ ]
464
+ for ci in range(colspan_len - 1):
465
+ current_line_tags.append("lcel")
466
+ current_line_expands.append([colspan_len - ci - 1, 0])
467
+ current_column += 1
468
+
469
+ # 4. Rowspans only
470
+ if html_tag in pos_rowspans:
471
+ if prev_html_tag not in pos_colspans:
472
+ if next_html_tag not in pos_colspans:
473
+ rowspan_len = pos_rowspans[html_tag]
474
+ tl_spans[current_html_cell_ind - 1] = [1, rowspan_len]
475
+ current_line_expands[len(current_line_expands) - 1] = [
476
+ 0,
477
+ rowspan_len - 1,
478
+ ]
479
+
480
+ # 5. 2D spans
481
+ if html_tag in pos_rowspans:
482
+ rowspan_len = pos_rowspans[html_tag]
483
+ if prev_html_tag in pos_colspans:
484
+ colspan_len = pos_colspans[prev_html_tag]
485
+ tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
486
+ newexp = [colspan_len, rowspan_len - 1]
487
+ current_line_expands[len(current_line_expands) - 1] = newexp
488
+ for ci in range(colspan_len - 1):
489
+ current_line_tags.append("xcel")
490
+ current_line_expands.append(
491
+ [colspan_len - ci - 1, rowspan_len - 1]
492
+ )
493
+ if next_html_tag in pos_colspans:
494
+ colspan_len = pos_colspans[next_html_tag]
495
+ tl_spans[current_html_cell_ind - 1] = [colspan_len, rowspan_len]
496
+ newexp = [colspan_len, rowspan_len - 1]
497
+ current_line_expands[len(current_line_expands) - 1] = newexp
498
+ for ci in range(colspan_len - 1):
499
+ current_line_tags.append("xcel")
500
+ current_line_expands.append(
501
+ [colspan_len - ci - 1, rowspan_len - 1]
502
+ )
503
+
504
+ t_name = "*** {}: {} ***".format(table["split"], table["filename"])
505
+ # check if square
506
+ isSquare = otsl_sqr_chk(tl_cells, t_name, logdebug)
507
+ # TODO: pad if not square?
508
+ if not isSquare:
509
+ tl_cells = otsl_pad_to_sqr(tl_cells, "fcel")
510
+ # check if cells (bboxes) in sync:
511
+ if "cells" in out_line["html"]:
512
+ isGood = otsl_tags_cells_sync_chk(
513
+ tl_cells, out_line["html"]["cells"], t_name, logdebug
514
+ )
515
+ # convert back to HTML
516
+ rHTML = []
517
+ if isSquare:
518
+ rHTML = otsl_to_html(tl_cells, logdebug)
519
+ out_line["html"]["html_restored_structure"]["tokens"] = rHTML
520
+
521
+ out_line["html"]["structure"]["tokens"] = tl_cells
522
+ out_line["otsl_spans"] = tl_spans
523
+ out_line["cols"] = colnum
524
+ out_line["rows"] = len(t_cells)
525
+ out_line["html_len"] = len(table_html_structure["tokens"])
526
+ out_line["rs_len"] = len(tl_cells)
527
+ # save converted line
528
+ if use_writer:
529
+ if isSquare:
530
+ if isGood:
531
+ writer.write(out_line)
532
+
533
+ if logdebug:
534
+ print("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
535
+ print(rHTML)
536
+ # original HTML
537
+ oHTML = out_line["html"]["html_structure"]
538
+ print("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
539
+ print(oHTML)
540
+
541
+ return True, out_line
@@ -0,0 +1,90 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+ import sys
7
+
8
+
9
+ def get_custom_logger(logger_name, level, stream=sys.stdout):
10
+ r"""
11
+ Create a custom logger with a standard formatting
12
+
13
+ Inputs:
14
+ - logger_name: Name of the logger. You can get the class name as self.__class__.__name__
15
+ - level: logging level (e.g. logging.INFO, logging.DEBUG, etc.)
16
+ - stream: One of sys.stdout or sys.stderr
17
+
18
+ Outputs:
19
+ logger
20
+ """
21
+ logger = logging.getLogger(logger_name)
22
+ logger.setLevel(level)
23
+
24
+ # Set the handler
25
+ if not logger.hasHandlers():
26
+ handler = logging.StreamHandler(stream)
27
+ formatter = logging.Formatter(
28
+ "%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
29
+ )
30
+ handler.setFormatter(formatter)
31
+ logger.addHandler(handler)
32
+
33
+ return logger
34
+
35
+
36
+ ###################################################################################
37
+ # System constants
38
+ #
39
+
40
+ r"""
41
+ This is a "generic" logger available to all scripts.
42
+ It is encouraged that each class has it's own custom logger with the name of the class.
43
+ You can use the "get_custom_logger" function to build a custom logger with a standard format.
44
+ """
45
+ LOGGER = get_custom_logger("docling-pm", logging.INFO)
46
+
47
+ # Supported dataset types
48
+ supported_datasets = ["TF_prepared"] # TF prepared dataset
49
+
50
+ # Split names
51
+ TRAIN_SPLIT = "train"
52
+ VAL_SPLIT = "val"
53
+ TEST_SPLIT = "test"
54
+
55
+ # Prepared data parts and filename templates
56
+ PREPARED_DATA_PARTS = {
57
+ # Array with the bboxes (x1y1x2y2) for all cells of the images across all splits.
58
+ # The bboxes are indexed with the filename.
59
+ # Notices:
60
+ # - The bboxes are NOT transformed.
61
+ # - If the image filenames are the same across splits, there will be one one entry in the file
62
+ "BBOXES": "BBOXES.json",
63
+ # Image filenames used for train and val
64
+ "IMAGES": "IMAGES.json",
65
+ # Mean, std, variance as arrays of 3 (for each color)
66
+ "STATISTICS": "STATISTICS_<POSTFIX>.json", # PRECOMPUTED
67
+ # Bboxes of the cells in the form [1, x1, x2, y1, y2] or [0, 0, 0, 0, 0] in case of no box.
68
+ "TRAIN_CELLBBOXES": "TRAIN_CELLBBOXES_<POSTFIX>.json", # NOT USED.
69
+ # Array with arrays of the length + 2 of the original cells per image.
70
+ "TRAIN_CELLLENS": "TRAIN_CELLLENS_<POSTFIX>.json",
71
+ # Indices of the cells between <start> <end> and <pad> at the end.
72
+ "TRAIN_CELLS": "TRAIN_CELLS_<POSTFIX>.json",
73
+ # Array with the length + 2 of the original tags per image.
74
+ "TRAIN_TAGLENS": "TRAIN_TAGLENS_<POSTFIX>.json",
75
+ # Indices of the tags between <start> <end> and <pad> at the end.
76
+ "TRAIN_TAGS": "TRAIN_TAGS_<POSTFIX>.json",
77
+ # Ground truth for the evaluation dataset per eval image.
78
+ "VAL": "VAL.json",
79
+ # Vocabulary: Indices of the word_map_cells and word_map_tags
80
+ "WORDMAP": "WORDMAP_<POSTFIX>.json", # PRECOMPUTED
81
+ }
82
+
83
+ # Purposes
84
+ TRAIN_PURPOSE = "train"
85
+ VAL_PURPOSE = "val"
86
+ TEST_PURPOSE = "test"
87
+ PREDICT_PURPOSE = "predict"
88
+
89
+ # The DDP world size when we train in CPU with DDP enabled
90
+ DDP_CPU_WORLD_SIZE = 2
@@ -0,0 +1,37 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+
7
+ import docling_ibm_models.tableformer.common as c
8
+ from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
9
+
10
+ LOG_LEVEL = logging.INFO
11
+ # LOG_LEVEL = logging.DEBUG
12
+
13
+
14
+ def dataset_test(config):
15
+ r"""
16
+ Parameters
17
+ ----------
18
+ config : dictionary
19
+ The configuration settings
20
+ """
21
+
22
+ # model_type = config["model"]["type"]
23
+ # Create the device and the Dataset
24
+ device = "cpu"
25
+ dataset = TFDataset(config, "train")
26
+ dataset.set_device(device)
27
+
28
+ # Loop over the data
29
+ dataset.reset()
30
+ dataset.shuffle()
31
+ for i, batch in enumerate(dataset):
32
+ print("Loading batch: {}".format(i))
33
+
34
+
35
+ if __name__ == "__main__":
36
+ config = c.parse_arguments()
37
+ dataset_test(config)