docling 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/__init__.py +0 -0
- docling/backend/__init__.py +0 -0
- docling/backend/abstract_backend.py +59 -0
- docling/backend/docling_parse_backend.py +207 -0
- docling/backend/pypdfium2_backend.py +233 -0
- docling/datamodel/__init__.py +0 -0
- docling/datamodel/base_models.py +312 -0
- docling/datamodel/document.py +363 -0
- docling/datamodel/settings.py +32 -0
- docling/document_converter.py +276 -0
- docling/models/__init__.py +0 -0
- docling/models/base_ocr_model.py +124 -0
- docling/models/ds_glm_model.py +82 -0
- docling/models/easyocr_model.py +70 -0
- docling/models/layout_model.py +328 -0
- docling/models/page_assemble_model.py +148 -0
- docling/models/table_structure_model.py +144 -0
- docling/pipeline/__init__.py +0 -0
- docling/pipeline/base_model_pipeline.py +17 -0
- docling/pipeline/standard_model_pipeline.py +38 -0
- docling/utils/__init__.py +0 -0
- docling/utils/layout_utils.py +806 -0
- docling/utils/utils.py +41 -0
- docling-1.6.2.dist-info/LICENSE +21 -0
- docling-1.6.2.dist-info/METADATA +192 -0
- docling-1.6.2.dist-info/RECORD +27 -0
- docling-1.6.2.dist-info/WHEEL +4 -0
@@ -0,0 +1,328 @@
|
|
1
|
+
import copy
|
2
|
+
import logging
|
3
|
+
import random
|
4
|
+
import time
|
5
|
+
from typing import Iterable, List
|
6
|
+
|
7
|
+
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
8
|
+
from PIL import ImageDraw
|
9
|
+
|
10
|
+
from docling.datamodel.base_models import (
|
11
|
+
BoundingBox,
|
12
|
+
Cell,
|
13
|
+
Cluster,
|
14
|
+
CoordOrigin,
|
15
|
+
LayoutPrediction,
|
16
|
+
Page,
|
17
|
+
)
|
18
|
+
from docling.utils import layout_utils as lu
|
19
|
+
|
20
|
+
_log = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
class LayoutModel:
|
24
|
+
|
25
|
+
TEXT_ELEM_LABELS = [
|
26
|
+
"Text",
|
27
|
+
"Footnote",
|
28
|
+
"Caption",
|
29
|
+
"Checkbox-Unselected",
|
30
|
+
"Checkbox-Selected",
|
31
|
+
"Section-header",
|
32
|
+
"Page-header",
|
33
|
+
"Page-footer",
|
34
|
+
"Code",
|
35
|
+
"List-item",
|
36
|
+
# "Formula",
|
37
|
+
]
|
38
|
+
PAGE_HEADER_LABELS = ["Page-header", "Page-footer"]
|
39
|
+
|
40
|
+
TABLE_LABEL = "Table"
|
41
|
+
FIGURE_LABEL = "Picture"
|
42
|
+
FORMULA_LABEL = "Formula"
|
43
|
+
|
44
|
+
def __init__(self, config):
|
45
|
+
self.config = config
|
46
|
+
self.layout_predictor = LayoutPredictor(
|
47
|
+
config["artifacts_path"]
|
48
|
+
) # TODO temporary
|
49
|
+
|
50
|
+
def postprocess(self, clusters: List[Cluster], cells: List[Cell], page_height):
|
51
|
+
MIN_INTERSECTION = 0.2
|
52
|
+
CLASS_THRESHOLDS = {
|
53
|
+
"Caption": 0.35,
|
54
|
+
"Footnote": 0.35,
|
55
|
+
"Formula": 0.35,
|
56
|
+
"List-item": 0.35,
|
57
|
+
"Page-footer": 0.35,
|
58
|
+
"Page-header": 0.35,
|
59
|
+
"Picture": 0.2, # low threshold adjust to capture chemical structures for examples.
|
60
|
+
"Section-header": 0.45,
|
61
|
+
"Table": 0.35,
|
62
|
+
"Text": 0.45,
|
63
|
+
"Title": 0.45,
|
64
|
+
"Document Index": 0.45,
|
65
|
+
"Code": 0.45,
|
66
|
+
"Checkbox-Selected": 0.45,
|
67
|
+
"Checkbox-Unselected": 0.45,
|
68
|
+
"Form": 0.45,
|
69
|
+
"Key-Value Region": 0.45,
|
70
|
+
}
|
71
|
+
|
72
|
+
CLASS_REMAPPINGS = {
|
73
|
+
"Document Index": "Table",
|
74
|
+
}
|
75
|
+
|
76
|
+
_log.debug("================= Start postprocess function ====================")
|
77
|
+
start_time = time.time()
|
78
|
+
# Apply Confidence Threshold to cluster predictions
|
79
|
+
# confidence = self.conf_threshold
|
80
|
+
clusters_out = []
|
81
|
+
|
82
|
+
for cluster in clusters:
|
83
|
+
confidence = CLASS_THRESHOLDS[cluster.label]
|
84
|
+
if cluster.confidence >= confidence:
|
85
|
+
# annotation["created_by"] = "high_conf_pred"
|
86
|
+
|
87
|
+
# Remap class labels where needed.
|
88
|
+
if cluster.label in CLASS_REMAPPINGS.keys():
|
89
|
+
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
90
|
+
clusters_out.append(cluster)
|
91
|
+
|
92
|
+
# map to dictionary clusters and cells, with bottom left origin
|
93
|
+
clusters = [
|
94
|
+
{
|
95
|
+
"id": c.id,
|
96
|
+
"bbox": list(
|
97
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
98
|
+
), # TODO
|
99
|
+
"confidence": c.confidence,
|
100
|
+
"cell_ids": [],
|
101
|
+
"type": c.label,
|
102
|
+
}
|
103
|
+
for c in clusters
|
104
|
+
]
|
105
|
+
|
106
|
+
clusters_out = [
|
107
|
+
{
|
108
|
+
"id": c.id,
|
109
|
+
"bbox": list(
|
110
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
111
|
+
), # TODO
|
112
|
+
"confidence": c.confidence,
|
113
|
+
"created_by": "high_conf_pred",
|
114
|
+
"cell_ids": [],
|
115
|
+
"type": c.label,
|
116
|
+
}
|
117
|
+
for c in clusters_out
|
118
|
+
]
|
119
|
+
|
120
|
+
raw_cells = [
|
121
|
+
{
|
122
|
+
"id": c.id,
|
123
|
+
"bbox": list(
|
124
|
+
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
125
|
+
), # TODO
|
126
|
+
"text": c.text,
|
127
|
+
}
|
128
|
+
for c in cells
|
129
|
+
]
|
130
|
+
cell_count = len(raw_cells)
|
131
|
+
|
132
|
+
_log.debug("---- 0. Treat cluster overlaps ------")
|
133
|
+
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
134
|
+
|
135
|
+
_log.debug(
|
136
|
+
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
137
|
+
)
|
138
|
+
## Check for cells included in or touched by clusters:
|
139
|
+
clusters_out = lu.assigning_cell_ids_to_clusters(
|
140
|
+
clusters_out, raw_cells, MIN_INTERSECTION
|
141
|
+
)
|
142
|
+
|
143
|
+
_log.debug("---- 2. Assign Orphans with Low Confidence Detections")
|
144
|
+
# Creates a map of cell_id->cluster_id
|
145
|
+
(
|
146
|
+
clusters_around_cells,
|
147
|
+
orphan_cell_indices,
|
148
|
+
ambiguous_cell_indices,
|
149
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
150
|
+
|
151
|
+
# Assign orphan cells with lower confidence predictions
|
152
|
+
clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
|
153
|
+
clusters_out, clusters, raw_cells, orphan_cell_indices
|
154
|
+
)
|
155
|
+
|
156
|
+
# Refresh the cell_ids assignment, after creating new clusters using low conf predictions
|
157
|
+
clusters_out = lu.assigning_cell_ids_to_clusters(
|
158
|
+
clusters_out, raw_cells, MIN_INTERSECTION
|
159
|
+
)
|
160
|
+
|
161
|
+
_log.debug("---- 3. Settle Ambigous Cells")
|
162
|
+
# Creates an update map after assignment of cell_id->cluster_id
|
163
|
+
(
|
164
|
+
clusters_around_cells,
|
165
|
+
orphan_cell_indices,
|
166
|
+
ambiguous_cell_indices,
|
167
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
168
|
+
|
169
|
+
# Settle pdf cells that belong to multiple clusters
|
170
|
+
clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
|
171
|
+
clusters_out, raw_cells, ambiguous_cell_indices
|
172
|
+
)
|
173
|
+
|
174
|
+
_log.debug("---- 4. Set Orphans as Text")
|
175
|
+
(
|
176
|
+
clusters_around_cells,
|
177
|
+
orphan_cell_indices,
|
178
|
+
ambiguous_cell_indices,
|
179
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
180
|
+
|
181
|
+
clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
|
182
|
+
clusters_out, clusters, raw_cells, orphan_cell_indices
|
183
|
+
)
|
184
|
+
|
185
|
+
_log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
|
186
|
+
# Merge cells orphan cells
|
187
|
+
clusters_out = lu.merge_cells(clusters_out)
|
188
|
+
|
189
|
+
# Clean up clusters that remain from merged and unreasonable clusters
|
190
|
+
clusters_out = lu.clean_up_clusters(
|
191
|
+
clusters_out,
|
192
|
+
raw_cells,
|
193
|
+
merge_cells=True,
|
194
|
+
img_table=True,
|
195
|
+
one_cell_table=True,
|
196
|
+
)
|
197
|
+
|
198
|
+
new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
|
199
|
+
clusters_out = new_clusters
|
200
|
+
|
201
|
+
## We first rebuild where every cell is now:
|
202
|
+
## Now we write into a prediction cells list, not into the raw cells list.
|
203
|
+
## As we don't need previous labels, we best overwrite any old list, because that might
|
204
|
+
## have been sorted differently.
|
205
|
+
(
|
206
|
+
clusters_around_cells,
|
207
|
+
orphan_cell_indices,
|
208
|
+
ambiguous_cell_indices,
|
209
|
+
) = lu.cell_id_state_map(clusters_out, cell_count)
|
210
|
+
|
211
|
+
target_cells = []
|
212
|
+
for ix, cell in enumerate(raw_cells):
|
213
|
+
new_cell = {
|
214
|
+
"id": ix,
|
215
|
+
"rawcell_id": ix,
|
216
|
+
"label": "None",
|
217
|
+
"bbox": cell["bbox"],
|
218
|
+
"text": cell["text"],
|
219
|
+
}
|
220
|
+
for cluster_index in clusters_around_cells[
|
221
|
+
ix
|
222
|
+
]: # By previous analysis, this is always 1 cluster.
|
223
|
+
new_cell["label"] = clusters_out[cluster_index]["type"]
|
224
|
+
target_cells.append(new_cell)
|
225
|
+
# _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
|
226
|
+
cells_out = target_cells
|
227
|
+
|
228
|
+
## -------------------------------
|
229
|
+
## Sort clusters into reasonable reading order, and sort the cells inside each cluster
|
230
|
+
_log.debug("---- 5. Sort clusters in reading order ------")
|
231
|
+
sorted_clusters = lu.produce_reading_order(
|
232
|
+
clusters_out, "raw_cell_ids", "raw_cell_ids", True
|
233
|
+
)
|
234
|
+
clusters_out = sorted_clusters
|
235
|
+
|
236
|
+
# end_time = timer()
|
237
|
+
_log.debug("---- End of postprocessing function ------")
|
238
|
+
end_time = time.time() - start_time
|
239
|
+
_log.debug(f"Finished post processing in seconds={end_time:.3f}")
|
240
|
+
|
241
|
+
cells_out = [
|
242
|
+
Cell(
|
243
|
+
id=c["id"],
|
244
|
+
bbox=BoundingBox.from_tuple(
|
245
|
+
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
246
|
+
).to_top_left_origin(page_height),
|
247
|
+
text=c["text"],
|
248
|
+
)
|
249
|
+
for c in cells_out
|
250
|
+
]
|
251
|
+
clusters_out_new = []
|
252
|
+
for c in clusters_out:
|
253
|
+
cluster_cells = [ccell for ccell in cells_out if ccell.id in c["cell_ids"]]
|
254
|
+
c_new = Cluster(
|
255
|
+
id=c["id"],
|
256
|
+
bbox=BoundingBox.from_tuple(
|
257
|
+
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT
|
258
|
+
).to_top_left_origin(page_height),
|
259
|
+
confidence=c["confidence"],
|
260
|
+
label=c["type"],
|
261
|
+
cells=cluster_cells,
|
262
|
+
)
|
263
|
+
clusters_out_new.append(c_new)
|
264
|
+
|
265
|
+
return clusters_out_new, cells_out
|
266
|
+
|
267
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
268
|
+
for page in page_batch:
|
269
|
+
clusters = []
|
270
|
+
for ix, pred_item in enumerate(
|
271
|
+
self.layout_predictor.predict(page.get_image(scale=1.0))
|
272
|
+
):
|
273
|
+
cluster = Cluster(
|
274
|
+
id=ix,
|
275
|
+
label=pred_item["label"],
|
276
|
+
confidence=pred_item["confidence"],
|
277
|
+
bbox=BoundingBox.model_validate(pred_item),
|
278
|
+
cells=[],
|
279
|
+
)
|
280
|
+
clusters.append(cluster)
|
281
|
+
|
282
|
+
# Map cells to clusters
|
283
|
+
# TODO: Remove, postprocess should take care of it anyway.
|
284
|
+
for cell in page.cells:
|
285
|
+
for cluster in clusters:
|
286
|
+
if not cell.bbox.area() > 0:
|
287
|
+
overlap_frac = 0.0
|
288
|
+
else:
|
289
|
+
overlap_frac = (
|
290
|
+
cell.bbox.intersection_area_with(cluster.bbox)
|
291
|
+
/ cell.bbox.area()
|
292
|
+
)
|
293
|
+
|
294
|
+
if overlap_frac > 0.5:
|
295
|
+
cluster.cells.append(cell)
|
296
|
+
|
297
|
+
# Pre-sort clusters
|
298
|
+
# clusters = self.sort_clusters_by_cell_order(clusters)
|
299
|
+
|
300
|
+
# DEBUG code:
|
301
|
+
def draw_clusters_and_cells():
|
302
|
+
image = copy.deepcopy(page.image)
|
303
|
+
draw = ImageDraw.Draw(image)
|
304
|
+
for c in clusters:
|
305
|
+
x0, y0, x1, y1 = c.bbox.as_tuple()
|
306
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
307
|
+
|
308
|
+
cell_color = (
|
309
|
+
random.randint(30, 140),
|
310
|
+
random.randint(30, 140),
|
311
|
+
random.randint(30, 140),
|
312
|
+
)
|
313
|
+
for tc in c.cells: # [:1]:
|
314
|
+
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
315
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline=cell_color)
|
316
|
+
image.show()
|
317
|
+
|
318
|
+
# draw_clusters_and_cells()
|
319
|
+
|
320
|
+
clusters, page.cells = self.postprocess(
|
321
|
+
clusters, page.cells, page.size.height
|
322
|
+
)
|
323
|
+
|
324
|
+
# draw_clusters_and_cells()
|
325
|
+
|
326
|
+
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
327
|
+
|
328
|
+
yield page
|
@@ -0,0 +1,148 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
from typing import Iterable, List
|
4
|
+
|
5
|
+
from docling.datamodel.base_models import (
|
6
|
+
AssembledUnit,
|
7
|
+
FigureElement,
|
8
|
+
Page,
|
9
|
+
PageElement,
|
10
|
+
TableElement,
|
11
|
+
TextElement,
|
12
|
+
)
|
13
|
+
from docling.models.layout_model import LayoutModel
|
14
|
+
|
15
|
+
_log = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class PageAssembleModel:
|
19
|
+
def __init__(self, config):
|
20
|
+
self.config = config
|
21
|
+
|
22
|
+
def sanitize_text(self, lines):
|
23
|
+
if len(lines) <= 1:
|
24
|
+
return " ".join(lines)
|
25
|
+
|
26
|
+
for ix, line in enumerate(lines[1:]):
|
27
|
+
prev_line = lines[ix]
|
28
|
+
|
29
|
+
if prev_line.endswith("-"):
|
30
|
+
prev_words = re.findall(r"\b[\w]+\b", prev_line)
|
31
|
+
line_words = re.findall(r"\b[\w]+\b", line)
|
32
|
+
|
33
|
+
if (
|
34
|
+
len(prev_words)
|
35
|
+
and len(line_words)
|
36
|
+
and prev_words[-1].isalnum()
|
37
|
+
and line_words[0].isalnum()
|
38
|
+
):
|
39
|
+
lines[ix] = prev_line[:-1]
|
40
|
+
else:
|
41
|
+
lines[ix] += " "
|
42
|
+
|
43
|
+
sanitized_text = "".join(lines)
|
44
|
+
|
45
|
+
return sanitized_text.strip() # Strip any leading or trailing whitespace
|
46
|
+
|
47
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
48
|
+
for page in page_batch:
|
49
|
+
# assembles some JSON output page by page.
|
50
|
+
|
51
|
+
elements: List[PageElement] = []
|
52
|
+
headers: List[PageElement] = []
|
53
|
+
body: List[PageElement] = []
|
54
|
+
|
55
|
+
for cluster in page.predictions.layout.clusters:
|
56
|
+
# _log.info("Cluster label seen:", cluster.label)
|
57
|
+
if cluster.label in LayoutModel.TEXT_ELEM_LABELS:
|
58
|
+
|
59
|
+
textlines = [
|
60
|
+
cell.text.replace("\x02", "-").strip()
|
61
|
+
for cell in cluster.cells
|
62
|
+
if len(cell.text.strip()) > 0
|
63
|
+
]
|
64
|
+
text = self.sanitize_text(textlines)
|
65
|
+
text_el = TextElement(
|
66
|
+
label=cluster.label,
|
67
|
+
id=cluster.id,
|
68
|
+
text=text,
|
69
|
+
page_no=page.page_no,
|
70
|
+
cluster=cluster,
|
71
|
+
)
|
72
|
+
elements.append(text_el)
|
73
|
+
|
74
|
+
if cluster.label in LayoutModel.PAGE_HEADER_LABELS:
|
75
|
+
headers.append(text_el)
|
76
|
+
else:
|
77
|
+
body.append(text_el)
|
78
|
+
elif cluster.label == LayoutModel.TABLE_LABEL:
|
79
|
+
tbl = None
|
80
|
+
if page.predictions.tablestructure:
|
81
|
+
tbl = page.predictions.tablestructure.table_map.get(
|
82
|
+
cluster.id, None
|
83
|
+
)
|
84
|
+
if (
|
85
|
+
not tbl
|
86
|
+
): # fallback: add table without structure, if it isn't present
|
87
|
+
tbl = TableElement(
|
88
|
+
label=cluster.label,
|
89
|
+
id=cluster.id,
|
90
|
+
text="",
|
91
|
+
otsl_seq=[],
|
92
|
+
table_cells=[],
|
93
|
+
cluster=cluster,
|
94
|
+
page_no=page.page_no,
|
95
|
+
)
|
96
|
+
|
97
|
+
elements.append(tbl)
|
98
|
+
body.append(tbl)
|
99
|
+
elif cluster.label == LayoutModel.FIGURE_LABEL:
|
100
|
+
fig = None
|
101
|
+
if page.predictions.figures_classification:
|
102
|
+
fig = page.predictions.figures_classification.figure_map.get(
|
103
|
+
cluster.id, None
|
104
|
+
)
|
105
|
+
if (
|
106
|
+
not fig
|
107
|
+
): # fallback: add figure without classification, if it isn't present
|
108
|
+
fig = FigureElement(
|
109
|
+
label=cluster.label,
|
110
|
+
id=cluster.id,
|
111
|
+
text="",
|
112
|
+
data=None,
|
113
|
+
cluster=cluster,
|
114
|
+
page_no=page.page_no,
|
115
|
+
)
|
116
|
+
elements.append(fig)
|
117
|
+
body.append(fig)
|
118
|
+
elif cluster.label == LayoutModel.FORMULA_LABEL:
|
119
|
+
equation = None
|
120
|
+
if page.predictions.equations_prediction:
|
121
|
+
equation = (
|
122
|
+
page.predictions.equations_prediction.equation_map.get(
|
123
|
+
cluster.id, None
|
124
|
+
)
|
125
|
+
)
|
126
|
+
if not equation: # fallback: add empty formula, if it isn't present
|
127
|
+
text = self.sanitize_text(
|
128
|
+
[
|
129
|
+
cell.text.replace("\x02", "-").strip()
|
130
|
+
for cell in cluster.cells
|
131
|
+
if len(cell.text.strip()) > 0
|
132
|
+
]
|
133
|
+
)
|
134
|
+
equation = TextElement(
|
135
|
+
label=cluster.label,
|
136
|
+
id=cluster.id,
|
137
|
+
cluster=cluster,
|
138
|
+
page_no=page.page_no,
|
139
|
+
text=text,
|
140
|
+
)
|
141
|
+
elements.append(equation)
|
142
|
+
body.append(equation)
|
143
|
+
|
144
|
+
page.assembled = AssembledUnit(
|
145
|
+
elements=elements, headers=headers, body=body
|
146
|
+
)
|
147
|
+
|
148
|
+
yield page
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import copy
|
2
|
+
from typing import Iterable, List
|
3
|
+
|
4
|
+
import numpy
|
5
|
+
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
6
|
+
from PIL import ImageDraw
|
7
|
+
|
8
|
+
from docling.datamodel.base_models import (
|
9
|
+
BoundingBox,
|
10
|
+
Page,
|
11
|
+
TableCell,
|
12
|
+
TableElement,
|
13
|
+
TableStructurePrediction,
|
14
|
+
)
|
15
|
+
|
16
|
+
|
17
|
+
class TableStructureModel:
|
18
|
+
def __init__(self, config):
|
19
|
+
self.config = config
|
20
|
+
self.do_cell_matching = config["do_cell_matching"]
|
21
|
+
|
22
|
+
self.enabled = config["enabled"]
|
23
|
+
if self.enabled:
|
24
|
+
artifacts_path = config["artifacts_path"]
|
25
|
+
# Third Party
|
26
|
+
import docling_ibm_models.tableformer.common as c
|
27
|
+
|
28
|
+
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
29
|
+
self.tm_config["model"]["save_dir"] = artifacts_path
|
30
|
+
self.tm_model_type = self.tm_config["model"]["type"]
|
31
|
+
|
32
|
+
self.tf_predictor = TFPredictor(self.tm_config)
|
33
|
+
self.scale = 2.0 # Scale up table input images to 144 dpi
|
34
|
+
|
35
|
+
def draw_table_and_cells(self, page: Page, tbl_list: List[TableElement]):
|
36
|
+
image = (
|
37
|
+
page._backend.get_page_image()
|
38
|
+
) # make new image to avoid drawing on the saved ones
|
39
|
+
draw = ImageDraw.Draw(image)
|
40
|
+
|
41
|
+
for table_element in tbl_list:
|
42
|
+
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
|
43
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
44
|
+
|
45
|
+
for tc in table_element.table_cells:
|
46
|
+
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
47
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="blue")
|
48
|
+
|
49
|
+
image.show()
|
50
|
+
|
51
|
+
def __call__(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
52
|
+
|
53
|
+
if not self.enabled:
|
54
|
+
yield from page_batch
|
55
|
+
return
|
56
|
+
|
57
|
+
for page in page_batch:
|
58
|
+
|
59
|
+
page.predictions.tablestructure = TableStructurePrediction() # dummy
|
60
|
+
|
61
|
+
in_tables = [
|
62
|
+
(
|
63
|
+
cluster,
|
64
|
+
[
|
65
|
+
round(cluster.bbox.l) * self.scale,
|
66
|
+
round(cluster.bbox.t) * self.scale,
|
67
|
+
round(cluster.bbox.r) * self.scale,
|
68
|
+
round(cluster.bbox.b) * self.scale,
|
69
|
+
],
|
70
|
+
)
|
71
|
+
for cluster in page.predictions.layout.clusters
|
72
|
+
if cluster.label == "Table"
|
73
|
+
]
|
74
|
+
if not len(in_tables):
|
75
|
+
yield page
|
76
|
+
continue
|
77
|
+
|
78
|
+
tokens = []
|
79
|
+
for c in page.cells:
|
80
|
+
for cluster, _ in in_tables:
|
81
|
+
if c.bbox.area() > 0:
|
82
|
+
if (
|
83
|
+
c.bbox.intersection_area_with(cluster.bbox) / c.bbox.area()
|
84
|
+
> 0.2
|
85
|
+
):
|
86
|
+
# Only allow non empty stings (spaces) into the cells of a table
|
87
|
+
if len(c.text.strip()) > 0:
|
88
|
+
new_cell = copy.deepcopy(c)
|
89
|
+
new_cell.bbox = new_cell.bbox.scaled(scale=self.scale)
|
90
|
+
|
91
|
+
tokens.append(new_cell.model_dump())
|
92
|
+
|
93
|
+
page_input = {
|
94
|
+
"tokens": tokens,
|
95
|
+
"width": page.size.width * self.scale,
|
96
|
+
"height": page.size.height * self.scale,
|
97
|
+
}
|
98
|
+
page_input["image"] = numpy.asarray(page.get_image(scale=self.scale))
|
99
|
+
|
100
|
+
table_clusters, table_bboxes = zip(*in_tables)
|
101
|
+
|
102
|
+
if len(table_bboxes):
|
103
|
+
tf_output = self.tf_predictor.multi_table_predict(
|
104
|
+
page_input, table_bboxes, do_matching=self.do_cell_matching
|
105
|
+
)
|
106
|
+
|
107
|
+
for table_cluster, table_out in zip(table_clusters, tf_output):
|
108
|
+
table_cells = []
|
109
|
+
for element in table_out["tf_responses"]:
|
110
|
+
|
111
|
+
if not self.do_cell_matching:
|
112
|
+
the_bbox = BoundingBox.model_validate(
|
113
|
+
element["bbox"]
|
114
|
+
).scaled(1 / self.scale)
|
115
|
+
text_piece = page._backend.get_text_in_rect(the_bbox)
|
116
|
+
element["bbox"]["token"] = text_piece
|
117
|
+
|
118
|
+
tc = TableCell.model_validate(element)
|
119
|
+
if self.do_cell_matching:
|
120
|
+
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
121
|
+
table_cells.append(tc)
|
122
|
+
|
123
|
+
# Retrieving cols/rows, after post processing:
|
124
|
+
num_rows = table_out["predict_details"]["num_rows"]
|
125
|
+
num_cols = table_out["predict_details"]["num_cols"]
|
126
|
+
otsl_seq = table_out["predict_details"]["prediction"]["rs_seq"]
|
127
|
+
|
128
|
+
tbl = TableElement(
|
129
|
+
otsl_seq=otsl_seq,
|
130
|
+
table_cells=table_cells,
|
131
|
+
num_rows=num_rows,
|
132
|
+
num_cols=num_cols,
|
133
|
+
id=table_cluster.id,
|
134
|
+
page_no=page.page_no,
|
135
|
+
cluster=table_cluster,
|
136
|
+
label="Table",
|
137
|
+
)
|
138
|
+
|
139
|
+
page.predictions.tablestructure.table_map[table_cluster.id] = tbl
|
140
|
+
|
141
|
+
# For debugging purposes:
|
142
|
+
# self.draw_table_and_cells(page, page.predictions.tablestructure.table_map.values())
|
143
|
+
|
144
|
+
yield page
|
File without changes
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Iterable
|
3
|
+
|
4
|
+
from docling.datamodel.base_models import Page, PipelineOptions
|
5
|
+
|
6
|
+
|
7
|
+
class BaseModelPipeline:
|
8
|
+
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
|
9
|
+
self.model_pipe = []
|
10
|
+
self.artifacts_path = artifacts_path
|
11
|
+
self.pipeline_options = pipeline_options
|
12
|
+
|
13
|
+
def apply(self, page_batch: Iterable[Page]) -> Iterable[Page]:
|
14
|
+
for model in self.model_pipe:
|
15
|
+
page_batch = model(page_batch)
|
16
|
+
|
17
|
+
yield from page_batch
|
@@ -0,0 +1,38 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
from docling.datamodel.base_models import PipelineOptions
|
4
|
+
from docling.models.easyocr_model import EasyOcrModel
|
5
|
+
from docling.models.layout_model import LayoutModel
|
6
|
+
from docling.models.table_structure_model import TableStructureModel
|
7
|
+
from docling.pipeline.base_model_pipeline import BaseModelPipeline
|
8
|
+
|
9
|
+
|
10
|
+
class StandardModelPipeline(BaseModelPipeline):
|
11
|
+
_layout_model_path = "model_artifacts/layout/beehive_v0.0.5"
|
12
|
+
_table_model_path = "model_artifacts/tableformer"
|
13
|
+
|
14
|
+
def __init__(self, artifacts_path: Path, pipeline_options: PipelineOptions):
|
15
|
+
super().__init__(artifacts_path, pipeline_options)
|
16
|
+
|
17
|
+
self.model_pipe = [
|
18
|
+
EasyOcrModel(
|
19
|
+
config={
|
20
|
+
"lang": ["fr", "de", "es", "en"],
|
21
|
+
"enabled": pipeline_options.do_ocr,
|
22
|
+
}
|
23
|
+
),
|
24
|
+
LayoutModel(
|
25
|
+
config={
|
26
|
+
"artifacts_path": artifacts_path
|
27
|
+
/ StandardModelPipeline._layout_model_path
|
28
|
+
}
|
29
|
+
),
|
30
|
+
TableStructureModel(
|
31
|
+
config={
|
32
|
+
"artifacts_path": artifacts_path
|
33
|
+
/ StandardModelPipeline._table_model_path,
|
34
|
+
"enabled": pipeline_options.do_table_structure,
|
35
|
+
"do_cell_matching": pipeline_options.table_structure_options.do_cell_matching,
|
36
|
+
}
|
37
|
+
),
|
38
|
+
]
|
File without changes
|