docling 2.46.0__py3-none-any.whl → 2.47.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -12,6 +12,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
12
12
  InlineVlmOptions,
13
13
  ResponseFormat,
14
14
  TransformersModelType,
15
+ TransformersPromptStyle,
15
16
  )
16
17
 
17
18
  _log = logging.getLogger(__name__)
@@ -26,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
26
27
  supported_devices=[AcceleratorDevice.MPS],
27
28
  scale=2.0,
28
29
  temperature=0.0,
30
+ stop_strings=["</doctag>", "<end_of_utterance>"],
29
31
  )
30
32
 
31
33
  SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
@@ -33,16 +35,74 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
33
35
  prompt="Convert this page to docling.",
34
36
  response_format=ResponseFormat.DOCTAGS,
35
37
  inference_framework=InferenceFramework.TRANSFORMERS,
36
- transformers_model_type=TransformersModelType.AUTOMODEL_VISION2SEQ,
38
+ transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
37
39
  supported_devices=[
38
40
  AcceleratorDevice.CPU,
39
41
  AcceleratorDevice.CUDA,
42
+ ],
43
+ torch_dtype="bfloat16",
44
+ scale=2.0,
45
+ temperature=0.0,
46
+ stop_strings=["</doctag>", "<end_of_utterance>"],
47
+ )
48
+
49
+ SMOLDOCLING_VLLM = InlineVlmOptions(
50
+ repo_id="ds4sd/SmolDocling-256M-preview",
51
+ prompt="Convert this page to docling.",
52
+ response_format=ResponseFormat.DOCTAGS,
53
+ inference_framework=InferenceFramework.VLLM,
54
+ supported_devices=[
55
+ AcceleratorDevice.CUDA,
56
+ ],
57
+ scale=2.0,
58
+ temperature=0.0,
59
+ stop_strings=["</doctag>", "<end_of_utterance>"],
60
+ )
61
+
62
+ # SmolVLM-256M-Instruct
63
+ SMOLVLM256_TRANSFORMERS = InlineVlmOptions(
64
+ repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
65
+ prompt="Transcribe this image to plain text.",
66
+ response_format=ResponseFormat.PLAINTEXT,
67
+ inference_framework=InferenceFramework.TRANSFORMERS,
68
+ transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
69
+ supported_devices=[
70
+ AcceleratorDevice.CPU,
71
+ AcceleratorDevice.CUDA,
72
+ # AcceleratorDevice.MPS,
73
+ ],
74
+ torch_dtype="bfloat16",
75
+ scale=2.0,
76
+ temperature=0.0,
77
+ )
78
+
79
+ # SmolVLM2-2.2b-Instruct
80
+ SMOLVLM256_MLX = InlineVlmOptions(
81
+ repo_id="moot20/SmolVLM-256M-Instruct-MLX",
82
+ prompt="Extract the text.",
83
+ response_format=ResponseFormat.DOCTAGS,
84
+ inference_framework=InferenceFramework.MLX,
85
+ transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
86
+ supported_devices=[
40
87
  AcceleratorDevice.MPS,
41
88
  ],
42
89
  scale=2.0,
43
90
  temperature=0.0,
44
91
  )
45
92
 
93
+ SMOLVLM256_VLLM = InlineVlmOptions(
94
+ repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
95
+ prompt="Transcribe this image to plain text.",
96
+ response_format=ResponseFormat.PLAINTEXT,
97
+ inference_framework=InferenceFramework.VLLM,
98
+ supported_devices=[
99
+ AcceleratorDevice.CUDA,
100
+ ],
101
+ scale=2.0,
102
+ temperature=0.0,
103
+ )
104
+
105
+
46
106
  # GraniteVision
47
107
  GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
48
108
  repo_id="ibm-granite/granite-vision-3.2-2b",
@@ -59,6 +119,18 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
59
119
  temperature=0.0,
60
120
  )
61
121
 
122
+ GRANITE_VISION_VLLM = InlineVlmOptions(
123
+ repo_id="ibm-granite/granite-vision-3.2-2b",
124
+ prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
125
+ response_format=ResponseFormat.MARKDOWN,
126
+ inference_framework=InferenceFramework.VLLM,
127
+ supported_devices=[
128
+ AcceleratorDevice.CUDA,
129
+ ],
130
+ scale=2.0,
131
+ temperature=0.0,
132
+ )
133
+
62
134
  GRANITE_VISION_OLLAMA = ApiVlmOptions(
63
135
  url=AnyUrl("http://localhost:11434/v1/chat/completions"),
64
136
  params={"model": "granite3.2-vision:2b"},
@@ -116,6 +188,26 @@ QWEN25_VL_3B_MLX = InlineVlmOptions(
116
188
  temperature=0.0,
117
189
  )
118
190
 
191
+ # GoT 2.0
192
+ GOT2_TRANSFORMERS = InlineVlmOptions(
193
+ repo_id="stepfun-ai/GOT-OCR-2.0-hf",
194
+ prompt="",
195
+ response_format=ResponseFormat.MARKDOWN,
196
+ inference_framework=InferenceFramework.TRANSFORMERS,
197
+ transformers_prompt_style=TransformersPromptStyle.NONE,
198
+ transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
199
+ supported_devices=[
200
+ AcceleratorDevice.CPU,
201
+ AcceleratorDevice.CUDA,
202
+ # AcceleratorDevice.MPS,
203
+ ],
204
+ scale=2.0,
205
+ temperature=0.0,
206
+ stop_strings=["<|im_end|>"],
207
+ extra_processor_kwargs={"format": True},
208
+ )
209
+
210
+
119
211
  # Gemma-3
120
212
  GEMMA3_12B_MLX = InlineVlmOptions(
121
213
  repo_id="mlx-community/gemma-3-12b-it-bf16",
@@ -137,8 +229,29 @@ GEMMA3_27B_MLX = InlineVlmOptions(
137
229
  temperature=0.0,
138
230
  )
139
231
 
232
+ # Dolphin
233
+
234
+ DOLPHIN_TRANSFORMERS = InlineVlmOptions(
235
+ repo_id="ByteDance/Dolphin",
236
+ prompt="<s>Read text in the image. <Answer/>",
237
+ response_format=ResponseFormat.MARKDOWN,
238
+ inference_framework=InferenceFramework.TRANSFORMERS,
239
+ transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
240
+ transformers_prompt_style=TransformersPromptStyle.RAW,
241
+ supported_devices=[
242
+ AcceleratorDevice.CUDA,
243
+ AcceleratorDevice.CPU,
244
+ AcceleratorDevice.MPS,
245
+ ],
246
+ scale=2.0,
247
+ temperature=0.0,
248
+ )
249
+
140
250
 
141
251
  class VlmModelType(str, Enum):
142
252
  SMOLDOCLING = "smoldocling"
253
+ SMOLDOCLING_VLLM = "smoldocling_vllm"
143
254
  GRANITE_VISION = "granite_vision"
255
+ GRANITE_VISION_VLLM = "granite_vision_vllm"
144
256
  GRANITE_VISION_OLLAMA = "granite_vision_ollama"
257
+ GOT_OCR_2 = "got_ocr_2"
@@ -1,13 +1,24 @@
1
+ import logging
1
2
  from abc import ABC, abstractmethod
2
3
  from collections.abc import Iterable
3
- from typing import Generic, Optional, Protocol, Type
4
+ from typing import Any, Generic, Optional, Protocol, Type, Union
4
5
 
6
+ import numpy as np
5
7
  from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
8
+ from PIL.Image import Image
6
9
  from typing_extensions import TypeVar
7
10
 
8
- from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
11
+ from docling.datamodel.base_models import (
12
+ ItemAndImageEnrichmentElement,
13
+ Page,
14
+ VlmPrediction,
15
+ )
9
16
  from docling.datamodel.document import ConversionResult
10
17
  from docling.datamodel.pipeline_options import BaseOptions
18
+ from docling.datamodel.pipeline_options_vlm_model import (
19
+ InlineVlmOptions,
20
+ TransformersPromptStyle,
21
+ )
11
22
  from docling.datamodel.settings import settings
12
23
 
13
24
 
@@ -26,6 +37,88 @@ class BasePageModel(ABC):
26
37
  pass
27
38
 
28
39
 
40
+ class BaseVlmModel(ABC):
41
+ """Base class for Vision-Language Models that adds image processing capability."""
42
+
43
+ @abstractmethod
44
+ def process_images(
45
+ self,
46
+ image_batch: Iterable[Union[Image, np.ndarray]],
47
+ prompt: Union[str, list[str]],
48
+ ) -> Iterable[VlmPrediction]:
49
+ """Process raw images without page metadata.
50
+
51
+ Args:
52
+ image_batch: Iterable of PIL Images or numpy arrays
53
+ prompt: Either:
54
+ - str: Single prompt used for all images
55
+ - list[str]: List of prompts (one per image, must match image count)
56
+
57
+ Raises:
58
+ ValueError: If prompt list length doesn't match image count.
59
+ """
60
+
61
+
62
+ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
63
+ """Base implementation for VLM models that inherit from BasePageModel.
64
+
65
+ Provides a default __call__ implementation that extracts images from pages,
66
+ processes them using process_images, and attaches results back to pages.
67
+ """
68
+
69
+ # Type annotations for attributes that subclasses must initialize
70
+ vlm_options: InlineVlmOptions
71
+ processor: Any
72
+
73
+ @abstractmethod
74
+ def __call__(
75
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
76
+ ) -> Iterable[Page]:
77
+ """Extract images from pages, process them, and attach results back."""
78
+
79
+ def formulate_prompt(self, user_prompt: str) -> str:
80
+ """Formulate a prompt for the VLM."""
81
+ _log = logging.getLogger(__name__)
82
+
83
+ if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
84
+ return user_prompt
85
+
86
+ elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
87
+ _log.debug("Using specialized prompt for Phi-4")
88
+ # Note: This might need adjustment for VLLM vs transformers
89
+ user_prompt_prefix = "<|user|>"
90
+ assistant_prompt = "<|assistant|>"
91
+ prompt_suffix = "<|end|>"
92
+
93
+ prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
94
+ _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
95
+
96
+ return prompt
97
+
98
+ elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
99
+ messages = [
100
+ {
101
+ "role": "user",
102
+ "content": [
103
+ {
104
+ "type": "text",
105
+ "text": "This is a page from a document.",
106
+ },
107
+ {"type": "image"},
108
+ {"type": "text", "text": user_prompt},
109
+ ],
110
+ }
111
+ ]
112
+ prompt = self.processor.apply_chat_template(
113
+ messages, add_generation_prompt=True
114
+ )
115
+ return prompt
116
+
117
+ raise RuntimeError(
118
+ f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
119
+ )
120
+
121
+
29
122
  EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
30
123
 
31
124
 
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
17
17
 
18
18
  class PagePreprocessingOptions(BaseModel):
19
19
  images_scale: Optional[float]
20
+ skip_cell_extraction: bool = (
21
+ False # Skip text cell extraction for VLM-only processing
22
+ )
20
23
 
21
24
 
22
25
  class PagePreprocessingModel(BasePageModel):
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
41
44
  else:
42
45
  with TimeRecorder(conv_res, "page_parse"):
43
46
  page = self._populate_page_images(page)
44
- page = self._parse_page_cells(conv_res, page)
47
+ if not self.options.skip_cell_extraction:
48
+ page = self._parse_page_cells(conv_res, page)
45
49
  yield page
46
50
 
47
51
  # Generate the page image and store it in the page object
@@ -4,6 +4,7 @@ from pathlib import Path
4
4
  from typing import Optional, Type, Union
5
5
 
6
6
  from PIL import Image
7
+ from transformers import AutoModelForImageTextToText
7
8
 
8
9
  from docling.datamodel.accelerator_options import AcceleratorOptions
9
10
  from docling.datamodel.pipeline_options import (
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
63
64
  # Initialize processor and model
64
65
  with _model_init_lock:
65
66
  self.processor = AutoProcessor.from_pretrained(artifacts_path)
66
- self.model = AutoModelForVision2Seq.from_pretrained(
67
+ self.model = AutoModelForImageTextToText.from_pretrained(
67
68
  artifacts_path,
68
69
  device_map=self.device,
69
70
  torch_dtype=torch.bfloat16,
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
71
72
  "flash_attention_2"
72
73
  if self.device.startswith("cuda")
73
74
  and accelerator_options.cuda_use_flash_attention2
74
- else "eager"
75
+ else "sdpa"
75
76
  ),
76
77
  )
78
+ self.model = torch.compile(self.model) # type: ignore
77
79
 
78
80
  self.provenance = f"{self.options.repo_id}"
79
81
 
@@ -0,0 +1 @@
1
+
@@ -3,7 +3,11 @@ import logging
3
3
  import time
4
4
  from collections.abc import Iterable
5
5
  from pathlib import Path
6
- from typing import Any, Optional
6
+ from typing import Any, Optional, Union
7
+
8
+ import numpy as np
9
+ from PIL.Image import Image
10
+ from transformers import StoppingCriteriaList, StopStringCriteria
7
11
 
8
12
  from docling.datamodel.accelerator_options import (
9
13
  AcceleratorOptions,
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
15
19
  TransformersModelType,
16
20
  TransformersPromptStyle,
17
21
  )
18
- from docling.models.base_model import BasePageModel
22
+ from docling.models.base_model import BaseVlmPageModel
19
23
  from docling.models.utils.hf_model_download import (
20
24
  HuggingFaceModelDownloadMixin,
21
25
  )
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
25
29
  _log = logging.getLogger(__name__)
26
30
 
27
31
 
28
- class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
32
+ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
29
33
  def __init__(
30
34
  self,
31
35
  enabled: bool,
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
103
107
  artifacts_path,
104
108
  trust_remote_code=vlm_options.trust_remote_code,
105
109
  )
110
+ self.processor.tokenizer.padding_side = "left"
111
+
106
112
  self.vlm_model = model_cls.from_pretrained(
107
113
  artifacts_path,
108
114
  device_map=self.device,
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
111
117
  "flash_attention_2"
112
118
  if self.device.startswith("cuda")
113
119
  and accelerator_options.cuda_use_flash_attention2
114
- else "eager"
120
+ else "sdpa"
115
121
  ),
116
122
  trust_remote_code=vlm_options.trust_remote_code,
117
123
  )
124
+ self.vlm_model = torch.compile(self.vlm_model) # type: ignore
118
125
 
119
126
  # Load generation config
120
127
  self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
122
129
  def __call__(
123
130
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
124
131
  ) -> Iterable[Page]:
125
- for page in page_batch:
132
+ page_list = list(page_batch)
133
+ if not page_list:
134
+ return
135
+
136
+ valid_pages = []
137
+ invalid_pages = []
138
+
139
+ for page in page_list:
126
140
  assert page._backend is not None
127
141
  if not page._backend.is_valid():
128
- yield page
142
+ invalid_pages.append(page)
129
143
  else:
130
- with TimeRecorder(conv_res, "vlm"):
131
- assert page.size is not None
144
+ valid_pages.append(page)
132
145
 
146
+ # Process valid pages in batch
147
+ if valid_pages:
148
+ with TimeRecorder(conv_res, "vlm"):
149
+ # Prepare images and prompts for batch processing
150
+ images = []
151
+ user_prompts = []
152
+ pages_with_images = []
153
+
154
+ for page in valid_pages:
155
+ assert page.size is not None
133
156
  hi_res_image = page.get_image(
134
157
  scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
135
158
  )
136
159
 
137
- # Define prompt structure
138
- user_prompt = self.vlm_options.build_prompt(page.parsed_page)
139
- prompt = self.formulate_prompt(user_prompt)
140
-
141
- inputs = self.processor(
142
- text=prompt, images=[hi_res_image], return_tensors="pt"
143
- ).to(self.device)
144
-
145
- start_time = time.time()
146
- # Call model to generate:
147
- generated_ids = self.vlm_model.generate(
148
- **inputs,
149
- max_new_tokens=self.max_new_tokens,
150
- use_cache=self.use_cache,
151
- temperature=self.temperature,
152
- generation_config=self.generation_config,
153
- **self.vlm_options.extra_generation_config,
154
- )
160
+ # Only process pages with valid images
161
+ if hi_res_image is not None:
162
+ images.append(hi_res_image)
155
163
 
156
- generation_time = time.time() - start_time
157
- generated_texts = self.processor.batch_decode(
158
- generated_ids[:, inputs["input_ids"].shape[1] :],
159
- skip_special_tokens=False,
160
- )[0]
164
+ # Define prompt structure
165
+ user_prompt = self.vlm_options.build_prompt(page.parsed_page)
161
166
 
162
- num_tokens = len(generated_ids[0])
163
- _log.debug(
164
- f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
165
- )
166
- generated_texts = self.vlm_options.decode_response(generated_texts)
167
- page.predictions.vlm_response = VlmPrediction(
168
- text=generated_texts,
169
- generation_time=generation_time,
167
+ user_prompts.append(user_prompt)
168
+ pages_with_images.append(page)
169
+
170
+ # Use process_images for the actual inference
171
+ if images: # Only if we have valid images
172
+ predictions = list(self.process_images(images, user_prompts))
173
+
174
+ # Attach results to pages
175
+ for page, prediction in zip(pages_with_images, predictions):
176
+ page.predictions.vlm_response = prediction
177
+
178
+ # Yield all pages (valid and invalid)
179
+ for page in invalid_pages:
180
+ yield page
181
+ for page in valid_pages:
182
+ yield page
183
+
184
+ def process_images(
185
+ self,
186
+ image_batch: Iterable[Union[Image, np.ndarray]],
187
+ prompt: Union[str, list[str]],
188
+ ) -> Iterable[VlmPrediction]:
189
+ """
190
+ Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
191
+ - Lets the processor handle all padding & batching for text+images.
192
+ - Trims generated sequences per row using attention_mask (no pad-id fallbacks).
193
+ - Keeps your formulate_prompt() exactly as-is.
194
+ """
195
+ import numpy as np
196
+ import torch
197
+ from PIL import Image as PILImage
198
+
199
+ # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
200
+ pil_images: list[Image] = []
201
+ for img in image_batch:
202
+ if isinstance(img, np.ndarray):
203
+ if img.ndim == 3 and img.shape[2] in (3, 4):
204
+ pil_img = PILImage.fromarray(img.astype(np.uint8))
205
+ elif img.ndim == 2:
206
+ pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
207
+ else:
208
+ raise ValueError(f"Unsupported numpy array shape: {img.shape}")
209
+ else:
210
+ pil_img = img
211
+ if pil_img.mode != "RGB":
212
+ pil_img = pil_img.convert("RGB")
213
+ pil_images.append(pil_img)
214
+
215
+ if not pil_images:
216
+ return
217
+
218
+ # -- Normalize prompts (1 per image)
219
+ if isinstance(prompt, str):
220
+ user_prompts = [prompt] * len(pil_images)
221
+ else:
222
+ if len(prompt) != len(pil_images):
223
+ raise ValueError(
224
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
225
+ )
226
+ user_prompts = prompt
227
+
228
+ # Use your prompt formatter verbatim
229
+ if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
230
+ inputs = self.processor(
231
+ pil_images,
232
+ return_tensors="pt",
233
+ padding=True, # pad across batch for both text and vision
234
+ **self.vlm_options.extra_processor_kwargs,
235
+ )
236
+ else:
237
+ prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
238
+
239
+ # -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
240
+ inputs = self.processor(
241
+ text=prompts,
242
+ images=pil_images,
243
+ return_tensors="pt",
244
+ padding=True, # pad across batch for both text and vision
245
+ **self.vlm_options.extra_processor_kwargs,
246
+ )
247
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
248
+
249
+ # -- Optional stopping criteria
250
+ stopping_criteria = None
251
+ if self.vlm_options.stop_strings:
252
+ stopping_criteria = StoppingCriteriaList(
253
+ [
254
+ StopStringCriteria(
255
+ stop_strings=self.vlm_options.stop_strings,
256
+ tokenizer=self.processor.tokenizer,
170
257
  )
258
+ ]
259
+ )
171
260
 
172
- yield page
173
-
174
- def formulate_prompt(self, user_prompt: str) -> str:
175
- """Formulate a prompt for the VLM."""
176
-
177
- if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
178
- return user_prompt
179
-
180
- elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
181
- _log.debug("Using specialized prompt for Phi-4")
182
- # more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
183
-
184
- user_prompt = "<|user|>"
185
- assistant_prompt = "<|assistant|>"
186
- prompt_suffix = "<|end|>"
187
-
188
- prompt = f"{user_prompt}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
189
- _log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
190
-
191
- return prompt
192
-
193
- elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
194
- messages = [
195
- {
196
- "role": "user",
197
- "content": [
198
- {
199
- "type": "text",
200
- "text": "This is a page from a document.",
201
- },
202
- {"type": "image"},
203
- {"type": "text", "text": user_prompt},
204
- ],
205
- }
206
- ]
207
- prompt = self.processor.apply_chat_template(
208
- messages, add_generation_prompt=False
261
+ # -- Generate (Image-Text-to-Text class expects these inputs from processor)
262
+ gen_kwargs = {
263
+ **inputs,
264
+ "max_new_tokens": self.max_new_tokens,
265
+ "use_cache": self.use_cache,
266
+ "generation_config": self.generation_config,
267
+ **self.vlm_options.extra_generation_config,
268
+ }
269
+ if self.temperature > 0:
270
+ gen_kwargs["do_sample"] = True
271
+ gen_kwargs["temperature"] = self.temperature
272
+ else:
273
+ gen_kwargs["do_sample"] = False
274
+
275
+ if stopping_criteria is not None:
276
+ gen_kwargs["stopping_criteria"] = stopping_criteria
277
+
278
+ start_time = time.time()
279
+ with torch.inference_mode():
280
+ generated_ids = self.vlm_model.generate(**gen_kwargs)
281
+ generation_time = time.time() - start_time
282
+
283
+ input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
284
+ trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
285
+
286
+ # -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
287
+ decode_fn = getattr(self.processor, "batch_decode", None)
288
+ if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
289
+ decode_fn = self.processor.tokenizer.batch_decode
290
+ if decode_fn is None:
291
+ raise RuntimeError(
292
+ "Neither processor.batch_decode nor tokenizer.batch_decode is available."
209
293
  )
210
- return prompt
211
294
 
212
- raise RuntimeError(
213
- f"Uknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
295
+ decoded_texts: list[str] = decode_fn(
296
+ trimmed_sequences, skip_special_tokens=False
214
297
  )
298
+
299
+ # -- Clip off pad tokens from decoded texts
300
+ pad_token = self.processor.tokenizer.pad_token
301
+ if pad_token:
302
+ decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
303
+
304
+ # -- Optional logging
305
+ if generated_ids.shape[0] > 0:
306
+ _log.debug(
307
+ f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
308
+ f"for batch size {generated_ids.shape[0]}."
309
+ )
310
+
311
+ for text in decoded_texts:
312
+ # Apply decode_response to the output text
313
+ decoded_text = self.vlm_options.decode_response(text)
314
+ yield VlmPrediction(text=decoded_text, generation_time=generation_time)