docling 2.54.0__py3-none-any.whl → 2.55.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docling might be problematic. Click here for more details.

@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  from PIL.Image import Image
10
- from transformers import StoppingCriteriaList, StopStringCriteria
10
+ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
11
11
 
12
12
  from docling.datamodel.accelerator_options import (
13
13
  AcceleratorOptions,
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
20
20
  TransformersPromptStyle,
21
21
  )
22
22
  from docling.models.base_model import BaseVlmPageModel
23
+ from docling.models.utils.generation_utils import (
24
+ GenerationStopper,
25
+ HFStoppingCriteriaWrapper,
26
+ )
23
27
  from docling.models.utils.hf_model_download import (
24
28
  HuggingFaceModelDownloadMixin,
25
29
  )
@@ -75,7 +79,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
75
79
  repo_cache_folder = vlm_options.repo_id.replace("/", "--")
76
80
 
77
81
  if artifacts_path is None:
78
- artifacts_path = self.download_models(self.vlm_options.repo_id)
82
+ artifacts_path = self.download_models(
83
+ self.vlm_options.repo_id, revision=self.vlm_options.revision
84
+ )
79
85
  elif (artifacts_path / repo_cache_folder).exists():
80
86
  artifacts_path = artifacts_path / repo_cache_folder
81
87
 
@@ -106,6 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
106
112
  self.processor = AutoProcessor.from_pretrained(
107
113
  artifacts_path,
108
114
  trust_remote_code=vlm_options.trust_remote_code,
115
+ revision=vlm_options.revision,
109
116
  )
110
117
  self.processor.tokenizer.padding_side = "left"
111
118
 
@@ -120,11 +127,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
120
127
  else "sdpa"
121
128
  ),
122
129
  trust_remote_code=vlm_options.trust_remote_code,
130
+ revision=vlm_options.revision,
123
131
  )
124
132
  self.vlm_model = torch.compile(self.vlm_model) # type: ignore
125
133
 
126
134
  # Load generation config
127
- self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
135
+ self.generation_config = GenerationConfig.from_pretrained(
136
+ artifacts_path, revision=vlm_options.revision
137
+ )
128
138
 
129
139
  def __call__(
130
140
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
@@ -196,7 +206,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
196
206
  import torch
197
207
  from PIL import Image as PILImage
198
208
 
199
- # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
209
+ # -- Normalize images to RGB PIL
200
210
  pil_images: list[Image] = []
201
211
  for img in image_batch:
202
212
  if isinstance(img, np.ndarray):
@@ -247,24 +257,74 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
247
257
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
248
258
 
249
259
  # -- Optional stopping criteria
250
- stopping_criteria = None
260
+ stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
261
+
262
+ # Add string-based stopping criteria
251
263
  if self.vlm_options.stop_strings:
252
- stopping_criteria = StoppingCriteriaList(
253
- [
254
- StopStringCriteria(
255
- stop_strings=self.vlm_options.stop_strings,
256
- tokenizer=self.processor.tokenizer,
257
- )
258
- ]
264
+ stopping_criteria_list.append(
265
+ StopStringCriteria(
266
+ stop_strings=self.vlm_options.stop_strings,
267
+ tokenizer=self.processor.tokenizer,
268
+ )
259
269
  )
260
270
 
271
+ # Add custom stopping criteria
272
+ if self.vlm_options.custom_stopping_criteria:
273
+ for criteria in self.vlm_options.custom_stopping_criteria:
274
+ # If it's a class (not an instance), determine the type and handle accordingly
275
+ if isinstance(criteria, type):
276
+ # Check if it's a GenerationStopper class
277
+ if issubclass(criteria, GenerationStopper):
278
+ # Instantiate GenerationStopper and wrap it
279
+ stopper_instance = criteria()
280
+ wrapped_criteria = HFStoppingCriteriaWrapper(
281
+ self.processor.tokenizer, stopper_instance
282
+ )
283
+ stopping_criteria_list.append(wrapped_criteria)
284
+ elif issubclass(criteria, StoppingCriteria):
285
+ # It's a StoppingCriteria class, instantiate with tokenizer
286
+ criteria_instance = criteria(self.processor.tokenizer)
287
+ stopping_criteria_list.append(criteria_instance)
288
+ elif isinstance(criteria, GenerationStopper):
289
+ # Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
290
+ wrapped_criteria = HFStoppingCriteriaWrapper(
291
+ self.processor.tokenizer, criteria
292
+ )
293
+ stopping_criteria_list.append(wrapped_criteria)
294
+ else:
295
+ # If it's already an instance of StoppingCriteria, use it directly
296
+ stopping_criteria_list.append(criteria)
297
+
298
+ stopping_criteria = (
299
+ StoppingCriteriaList(stopping_criteria_list)
300
+ if stopping_criteria_list
301
+ else None
302
+ )
303
+
304
+ # -- Filter out decoder-specific keys from extra_generation_config
305
+ decoder_keys = {
306
+ "skip_special_tokens",
307
+ "clean_up_tokenization_spaces",
308
+ "spaces_between_special_tokens",
309
+ }
310
+ generation_config = {
311
+ k: v
312
+ for k, v in self.vlm_options.extra_generation_config.items()
313
+ if k not in decoder_keys
314
+ }
315
+ decoder_config = {
316
+ k: v
317
+ for k, v in self.vlm_options.extra_generation_config.items()
318
+ if k in decoder_keys
319
+ }
320
+
261
321
  # -- Generate (Image-Text-to-Text class expects these inputs from processor)
262
322
  gen_kwargs = {
263
323
  **inputs,
264
324
  "max_new_tokens": self.max_new_tokens,
265
325
  "use_cache": self.use_cache,
266
326
  "generation_config": self.generation_config,
267
- **self.vlm_options.extra_generation_config,
327
+ **generation_config,
268
328
  }
269
329
  if self.temperature > 0:
270
330
  gen_kwargs["do_sample"] = True
@@ -293,7 +353,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
293
353
  )
294
354
 
295
355
  decoded_texts: list[str] = decode_fn(
296
- trimmed_sequences, skip_special_tokens=False
356
+ trimmed_sequences,
357
+ **decoder_config,
297
358
  )
298
359
 
299
360
  # -- Clip off pad tokens from decoded texts
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import sys
2
3
  import threading
3
4
  import time
4
5
  from collections.abc import Iterable
@@ -7,6 +8,7 @@ from typing import Optional, Union
7
8
 
8
9
  import numpy as np
9
10
  from PIL.Image import Image
11
+ from transformers import StoppingCriteria
10
12
 
11
13
  from docling.datamodel.accelerator_options import (
12
14
  AcceleratorOptions,
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
15
17
  from docling.datamodel.document import ConversionResult
16
18
  from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
17
19
  from docling.models.base_model import BaseVlmPageModel
20
+ from docling.models.utils.generation_utils import GenerationStopper
18
21
  from docling.models.utils.hf_model_download import (
19
22
  HuggingFaceModelDownloadMixin,
20
23
  )
@@ -60,6 +63,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
60
63
  if artifacts_path is None:
61
64
  artifacts_path = self.download_models(
62
65
  self.vlm_options.repo_id,
66
+ revision=self.vlm_options.revision,
63
67
  )
64
68
  elif (artifacts_path / repo_cache_folder).exists():
65
69
  artifacts_path = artifacts_path / repo_cache_folder
@@ -68,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
68
72
  self.vlm_model, self.processor = load(artifacts_path)
69
73
  self.config = load_config(artifacts_path)
70
74
 
75
+ # Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
76
+ if self.vlm_options.custom_stopping_criteria:
77
+ for criteria in self.vlm_options.custom_stopping_criteria:
78
+ if isinstance(criteria, StoppingCriteria):
79
+ raise ValueError(
80
+ f"MLX models do not support HuggingFace StoppingCriteria instances. "
81
+ f"Found {type(criteria).__name__}. Use GenerationStopper instead."
82
+ )
83
+ elif isinstance(criteria, type) and issubclass(
84
+ criteria, StoppingCriteria
85
+ ):
86
+ raise ValueError(
87
+ f"MLX models do not support HuggingFace StoppingCriteria classes. "
88
+ f"Found {criteria.__name__}. Use GenerationStopper instead."
89
+ )
90
+
71
91
  def __call__(
72
92
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
73
93
  ) -> Iterable[Page]:
@@ -192,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
192
212
  self.processor, self.config, user_prompt, num_images=1
193
213
  )
194
214
 
195
- # Stream generate with stop strings support
215
+ # Stream generate with stop strings and custom stopping criteria support
196
216
  start_time = time.time()
197
217
  _log.debug("start generating ...")
198
218
 
@@ -244,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
244
264
  _log.debug("Stopping generation due to stop string match")
245
265
  break
246
266
 
267
+ # Check for custom stopping criteria (GenerationStopper instances)
268
+ if self.vlm_options.custom_stopping_criteria:
269
+ for criteria in self.vlm_options.custom_stopping_criteria:
270
+ # Handle both instances and classes of GenerationStopper
271
+ if isinstance(criteria, GenerationStopper):
272
+ stopper = criteria
273
+ elif isinstance(criteria, type) and issubclass(
274
+ criteria, GenerationStopper
275
+ ):
276
+ stopper = criteria()
277
+
278
+ # Determine the text window to check based on lookback_tokens
279
+ lookback_tokens = stopper.lookback_tokens()
280
+ # Check only the last N characters worth of text
281
+ # This is a simplified approach - in practice, you might want to
282
+ # decode the last N tokens from the token list for more accuracy
283
+ text_to_check = (
284
+ output[-lookback_tokens:]
285
+ if len(output) > lookback_tokens
286
+ else output
287
+ )
288
+
289
+ try:
290
+ if stopper.should_stop(text_to_check):
291
+ _log.info(
292
+ f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
293
+ )
294
+ break
295
+ except Exception as e:
296
+ _log.warning(
297
+ f"Error in GenerationStopper.should_stop: {e}"
298
+ )
299
+ continue
300
+ else: # note: for-else idiom
301
+ continue # Only executed if the inner loop didn't break
302
+ break # Break the outer loop if any stopper triggered
303
+
247
304
  generation_time = time.time() - start_time
248
305
 
249
306
  _log.debug(
@@ -7,9 +7,7 @@ from typing import Any, Dict, Optional, Union
7
7
  import numpy as np
8
8
  from PIL.Image import Image
9
9
 
10
- from docling.datamodel.accelerator_options import (
11
- AcceleratorOptions,
12
- )
10
+ from docling.datamodel.accelerator_options import AcceleratorOptions
13
11
  from docling.datamodel.base_models import Page, VlmPrediction
14
12
  from docling.datamodel.document import ConversionResult
15
13
  from docling.datamodel.pipeline_options_vlm_model import (
@@ -17,9 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
17
15
  TransformersPromptStyle,
18
16
  )
19
17
  from docling.models.base_model import BaseVlmPageModel
20
- from docling.models.utils.hf_model_download import (
21
- HuggingFaceModelDownloadMixin,
22
- )
18
+ from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
23
19
  from docling.utils.accelerator_utils import decide_device
24
20
  from docling.utils.profiling import TimeRecorder
25
21
 
@@ -27,6 +23,62 @@ _log = logging.getLogger(__name__)
27
23
 
28
24
 
29
25
  class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
26
+ """
27
+ vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
28
+ via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
29
+ """
30
+
31
+ # --------- Allowlist of vLLM args ---------
32
+ # SamplingParams (runtime generation controls)
33
+ _VLLM_SAMPLING_KEYS = {
34
+ # Core
35
+ "max_tokens",
36
+ "temperature",
37
+ "top_p",
38
+ "top_k",
39
+ # Penalties
40
+ "presence_penalty",
41
+ "frequency_penalty",
42
+ "repetition_penalty",
43
+ # Stops / outputs
44
+ "stop",
45
+ "stop_token_ids",
46
+ "skip_special_tokens",
47
+ "spaces_between_special_tokens",
48
+ # Search / length
49
+ "n",
50
+ "best_of",
51
+ "length_penalty",
52
+ "early_stopping",
53
+ # Misc
54
+ "logprobs",
55
+ "prompt_logprobs",
56
+ "min_p",
57
+ "seed",
58
+ }
59
+
60
+ # LLM(...) / EngineArgs (engine/load-time controls)
61
+ _VLLM_ENGINE_KEYS = {
62
+ # Model/tokenizer/impl
63
+ "tokenizer",
64
+ "tokenizer_mode",
65
+ "download_dir",
66
+ # Parallelism / memory / lengths
67
+ "tensor_parallel_size",
68
+ "pipeline_parallel_size",
69
+ "gpu_memory_utilization",
70
+ "max_model_len",
71
+ "max_num_batched_tokens",
72
+ "kv_cache_dtype",
73
+ "dtype",
74
+ # Quantization (coarse switch)
75
+ "quantization",
76
+ # Multimodal limits
77
+ "limit_mm_per_prompt",
78
+ # Execution toggles
79
+ "enforce_eager",
80
+ }
81
+
30
82
  def __init__(
31
83
  self,
32
84
  enabled: bool,
@@ -35,120 +87,147 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
35
87
  vlm_options: InlineVlmOptions,
36
88
  ):
37
89
  self.enabled = enabled
38
-
39
90
  self.vlm_options = vlm_options
40
91
 
41
- if self.enabled:
42
- from transformers import AutoProcessor
43
- from vllm import LLM, SamplingParams
44
-
45
- self.device = decide_device(
46
- accelerator_options.device,
47
- supported_devices=vlm_options.supported_devices,
48
- )
49
- _log.debug(f"Available device for VLM: {self.device}")
50
-
51
- self.max_new_tokens = vlm_options.max_new_tokens
52
- self.temperature = vlm_options.temperature
53
-
54
- repo_cache_folder = vlm_options.repo_id.replace("/", "--")
92
+ self.llm = None
93
+ self.sampling_params = None
94
+ self.processor = None # used for CHAT templating in formulate_prompt()
95
+ self.device = "cpu"
96
+ self.max_new_tokens = vlm_options.max_new_tokens
97
+ self.temperature = vlm_options.temperature
55
98
 
56
- if artifacts_path is None:
57
- artifacts_path = self.download_models(self.vlm_options.repo_id)
58
- elif (artifacts_path / repo_cache_folder).exists():
59
- artifacts_path = artifacts_path / repo_cache_folder
60
-
61
- # Initialize VLLM LLM
62
- llm_kwargs: Dict[str, Any] = {
63
- "model": str(artifacts_path),
64
- "limit_mm_per_prompt": {"image": 1},
65
- "trust_remote_code": vlm_options.trust_remote_code,
66
- "model_impl": "transformers",
67
- "gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
68
- }
69
-
70
- # Add device-specific configurations
71
-
72
- if self.device == "cpu":
73
- llm_kwargs["device"] = "cpu"
99
+ if not self.enabled:
100
+ return
74
101
 
75
- # Add quantization if specified
76
- if vlm_options.quantized:
77
- if vlm_options.load_in_8bit:
78
- llm_kwargs["quantization"] = "bitsandbytes"
102
+ from transformers import AutoProcessor
103
+ from vllm import LLM, SamplingParams
79
104
 
80
- self.llm = LLM(**llm_kwargs)
105
+ # Device selection
106
+ self.device = decide_device(
107
+ accelerator_options.device, supported_devices=vlm_options.supported_devices
108
+ )
109
+ _log.debug(f"Available device for VLM: {self.device}")
81
110
 
82
- # Initialize processor for prompt formatting
83
- self.processor = AutoProcessor.from_pretrained(
84
- artifacts_path,
85
- trust_remote_code=vlm_options.trust_remote_code,
111
+ # Resolve artifacts path / cache folder
112
+ repo_cache_folder = vlm_options.repo_id.replace("/", "--")
113
+ if artifacts_path is None:
114
+ artifacts_path = self.download_models(
115
+ self.vlm_options.repo_id, revision=self.vlm_options.revision
86
116
  )
87
-
88
- # Set up sampling parameters
89
- self.sampling_params = SamplingParams(
90
- temperature=self.temperature,
91
- max_tokens=self.max_new_tokens,
92
- stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
93
- **vlm_options.extra_generation_config,
117
+ elif (artifacts_path / repo_cache_folder).exists():
118
+ artifacts_path = artifacts_path / repo_cache_folder
119
+
120
+ # --------- Strict split & validation of extra_generation_config ---------
121
+ extra_cfg = self.vlm_options.extra_generation_config
122
+
123
+ load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
124
+ gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
125
+
126
+ unknown = sorted(
127
+ k
128
+ for k in extra_cfg.keys()
129
+ if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
130
+ )
131
+ if unknown:
132
+ _log.warning(
133
+ "Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
94
134
  )
95
135
 
136
+ # --------- Construct LLM kwargs (engine/load-time) ---------
137
+ llm_kwargs: Dict[str, Any] = {
138
+ "model": str(artifacts_path),
139
+ "model_impl": "transformers",
140
+ "limit_mm_per_prompt": {"image": 1},
141
+ "revision": self.vlm_options.revision,
142
+ "trust_remote_code": self.vlm_options.trust_remote_code,
143
+ **load_cfg,
144
+ }
145
+
146
+ if self.device == "cpu":
147
+ llm_kwargs.setdefault("enforce_eager", True)
148
+ else:
149
+ llm_kwargs.setdefault(
150
+ "gpu_memory_utilization", 0.3
151
+ ) # room for other models
152
+
153
+ # Quantization (kept as-is; coarse)
154
+ if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
155
+ llm_kwargs.setdefault("quantization", "bitsandbytes")
156
+
157
+ # Initialize vLLM LLM
158
+ self.llm = LLM(**llm_kwargs)
159
+
160
+ # Initialize processor for prompt templating (needed for CHAT style)
161
+ self.processor = AutoProcessor.from_pretrained(
162
+ artifacts_path,
163
+ trust_remote_code=self.vlm_options.trust_remote_code,
164
+ revision=self.vlm_options.revision,
165
+ )
166
+
167
+ # --------- SamplingParams (runtime) ---------
168
+ self.sampling_params = SamplingParams(
169
+ temperature=self.temperature,
170
+ max_tokens=self.max_new_tokens,
171
+ stop=(self.vlm_options.stop_strings or None),
172
+ **gen_cfg,
173
+ )
174
+
96
175
  def __call__(
97
176
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
98
177
  ) -> Iterable[Page]:
178
+ # If disabled, pass-through
179
+ if not self.enabled:
180
+ for page in page_batch:
181
+ yield page
182
+ return
183
+
99
184
  page_list = list(page_batch)
100
185
  if not page_list:
101
186
  return
102
187
 
103
- valid_pages = []
104
- invalid_pages = []
188
+ # Preserve original order
189
+ original_order = page_list[:]
105
190
 
191
+ # Separate valid/invalid
192
+ valid_pages: list[Page] = []
193
+ invalid_pages: list[Page] = []
106
194
  for page in page_list:
107
195
  assert page._backend is not None
108
- if not page._backend.is_valid():
109
- invalid_pages.append(page)
110
- else:
196
+ if page._backend.is_valid():
111
197
  valid_pages.append(page)
198
+ else:
199
+ invalid_pages.append(page)
112
200
 
113
- # Process valid pages in batch
114
201
  if valid_pages:
115
202
  with TimeRecorder(conv_res, "vlm"):
116
- # Prepare images and prompts for batch processing
117
- images = []
118
- user_prompts = []
119
- pages_with_images = []
203
+ images: list[Image] = []
204
+ user_prompts: list[str] = []
205
+ pages_with_images: list[Page] = []
120
206
 
121
207
  for page in valid_pages:
122
208
  assert page.size is not None
123
209
  hi_res_image = page.get_image(
124
- scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
210
+ scale=self.vlm_options.scale,
211
+ max_size=self.vlm_options.max_size,
125
212
  )
213
+ if hi_res_image is None:
214
+ continue
126
215
 
127
- # Only process pages with valid images
128
- if hi_res_image is not None:
129
- images.append(hi_res_image)
216
+ images.append(hi_res_image)
130
217
 
131
- # Define prompt structure
132
- if callable(self.vlm_options.prompt):
133
- user_prompt = self.vlm_options.prompt(page.parsed_page)
134
- else:
135
- user_prompt = self.vlm_options.prompt
218
+ # Define prompt structure
219
+ user_prompt = self.vlm_options.build_prompt(page.parsed_page)
136
220
 
137
- user_prompts.append(user_prompt)
138
- pages_with_images.append(page)
221
+ user_prompts.append(user_prompt)
222
+ pages_with_images.append(page)
139
223
 
140
- # Use process_images for the actual inference
141
- if images: # Only if we have valid images
224
+ if images:
142
225
  predictions = list(self.process_images(images, user_prompts))
143
-
144
- # Attach results to pages
145
226
  for page, prediction in zip(pages_with_images, predictions):
146
227
  page.predictions.vlm_response = prediction
147
228
 
148
- # Yield all pages (valid and invalid)
149
- for page in invalid_pages:
150
- yield page
151
- for page in valid_pages:
229
+ # Yield in original order
230
+ for page in original_order:
152
231
  yield page
153
232
 
154
233
  def process_images(
@@ -156,50 +235,33 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
156
235
  image_batch: Iterable[Union[Image, np.ndarray]],
157
236
  prompt: Union[str, list[str]],
158
237
  ) -> Iterable[VlmPrediction]:
159
- """Process raw images without page metadata in a single batched inference call.
160
-
161
- Args:
162
- image_batch: Iterable of PIL Images or numpy arrays
163
- prompt: Either:
164
- - str: Single prompt used for all images
165
- - list[str]: List of prompts (one per image, must match image count)
238
+ """Process images in a single batched vLLM inference call."""
239
+ import numpy as np
240
+ from PIL import Image as PILImage
166
241
 
167
- Raises:
168
- ValueError: If prompt list length doesn't match image count.
169
- """
242
+ # -- Normalize images to RGB PIL
170
243
  pil_images: list[Image] = []
171
-
172
244
  for img in image_batch:
173
- # Convert numpy array to PIL Image if needed
174
245
  if isinstance(img, np.ndarray):
175
- if img.ndim == 3 and img.shape[2] in [3, 4]:
176
- from PIL import Image as PILImage
177
-
246
+ if img.ndim == 3 and img.shape[2] in (3, 4):
178
247
  pil_img = PILImage.fromarray(img.astype(np.uint8))
179
248
  elif img.ndim == 2:
180
- from PIL import Image as PILImage
181
-
182
249
  pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
183
250
  else:
184
251
  raise ValueError(f"Unsupported numpy array shape: {img.shape}")
185
252
  else:
186
253
  pil_img = img
187
-
188
- # Ensure image is in RGB mode (handles RGBA, L, etc.)
189
254
  if pil_img.mode != "RGB":
190
255
  pil_img = pil_img.convert("RGB")
191
-
192
256
  pil_images.append(pil_img)
193
257
 
194
- if len(pil_images) == 0:
258
+ if not pil_images:
195
259
  return
196
260
 
197
- # Handle prompt parameter
261
+ # Normalize prompts
198
262
  if isinstance(prompt, str):
199
- # Single prompt for all images
200
263
  user_prompts = [prompt] * len(pil_images)
201
264
  elif isinstance(prompt, list):
202
- # List of prompts (one per image)
203
265
  if len(prompt) != len(pil_images):
204
266
  raise ValueError(
205
267
  f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
@@ -208,28 +270,31 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
208
270
  else:
209
271
  raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
210
272
 
211
- # Format prompts individually
212
- prompts: list[str] = [
213
- self.formulate_prompt(user_prompt) for user_prompt in user_prompts
214
- ]
273
+ # Format prompts
274
+ prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
215
275
 
216
- # Prepare VLLM inputs
217
- llm_inputs = []
218
- for prompt, image in zip(prompts, pil_images):
219
- llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
276
+ # Build vLLM inputs
277
+ llm_inputs = [
278
+ {"prompt": p, "multi_modal_data": {"image": im}}
279
+ for p, im in zip(prompts, pil_images)
280
+ ]
220
281
 
282
+ # Generate
283
+ assert self.llm is not None and self.sampling_params is not None
221
284
  start_time = time.time()
222
285
  outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
223
286
  generation_time = time.time() - start_time
224
287
 
225
- # Logging tokens count for the first sample as a representative metric
226
- if len(outputs) > 0:
227
- num_tokens = len(outputs[0].outputs[0].token_ids)
228
- _log.debug(
229
- f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
230
- )
288
+ # Optional debug
289
+ if outputs:
290
+ try:
291
+ num_tokens = len(outputs[0].outputs[0].token_ids)
292
+ _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
293
+ except Exception:
294
+ pass
231
295
 
296
+ # Emit predictions
232
297
  for output in outputs:
233
- # Apply decode_response to the output text
234
- decoded_text = self.vlm_options.decode_response(output.outputs[0].text)
298
+ text = output.outputs[0].text if output.outputs else ""
299
+ decoded_text = self.vlm_options.decode_response(text)
235
300
  yield VlmPrediction(text=decoded_text, generation_time=generation_time)