docling 2.69.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docling might be problematic. Click here for more details.
- docling/__init__.py +0 -0
- docling/backend/__init__.py +0 -0
- docling/backend/abstract_backend.py +84 -0
- docling/backend/asciidoc_backend.py +443 -0
- docling/backend/csv_backend.py +125 -0
- docling/backend/docling_parse_backend.py +237 -0
- docling/backend/docling_parse_v2_backend.py +276 -0
- docling/backend/docling_parse_v4_backend.py +260 -0
- docling/backend/docx/__init__.py +0 -0
- docling/backend/docx/drawingml/utils.py +131 -0
- docling/backend/docx/latex/__init__.py +0 -0
- docling/backend/docx/latex/latex_dict.py +274 -0
- docling/backend/docx/latex/omml.py +459 -0
- docling/backend/html_backend.py +1502 -0
- docling/backend/image_backend.py +188 -0
- docling/backend/json/__init__.py +0 -0
- docling/backend/json/docling_json_backend.py +58 -0
- docling/backend/md_backend.py +618 -0
- docling/backend/mets_gbs_backend.py +399 -0
- docling/backend/msexcel_backend.py +686 -0
- docling/backend/mspowerpoint_backend.py +398 -0
- docling/backend/msword_backend.py +1663 -0
- docling/backend/noop_backend.py +51 -0
- docling/backend/pdf_backend.py +82 -0
- docling/backend/pypdfium2_backend.py +417 -0
- docling/backend/webvtt_backend.py +572 -0
- docling/backend/xml/__init__.py +0 -0
- docling/backend/xml/jats_backend.py +819 -0
- docling/backend/xml/uspto_backend.py +1905 -0
- docling/chunking/__init__.py +12 -0
- docling/cli/__init__.py +0 -0
- docling/cli/main.py +974 -0
- docling/cli/models.py +196 -0
- docling/cli/tools.py +17 -0
- docling/datamodel/__init__.py +0 -0
- docling/datamodel/accelerator_options.py +69 -0
- docling/datamodel/asr_model_specs.py +494 -0
- docling/datamodel/backend_options.py +102 -0
- docling/datamodel/base_models.py +493 -0
- docling/datamodel/document.py +699 -0
- docling/datamodel/extraction.py +39 -0
- docling/datamodel/layout_model_specs.py +91 -0
- docling/datamodel/pipeline_options.py +457 -0
- docling/datamodel/pipeline_options_asr_model.py +78 -0
- docling/datamodel/pipeline_options_vlm_model.py +136 -0
- docling/datamodel/settings.py +65 -0
- docling/datamodel/vlm_model_specs.py +365 -0
- docling/document_converter.py +559 -0
- docling/document_extractor.py +327 -0
- docling/exceptions.py +10 -0
- docling/experimental/__init__.py +5 -0
- docling/experimental/datamodel/__init__.py +1 -0
- docling/experimental/datamodel/table_crops_layout_options.py +13 -0
- docling/experimental/datamodel/threaded_layout_vlm_pipeline_options.py +45 -0
- docling/experimental/models/__init__.py +3 -0
- docling/experimental/models/table_crops_layout_model.py +114 -0
- docling/experimental/pipeline/__init__.py +1 -0
- docling/experimental/pipeline/threaded_layout_vlm_pipeline.py +439 -0
- docling/models/__init__.py +0 -0
- docling/models/base_layout_model.py +39 -0
- docling/models/base_model.py +230 -0
- docling/models/base_ocr_model.py +241 -0
- docling/models/base_table_model.py +45 -0
- docling/models/extraction/__init__.py +0 -0
- docling/models/extraction/nuextract_transformers_model.py +305 -0
- docling/models/factories/__init__.py +47 -0
- docling/models/factories/base_factory.py +122 -0
- docling/models/factories/layout_factory.py +7 -0
- docling/models/factories/ocr_factory.py +11 -0
- docling/models/factories/picture_description_factory.py +11 -0
- docling/models/factories/table_factory.py +7 -0
- docling/models/picture_description_base_model.py +149 -0
- docling/models/plugins/__init__.py +0 -0
- docling/models/plugins/defaults.py +60 -0
- docling/models/stages/__init__.py +0 -0
- docling/models/stages/code_formula/__init__.py +0 -0
- docling/models/stages/code_formula/code_formula_model.py +342 -0
- docling/models/stages/layout/__init__.py +0 -0
- docling/models/stages/layout/layout_model.py +249 -0
- docling/models/stages/ocr/__init__.py +0 -0
- docling/models/stages/ocr/auto_ocr_model.py +132 -0
- docling/models/stages/ocr/easyocr_model.py +200 -0
- docling/models/stages/ocr/ocr_mac_model.py +145 -0
- docling/models/stages/ocr/rapid_ocr_model.py +328 -0
- docling/models/stages/ocr/tesseract_ocr_cli_model.py +331 -0
- docling/models/stages/ocr/tesseract_ocr_model.py +262 -0
- docling/models/stages/page_assemble/__init__.py +0 -0
- docling/models/stages/page_assemble/page_assemble_model.py +156 -0
- docling/models/stages/page_preprocessing/__init__.py +0 -0
- docling/models/stages/page_preprocessing/page_preprocessing_model.py +145 -0
- docling/models/stages/picture_classifier/__init__.py +0 -0
- docling/models/stages/picture_classifier/document_picture_classifier.py +246 -0
- docling/models/stages/picture_description/__init__.py +0 -0
- docling/models/stages/picture_description/picture_description_api_model.py +66 -0
- docling/models/stages/picture_description/picture_description_vlm_model.py +123 -0
- docling/models/stages/reading_order/__init__.py +0 -0
- docling/models/stages/reading_order/readingorder_model.py +431 -0
- docling/models/stages/table_structure/__init__.py +0 -0
- docling/models/stages/table_structure/table_structure_model.py +305 -0
- docling/models/utils/__init__.py +0 -0
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +45 -0
- docling/models/vlm_pipeline_models/__init__.py +1 -0
- docling/models/vlm_pipeline_models/api_vlm_model.py +180 -0
- docling/models/vlm_pipeline_models/hf_transformers_model.py +391 -0
- docling/models/vlm_pipeline_models/mlx_model.py +325 -0
- docling/models/vlm_pipeline_models/vllm_model.py +344 -0
- docling/pipeline/__init__.py +0 -0
- docling/pipeline/asr_pipeline.py +431 -0
- docling/pipeline/base_extraction_pipeline.py +72 -0
- docling/pipeline/base_pipeline.py +326 -0
- docling/pipeline/extraction_vlm_pipeline.py +207 -0
- docling/pipeline/legacy_standard_pdf_pipeline.py +262 -0
- docling/pipeline/simple_pipeline.py +55 -0
- docling/pipeline/standard_pdf_pipeline.py +859 -0
- docling/pipeline/threaded_standard_pdf_pipeline.py +5 -0
- docling/pipeline/vlm_pipeline.py +416 -0
- docling/py.typed +1 -0
- docling/utils/__init__.py +0 -0
- docling/utils/accelerator_utils.py +97 -0
- docling/utils/api_image_request.py +205 -0
- docling/utils/deepseekocr_utils.py +388 -0
- docling/utils/export.py +146 -0
- docling/utils/glm_utils.py +361 -0
- docling/utils/layout_postprocessor.py +683 -0
- docling/utils/locks.py +3 -0
- docling/utils/model_downloader.py +168 -0
- docling/utils/ocr_utils.py +69 -0
- docling/utils/orientation.py +65 -0
- docling/utils/profiling.py +65 -0
- docling/utils/utils.py +65 -0
- docling/utils/visualization.py +85 -0
- docling-2.69.0.dist-info/METADATA +237 -0
- docling-2.69.0.dist-info/RECORD +138 -0
- docling-2.69.0.dist-info/WHEEL +5 -0
- docling-2.69.0.dist-info/entry_points.txt +6 -0
- docling-2.69.0.dist-info/licenses/LICENSE +21 -0
- docling-2.69.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
import time
|
|
5
|
+
from collections.abc import Iterable
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL.Image import Image
|
|
11
|
+
from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
|
|
12
|
+
|
|
13
|
+
from docling.datamodel.accelerator_options import (
|
|
14
|
+
AcceleratorOptions,
|
|
15
|
+
)
|
|
16
|
+
from docling.datamodel.base_models import Page, VlmPrediction, VlmStopReason
|
|
17
|
+
from docling.datamodel.document import ConversionResult
|
|
18
|
+
from docling.datamodel.pipeline_options_vlm_model import (
|
|
19
|
+
InlineVlmOptions,
|
|
20
|
+
TransformersModelType,
|
|
21
|
+
TransformersPromptStyle,
|
|
22
|
+
)
|
|
23
|
+
from docling.models.base_model import BaseVlmPageModel
|
|
24
|
+
from docling.models.utils.generation_utils import (
|
|
25
|
+
GenerationStopper,
|
|
26
|
+
HFStoppingCriteriaWrapper,
|
|
27
|
+
)
|
|
28
|
+
from docling.models.utils.hf_model_download import (
|
|
29
|
+
HuggingFaceModelDownloadMixin,
|
|
30
|
+
)
|
|
31
|
+
from docling.utils.accelerator_utils import decide_device
|
|
32
|
+
from docling.utils.profiling import TimeRecorder
|
|
33
|
+
|
|
34
|
+
_log = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
enabled: bool,
|
|
41
|
+
artifacts_path: Optional[Path],
|
|
42
|
+
accelerator_options: AcceleratorOptions,
|
|
43
|
+
vlm_options: InlineVlmOptions,
|
|
44
|
+
):
|
|
45
|
+
self.enabled = enabled
|
|
46
|
+
|
|
47
|
+
self.vlm_options = vlm_options
|
|
48
|
+
|
|
49
|
+
if self.enabled:
|
|
50
|
+
import torch
|
|
51
|
+
from transformers import (
|
|
52
|
+
AutoModel,
|
|
53
|
+
AutoModelForCausalLM,
|
|
54
|
+
AutoModelForImageTextToText,
|
|
55
|
+
AutoModelForVision2Seq,
|
|
56
|
+
AutoProcessor,
|
|
57
|
+
BitsAndBytesConfig,
|
|
58
|
+
GenerationConfig,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
transformers_version = importlib.metadata.version("transformers")
|
|
62
|
+
if (
|
|
63
|
+
self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct"
|
|
64
|
+
and transformers_version >= "4.52.0"
|
|
65
|
+
):
|
|
66
|
+
raise NotImplementedError(
|
|
67
|
+
f"Phi 4 only works with transformers<4.52.0 but you have {transformers_version=}. Please downgrage running pip install -U 'transformers<4.52.0'."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
self.device = decide_device(
|
|
71
|
+
accelerator_options.device,
|
|
72
|
+
supported_devices=vlm_options.supported_devices,
|
|
73
|
+
)
|
|
74
|
+
_log.debug(f"Available device for VLM: {self.device}")
|
|
75
|
+
|
|
76
|
+
self.use_cache = vlm_options.use_kv_cache
|
|
77
|
+
self.max_new_tokens = vlm_options.max_new_tokens
|
|
78
|
+
self.temperature = vlm_options.temperature
|
|
79
|
+
|
|
80
|
+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
81
|
+
|
|
82
|
+
if artifacts_path is None:
|
|
83
|
+
artifacts_path = self.download_models(
|
|
84
|
+
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
|
85
|
+
)
|
|
86
|
+
elif (artifacts_path / repo_cache_folder).exists():
|
|
87
|
+
artifacts_path = artifacts_path / repo_cache_folder
|
|
88
|
+
|
|
89
|
+
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
|
|
90
|
+
if vlm_options.quantized:
|
|
91
|
+
self.param_quantization_config = BitsAndBytesConfig(
|
|
92
|
+
load_in_8bit=vlm_options.load_in_8bit,
|
|
93
|
+
llm_int8_threshold=vlm_options.llm_int8_threshold,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
model_cls: Any = AutoModel
|
|
97
|
+
if (
|
|
98
|
+
self.vlm_options.transformers_model_type
|
|
99
|
+
== TransformersModelType.AUTOMODEL_CAUSALLM
|
|
100
|
+
):
|
|
101
|
+
model_cls = AutoModelForCausalLM
|
|
102
|
+
elif (
|
|
103
|
+
self.vlm_options.transformers_model_type
|
|
104
|
+
== TransformersModelType.AUTOMODEL_VISION2SEQ
|
|
105
|
+
):
|
|
106
|
+
model_cls = AutoModelForVision2Seq
|
|
107
|
+
elif (
|
|
108
|
+
self.vlm_options.transformers_model_type
|
|
109
|
+
== TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT
|
|
110
|
+
):
|
|
111
|
+
model_cls = AutoModelForImageTextToText
|
|
112
|
+
|
|
113
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
114
|
+
artifacts_path,
|
|
115
|
+
trust_remote_code=vlm_options.trust_remote_code,
|
|
116
|
+
revision=vlm_options.revision,
|
|
117
|
+
)
|
|
118
|
+
self.processor.tokenizer.padding_side = "left"
|
|
119
|
+
|
|
120
|
+
self.vlm_model = model_cls.from_pretrained(
|
|
121
|
+
artifacts_path,
|
|
122
|
+
device_map=self.device,
|
|
123
|
+
dtype=self.vlm_options.torch_dtype,
|
|
124
|
+
_attn_implementation=(
|
|
125
|
+
"flash_attention_2"
|
|
126
|
+
if self.device.startswith("cuda")
|
|
127
|
+
and accelerator_options.cuda_use_flash_attention2
|
|
128
|
+
else "sdpa"
|
|
129
|
+
),
|
|
130
|
+
trust_remote_code=vlm_options.trust_remote_code,
|
|
131
|
+
revision=vlm_options.revision,
|
|
132
|
+
)
|
|
133
|
+
if sys.version_info < (3, 14):
|
|
134
|
+
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
|
135
|
+
else:
|
|
136
|
+
self.vlm_model.eval()
|
|
137
|
+
|
|
138
|
+
# Load generation config
|
|
139
|
+
self.generation_config = GenerationConfig.from_pretrained(
|
|
140
|
+
artifacts_path, revision=vlm_options.revision
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def __call__(
|
|
144
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
145
|
+
) -> Iterable[Page]:
|
|
146
|
+
page_list = list(page_batch)
|
|
147
|
+
if not page_list:
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
valid_pages = []
|
|
151
|
+
invalid_pages = []
|
|
152
|
+
|
|
153
|
+
for page in page_list:
|
|
154
|
+
assert page._backend is not None
|
|
155
|
+
if not page._backend.is_valid():
|
|
156
|
+
invalid_pages.append(page)
|
|
157
|
+
else:
|
|
158
|
+
valid_pages.append(page)
|
|
159
|
+
|
|
160
|
+
# Process valid pages in batch
|
|
161
|
+
if valid_pages:
|
|
162
|
+
with TimeRecorder(conv_res, "vlm"):
|
|
163
|
+
# Prepare images and prompts for batch processing
|
|
164
|
+
images = []
|
|
165
|
+
user_prompts = []
|
|
166
|
+
pages_with_images = []
|
|
167
|
+
|
|
168
|
+
for page in valid_pages:
|
|
169
|
+
assert page.size is not None
|
|
170
|
+
hi_res_image = page.get_image(
|
|
171
|
+
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Only process pages with valid images
|
|
175
|
+
if hi_res_image is not None:
|
|
176
|
+
images.append(hi_res_image)
|
|
177
|
+
|
|
178
|
+
# Define prompt structure
|
|
179
|
+
user_prompt = self._build_prompt_safe(page)
|
|
180
|
+
|
|
181
|
+
user_prompts.append(user_prompt)
|
|
182
|
+
pages_with_images.append(page)
|
|
183
|
+
|
|
184
|
+
# Use process_images for the actual inference
|
|
185
|
+
if images: # Only if we have valid images
|
|
186
|
+
with TimeRecorder(conv_res, "vlm_inference"):
|
|
187
|
+
predictions = list(self.process_images(images, user_prompts))
|
|
188
|
+
|
|
189
|
+
# Attach results to pages
|
|
190
|
+
for page, prediction in zip(pages_with_images, predictions):
|
|
191
|
+
page.predictions.vlm_response = prediction
|
|
192
|
+
|
|
193
|
+
# Yield all pages (valid and invalid)
|
|
194
|
+
for page in invalid_pages:
|
|
195
|
+
yield page
|
|
196
|
+
for page in valid_pages:
|
|
197
|
+
yield page
|
|
198
|
+
|
|
199
|
+
def process_images(
|
|
200
|
+
self,
|
|
201
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
202
|
+
prompt: Union[str, list[str]],
|
|
203
|
+
) -> Iterable[VlmPrediction]:
|
|
204
|
+
"""
|
|
205
|
+
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
|
|
206
|
+
- Lets the processor handle all padding & batching for text+images.
|
|
207
|
+
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
|
|
208
|
+
- Keeps your formulate_prompt() exactly as-is.
|
|
209
|
+
"""
|
|
210
|
+
import numpy as np
|
|
211
|
+
import torch
|
|
212
|
+
from PIL import Image as PILImage
|
|
213
|
+
|
|
214
|
+
# -- Normalize images to RGB PIL
|
|
215
|
+
pil_images: list[Image] = []
|
|
216
|
+
for img in image_batch:
|
|
217
|
+
if isinstance(img, np.ndarray):
|
|
218
|
+
if img.ndim == 3 and img.shape[2] in (3, 4):
|
|
219
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
|
220
|
+
elif img.ndim == 2:
|
|
221
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
|
224
|
+
else:
|
|
225
|
+
pil_img = img
|
|
226
|
+
if pil_img.mode != "RGB":
|
|
227
|
+
pil_img = pil_img.convert("RGB")
|
|
228
|
+
pil_images.append(pil_img)
|
|
229
|
+
|
|
230
|
+
if not pil_images:
|
|
231
|
+
return
|
|
232
|
+
|
|
233
|
+
# -- Normalize prompts (1 per image)
|
|
234
|
+
if isinstance(prompt, str):
|
|
235
|
+
user_prompts = [prompt] * len(pil_images)
|
|
236
|
+
else:
|
|
237
|
+
if len(prompt) != len(pil_images):
|
|
238
|
+
raise ValueError(
|
|
239
|
+
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
|
240
|
+
)
|
|
241
|
+
user_prompts = prompt
|
|
242
|
+
|
|
243
|
+
# Use your prompt formatter verbatim
|
|
244
|
+
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
|
245
|
+
inputs = self.processor(
|
|
246
|
+
pil_images,
|
|
247
|
+
return_tensors="pt",
|
|
248
|
+
padding=True, # pad across batch for both text and vision
|
|
249
|
+
**self.vlm_options.extra_processor_kwargs,
|
|
250
|
+
)
|
|
251
|
+
else:
|
|
252
|
+
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
|
|
253
|
+
|
|
254
|
+
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
|
|
255
|
+
inputs = self.processor(
|
|
256
|
+
text=prompts,
|
|
257
|
+
images=pil_images,
|
|
258
|
+
return_tensors="pt",
|
|
259
|
+
padding=True, # pad across batch for both text and vision
|
|
260
|
+
**self.vlm_options.extra_processor_kwargs,
|
|
261
|
+
)
|
|
262
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
263
|
+
|
|
264
|
+
# -- Optional stopping criteria
|
|
265
|
+
stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
|
|
266
|
+
|
|
267
|
+
# Add string-based stopping criteria
|
|
268
|
+
if self.vlm_options.stop_strings:
|
|
269
|
+
stopping_criteria_list.append(
|
|
270
|
+
StopStringCriteria(
|
|
271
|
+
stop_strings=self.vlm_options.stop_strings,
|
|
272
|
+
tokenizer=self.processor.tokenizer,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# Add custom stopping criteria
|
|
277
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
278
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
279
|
+
# If it's a class (not an instance), determine the type and handle accordingly
|
|
280
|
+
if isinstance(criteria, type):
|
|
281
|
+
# Check if it's a GenerationStopper class
|
|
282
|
+
if issubclass(criteria, GenerationStopper):
|
|
283
|
+
# Instantiate GenerationStopper and wrap it
|
|
284
|
+
stopper_instance = criteria()
|
|
285
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
286
|
+
self.processor.tokenizer, stopper_instance
|
|
287
|
+
)
|
|
288
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
289
|
+
elif issubclass(criteria, StoppingCriteria):
|
|
290
|
+
# It's a StoppingCriteria class, instantiate with tokenizer
|
|
291
|
+
criteria_instance = criteria(self.processor.tokenizer)
|
|
292
|
+
stopping_criteria_list.append(criteria_instance)
|
|
293
|
+
elif isinstance(criteria, GenerationStopper):
|
|
294
|
+
# Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
|
|
295
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
296
|
+
self.processor.tokenizer, criteria
|
|
297
|
+
)
|
|
298
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
299
|
+
else:
|
|
300
|
+
# If it's already an instance of StoppingCriteria, use it directly
|
|
301
|
+
stopping_criteria_list.append(criteria)
|
|
302
|
+
|
|
303
|
+
stopping_criteria = (
|
|
304
|
+
StoppingCriteriaList(stopping_criteria_list)
|
|
305
|
+
if stopping_criteria_list
|
|
306
|
+
else None
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# -- Filter out decoder-specific keys from extra_generation_config
|
|
310
|
+
decoder_keys = {
|
|
311
|
+
"skip_special_tokens",
|
|
312
|
+
"clean_up_tokenization_spaces",
|
|
313
|
+
"spaces_between_special_tokens",
|
|
314
|
+
}
|
|
315
|
+
generation_config = {
|
|
316
|
+
k: v
|
|
317
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
318
|
+
if k not in decoder_keys
|
|
319
|
+
}
|
|
320
|
+
decoder_config = {
|
|
321
|
+
k: v
|
|
322
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
323
|
+
if k in decoder_keys
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
|
327
|
+
gen_kwargs = {
|
|
328
|
+
**inputs,
|
|
329
|
+
"max_new_tokens": self.max_new_tokens,
|
|
330
|
+
"use_cache": self.use_cache,
|
|
331
|
+
"generation_config": self.generation_config,
|
|
332
|
+
**generation_config,
|
|
333
|
+
}
|
|
334
|
+
if self.temperature > 0:
|
|
335
|
+
gen_kwargs["do_sample"] = True
|
|
336
|
+
gen_kwargs["temperature"] = self.temperature
|
|
337
|
+
else:
|
|
338
|
+
gen_kwargs["do_sample"] = False
|
|
339
|
+
|
|
340
|
+
if stopping_criteria is not None:
|
|
341
|
+
gen_kwargs["stopping_criteria"] = stopping_criteria
|
|
342
|
+
|
|
343
|
+
start_time = time.time()
|
|
344
|
+
with torch.inference_mode():
|
|
345
|
+
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
|
346
|
+
generation_time = time.time() - start_time
|
|
347
|
+
|
|
348
|
+
input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
|
|
349
|
+
trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
|
|
350
|
+
|
|
351
|
+
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
|
|
352
|
+
decode_fn = getattr(self.processor, "batch_decode", None)
|
|
353
|
+
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
|
|
354
|
+
decode_fn = self.processor.tokenizer.batch_decode
|
|
355
|
+
if decode_fn is None:
|
|
356
|
+
raise RuntimeError(
|
|
357
|
+
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
decoded_texts: list[str] = decode_fn(
|
|
361
|
+
trimmed_sequences,
|
|
362
|
+
**decoder_config,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# -- Clip off pad tokens from decoded texts
|
|
366
|
+
pad_token = self.processor.tokenizer.pad_token
|
|
367
|
+
if pad_token:
|
|
368
|
+
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
|
369
|
+
|
|
370
|
+
# -- Optional logging
|
|
371
|
+
num_tokens = None
|
|
372
|
+
if generated_ids.shape[0] > 0:
|
|
373
|
+
num_tokens = int(generated_ids[0].shape[0])
|
|
374
|
+
_log.debug(
|
|
375
|
+
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
|
376
|
+
f"for batch size {generated_ids.shape[0]}."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
for i, text in enumerate(decoded_texts):
|
|
380
|
+
input_prompt = (
|
|
381
|
+
prompts[i] if self.vlm_options.track_input_prompt and prompts else None
|
|
382
|
+
)
|
|
383
|
+
# Apply decode_response to the output text
|
|
384
|
+
decoded_text = self.vlm_options.decode_response(text)
|
|
385
|
+
yield VlmPrediction(
|
|
386
|
+
text=decoded_text,
|
|
387
|
+
generation_time=generation_time,
|
|
388
|
+
num_tokens=num_tokens,
|
|
389
|
+
stop_reason=VlmStopReason.UNSPECIFIED,
|
|
390
|
+
input_prompt=input_prompt,
|
|
391
|
+
)
|