docling 2.11.0__py3-none-any.whl → 2.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/xml/__init__.py +0 -0
- docling/backend/xml/uspto_backend.py +1888 -0
- docling/cli/main.py +8 -0
- docling/datamodel/base_models.py +18 -4
- docling/datamodel/document.py +77 -13
- docling/datamodel/pipeline_options.py +68 -4
- docling/datamodel/settings.py +1 -0
- docling/document_converter.py +11 -2
- docling/models/ds_glm_model.py +34 -4
- docling/models/easyocr_model.py +37 -3
- docling/models/layout_model.py +144 -280
- docling/models/page_assemble_model.py +11 -1
- docling/models/rapid_ocr_model.py +24 -45
- docling/models/table_structure_model.py +49 -33
- docling/pipeline/base_pipeline.py +3 -1
- docling/pipeline/standard_pdf_pipeline.py +7 -3
- docling/utils/accelerator_utils.py +42 -0
- docling/utils/glm_utils.py +11 -3
- docling/utils/layout_postprocessor.py +666 -0
- {docling-2.11.0.dist-info → docling-2.13.0.dist-info}/METADATA +3 -3
- {docling-2.11.0.dist-info → docling-2.13.0.dist-info}/RECORD +24 -21
- docling/utils/layout_utils.py +0 -812
- {docling-2.11.0.dist-info → docling-2.13.0.dist-info}/LICENSE +0 -0
- {docling-2.11.0.dist-info → docling-2.13.0.dist-info}/WHEEL +0 -0
- {docling-2.11.0.dist-info → docling-2.13.0.dist-info}/entry_points.txt +0 -0
docling/models/layout_model.py
CHANGED
@@ -7,7 +7,7 @@ from typing import Iterable, List
|
|
7
7
|
|
8
8
|
from docling_core.types.doc import CoordOrigin, DocItemLabel
|
9
9
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
10
|
-
from PIL import ImageDraw
|
10
|
+
from PIL import Image, ImageDraw, ImageFont
|
11
11
|
|
12
12
|
from docling.datamodel.base_models import (
|
13
13
|
BoundingBox,
|
@@ -17,9 +17,11 @@ from docling.datamodel.base_models import (
|
|
17
17
|
Page,
|
18
18
|
)
|
19
19
|
from docling.datamodel.document import ConversionResult
|
20
|
+
from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
|
20
21
|
from docling.datamodel.settings import settings
|
21
22
|
from docling.models.base_model import BasePageModel
|
22
|
-
from docling.utils import
|
23
|
+
from docling.utils.accelerator_utils import decide_device
|
24
|
+
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
23
25
|
from docling.utils.profiling import TimeRecorder
|
24
26
|
|
25
27
|
_log = logging.getLogger(__name__)
|
@@ -42,237 +44,139 @@ class LayoutModel(BasePageModel):
|
|
42
44
|
]
|
43
45
|
PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
|
44
46
|
|
45
|
-
|
47
|
+
TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
46
48
|
FIGURE_LABEL = DocItemLabel.PICTURE
|
47
49
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
50
|
+
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
48
51
|
|
49
|
-
def __init__(self, artifacts_path: Path):
|
50
|
-
|
52
|
+
def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
|
53
|
+
device = decide_device(accelerator_options.device)
|
51
54
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
DocItemLabel.FOOTNOTE: 0.35,
|
57
|
-
DocItemLabel.FORMULA: 0.35,
|
58
|
-
DocItemLabel.LIST_ITEM: 0.35,
|
59
|
-
DocItemLabel.PAGE_FOOTER: 0.35,
|
60
|
-
DocItemLabel.PAGE_HEADER: 0.35,
|
61
|
-
DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
|
62
|
-
DocItemLabel.SECTION_HEADER: 0.45,
|
63
|
-
DocItemLabel.TABLE: 0.35,
|
64
|
-
DocItemLabel.TEXT: 0.45,
|
65
|
-
DocItemLabel.TITLE: 0.45,
|
66
|
-
DocItemLabel.DOCUMENT_INDEX: 0.45,
|
67
|
-
DocItemLabel.CODE: 0.45,
|
68
|
-
DocItemLabel.CHECKBOX_SELECTED: 0.45,
|
69
|
-
DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
|
70
|
-
DocItemLabel.FORM: 0.45,
|
71
|
-
DocItemLabel.KEY_VALUE_REGION: 0.45,
|
72
|
-
}
|
73
|
-
|
74
|
-
CLASS_REMAPPINGS = {
|
75
|
-
DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
|
76
|
-
DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
|
77
|
-
}
|
78
|
-
|
79
|
-
_log.debug("================= Start postprocess function ====================")
|
80
|
-
start_time = time.time()
|
81
|
-
# Apply Confidence Threshold to cluster predictions
|
82
|
-
# confidence = self.conf_threshold
|
83
|
-
clusters_mod = []
|
84
|
-
|
85
|
-
for cluster in clusters_in:
|
86
|
-
confidence = CLASS_THRESHOLDS[cluster.label]
|
87
|
-
if cluster.confidence >= confidence:
|
88
|
-
# annotation["created_by"] = "high_conf_pred"
|
89
|
-
|
90
|
-
# Remap class labels where needed.
|
91
|
-
if cluster.label in CLASS_REMAPPINGS.keys():
|
92
|
-
cluster.label = CLASS_REMAPPINGS[cluster.label]
|
93
|
-
clusters_mod.append(cluster)
|
94
|
-
|
95
|
-
# map to dictionary clusters and cells, with bottom left origin
|
96
|
-
clusters_orig = [
|
97
|
-
{
|
98
|
-
"id": c.id,
|
99
|
-
"bbox": list(
|
100
|
-
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
101
|
-
), # TODO
|
102
|
-
"confidence": c.confidence,
|
103
|
-
"cell_ids": [],
|
104
|
-
"type": c.label,
|
105
|
-
}
|
106
|
-
for c in clusters_in
|
107
|
-
]
|
108
|
-
|
109
|
-
clusters_out = [
|
110
|
-
{
|
111
|
-
"id": c.id,
|
112
|
-
"bbox": list(
|
113
|
-
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
114
|
-
), # TODO
|
115
|
-
"confidence": c.confidence,
|
116
|
-
"created_by": "high_conf_pred",
|
117
|
-
"cell_ids": [],
|
118
|
-
"type": c.label,
|
119
|
-
}
|
120
|
-
for c in clusters_mod
|
121
|
-
]
|
122
|
-
|
123
|
-
del clusters_mod
|
124
|
-
|
125
|
-
raw_cells = [
|
126
|
-
{
|
127
|
-
"id": c.id,
|
128
|
-
"bbox": list(
|
129
|
-
c.bbox.to_bottom_left_origin(page_height).as_tuple()
|
130
|
-
), # TODO
|
131
|
-
"text": c.text,
|
132
|
-
}
|
133
|
-
for c in cells
|
134
|
-
]
|
135
|
-
cell_count = len(raw_cells)
|
136
|
-
|
137
|
-
_log.debug("---- 0. Treat cluster overlaps ------")
|
138
|
-
clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
|
139
|
-
|
140
|
-
_log.debug(
|
141
|
-
"---- 1. Initially assign cells to clusters based on minimum intersection ------"
|
142
|
-
)
|
143
|
-
## Check for cells included in or touched by clusters:
|
144
|
-
clusters_out = lu.assigning_cell_ids_to_clusters(
|
145
|
-
clusters_out, raw_cells, MIN_INTERSECTION
|
55
|
+
self.layout_predictor = LayoutPredictor(
|
56
|
+
artifact_path=str(artifacts_path),
|
57
|
+
device=device,
|
58
|
+
num_threads=accelerator_options.num_threads,
|
146
59
|
)
|
147
60
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
Cell(
|
248
|
-
id=c["id"], # type: ignore
|
249
|
-
bbox=BoundingBox.from_tuple(
|
250
|
-
coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
|
251
|
-
).to_top_left_origin(page_height),
|
252
|
-
text=c["text"], # type: ignore
|
253
|
-
)
|
254
|
-
for c in cells_out
|
255
|
-
]
|
256
|
-
|
257
|
-
del cells_out
|
61
|
+
def draw_clusters_and_cells_side_by_side(
|
62
|
+
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
63
|
+
):
|
64
|
+
"""
|
65
|
+
Draws a page image side by side with clusters filtered into two categories:
|
66
|
+
- Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
|
67
|
+
- Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
|
68
|
+
Includes label names and confidence scores for each cluster.
|
69
|
+
"""
|
70
|
+
label_to_color = {
|
71
|
+
DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
|
72
|
+
DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
|
73
|
+
DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
|
74
|
+
DocItemLabel.FORMULA: (192, 192, 192), # Gray
|
75
|
+
DocItemLabel.TABLE: (255, 204, 204), # Light Pink
|
76
|
+
DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
|
77
|
+
DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
|
78
|
+
DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
|
79
|
+
DocItemLabel.PAGE_FOOTER: (
|
80
|
+
204,
|
81
|
+
255,
|
82
|
+
204,
|
83
|
+
), # Light Green (same as Page-Header)
|
84
|
+
DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
|
85
|
+
DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
|
86
|
+
DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
|
87
|
+
DocItemLabel.CODE: (125, 125, 125), # Gray
|
88
|
+
DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
|
89
|
+
DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
|
90
|
+
DocItemLabel.FORM: (200, 255, 255), # Light Cyan
|
91
|
+
DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
|
92
|
+
}
|
93
|
+
# Filter clusters for left and right images
|
94
|
+
exclude_labels = {
|
95
|
+
DocItemLabel.FORM,
|
96
|
+
DocItemLabel.KEY_VALUE_REGION,
|
97
|
+
DocItemLabel.PICTURE,
|
98
|
+
}
|
99
|
+
left_clusters = [c for c in clusters if c.label not in exclude_labels]
|
100
|
+
right_clusters = [c for c in clusters if c.label in exclude_labels]
|
101
|
+
# Create a deep copy of the original image for both sides
|
102
|
+
left_image = copy.deepcopy(page.image)
|
103
|
+
right_image = copy.deepcopy(page.image)
|
104
|
+
|
105
|
+
# Function to draw clusters on an image
|
106
|
+
def draw_clusters(image, clusters):
|
107
|
+
draw = ImageDraw.Draw(image, "RGBA")
|
108
|
+
# Create a smaller font for the labels
|
109
|
+
try:
|
110
|
+
font = ImageFont.truetype("arial.ttf", 12)
|
111
|
+
except OSError:
|
112
|
+
# Fallback to default font if arial is not available
|
113
|
+
font = ImageFont.load_default()
|
114
|
+
for c_tl in clusters:
|
115
|
+
all_clusters = [c_tl, *c_tl.children]
|
116
|
+
for c in all_clusters:
|
117
|
+
# Draw cells first (underneath)
|
118
|
+
cell_color = (0, 0, 0, 40) # Transparent black for cells
|
119
|
+
for tc in c.cells:
|
120
|
+
cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
|
121
|
+
draw.rectangle(
|
122
|
+
[(cx0, cy0), (cx1, cy1)],
|
123
|
+
outline=None,
|
124
|
+
fill=cell_color,
|
125
|
+
)
|
126
|
+
# Draw cluster rectangle
|
127
|
+
x0, y0, x1, y1 = c.bbox.as_tuple()
|
128
|
+
cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
|
129
|
+
cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
|
130
|
+
draw.rectangle(
|
131
|
+
[(x0, y0), (x1, y1)],
|
132
|
+
outline=cluster_outline_color,
|
133
|
+
fill=cluster_fill_color,
|
134
|
+
)
|
135
|
+
# Add label name and confidence
|
136
|
+
label_text = f"{c.label.name} ({c.confidence:.2f})"
|
137
|
+
# Create semi-transparent background for text
|
138
|
+
text_bbox = draw.textbbox((x0, y0), label_text, font=font)
|
139
|
+
text_bg_padding = 2
|
140
|
+
draw.rectangle(
|
141
|
+
[
|
142
|
+
(
|
143
|
+
text_bbox[0] - text_bg_padding,
|
144
|
+
text_bbox[1] - text_bg_padding,
|
145
|
+
),
|
146
|
+
(
|
147
|
+
text_bbox[2] + text_bg_padding,
|
148
|
+
text_bbox[3] + text_bg_padding,
|
149
|
+
),
|
150
|
+
],
|
151
|
+
fill=(255, 255, 255, 180), # Semi-transparent white
|
152
|
+
)
|
153
|
+
# Draw text
|
154
|
+
draw.text(
|
155
|
+
(x0, y0),
|
156
|
+
label_text,
|
157
|
+
fill=(0, 0, 0, 255), # Solid black
|
158
|
+
font=font,
|
159
|
+
)
|
258
160
|
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
161
|
+
# Draw clusters on both images
|
162
|
+
draw_clusters(left_image, left_clusters)
|
163
|
+
draw_clusters(right_image, right_clusters)
|
164
|
+
# Combine the images side by side
|
165
|
+
combined_width = left_image.width * 2
|
166
|
+
combined_height = left_image.height
|
167
|
+
combined_image = Image.new("RGB", (combined_width, combined_height))
|
168
|
+
combined_image.paste(left_image, (0, 0))
|
169
|
+
combined_image.paste(right_image, (left_image.width, 0))
|
170
|
+
if show:
|
171
|
+
combined_image.show()
|
172
|
+
else:
|
173
|
+
out_path: Path = (
|
174
|
+
Path(settings.debug.debug_output_path)
|
175
|
+
/ f"debug_{conv_res.input.file.stem}"
|
272
176
|
)
|
273
|
-
|
274
|
-
|
275
|
-
|
177
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
178
|
+
out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
|
179
|
+
combined_image.save(str(out_file), format="png")
|
276
180
|
|
277
181
|
def __call__(
|
278
182
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
@@ -305,66 +209,26 @@ class LayoutModel(BasePageModel):
|
|
305
209
|
)
|
306
210
|
clusters.append(cluster)
|
307
211
|
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
if not cell.bbox.area() > 0:
|
313
|
-
overlap_frac = 0.0
|
314
|
-
else:
|
315
|
-
overlap_frac = (
|
316
|
-
cell.bbox.intersection_area_with(cluster.bbox)
|
317
|
-
/ cell.bbox.area()
|
318
|
-
)
|
319
|
-
|
320
|
-
if overlap_frac > 0.5:
|
321
|
-
cluster.cells.append(cell)
|
322
|
-
|
323
|
-
# Pre-sort clusters
|
324
|
-
# clusters = self.sort_clusters_by_cell_order(clusters)
|
325
|
-
|
326
|
-
# DEBUG code:
|
327
|
-
def draw_clusters_and_cells(show: bool = False):
|
328
|
-
image = copy.deepcopy(page.image)
|
329
|
-
if image is not None:
|
330
|
-
draw = ImageDraw.Draw(image)
|
331
|
-
for c in clusters:
|
332
|
-
x0, y0, x1, y1 = c.bbox.as_tuple()
|
333
|
-
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
334
|
-
|
335
|
-
cell_color = (
|
336
|
-
random.randint(30, 140),
|
337
|
-
random.randint(30, 140),
|
338
|
-
random.randint(30, 140),
|
339
|
-
)
|
340
|
-
for tc in c.cells: # [:1]:
|
341
|
-
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
342
|
-
draw.rectangle(
|
343
|
-
[(x0, y0), (x1, y1)], outline=cell_color
|
344
|
-
)
|
345
|
-
if show:
|
346
|
-
image.show()
|
347
|
-
else:
|
348
|
-
out_path: Path = (
|
349
|
-
Path(settings.debug.debug_output_path)
|
350
|
-
/ f"debug_{conv_res.input.file.stem}"
|
351
|
-
)
|
352
|
-
out_path.mkdir(parents=True, exist_ok=True)
|
212
|
+
if settings.debug.visualize_raw_layout:
|
213
|
+
self.draw_clusters_and_cells_side_by_side(
|
214
|
+
conv_res, page, clusters, mode_prefix="raw"
|
215
|
+
)
|
353
216
|
|
354
|
-
|
355
|
-
out_path / f"layout_page_{page.page_no:05}.png"
|
356
|
-
)
|
357
|
-
image.save(str(out_file), format="png")
|
217
|
+
# Apply postprocessing
|
358
218
|
|
359
|
-
|
219
|
+
processed_clusters, processed_cells = LayoutPostprocessor(
|
220
|
+
page.cells, clusters, page.size
|
221
|
+
).postprocess()
|
222
|
+
# processed_clusters, processed_cells = clusters, page.cells
|
360
223
|
|
361
|
-
|
362
|
-
|
224
|
+
page.cells = processed_cells
|
225
|
+
page.predictions.layout = LayoutPrediction(
|
226
|
+
clusters=processed_clusters
|
363
227
|
)
|
364
228
|
|
365
|
-
page.predictions.layout = LayoutPrediction(clusters=clusters)
|
366
|
-
|
367
229
|
if settings.debug.visualize_layout:
|
368
|
-
|
230
|
+
self.draw_clusters_and_cells_side_by_side(
|
231
|
+
conv_res, page, processed_clusters, mode_prefix="postprocessed"
|
232
|
+
)
|
369
233
|
|
370
234
|
yield page
|
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|
6
6
|
|
7
7
|
from docling.datamodel.base_models import (
|
8
8
|
AssembledUnit,
|
9
|
+
ContainerElement,
|
9
10
|
FigureElement,
|
10
11
|
Page,
|
11
12
|
PageElement,
|
@@ -94,7 +95,7 @@ class PageAssembleModel(BasePageModel):
|
|
94
95
|
headers.append(text_el)
|
95
96
|
else:
|
96
97
|
body.append(text_el)
|
97
|
-
elif cluster.label
|
98
|
+
elif cluster.label in LayoutModel.TABLE_LABELS:
|
98
99
|
tbl = None
|
99
100
|
if page.predictions.tablestructure:
|
100
101
|
tbl = page.predictions.tablestructure.table_map.get(
|
@@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
|
|
159
160
|
)
|
160
161
|
elements.append(equation)
|
161
162
|
body.append(equation)
|
163
|
+
elif cluster.label in LayoutModel.CONTAINER_LABELS:
|
164
|
+
container_el = ContainerElement(
|
165
|
+
label=cluster.label,
|
166
|
+
id=cluster.id,
|
167
|
+
page_no=page.page_no,
|
168
|
+
cluster=cluster,
|
169
|
+
)
|
170
|
+
elements.append(container_el)
|
171
|
+
body.append(container_el)
|
162
172
|
|
163
173
|
page.assembled = AssembledUnit(
|
164
174
|
elements=elements, headers=headers, body=body
|
@@ -6,16 +6,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
|
|
6
6
|
|
7
7
|
from docling.datamodel.base_models import OcrCell, Page
|
8
8
|
from docling.datamodel.document import ConversionResult
|
9
|
-
from docling.datamodel.pipeline_options import
|
9
|
+
from docling.datamodel.pipeline_options import (
|
10
|
+
AcceleratorDevice,
|
11
|
+
AcceleratorOptions,
|
12
|
+
RapidOcrOptions,
|
13
|
+
)
|
10
14
|
from docling.datamodel.settings import settings
|
11
15
|
from docling.models.base_ocr_model import BaseOcrModel
|
16
|
+
from docling.utils.accelerator_utils import decide_device
|
12
17
|
from docling.utils.profiling import TimeRecorder
|
13
18
|
|
14
19
|
_log = logging.getLogger(__name__)
|
15
20
|
|
16
21
|
|
17
22
|
class RapidOcrModel(BaseOcrModel):
|
18
|
-
def __init__(
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
enabled: bool,
|
26
|
+
options: RapidOcrOptions,
|
27
|
+
accelerator_options: AcceleratorOptions,
|
28
|
+
):
|
19
29
|
super().__init__(enabled=enabled, options=options)
|
20
30
|
self.options: RapidOcrOptions
|
21
31
|
|
@@ -30,52 +40,21 @@ class RapidOcrModel(BaseOcrModel):
|
|
30
40
|
"Alternatively, Docling has support for other OCR engines. See the documentation."
|
31
41
|
)
|
32
42
|
|
33
|
-
#
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
det_use_dml = True
|
39
|
-
cls_use_dml = True
|
40
|
-
rec_use_dml = True
|
41
|
-
|
42
|
-
# # Same as Defaults in RapidOCR
|
43
|
-
# cls_use_cuda = False
|
44
|
-
# rec_use_cuda = False
|
45
|
-
# det_use_cuda = False
|
46
|
-
# det_use_dml = False
|
47
|
-
# cls_use_dml = False
|
48
|
-
# rec_use_dml = False
|
49
|
-
|
50
|
-
# # If we set everything to true onnx-runtime would automatically choose the fastest accelerator
|
51
|
-
# if self.options.device == self.options.Device.AUTO:
|
52
|
-
# cls_use_cuda = True
|
53
|
-
# rec_use_cuda = True
|
54
|
-
# det_use_cuda = True
|
55
|
-
# det_use_dml = True
|
56
|
-
# cls_use_dml = True
|
57
|
-
# rec_use_dml = True
|
58
|
-
|
59
|
-
# # If we set use_cuda to true onnx would use the cuda device available in runtime if no cuda device is available it would run on CPU.
|
60
|
-
# elif self.options.device == self.options.Device.CUDA:
|
61
|
-
# cls_use_cuda = True
|
62
|
-
# rec_use_cuda = True
|
63
|
-
# det_use_cuda = True
|
64
|
-
|
65
|
-
# # If we set use_dml to true onnx would use the dml device available in runtime if no dml device is available it would work on CPU.
|
66
|
-
# elif self.options.device == self.options.Device.DIRECTML:
|
67
|
-
# det_use_dml = True
|
68
|
-
# cls_use_dml = True
|
69
|
-
# rec_use_dml = True
|
43
|
+
# Decide the accelerator devices
|
44
|
+
device = decide_device(accelerator_options.device)
|
45
|
+
use_cuda = str(AcceleratorDevice.CUDA.value).lower() in device
|
46
|
+
use_dml = accelerator_options.device == AcceleratorDevice.AUTO
|
47
|
+
intra_op_num_threads = accelerator_options.num_threads
|
70
48
|
|
71
49
|
self.reader = RapidOCR(
|
72
50
|
text_score=self.options.text_score,
|
73
|
-
cls_use_cuda=
|
74
|
-
rec_use_cuda=
|
75
|
-
det_use_cuda=
|
76
|
-
det_use_dml=
|
77
|
-
cls_use_dml=
|
78
|
-
rec_use_dml=
|
51
|
+
cls_use_cuda=use_cuda,
|
52
|
+
rec_use_cuda=use_cuda,
|
53
|
+
det_use_cuda=use_cuda,
|
54
|
+
det_use_dml=use_dml,
|
55
|
+
cls_use_dml=use_dml,
|
56
|
+
rec_use_dml=use_dml,
|
57
|
+
intra_op_num_threads=intra_op_num_threads,
|
79
58
|
print_verbose=self.options.print_verbose,
|
80
59
|
det_model_path=self.options.det_model_path,
|
81
60
|
cls_model_path=self.options.cls_model_path,
|