docling 2.45.0__py3-none-any.whl → 2.47.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,4 @@
1
1
  import re
2
- from collections import Counter
3
2
  from collections.abc import Iterable
4
3
  from pathlib import Path
5
4
  from typing import List, Literal, Optional, Tuple, Union
@@ -13,10 +12,11 @@ from docling_core.types.doc import (
13
12
  TextItem,
14
13
  )
15
14
  from docling_core.types.doc.labels import CodeLanguageLabel
16
- from PIL import Image, ImageOps
15
+ from PIL import Image
17
16
  from pydantic import BaseModel
17
+ from transformers import AutoModelForImageTextToText, AutoProcessor
18
18
 
19
- from docling.datamodel.accelerator_options import AcceleratorOptions
19
+ from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
20
20
  from docling.datamodel.base_models import ItemAndImageEnrichmentElement
21
21
  from docling.models.base_model import BaseItemAndImageEnrichmentModel
22
22
  from docling.models.utils.hf_model_download import download_hf_model
@@ -65,9 +65,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
65
65
  Processes the given batch of elements and enriches them with predictions.
66
66
  """
67
67
 
68
- _model_repo_folder = "ds4sd--CodeFormula"
68
+ _model_repo_folder = "ds4sd--CodeFormulaV2"
69
69
  elements_batch_size = 5
70
- images_scale = 1.66 # = 120 dpi, aligned with training data resolution
70
+ images_scale = 1.67 # = 120 dpi, aligned with training data resolution
71
71
  expansion_factor = 0.18
72
72
 
73
73
  def __init__(
@@ -95,10 +95,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
95
95
  self.options = options
96
96
 
97
97
  if self.enabled:
98
- device = decide_device(accelerator_options.device)
99
-
100
- from docling_ibm_models.code_formula_model.code_formula_predictor import (
101
- CodeFormulaPredictor,
98
+ self.device = decide_device(
99
+ accelerator_options.device,
100
+ supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
102
101
  )
103
102
 
104
103
  if artifacts_path is None:
@@ -106,11 +105,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
106
105
  else:
107
106
  artifacts_path = artifacts_path / self._model_repo_folder
108
107
 
109
- self.code_formula_model = CodeFormulaPredictor(
110
- artifacts_path=str(artifacts_path),
111
- device=device,
112
- num_threads=accelerator_options.num_threads,
108
+ self._processor = AutoProcessor.from_pretrained(
109
+ artifacts_path,
110
+ )
111
+ self._model_max_length = self._processor.tokenizer.model_max_length
112
+ self._model = AutoModelForImageTextToText.from_pretrained(
113
+ artifacts_path, device_map=self.device
113
114
  )
115
+ self._model.eval()
114
116
 
115
117
  @staticmethod
116
118
  def download_models(
@@ -119,8 +121,8 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
119
121
  progress: bool = False,
120
122
  ) -> Path:
121
123
  return download_hf_model(
122
- repo_id="ds4sd/CodeFormula",
123
- revision="v1.0.2",
124
+ repo_id="ds4sd/CodeFormulaV2",
125
+ revision="main",
124
126
  local_dir=local_dir,
125
127
  force=force,
126
128
  progress=progress,
@@ -172,7 +174,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
172
174
  - The second element is the extracted language if a match is found;
173
175
  otherwise, `None`.
174
176
  """
175
- pattern = r"^<_([^_>]+)_>\s(.*)"
177
+ pattern = r"^<_([^_>]+)_>\s*(.*)"
176
178
  match = re.match(pattern, input_string, flags=re.DOTALL)
177
179
  if match:
178
180
  language = str(match.group(1)) # the captured programming language
@@ -203,81 +205,74 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
203
205
  except ValueError:
204
206
  return CodeLanguageLabel.UNKNOWN
205
207
 
206
- def _get_most_frequent_edge_color(self, pil_img: Image.Image):
208
+ def _get_prompt(self, label: str) -> str:
207
209
  """
208
- Compute the most frequent color along the outer edges of a PIL image.
210
+ Constructs the prompt for the model based on the input label.
209
211
 
210
212
  Parameters
211
213
  ----------
212
- pil_img : Image.Image
213
- A PIL Image in any mode (L, RGB, RGBA, etc.).
214
+ label : str
215
+ The type of input, either 'code' or 'formula'.
214
216
 
215
217
  Returns
216
218
  -------
217
- (int) or (tuple): The most common edge color as a scalar (for grayscale) or
218
- tuple (for RGB/RGBA).
219
+ str
220
+ The constructed prompt including necessary tokens and query.
221
+
222
+ Raises
223
+ ------
224
+ NotImplementedError
225
+ If the label is not 'code' or 'formula'.
219
226
  """
220
- # Convert to NumPy array for easy pixel access
221
- img_np = np.array(pil_img)
227
+ if label == "code":
228
+ query = "<code>"
229
+ elif label == "formula":
230
+ query = "<formula>"
231
+ else:
232
+ raise NotImplementedError("Label must be either code or formula")
222
233
 
223
- if img_np.ndim == 2:
224
- # Grayscale-like image: shape (H, W)
225
- # Extract edges: top row, bottom row, left col, right col
226
- top = img_np[0, :] # shape (W,)
227
- bottom = img_np[-1, :] # shape (W,)
228
- left = img_np[:, 0] # shape (H,)
229
- right = img_np[:, -1] # shape (H,)
234
+ messages = [
235
+ {
236
+ "role": "user",
237
+ "content": [{"type": "image"}, {"type": "text", "text": query}],
238
+ },
239
+ ]
230
240
 
231
- # Concatenate all edges
232
- edges = np.concatenate([top, bottom, left, right])
241
+ prompt = self._processor.apply_chat_template(
242
+ messages, add_generation_prompt=True
243
+ )
233
244
 
234
- # Count frequencies
235
- freq = Counter(edges.tolist())
236
- most_common_value, _ = freq.most_common(1)[0]
237
- return int(most_common_value) # single channel color
245
+ return prompt
238
246
 
239
- else:
240
- # Color image: shape (H, W, C)
241
- top = img_np[0, :, :] # shape (W, C)
242
- bottom = img_np[-1, :, :] # shape (W, C)
243
- left = img_np[:, 0, :] # shape (H, C)
244
- right = img_np[:, -1, :] # shape (H, C)
245
-
246
- # Concatenate edges along first axis
247
- edges = np.concatenate([top, bottom, left, right], axis=0)
248
-
249
- # Convert each color to a tuple for counting
250
- edges_as_tuples = [tuple(pixel) for pixel in edges]
251
- freq = Counter(edges_as_tuples)
252
- most_common_value, _ = freq.most_common(1)[0]
253
- return most_common_value # e.g. (R, G, B) or (R, G, B, A)
254
-
255
- def _pad_with_most_frequent_edge_color(
256
- self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
257
- ):
247
+ def _post_process(self, texts: list[str]) -> list[str]:
258
248
  """
259
- Pads an image (PIL or NumPy array) using the most frequent edge color.
249
+ Processes a list of text strings by truncating at '<end_of_utterance>' and
250
+ removing a predefined set of unwanted substrings.
260
251
 
261
252
  Parameters
262
253
  ----------
263
- img : Union[Image.Image, np.ndarray]
264
- The original image.
265
- padding : tuple
266
- Padding (left, top, right, bottom) in pixels.
254
+ texts : list[str]
255
+ A list of strings to be post-processed.
267
256
 
268
257
  Returns
269
258
  -------
270
- Image.Image: A new PIL image with the specified padding.
259
+ list[str]
260
+ A list of cleaned strings with specified substrings removed and truncated at
261
+ '<end_of_utterance>' if present.
271
262
  """
272
- if isinstance(img, np.ndarray):
273
- pil_img = Image.fromarray(img)
274
- else:
275
- pil_img = img
263
+ to_remove = ["</code>", "</formula>", "<loc_0><loc_0><loc_500><loc_500>"]
276
264
 
277
- most_freq_color = self._get_most_frequent_edge_color(pil_img)
265
+ def clean_text(text: str) -> str:
266
+ idx = text.find("<end_of_utterance>")
267
+ if idx != -1:
268
+ text = text[:idx]
278
269
 
279
- padded_img = ImageOps.expand(pil_img, border=padding, fill=most_freq_color)
280
- return padded_img
270
+ for token in to_remove:
271
+ if token in text:
272
+ text = text.replace(token, "")
273
+ return text.lstrip()
274
+
275
+ return [clean_text(t) for t in texts]
281
276
 
282
277
  def __call__(
283
278
  self,
@@ -308,14 +303,30 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
308
303
  images: List[Union[Image.Image, np.ndarray]] = []
309
304
  elements: List[TextItem] = []
310
305
  for el in element_batch:
311
- assert isinstance(el.item, TextItem)
312
- elements.append(el.item)
313
- labels.append(el.item.label)
314
- images.append(
315
- self._pad_with_most_frequent_edge_color(el.image, (20, 10, 20, 10))
316
- )
306
+ elements.append(el.item) # type: ignore[arg-type]
307
+ labels.append(el.item.label) # type: ignore[attr-defined]
308
+ images.append(el.image)
309
+
310
+ prompts = [self._get_prompt(label) for label in labels]
311
+ inputs = self._processor(
312
+ text=prompts,
313
+ images=images,
314
+ return_tensors="pt",
315
+ )
316
+ inputs = inputs.to(self.device)
317
317
 
318
- outputs = self.code_formula_model.predict(images, labels)
318
+ gen_kwargs = dict(
319
+ max_new_tokens=self._model_max_length - inputs.input_ids.shape[1],
320
+ use_cache=True,
321
+ do_sample=False,
322
+ )
323
+
324
+ generated_ids = self._model.generate(**inputs, **gen_kwargs)
325
+
326
+ outputs = self._processor.batch_decode(
327
+ generated_ids[:, inputs.input_ids.shape[1] :], skip_special_tokens=False
328
+ )
329
+ outputs = self._post_process(outputs)
319
330
 
320
331
  for item, output in zip(elements, outputs):
321
332
  if isinstance(item, CodeItem):
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
17
17
 
18
18
  class PagePreprocessingOptions(BaseModel):
19
19
  images_scale: Optional[float]
20
+ skip_cell_extraction: bool = (
21
+ False # Skip text cell extraction for VLM-only processing
22
+ )
20
23
 
21
24
 
22
25
  class PagePreprocessingModel(BasePageModel):
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
41
44
  else:
42
45
  with TimeRecorder(conv_res, "page_parse"):
43
46
  page = self._populate_page_images(page)
44
- page = self._parse_page_cells(conv_res, page)
47
+ if not self.options.skip_cell_extraction:
48
+ page = self._parse_page_cells(conv_res, page)
45
49
  yield page
46
50
 
47
51
  # Generate the page image and store it in the page object
@@ -4,6 +4,7 @@ from pathlib import Path
4
4
  from typing import Optional, Type, Union
5
5
 
6
6
  from PIL import Image
7
+ from transformers import AutoModelForImageTextToText
7
8
 
8
9
  from docling.datamodel.accelerator_options import AcceleratorOptions
9
10
  from docling.datamodel.pipeline_options import (
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
63
64
  # Initialize processor and model
64
65
  with _model_init_lock:
65
66
  self.processor = AutoProcessor.from_pretrained(artifacts_path)
66
- self.model = AutoModelForVision2Seq.from_pretrained(
67
+ self.model = AutoModelForImageTextToText.from_pretrained(
67
68
  artifacts_path,
68
69
  device_map=self.device,
69
70
  torch_dtype=torch.bfloat16,
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
71
72
  "flash_attention_2"
72
73
  if self.device.startswith("cuda")
73
74
  and accelerator_options.cuda_use_flash_attention2
74
- else "eager"
75
+ else "sdpa"
75
76
  ),
76
77
  )
78
+ self.model = torch.compile(self.model) # type: ignore
77
79
 
78
80
  self.provenance = f"{self.options.repo_id}"
79
81
 
@@ -320,6 +320,8 @@ class TesseractOcrCliModel(BaseOcrModel):
320
320
 
321
321
 
322
322
  def _parse_orientation(df_osd: pd.DataFrame) -> int:
323
- orientations = df_osd.loc[df_osd["key"] == "Orientation in degrees"].value.tolist()
324
- orientation = parse_tesseract_orientation(orientations[0].strip())
323
+ # For strictly optimal performance with invariant dataframe format:
324
+ mask = df_osd["key"].to_numpy() == "Orientation in degrees"
325
+ orientation_val = df_osd["value"].to_numpy()[mask][0]
326
+ orientation = parse_tesseract_orientation(orientation_val.strip())
325
327
  return orientation
@@ -0,0 +1 @@
1
+
@@ -3,7 +3,11 @@ import logging
3
3
  import time
4
4
  from collections.abc import Iterable
5
5
  from pathlib import Path
6
- from typing import Any, Optional
6
+ from typing import Any, Optional, Union
7
+
8
+ import numpy as np
9
+ from PIL.Image import Image
10
+ from transformers import StoppingCriteriaList, StopStringCriteria
7
11
 
8
12
  from docling.datamodel.accelerator_options import (
9
13
  AcceleratorOptions,
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
15
19
  TransformersModelType,
16
20
  TransformersPromptStyle,
17
21
  )
18
- from docling.models.base_model import BasePageModel
22
+ from docling.models.base_model import BaseVlmPageModel
19
23
  from docling.models.utils.hf_model_download import (
20
24
  HuggingFaceModelDownloadMixin,
21
25
  )
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
25
29
  _log = logging.getLogger(__name__)
26
30
 
27
31
 
28
- class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
32
+ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
29
33
  def __init__(
30
34
  self,
31
35
  enabled: bool,
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
103
107
  artifacts_path,
104
108
  trust_remote_code=vlm_options.trust_remote_code,
105
109
  )
110
+ self.processor.tokenizer.padding_side = "left"
111
+
106
112
  self.vlm_model = model_cls.from_pretrained(
107
113
  artifacts_path,
108
114
  device_map=self.device,
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
111
117
  "flash_attention_2"
112
118
  if self.device.startswith("cuda")
113
119
  and accelerator_options.cuda_use_flash_attention2
114
- else "eager"
120
+ else "sdpa"
115
121
  ),
116
122
  trust_remote_code=vlm_options.trust_remote_code,
117
123
  )
124
+ self.vlm_model = torch.compile(self.vlm_model) # type: ignore
118
125
 
119
126
  # Load generation config
120
127
  self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
122
129
  def __call__(
123
130
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
124
131
  ) -> Iterable[Page]:
125
- for page in page_batch:
132
+ page_list = list(page_batch)
133
+ if not page_list:
134
+ return
135
+
136
+ valid_pages = []
137
+ invalid_pages = []
138
+
139
+ for page in page_list:
126
140
  assert page._backend is not None
127
141
  if not page._backend.is_valid():
128
- yield page
142
+ invalid_pages.append(page)
129
143
  else:
130
- with TimeRecorder(conv_res, "vlm"):
131
- assert page.size is not None
144
+ valid_pages.append(page)
132
145
 
146
+ # Process valid pages in batch
147
+ if valid_pages:
148
+ with TimeRecorder(conv_res, "vlm"):
149
+ # Prepare images and prompts for batch processing
150
+ images = []
151
+ user_prompts = []
152
+ pages_with_images = []
153
+
154
+ for page in valid_pages:
155
+ assert page.size is not None
133
156
  hi_res_image = page.get_image(
134
157
  scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
135
158
  )
136
159
 
137
- # Define prompt structure
138
- user_prompt = self.vlm_options.build_prompt(page.parsed_page)
139
- prompt = self.formulate_prompt(user_prompt)
140
-
141
- inputs = self.processor(
142
- text=prompt, images=[hi_res_image], return_tensors="pt"
143
- ).to(self.device)
144
-
145
- start_time = time.time()
146
- # Call model to generate:
147
- generated_ids = self.vlm_model.generate(
148
- **inputs,
149
- max_new_tokens=self.max_new_tokens,
150
- use_cache=self.use_cache,
151
- temperature=self.temperature,
152
- generation_config=self.generation_config,
153
- **self.vlm_options.extra_generation_config,
154
- )
160
+ # Only process pages with valid images
161
+ if hi_res_image is not None:
162
+ images.append(hi_res_image)
155
163
 
156
- generation_time = time.time() - start_time
157
- generated_texts = self.processor.batch_decode(
158
- generated_ids[:, inputs["input_ids"].shape[1] :],
159
- skip_special_tokens=False,
160
- )[0]
164
+ # Define prompt structure
165
+ user_prompt = self.vlm_options.build_prompt(page.parsed_page)
161
166
 
162
- num_tokens = len(generated_ids[0])
163
- _log.debug(
164
- f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
165
- )
166
- generated_texts = self.vlm_options.decode_response(generated_texts)
167
- page.predictions.vlm_response = VlmPrediction(
168
- text=generated_texts,
169
- generation_time=generation_time,
167
+ user_prompts.append(user_prompt)
168
+ pages_with_images.append(page)
169
+
170
+ # Use process_images for the actual inference
171
+ if images: # Only if we have valid images
172
+ predictions = list(self.process_images(images, user_prompts))
173
+
174
+ # Attach results to pages
175
+ for page, prediction in zip(pages_with_images, predictions):
176
+ page.predictions.vlm_response = prediction
177
+
178
+ # Yield all pages (valid and invalid)
179
+ for page in invalid_pages:
180
+ yield page
181
+ for page in valid_pages:
182
+ yield page
183
+
184
+ def process_images(
185
+ self,
186
+ image_batch: Iterable[Union[Image, np.ndarray]],
187
+ prompt: Union[str, list[str]],
188
+ ) -> Iterable[VlmPrediction]:
189
+ """
190
+ Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
191
+ - Lets the processor handle all padding & batching for text+images.
192
+ - Trims generated sequences per row using attention_mask (no pad-id fallbacks).
193
+ - Keeps your formulate_prompt() exactly as-is.
194
+ """
195
+ import numpy as np
196
+ import torch
197
+ from PIL import Image as PILImage
198
+
199
+ # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
200
+ pil_images: list[Image] = []
201
+ for img in image_batch:
202
+ if isinstance(img, np.ndarray):
203
+ if img.ndim == 3 and img.shape[2] in (3, 4):
204
+ pil_img = PILImage.fromarray(img.astype(np.uint8))
205
+ elif img.ndim == 2:
206
+ pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
207
+ else:
208
+ raise ValueError(f"Unsupported numpy array shape: {img.shape}")
209
+ else:
210
+ pil_img = img
211
+ if pil_img.mode != "RGB":
212
+ pil_img = pil_img.convert("RGB")
213
+ pil_images.append(pil_img)
214
+
215
+ if not pil_images:
216
+ return
217
+
218
+ # -- Normalize prompts (1 per image)
219
+ if isinstance(prompt, str):
220
+ user_prompts = [prompt] * len(pil_images)
221
+ else:
222
+ if len(prompt) != len(pil_images):
223
+ raise ValueError(
224
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
225
+ )
226
+ user_prompts = prompt
227
+
228
+ # Use your prompt formatter verbatim
229
+ if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
230
+ inputs = self.processor(
231
+ pil_images,
232
+ return_tensors="pt",
233
+ padding=True, # pad across batch for both text and vision
234
+ **self.vlm_options.extra_processor_kwargs,
235
+ )
236
+ else:
237
+ prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
238
+
239
+ # -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
240
+ inputs = self.processor(
241
+ text=prompts,
242
+ images=pil_images,
243
+ return_tensors="pt",
244
+ padding=True, # pad across batch for both text and vision
245
+ **self.vlm_options.extra_processor_kwargs,
246
+ )
247
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
248
+
249
+ # -- Optional stopping criteria
250
+ stopping_criteria = None
251
+ if self.vlm_options.stop_strings:
252
+ stopping_criteria = StoppingCriteriaList(
253
+ [
254
+ StopStringCriteria(
255
+ stop_strings=self.vlm_options.stop_strings,
256
+ tokenizer=self.processor.tokenizer,
170
257
  )
258
+ ]
259
+ )
171
260
 
172
- yield page
173
-
174
- def formulate_prompt(self, user_prompt: str) -> str:
175
- """Formulate a prompt for the VLM."""
176
-
177
- if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
178
- return user_prompt
179
-
180
- elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
181
- _log.debug("Using specialized prompt for Phi-4")
182
- # more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
183
-
184
- user_prompt = "<|user|>"
185
- assistant_prompt = "<|assistant|>"
186
- prompt_suffix = "<|end|>"
187
-
188
- prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
189
- _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
190
-
191
- return prompt
192
-
193
- elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
194
- messages = [
195
- {
196
- "role": "user",
197
- "content": [
198
- {
199
- "type": "text",
200
- "text": "This is a page from a document.",
201
- },
202
- {"type": "image"},
203
- {"type": "text", "text": user_prompt},
204
- ],
205
- }
206
- ]
207
- prompt = self.processor.apply_chat_template(
208
- messages, add_generation_prompt=False
261
+ # -- Generate (Image-Text-to-Text class expects these inputs from processor)
262
+ gen_kwargs = {
263
+ **inputs,
264
+ "max_new_tokens": self.max_new_tokens,
265
+ "use_cache": self.use_cache,
266
+ "generation_config": self.generation_config,
267
+ **self.vlm_options.extra_generation_config,
268
+ }
269
+ if self.temperature > 0:
270
+ gen_kwargs["do_sample"] = True
271
+ gen_kwargs["temperature"] = self.temperature
272
+ else:
273
+ gen_kwargs["do_sample"] = False
274
+
275
+ if stopping_criteria is not None:
276
+ gen_kwargs["stopping_criteria"] = stopping_criteria
277
+
278
+ start_time = time.time()
279
+ with torch.inference_mode():
280
+ generated_ids = self.vlm_model.generate(**gen_kwargs)
281
+ generation_time = time.time() - start_time
282
+
283
+ input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
284
+ trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
285
+
286
+ # -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
287
+ decode_fn = getattr(self.processor, "batch_decode", None)
288
+ if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
289
+ decode_fn = self.processor.tokenizer.batch_decode
290
+ if decode_fn is None:
291
+ raise RuntimeError(
292
+ "Neither processor.batch_decode nor tokenizer.batch_decode is available."
209
293
  )
210
- return prompt
211
294
 
212
- raise RuntimeError(
213
- f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
295
+ decoded_texts: list[str] = decode_fn(
296
+ trimmed_sequences, skip_special_tokens=False
214
297
  )
298
+
299
+ # -- Clip off pad tokens from decoded texts
300
+ pad_token = self.processor.tokenizer.pad_token
301
+ if pad_token:
302
+ decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
303
+
304
+ # -- Optional logging
305
+ if generated_ids.shape[0] > 0:
306
+ _log.debug(
307
+ f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
308
+ f"for batch size {generated_ids.shape[0]}."
309
+ )
310
+
311
+ for text in decoded_texts:
312
+ # Apply decode_response to the output text
313
+ decoded_text = self.vlm_options.decode_response(text)
314
+ yield VlmPrediction(text=decoded_text, generation_time=generation_time)