docling 2.69.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docling might be problematic. Click here for more details.
- docling/__init__.py +0 -0
- docling/backend/__init__.py +0 -0
- docling/backend/abstract_backend.py +84 -0
- docling/backend/asciidoc_backend.py +443 -0
- docling/backend/csv_backend.py +125 -0
- docling/backend/docling_parse_backend.py +237 -0
- docling/backend/docling_parse_v2_backend.py +276 -0
- docling/backend/docling_parse_v4_backend.py +260 -0
- docling/backend/docx/__init__.py +0 -0
- docling/backend/docx/drawingml/utils.py +131 -0
- docling/backend/docx/latex/__init__.py +0 -0
- docling/backend/docx/latex/latex_dict.py +274 -0
- docling/backend/docx/latex/omml.py +459 -0
- docling/backend/html_backend.py +1502 -0
- docling/backend/image_backend.py +188 -0
- docling/backend/json/__init__.py +0 -0
- docling/backend/json/docling_json_backend.py +58 -0
- docling/backend/md_backend.py +618 -0
- docling/backend/mets_gbs_backend.py +399 -0
- docling/backend/msexcel_backend.py +686 -0
- docling/backend/mspowerpoint_backend.py +398 -0
- docling/backend/msword_backend.py +1663 -0
- docling/backend/noop_backend.py +51 -0
- docling/backend/pdf_backend.py +82 -0
- docling/backend/pypdfium2_backend.py +417 -0
- docling/backend/webvtt_backend.py +572 -0
- docling/backend/xml/__init__.py +0 -0
- docling/backend/xml/jats_backend.py +819 -0
- docling/backend/xml/uspto_backend.py +1905 -0
- docling/chunking/__init__.py +12 -0
- docling/cli/__init__.py +0 -0
- docling/cli/main.py +974 -0
- docling/cli/models.py +196 -0
- docling/cli/tools.py +17 -0
- docling/datamodel/__init__.py +0 -0
- docling/datamodel/accelerator_options.py +69 -0
- docling/datamodel/asr_model_specs.py +494 -0
- docling/datamodel/backend_options.py +102 -0
- docling/datamodel/base_models.py +493 -0
- docling/datamodel/document.py +699 -0
- docling/datamodel/extraction.py +39 -0
- docling/datamodel/layout_model_specs.py +91 -0
- docling/datamodel/pipeline_options.py +457 -0
- docling/datamodel/pipeline_options_asr_model.py +78 -0
- docling/datamodel/pipeline_options_vlm_model.py +136 -0
- docling/datamodel/settings.py +65 -0
- docling/datamodel/vlm_model_specs.py +365 -0
- docling/document_converter.py +559 -0
- docling/document_extractor.py +327 -0
- docling/exceptions.py +10 -0
- docling/experimental/__init__.py +5 -0
- docling/experimental/datamodel/__init__.py +1 -0
- docling/experimental/datamodel/table_crops_layout_options.py +13 -0
- docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py +45 -0
- docling/experimental/models/__init__.py +3 -0
- docling/experimental/models/table_crops_layout_model.py +114 -0
- docling/experimental/pipeline/__init__.py +1 -0
- docling/experimental/pipeline/threaded_layout_vlm_pipeline.py +439 -0
- docling/models/__init__.py +0 -0
- docling/models/base_layout_model.py +39 -0
- docling/models/base_model.py +230 -0
- docling/models/base_ocr_model.py +241 -0
- docling/models/base_table_model.py +45 -0
- docling/models/extraction/__init__.py +0 -0
- docling/models/extraction/nuextract_transformers_model.py +305 -0
- docling/models/factories/__init__.py +47 -0
- docling/models/factories/base_factory.py +122 -0
- docling/models/factories/layout_factory.py +7 -0
- docling/models/factories/ocr_factory.py +11 -0
- docling/models/factories/picture_description_factory.py +11 -0
- docling/models/factories/table_factory.py +7 -0
- docling/models/picture_description_base_model.py +149 -0
- docling/models/plugins/__init__.py +0 -0
- docling/models/plugins/defaults.py +60 -0
- docling/models/stages/__init__.py +0 -0
- docling/models/stages/code_formula/__init__.py +0 -0
- docling/models/stages/code_formula/code_formula_model.py +342 -0
- docling/models/stages/layout/__init__.py +0 -0
- docling/models/stages/layout/layout_model.py +249 -0
- docling/models/stages/ocr/__init__.py +0 -0
- docling/models/stages/ocr/auto_ocr_model.py +132 -0
- docling/models/stages/ocr/easyocr_model.py +200 -0
- docling/models/stages/ocr/ocr_mac_model.py +145 -0
- docling/models/stages/ocr/rapid_ocr_model.py +328 -0
- docling/models/stages/ocr/tesseract_ocr_cli_model.py +331 -0
- docling/models/stages/ocr/tesseract_ocr_model.py +262 -0
- docling/models/stages/page_assemble/__init__.py +0 -0
- docling/models/stages/page_assemble/page_assemble_model.py +156 -0
- docling/models/stages/page_preprocessing/__init__.py +0 -0
- docling/models/stages/page_preprocessing/page_preprocessing_model.py +145 -0
- docling/models/stages/picture_classifier/__init__.py +0 -0
- docling/models/stages/picture_classifier/document_picture_classifier.py +246 -0
- docling/models/stages/picture_description/__init__.py +0 -0
- docling/models/stages/picture_description/picture_description_api_model.py +66 -0
- docling/models/stages/picture_description/picture_description_vlm_model.py +123 -0
- docling/models/stages/reading_order/__init__.py +0 -0
- docling/models/stages/reading_order/readingorder_model.py +431 -0
- docling/models/stages/table_structure/__init__.py +0 -0
- docling/models/stages/table_structure/table_structure_model.py +305 -0
- docling/models/utils/__init__.py +0 -0
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +45 -0
- docling/models/vlm_pipeline_models/__init__.py +1 -0
- docling/models/vlm_pipeline_models/api_vlm_model.py +180 -0
- docling/models/vlm_pipeline_models/hf_transformers_model.py +391 -0
- docling/models/vlm_pipeline_models/mlx_model.py +325 -0
- docling/models/vlm_pipeline_models/vllm_model.py +344 -0
- docling/pipeline/__init__.py +0 -0
- docling/pipeline/asr_pipeline.py +431 -0
- docling/pipeline/base_extraction_pipeline.py +72 -0
- docling/pipeline/base_pipeline.py +326 -0
- docling/pipeline/extraction_vlm_pipeline.py +207 -0
- docling/pipeline/legacy_standard_pdf_pipeline.py +262 -0
- docling/pipeline/simple_pipeline.py +55 -0
- docling/pipeline/standard_pdf_pipeline.py +859 -0
- docling/pipeline/threaded_standard_pdf_pipeline.py +5 -0
- docling/pipeline/vlm_pipeline.py +416 -0
- docling/py.typed +1 -0
- docling/utils/__init__.py +0 -0
- docling/utils/accelerator_utils.py +97 -0
- docling/utils/api_image_request.py +205 -0
- docling/utils/deepseekocr_utils.py +388 -0
- docling/utils/export.py +146 -0
- docling/utils/glm_utils.py +361 -0
- docling/utils/layout_postprocessor.py +683 -0
- docling/utils/locks.py +3 -0
- docling/utils/model_downloader.py +168 -0
- docling/utils/ocr_utils.py +69 -0
- docling/utils/orientation.py +65 -0
- docling/utils/profiling.py +65 -0
- docling/utils/utils.py +65 -0
- docling/utils/visualization.py +85 -0
- docling-2.69.0.dist-info/METADATA +237 -0
- docling-2.69.0.dist-info/RECORD +138 -0
- docling-2.69.0.dist-info/WHEEL +5 -0
- docling-2.69.0.dist-info/entry_points.txt +6 -0
- docling-2.69.0.dist-info/licenses/LICENSE +21 -0
- docling-2.69.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Iterable, Sequence
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
import numpy
|
|
8
|
+
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
|
9
|
+
from docling_core.types.doc.page import (
|
|
10
|
+
BoundingRectangle,
|
|
11
|
+
TextCellUnit,
|
|
12
|
+
)
|
|
13
|
+
from PIL import ImageDraw
|
|
14
|
+
|
|
15
|
+
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
|
16
|
+
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
|
17
|
+
from docling.datamodel.document import ConversionResult
|
|
18
|
+
from docling.datamodel.pipeline_options import (
|
|
19
|
+
TableFormerMode,
|
|
20
|
+
TableStructureOptions,
|
|
21
|
+
)
|
|
22
|
+
from docling.datamodel.settings import settings
|
|
23
|
+
from docling.models.base_table_model import BaseTableStructureModel
|
|
24
|
+
from docling.models.utils.hf_model_download import download_hf_model
|
|
25
|
+
from docling.utils.accelerator_utils import decide_device
|
|
26
|
+
from docling.utils.profiling import TimeRecorder
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class TableStructureModel(BaseTableStructureModel):
|
|
30
|
+
_model_repo_folder = "docling-project--docling-models"
|
|
31
|
+
_model_path = "model_artifacts/tableformer"
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
enabled: bool,
|
|
36
|
+
artifacts_path: Optional[Path],
|
|
37
|
+
options: TableStructureOptions,
|
|
38
|
+
accelerator_options: AcceleratorOptions,
|
|
39
|
+
):
|
|
40
|
+
self.options = options
|
|
41
|
+
self.do_cell_matching = self.options.do_cell_matching
|
|
42
|
+
self.mode = self.options.mode
|
|
43
|
+
|
|
44
|
+
self.enabled = enabled
|
|
45
|
+
if self.enabled:
|
|
46
|
+
if artifacts_path is None:
|
|
47
|
+
artifacts_path = self.download_models() / self._model_path
|
|
48
|
+
else:
|
|
49
|
+
# will become the default in the future
|
|
50
|
+
if (artifacts_path / self._model_repo_folder).exists():
|
|
51
|
+
artifacts_path = (
|
|
52
|
+
artifacts_path / self._model_repo_folder / self._model_path
|
|
53
|
+
)
|
|
54
|
+
elif (artifacts_path / self._model_path).exists():
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"The usage of artifacts_path containing directly "
|
|
57
|
+
f"{self._model_path} is deprecated. Please point "
|
|
58
|
+
"the artifacts_path to the parent containing "
|
|
59
|
+
f"the {self._model_repo_folder} folder.",
|
|
60
|
+
DeprecationWarning,
|
|
61
|
+
stacklevel=3,
|
|
62
|
+
)
|
|
63
|
+
artifacts_path = artifacts_path / self._model_path
|
|
64
|
+
|
|
65
|
+
if self.mode == TableFormerMode.ACCURATE:
|
|
66
|
+
artifacts_path = artifacts_path / "accurate"
|
|
67
|
+
else:
|
|
68
|
+
artifacts_path = artifacts_path / "fast"
|
|
69
|
+
|
|
70
|
+
# Third Party
|
|
71
|
+
import docling_ibm_models.tableformer.common as c
|
|
72
|
+
from docling_ibm_models.tableformer.data_management.tf_predictor import (
|
|
73
|
+
TFPredictor,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
device = decide_device(accelerator_options.device)
|
|
77
|
+
|
|
78
|
+
# Disable MPS here, until we know why it makes things slower.
|
|
79
|
+
if device == AcceleratorDevice.MPS.value:
|
|
80
|
+
device = AcceleratorDevice.CPU.value
|
|
81
|
+
|
|
82
|
+
self.tm_config = c.read_config(f"{artifacts_path}/tm_config.json")
|
|
83
|
+
self.tm_config["model"]["save_dir"] = artifacts_path
|
|
84
|
+
self.tm_model_type = self.tm_config["model"]["type"]
|
|
85
|
+
|
|
86
|
+
self.tf_predictor = TFPredictor(
|
|
87
|
+
self.tm_config, device, accelerator_options.num_threads
|
|
88
|
+
)
|
|
89
|
+
self.scale = 2.0 # Scale up table input images to 144 dpi
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
def get_options_type(cls) -> type[TableStructureOptions]:
|
|
93
|
+
return TableStructureOptions
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def download_models(
|
|
97
|
+
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
|
98
|
+
) -> Path:
|
|
99
|
+
return download_hf_model(
|
|
100
|
+
repo_id="docling-project/docling-models",
|
|
101
|
+
revision="v2.3.0",
|
|
102
|
+
local_dir=local_dir,
|
|
103
|
+
force=force,
|
|
104
|
+
progress=progress,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def draw_table_and_cells(
|
|
108
|
+
self,
|
|
109
|
+
conv_res: ConversionResult,
|
|
110
|
+
page: Page,
|
|
111
|
+
tbl_list: Iterable[Table],
|
|
112
|
+
show: bool = False,
|
|
113
|
+
):
|
|
114
|
+
assert page._backend is not None
|
|
115
|
+
assert page.size is not None
|
|
116
|
+
|
|
117
|
+
image = (
|
|
118
|
+
page._backend.get_page_image()
|
|
119
|
+
) # make new image to avoid drawing on the saved ones
|
|
120
|
+
|
|
121
|
+
scale_x = image.width / page.size.width
|
|
122
|
+
scale_y = image.height / page.size.height
|
|
123
|
+
|
|
124
|
+
draw = ImageDraw.Draw(image)
|
|
125
|
+
|
|
126
|
+
for table_element in tbl_list:
|
|
127
|
+
x0, y0, x1, y1 = table_element.cluster.bbox.as_tuple()
|
|
128
|
+
y0 *= scale_y
|
|
129
|
+
y1 *= scale_y
|
|
130
|
+
x0 *= scale_x
|
|
131
|
+
x1 *= scale_x
|
|
132
|
+
|
|
133
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="red")
|
|
134
|
+
|
|
135
|
+
for cell in table_element.cluster.cells:
|
|
136
|
+
x0, y0, x1, y1 = cell.rect.to_bounding_box().as_tuple()
|
|
137
|
+
x0 *= scale_x
|
|
138
|
+
x1 *= scale_x
|
|
139
|
+
y0 *= scale_y
|
|
140
|
+
y1 *= scale_y
|
|
141
|
+
|
|
142
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="green")
|
|
143
|
+
|
|
144
|
+
for tc in table_element.table_cells:
|
|
145
|
+
if tc.bbox is not None:
|
|
146
|
+
x0, y0, x1, y1 = tc.bbox.as_tuple()
|
|
147
|
+
x0 *= scale_x
|
|
148
|
+
x1 *= scale_x
|
|
149
|
+
y0 *= scale_y
|
|
150
|
+
y1 *= scale_y
|
|
151
|
+
|
|
152
|
+
if tc.column_header:
|
|
153
|
+
width = 3
|
|
154
|
+
else:
|
|
155
|
+
width = 1
|
|
156
|
+
draw.rectangle([(x0, y0), (x1, y1)], outline="blue", width=width)
|
|
157
|
+
draw.text(
|
|
158
|
+
(x0 + 3, y0 + 3),
|
|
159
|
+
text=f"{tc.start_row_offset_idx}, {tc.start_col_offset_idx}",
|
|
160
|
+
fill="black",
|
|
161
|
+
)
|
|
162
|
+
if show:
|
|
163
|
+
image.show()
|
|
164
|
+
else:
|
|
165
|
+
out_path: Path = (
|
|
166
|
+
Path(settings.debug.debug_output_path)
|
|
167
|
+
/ f"debug_{conv_res.input.file.stem}"
|
|
168
|
+
)
|
|
169
|
+
out_path.mkdir(parents=True, exist_ok=True)
|
|
170
|
+
|
|
171
|
+
out_file = out_path / f"table_struct_page_{page.page_no:05}.png"
|
|
172
|
+
image.save(str(out_file), format="png")
|
|
173
|
+
|
|
174
|
+
def predict_tables(
|
|
175
|
+
self,
|
|
176
|
+
conv_res: ConversionResult,
|
|
177
|
+
pages: Sequence[Page],
|
|
178
|
+
) -> Sequence[TableStructurePrediction]:
|
|
179
|
+
pages = list(pages)
|
|
180
|
+
predictions: list[TableStructurePrediction] = []
|
|
181
|
+
|
|
182
|
+
for page in pages:
|
|
183
|
+
assert page._backend is not None
|
|
184
|
+
if not page._backend.is_valid():
|
|
185
|
+
existing_prediction = (
|
|
186
|
+
page.predictions.tablestructure or TableStructurePrediction()
|
|
187
|
+
)
|
|
188
|
+
page.predictions.tablestructure = existing_prediction
|
|
189
|
+
predictions.append(existing_prediction)
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
with TimeRecorder(conv_res, "table_structure"):
|
|
193
|
+
assert page.predictions.layout is not None
|
|
194
|
+
assert page.size is not None
|
|
195
|
+
|
|
196
|
+
table_prediction = TableStructurePrediction()
|
|
197
|
+
page.predictions.tablestructure = table_prediction
|
|
198
|
+
|
|
199
|
+
in_tables = [
|
|
200
|
+
(
|
|
201
|
+
cluster,
|
|
202
|
+
[
|
|
203
|
+
round(cluster.bbox.l) * self.scale,
|
|
204
|
+
round(cluster.bbox.t) * self.scale,
|
|
205
|
+
round(cluster.bbox.r) * self.scale,
|
|
206
|
+
round(cluster.bbox.b) * self.scale,
|
|
207
|
+
],
|
|
208
|
+
)
|
|
209
|
+
for cluster in page.predictions.layout.clusters
|
|
210
|
+
if cluster.label
|
|
211
|
+
in [DocItemLabel.TABLE, DocItemLabel.DOCUMENT_INDEX]
|
|
212
|
+
]
|
|
213
|
+
if not in_tables:
|
|
214
|
+
predictions.append(table_prediction)
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
page_input = {
|
|
218
|
+
"width": page.size.width * self.scale,
|
|
219
|
+
"height": page.size.height * self.scale,
|
|
220
|
+
"image": numpy.asarray(page.get_image(scale=self.scale)),
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
for table_cluster, tbl_box in in_tables:
|
|
224
|
+
# Check if word-level cells are available from backend:
|
|
225
|
+
sp = page._backend.get_segmented_page()
|
|
226
|
+
if sp is not None:
|
|
227
|
+
tcells = sp.get_cells_in_bbox(
|
|
228
|
+
cell_unit=TextCellUnit.WORD,
|
|
229
|
+
bbox=table_cluster.bbox,
|
|
230
|
+
)
|
|
231
|
+
if len(tcells) == 0:
|
|
232
|
+
# In case word-level cells yield empty
|
|
233
|
+
tcells = table_cluster.cells
|
|
234
|
+
else:
|
|
235
|
+
# Otherwise - we use normal (line/phrase) cells
|
|
236
|
+
tcells = table_cluster.cells
|
|
237
|
+
tokens = []
|
|
238
|
+
for c in tcells:
|
|
239
|
+
# Only allow non empty strings (spaces) into the cells of a table
|
|
240
|
+
if len(c.text.strip()) > 0:
|
|
241
|
+
new_cell = copy.deepcopy(c)
|
|
242
|
+
new_cell.rect = BoundingRectangle.from_bounding_box(
|
|
243
|
+
new_cell.rect.to_bounding_box().scaled(scale=self.scale)
|
|
244
|
+
)
|
|
245
|
+
tokens.append(
|
|
246
|
+
{
|
|
247
|
+
"id": new_cell.index,
|
|
248
|
+
"text": new_cell.text,
|
|
249
|
+
"bbox": new_cell.rect.to_bounding_box().model_dump(),
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
page_input["tokens"] = tokens
|
|
253
|
+
|
|
254
|
+
tf_output = self.tf_predictor.multi_table_predict(
|
|
255
|
+
page_input, [tbl_box], do_matching=self.do_cell_matching
|
|
256
|
+
)
|
|
257
|
+
table_out = tf_output[0]
|
|
258
|
+
table_cells = []
|
|
259
|
+
for element in table_out["tf_responses"]:
|
|
260
|
+
if not self.do_cell_matching:
|
|
261
|
+
the_bbox = BoundingBox.model_validate(
|
|
262
|
+
element["bbox"]
|
|
263
|
+
).scaled(1 / self.scale)
|
|
264
|
+
text_piece = page._backend.get_text_in_rect(the_bbox)
|
|
265
|
+
element["bbox"]["token"] = text_piece
|
|
266
|
+
|
|
267
|
+
tc = TableCell.model_validate(element)
|
|
268
|
+
if tc.bbox is not None:
|
|
269
|
+
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
|
270
|
+
table_cells.append(tc)
|
|
271
|
+
|
|
272
|
+
assert "predict_details" in table_out
|
|
273
|
+
|
|
274
|
+
# Retrieving cols/rows, after post processing:
|
|
275
|
+
num_rows = table_out["predict_details"].get("num_rows", 0)
|
|
276
|
+
num_cols = table_out["predict_details"].get("num_cols", 0)
|
|
277
|
+
otsl_seq = (
|
|
278
|
+
table_out["predict_details"]
|
|
279
|
+
.get("prediction", {})
|
|
280
|
+
.get("rs_seq", [])
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
tbl = Table(
|
|
284
|
+
otsl_seq=otsl_seq,
|
|
285
|
+
table_cells=table_cells,
|
|
286
|
+
num_rows=num_rows,
|
|
287
|
+
num_cols=num_cols,
|
|
288
|
+
id=table_cluster.id,
|
|
289
|
+
page_no=page.page_no,
|
|
290
|
+
cluster=table_cluster,
|
|
291
|
+
label=table_cluster.label,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
table_prediction.table_map[table_cluster.id] = tbl
|
|
295
|
+
|
|
296
|
+
if settings.debug.visualize_tables:
|
|
297
|
+
self.draw_table_and_cells(
|
|
298
|
+
conv_res,
|
|
299
|
+
page,
|
|
300
|
+
page.predictions.tablestructure.table_map.values(),
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
predictions.append(table_prediction)
|
|
304
|
+
|
|
305
|
+
return predictions
|
|
File without changes
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
import sys
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from transformers import StoppingCriteria
|
|
8
|
+
|
|
9
|
+
_log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GenerationStopper:
|
|
13
|
+
"""
|
|
14
|
+
Base interface for stopping logic.
|
|
15
|
+
- should_stop(s): True to stop given the current decoded text window.
|
|
16
|
+
- lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def should_stop(self, s: str) -> bool:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def lookback_tokens(self) -> int:
|
|
24
|
+
return sys.maxsize
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DocTagsRepetitionStopper(GenerationStopper):
|
|
28
|
+
"""
|
|
29
|
+
Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
|
|
30
|
+
but only when repeats are **consecutive** and both tag & inner text are identical.
|
|
31
|
+
|
|
32
|
+
Performance:
|
|
33
|
+
- Heavy check runs every N calls (default 32).
|
|
34
|
+
- Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
|
|
38
|
+
self.N = max(1, int(N))
|
|
39
|
+
self._lookback_tokens = max(1, int(lookback_tokens))
|
|
40
|
+
self._call_count = 0
|
|
41
|
+
|
|
42
|
+
# <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
|
|
43
|
+
self._PATTERN = re.compile(
|
|
44
|
+
r"""
|
|
45
|
+
<(?P<tag>[a-zA-Z0-9_]+)>\s*
|
|
46
|
+
(?P<prefix>.*?)?
|
|
47
|
+
<loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
|
|
48
|
+
(?P<text>.*?)
|
|
49
|
+
</(?P=tag)>
|
|
50
|
+
""",
|
|
51
|
+
re.DOTALL | re.VERBOSE,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# --- small helper ---
|
|
55
|
+
def _regular(self, vals: List[int]) -> bool:
|
|
56
|
+
"""3+ strictly increasing values with ~regular spacing (±20%)."""
|
|
57
|
+
if len(vals) < 3:
|
|
58
|
+
return False
|
|
59
|
+
diffs = [b - a for a, b in zip(vals, vals[1:])]
|
|
60
|
+
if any(d <= 0 for d in diffs):
|
|
61
|
+
return False
|
|
62
|
+
mean = sum(diffs) / len(diffs)
|
|
63
|
+
tol = 0.2 * mean
|
|
64
|
+
return all(abs(d - mean) <= tol for d in diffs)
|
|
65
|
+
|
|
66
|
+
def should_stop(self, s: str) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
|
|
69
|
+
with the same <tag> and identical inner text, where within that run we see:
|
|
70
|
+
- any exact duplicate (x,y,w,h), or
|
|
71
|
+
- stable X/W with regular Y progression, or
|
|
72
|
+
- stable Y/H with regular X progression.
|
|
73
|
+
"""
|
|
74
|
+
# Stream matches and evaluate runs on-the-fly to stay compact and fast.
|
|
75
|
+
prev_tag = prev_text = None
|
|
76
|
+
run = [] # list of (x,y,w,h)
|
|
77
|
+
|
|
78
|
+
def run_repetitive(boxes: List[tuple]) -> bool:
|
|
79
|
+
if len(boxes) < 3:
|
|
80
|
+
return False
|
|
81
|
+
# duplicates?
|
|
82
|
+
if len(set(boxes)) < len(boxes):
|
|
83
|
+
return True
|
|
84
|
+
xs, ys, ws, hs = zip(*boxes)
|
|
85
|
+
x_stable = all(x == xs[0] for x in xs)
|
|
86
|
+
y_stable = all(y == ys[0] for y in ys)
|
|
87
|
+
w_stable = all(w == ws[0] for w in ws)
|
|
88
|
+
h_stable = all(h == hs[0] for h in hs)
|
|
89
|
+
# horizontal (down the page): X/W stable, Y regular
|
|
90
|
+
if (x_stable or w_stable) and self._regular(list(ys)):
|
|
91
|
+
return True
|
|
92
|
+
# vertical (across): Y/H stable, X regular
|
|
93
|
+
if (y_stable or h_stable) and self._regular(list(xs)):
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
for m in self._PATTERN.finditer(s):
|
|
98
|
+
tag, text = m.group("tag"), m.group("text")
|
|
99
|
+
box = (
|
|
100
|
+
int(m.group("x")),
|
|
101
|
+
int(m.group("y")),
|
|
102
|
+
int(m.group("w")),
|
|
103
|
+
int(m.group("h")),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if prev_tag == tag and prev_text == text:
|
|
107
|
+
run.append(box) # consecutive same-tag+text
|
|
108
|
+
else:
|
|
109
|
+
# evaluate previous run before starting a new one
|
|
110
|
+
if run_repetitive(run):
|
|
111
|
+
return True
|
|
112
|
+
prev_tag, prev_text = tag, text
|
|
113
|
+
run = [box]
|
|
114
|
+
|
|
115
|
+
# check the last run
|
|
116
|
+
return run_repetitive(run)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class HFStoppingCriteriaWrapper(StoppingCriteria):
|
|
120
|
+
"""
|
|
121
|
+
Adapts any GenerationStopper to HuggingFace Transformers.
|
|
122
|
+
Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
tokenizer,
|
|
128
|
+
stopper: GenerationStopper,
|
|
129
|
+
*,
|
|
130
|
+
skip_special_tokens: bool = False,
|
|
131
|
+
):
|
|
132
|
+
self.tokenizer = tokenizer
|
|
133
|
+
self.stopper = stopper
|
|
134
|
+
self.skip_special_tokens = skip_special_tokens
|
|
135
|
+
|
|
136
|
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
|
137
|
+
lb = max(1, int(self.stopper.lookback_tokens()))
|
|
138
|
+
for seq in input_ids: # (batch, seq_len)
|
|
139
|
+
window = seq[-lb:] # slicing handles lb > len(seq)
|
|
140
|
+
try:
|
|
141
|
+
text = self.tokenizer.decode(
|
|
142
|
+
window, skip_special_tokens=self.skip_special_tokens
|
|
143
|
+
)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
_log.info(f"Decoding failed for stopping check: {e}")
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
if self.stopper.should_stop(text):
|
|
150
|
+
_log.info(
|
|
151
|
+
"HF wrapper: stopping due to TextStopper.should_stop==True"
|
|
152
|
+
)
|
|
153
|
+
return True
|
|
154
|
+
except Exception as e:
|
|
155
|
+
_log.info(f"Error in TextStopper.should_stop: {e}")
|
|
156
|
+
continue
|
|
157
|
+
return False
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
_log = logging.getLogger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def download_hf_model(
|
|
9
|
+
repo_id: str,
|
|
10
|
+
local_dir: Optional[Path] = None,
|
|
11
|
+
force: bool = False,
|
|
12
|
+
progress: bool = False,
|
|
13
|
+
revision: Optional[str] = None,
|
|
14
|
+
) -> Path:
|
|
15
|
+
from huggingface_hub import snapshot_download
|
|
16
|
+
from huggingface_hub.utils import disable_progress_bars
|
|
17
|
+
|
|
18
|
+
if not progress:
|
|
19
|
+
disable_progress_bars()
|
|
20
|
+
download_path = snapshot_download(
|
|
21
|
+
repo_id=repo_id,
|
|
22
|
+
force_download=force,
|
|
23
|
+
local_dir=local_dir,
|
|
24
|
+
revision=revision,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
return Path(download_path)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class HuggingFaceModelDownloadMixin:
|
|
31
|
+
@staticmethod
|
|
32
|
+
def download_models(
|
|
33
|
+
repo_id: str,
|
|
34
|
+
local_dir: Optional[Path] = None,
|
|
35
|
+
force: bool = False,
|
|
36
|
+
progress: bool = False,
|
|
37
|
+
revision: Optional[str] = None,
|
|
38
|
+
) -> Path:
|
|
39
|
+
return download_hf_model(
|
|
40
|
+
repo_id=repo_id,
|
|
41
|
+
local_dir=local_dir,
|
|
42
|
+
force=force,
|
|
43
|
+
progress=progress,
|
|
44
|
+
revision=revision,
|
|
45
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from collections.abc import Iterable
|
|
2
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from PIL.Image import Image
|
|
7
|
+
|
|
8
|
+
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
|
9
|
+
from docling.datamodel.document import ConversionResult
|
|
10
|
+
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
|
11
|
+
from docling.exceptions import OperationNotAllowed
|
|
12
|
+
from docling.models.base_model import BaseVlmPageModel
|
|
13
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
14
|
+
from docling.utils.api_image_request import (
|
|
15
|
+
api_image_request,
|
|
16
|
+
api_image_request_streaming,
|
|
17
|
+
)
|
|
18
|
+
from docling.utils.profiling import TimeRecorder
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ApiVlmModel(BaseVlmPageModel):
|
|
22
|
+
# Override the vlm_options type annotation from BaseVlmPageModel
|
|
23
|
+
vlm_options: ApiVlmOptions # type: ignore[assignment]
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
enabled: bool,
|
|
28
|
+
enable_remote_services: bool,
|
|
29
|
+
vlm_options: ApiVlmOptions,
|
|
30
|
+
):
|
|
31
|
+
self.enabled = enabled
|
|
32
|
+
self.vlm_options = vlm_options
|
|
33
|
+
if self.enabled:
|
|
34
|
+
if not enable_remote_services:
|
|
35
|
+
raise OperationNotAllowed(
|
|
36
|
+
"Connections to remote services is only allowed when set explicitly. "
|
|
37
|
+
"pipeline_options.enable_remote_services=True, or using the CLI "
|
|
38
|
+
"--enable-remote-services."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
self.timeout = self.vlm_options.timeout
|
|
42
|
+
self.concurrency = self.vlm_options.concurrency
|
|
43
|
+
self.params = {
|
|
44
|
+
**self.vlm_options.params,
|
|
45
|
+
"temperature": self.vlm_options.temperature,
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
def __call__(
|
|
49
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
50
|
+
) -> Iterable[Page]:
|
|
51
|
+
page_list = list(page_batch)
|
|
52
|
+
if not page_list:
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
original_order = page_list[:]
|
|
56
|
+
valid_pages = []
|
|
57
|
+
|
|
58
|
+
for page in page_list:
|
|
59
|
+
assert page._backend is not None
|
|
60
|
+
if page._backend.is_valid():
|
|
61
|
+
valid_pages.append(page)
|
|
62
|
+
|
|
63
|
+
# Process valid pages in batch
|
|
64
|
+
if valid_pages:
|
|
65
|
+
with TimeRecorder(conv_res, "vlm"):
|
|
66
|
+
# Prepare images and prompts for batch processing
|
|
67
|
+
images = []
|
|
68
|
+
prompts = []
|
|
69
|
+
pages_with_images = []
|
|
70
|
+
|
|
71
|
+
for page in valid_pages:
|
|
72
|
+
assert page.size is not None
|
|
73
|
+
hi_res_image = page.get_image(
|
|
74
|
+
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Only process pages with valid images
|
|
78
|
+
if hi_res_image is not None:
|
|
79
|
+
images.append(hi_res_image)
|
|
80
|
+
prompt = self._build_prompt_safe(page)
|
|
81
|
+
prompts.append(prompt)
|
|
82
|
+
pages_with_images.append(page)
|
|
83
|
+
|
|
84
|
+
# Use process_images for the actual inference
|
|
85
|
+
if images: # Only if we have valid images
|
|
86
|
+
with TimeRecorder(conv_res, "vlm_inference"):
|
|
87
|
+
predictions = list(self.process_images(images, prompts))
|
|
88
|
+
|
|
89
|
+
# Attach results to pages
|
|
90
|
+
for page, prediction in zip(pages_with_images, predictions):
|
|
91
|
+
page.predictions.vlm_response = prediction
|
|
92
|
+
|
|
93
|
+
# Yield pages preserving original order
|
|
94
|
+
for page in original_order:
|
|
95
|
+
yield page
|
|
96
|
+
|
|
97
|
+
def process_images(
|
|
98
|
+
self,
|
|
99
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
100
|
+
prompt: Union[str, list[str]],
|
|
101
|
+
) -> Iterable[VlmPrediction]:
|
|
102
|
+
"""Process raw images without page metadata."""
|
|
103
|
+
images = list(image_batch)
|
|
104
|
+
|
|
105
|
+
# Handle prompt parameter
|
|
106
|
+
if isinstance(prompt, str):
|
|
107
|
+
prompts = [prompt] * len(images)
|
|
108
|
+
elif isinstance(prompt, list):
|
|
109
|
+
if len(prompt) != len(images):
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Prompt list length ({len(prompt)}) must match image count ({len(images)})"
|
|
112
|
+
)
|
|
113
|
+
prompts = prompt
|
|
114
|
+
|
|
115
|
+
def _process_single_image(image_prompt_pair):
|
|
116
|
+
image, prompt_text = image_prompt_pair
|
|
117
|
+
|
|
118
|
+
# Convert numpy array to PIL Image if needed
|
|
119
|
+
if isinstance(image, np.ndarray):
|
|
120
|
+
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
|
121
|
+
from PIL import Image as PILImage
|
|
122
|
+
|
|
123
|
+
image = PILImage.fromarray(image.astype(np.uint8))
|
|
124
|
+
elif image.ndim == 2:
|
|
125
|
+
from PIL import Image as PILImage
|
|
126
|
+
|
|
127
|
+
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"Unsupported numpy array shape: {image.shape}")
|
|
130
|
+
|
|
131
|
+
# Ensure image is in RGB mode
|
|
132
|
+
if image.mode != "RGB":
|
|
133
|
+
image = image.convert("RGB")
|
|
134
|
+
|
|
135
|
+
stop_reason = VlmStopReason.UNSPECIFIED
|
|
136
|
+
|
|
137
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
138
|
+
# Instantiate any GenerationStopper classes before passing to streaming
|
|
139
|
+
instantiated_stoppers = []
|
|
140
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
141
|
+
if isinstance(criteria, GenerationStopper):
|
|
142
|
+
instantiated_stoppers.append(criteria)
|
|
143
|
+
elif isinstance(criteria, type) and issubclass(
|
|
144
|
+
criteria, GenerationStopper
|
|
145
|
+
):
|
|
146
|
+
instantiated_stoppers.append(criteria())
|
|
147
|
+
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
|
148
|
+
|
|
149
|
+
# Streaming path with early abort support
|
|
150
|
+
page_tags, num_tokens = api_image_request_streaming(
|
|
151
|
+
image=image,
|
|
152
|
+
prompt=prompt_text,
|
|
153
|
+
url=self.vlm_options.url,
|
|
154
|
+
timeout=self.timeout,
|
|
155
|
+
headers=self.vlm_options.headers,
|
|
156
|
+
generation_stoppers=instantiated_stoppers,
|
|
157
|
+
**self.params,
|
|
158
|
+
)
|
|
159
|
+
else:
|
|
160
|
+
# Non-streaming fallback (existing behavior)
|
|
161
|
+
page_tags, num_tokens, stop_reason = api_image_request(
|
|
162
|
+
image=image,
|
|
163
|
+
prompt=prompt_text,
|
|
164
|
+
url=self.vlm_options.url,
|
|
165
|
+
timeout=self.timeout,
|
|
166
|
+
headers=self.vlm_options.headers,
|
|
167
|
+
**self.params,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
page_tags = self.vlm_options.decode_response(page_tags)
|
|
171
|
+
input_prompt = prompt_text if self.vlm_options.track_input_prompt else None
|
|
172
|
+
return VlmPrediction(
|
|
173
|
+
text=page_tags,
|
|
174
|
+
num_tokens=num_tokens,
|
|
175
|
+
stop_reason=stop_reason,
|
|
176
|
+
input_prompt=input_prompt,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
|
180
|
+
yield from executor.map(_process_single_image, zip(images, prompts))
|