docling 2.69.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docling might be problematic. Click here for more details.

Files changed (138) hide show
  1. docling/__init__.py +0 -0
  2. docling/backend/__init__.py +0 -0
  3. docling/backend/abstract_backend.py +84 -0
  4. docling/backend/asciidoc_backend.py +443 -0
  5. docling/backend/csv_backend.py +125 -0
  6. docling/backend/docling_parse_backend.py +237 -0
  7. docling/backend/docling_parse_v2_backend.py +276 -0
  8. docling/backend/docling_parse_v4_backend.py +260 -0
  9. docling/backend/docx/__init__.py +0 -0
  10. docling/backend/docx/drawingml/utils.py +131 -0
  11. docling/backend/docx/latex/__init__.py +0 -0
  12. docling/backend/docx/latex/latex_dict.py +274 -0
  13. docling/backend/docx/latex/omml.py +459 -0
  14. docling/backend/html_backend.py +1502 -0
  15. docling/backend/image_backend.py +188 -0
  16. docling/backend/json/__init__.py +0 -0
  17. docling/backend/json/docling_json_backend.py +58 -0
  18. docling/backend/md_backend.py +618 -0
  19. docling/backend/mets_gbs_backend.py +399 -0
  20. docling/backend/msexcel_backend.py +686 -0
  21. docling/backend/mspowerpoint_backend.py +398 -0
  22. docling/backend/msword_backend.py +1663 -0
  23. docling/backend/noop_backend.py +51 -0
  24. docling/backend/pdf_backend.py +82 -0
  25. docling/backend/pypdfium2_backend.py +417 -0
  26. docling/backend/webvtt_backend.py +572 -0
  27. docling/backend/xml/__init__.py +0 -0
  28. docling/backend/xml/jats_backend.py +819 -0
  29. docling/backend/xml/uspto_backend.py +1905 -0
  30. docling/chunking/__init__.py +12 -0
  31. docling/cli/__init__.py +0 -0
  32. docling/cli/main.py +974 -0
  33. docling/cli/models.py +196 -0
  34. docling/cli/tools.py +17 -0
  35. docling/datamodel/__init__.py +0 -0
  36. docling/datamodel/accelerator_options.py +69 -0
  37. docling/datamodel/asr_model_specs.py +494 -0
  38. docling/datamodel/backend_options.py +102 -0
  39. docling/datamodel/base_models.py +493 -0
  40. docling/datamodel/document.py +699 -0
  41. docling/datamodel/extraction.py +39 -0
  42. docling/datamodel/layout_model_specs.py +91 -0
  43. docling/datamodel/pipeline_options.py +457 -0
  44. docling/datamodel/pipeline_options_asr_model.py +78 -0
  45. docling/datamodel/pipeline_options_vlm_model.py +136 -0
  46. docling/datamodel/settings.py +65 -0
  47. docling/datamodel/vlm_model_specs.py +365 -0
  48. docling/document_converter.py +559 -0
  49. docling/document_extractor.py +327 -0
  50. docling/exceptions.py +10 -0
  51. docling/experimental/__init__.py +5 -0
  52. docling/experimental/datamodel/__init__.py +1 -0
  53. docling/experimental/datamodel/table_crops_layout_options.py +13 -0
  54. docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py +45 -0
  55. docling/experimental/models/__init__.py +3 -0
  56. docling/experimental/models/table_crops_layout_model.py +114 -0
  57. docling/experimental/pipeline/__init__.py +1 -0
  58. docling/experimental/pipeline/threaded_layout_vlm_pipeline.py +439 -0
  59. docling/models/__init__.py +0 -0
  60. docling/models/base_layout_model.py +39 -0
  61. docling/models/base_model.py +230 -0
  62. docling/models/base_ocr_model.py +241 -0
  63. docling/models/base_table_model.py +45 -0
  64. docling/models/extraction/__init__.py +0 -0
  65. docling/models/extraction/nuextract_transformers_model.py +305 -0
  66. docling/models/factories/__init__.py +47 -0
  67. docling/models/factories/base_factory.py +122 -0
  68. docling/models/factories/layout_factory.py +7 -0
  69. docling/models/factories/ocr_factory.py +11 -0
  70. docling/models/factories/picture_description_factory.py +11 -0
  71. docling/models/factories/table_factory.py +7 -0
  72. docling/models/picture_description_base_model.py +149 -0
  73. docling/models/plugins/__init__.py +0 -0
  74. docling/models/plugins/defaults.py +60 -0
  75. docling/models/stages/__init__.py +0 -0
  76. docling/models/stages/code_formula/__init__.py +0 -0
  77. docling/models/stages/code_formula/code_formula_model.py +342 -0
  78. docling/models/stages/layout/__init__.py +0 -0
  79. docling/models/stages/layout/layout_model.py +249 -0
  80. docling/models/stages/ocr/__init__.py +0 -0
  81. docling/models/stages/ocr/auto_ocr_model.py +132 -0
  82. docling/models/stages/ocr/easyocr_model.py +200 -0
  83. docling/models/stages/ocr/ocr_mac_model.py +145 -0
  84. docling/models/stages/ocr/rapid_ocr_model.py +328 -0
  85. docling/models/stages/ocr/tesseract_ocr_cli_model.py +331 -0
  86. docling/models/stages/ocr/tesseract_ocr_model.py +262 -0
  87. docling/models/stages/page_assemble/__init__.py +0 -0
  88. docling/models/stages/page_assemble/page_assemble_model.py +156 -0
  89. docling/models/stages/page_preprocessing/__init__.py +0 -0
  90. docling/models/stages/page_preprocessing/page_preprocessing_model.py +145 -0
  91. docling/models/stages/picture_classifier/__init__.py +0 -0
  92. docling/models/stages/picture_classifier/document_picture_classifier.py +246 -0
  93. docling/models/stages/picture_description/__init__.py +0 -0
  94. docling/models/stages/picture_description/picture_description_api_model.py +66 -0
  95. docling/models/stages/picture_description/picture_description_vlm_model.py +123 -0
  96. docling/models/stages/reading_order/__init__.py +0 -0
  97. docling/models/stages/reading_order/readingorder_model.py +431 -0
  98. docling/models/stages/table_structure/__init__.py +0 -0
  99. docling/models/stages/table_structure/table_structure_model.py +305 -0
  100. docling/models/utils/__init__.py +0 -0
  101. docling/models/utils/generation_utils.py +157 -0
  102. docling/models/utils/hf_model_download.py +45 -0
  103. docling/models/vlm_pipeline_models/__init__.py +1 -0
  104. docling/models/vlm_pipeline_models/api_vlm_model.py +180 -0
  105. docling/models/vlm_pipeline_models/hf_transformers_model.py +391 -0
  106. docling/models/vlm_pipeline_models/mlx_model.py +325 -0
  107. docling/models/vlm_pipeline_models/vllm_model.py +344 -0
  108. docling/pipeline/__init__.py +0 -0
  109. docling/pipeline/asr_pipeline.py +431 -0
  110. docling/pipeline/base_extraction_pipeline.py +72 -0
  111. docling/pipeline/base_pipeline.py +326 -0
  112. docling/pipeline/extraction_vlm_pipeline.py +207 -0
  113. docling/pipeline/legacy_standard_pdf_pipeline.py +262 -0
  114. docling/pipeline/simple_pipeline.py +55 -0
  115. docling/pipeline/standard_pdf_pipeline.py +859 -0
  116. docling/pipeline/threaded_standard_pdf_pipeline.py +5 -0
  117. docling/pipeline/vlm_pipeline.py +416 -0
  118. docling/py.typed +1 -0
  119. docling/utils/__init__.py +0 -0
  120. docling/utils/accelerator_utils.py +97 -0
  121. docling/utils/api_image_request.py +205 -0
  122. docling/utils/deepseekocr_utils.py +388 -0
  123. docling/utils/export.py +146 -0
  124. docling/utils/glm_utils.py +361 -0
  125. docling/utils/layout_postprocessor.py +683 -0
  126. docling/utils/locks.py +3 -0
  127. docling/utils/model_downloader.py +168 -0
  128. docling/utils/ocr_utils.py +69 -0
  129. docling/utils/orientation.py +65 -0
  130. docling/utils/profiling.py +65 -0
  131. docling/utils/utils.py +65 -0
  132. docling/utils/visualization.py +85 -0
  133. docling-2.69.0.dist-info/METADATA +237 -0
  134. docling-2.69.0.dist-info/RECORD +138 -0
  135. docling-2.69.0.dist-info/WHEEL +5 -0
  136. docling-2.69.0.dist-info/entry_points.txt +6 -0
  137. docling-2.69.0.dist-info/licenses/LICENSE +21 -0
  138. docling-2.69.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,325 @@
1
+ import logging
2
+ import sys
3
+ import threading
4
+ import time
5
+ from collections.abc import Iterable
6
+ from pathlib import Path
7
+ from typing import Optional, Union
8
+
9
+ import numpy as np
10
+ from PIL.Image import Image
11
+ from transformers import StoppingCriteria
12
+
13
+ from docling.datamodel.accelerator_options import (
14
+ AcceleratorOptions,
15
+ )
16
+ from docling.datamodel.base_models import (
17
+ Page,
18
+ VlmPrediction,
19
+ VlmPredictionToken,
20
+ VlmStopReason,
21
+ )
22
+ from docling.datamodel.document import ConversionResult
23
+ from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
24
+ from docling.models.base_model import BaseVlmPageModel
25
+ from docling.models.utils.generation_utils import GenerationStopper
26
+ from docling.models.utils.hf_model_download import (
27
+ HuggingFaceModelDownloadMixin,
28
+ )
29
+ from docling.utils.profiling import TimeRecorder
30
+
31
+ _log = logging.getLogger(__name__)
32
+
33
+ # Global lock for MLX model calls - MLX models are not thread-safe
34
+ # All MLX models share this lock to prevent concurrent MLX operations
35
+ _MLX_GLOBAL_LOCK = threading.Lock()
36
+
37
+
38
+ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
39
+ def __init__(
40
+ self,
41
+ enabled: bool,
42
+ artifacts_path: Optional[Path],
43
+ accelerator_options: AcceleratorOptions,
44
+ vlm_options: InlineVlmOptions,
45
+ ):
46
+ self.enabled = enabled
47
+
48
+ self.vlm_options = vlm_options
49
+ self.max_tokens = vlm_options.max_new_tokens
50
+ self.temperature = vlm_options.temperature
51
+
52
+ if self.enabled:
53
+ try:
54
+ from mlx_vlm import generate, load, stream_generate # type: ignore
55
+ from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
56
+ from mlx_vlm.utils import load_config # type: ignore
57
+ except ImportError:
58
+ raise ImportError(
59
+ "mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
60
+ )
61
+
62
+ repo_cache_folder = vlm_options.repo_id.replace("/", "--")
63
+
64
+ self.apply_chat_template = apply_chat_template
65
+ self.stream_generate = stream_generate
66
+
67
+ # PARAMETERS:
68
+ if artifacts_path is None:
69
+ artifacts_path = self.download_models(
70
+ self.vlm_options.repo_id,
71
+ revision=self.vlm_options.revision,
72
+ )
73
+ elif (artifacts_path / repo_cache_folder).exists():
74
+ artifacts_path = artifacts_path / repo_cache_folder
75
+
76
+ ## Load the model
77
+ self.vlm_model, self.processor = load(artifacts_path)
78
+ self.config = load_config(artifacts_path)
79
+
80
+ # Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
81
+ if self.vlm_options.custom_stopping_criteria:
82
+ for criteria in self.vlm_options.custom_stopping_criteria:
83
+ if isinstance(criteria, StoppingCriteria):
84
+ raise ValueError(
85
+ f"MLX models do not support HuggingFace StoppingCriteria instances. "
86
+ f"Found {type(criteria).__name__}. Use GenerationStopper instead."
87
+ )
88
+ elif isinstance(criteria, type) and issubclass(
89
+ criteria, StoppingCriteria
90
+ ):
91
+ raise ValueError(
92
+ f"MLX models do not support HuggingFace StoppingCriteria classes. "
93
+ f"Found {criteria.__name__}. Use GenerationStopper instead."
94
+ )
95
+
96
+ def __call__(
97
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
98
+ ) -> Iterable[Page]:
99
+ page_list = list(page_batch)
100
+ if not page_list:
101
+ return
102
+
103
+ valid_pages = []
104
+ invalid_pages = []
105
+
106
+ for page in page_list:
107
+ assert page._backend is not None
108
+ if not page._backend.is_valid():
109
+ invalid_pages.append(page)
110
+ else:
111
+ valid_pages.append(page)
112
+
113
+ # Process valid pages in batch
114
+ if valid_pages:
115
+ with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
116
+ # Prepare images and prompts for batch processing
117
+ images = []
118
+ user_prompts = []
119
+ pages_with_images = []
120
+
121
+ for page in valid_pages:
122
+ assert page.size is not None
123
+ hi_res_image = page.get_image(
124
+ scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
125
+ )
126
+
127
+ # Only process pages with valid images
128
+ if hi_res_image is not None:
129
+ images.append(hi_res_image)
130
+
131
+ # Define prompt structure
132
+ user_prompt = self._build_prompt_safe(page)
133
+
134
+ user_prompts.append(user_prompt)
135
+ pages_with_images.append(page)
136
+
137
+ # Use process_images for the actual inference
138
+ if images: # Only if we have valid images
139
+ predictions = list(self.process_images(images, user_prompts))
140
+
141
+ # Attach results to pages
142
+ for page, prediction in zip(pages_with_images, predictions):
143
+ page.predictions.vlm_response = prediction
144
+
145
+ # Yield all pages (valid and invalid)
146
+ for page in invalid_pages:
147
+ yield page
148
+ for page in valid_pages:
149
+ yield page
150
+
151
+ def process_images(
152
+ self,
153
+ image_batch: Iterable[Union[Image, np.ndarray]],
154
+ prompt: Union[str, list[str]],
155
+ ) -> Iterable[VlmPrediction]:
156
+ """Process raw images without page metadata.
157
+
158
+ Args:
159
+ image_batch: Iterable of PIL Images or numpy arrays
160
+ prompt: Either:
161
+ - str: Single prompt used for all images
162
+ - list[str]: List of prompts (one per image, must match image count)
163
+
164
+ Raises:
165
+ ValueError: If prompt list length doesn't match image count.
166
+ """
167
+ # Convert image batch to list for length validation
168
+ image_list = list(image_batch)
169
+
170
+ if len(image_list) == 0:
171
+ return
172
+
173
+ # Handle prompt parameter
174
+ if isinstance(prompt, str):
175
+ # Single prompt for all images
176
+ user_prompts = [prompt] * len(image_list)
177
+ elif isinstance(prompt, list):
178
+ # List of prompts (one per image)
179
+ if len(prompt) != len(image_list):
180
+ raise ValueError(
181
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
182
+ )
183
+ user_prompts = prompt
184
+ else:
185
+ raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
186
+
187
+ # MLX models are not thread-safe - use global lock to serialize access
188
+ with _MLX_GLOBAL_LOCK:
189
+ _log.debug("MLX model: Acquired global lock for thread safety")
190
+ for image, user_prompt in zip(image_list, user_prompts):
191
+ # Convert numpy array to PIL Image if needed
192
+ if isinstance(image, np.ndarray):
193
+ if image.ndim == 3 and image.shape[2] in [3, 4]:
194
+ # RGB or RGBA array
195
+ from PIL import Image as PILImage
196
+
197
+ image = PILImage.fromarray(image.astype(np.uint8))
198
+ elif image.ndim == 2:
199
+ # Grayscale array
200
+ from PIL import Image as PILImage
201
+
202
+ image = PILImage.fromarray(image.astype(np.uint8), mode="L")
203
+ else:
204
+ raise ValueError(
205
+ f"Unsupported numpy array shape: {image.shape}"
206
+ )
207
+
208
+ # Ensure image is in RGB mode (handles RGBA, L, etc.)
209
+ if image.mode != "RGB":
210
+ image = image.convert("RGB")
211
+
212
+ # Use the MLX chat template approach like in the __call__ method
213
+ formatted_prompt = self.apply_chat_template(
214
+ self.processor, self.config, user_prompt, num_images=1
215
+ )
216
+
217
+ # Stream generate with stop strings and custom stopping criteria support
218
+ start_time = time.time()
219
+ _log.debug("start generating ...")
220
+
221
+ tokens: list[VlmPredictionToken] = []
222
+ output = ""
223
+
224
+ # Use stream_generate for proper stop string handling
225
+ for token in self.stream_generate(
226
+ self.vlm_model,
227
+ self.processor,
228
+ formatted_prompt,
229
+ [image], # MLX stream_generate expects list of images
230
+ max_tokens=self.max_tokens,
231
+ verbose=False,
232
+ temp=self.temperature,
233
+ ):
234
+ # Collect token information
235
+ if len(token.logprobs.shape) == 1:
236
+ tokens.append(
237
+ VlmPredictionToken(
238
+ text=token.text,
239
+ token=token.token,
240
+ logprob=token.logprobs[token.token],
241
+ )
242
+ )
243
+ elif (
244
+ len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
245
+ ):
246
+ tokens.append(
247
+ VlmPredictionToken(
248
+ text=token.text,
249
+ token=token.token,
250
+ logprob=token.logprobs[0, token.token],
251
+ )
252
+ )
253
+ else:
254
+ _log.warning(
255
+ f"incompatible shape for logprobs: {token.logprobs.shape}"
256
+ )
257
+
258
+ output += token.text
259
+
260
+ # Check for any configured stop strings
261
+ if self.vlm_options.stop_strings:
262
+ if any(
263
+ stop_str in output
264
+ for stop_str in self.vlm_options.stop_strings
265
+ ):
266
+ _log.debug("Stopping generation due to stop string match")
267
+ break
268
+
269
+ # Check for custom stopping criteria (GenerationStopper instances)
270
+ if self.vlm_options.custom_stopping_criteria:
271
+ for criteria in self.vlm_options.custom_stopping_criteria:
272
+ # Handle both instances and classes of GenerationStopper
273
+ if isinstance(criteria, GenerationStopper):
274
+ stopper = criteria
275
+ elif isinstance(criteria, type) and issubclass(
276
+ criteria, GenerationStopper
277
+ ):
278
+ stopper = criteria()
279
+
280
+ # Determine the text window to check based on lookback_tokens
281
+ lookback_tokens = stopper.lookback_tokens()
282
+ # Check only the last N characters worth of text
283
+ # This is a simplified approach - in practice, you might want to
284
+ # decode the last N tokens from the token list for more accuracy
285
+ text_to_check = (
286
+ output[-lookback_tokens:]
287
+ if len(output) > lookback_tokens
288
+ else output
289
+ )
290
+
291
+ try:
292
+ if stopper.should_stop(text_to_check):
293
+ _log.info(
294
+ f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
295
+ )
296
+ break
297
+ except Exception as e:
298
+ _log.warning(
299
+ f"Error in GenerationStopper.should_stop: {e}"
300
+ )
301
+ continue
302
+ else: # note: for-else idiom
303
+ continue # Only executed if the inner loop didn't break
304
+ break # Break the outer loop if any stopper triggered
305
+
306
+ generation_time = time.time() - start_time
307
+
308
+ _log.debug(
309
+ f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
310
+ )
311
+
312
+ # Apply decode_response to the output before yielding
313
+ decoded_output = self.vlm_options.decode_response(output)
314
+ input_prompt = (
315
+ formatted_prompt if self.vlm_options.track_input_prompt else None
316
+ )
317
+ yield VlmPrediction(
318
+ text=decoded_output,
319
+ generation_time=generation_time,
320
+ generated_tokens=tokens,
321
+ num_tokens=len(tokens),
322
+ stop_reason=VlmStopReason.UNSPECIFIED,
323
+ input_prompt=input_prompt,
324
+ )
325
+ _log.debug("MLX model: Released global lock")
@@ -0,0 +1,344 @@
1
+ import logging
2
+ import sys
3
+ import time
4
+ from collections.abc import Iterable
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Optional, Union
7
+
8
+ import numpy as np
9
+ from PIL.Image import Image
10
+
11
+ from docling.datamodel.accelerator_options import AcceleratorOptions
12
+ from docling.datamodel.base_models import (
13
+ Page,
14
+ VlmPrediction,
15
+ VlmPredictionToken,
16
+ VlmStopReason,
17
+ )
18
+ from docling.datamodel.document import ConversionResult
19
+ from docling.datamodel.pipeline_options_vlm_model import (
20
+ InlineVlmOptions,
21
+ TransformersPromptStyle,
22
+ )
23
+ from docling.models.base_model import BaseVlmPageModel
24
+ from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
25
+ from docling.utils.accelerator_utils import decide_device
26
+ from docling.utils.profiling import TimeRecorder
27
+
28
+ _log = logging.getLogger(__name__)
29
+
30
+
31
+ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
32
+ """
33
+ vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
34
+ via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
35
+ """
36
+
37
+ # --------- Allowlist of vLLM args ---------
38
+ # SamplingParams (runtime generation controls)
39
+ _VLLM_SAMPLING_KEYS = {
40
+ # Core
41
+ "max_tokens",
42
+ "temperature",
43
+ "top_p",
44
+ "top_k",
45
+ # Penalties
46
+ "presence_penalty",
47
+ "frequency_penalty",
48
+ "repetition_penalty",
49
+ # Stops / outputs
50
+ "stop",
51
+ "stop_token_ids",
52
+ "skip_special_tokens",
53
+ "spaces_between_special_tokens",
54
+ # Search / length
55
+ "n",
56
+ "best_of",
57
+ "length_penalty",
58
+ "early_stopping",
59
+ # Misc
60
+ "logprobs",
61
+ "prompt_logprobs",
62
+ "min_p",
63
+ "seed",
64
+ }
65
+
66
+ # LLM(...) / EngineArgs (engine/load-time controls)
67
+ _VLLM_ENGINE_KEYS = {
68
+ # Model/tokenizer/impl
69
+ "tokenizer",
70
+ "tokenizer_mode",
71
+ "download_dir",
72
+ # Parallelism / memory / lengths
73
+ "tensor_parallel_size",
74
+ "pipeline_parallel_size",
75
+ "gpu_memory_utilization",
76
+ "max_model_len",
77
+ "max_num_batched_tokens",
78
+ "kv_cache_dtype",
79
+ "dtype",
80
+ # Quantization (coarse switch)
81
+ "quantization",
82
+ # Multimodal limits
83
+ "limit_mm_per_prompt",
84
+ # Execution toggles
85
+ "enforce_eager",
86
+ }
87
+
88
+ def __init__(
89
+ self,
90
+ enabled: bool,
91
+ artifacts_path: Optional[Path],
92
+ accelerator_options: AcceleratorOptions,
93
+ vlm_options: InlineVlmOptions,
94
+ ):
95
+ self.enabled = enabled
96
+ self.vlm_options: InlineVlmOptions = vlm_options
97
+
98
+ self.llm = None
99
+ self.sampling_params = None
100
+ self.processor = None # used for CHAT templating in formulate_prompt()
101
+ self.device = "cpu"
102
+ self.max_new_tokens = vlm_options.max_new_tokens
103
+ self.temperature = vlm_options.temperature
104
+
105
+ if not self.enabled:
106
+ return
107
+
108
+ from transformers import AutoProcessor
109
+
110
+ try:
111
+ from vllm import LLM, SamplingParams
112
+ except ImportError:
113
+ if sys.version_info < (3, 14):
114
+ raise ImportError(
115
+ "vllm is not installed. Please install it via `pip install vllm`."
116
+ )
117
+ else:
118
+ raise ImportError(
119
+ "vllm is not installed. It is not yet available on Python 3.14."
120
+ )
121
+
122
+ # Device selection
123
+ self.device = decide_device(
124
+ accelerator_options.device, supported_devices=vlm_options.supported_devices
125
+ )
126
+ _log.debug(f"Available device for VLM: {self.device}")
127
+
128
+ # Resolve artifacts path / cache folder
129
+ repo_cache_folder = vlm_options.repo_id.replace("/", "--")
130
+ if artifacts_path is None:
131
+ artifacts_path = self.download_models(
132
+ self.vlm_options.repo_id, revision=self.vlm_options.revision
133
+ )
134
+ elif (artifacts_path / repo_cache_folder).exists():
135
+ artifacts_path = artifacts_path / repo_cache_folder
136
+
137
+ # --------- Strict split & validation of extra_generation_config ---------
138
+ extra_cfg = self.vlm_options.extra_generation_config
139
+
140
+ load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
141
+ gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
142
+
143
+ unknown = sorted(
144
+ k
145
+ for k in extra_cfg.keys()
146
+ if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
147
+ )
148
+ if unknown:
149
+ _log.warning(
150
+ "Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
151
+ )
152
+
153
+ # --------- Construct LLM kwargs (engine/load-time) ---------
154
+ llm_kwargs: Dict[str, Any] = {
155
+ "model": str(artifacts_path),
156
+ "model_impl": "transformers",
157
+ "limit_mm_per_prompt": {"image": 1},
158
+ "revision": self.vlm_options.revision,
159
+ "trust_remote_code": self.vlm_options.trust_remote_code,
160
+ **load_cfg,
161
+ }
162
+
163
+ if self.device == "cpu":
164
+ llm_kwargs.setdefault("enforce_eager", True)
165
+ else:
166
+ llm_kwargs.setdefault(
167
+ "gpu_memory_utilization", 0.3
168
+ ) # room for other models
169
+
170
+ # Quantization (kept as-is; coarse)
171
+ if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
172
+ llm_kwargs.setdefault("quantization", "bitsandbytes")
173
+
174
+ # Initialize vLLM LLM
175
+ self.llm = LLM(**llm_kwargs)
176
+
177
+ # Initialize processor for prompt templating (needed for CHAT style)
178
+ self.processor = AutoProcessor.from_pretrained(
179
+ artifacts_path,
180
+ trust_remote_code=self.vlm_options.trust_remote_code,
181
+ revision=self.vlm_options.revision,
182
+ )
183
+
184
+ # --------- SamplingParams (runtime) ---------
185
+ self.sampling_params = SamplingParams(
186
+ temperature=self.temperature,
187
+ max_tokens=self.max_new_tokens,
188
+ stop=(self.vlm_options.stop_strings or None),
189
+ **gen_cfg,
190
+ )
191
+
192
+ def __call__(
193
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
194
+ ) -> Iterable[Page]:
195
+ # If disabled, pass-through
196
+ if not self.enabled:
197
+ for page in page_batch:
198
+ yield page
199
+ return
200
+
201
+ page_list = list(page_batch)
202
+ if not page_list:
203
+ return
204
+
205
+ # Preserve original order
206
+ original_order = page_list[:]
207
+
208
+ # Separate valid/invalid
209
+ valid_pages: list[Page] = []
210
+ invalid_pages: list[Page] = []
211
+ for page in page_list:
212
+ assert page._backend is not None
213
+ if page._backend.is_valid():
214
+ valid_pages.append(page)
215
+ else:
216
+ invalid_pages.append(page)
217
+
218
+ if valid_pages:
219
+ with TimeRecorder(conv_res, "vlm"):
220
+ images: list[Image] = []
221
+ user_prompts: list[str] = []
222
+ pages_with_images: list[Page] = []
223
+
224
+ for page in valid_pages:
225
+ assert page.size is not None
226
+ hi_res_image = page.get_image(
227
+ scale=self.vlm_options.scale,
228
+ max_size=self.vlm_options.max_size,
229
+ )
230
+ if hi_res_image is None:
231
+ continue
232
+
233
+ images.append(hi_res_image)
234
+
235
+ # Define prompt structure
236
+ user_prompt = self._build_prompt_safe(page)
237
+
238
+ user_prompts.append(user_prompt)
239
+ pages_with_images.append(page)
240
+
241
+ if images:
242
+ with TimeRecorder(conv_res, "vlm_inference"):
243
+ predictions = list(self.process_images(images, user_prompts))
244
+ for page, prediction in zip(pages_with_images, predictions):
245
+ page.predictions.vlm_response = prediction
246
+
247
+ # Yield in original order
248
+ for page in original_order:
249
+ yield page
250
+
251
+ def process_images(
252
+ self,
253
+ image_batch: Iterable[Union[Image, np.ndarray]],
254
+ prompt: Union[str, list[str]],
255
+ ) -> Iterable[VlmPrediction]:
256
+ """Process images in a single batched vLLM inference call."""
257
+ import numpy as np
258
+ from PIL import Image as PILImage
259
+
260
+ # -- Normalize images to RGB PIL
261
+ pil_images: list[Image] = []
262
+ for img in image_batch:
263
+ if isinstance(img, np.ndarray):
264
+ if img.ndim == 3 and img.shape[2] in (3, 4):
265
+ pil_img = PILImage.fromarray(img.astype(np.uint8))
266
+ elif img.ndim == 2:
267
+ pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
268
+ else:
269
+ raise ValueError(f"Unsupported numpy array shape: {img.shape}")
270
+ else:
271
+ pil_img = img
272
+ if pil_img.mode != "RGB":
273
+ pil_img = pil_img.convert("RGB")
274
+ pil_images.append(pil_img)
275
+
276
+ if not pil_images:
277
+ return
278
+
279
+ # Normalize prompts
280
+ if isinstance(prompt, str):
281
+ user_prompts = [prompt] * len(pil_images)
282
+ elif isinstance(prompt, list):
283
+ if len(prompt) != len(pil_images):
284
+ raise ValueError(
285
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
286
+ )
287
+ user_prompts = prompt
288
+ else:
289
+ raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
290
+
291
+ # Format prompts
292
+ prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
293
+
294
+ # Build vLLM inputs
295
+ llm_inputs = [
296
+ {"prompt": p, "multi_modal_data": {"image": im}}
297
+ for p, im in zip(prompts, pil_images)
298
+ ]
299
+
300
+ # Generate
301
+ assert self.llm is not None and self.sampling_params is not None
302
+ start_time = time.time()
303
+ outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
304
+ generation_time = time.time() - start_time
305
+
306
+ # Optional debug
307
+ if outputs:
308
+ try:
309
+ num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
310
+ _log.debug(
311
+ f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
312
+ )
313
+ except Exception:
314
+ num_tokens_within_batch = 0
315
+
316
+ # Emit predictions
317
+ for i, output in enumerate(outputs):
318
+ text = output.outputs[0].text if output.outputs else ""
319
+ stop_reason = (
320
+ VlmStopReason.END_OF_SEQUENCE
321
+ if output.outputs[0].stop_reason
322
+ else VlmStopReason.LENGTH
323
+ )
324
+
325
+ generated_tokens = [
326
+ VlmPredictionToken(token=int(t)) for t in output.outputs[0].token_ids
327
+ ]
328
+ num_tokens = len(generated_tokens)
329
+
330
+ if not self.vlm_options.track_generated_tokens:
331
+ generated_tokens = []
332
+
333
+ input_prompt = prompts[i] if self.vlm_options.track_input_prompt else None
334
+ _log.debug(f"VLM generated response carries input prompt: {input_prompt}")
335
+
336
+ decoded_text = self.vlm_options.decode_response(text)
337
+ yield VlmPrediction(
338
+ text=decoded_text,
339
+ generation_time=generation_time,
340
+ num_tokens=num_tokens,
341
+ stop_reason=stop_reason,
342
+ generated_tokens=generated_tokens,
343
+ input_prompt=input_prompt,
344
+ )
File without changes