docling 2.45.0__py3-none-any.whl → 2.47.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/docling_parse_v4_backend.py +61 -27
- docling/backend/html_backend.py +119 -17
- docling/backend/msword_backend.py +126 -16
- docling/cli/main.py +14 -0
- docling/cli/models.py +56 -0
- docling/datamodel/base_models.py +1 -1
- docling/datamodel/pipeline_options.py +4 -3
- docling/datamodel/pipeline_options_vlm_model.py +5 -0
- docling/datamodel/vlm_model_specs.py +114 -1
- docling/models/base_model.py +95 -2
- docling/models/code_formula_model.py +87 -76
- docling/models/page_preprocessing_model.py +5 -1
- docling/models/picture_description_vlm_model.py +4 -2
- docling/models/tesseract_ocr_cli_model.py +4 -2
- docling/models/vlm_models_inline/__init__.py +1 -0
- docling/models/vlm_models_inline/hf_transformers_model.py +179 -79
- docling/models/vlm_models_inline/mlx_model.py +179 -68
- docling/models/vlm_models_inline/vllm_model.py +235 -0
- docling/pipeline/base_pipeline.py +7 -1
- docling/pipeline/threaded_standard_pdf_pipeline.py +7 -5
- docling/pipeline/vlm_pipeline.py +14 -1
- docling/utils/layout_postprocessor.py +51 -43
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/METADATA +3 -2
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/RECORD +28 -27
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/WHEEL +0 -0
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/entry_points.txt +0 -0
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/licenses/LICENSE +0 -0
- {docling-2.45.0.dist-info → docling-2.47.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
|
|
1
1
|
import re
|
2
|
-
from collections import Counter
|
3
2
|
from collections.abc import Iterable
|
4
3
|
from pathlib import Path
|
5
4
|
from typing import List, Literal, Optional, Tuple, Union
|
@@ -13,10 +12,11 @@ from docling_core.types.doc import (
|
|
13
12
|
TextItem,
|
14
13
|
)
|
15
14
|
from docling_core.types.doc.labels import CodeLanguageLabel
|
16
|
-
from PIL import Image
|
15
|
+
from PIL import Image
|
17
16
|
from pydantic import BaseModel
|
17
|
+
from transformers import AutoModelForImageTextToText, AutoProcessor
|
18
18
|
|
19
|
-
from docling.datamodel.accelerator_options import AcceleratorOptions
|
19
|
+
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
20
20
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
21
21
|
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
22
22
|
from docling.models.utils.hf_model_download import download_hf_model
|
@@ -65,9 +65,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
65
65
|
Processes the given batch of elements and enriches them with predictions.
|
66
66
|
"""
|
67
67
|
|
68
|
-
_model_repo_folder = "ds4sd--
|
68
|
+
_model_repo_folder = "ds4sd--CodeFormulaV2"
|
69
69
|
elements_batch_size = 5
|
70
|
-
images_scale = 1.
|
70
|
+
images_scale = 1.67 # = 120 dpi, aligned with training data resolution
|
71
71
|
expansion_factor = 0.18
|
72
72
|
|
73
73
|
def __init__(
|
@@ -95,10 +95,9 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
95
95
|
self.options = options
|
96
96
|
|
97
97
|
if self.enabled:
|
98
|
-
device = decide_device(
|
99
|
-
|
100
|
-
|
101
|
-
CodeFormulaPredictor,
|
98
|
+
self.device = decide_device(
|
99
|
+
accelerator_options.device,
|
100
|
+
supported_devices=[AcceleratorDevice.CPU, AcceleratorDevice.CUDA],
|
102
101
|
)
|
103
102
|
|
104
103
|
if artifacts_path is None:
|
@@ -106,11 +105,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
106
105
|
else:
|
107
106
|
artifacts_path = artifacts_path / self._model_repo_folder
|
108
107
|
|
109
|
-
self.
|
110
|
-
artifacts_path
|
111
|
-
|
112
|
-
|
108
|
+
self._processor = AutoProcessor.from_pretrained(
|
109
|
+
artifacts_path,
|
110
|
+
)
|
111
|
+
self._model_max_length = self._processor.tokenizer.model_max_length
|
112
|
+
self._model = AutoModelForImageTextToText.from_pretrained(
|
113
|
+
artifacts_path, device_map=self.device
|
113
114
|
)
|
115
|
+
self._model.eval()
|
114
116
|
|
115
117
|
@staticmethod
|
116
118
|
def download_models(
|
@@ -119,8 +121,8 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
119
121
|
progress: bool = False,
|
120
122
|
) -> Path:
|
121
123
|
return download_hf_model(
|
122
|
-
repo_id="ds4sd/
|
123
|
-
revision="
|
124
|
+
repo_id="ds4sd/CodeFormulaV2",
|
125
|
+
revision="main",
|
124
126
|
local_dir=local_dir,
|
125
127
|
force=force,
|
126
128
|
progress=progress,
|
@@ -172,7 +174,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
172
174
|
- The second element is the extracted language if a match is found;
|
173
175
|
otherwise, `None`.
|
174
176
|
"""
|
175
|
-
pattern = r"^<_([^_>]+)_>\s(.*)"
|
177
|
+
pattern = r"^<_([^_>]+)_>\s*(.*)"
|
176
178
|
match = re.match(pattern, input_string, flags=re.DOTALL)
|
177
179
|
if match:
|
178
180
|
language = str(match.group(1)) # the captured programming language
|
@@ -203,81 +205,74 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
203
205
|
except ValueError:
|
204
206
|
return CodeLanguageLabel.UNKNOWN
|
205
207
|
|
206
|
-
def
|
208
|
+
def _get_prompt(self, label: str) -> str:
|
207
209
|
"""
|
208
|
-
|
210
|
+
Constructs the prompt for the model based on the input label.
|
209
211
|
|
210
212
|
Parameters
|
211
213
|
----------
|
212
|
-
|
213
|
-
|
214
|
+
label : str
|
215
|
+
The type of input, either 'code' or 'formula'.
|
214
216
|
|
215
217
|
Returns
|
216
218
|
-------
|
217
|
-
|
218
|
-
|
219
|
+
str
|
220
|
+
The constructed prompt including necessary tokens and query.
|
221
|
+
|
222
|
+
Raises
|
223
|
+
------
|
224
|
+
NotImplementedError
|
225
|
+
If the label is not 'code' or 'formula'.
|
219
226
|
"""
|
220
|
-
|
221
|
-
|
227
|
+
if label == "code":
|
228
|
+
query = "<code>"
|
229
|
+
elif label == "formula":
|
230
|
+
query = "<formula>"
|
231
|
+
else:
|
232
|
+
raise NotImplementedError("Label must be either code or formula")
|
222
233
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
right = img_np[:, -1] # shape (H,)
|
234
|
+
messages = [
|
235
|
+
{
|
236
|
+
"role": "user",
|
237
|
+
"content": [{"type": "image"}, {"type": "text", "text": query}],
|
238
|
+
},
|
239
|
+
]
|
230
240
|
|
231
|
-
|
232
|
-
|
241
|
+
prompt = self._processor.apply_chat_template(
|
242
|
+
messages, add_generation_prompt=True
|
243
|
+
)
|
233
244
|
|
234
|
-
|
235
|
-
freq = Counter(edges.tolist())
|
236
|
-
most_common_value, _ = freq.most_common(1)[0]
|
237
|
-
return int(most_common_value) # single channel color
|
245
|
+
return prompt
|
238
246
|
|
239
|
-
|
240
|
-
# Color image: shape (H, W, C)
|
241
|
-
top = img_np[0, :, :] # shape (W, C)
|
242
|
-
bottom = img_np[-1, :, :] # shape (W, C)
|
243
|
-
left = img_np[:, 0, :] # shape (H, C)
|
244
|
-
right = img_np[:, -1, :] # shape (H, C)
|
245
|
-
|
246
|
-
# Concatenate edges along first axis
|
247
|
-
edges = np.concatenate([top, bottom, left, right], axis=0)
|
248
|
-
|
249
|
-
# Convert each color to a tuple for counting
|
250
|
-
edges_as_tuples = [tuple(pixel) for pixel in edges]
|
251
|
-
freq = Counter(edges_as_tuples)
|
252
|
-
most_common_value, _ = freq.most_common(1)[0]
|
253
|
-
return most_common_value # e.g. (R, G, B) or (R, G, B, A)
|
254
|
-
|
255
|
-
def _pad_with_most_frequent_edge_color(
|
256
|
-
self, img: Union[Image.Image, np.ndarray], padding: Tuple[int, int, int, int]
|
257
|
-
):
|
247
|
+
def _post_process(self, texts: list[str]) -> list[str]:
|
258
248
|
"""
|
259
|
-
|
249
|
+
Processes a list of text strings by truncating at '<end_of_utterance>' and
|
250
|
+
removing a predefined set of unwanted substrings.
|
260
251
|
|
261
252
|
Parameters
|
262
253
|
----------
|
263
|
-
|
264
|
-
|
265
|
-
padding : tuple
|
266
|
-
Padding (left, top, right, bottom) in pixels.
|
254
|
+
texts : list[str]
|
255
|
+
A list of strings to be post-processed.
|
267
256
|
|
268
257
|
Returns
|
269
258
|
-------
|
270
|
-
|
259
|
+
list[str]
|
260
|
+
A list of cleaned strings with specified substrings removed and truncated at
|
261
|
+
'<end_of_utterance>' if present.
|
271
262
|
"""
|
272
|
-
|
273
|
-
pil_img = Image.fromarray(img)
|
274
|
-
else:
|
275
|
-
pil_img = img
|
263
|
+
to_remove = ["</code>", "</formula>", "<loc_0><loc_0><loc_500><loc_500>"]
|
276
264
|
|
277
|
-
|
265
|
+
def clean_text(text: str) -> str:
|
266
|
+
idx = text.find("<end_of_utterance>")
|
267
|
+
if idx != -1:
|
268
|
+
text = text[:idx]
|
278
269
|
|
279
|
-
|
280
|
-
|
270
|
+
for token in to_remove:
|
271
|
+
if token in text:
|
272
|
+
text = text.replace(token, "")
|
273
|
+
return text.lstrip()
|
274
|
+
|
275
|
+
return [clean_text(t) for t in texts]
|
281
276
|
|
282
277
|
def __call__(
|
283
278
|
self,
|
@@ -308,14 +303,30 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
308
303
|
images: List[Union[Image.Image, np.ndarray]] = []
|
309
304
|
elements: List[TextItem] = []
|
310
305
|
for el in element_batch:
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
306
|
+
elements.append(el.item) # type: ignore[arg-type]
|
307
|
+
labels.append(el.item.label) # type: ignore[attr-defined]
|
308
|
+
images.append(el.image)
|
309
|
+
|
310
|
+
prompts = [self._get_prompt(label) for label in labels]
|
311
|
+
inputs = self._processor(
|
312
|
+
text=prompts,
|
313
|
+
images=images,
|
314
|
+
return_tensors="pt",
|
315
|
+
)
|
316
|
+
inputs = inputs.to(self.device)
|
317
317
|
|
318
|
-
|
318
|
+
gen_kwargs = dict(
|
319
|
+
max_new_tokens=self._model_max_length - inputs.input_ids.shape[1],
|
320
|
+
use_cache=True,
|
321
|
+
do_sample=False,
|
322
|
+
)
|
323
|
+
|
324
|
+
generated_ids = self._model.generate(**inputs, **gen_kwargs)
|
325
|
+
|
326
|
+
outputs = self._processor.batch_decode(
|
327
|
+
generated_ids[:, inputs.input_ids.shape[1] :], skip_special_tokens=False
|
328
|
+
)
|
329
|
+
outputs = self._post_process(outputs)
|
319
330
|
|
320
331
|
for item, output in zip(elements, outputs):
|
321
332
|
if isinstance(item, CodeItem):
|
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
|
|
17
17
|
|
18
18
|
class PagePreprocessingOptions(BaseModel):
|
19
19
|
images_scale: Optional[float]
|
20
|
+
skip_cell_extraction: bool = (
|
21
|
+
False # Skip text cell extraction for VLM-only processing
|
22
|
+
)
|
20
23
|
|
21
24
|
|
22
25
|
class PagePreprocessingModel(BasePageModel):
|
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
|
|
41
44
|
else:
|
42
45
|
with TimeRecorder(conv_res, "page_parse"):
|
43
46
|
page = self._populate_page_images(page)
|
44
|
-
|
47
|
+
if not self.options.skip_cell_extraction:
|
48
|
+
page = self._parse_page_cells(conv_res, page)
|
45
49
|
yield page
|
46
50
|
|
47
51
|
# Generate the page image and store it in the page object
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
4
4
|
from typing import Optional, Type, Union
|
5
5
|
|
6
6
|
from PIL import Image
|
7
|
+
from transformers import AutoModelForImageTextToText
|
7
8
|
|
8
9
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
9
10
|
from docling.datamodel.pipeline_options import (
|
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
|
|
63
64
|
# Initialize processor and model
|
64
65
|
with _model_init_lock:
|
65
66
|
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
66
|
-
self.model =
|
67
|
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
67
68
|
artifacts_path,
|
68
69
|
device_map=self.device,
|
69
70
|
torch_dtype=torch.bfloat16,
|
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
|
|
71
72
|
"flash_attention_2"
|
72
73
|
if self.device.startswith("cuda")
|
73
74
|
and accelerator_options.cuda_use_flash_attention2
|
74
|
-
else "
|
75
|
+
else "sdpa"
|
75
76
|
),
|
76
77
|
)
|
78
|
+
self.model = torch.compile(self.model) # type: ignore
|
77
79
|
|
78
80
|
self.provenance = f"{self.options.repo_id}"
|
79
81
|
|
@@ -320,6 +320,8 @@ class TesseractOcrCliModel(BaseOcrModel):
|
|
320
320
|
|
321
321
|
|
322
322
|
def _parse_orientation(df_osd: pd.DataFrame) -> int:
|
323
|
-
|
324
|
-
|
323
|
+
# For strictly optimal performance with invariant dataframe format:
|
324
|
+
mask = df_osd["key"].to_numpy() == "Orientation in degrees"
|
325
|
+
orientation_val = df_osd["value"].to_numpy()[mask][0]
|
326
|
+
orientation = parse_tesseract_orientation(orientation_val.strip())
|
325
327
|
return orientation
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -3,7 +3,11 @@ import logging
|
|
3
3
|
import time
|
4
4
|
from collections.abc import Iterable
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Optional
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from PIL.Image import Image
|
10
|
+
from transformers import StoppingCriteriaList, StopStringCriteria
|
7
11
|
|
8
12
|
from docling.datamodel.accelerator_options import (
|
9
13
|
AcceleratorOptions,
|
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
15
19
|
TransformersModelType,
|
16
20
|
TransformersPromptStyle,
|
17
21
|
)
|
18
|
-
from docling.models.base_model import
|
22
|
+
from docling.models.base_model import BaseVlmPageModel
|
19
23
|
from docling.models.utils.hf_model_download import (
|
20
24
|
HuggingFaceModelDownloadMixin,
|
21
25
|
)
|
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
|
|
25
29
|
_log = logging.getLogger(__name__)
|
26
30
|
|
27
31
|
|
28
|
-
class HuggingFaceTransformersVlmModel(
|
32
|
+
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
29
33
|
def __init__(
|
30
34
|
self,
|
31
35
|
enabled: bool,
|
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
103
107
|
artifacts_path,
|
104
108
|
trust_remote_code=vlm_options.trust_remote_code,
|
105
109
|
)
|
110
|
+
self.processor.tokenizer.padding_side = "left"
|
111
|
+
|
106
112
|
self.vlm_model = model_cls.from_pretrained(
|
107
113
|
artifacts_path,
|
108
114
|
device_map=self.device,
|
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
111
117
|
"flash_attention_2"
|
112
118
|
if self.device.startswith("cuda")
|
113
119
|
and accelerator_options.cuda_use_flash_attention2
|
114
|
-
else "
|
120
|
+
else "sdpa"
|
115
121
|
),
|
116
122
|
trust_remote_code=vlm_options.trust_remote_code,
|
117
123
|
)
|
124
|
+
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
118
125
|
|
119
126
|
# Load generation config
|
120
127
|
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
122
129
|
def __call__(
|
123
130
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
124
131
|
) -> Iterable[Page]:
|
125
|
-
|
132
|
+
page_list = list(page_batch)
|
133
|
+
if not page_list:
|
134
|
+
return
|
135
|
+
|
136
|
+
valid_pages = []
|
137
|
+
invalid_pages = []
|
138
|
+
|
139
|
+
for page in page_list:
|
126
140
|
assert page._backend is not None
|
127
141
|
if not page._backend.is_valid():
|
128
|
-
|
142
|
+
invalid_pages.append(page)
|
129
143
|
else:
|
130
|
-
|
131
|
-
assert page.size is not None
|
144
|
+
valid_pages.append(page)
|
132
145
|
|
146
|
+
# Process valid pages in batch
|
147
|
+
if valid_pages:
|
148
|
+
with TimeRecorder(conv_res, "vlm"):
|
149
|
+
# Prepare images and prompts for batch processing
|
150
|
+
images = []
|
151
|
+
user_prompts = []
|
152
|
+
pages_with_images = []
|
153
|
+
|
154
|
+
for page in valid_pages:
|
155
|
+
assert page.size is not None
|
133
156
|
hi_res_image = page.get_image(
|
134
157
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
135
158
|
)
|
136
159
|
|
137
|
-
#
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
inputs = self.processor(
|
142
|
-
text=prompt, images=[hi_res_image], return_tensors="pt"
|
143
|
-
).to(self.device)
|
144
|
-
|
145
|
-
start_time = time.time()
|
146
|
-
# Call model to generate:
|
147
|
-
generated_ids = self.vlm_model.generate(
|
148
|
-
**inputs,
|
149
|
-
max_new_tokens=self.max_new_tokens,
|
150
|
-
use_cache=self.use_cache,
|
151
|
-
temperature=self.temperature,
|
152
|
-
generation_config=self.generation_config,
|
153
|
-
**self.vlm_options.extra_generation_config,
|
154
|
-
)
|
160
|
+
# Only process pages with valid images
|
161
|
+
if hi_res_image is not None:
|
162
|
+
images.append(hi_res_image)
|
155
163
|
|
156
|
-
|
157
|
-
|
158
|
-
generated_ids[:, inputs["input_ids"].shape[1] :],
|
159
|
-
skip_special_tokens=False,
|
160
|
-
)[0]
|
164
|
+
# Define prompt structure
|
165
|
+
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
161
166
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
user_prompts.append(user_prompt)
|
168
|
+
pages_with_images.append(page)
|
169
|
+
|
170
|
+
# Use process_images for the actual inference
|
171
|
+
if images: # Only if we have valid images
|
172
|
+
predictions = list(self.process_images(images, user_prompts))
|
173
|
+
|
174
|
+
# Attach results to pages
|
175
|
+
for page, prediction in zip(pages_with_images, predictions):
|
176
|
+
page.predictions.vlm_response = prediction
|
177
|
+
|
178
|
+
# Yield all pages (valid and invalid)
|
179
|
+
for page in invalid_pages:
|
180
|
+
yield page
|
181
|
+
for page in valid_pages:
|
182
|
+
yield page
|
183
|
+
|
184
|
+
def process_images(
|
185
|
+
self,
|
186
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
187
|
+
prompt: Union[str, list[str]],
|
188
|
+
) -> Iterable[VlmPrediction]:
|
189
|
+
"""
|
190
|
+
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
|
191
|
+
- Lets the processor handle all padding & batching for text+images.
|
192
|
+
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
|
193
|
+
- Keeps your formulate_prompt() exactly as-is.
|
194
|
+
"""
|
195
|
+
import numpy as np
|
196
|
+
import torch
|
197
|
+
from PIL import Image as PILImage
|
198
|
+
|
199
|
+
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
|
200
|
+
pil_images: list[Image] = []
|
201
|
+
for img in image_batch:
|
202
|
+
if isinstance(img, np.ndarray):
|
203
|
+
if img.ndim == 3 and img.shape[2] in (3, 4):
|
204
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
205
|
+
elif img.ndim == 2:
|
206
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
207
|
+
else:
|
208
|
+
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
209
|
+
else:
|
210
|
+
pil_img = img
|
211
|
+
if pil_img.mode != "RGB":
|
212
|
+
pil_img = pil_img.convert("RGB")
|
213
|
+
pil_images.append(pil_img)
|
214
|
+
|
215
|
+
if not pil_images:
|
216
|
+
return
|
217
|
+
|
218
|
+
# -- Normalize prompts (1 per image)
|
219
|
+
if isinstance(prompt, str):
|
220
|
+
user_prompts = [prompt] * len(pil_images)
|
221
|
+
else:
|
222
|
+
if len(prompt) != len(pil_images):
|
223
|
+
raise ValueError(
|
224
|
+
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
225
|
+
)
|
226
|
+
user_prompts = prompt
|
227
|
+
|
228
|
+
# Use your prompt formatter verbatim
|
229
|
+
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
230
|
+
inputs = self.processor(
|
231
|
+
pil_images,
|
232
|
+
return_tensors="pt",
|
233
|
+
padding=True, # pad across batch for both text and vision
|
234
|
+
**self.vlm_options.extra_processor_kwargs,
|
235
|
+
)
|
236
|
+
else:
|
237
|
+
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
|
238
|
+
|
239
|
+
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
|
240
|
+
inputs = self.processor(
|
241
|
+
text=prompts,
|
242
|
+
images=pil_images,
|
243
|
+
return_tensors="pt",
|
244
|
+
padding=True, # pad across batch for both text and vision
|
245
|
+
**self.vlm_options.extra_processor_kwargs,
|
246
|
+
)
|
247
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
248
|
+
|
249
|
+
# -- Optional stopping criteria
|
250
|
+
stopping_criteria = None
|
251
|
+
if self.vlm_options.stop_strings:
|
252
|
+
stopping_criteria = StoppingCriteriaList(
|
253
|
+
[
|
254
|
+
StopStringCriteria(
|
255
|
+
stop_strings=self.vlm_options.stop_strings,
|
256
|
+
tokenizer=self.processor.tokenizer,
|
170
257
|
)
|
258
|
+
]
|
259
|
+
)
|
171
260
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
],
|
205
|
-
}
|
206
|
-
]
|
207
|
-
prompt = self.processor.apply_chat_template(
|
208
|
-
messages, add_generation_prompt=False
|
261
|
+
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
262
|
+
gen_kwargs = {
|
263
|
+
**inputs,
|
264
|
+
"max_new_tokens": self.max_new_tokens,
|
265
|
+
"use_cache": self.use_cache,
|
266
|
+
"generation_config": self.generation_config,
|
267
|
+
**self.vlm_options.extra_generation_config,
|
268
|
+
}
|
269
|
+
if self.temperature > 0:
|
270
|
+
gen_kwargs["do_sample"] = True
|
271
|
+
gen_kwargs["temperature"] = self.temperature
|
272
|
+
else:
|
273
|
+
gen_kwargs["do_sample"] = False
|
274
|
+
|
275
|
+
if stopping_criteria is not None:
|
276
|
+
gen_kwargs["stopping_criteria"] = stopping_criteria
|
277
|
+
|
278
|
+
start_time = time.time()
|
279
|
+
with torch.inference_mode():
|
280
|
+
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
281
|
+
generation_time = time.time() - start_time
|
282
|
+
|
283
|
+
input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
|
284
|
+
trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
|
285
|
+
|
286
|
+
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
|
287
|
+
decode_fn = getattr(self.processor, "batch_decode", None)
|
288
|
+
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
|
289
|
+
decode_fn = self.processor.tokenizer.batch_decode
|
290
|
+
if decode_fn is None:
|
291
|
+
raise RuntimeError(
|
292
|
+
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
|
209
293
|
)
|
210
|
-
return prompt
|
211
294
|
|
212
|
-
|
213
|
-
|
295
|
+
decoded_texts: list[str] = decode_fn(
|
296
|
+
trimmed_sequences, skip_special_tokens=False
|
214
297
|
)
|
298
|
+
|
299
|
+
# -- Clip off pad tokens from decoded texts
|
300
|
+
pad_token = self.processor.tokenizer.pad_token
|
301
|
+
if pad_token:
|
302
|
+
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
303
|
+
|
304
|
+
# -- Optional logging
|
305
|
+
if generated_ids.shape[0] > 0:
|
306
|
+
_log.debug(
|
307
|
+
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
308
|
+
f"for batch size {generated_ids.shape[0]}."
|
309
|
+
)
|
310
|
+
|
311
|
+
for text in decoded_texts:
|
312
|
+
# Apply decode_response to the output text
|
313
|
+
decoded_text = self.vlm_options.decode_response(text)
|
314
|
+
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|