docling 2.69.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docling might be problematic. Click here for more details.
- docling/__init__.py +0 -0
- docling/backend/__init__.py +0 -0
- docling/backend/abstract_backend.py +84 -0
- docling/backend/asciidoc_backend.py +443 -0
- docling/backend/csv_backend.py +125 -0
- docling/backend/docling_parse_backend.py +237 -0
- docling/backend/docling_parse_v2_backend.py +276 -0
- docling/backend/docling_parse_v4_backend.py +260 -0
- docling/backend/docx/__init__.py +0 -0
- docling/backend/docx/drawingml/utils.py +131 -0
- docling/backend/docx/latex/__init__.py +0 -0
- docling/backend/docx/latex/latex_dict.py +274 -0
- docling/backend/docx/latex/omml.py +459 -0
- docling/backend/html_backend.py +1502 -0
- docling/backend/image_backend.py +188 -0
- docling/backend/json/__init__.py +0 -0
- docling/backend/json/docling_json_backend.py +58 -0
- docling/backend/md_backend.py +618 -0
- docling/backend/mets_gbs_backend.py +399 -0
- docling/backend/msexcel_backend.py +686 -0
- docling/backend/mspowerpoint_backend.py +398 -0
- docling/backend/msword_backend.py +1663 -0
- docling/backend/noop_backend.py +51 -0
- docling/backend/pdf_backend.py +82 -0
- docling/backend/pypdfium2_backend.py +417 -0
- docling/backend/webvtt_backend.py +572 -0
- docling/backend/xml/__init__.py +0 -0
- docling/backend/xml/jats_backend.py +819 -0
- docling/backend/xml/uspto_backend.py +1905 -0
- docling/chunking/__init__.py +12 -0
- docling/cli/__init__.py +0 -0
- docling/cli/main.py +974 -0
- docling/cli/models.py +196 -0
- docling/cli/tools.py +17 -0
- docling/datamodel/__init__.py +0 -0
- docling/datamodel/accelerator_options.py +69 -0
- docling/datamodel/asr_model_specs.py +494 -0
- docling/datamodel/backend_options.py +102 -0
- docling/datamodel/base_models.py +493 -0
- docling/datamodel/document.py +699 -0
- docling/datamodel/extraction.py +39 -0
- docling/datamodel/layout_model_specs.py +91 -0
- docling/datamodel/pipeline_options.py +457 -0
- docling/datamodel/pipeline_options_asr_model.py +78 -0
- docling/datamodel/pipeline_options_vlm_model.py +136 -0
- docling/datamodel/settings.py +65 -0
- docling/datamodel/vlm_model_specs.py +365 -0
- docling/document_converter.py +559 -0
- docling/document_extractor.py +327 -0
- docling/exceptions.py +10 -0
- docling/experimental/__init__.py +5 -0
- docling/experimental/datamodel/__init__.py +1 -0
- docling/experimental/datamodel/table_crops_layout_options.py +13 -0
- docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py +45 -0
- docling/experimental/models/__init__.py +3 -0
- docling/experimental/models/table_crops_layout_model.py +114 -0
- docling/experimental/pipeline/__init__.py +1 -0
- docling/experimental/pipeline/threaded_layout_vlm_pipeline.py +439 -0
- docling/models/__init__.py +0 -0
- docling/models/base_layout_model.py +39 -0
- docling/models/base_model.py +230 -0
- docling/models/base_ocr_model.py +241 -0
- docling/models/base_table_model.py +45 -0
- docling/models/extraction/__init__.py +0 -0
- docling/models/extraction/nuextract_transformers_model.py +305 -0
- docling/models/factories/__init__.py +47 -0
- docling/models/factories/base_factory.py +122 -0
- docling/models/factories/layout_factory.py +7 -0
- docling/models/factories/ocr_factory.py +11 -0
- docling/models/factories/picture_description_factory.py +11 -0
- docling/models/factories/table_factory.py +7 -0
- docling/models/picture_description_base_model.py +149 -0
- docling/models/plugins/__init__.py +0 -0
- docling/models/plugins/defaults.py +60 -0
- docling/models/stages/__init__.py +0 -0
- docling/models/stages/code_formula/__init__.py +0 -0
- docling/models/stages/code_formula/code_formula_model.py +342 -0
- docling/models/stages/layout/__init__.py +0 -0
- docling/models/stages/layout/layout_model.py +249 -0
- docling/models/stages/ocr/__init__.py +0 -0
- docling/models/stages/ocr/auto_ocr_model.py +132 -0
- docling/models/stages/ocr/easyocr_model.py +200 -0
- docling/models/stages/ocr/ocr_mac_model.py +145 -0
- docling/models/stages/ocr/rapid_ocr_model.py +328 -0
- docling/models/stages/ocr/tesseract_ocr_cli_model.py +331 -0
- docling/models/stages/ocr/tesseract_ocr_model.py +262 -0
- docling/models/stages/page_assemble/__init__.py +0 -0
- docling/models/stages/page_assemble/page_assemble_model.py +156 -0
- docling/models/stages/page_preprocessing/__init__.py +0 -0
- docling/models/stages/page_preprocessing/page_preprocessing_model.py +145 -0
- docling/models/stages/picture_classifier/__init__.py +0 -0
- docling/models/stages/picture_classifier/document_picture_classifier.py +246 -0
- docling/models/stages/picture_description/__init__.py +0 -0
- docling/models/stages/picture_description/picture_description_api_model.py +66 -0
- docling/models/stages/picture_description/picture_description_vlm_model.py +123 -0
- docling/models/stages/reading_order/__init__.py +0 -0
- docling/models/stages/reading_order/readingorder_model.py +431 -0
- docling/models/stages/table_structure/__init__.py +0 -0
- docling/models/stages/table_structure/table_structure_model.py +305 -0
- docling/models/utils/__init__.py +0 -0
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +45 -0
- docling/models/vlm_pipeline_models/__init__.py +1 -0
- docling/models/vlm_pipeline_models/api_vlm_model.py +180 -0
- docling/models/vlm_pipeline_models/hf_transformers_model.py +391 -0
- docling/models/vlm_pipeline_models/mlx_model.py +325 -0
- docling/models/vlm_pipeline_models/vllm_model.py +344 -0
- docling/pipeline/__init__.py +0 -0
- docling/pipeline/asr_pipeline.py +431 -0
- docling/pipeline/base_extraction_pipeline.py +72 -0
- docling/pipeline/base_pipeline.py +326 -0
- docling/pipeline/extraction_vlm_pipeline.py +207 -0
- docling/pipeline/legacy_standard_pdf_pipeline.py +262 -0
- docling/pipeline/simple_pipeline.py +55 -0
- docling/pipeline/standard_pdf_pipeline.py +859 -0
- docling/pipeline/threaded_standard_pdf_pipeline.py +5 -0
- docling/pipeline/vlm_pipeline.py +416 -0
- docling/py.typed +1 -0
- docling/utils/__init__.py +0 -0
- docling/utils/accelerator_utils.py +97 -0
- docling/utils/api_image_request.py +205 -0
- docling/utils/deepseekocr_utils.py +388 -0
- docling/utils/export.py +146 -0
- docling/utils/glm_utils.py +361 -0
- docling/utils/layout_postprocessor.py +683 -0
- docling/utils/locks.py +3 -0
- docling/utils/model_downloader.py +168 -0
- docling/utils/ocr_utils.py +69 -0
- docling/utils/orientation.py +65 -0
- docling/utils/profiling.py +65 -0
- docling/utils/utils.py +65 -0
- docling/utils/visualization.py +85 -0
- docling-2.69.0.dist-info/METADATA +237 -0
- docling-2.69.0.dist-info/RECORD +138 -0
- docling-2.69.0.dist-info/WHEEL +5 -0
- docling-2.69.0.dist-info/entry_points.txt +6 -0
- docling-2.69.0.dist-info/licenses/LICENSE +21 -0
- docling-2.69.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import threading
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL.Image import Image
|
|
11
|
+
from transformers import StoppingCriteria
|
|
12
|
+
|
|
13
|
+
from docling.datamodel.accelerator_options import (
|
|
14
|
+
AcceleratorOptions,
|
|
15
|
+
)
|
|
16
|
+
from docling.datamodel.base_models import (
|
|
17
|
+
Page,
|
|
18
|
+
VlmPrediction,
|
|
19
|
+
VlmPredictionToken,
|
|
20
|
+
VlmStopReason,
|
|
21
|
+
)
|
|
22
|
+
from docling.datamodel.document import ConversionResult
|
|
23
|
+
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
|
24
|
+
from docling.models.base_model import BaseVlmPageModel
|
|
25
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
26
|
+
from docling.models.utils.hf_model_download import (
|
|
27
|
+
HuggingFaceModelDownloadMixin,
|
|
28
|
+
)
|
|
29
|
+
from docling.utils.profiling import TimeRecorder
|
|
30
|
+
|
|
31
|
+
_log = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
# Global lock for MLX model calls - MLX models are not thread-safe
|
|
34
|
+
# All MLX models share this lock to prevent concurrent MLX operations
|
|
35
|
+
_MLX_GLOBAL_LOCK = threading.Lock()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
enabled: bool,
|
|
42
|
+
artifacts_path: Optional[Path],
|
|
43
|
+
accelerator_options: AcceleratorOptions,
|
|
44
|
+
vlm_options: InlineVlmOptions,
|
|
45
|
+
):
|
|
46
|
+
self.enabled = enabled
|
|
47
|
+
|
|
48
|
+
self.vlm_options = vlm_options
|
|
49
|
+
self.max_tokens = vlm_options.max_new_tokens
|
|
50
|
+
self.temperature = vlm_options.temperature
|
|
51
|
+
|
|
52
|
+
if self.enabled:
|
|
53
|
+
try:
|
|
54
|
+
from mlx_vlm import generate, load, stream_generate # type: ignore
|
|
55
|
+
from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
|
|
56
|
+
from mlx_vlm.utils import load_config # type: ignore
|
|
57
|
+
except ImportError:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
63
|
+
|
|
64
|
+
self.apply_chat_template = apply_chat_template
|
|
65
|
+
self.stream_generate = stream_generate
|
|
66
|
+
|
|
67
|
+
# PARAMETERS:
|
|
68
|
+
if artifacts_path is None:
|
|
69
|
+
artifacts_path = self.download_models(
|
|
70
|
+
self.vlm_options.repo_id,
|
|
71
|
+
revision=self.vlm_options.revision,
|
|
72
|
+
)
|
|
73
|
+
elif (artifacts_path / repo_cache_folder).exists():
|
|
74
|
+
artifacts_path = artifacts_path / repo_cache_folder
|
|
75
|
+
|
|
76
|
+
## Load the model
|
|
77
|
+
self.vlm_model, self.processor = load(artifacts_path)
|
|
78
|
+
self.config = load_config(artifacts_path)
|
|
79
|
+
|
|
80
|
+
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
|
|
81
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
82
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
83
|
+
if isinstance(criteria, StoppingCriteria):
|
|
84
|
+
raise ValueError(
|
|
85
|
+
f"MLX models do not support HuggingFace StoppingCriteria instances. "
|
|
86
|
+
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
|
|
87
|
+
)
|
|
88
|
+
elif isinstance(criteria, type) and issubclass(
|
|
89
|
+
criteria, StoppingCriteria
|
|
90
|
+
):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
f"MLX models do not support HuggingFace StoppingCriteria classes. "
|
|
93
|
+
f"Found {criteria.__name__}. Use GenerationStopper instead."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def __call__(
|
|
97
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
98
|
+
) -> Iterable[Page]:
|
|
99
|
+
page_list = list(page_batch)
|
|
100
|
+
if not page_list:
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
valid_pages = []
|
|
104
|
+
invalid_pages = []
|
|
105
|
+
|
|
106
|
+
for page in page_list:
|
|
107
|
+
assert page._backend is not None
|
|
108
|
+
if not page._backend.is_valid():
|
|
109
|
+
invalid_pages.append(page)
|
|
110
|
+
else:
|
|
111
|
+
valid_pages.append(page)
|
|
112
|
+
|
|
113
|
+
# Process valid pages in batch
|
|
114
|
+
if valid_pages:
|
|
115
|
+
with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
|
|
116
|
+
# Prepare images and prompts for batch processing
|
|
117
|
+
images = []
|
|
118
|
+
user_prompts = []
|
|
119
|
+
pages_with_images = []
|
|
120
|
+
|
|
121
|
+
for page in valid_pages:
|
|
122
|
+
assert page.size is not None
|
|
123
|
+
hi_res_image = page.get_image(
|
|
124
|
+
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Only process pages with valid images
|
|
128
|
+
if hi_res_image is not None:
|
|
129
|
+
images.append(hi_res_image)
|
|
130
|
+
|
|
131
|
+
# Define prompt structure
|
|
132
|
+
user_prompt = self._build_prompt_safe(page)
|
|
133
|
+
|
|
134
|
+
user_prompts.append(user_prompt)
|
|
135
|
+
pages_with_images.append(page)
|
|
136
|
+
|
|
137
|
+
# Use process_images for the actual inference
|
|
138
|
+
if images: # Only if we have valid images
|
|
139
|
+
predictions = list(self.process_images(images, user_prompts))
|
|
140
|
+
|
|
141
|
+
# Attach results to pages
|
|
142
|
+
for page, prediction in zip(pages_with_images, predictions):
|
|
143
|
+
page.predictions.vlm_response = prediction
|
|
144
|
+
|
|
145
|
+
# Yield all pages (valid and invalid)
|
|
146
|
+
for page in invalid_pages:
|
|
147
|
+
yield page
|
|
148
|
+
for page in valid_pages:
|
|
149
|
+
yield page
|
|
150
|
+
|
|
151
|
+
def process_images(
|
|
152
|
+
self,
|
|
153
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
154
|
+
prompt: Union[str, list[str]],
|
|
155
|
+
) -> Iterable[VlmPrediction]:
|
|
156
|
+
"""Process raw images without page metadata.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
image_batch: Iterable of PIL Images or numpy arrays
|
|
160
|
+
prompt: Either:
|
|
161
|
+
- str: Single prompt used for all images
|
|
162
|
+
- list[str]: List of prompts (one per image, must match image count)
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
ValueError: If prompt list length doesn't match image count.
|
|
166
|
+
"""
|
|
167
|
+
# Convert image batch to list for length validation
|
|
168
|
+
image_list = list(image_batch)
|
|
169
|
+
|
|
170
|
+
if len(image_list) == 0:
|
|
171
|
+
return
|
|
172
|
+
|
|
173
|
+
# Handle prompt parameter
|
|
174
|
+
if isinstance(prompt, str):
|
|
175
|
+
# Single prompt for all images
|
|
176
|
+
user_prompts = [prompt] * len(image_list)
|
|
177
|
+
elif isinstance(prompt, list):
|
|
178
|
+
# List of prompts (one per image)
|
|
179
|
+
if len(prompt) != len(image_list):
|
|
180
|
+
raise ValueError(
|
|
181
|
+
f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
|
|
182
|
+
)
|
|
183
|
+
user_prompts = prompt
|
|
184
|
+
else:
|
|
185
|
+
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
|
186
|
+
|
|
187
|
+
# MLX models are not thread-safe - use global lock to serialize access
|
|
188
|
+
with _MLX_GLOBAL_LOCK:
|
|
189
|
+
_log.debug("MLX model: Acquired global lock for thread safety")
|
|
190
|
+
for image, user_prompt in zip(image_list, user_prompts):
|
|
191
|
+
# Convert numpy array to PIL Image if needed
|
|
192
|
+
if isinstance(image, np.ndarray):
|
|
193
|
+
if image.ndim == 3 and image.shape[2] in [3, 4]:
|
|
194
|
+
# RGB or RGBA array
|
|
195
|
+
from PIL import Image as PILImage
|
|
196
|
+
|
|
197
|
+
image = PILImage.fromarray(image.astype(np.uint8))
|
|
198
|
+
elif image.ndim == 2:
|
|
199
|
+
# Grayscale array
|
|
200
|
+
from PIL import Image as PILImage
|
|
201
|
+
|
|
202
|
+
image = PILImage.fromarray(image.astype(np.uint8), mode="L")
|
|
203
|
+
else:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
f"Unsupported numpy array shape: {image.shape}"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
|
209
|
+
if image.mode != "RGB":
|
|
210
|
+
image = image.convert("RGB")
|
|
211
|
+
|
|
212
|
+
# Use the MLX chat template approach like in the __call__ method
|
|
213
|
+
formatted_prompt = self.apply_chat_template(
|
|
214
|
+
self.processor, self.config, user_prompt, num_images=1
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Stream generate with stop strings and custom stopping criteria support
|
|
218
|
+
start_time = time.time()
|
|
219
|
+
_log.debug("start generating ...")
|
|
220
|
+
|
|
221
|
+
tokens: list[VlmPredictionToken] = []
|
|
222
|
+
output = ""
|
|
223
|
+
|
|
224
|
+
# Use stream_generate for proper stop string handling
|
|
225
|
+
for token in self.stream_generate(
|
|
226
|
+
self.vlm_model,
|
|
227
|
+
self.processor,
|
|
228
|
+
formatted_prompt,
|
|
229
|
+
[image], # MLX stream_generate expects list of images
|
|
230
|
+
max_tokens=self.max_tokens,
|
|
231
|
+
verbose=False,
|
|
232
|
+
temp=self.temperature,
|
|
233
|
+
):
|
|
234
|
+
# Collect token information
|
|
235
|
+
if len(token.logprobs.shape) == 1:
|
|
236
|
+
tokens.append(
|
|
237
|
+
VlmPredictionToken(
|
|
238
|
+
text=token.text,
|
|
239
|
+
token=token.token,
|
|
240
|
+
logprob=token.logprobs[token.token],
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
elif (
|
|
244
|
+
len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
|
|
245
|
+
):
|
|
246
|
+
tokens.append(
|
|
247
|
+
VlmPredictionToken(
|
|
248
|
+
text=token.text,
|
|
249
|
+
token=token.token,
|
|
250
|
+
logprob=token.logprobs[0, token.token],
|
|
251
|
+
)
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
_log.warning(
|
|
255
|
+
f"incompatible shape for logprobs: {token.logprobs.shape}"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
output += token.text
|
|
259
|
+
|
|
260
|
+
# Check for any configured stop strings
|
|
261
|
+
if self.vlm_options.stop_strings:
|
|
262
|
+
if any(
|
|
263
|
+
stop_str in output
|
|
264
|
+
for stop_str in self.vlm_options.stop_strings
|
|
265
|
+
):
|
|
266
|
+
_log.debug("Stopping generation due to stop string match")
|
|
267
|
+
break
|
|
268
|
+
|
|
269
|
+
# Check for custom stopping criteria (GenerationStopper instances)
|
|
270
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
271
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
272
|
+
# Handle both instances and classes of GenerationStopper
|
|
273
|
+
if isinstance(criteria, GenerationStopper):
|
|
274
|
+
stopper = criteria
|
|
275
|
+
elif isinstance(criteria, type) and issubclass(
|
|
276
|
+
criteria, GenerationStopper
|
|
277
|
+
):
|
|
278
|
+
stopper = criteria()
|
|
279
|
+
|
|
280
|
+
# Determine the text window to check based on lookback_tokens
|
|
281
|
+
lookback_tokens = stopper.lookback_tokens()
|
|
282
|
+
# Check only the last N characters worth of text
|
|
283
|
+
# This is a simplified approach - in practice, you might want to
|
|
284
|
+
# decode the last N tokens from the token list for more accuracy
|
|
285
|
+
text_to_check = (
|
|
286
|
+
output[-lookback_tokens:]
|
|
287
|
+
if len(output) > lookback_tokens
|
|
288
|
+
else output
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
if stopper.should_stop(text_to_check):
|
|
293
|
+
_log.info(
|
|
294
|
+
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
|
|
295
|
+
)
|
|
296
|
+
break
|
|
297
|
+
except Exception as e:
|
|
298
|
+
_log.warning(
|
|
299
|
+
f"Error in GenerationStopper.should_stop: {e}"
|
|
300
|
+
)
|
|
301
|
+
continue
|
|
302
|
+
else: # note: for-else idiom
|
|
303
|
+
continue # Only executed if the inner loop didn't break
|
|
304
|
+
break # Break the outer loop if any stopper triggered
|
|
305
|
+
|
|
306
|
+
generation_time = time.time() - start_time
|
|
307
|
+
|
|
308
|
+
_log.debug(
|
|
309
|
+
f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Apply decode_response to the output before yielding
|
|
313
|
+
decoded_output = self.vlm_options.decode_response(output)
|
|
314
|
+
input_prompt = (
|
|
315
|
+
formatted_prompt if self.vlm_options.track_input_prompt else None
|
|
316
|
+
)
|
|
317
|
+
yield VlmPrediction(
|
|
318
|
+
text=decoded_output,
|
|
319
|
+
generation_time=generation_time,
|
|
320
|
+
generated_tokens=tokens,
|
|
321
|
+
num_tokens=len(tokens),
|
|
322
|
+
stop_reason=VlmStopReason.UNSPECIFIED,
|
|
323
|
+
input_prompt=input_prompt,
|
|
324
|
+
)
|
|
325
|
+
_log.debug("MLX model: Released global lock")
|
|
@@ -0,0 +1,344 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, Optional, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from PIL.Image import Image
|
|
10
|
+
|
|
11
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
|
12
|
+
from docling.datamodel.base_models import (
|
|
13
|
+
Page,
|
|
14
|
+
VlmPrediction,
|
|
15
|
+
VlmPredictionToken,
|
|
16
|
+
VlmStopReason,
|
|
17
|
+
)
|
|
18
|
+
from docling.datamodel.document import ConversionResult
|
|
19
|
+
from docling.datamodel.pipeline_options_vlm_model import (
|
|
20
|
+
InlineVlmOptions,
|
|
21
|
+
TransformersPromptStyle,
|
|
22
|
+
)
|
|
23
|
+
from docling.models.base_model import BaseVlmPageModel
|
|
24
|
+
from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
|
|
25
|
+
from docling.utils.accelerator_utils import decide_device
|
|
26
|
+
from docling.utils.profiling import TimeRecorder
|
|
27
|
+
|
|
28
|
+
_log = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
32
|
+
"""
|
|
33
|
+
vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
|
|
34
|
+
via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# --------- Allowlist of vLLM args ---------
|
|
38
|
+
# SamplingParams (runtime generation controls)
|
|
39
|
+
_VLLM_SAMPLING_KEYS = {
|
|
40
|
+
# Core
|
|
41
|
+
"max_tokens",
|
|
42
|
+
"temperature",
|
|
43
|
+
"top_p",
|
|
44
|
+
"top_k",
|
|
45
|
+
# Penalties
|
|
46
|
+
"presence_penalty",
|
|
47
|
+
"frequency_penalty",
|
|
48
|
+
"repetition_penalty",
|
|
49
|
+
# Stops / outputs
|
|
50
|
+
"stop",
|
|
51
|
+
"stop_token_ids",
|
|
52
|
+
"skip_special_tokens",
|
|
53
|
+
"spaces_between_special_tokens",
|
|
54
|
+
# Search / length
|
|
55
|
+
"n",
|
|
56
|
+
"best_of",
|
|
57
|
+
"length_penalty",
|
|
58
|
+
"early_stopping",
|
|
59
|
+
# Misc
|
|
60
|
+
"logprobs",
|
|
61
|
+
"prompt_logprobs",
|
|
62
|
+
"min_p",
|
|
63
|
+
"seed",
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
# LLM(...) / EngineArgs (engine/load-time controls)
|
|
67
|
+
_VLLM_ENGINE_KEYS = {
|
|
68
|
+
# Model/tokenizer/impl
|
|
69
|
+
"tokenizer",
|
|
70
|
+
"tokenizer_mode",
|
|
71
|
+
"download_dir",
|
|
72
|
+
# Parallelism / memory / lengths
|
|
73
|
+
"tensor_parallel_size",
|
|
74
|
+
"pipeline_parallel_size",
|
|
75
|
+
"gpu_memory_utilization",
|
|
76
|
+
"max_model_len",
|
|
77
|
+
"max_num_batched_tokens",
|
|
78
|
+
"kv_cache_dtype",
|
|
79
|
+
"dtype",
|
|
80
|
+
# Quantization (coarse switch)
|
|
81
|
+
"quantization",
|
|
82
|
+
# Multimodal limits
|
|
83
|
+
"limit_mm_per_prompt",
|
|
84
|
+
# Execution toggles
|
|
85
|
+
"enforce_eager",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
enabled: bool,
|
|
91
|
+
artifacts_path: Optional[Path],
|
|
92
|
+
accelerator_options: AcceleratorOptions,
|
|
93
|
+
vlm_options: InlineVlmOptions,
|
|
94
|
+
):
|
|
95
|
+
self.enabled = enabled
|
|
96
|
+
self.vlm_options: InlineVlmOptions = vlm_options
|
|
97
|
+
|
|
98
|
+
self.llm = None
|
|
99
|
+
self.sampling_params = None
|
|
100
|
+
self.processor = None # used for CHAT templating in formulate_prompt()
|
|
101
|
+
self.device = "cpu"
|
|
102
|
+
self.max_new_tokens = vlm_options.max_new_tokens
|
|
103
|
+
self.temperature = vlm_options.temperature
|
|
104
|
+
|
|
105
|
+
if not self.enabled:
|
|
106
|
+
return
|
|
107
|
+
|
|
108
|
+
from transformers import AutoProcessor
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
from vllm import LLM, SamplingParams
|
|
112
|
+
except ImportError:
|
|
113
|
+
if sys.version_info < (3, 14):
|
|
114
|
+
raise ImportError(
|
|
115
|
+
"vllm is not installed. Please install it via `pip install vllm`."
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
raise ImportError(
|
|
119
|
+
"vllm is not installed. It is not yet available on Python 3.14."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Device selection
|
|
123
|
+
self.device = decide_device(
|
|
124
|
+
accelerator_options.device, supported_devices=vlm_options.supported_devices
|
|
125
|
+
)
|
|
126
|
+
_log.debug(f"Available device for VLM: {self.device}")
|
|
127
|
+
|
|
128
|
+
# Resolve artifacts path / cache folder
|
|
129
|
+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
130
|
+
if artifacts_path is None:
|
|
131
|
+
artifacts_path = self.download_models(
|
|
132
|
+
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
|
133
|
+
)
|
|
134
|
+
elif (artifacts_path / repo_cache_folder).exists():
|
|
135
|
+
artifacts_path = artifacts_path / repo_cache_folder
|
|
136
|
+
|
|
137
|
+
# --------- Strict split & validation of extra_generation_config ---------
|
|
138
|
+
extra_cfg = self.vlm_options.extra_generation_config
|
|
139
|
+
|
|
140
|
+
load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
|
|
141
|
+
gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
|
|
142
|
+
|
|
143
|
+
unknown = sorted(
|
|
144
|
+
k
|
|
145
|
+
for k in extra_cfg.keys()
|
|
146
|
+
if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
|
|
147
|
+
)
|
|
148
|
+
if unknown:
|
|
149
|
+
_log.warning(
|
|
150
|
+
"Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# --------- Construct LLM kwargs (engine/load-time) ---------
|
|
154
|
+
llm_kwargs: Dict[str, Any] = {
|
|
155
|
+
"model": str(artifacts_path),
|
|
156
|
+
"model_impl": "transformers",
|
|
157
|
+
"limit_mm_per_prompt": {"image": 1},
|
|
158
|
+
"revision": self.vlm_options.revision,
|
|
159
|
+
"trust_remote_code": self.vlm_options.trust_remote_code,
|
|
160
|
+
**load_cfg,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if self.device == "cpu":
|
|
164
|
+
llm_kwargs.setdefault("enforce_eager", True)
|
|
165
|
+
else:
|
|
166
|
+
llm_kwargs.setdefault(
|
|
167
|
+
"gpu_memory_utilization", 0.3
|
|
168
|
+
) # room for other models
|
|
169
|
+
|
|
170
|
+
# Quantization (kept as-is; coarse)
|
|
171
|
+
if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
|
|
172
|
+
llm_kwargs.setdefault("quantization", "bitsandbytes")
|
|
173
|
+
|
|
174
|
+
# Initialize vLLM LLM
|
|
175
|
+
self.llm = LLM(**llm_kwargs)
|
|
176
|
+
|
|
177
|
+
# Initialize processor for prompt templating (needed for CHAT style)
|
|
178
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
179
|
+
artifacts_path,
|
|
180
|
+
trust_remote_code=self.vlm_options.trust_remote_code,
|
|
181
|
+
revision=self.vlm_options.revision,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# --------- SamplingParams (runtime) ---------
|
|
185
|
+
self.sampling_params = SamplingParams(
|
|
186
|
+
temperature=self.temperature,
|
|
187
|
+
max_tokens=self.max_new_tokens,
|
|
188
|
+
stop=(self.vlm_options.stop_strings or None),
|
|
189
|
+
**gen_cfg,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def __call__(
|
|
193
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
194
|
+
) -> Iterable[Page]:
|
|
195
|
+
# If disabled, pass-through
|
|
196
|
+
if not self.enabled:
|
|
197
|
+
for page in page_batch:
|
|
198
|
+
yield page
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
page_list = list(page_batch)
|
|
202
|
+
if not page_list:
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
# Preserve original order
|
|
206
|
+
original_order = page_list[:]
|
|
207
|
+
|
|
208
|
+
# Separate valid/invalid
|
|
209
|
+
valid_pages: list[Page] = []
|
|
210
|
+
invalid_pages: list[Page] = []
|
|
211
|
+
for page in page_list:
|
|
212
|
+
assert page._backend is not None
|
|
213
|
+
if page._backend.is_valid():
|
|
214
|
+
valid_pages.append(page)
|
|
215
|
+
else:
|
|
216
|
+
invalid_pages.append(page)
|
|
217
|
+
|
|
218
|
+
if valid_pages:
|
|
219
|
+
with TimeRecorder(conv_res, "vlm"):
|
|
220
|
+
images: list[Image] = []
|
|
221
|
+
user_prompts: list[str] = []
|
|
222
|
+
pages_with_images: list[Page] = []
|
|
223
|
+
|
|
224
|
+
for page in valid_pages:
|
|
225
|
+
assert page.size is not None
|
|
226
|
+
hi_res_image = page.get_image(
|
|
227
|
+
scale=self.vlm_options.scale,
|
|
228
|
+
max_size=self.vlm_options.max_size,
|
|
229
|
+
)
|
|
230
|
+
if hi_res_image is None:
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
images.append(hi_res_image)
|
|
234
|
+
|
|
235
|
+
# Define prompt structure
|
|
236
|
+
user_prompt = self._build_prompt_safe(page)
|
|
237
|
+
|
|
238
|
+
user_prompts.append(user_prompt)
|
|
239
|
+
pages_with_images.append(page)
|
|
240
|
+
|
|
241
|
+
if images:
|
|
242
|
+
with TimeRecorder(conv_res, "vlm_inference"):
|
|
243
|
+
predictions = list(self.process_images(images, user_prompts))
|
|
244
|
+
for page, prediction in zip(pages_with_images, predictions):
|
|
245
|
+
page.predictions.vlm_response = prediction
|
|
246
|
+
|
|
247
|
+
# Yield in original order
|
|
248
|
+
for page in original_order:
|
|
249
|
+
yield page
|
|
250
|
+
|
|
251
|
+
def process_images(
|
|
252
|
+
self,
|
|
253
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
254
|
+
prompt: Union[str, list[str]],
|
|
255
|
+
) -> Iterable[VlmPrediction]:
|
|
256
|
+
"""Process images in a single batched vLLM inference call."""
|
|
257
|
+
import numpy as np
|
|
258
|
+
from PIL import Image as PILImage
|
|
259
|
+
|
|
260
|
+
# -- Normalize images to RGB PIL
|
|
261
|
+
pil_images: list[Image] = []
|
|
262
|
+
for img in image_batch:
|
|
263
|
+
if isinstance(img, np.ndarray):
|
|
264
|
+
if img.ndim == 3 and img.shape[2] in (3, 4):
|
|
265
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
|
266
|
+
elif img.ndim == 2:
|
|
267
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
|
268
|
+
else:
|
|
269
|
+
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
|
270
|
+
else:
|
|
271
|
+
pil_img = img
|
|
272
|
+
if pil_img.mode != "RGB":
|
|
273
|
+
pil_img = pil_img.convert("RGB")
|
|
274
|
+
pil_images.append(pil_img)
|
|
275
|
+
|
|
276
|
+
if not pil_images:
|
|
277
|
+
return
|
|
278
|
+
|
|
279
|
+
# Normalize prompts
|
|
280
|
+
if isinstance(prompt, str):
|
|
281
|
+
user_prompts = [prompt] * len(pil_images)
|
|
282
|
+
elif isinstance(prompt, list):
|
|
283
|
+
if len(prompt) != len(pil_images):
|
|
284
|
+
raise ValueError(
|
|
285
|
+
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
|
286
|
+
)
|
|
287
|
+
user_prompts = prompt
|
|
288
|
+
else:
|
|
289
|
+
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
|
290
|
+
|
|
291
|
+
# Format prompts
|
|
292
|
+
prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
|
|
293
|
+
|
|
294
|
+
# Build vLLM inputs
|
|
295
|
+
llm_inputs = [
|
|
296
|
+
{"prompt": p, "multi_modal_data": {"image": im}}
|
|
297
|
+
for p, im in zip(prompts, pil_images)
|
|
298
|
+
]
|
|
299
|
+
|
|
300
|
+
# Generate
|
|
301
|
+
assert self.llm is not None and self.sampling_params is not None
|
|
302
|
+
start_time = time.time()
|
|
303
|
+
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
|
|
304
|
+
generation_time = time.time() - start_time
|
|
305
|
+
|
|
306
|
+
# Optional debug
|
|
307
|
+
if outputs:
|
|
308
|
+
try:
|
|
309
|
+
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
|
|
310
|
+
_log.debug(
|
|
311
|
+
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
|
|
312
|
+
)
|
|
313
|
+
except Exception:
|
|
314
|
+
num_tokens_within_batch = 0
|
|
315
|
+
|
|
316
|
+
# Emit predictions
|
|
317
|
+
for i, output in enumerate(outputs):
|
|
318
|
+
text = output.outputs[0].text if output.outputs else ""
|
|
319
|
+
stop_reason = (
|
|
320
|
+
VlmStopReason.END_OF_SEQUENCE
|
|
321
|
+
if output.outputs[0].stop_reason
|
|
322
|
+
else VlmStopReason.LENGTH
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
generated_tokens = [
|
|
326
|
+
VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids
|
|
327
|
+
]
|
|
328
|
+
num_tokens = len(generated_tokens)
|
|
329
|
+
|
|
330
|
+
if not self.vlm_options.track_generated_tokens:
|
|
331
|
+
generated_tokens = []
|
|
332
|
+
|
|
333
|
+
input_prompt = prompts[i] if self.vlm_options.track_input_prompt else None
|
|
334
|
+
_log.debug(f"VLM generated response carries input prompt: {input_prompt}")
|
|
335
|
+
|
|
336
|
+
decoded_text = self.vlm_options.decode_response(text)
|
|
337
|
+
yield VlmPrediction(
|
|
338
|
+
text=decoded_text,
|
|
339
|
+
generation_time=generation_time,
|
|
340
|
+
num_tokens=num_tokens,
|
|
341
|
+
stop_reason=stop_reason,
|
|
342
|
+
generated_tokens=generated_tokens,
|
|
343
|
+
input_prompt=input_prompt,
|
|
344
|
+
)
|
|
File without changes
|