docling 2.11.0__py3-none-any.whl → 2.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,7 +7,7 @@ from typing import Iterable, List
7
7
 
8
8
  from docling_core.types.doc import CoordOrigin, DocItemLabel
9
9
  from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
10
- from PIL import ImageDraw
10
+ from PIL import Image, ImageDraw, ImageFont
11
11
 
12
12
  from docling.datamodel.base_models import (
13
13
  BoundingBox,
@@ -17,9 +17,11 @@ from docling.datamodel.base_models import (
17
17
  Page,
18
18
  )
19
19
  from docling.datamodel.document import ConversionResult
20
+ from docling.datamodel.pipeline_options import AcceleratorDevice, AcceleratorOptions
20
21
  from docling.datamodel.settings import settings
21
22
  from docling.models.base_model import BasePageModel
22
- from docling.utils import layout_utils as lu
23
+ from docling.utils.accelerator_utils import decide_device
24
+ from docling.utils.layout_postprocessor import LayoutPostprocessor
23
25
  from docling.utils.profiling import TimeRecorder
24
26
 
25
27
  _log = logging.getLogger(__name__)
@@ -42,237 +44,139 @@ class LayoutModel(BasePageModel):
42
44
  ]
43
45
  PAGE_HEADER_LABELS = [DocItemLabel.PAGE_HEADER, DocItemLabel.PAGE_FOOTER]
44
46
 
45
- TABLE_LABEL = DocItemLabel.TABLE
47
+ TABLE_LABELS = [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
46
48
  FIGURE_LABEL = DocItemLabel.PICTURE
47
49
  FORMULA_LABEL = DocItemLabel.FORMULA
50
+ CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
48
51
 
49
- def __init__(self, artifacts_path: Path):
50
- self.layout_predictor = LayoutPredictor(artifacts_path) # TODO temporary
52
+ def __init__(self, artifacts_path: Path, accelerator_options: AcceleratorOptions):
53
+ device = decide_device(accelerator_options.device)
51
54
 
52
- def postprocess(self, clusters_in: List[Cluster], cells: List[Cell], page_height):
53
- MIN_INTERSECTION = 0.2
54
- CLASS_THRESHOLDS = {
55
- DocItemLabel.CAPTION: 0.35,
56
- DocItemLabel.FOOTNOTE: 0.35,
57
- DocItemLabel.FORMULA: 0.35,
58
- DocItemLabel.LIST_ITEM: 0.35,
59
- DocItemLabel.PAGE_FOOTER: 0.35,
60
- DocItemLabel.PAGE_HEADER: 0.35,
61
- DocItemLabel.PICTURE: 0.2, # low threshold adjust to capture chemical structures for examples.
62
- DocItemLabel.SECTION_HEADER: 0.45,
63
- DocItemLabel.TABLE: 0.35,
64
- DocItemLabel.TEXT: 0.45,
65
- DocItemLabel.TITLE: 0.45,
66
- DocItemLabel.DOCUMENT_INDEX: 0.45,
67
- DocItemLabel.CODE: 0.45,
68
- DocItemLabel.CHECKBOX_SELECTED: 0.45,
69
- DocItemLabel.CHECKBOX_UNSELECTED: 0.45,
70
- DocItemLabel.FORM: 0.45,
71
- DocItemLabel.KEY_VALUE_REGION: 0.45,
72
- }
73
-
74
- CLASS_REMAPPINGS = {
75
- DocItemLabel.DOCUMENT_INDEX: DocItemLabel.TABLE,
76
- DocItemLabel.TITLE: DocItemLabel.SECTION_HEADER,
77
- }
78
-
79
- _log.debug("================= Start postprocess function ====================")
80
- start_time = time.time()
81
- # Apply Confidence Threshold to cluster predictions
82
- # confidence = self.conf_threshold
83
- clusters_mod = []
84
-
85
- for cluster in clusters_in:
86
- confidence = CLASS_THRESHOLDS[cluster.label]
87
- if cluster.confidence >= confidence:
88
- # annotation["created_by"] = "high_conf_pred"
89
-
90
- # Remap class labels where needed.
91
- if cluster.label in CLASS_REMAPPINGS.keys():
92
- cluster.label = CLASS_REMAPPINGS[cluster.label]
93
- clusters_mod.append(cluster)
94
-
95
- # map to dictionary clusters and cells, with bottom left origin
96
- clusters_orig = [
97
- {
98
- "id": c.id,
99
- "bbox": list(
100
- c.bbox.to_bottom_left_origin(page_height).as_tuple()
101
- ), # TODO
102
- "confidence": c.confidence,
103
- "cell_ids": [],
104
- "type": c.label,
105
- }
106
- for c in clusters_in
107
- ]
108
-
109
- clusters_out = [
110
- {
111
- "id": c.id,
112
- "bbox": list(
113
- c.bbox.to_bottom_left_origin(page_height).as_tuple()
114
- ), # TODO
115
- "confidence": c.confidence,
116
- "created_by": "high_conf_pred",
117
- "cell_ids": [],
118
- "type": c.label,
119
- }
120
- for c in clusters_mod
121
- ]
122
-
123
- del clusters_mod
124
-
125
- raw_cells = [
126
- {
127
- "id": c.id,
128
- "bbox": list(
129
- c.bbox.to_bottom_left_origin(page_height).as_tuple()
130
- ), # TODO
131
- "text": c.text,
132
- }
133
- for c in cells
134
- ]
135
- cell_count = len(raw_cells)
136
-
137
- _log.debug("---- 0. Treat cluster overlaps ------")
138
- clusters_out = lu.remove_cluster_duplicates_by_conf(clusters_out, 0.8)
139
-
140
- _log.debug(
141
- "---- 1. Initially assign cells to clusters based on minimum intersection ------"
142
- )
143
- ## Check for cells included in or touched by clusters:
144
- clusters_out = lu.assigning_cell_ids_to_clusters(
145
- clusters_out, raw_cells, MIN_INTERSECTION
55
+ self.layout_predictor = LayoutPredictor(
56
+ artifact_path=str(artifacts_path),
57
+ device=device,
58
+ num_threads=accelerator_options.num_threads,
146
59
  )
147
60
 
148
- _log.debug("---- 2. Assign Orphans with Low Confidence Detections")
149
- # Creates a map of cell_id->cluster_id
150
- (
151
- clusters_around_cells,
152
- orphan_cell_indices,
153
- ambiguous_cell_indices,
154
- ) = lu.cell_id_state_map(clusters_out, cell_count)
155
-
156
- # Assign orphan cells with lower confidence predictions
157
- clusters_out, orphan_cell_indices = lu.assign_orphans_with_low_conf_pred(
158
- clusters_out, clusters_orig, raw_cells, orphan_cell_indices
159
- )
160
-
161
- # Refresh the cell_ids assignment, after creating new clusters using low conf predictions
162
- clusters_out = lu.assigning_cell_ids_to_clusters(
163
- clusters_out, raw_cells, MIN_INTERSECTION
164
- )
165
-
166
- _log.debug("---- 3. Settle Ambigous Cells")
167
- # Creates an update map after assignment of cell_id->cluster_id
168
- (
169
- clusters_around_cells,
170
- orphan_cell_indices,
171
- ambiguous_cell_indices,
172
- ) = lu.cell_id_state_map(clusters_out, cell_count)
173
-
174
- # Settle pdf cells that belong to multiple clusters
175
- clusters_out, ambiguous_cell_indices = lu.remove_ambigous_pdf_cell_by_conf(
176
- clusters_out, raw_cells, ambiguous_cell_indices
177
- )
178
-
179
- _log.debug("---- 4. Set Orphans as Text")
180
- (
181
- clusters_around_cells,
182
- orphan_cell_indices,
183
- ambiguous_cell_indices,
184
- ) = lu.cell_id_state_map(clusters_out, cell_count)
185
-
186
- clusters_out, orphan_cell_indices = lu.set_orphan_as_text(
187
- clusters_out, clusters_orig, raw_cells, orphan_cell_indices
188
- )
189
-
190
- _log.debug("---- 5. Merge Cells & and adapt the bounding boxes")
191
- # Merge cells orphan cells
192
- clusters_out = lu.merge_cells(clusters_out)
193
-
194
- # Clean up clusters that remain from merged and unreasonable clusters
195
- clusters_out = lu.clean_up_clusters(
196
- clusters_out,
197
- raw_cells,
198
- merge_cells=True,
199
- img_table=True,
200
- one_cell_table=True,
201
- )
202
-
203
- new_clusters = lu.adapt_bboxes(raw_cells, clusters_out, orphan_cell_indices)
204
- clusters_out = new_clusters
205
-
206
- ## We first rebuild where every cell is now:
207
- ## Now we write into a prediction cells list, not into the raw cells list.
208
- ## As we don't need previous labels, we best overwrite any old list, because that might
209
- ## have been sorted differently.
210
- (
211
- clusters_around_cells,
212
- orphan_cell_indices,
213
- ambiguous_cell_indices,
214
- ) = lu.cell_id_state_map(clusters_out, cell_count)
215
-
216
- target_cells = []
217
- for ix, cell in enumerate(raw_cells):
218
- new_cell = {
219
- "id": ix,
220
- "rawcell_id": ix,
221
- "label": "None",
222
- "bbox": cell["bbox"],
223
- "text": cell["text"],
224
- }
225
- for cluster_index in clusters_around_cells[
226
- ix
227
- ]: # By previous analysis, this is always 1 cluster.
228
- new_cell["label"] = clusters_out[cluster_index]["type"]
229
- target_cells.append(new_cell)
230
- # _log.debug("New label of cell " + str(ix) + " is " + str(new_cell["label"]))
231
- cells_out = target_cells
232
-
233
- ## -------------------------------
234
- ## Sort clusters into reasonable reading order, and sort the cells inside each cluster
235
- _log.debug("---- 5. Sort clusters in reading order ------")
236
- sorted_clusters = lu.produce_reading_order(
237
- clusters_out, "raw_cell_ids", "raw_cell_ids", True
238
- )
239
- clusters_out = sorted_clusters
240
-
241
- # end_time = timer()
242
- _log.debug("---- End of postprocessing function ------")
243
- end_time = time.time() - start_time
244
- _log.debug(f"Finished post processing in seconds={end_time:.3f}")
245
-
246
- cells_out_new = [
247
- Cell(
248
- id=c["id"], # type: ignore
249
- bbox=BoundingBox.from_tuple(
250
- coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
251
- ).to_top_left_origin(page_height),
252
- text=c["text"], # type: ignore
253
- )
254
- for c in cells_out
255
- ]
256
-
257
- del cells_out
61
+ def draw_clusters_and_cells_side_by_side(
62
+ self, conv_res, page, clusters, mode_prefix: str, show: bool = False
63
+ ):
64
+ """
65
+ Draws a page image side by side with clusters filtered into two categories:
66
+ - Left: Clusters excluding FORM, KEY_VALUE_REGION, and PICTURE.
67
+ - Right: Clusters including FORM, KEY_VALUE_REGION, and PICTURE.
68
+ Includes label names and confidence scores for each cluster.
69
+ """
70
+ label_to_color = {
71
+ DocItemLabel.TEXT: (255, 255, 153), # Light Yellow
72
+ DocItemLabel.CAPTION: (255, 204, 153), # Light Orange
73
+ DocItemLabel.LIST_ITEM: (153, 153, 255), # Light Purple
74
+ DocItemLabel.FORMULA: (192, 192, 192), # Gray
75
+ DocItemLabel.TABLE: (255, 204, 204), # Light Pink
76
+ DocItemLabel.PICTURE: (255, 204, 164), # Light Beige
77
+ DocItemLabel.SECTION_HEADER: (255, 153, 153), # Light Red
78
+ DocItemLabel.PAGE_HEADER: (204, 255, 204), # Light Green
79
+ DocItemLabel.PAGE_FOOTER: (
80
+ 204,
81
+ 255,
82
+ 204,
83
+ ), # Light Green (same as Page-Header)
84
+ DocItemLabel.TITLE: (255, 153, 153), # Light Red (same as Section-Header)
85
+ DocItemLabel.FOOTNOTE: (200, 200, 255), # Light Blue
86
+ DocItemLabel.DOCUMENT_INDEX: (220, 220, 220), # Light Gray
87
+ DocItemLabel.CODE: (125, 125, 125), # Gray
88
+ DocItemLabel.CHECKBOX_SELECTED: (255, 182, 193), # Pale Green
89
+ DocItemLabel.CHECKBOX_UNSELECTED: (255, 182, 193), # Light Pink
90
+ DocItemLabel.FORM: (200, 255, 255), # Light Cyan
91
+ DocItemLabel.KEY_VALUE_REGION: (183, 65, 14), # Rusty orange
92
+ }
93
+ # Filter clusters for left and right images
94
+ exclude_labels = {
95
+ DocItemLabel.FORM,
96
+ DocItemLabel.KEY_VALUE_REGION,
97
+ DocItemLabel.PICTURE,
98
+ }
99
+ left_clusters = [c for c in clusters if c.label not in exclude_labels]
100
+ right_clusters = [c for c in clusters if c.label in exclude_labels]
101
+ # Create a deep copy of the original image for both sides
102
+ left_image = copy.deepcopy(page.image)
103
+ right_image = copy.deepcopy(page.image)
104
+
105
+ # Function to draw clusters on an image
106
+ def draw_clusters(image, clusters):
107
+ draw = ImageDraw.Draw(image, "RGBA")
108
+ # Create a smaller font for the labels
109
+ try:
110
+ font = ImageFont.truetype("arial.ttf", 12)
111
+ except OSError:
112
+ # Fallback to default font if arial is not available
113
+ font = ImageFont.load_default()
114
+ for c_tl in clusters:
115
+ all_clusters = [c_tl, *c_tl.children]
116
+ for c in all_clusters:
117
+ # Draw cells first (underneath)
118
+ cell_color = (0, 0, 0, 40) # Transparent black for cells
119
+ for tc in c.cells:
120
+ cx0, cy0, cx1, cy1 = tc.bbox.as_tuple()
121
+ draw.rectangle(
122
+ [(cx0, cy0), (cx1, cy1)],
123
+ outline=None,
124
+ fill=cell_color,
125
+ )
126
+ # Draw cluster rectangle
127
+ x0, y0, x1, y1 = c.bbox.as_tuple()
128
+ cluster_fill_color = (*list(label_to_color.get(c.label)), 70)
129
+ cluster_outline_color = (*list(label_to_color.get(c.label)), 255)
130
+ draw.rectangle(
131
+ [(x0, y0), (x1, y1)],
132
+ outline=cluster_outline_color,
133
+ fill=cluster_fill_color,
134
+ )
135
+ # Add label name and confidence
136
+ label_text = f"{c.label.name} ({c.confidence:.2f})"
137
+ # Create semi-transparent background for text
138
+ text_bbox = draw.textbbox((x0, y0), label_text, font=font)
139
+ text_bg_padding = 2
140
+ draw.rectangle(
141
+ [
142
+ (
143
+ text_bbox[0] - text_bg_padding,
144
+ text_bbox[1] - text_bg_padding,
145
+ ),
146
+ (
147
+ text_bbox[2] + text_bg_padding,
148
+ text_bbox[3] + text_bg_padding,
149
+ ),
150
+ ],
151
+ fill=(255, 255, 255, 180), # Semi-transparent white
152
+ )
153
+ # Draw text
154
+ draw.text(
155
+ (x0, y0),
156
+ label_text,
157
+ fill=(0, 0, 0, 255), # Solid black
158
+ font=font,
159
+ )
258
160
 
259
- clusters_out_new = []
260
- for c in clusters_out:
261
- cluster_cells = [
262
- ccell for ccell in cells_out_new if ccell.id in c["cell_ids"] # type: ignore
263
- ]
264
- c_new = Cluster(
265
- id=c["id"], # type: ignore
266
- bbox=BoundingBox.from_tuple(
267
- coord=c["bbox"], origin=CoordOrigin.BOTTOMLEFT # type: ignore
268
- ).to_top_left_origin(page_height),
269
- confidence=c["confidence"], # type: ignore
270
- label=DocItemLabel(c["type"]),
271
- cells=cluster_cells,
161
+ # Draw clusters on both images
162
+ draw_clusters(left_image, left_clusters)
163
+ draw_clusters(right_image, right_clusters)
164
+ # Combine the images side by side
165
+ combined_width = left_image.width * 2
166
+ combined_height = left_image.height
167
+ combined_image = Image.new("RGB", (combined_width, combined_height))
168
+ combined_image.paste(left_image, (0, 0))
169
+ combined_image.paste(right_image, (left_image.width, 0))
170
+ if show:
171
+ combined_image.show()
172
+ else:
173
+ out_path: Path = (
174
+ Path(settings.debug.debug_output_path)
175
+ / f"debug_{conv_res.input.file.stem}"
272
176
  )
273
- clusters_out_new.append(c_new)
274
-
275
- return clusters_out_new, cells_out_new
177
+ out_path.mkdir(parents=True, exist_ok=True)
178
+ out_file = out_path / f"{mode_prefix}_layout_page_{page.page_no:05}.png"
179
+ combined_image.save(str(out_file), format="png")
276
180
 
277
181
  def __call__(
278
182
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
@@ -305,66 +209,26 @@ class LayoutModel(BasePageModel):
305
209
  )
306
210
  clusters.append(cluster)
307
211
 
308
- # Map cells to clusters
309
- # TODO: Remove, postprocess should take care of it anyway.
310
- for cell in page.cells:
311
- for cluster in clusters:
312
- if not cell.bbox.area() > 0:
313
- overlap_frac = 0.0
314
- else:
315
- overlap_frac = (
316
- cell.bbox.intersection_area_with(cluster.bbox)
317
- / cell.bbox.area()
318
- )
319
-
320
- if overlap_frac > 0.5:
321
- cluster.cells.append(cell)
322
-
323
- # Pre-sort clusters
324
- # clusters = self.sort_clusters_by_cell_order(clusters)
325
-
326
- # DEBUG code:
327
- def draw_clusters_and_cells(show: bool = False):
328
- image = copy.deepcopy(page.image)
329
- if image is not None:
330
- draw = ImageDraw.Draw(image)
331
- for c in clusters:
332
- x0, y0, x1, y1 = c.bbox.as_tuple()
333
- draw.rectangle([(x0, y0), (x1, y1)], outline="green")
334
-
335
- cell_color = (
336
- random.randint(30, 140),
337
- random.randint(30, 140),
338
- random.randint(30, 140),
339
- )
340
- for tc in c.cells: # [:1]:
341
- x0, y0, x1, y1 = tc.bbox.as_tuple()
342
- draw.rectangle(
343
- [(x0, y0), (x1, y1)], outline=cell_color
344
- )
345
- if show:
346
- image.show()
347
- else:
348
- out_path: Path = (
349
- Path(settings.debug.debug_output_path)
350
- / f"debug_{conv_res.input.file.stem}"
351
- )
352
- out_path.mkdir(parents=True, exist_ok=True)
212
+ if settings.debug.visualize_raw_layout:
213
+ self.draw_clusters_and_cells_side_by_side(
214
+ conv_res, page, clusters, mode_prefix="raw"
215
+ )
353
216
 
354
- out_file = (
355
- out_path / f"layout_page_{page.page_no:05}.png"
356
- )
357
- image.save(str(out_file), format="png")
217
+ # Apply postprocessing
358
218
 
359
- # draw_clusters_and_cells()
219
+ processed_clusters, processed_cells = LayoutPostprocessor(
220
+ page.cells, clusters, page.size
221
+ ).postprocess()
222
+ # processed_clusters, processed_cells = clusters, page.cells
360
223
 
361
- clusters, page.cells = self.postprocess(
362
- clusters, page.cells, page.size.height
224
+ page.cells = processed_cells
225
+ page.predictions.layout = LayoutPrediction(
226
+ clusters=processed_clusters
363
227
  )
364
228
 
365
- page.predictions.layout = LayoutPrediction(clusters=clusters)
366
-
367
229
  if settings.debug.visualize_layout:
368
- draw_clusters_and_cells()
230
+ self.draw_clusters_and_cells_side_by_side(
231
+ conv_res, page, processed_clusters, mode_prefix="postprocessed"
232
+ )
369
233
 
370
234
  yield page
@@ -6,6 +6,7 @@ from pydantic import BaseModel
6
6
 
7
7
  from docling.datamodel.base_models import (
8
8
  AssembledUnit,
9
+ ContainerElement,
9
10
  FigureElement,
10
11
  Page,
11
12
  PageElement,
@@ -94,7 +95,7 @@ class PageAssembleModel(BasePageModel):
94
95
  headers.append(text_el)
95
96
  else:
96
97
  body.append(text_el)
97
- elif cluster.label == LayoutModel.TABLE_LABEL:
98
+ elif cluster.label in LayoutModel.TABLE_LABELS:
98
99
  tbl = None
99
100
  if page.predictions.tablestructure:
100
101
  tbl = page.predictions.tablestructure.table_map.get(
@@ -159,6 +160,15 @@ class PageAssembleModel(BasePageModel):
159
160
  )
160
161
  elements.append(equation)
161
162
  body.append(equation)
163
+ elif cluster.label in LayoutModel.CONTAINER_LABELS:
164
+ container_el = ContainerElement(
165
+ label=cluster.label,
166
+ id=cluster.id,
167
+ page_no=page.page_no,
168
+ cluster=cluster,
169
+ )
170
+ elements.append(container_el)
171
+ body.append(container_el)
162
172
 
163
173
  page.assembled = AssembledUnit(
164
174
  elements=elements, headers=headers, body=body
@@ -6,16 +6,26 @@ from docling_core.types.doc import BoundingBox, CoordOrigin
6
6
 
7
7
  from docling.datamodel.base_models import OcrCell, Page
8
8
  from docling.datamodel.document import ConversionResult
9
- from docling.datamodel.pipeline_options import RapidOcrOptions
9
+ from docling.datamodel.pipeline_options import (
10
+ AcceleratorDevice,
11
+ AcceleratorOptions,
12
+ RapidOcrOptions,
13
+ )
10
14
  from docling.datamodel.settings import settings
11
15
  from docling.models.base_ocr_model import BaseOcrModel
16
+ from docling.utils.accelerator_utils import decide_device
12
17
  from docling.utils.profiling import TimeRecorder
13
18
 
14
19
  _log = logging.getLogger(__name__)
15
20
 
16
21
 
17
22
  class RapidOcrModel(BaseOcrModel):
18
- def __init__(self, enabled: bool, options: RapidOcrOptions):
23
+ def __init__(
24
+ self,
25
+ enabled: bool,
26
+ options: RapidOcrOptions,
27
+ accelerator_options: AcceleratorOptions,
28
+ ):
19
29
  super().__init__(enabled=enabled, options=options)
20
30
  self.options: RapidOcrOptions
21
31
 
@@ -30,52 +40,21 @@ class RapidOcrModel(BaseOcrModel):
30
40
  "Alternatively, Docling has support for other OCR engines. See the documentation."
31
41
  )
32
42
 
33
- # This configuration option will be revamped while introducing device settings for all models.
34
- # For the moment we will default to auto and let onnx-runtime pick the best.
35
- cls_use_cuda = True
36
- rec_use_cuda = True
37
- det_use_cuda = True
38
- det_use_dml = True
39
- cls_use_dml = True
40
- rec_use_dml = True
41
-
42
- # # Same as Defaults in RapidOCR
43
- # cls_use_cuda = False
44
- # rec_use_cuda = False
45
- # det_use_cuda = False
46
- # det_use_dml = False
47
- # cls_use_dml = False
48
- # rec_use_dml = False
49
-
50
- # # If we set everything to true onnx-runtime would automatically choose the fastest accelerator
51
- # if self.options.device == self.options.Device.AUTO:
52
- # cls_use_cuda = True
53
- # rec_use_cuda = True
54
- # det_use_cuda = True
55
- # det_use_dml = True
56
- # cls_use_dml = True
57
- # rec_use_dml = True
58
-
59
- # # If we set use_cuda to true onnx would use the cuda device available in runtime if no cuda device is available it would run on CPU.
60
- # elif self.options.device == self.options.Device.CUDA:
61
- # cls_use_cuda = True
62
- # rec_use_cuda = True
63
- # det_use_cuda = True
64
-
65
- # # If we set use_dml to true onnx would use the dml device available in runtime if no dml device is available it would work on CPU.
66
- # elif self.options.device == self.options.Device.DIRECTML:
67
- # det_use_dml = True
68
- # cls_use_dml = True
69
- # rec_use_dml = True
43
+ # Decide the accelerator devices
44
+ device = decide_device(accelerator_options.device)
45
+ use_cuda = str(AcceleratorDevice.CUDA.value).lower() in device
46
+ use_dml = accelerator_options.device == AcceleratorDevice.AUTO
47
+ intra_op_num_threads = accelerator_options.num_threads
70
48
 
71
49
  self.reader = RapidOCR(
72
50
  text_score=self.options.text_score,
73
- cls_use_cuda=cls_use_cuda,
74
- rec_use_cuda=rec_use_cuda,
75
- det_use_cuda=det_use_cuda,
76
- det_use_dml=det_use_dml,
77
- cls_use_dml=cls_use_dml,
78
- rec_use_dml=rec_use_dml,
51
+ cls_use_cuda=use_cuda,
52
+ rec_use_cuda=use_cuda,
53
+ det_use_cuda=use_cuda,
54
+ det_use_dml=use_dml,
55
+ cls_use_dml=use_dml,
56
+ rec_use_dml=use_dml,
57
+ intra_op_num_threads=intra_op_num_threads,
79
58
  print_verbose=self.options.print_verbose,
80
59
  det_model_path=self.options.det_model_path,
81
60
  cls_model_path=self.options.cls_model_path,