docling 2.15.0__py3-none-any.whl → 2.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/abstract_backend.py +0 -1
- docling/backend/asciidoc_backend.py +0 -1
- docling/backend/docling_parse_backend.py +2 -2
- docling/backend/docling_parse_v2_backend.py +2 -2
- docling/backend/html_backend.py +1 -1
- docling/backend/json/__init__.py +0 -0
- docling/backend/json/docling_json_backend.py +58 -0
- docling/backend/md_backend.py +44 -27
- docling/backend/msexcel_backend.py +50 -38
- docling/backend/msword_backend.py +0 -1
- docling/backend/pdf_backend.py +0 -2
- docling/backend/pypdfium2_backend.py +2 -2
- docling/datamodel/base_models.py +30 -3
- docling/datamodel/document.py +2 -0
- docling/datamodel/pipeline_options.py +7 -10
- docling/document_converter.py +4 -0
- docling/models/base_model.py +62 -6
- docling/models/base_ocr_model.py +15 -12
- docling/models/code_formula_model.py +245 -0
- docling/models/document_picture_classifier.py +187 -0
- docling/models/layout_model.py +10 -86
- docling/models/page_assemble_model.py +1 -33
- docling/models/tesseract_ocr_cli_model.py +0 -1
- docling/models/tesseract_ocr_model.py +63 -15
- docling/pipeline/base_pipeline.py +40 -17
- docling/pipeline/standard_pdf_pipeline.py +31 -2
- docling/utils/glm_utils.py +4 -1
- docling/utils/visualization.py +80 -0
- {docling-2.15.0.dist-info → docling-2.16.0.dist-info}/METADATA +7 -7
- docling-2.16.0.dist-info/RECORD +61 -0
- docling-2.15.0.dist-info/RECORD +0 -56
- {docling-2.15.0.dist-info → docling-2.16.0.dist-info}/LICENSE +0 -0
- {docling-2.15.0.dist-info → docling-2.16.0.dist-info}/WHEEL +0 -0
- {docling-2.15.0.dist-info → docling-2.16.0.dist-info}/entry_points.txt +0 -0
docling/models/base_model.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, Iterable
|
2
|
+
from typing import Any, Generic, Iterable, Optional
|
3
3
|
|
4
|
-
from docling_core.types.doc import DoclingDocument, NodeItem
|
4
|
+
from docling_core.types.doc import BoundingBox, DoclingDocument, NodeItem, TextItem
|
5
|
+
from typing_extensions import TypeVar
|
5
6
|
|
6
|
-
from docling.datamodel.base_models import Page
|
7
|
+
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
7
8
|
from docling.datamodel.document import ConversionResult
|
8
9
|
|
9
10
|
|
@@ -15,14 +16,69 @@ class BasePageModel(ABC):
|
|
15
16
|
pass
|
16
17
|
|
17
18
|
|
18
|
-
|
19
|
+
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
20
|
+
|
21
|
+
|
22
|
+
class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
|
19
23
|
|
20
24
|
@abstractmethod
|
21
25
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
22
26
|
pass
|
23
27
|
|
28
|
+
@abstractmethod
|
29
|
+
def prepare_element(
|
30
|
+
self, conv_res: ConversionResult, element: NodeItem
|
31
|
+
) -> Optional[EnrichElementT]:
|
32
|
+
pass
|
33
|
+
|
24
34
|
@abstractmethod
|
25
35
|
def __call__(
|
26
|
-
self, doc: DoclingDocument, element_batch: Iterable[
|
27
|
-
) -> Iterable[
|
36
|
+
self, doc: DoclingDocument, element_batch: Iterable[EnrichElementT]
|
37
|
+
) -> Iterable[NodeItem]:
|
28
38
|
pass
|
39
|
+
|
40
|
+
|
41
|
+
class BaseEnrichmentModel(GenericEnrichmentModel[NodeItem]):
|
42
|
+
|
43
|
+
def prepare_element(
|
44
|
+
self, conv_res: ConversionResult, element: NodeItem
|
45
|
+
) -> Optional[NodeItem]:
|
46
|
+
if self.is_processable(doc=conv_res.document, element=element):
|
47
|
+
return element
|
48
|
+
return None
|
49
|
+
|
50
|
+
|
51
|
+
class BaseItemAndImageEnrichmentModel(
|
52
|
+
GenericEnrichmentModel[ItemAndImageEnrichmentElement]
|
53
|
+
):
|
54
|
+
|
55
|
+
images_scale: float
|
56
|
+
expansion_factor: float = 0.0
|
57
|
+
|
58
|
+
def prepare_element(
|
59
|
+
self, conv_res: ConversionResult, element: NodeItem
|
60
|
+
) -> Optional[ItemAndImageEnrichmentElement]:
|
61
|
+
if not self.is_processable(doc=conv_res.document, element=element):
|
62
|
+
return None
|
63
|
+
|
64
|
+
assert isinstance(element, TextItem)
|
65
|
+
element_prov = element.prov[0]
|
66
|
+
|
67
|
+
bbox = element_prov.bbox
|
68
|
+
width = bbox.r - bbox.l
|
69
|
+
height = bbox.t - bbox.b
|
70
|
+
|
71
|
+
# TODO: move to a utility in the BoundingBox class
|
72
|
+
expanded_bbox = BoundingBox(
|
73
|
+
l=bbox.l - width * self.expansion_factor,
|
74
|
+
t=bbox.t + height * self.expansion_factor,
|
75
|
+
r=bbox.r + width * self.expansion_factor,
|
76
|
+
b=bbox.b - height * self.expansion_factor,
|
77
|
+
coord_origin=bbox.coord_origin,
|
78
|
+
)
|
79
|
+
|
80
|
+
page_ix = element_prov.page_no - 1
|
81
|
+
cropped_image = conv_res.pages[page_ix].get_image(
|
82
|
+
scale=self.images_scale, cropbox=expanded_bbox
|
83
|
+
)
|
84
|
+
return ItemAndImageEnrichmentElement(item=element, image=cropped_image)
|
docling/models/base_ocr_model.py
CHANGED
@@ -8,7 +8,7 @@ import numpy as np
|
|
8
8
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
9
9
|
from PIL import Image, ImageDraw
|
10
10
|
from rtree import index
|
11
|
-
from scipy.ndimage import find_objects, label
|
11
|
+
from scipy.ndimage import binary_dilation, find_objects, label
|
12
12
|
|
13
13
|
from docling.datamodel.base_models import Cell, OcrCell, Page
|
14
14
|
from docling.datamodel.document import ConversionResult
|
@@ -43,6 +43,12 @@ class BaseOcrModel(BasePageModel):
|
|
43
43
|
|
44
44
|
np_image = np.array(image)
|
45
45
|
|
46
|
+
# Dilate the image by 10 pixels to merge nearby bitmap rectangles
|
47
|
+
structure = np.ones(
|
48
|
+
(20, 20)
|
49
|
+
) # Create a 20x20 structure element (10 pixels in all directions)
|
50
|
+
np_image = binary_dilation(np_image > 0, structure=structure)
|
51
|
+
|
46
52
|
# Find the connected components
|
47
53
|
labeled_image, num_features = label(
|
48
54
|
np_image > 0
|
@@ -72,7 +78,7 @@ class BaseOcrModel(BasePageModel):
|
|
72
78
|
bitmap_rects = []
|
73
79
|
coverage, ocr_rects = find_ocr_rects(page.size, bitmap_rects)
|
74
80
|
|
75
|
-
# return full-page rectangle if
|
81
|
+
# return full-page rectangle if page is dominantly covered with bitmaps
|
76
82
|
if self.options.force_full_page_ocr or coverage > max(
|
77
83
|
BITMAP_COVERAGE_TRESHOLD, self.options.bitmap_area_threshold
|
78
84
|
):
|
@@ -85,17 +91,11 @@ class BaseOcrModel(BasePageModel):
|
|
85
91
|
coord_origin=CoordOrigin.TOPLEFT,
|
86
92
|
)
|
87
93
|
]
|
88
|
-
# return individual rectangles if the bitmap coverage is
|
89
|
-
|
90
|
-
|
91
|
-
# skip OCR if the bitmap area on the page is smaller than the options threshold
|
92
|
-
ocr_rects = [
|
93
|
-
rect
|
94
|
-
for rect in ocr_rects
|
95
|
-
if rect.area() / (page.size.width * page.size.height)
|
96
|
-
> self.options.bitmap_area_threshold
|
97
|
-
]
|
94
|
+
# return individual rectangles if the bitmap coverage is above the threshold
|
95
|
+
elif coverage > self.options.bitmap_area_threshold:
|
98
96
|
return ocr_rects
|
97
|
+
else: # overall coverage of bitmaps is too low, drop all bitmap rectangles.
|
98
|
+
return []
|
99
99
|
|
100
100
|
# Filters OCR cells by dropping any OCR cell that intersects with an existing programmatic cell.
|
101
101
|
def _filter_ocr_cells(self, ocr_cells, programmatic_cells):
|
@@ -162,6 +162,9 @@ class BaseOcrModel(BasePageModel):
|
|
162
162
|
x0 *= scale_x
|
163
163
|
x1 *= scale_x
|
164
164
|
|
165
|
+
if y1 <= y0:
|
166
|
+
y1, y0 = y0, y1
|
167
|
+
|
165
168
|
color = "gray"
|
166
169
|
if isinstance(tc, OcrCell):
|
167
170
|
color = "magenta"
|
@@ -0,0 +1,245 @@
|
|
1
|
+
import re
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
4
|
+
|
5
|
+
from docling_core.types.doc import (
|
6
|
+
CodeItem,
|
7
|
+
DocItemLabel,
|
8
|
+
DoclingDocument,
|
9
|
+
NodeItem,
|
10
|
+
TextItem,
|
11
|
+
)
|
12
|
+
from docling_core.types.doc.labels import CodeLanguageLabel
|
13
|
+
from PIL import Image
|
14
|
+
from pydantic import BaseModel
|
15
|
+
|
16
|
+
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
17
|
+
from docling.datamodel.pipeline_options import AcceleratorOptions
|
18
|
+
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
19
|
+
from docling.utils.accelerator_utils import decide_device
|
20
|
+
|
21
|
+
|
22
|
+
class CodeFormulaModelOptions(BaseModel):
|
23
|
+
"""
|
24
|
+
Configuration options for the CodeFormulaModel.
|
25
|
+
|
26
|
+
Attributes
|
27
|
+
----------
|
28
|
+
kind : str
|
29
|
+
Type of the model. Fixed value "code_formula".
|
30
|
+
do_code_enrichment : bool
|
31
|
+
True if code enrichment is enabled, False otherwise.
|
32
|
+
do_formula_enrichment : bool
|
33
|
+
True if formula enrichment is enabled, False otherwise.
|
34
|
+
"""
|
35
|
+
|
36
|
+
kind: Literal["code_formula"] = "code_formula"
|
37
|
+
do_code_enrichment: bool = True
|
38
|
+
do_formula_enrichment: bool = True
|
39
|
+
|
40
|
+
|
41
|
+
class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
42
|
+
"""
|
43
|
+
Model for processing and enriching documents with code and formula predictions.
|
44
|
+
|
45
|
+
Attributes
|
46
|
+
----------
|
47
|
+
enabled : bool
|
48
|
+
True if the model is enabled, False otherwise.
|
49
|
+
options : CodeFormulaModelOptions
|
50
|
+
Configuration options for the CodeFormulaModel.
|
51
|
+
code_formula_model : CodeFormulaPredictor
|
52
|
+
The predictor model for code and formula processing.
|
53
|
+
|
54
|
+
Methods
|
55
|
+
-------
|
56
|
+
__init__(self, enabled, artifacts_path, accelerator_options, code_formula_options)
|
57
|
+
Initializes the CodeFormulaModel with the given configuration options.
|
58
|
+
is_processable(self, doc, element)
|
59
|
+
Determines if a given element in a document can be processed by the model.
|
60
|
+
__call__(self, doc, element_batch)
|
61
|
+
Processes the given batch of elements and enriches them with predictions.
|
62
|
+
"""
|
63
|
+
|
64
|
+
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
65
|
+
expansion_factor = 0.03
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
enabled: bool,
|
70
|
+
artifacts_path: Optional[Union[Path, str]],
|
71
|
+
options: CodeFormulaModelOptions,
|
72
|
+
accelerator_options: AcceleratorOptions,
|
73
|
+
):
|
74
|
+
"""
|
75
|
+
Initializes the CodeFormulaModel with the given configuration.
|
76
|
+
|
77
|
+
Parameters
|
78
|
+
----------
|
79
|
+
enabled : bool
|
80
|
+
True if the model is enabled, False otherwise.
|
81
|
+
artifacts_path : Path
|
82
|
+
Path to the directory containing the model artifacts.
|
83
|
+
options : CodeFormulaModelOptions
|
84
|
+
Configuration options for the model.
|
85
|
+
accelerator_options : AcceleratorOptions
|
86
|
+
Options specifying the device and number of threads for acceleration.
|
87
|
+
"""
|
88
|
+
self.enabled = enabled
|
89
|
+
self.options = options
|
90
|
+
|
91
|
+
if self.enabled:
|
92
|
+
device = decide_device(accelerator_options.device)
|
93
|
+
|
94
|
+
from docling_ibm_models.code_formula_model.code_formula_predictor import (
|
95
|
+
CodeFormulaPredictor,
|
96
|
+
)
|
97
|
+
|
98
|
+
if artifacts_path is None:
|
99
|
+
artifacts_path = self.download_models_hf()
|
100
|
+
else:
|
101
|
+
artifacts_path = Path(artifacts_path)
|
102
|
+
|
103
|
+
self.code_formula_model = CodeFormulaPredictor(
|
104
|
+
artifacts_path=artifacts_path,
|
105
|
+
device=device,
|
106
|
+
num_threads=accelerator_options.num_threads,
|
107
|
+
)
|
108
|
+
|
109
|
+
@staticmethod
|
110
|
+
def download_models_hf(
|
111
|
+
local_dir: Optional[Path] = None, force: bool = False
|
112
|
+
) -> Path:
|
113
|
+
from huggingface_hub import snapshot_download
|
114
|
+
from huggingface_hub.utils import disable_progress_bars
|
115
|
+
|
116
|
+
disable_progress_bars()
|
117
|
+
download_path = snapshot_download(
|
118
|
+
repo_id="ds4sd/CodeFormula",
|
119
|
+
force_download=force,
|
120
|
+
local_dir=local_dir,
|
121
|
+
revision="v1.0.0",
|
122
|
+
)
|
123
|
+
|
124
|
+
return Path(download_path)
|
125
|
+
|
126
|
+
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
127
|
+
"""
|
128
|
+
Determines if a given element in a document can be processed by the model.
|
129
|
+
|
130
|
+
Parameters
|
131
|
+
----------
|
132
|
+
doc : DoclingDocument
|
133
|
+
The document being processed.
|
134
|
+
element : NodeItem
|
135
|
+
The element within the document to check.
|
136
|
+
|
137
|
+
Returns
|
138
|
+
-------
|
139
|
+
bool
|
140
|
+
True if the element can be processed, False otherwise.
|
141
|
+
"""
|
142
|
+
return self.enabled and (
|
143
|
+
(isinstance(element, CodeItem) and self.options.do_code_enrichment)
|
144
|
+
or (
|
145
|
+
isinstance(element, TextItem)
|
146
|
+
and element.label == DocItemLabel.FORMULA
|
147
|
+
and self.options.do_formula_enrichment
|
148
|
+
)
|
149
|
+
)
|
150
|
+
|
151
|
+
def _extract_code_language(self, input_string: str) -> Tuple[str, Optional[str]]:
|
152
|
+
"""Extracts a programming language from the beginning of a string.
|
153
|
+
|
154
|
+
This function checks if the input string starts with a pattern of the form
|
155
|
+
``<_some_language_>``. If it does, it extracts the language string and returns
|
156
|
+
a tuple of (remainder, language). Otherwise, it returns the original string
|
157
|
+
and `None`.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
input_string (str): The input string, which may start with ``<_language_>``.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
Tuple[str, Optional[str]]:
|
164
|
+
A tuple where:
|
165
|
+
- The first element is either:
|
166
|
+
- The remainder of the string (everything after ``<_language_>``),
|
167
|
+
if a match is found; or
|
168
|
+
- The original string, if no match is found.
|
169
|
+
- The second element is the extracted language if a match is found;
|
170
|
+
otherwise, `None`.
|
171
|
+
"""
|
172
|
+
pattern = r"^<_([^>]+)_>\s*(.*)"
|
173
|
+
match = re.match(pattern, input_string, flags=re.DOTALL)
|
174
|
+
if match:
|
175
|
+
language = str(match.group(1)) # the captured programming language
|
176
|
+
remainder = str(match.group(2)) # everything after the <_language_>
|
177
|
+
return remainder, language
|
178
|
+
else:
|
179
|
+
return input_string, None
|
180
|
+
|
181
|
+
def _get_code_language_enum(self, value: Optional[str]) -> CodeLanguageLabel:
|
182
|
+
"""
|
183
|
+
Converts a string to a corresponding `CodeLanguageLabel` enum member.
|
184
|
+
|
185
|
+
If the provided string does not match any value in `CodeLanguageLabel`,
|
186
|
+
it defaults to `CodeLanguageLabel.UNKNOWN`.
|
187
|
+
|
188
|
+
Args:
|
189
|
+
value (Optional[str]): The string representation of the code language or None.
|
190
|
+
|
191
|
+
Returns:
|
192
|
+
CodeLanguageLabel: The corresponding enum member if the value is valid,
|
193
|
+
otherwise `CodeLanguageLabel.UNKNOWN`.
|
194
|
+
"""
|
195
|
+
if not isinstance(value, str):
|
196
|
+
return CodeLanguageLabel.UNKNOWN
|
197
|
+
|
198
|
+
try:
|
199
|
+
return CodeLanguageLabel(value)
|
200
|
+
except ValueError:
|
201
|
+
return CodeLanguageLabel.UNKNOWN
|
202
|
+
|
203
|
+
def __call__(
|
204
|
+
self,
|
205
|
+
doc: DoclingDocument,
|
206
|
+
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
207
|
+
) -> Iterable[NodeItem]:
|
208
|
+
"""
|
209
|
+
Processes the given batch of elements and enriches them with predictions.
|
210
|
+
|
211
|
+
Parameters
|
212
|
+
----------
|
213
|
+
doc : DoclingDocument
|
214
|
+
The document being processed.
|
215
|
+
element_batch : Iterable[ItemAndImageEnrichmentElement]
|
216
|
+
A batch of elements to be processed.
|
217
|
+
|
218
|
+
Returns
|
219
|
+
-------
|
220
|
+
Iterable[Any]
|
221
|
+
An iterable of enriched elements.
|
222
|
+
"""
|
223
|
+
if not self.enabled:
|
224
|
+
for element in element_batch:
|
225
|
+
yield element.item
|
226
|
+
return
|
227
|
+
|
228
|
+
labels: List[str] = []
|
229
|
+
images: List[Image.Image] = []
|
230
|
+
elements: List[TextItem] = []
|
231
|
+
for el in element_batch:
|
232
|
+
assert isinstance(el.item, TextItem)
|
233
|
+
elements.append(el.item)
|
234
|
+
labels.append(el.item.label)
|
235
|
+
images.append(el.image)
|
236
|
+
|
237
|
+
outputs = self.code_formula_model.predict(images, labels)
|
238
|
+
|
239
|
+
for item, output in zip(elements, outputs):
|
240
|
+
if isinstance(item, CodeItem):
|
241
|
+
output, code_language = self._extract_code_language(output)
|
242
|
+
item.code_language = self._get_code_language_enum(code_language)
|
243
|
+
item.text = output
|
244
|
+
|
245
|
+
yield item
|
@@ -0,0 +1,187 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
3
|
+
|
4
|
+
from docling_core.types.doc import (
|
5
|
+
DoclingDocument,
|
6
|
+
NodeItem,
|
7
|
+
PictureClassificationClass,
|
8
|
+
PictureClassificationData,
|
9
|
+
PictureItem,
|
10
|
+
)
|
11
|
+
from PIL import Image
|
12
|
+
from pydantic import BaseModel
|
13
|
+
|
14
|
+
from docling.datamodel.pipeline_options import AcceleratorOptions
|
15
|
+
from docling.models.base_model import BaseEnrichmentModel
|
16
|
+
from docling.utils.accelerator_utils import decide_device
|
17
|
+
|
18
|
+
|
19
|
+
class DocumentPictureClassifierOptions(BaseModel):
|
20
|
+
"""
|
21
|
+
Options for configuring the DocumentPictureClassifier.
|
22
|
+
|
23
|
+
Attributes
|
24
|
+
----------
|
25
|
+
kind : Literal["document_picture_classifier"]
|
26
|
+
Identifier for the type of classifier.
|
27
|
+
"""
|
28
|
+
|
29
|
+
kind: Literal["document_picture_classifier"] = "document_picture_classifier"
|
30
|
+
|
31
|
+
|
32
|
+
class DocumentPictureClassifier(BaseEnrichmentModel):
|
33
|
+
"""
|
34
|
+
A model for classifying pictures in documents.
|
35
|
+
|
36
|
+
This class enriches document pictures with predicted classifications
|
37
|
+
based on a predefined set of classes.
|
38
|
+
|
39
|
+
Attributes
|
40
|
+
----------
|
41
|
+
enabled : bool
|
42
|
+
Whether the classifier is enabled for use.
|
43
|
+
options : DocumentPictureClassifierOptions
|
44
|
+
Configuration options for the classifier.
|
45
|
+
document_picture_classifier : DocumentPictureClassifierPredictor
|
46
|
+
The underlying prediction model, loaded if the classifier is enabled.
|
47
|
+
|
48
|
+
Methods
|
49
|
+
-------
|
50
|
+
__init__(enabled, artifacts_path, options, accelerator_options)
|
51
|
+
Initializes the classifier with specified configurations.
|
52
|
+
is_processable(doc, element)
|
53
|
+
Checks if the given element can be processed by the classifier.
|
54
|
+
__call__(doc, element_batch)
|
55
|
+
Processes a batch of elements and adds classification annotations.
|
56
|
+
"""
|
57
|
+
|
58
|
+
images_scale = 2
|
59
|
+
|
60
|
+
def __init__(
|
61
|
+
self,
|
62
|
+
enabled: bool,
|
63
|
+
artifacts_path: Optional[Union[Path, str]],
|
64
|
+
options: DocumentPictureClassifierOptions,
|
65
|
+
accelerator_options: AcceleratorOptions,
|
66
|
+
):
|
67
|
+
"""
|
68
|
+
Initializes the DocumentPictureClassifier.
|
69
|
+
|
70
|
+
Parameters
|
71
|
+
----------
|
72
|
+
enabled : bool
|
73
|
+
Indicates whether the classifier is enabled.
|
74
|
+
artifacts_path : Optional[Union[Path, str]],
|
75
|
+
Path to the directory containing model artifacts.
|
76
|
+
options : DocumentPictureClassifierOptions
|
77
|
+
Configuration options for the classifier.
|
78
|
+
accelerator_options : AcceleratorOptions
|
79
|
+
Options for configuring the device and parallelism.
|
80
|
+
"""
|
81
|
+
self.enabled = enabled
|
82
|
+
self.options = options
|
83
|
+
|
84
|
+
if self.enabled:
|
85
|
+
device = decide_device(accelerator_options.device)
|
86
|
+
from docling_ibm_models.document_figure_classifier_model.document_figure_classifier_predictor import (
|
87
|
+
DocumentFigureClassifierPredictor,
|
88
|
+
)
|
89
|
+
|
90
|
+
if artifacts_path is None:
|
91
|
+
artifacts_path = self.download_models_hf()
|
92
|
+
else:
|
93
|
+
artifacts_path = Path(artifacts_path)
|
94
|
+
|
95
|
+
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
96
|
+
artifacts_path=artifacts_path,
|
97
|
+
device=device,
|
98
|
+
num_threads=accelerator_options.num_threads,
|
99
|
+
)
|
100
|
+
|
101
|
+
@staticmethod
|
102
|
+
def download_models_hf(
|
103
|
+
local_dir: Optional[Path] = None, force: bool = False
|
104
|
+
) -> Path:
|
105
|
+
from huggingface_hub import snapshot_download
|
106
|
+
from huggingface_hub.utils import disable_progress_bars
|
107
|
+
|
108
|
+
disable_progress_bars()
|
109
|
+
download_path = snapshot_download(
|
110
|
+
repo_id="ds4sd/DocumentFigureClassifier",
|
111
|
+
force_download=force,
|
112
|
+
local_dir=local_dir,
|
113
|
+
revision="v1.0.0",
|
114
|
+
)
|
115
|
+
|
116
|
+
return Path(download_path)
|
117
|
+
|
118
|
+
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
119
|
+
"""
|
120
|
+
Determines if the given element can be processed by the classifier.
|
121
|
+
|
122
|
+
Parameters
|
123
|
+
----------
|
124
|
+
doc : DoclingDocument
|
125
|
+
The document containing the element.
|
126
|
+
element : NodeItem
|
127
|
+
The element to be checked.
|
128
|
+
|
129
|
+
Returns
|
130
|
+
-------
|
131
|
+
bool
|
132
|
+
True if the element is a PictureItem and processing is enabled; False otherwise.
|
133
|
+
"""
|
134
|
+
return self.enabled and isinstance(element, PictureItem)
|
135
|
+
|
136
|
+
def __call__(
|
137
|
+
self,
|
138
|
+
doc: DoclingDocument,
|
139
|
+
element_batch: Iterable[NodeItem],
|
140
|
+
) -> Iterable[NodeItem]:
|
141
|
+
"""
|
142
|
+
Processes a batch of elements and enriches them with classification predictions.
|
143
|
+
|
144
|
+
Parameters
|
145
|
+
----------
|
146
|
+
doc : DoclingDocument
|
147
|
+
The document containing the elements to be processed.
|
148
|
+
element_batch : Iterable[NodeItem]
|
149
|
+
A batch of pictures to classify.
|
150
|
+
|
151
|
+
Returns
|
152
|
+
-------
|
153
|
+
Iterable[NodeItem]
|
154
|
+
An iterable of NodeItem objects after processing. The field
|
155
|
+
'data.classification' is added containing the classification for each picture.
|
156
|
+
"""
|
157
|
+
if not self.enabled:
|
158
|
+
for element in element_batch:
|
159
|
+
yield element
|
160
|
+
return
|
161
|
+
|
162
|
+
images: List[Image.Image] = []
|
163
|
+
elements: List[PictureItem] = []
|
164
|
+
for el in element_batch:
|
165
|
+
assert isinstance(el, PictureItem)
|
166
|
+
elements.append(el)
|
167
|
+
img = el.get_image(doc)
|
168
|
+
assert img is not None
|
169
|
+
images.append(img)
|
170
|
+
|
171
|
+
outputs = self.document_picture_classifier.predict(images)
|
172
|
+
|
173
|
+
for element, output in zip(elements, outputs):
|
174
|
+
element.annotations.append(
|
175
|
+
PictureClassificationData(
|
176
|
+
provenance="DocumentPictureClassifier",
|
177
|
+
predicted_classes=[
|
178
|
+
PictureClassificationClass(
|
179
|
+
class_name=pred[0],
|
180
|
+
confidence=pred[1],
|
181
|
+
)
|
182
|
+
for pred in output
|
183
|
+
],
|
184
|
+
)
|
185
|
+
)
|
186
|
+
|
187
|
+
yield element
|