docling 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/__init__.py +0 -0
- docling/backend/__init__.py +0 -0
- docling/backend/abstract_backend.py +55 -0
- docling/backend/pypdfium2_backend.py +223 -0
- docling/datamodel/__init__.py +0 -0
- docling/datamodel/base_models.py +247 -0
- docling/datamodel/document.py +351 -0
- docling/datamodel/settings.py +32 -0
- docling/document_converter.py +207 -0
- docling/models/__init__.py +0 -0
- docling/models/ds_glm_model.py +82 -0
- docling/models/easyocr_model.py +77 -0
- docling/models/layout_model.py +318 -0
- docling/models/page_assemble_model.py +160 -0
- docling/models/table_structure_model.py +114 -0
- docling/pipeline/__init__.py +0 -0
- docling/pipeline/base_model_pipeline.py +18 -0
- docling/pipeline/standard_model_pipeline.py +40 -0
- docling/utils/__init__.py +0 -0
- docling/utils/layout_utils.py +806 -0
- docling/utils/utils.py +41 -0
- docling-0.1.0.dist-info/LICENSE +21 -0
- docling-0.1.0.dist-info/METADATA +130 -0
- docling-0.1.0.dist-info/RECORD +25 -0
- docling-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,77 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
import random
|
4
|
+
from typing import Iterable
|
5
|
+
|
6
|
+
import numpy
|
7
|
+
from PIL import ImageDraw
|
8
|
+
|
9
|
+
from docling.datamodel.base_models import BoundingBox, CoordOrigin, OcrCell, Page
|
10
|
+
|
11
|
+
_log = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class EasyOcrModel:
|
15
|
+
def __init__(self, config):
|
16
|
+
self.config = config
|
17
|
+
self.enabled = config["enabled"]
|
18
|
+
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
|
19
|
+
|
20
|
+
if self.enabled:
|
21
|
+
import easyocr
|
22
|
+
|
23
|
+
self.reader = easyocr.Reader(config["lang"])
|
24
|
+
|
25
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
26
|
+
|
27
|
+
if not self.enabled:
|
28
|
+
yield from page_batch
|
29
|
+
return
|
30
|
+
|
31
|
+
for page in page_batch:
|
32
|
+
# rects = page._fpage.
|
33
|
+
high_res_image = page._backend.get_page_image(scale=self.scale)
|
34
|
+
im = numpy.array(high_res_image)
|
35
|
+
result = self.reader.readtext(im)
|
36
|
+
|
37
|
+
del high_res_image
|
38
|
+
del im
|
39
|
+
|
40
|
+
cells = [
|
41
|
+
OcrCell(
|
42
|
+
id=ix,
|
43
|
+
text=line[1],
|
44
|
+
confidence=line[2],
|
45
|
+
bbox=BoundingBox.from_tuple(
|
46
|
+
coord=(
|
47
|
+
line[0][0][0] / self.scale,
|
48
|
+
line[0][0][1] / self.scale,
|
49
|
+
line[0][2][0] / self.scale,
|
50
|
+
line[0][2][1] / self.scale,
|
51
|
+
),
|
52
|
+
origin=CoordOrigin.TOPLEFT,
|
53
|
+
),
|
54
|
+
)
|
55
|
+
for ix, line in enumerate(result)
|
56
|
+
]
|
57
|
+
|
58
|
+
page.cells = cells # For now, just overwrites all digital cells.
|
59
|
+
|
60
|
+
# DEBUG code:
|
61
|
+
def draw_clusters_and_cells():
|
62
|
+
image = copy.deepcopy(page.image)
|
63
|
+
draw = ImageDraw.Draw(image)
|
64
|
+
|
65
|
+
cell_color = (
|
66
|
+
random.randint(30, 140),
|
67
|
+
random.randint(30, 140),
|
68
|
+
random.randint(30, 140),
|
69
|
+
)
|
70
|
+
for tc in cells:
|
71
|
+
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
72
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
73
|
+
image.show()
|
74
|
+
|
75
|
+
# draw_clusters_and_cells()
|
76
|
+
|
77
|
+
yield page
|
@@ -0,0 +1,318 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
import random
|
4
|
+
import time
|
5
|
+
from typing import Iterable, List
|
6
|
+
|
7
|
+
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
8
|
+
from PIL import ImageDraw
|
9
|
+
|
10
|
+
from docling.datamodel.base_models import (
|
11
|
+
BoundingBox,
|
12
|
+
Cell,
|
13
|
+
Cluster,
|
14
|
+
CoordOrigin,
|
15
|
+
LayoutPrediction,
|
16
|
+
Page,
|
17
|
+
)
|
18
|
+
from docling.utils import layout_utils as lu
|
19
|
+
|
20
|
+
_log = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class LayoutModel:
|
24
|
+
|
25
|
+
TEXT_ELEM_LABELS = [
|
26
|
+
"Text",
|
27
|
+
"Footnote",
|
28
|
+
"Caption",
|
29
|
+
"Checkbox-Unselected",
|
30
|
+
"Checkbox-Selected",
|
31
|
+
"Section-header",
|
32
|
+
"Page-header",
|
33
|
+
"Page-footer",
|
34
|
+
"Code",
|
35
|
+
"List-item",
|
36
|
+
# "Formula",
|
37
|
+
]
|
38
|
+
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
|
39
|
+
|
40
|
+
TABLE_LABEL = "Table"
|
41
|
+
FIGURE_LABEL = "Picture"
|
42
|
+
FORMULA_LABEL = "Formula"
|
43
|
+
|
44
|
+
def __init__(self, config):
|
45
|
+
self.config = config
|
46
|
+
self.layout_predictor = LayoutPredictor(
|
47
|
+
config["artifacts_path"]
|
48
|
+
) # TODO temporary
|
49
|
+
|
50
|
+
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
51
|
+
MIN_INTERSECTION = 0.2
|
52
|
+
CLASS_THRESHOLDS = {
|
53
|
+
"Caption": 0.35,
|
54
|
+
"Footnote": 0.35,
|
55
|
+
"Formula": 0.35,
|
56
|
+
"List-item": 0.35,
|
57
|
+
"Page-footer": 0.35,
|
58
|
+
"Page-header": 0.35,
|
59
|
+
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
|
60
|
+
"Section-header": 0.45,
|
61
|
+
"Table": 0.35,
|
62
|
+
"Text": 0.45,
|
63
|
+
"Title": 0.45,
|
64
|
+
"Document Index": 0.45,
|
65
|
+
"Code": 0.45,
|
66
|
+
"Checkbox-Selected": 0.45,
|
67
|
+
"Checkbox-Unselected": 0.45,
|
68
|
+
"Form": 0.45,
|
69
|
+
"Key-Value Region": 0.45,
|
70
|
+
}
|
71
|
+
|
72
|
+
_log.debug("================= Start postprocess function ====================")
|
73
|
+
start_time = time.time()
|
74
|
+
# Apply Confidence Threshold to cluster predictions
|
75
|
+
# confidence = self.conf_threshold
|
76
|
+
clusters_out = []
|
77
|
+
|
78
|
+
for cluster in clusters:
|
79
|
+
confidence = CLASS_THRESHOLDS[cluster.label]
|
80
|
+
if cluster.confidence >= confidence:
|
81
|
+
# annotation["created_by"] = "high_conf_pred"
|
82
|
+
clusters_out.append(cluster)
|
83
|
+
|
84
|
+
# map to dictionary clusters and cells, with bottom left origin
|
85
|
+
clusters = [
|
86
|
+
{
|
87
|
+
"id": c.id,
|
88
|
+
"bbox": list(
|
89
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
90
|
+
), # TODO
|
91
|
+
"confidence": c.confidence,
|
92
|
+
"cell_ids": [],
|
93
|
+
"type": c.label,
|
94
|
+
}
|
95
|
+
for c in clusters
|
96
|
+
]
|
97
|
+
|
98
|
+
clusters_out = [
|
99
|
+
{
|
100
|
+
"id": c.id,
|
101
|
+
"bbox": list(
|
102
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
103
|
+
), # TODO
|
104
|
+
"confidence": c.confidence,
|
105
|
+
"created_by": "high_conf_pred",
|
106
|
+
"cell_ids": [],
|
107
|
+
"type": c.label,
|
108
|
+
}
|
109
|
+
for c in clusters_out
|
110
|
+
]
|
111
|
+
|
112
|
+
raw_cells = [
|
113
|
+
{
|
114
|
+
"id": c.id,
|
115
|
+
"bbox": list(
|
116
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
117
|
+
), # TODO
|
118
|
+
"text": c.text,
|
119
|
+
}
|
120
|
+
for c in cells
|
121
|
+
]
|
122
|
+
cell_count = len(raw_cells)
|
123
|
+
|
124
|
+
_log.debug("---- 0. Treat cluster overlaps ------")
|
125
|
+
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
126
|
+
|
127
|
+
_log.debug(
|
128
|
+
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
129
|
+
)
|
130
|
+
## Check for cells included in or touched by clusters:
|
131
|
+
clusters_out = lu.assigning_cell_ids_to_clusters(
|
132
|
+
clusters_out, raw_cells, MIN_INTERSECTION
|
133
|
+
)
|
134
|
+
|
135
|
+
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
|
136
|
+
# Creates a map of cell_id->cluster_id
|
137
|
+
(
|
138
|
+
clusters_around_cells,
|
139
|
+
orphan_cell_indices,
|
140
|
+
ambiguous_cell_indices,
|
141
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
142
|
+
|
143
|
+
# Assign orphan cells with lower confidence predictions
|
144
|
+
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
145
|
+
clusters_out, clusters, raw_cells, orphan_cell_indices
|
146
|
+
)
|
147
|
+
|
148
|
+
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
149
|
+
clusters_out = lu.assigning_cell_ids_to_clusters(
|
150
|
+
clusters_out, raw_cells, MIN_INTERSECTION
|
151
|
+
)
|
152
|
+
|
153
|
+
_log.debug("---- 3. Settle Ambigous Cells")
|
154
|
+
# Creates an update map after assignment of cell_id->cluster_id
|
155
|
+
(
|
156
|
+
clusters_around_cells,
|
157
|
+
orphan_cell_indices,
|
158
|
+
ambiguous_cell_indices,
|
159
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
160
|
+
|
161
|
+
# Settle pdf cells that belong to multiple clusters
|
162
|
+
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
|
163
|
+
clusters_out, raw_cells, ambiguous_cell_indices
|
164
|
+
)
|
165
|
+
|
166
|
+
_log.debug("---- 4. Set Orphans as Text")
|
167
|
+
(
|
168
|
+
clusters_around_cells,
|
169
|
+
orphan_cell_indices,
|
170
|
+
ambiguous_cell_indices,
|
171
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
172
|
+
|
173
|
+
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
174
|
+
clusters_out, clusters, raw_cells, orphan_cell_indices
|
175
|
+
)
|
176
|
+
|
177
|
+
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
178
|
+
# Merge cells orphan cells
|
179
|
+
clusters_out = lu.merge_cells(clusters_out)
|
180
|
+
|
181
|
+
# Clean up clusters that remain from merged and unreasonable clusters
|
182
|
+
clusters_out = lu.clean_up_clusters(
|
183
|
+
clusters_out,
|
184
|
+
raw_cells,
|
185
|
+
merge_cells=True,
|
186
|
+
img_table=True,
|
187
|
+
one_cell_table=True,
|
188
|
+
)
|
189
|
+
|
190
|
+
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
|
191
|
+
clusters_out = new_clusters
|
192
|
+
|
193
|
+
## We first rebuild where every cell is now:
|
194
|
+
## Now we write into a prediction cells list, not into the raw cells list.
|
195
|
+
## As we don't need previous labels, we best overwrite any old list, because that might
|
196
|
+
## have been sorted differently.
|
197
|
+
(
|
198
|
+
clusters_around_cells,
|
199
|
+
orphan_cell_indices,
|
200
|
+
ambiguous_cell_indices,
|
201
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
202
|
+
|
203
|
+
target_cells = []
|
204
|
+
for ix, cell in enumerate(raw_cells):
|
205
|
+
new_cell = {
|
206
|
+
"id": ix,
|
207
|
+
"rawcell_id": ix,
|
208
|
+
"label": "None",
|
209
|
+
"bbox": cell["bbox"],
|
210
|
+
"text": cell["text"],
|
211
|
+
}
|
212
|
+
for cluster_index in clusters_around_cells[
|
213
|
+
ix
|
214
|
+
]: # By previous analysis, this is always 1 cluster.
|
215
|
+
new_cell["label"] = clusters_out[cluster_index]["type"]
|
216
|
+
target_cells.append(new_cell)
|
217
|
+
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
|
218
|
+
cells_out = target_cells
|
219
|
+
|
220
|
+
## -------------------------------
|
221
|
+
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
|
222
|
+
_log.debug("---- 5. Sort clusters in reading order ------")
|
223
|
+
sorted_clusters = lu.produce_reading_order(
|
224
|
+
clusters_out, "raw_cell_ids", "raw_cell_ids", True
|
225
|
+
)
|
226
|
+
clusters_out = sorted_clusters
|
227
|
+
|
228
|
+
# end_time = timer()
|
229
|
+
_log.debug("---- End of postprocessing function ------")
|
230
|
+
end_time = time.time() - start_time
|
231
|
+
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
232
|
+
|
233
|
+
cells_out = [
|
234
|
+
Cell(
|
235
|
+
id=c["id"],
|
236
|
+
bbox=BoundingBox.from_tuple(
|
237
|
+
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
238
|
+
).to_top_left_origin(page_height),
|
239
|
+
text=c["text"],
|
240
|
+
)
|
241
|
+
for c in cells_out
|
242
|
+
]
|
243
|
+
clusters_out_new = []
|
244
|
+
for c in clusters_out:
|
245
|
+
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
|
246
|
+
c_new = Cluster(
|
247
|
+
id=c["id"],
|
248
|
+
bbox=BoundingBox.from_tuple(
|
249
|
+
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
250
|
+
).to_top_left_origin(page_height),
|
251
|
+
confidence=c["confidence"],
|
252
|
+
label=c["type"],
|
253
|
+
cells=cluster_cells,
|
254
|
+
)
|
255
|
+
clusters_out_new.append(c_new)
|
256
|
+
|
257
|
+
return clusters_out_new, cells_out
|
258
|
+
|
259
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
260
|
+
for page in page_batch:
|
261
|
+
clusters = []
|
262
|
+
for ix, pred_item in enumerate(self.layout_predictor.predict(page.image)):
|
263
|
+
cluster = Cluster(
|
264
|
+
id=ix,
|
265
|
+
label=pred_item["label"],
|
266
|
+
confidence=pred_item["confidence"],
|
267
|
+
bbox=BoundingBox.model_validate(pred_item),
|
268
|
+
cells=[],
|
269
|
+
)
|
270
|
+
clusters.append(cluster)
|
271
|
+
|
272
|
+
# Map cells to clusters
|
273
|
+
# TODO: Remove, postprocess should take care of it anyway.
|
274
|
+
for cell in page.cells:
|
275
|
+
for cluster in clusters:
|
276
|
+
if not cell.bbox.area() > 0:
|
277
|
+
overlap_frac = 0.0
|
278
|
+
else:
|
279
|
+
overlap_frac = (
|
280
|
+
cell.bbox.intersection_area_with(cluster.bbox)
|
281
|
+
/ cell.bbox.area()
|
282
|
+
)
|
283
|
+
|
284
|
+
if overlap_frac > 0.5:
|
285
|
+
cluster.cells.append(cell)
|
286
|
+
|
287
|
+
# Pre-sort clusters
|
288
|
+
# clusters = self.sort_clusters_by_cell_order(clusters)
|
289
|
+
|
290
|
+
# DEBUG code:
|
291
|
+
def draw_clusters_and_cells():
|
292
|
+
image = copy.deepcopy(page.image)
|
293
|
+
draw = ImageDraw.Draw(image)
|
294
|
+
for c in clusters:
|
295
|
+
x0, y0, x1, y1 = c.bbox.as_tuple()
|
296
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
297
|
+
|
298
|
+
cell_color = (
|
299
|
+
random.randint(30, 140),
|
300
|
+
random.randint(30, 140),
|
301
|
+
random.randint(30, 140),
|
302
|
+
)
|
303
|
+
for tc in c.cells: # [:1]:
|
304
|
+
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
305
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
306
|
+
image.show()
|
307
|
+
|
308
|
+
# draw_clusters_and_cells()
|
309
|
+
|
310
|
+
clusters, page.cells = self.postprocess(
|
311
|
+
clusters, page.cells, page.size.height
|
312
|
+
)
|
313
|
+
|
314
|
+
# draw_clusters_and_cells()
|
315
|
+
|
316
|
+
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
317
|
+
|
318
|
+
yield page
|
@@ -0,0 +1,160 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
from typing import Iterable, List
|
4
|
+
|
5
|
+
from docling.datamodel.base_models import (
|
6
|
+
AssembledUnit,
|
7
|
+
FigureElement,
|
8
|
+
Page,
|
9
|
+
PageElement,
|
10
|
+
TableElement,
|
11
|
+
TextElement,
|
12
|
+
)
|
13
|
+
from docling.models.layout_model import LayoutModel
|
14
|
+
|
15
|
+
_log = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PageAssembleModel:
|
19
|
+
def __init__(self, config):
|
20
|
+
self.config = config
|
21
|
+
|
22
|
+
# self.line_wrap_pattern = re.compile(r'(?<=[^\W_])- \n(?=\w)')
|
23
|
+
|
24
|
+
# def sanitize_text_poor(self, lines):
|
25
|
+
# text = '\n'.join(lines)
|
26
|
+
#
|
27
|
+
# # treat line wraps.
|
28
|
+
# sanitized_text = self.line_wrap_pattern.sub('', text)
|
29
|
+
#
|
30
|
+
# sanitized_text = sanitized_text.replace('\n', ' ')
|
31
|
+
#
|
32
|
+
# return sanitized_text
|
33
|
+
|
34
|
+
def sanitize_text(self, lines):
|
35
|
+
if len(lines) <= 1:
|
36
|
+
return " ".join(lines)
|
37
|
+
|
38
|
+
for ix, line in enumerate(lines[1:]):
|
39
|
+
prev_line = lines[ix]
|
40
|
+
|
41
|
+
if prev_line.endswith("-"):
|
42
|
+
prev_words = re.findall(r"\b[\w]+\b", prev_line)
|
43
|
+
line_words = re.findall(r"\b[\w]+\b", line)
|
44
|
+
|
45
|
+
if (
|
46
|
+
len(prev_words)
|
47
|
+
and len(line_words)
|
48
|
+
and prev_words[-1].isalnum()
|
49
|
+
and line_words[0].isalnum()
|
50
|
+
):
|
51
|
+
lines[ix] = prev_line[:-1]
|
52
|
+
else:
|
53
|
+
lines[ix] += " "
|
54
|
+
|
55
|
+
sanitized_text = "".join(lines)
|
56
|
+
|
57
|
+
return sanitized_text.strip() # Strip any leading or trailing whitespace
|
58
|
+
|
59
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
60
|
+
for page in page_batch:
|
61
|
+
# assembles some JSON output page by page.
|
62
|
+
|
63
|
+
elements: List[PageElement] = []
|
64
|
+
headers: List[PageElement] = []
|
65
|
+
body: List[PageElement] = []
|
66
|
+
|
67
|
+
for cluster in page.predictions.layout.clusters:
|
68
|
+
# _log.info("Cluster label seen:", cluster.label)
|
69
|
+
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
|
70
|
+
|
71
|
+
textlines = [
|
72
|
+
cell.text.replace("\x02", "-").strip()
|
73
|
+
for cell in cluster.cells
|
74
|
+
if len(cell.text.strip()) > 0
|
75
|
+
]
|
76
|
+
text = self.sanitize_text(textlines)
|
77
|
+
text_el = TextElement(
|
78
|
+
label=cluster.label,
|
79
|
+
id=cluster.id,
|
80
|
+
text=text,
|
81
|
+
page_no=page.page_no,
|
82
|
+
cluster=cluster,
|
83
|
+
)
|
84
|
+
elements.append(text_el)
|
85
|
+
|
86
|
+
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
|
87
|
+
headers.append(text_el)
|
88
|
+
else:
|
89
|
+
body.append(text_el)
|
90
|
+
elif cluster.label == LayoutModel.TABLE_LABEL:
|
91
|
+
tbl = None
|
92
|
+
if page.predictions.tablestructure:
|
93
|
+
tbl = page.predictions.tablestructure.table_map.get(
|
94
|
+
cluster.id, None
|
95
|
+
)
|
96
|
+
if (
|
97
|
+
not tbl
|
98
|
+
): # fallback: add table without structure, if it isn't present
|
99
|
+
tbl = TableElement(
|
100
|
+
label=cluster.label,
|
101
|
+
id=cluster.id,
|
102
|
+
text="",
|
103
|
+
otsl_seq=[],
|
104
|
+
table_cells=[],
|
105
|
+
cluster=cluster,
|
106
|
+
page_no=page.page_no,
|
107
|
+
)
|
108
|
+
|
109
|
+
elements.append(tbl)
|
110
|
+
body.append(tbl)
|
111
|
+
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
112
|
+
fig = None
|
113
|
+
if page.predictions.figures_classification:
|
114
|
+
fig = page.predictions.figures_classification.figure_map.get(
|
115
|
+
cluster.id, None
|
116
|
+
)
|
117
|
+
if (
|
118
|
+
not fig
|
119
|
+
): # fallback: add figure without classification, if it isn't present
|
120
|
+
fig = FigureElement(
|
121
|
+
label=cluster.label,
|
122
|
+
id=cluster.id,
|
123
|
+
text="",
|
124
|
+
data=None,
|
125
|
+
cluster=cluster,
|
126
|
+
page_no=page.page_no,
|
127
|
+
)
|
128
|
+
elements.append(fig)
|
129
|
+
body.append(fig)
|
130
|
+
elif cluster.label == LayoutModel.FORMULA_LABEL:
|
131
|
+
equation = None
|
132
|
+
if page.predictions.equations_prediction:
|
133
|
+
equation = (
|
134
|
+
page.predictions.equations_prediction.equation_map.get(
|
135
|
+
cluster.id, None
|
136
|
+
)
|
137
|
+
)
|
138
|
+
if not equation: # fallback: add empty formula, if it isn't present
|
139
|
+
text = self.sanitize_text(
|
140
|
+
[
|
141
|
+
cell.text.replace("\x02", "-").strip()
|
142
|
+
for cell in cluster.cells
|
143
|
+
if len(cell.text.strip()) > 0
|
144
|
+
]
|
145
|
+
)
|
146
|
+
equation = TextElement(
|
147
|
+
label=cluster.label,
|
148
|
+
id=cluster.id,
|
149
|
+
cluster=cluster,
|
150
|
+
page_no=page.page_no,
|
151
|
+
text=text,
|
152
|
+
)
|
153
|
+
elements.append(equation)
|
154
|
+
body.append(equation)
|
155
|
+
|
156
|
+
page.assembled = AssembledUnit(
|
157
|
+
elements=elements, headers=headers, body=body
|
158
|
+
)
|
159
|
+
|
160
|
+
yield page
|
@@ -0,0 +1,114 @@
|
|
1
|
+
from typing import Iterable
|
2
|
+
|
3
|
+
import numpy
|
4
|
+
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
5
|
+
|
6
|
+
from docling.datamodel.base_models import (
|
7
|
+
BoundingBox,
|
8
|
+
Page,
|
9
|
+
TableCell,
|
10
|
+
TableElement,
|
11
|
+
TableStructurePrediction,
|
12
|
+
)
|
13
|
+
|
14
|
+
|
15
|
+
class TableStructureModel:
|
16
|
+
def __init__(self, config):
|
17
|
+
self.config = config
|
18
|
+
self.do_cell_matching = config["do_cell_matching"]
|
19
|
+
|
20
|
+
self.enabled = config["enabled"]
|
21
|
+
if self.enabled:
|
22
|
+
artifacts_path = config["artifacts_path"]
|
23
|
+
# Third Party
|
24
|
+
import docling_ibm_models.tableformer.common as c
|
25
|
+
|
26
|
+
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
27
|
+
self.tm_config["model"]["save_dir"] = artifacts_path
|
28
|
+
self.tm_model_type = self.tm_config["model"]["type"]
|
29
|
+
|
30
|
+
self.tf_predictor = TFPredictor(self.tm_config)
|
31
|
+
|
32
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
33
|
+
|
34
|
+
if not self.enabled:
|
35
|
+
yield from page_batch
|
36
|
+
return
|
37
|
+
|
38
|
+
for page in page_batch:
|
39
|
+
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
40
|
+
|
41
|
+
in_tables = [
|
42
|
+
(
|
43
|
+
cluster,
|
44
|
+
[
|
45
|
+
round(cluster.bbox.l),
|
46
|
+
round(cluster.bbox.t),
|
47
|
+
round(cluster.bbox.r),
|
48
|
+
round(cluster.bbox.b),
|
49
|
+
],
|
50
|
+
)
|
51
|
+
for cluster in page.predictions.layout.clusters
|
52
|
+
if cluster.label == "Table"
|
53
|
+
]
|
54
|
+
if not len(in_tables):
|
55
|
+
yield page
|
56
|
+
continue
|
57
|
+
|
58
|
+
tokens = []
|
59
|
+
for c in page.cells:
|
60
|
+
for cluster, _ in in_tables:
|
61
|
+
if c.bbox.area() > 0:
|
62
|
+
if (
|
63
|
+
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
64
|
+
> 0.2
|
65
|
+
):
|
66
|
+
# Only allow non empty stings (spaces) into the cells of a table
|
67
|
+
if len(c.text.strip()) > 0:
|
68
|
+
tokens.append(c.model_dump())
|
69
|
+
|
70
|
+
iocr_page = {
|
71
|
+
"image": numpy.asarray(page.image),
|
72
|
+
"tokens": tokens,
|
73
|
+
"width": page.size.width,
|
74
|
+
"height": page.size.height,
|
75
|
+
}
|
76
|
+
|
77
|
+
table_clusters, table_bboxes = zip(*in_tables)
|
78
|
+
|
79
|
+
if len(table_bboxes):
|
80
|
+
tf_output = self.tf_predictor.multi_table_predict(
|
81
|
+
iocr_page, table_bboxes, do_matching=self.do_cell_matching
|
82
|
+
)
|
83
|
+
|
84
|
+
for table_cluster, table_out in zip(table_clusters, tf_output):
|
85
|
+
table_cells = []
|
86
|
+
for element in table_out["tf_responses"]:
|
87
|
+
|
88
|
+
if not self.do_cell_matching:
|
89
|
+
the_bbox = BoundingBox.model_validate(element["bbox"])
|
90
|
+
text_piece = page._backend.get_text_in_rect(the_bbox)
|
91
|
+
element["bbox"]["token"] = text_piece
|
92
|
+
|
93
|
+
tc = TableCell.model_validate(element)
|
94
|
+
table_cells.append(tc)
|
95
|
+
|
96
|
+
# Retrieving cols/rows, after post processing:
|
97
|
+
num_rows = table_out["predict_details"]["num_rows"]
|
98
|
+
num_cols = table_out["predict_details"]["num_cols"]
|
99
|
+
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
100
|
+
|
101
|
+
tbl = TableElement(
|
102
|
+
otsl_seq=otsl_seq,
|
103
|
+
table_cells=table_cells,
|
104
|
+
num_rows=num_rows,
|
105
|
+
num_cols=num_cols,
|
106
|
+
id=table_cluster.id,
|
107
|
+
page_no=page.page_no,
|
108
|
+
cluster=table_cluster,
|
109
|
+
label="Table",
|
110
|
+
)
|
111
|
+
|
112
|
+
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
113
|
+
|
114
|
+
yield page
|
File without changes
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from abc import abstractmethod
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Iterable
|
4
|
+
|
5
|
+
from docling.datamodel.base_models import Page, PipelineOptions
|
6
|
+
|
7
|
+
|
8
|
+
class BaseModelPipeline:
|
9
|
+
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
|
10
|
+
self.model_pipe = []
|
11
|
+
self.artifacts_path = artifacts_path
|
12
|
+
self.pipeline_options = pipeline_options
|
13
|
+
|
14
|
+
def apply(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
15
|
+
for model in self.model_pipe:
|
16
|
+
page_batch = model(page_batch)
|
17
|
+
|
18
|
+
yield from page_batch
|