docling 2.54.0__py3-none-any.whl → 2.55.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docling might be problematic. Click here for more details.
- docling/backend/asciidoc_backend.py +1 -1
- docling/backend/html_backend.py +254 -136
- docling/backend/md_backend.py +8 -2
- docling/backend/msword_backend.py +1 -1
- docling/backend/xml/jats_backend.py +111 -7
- docling/backend/xml/uspto_backend.py +1 -1
- docling/cli/main.py +13 -1
- docling/datamodel/pipeline_options_vlm_model.py +13 -2
- docling/datamodel/vlm_model_specs.py +9 -0
- docling/models/api_vlm_model.py +45 -16
- docling/models/base_model.py +2 -1
- docling/models/readingorder_model.py +57 -6
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +6 -1
- docling/models/vlm_models_inline/hf_transformers_model.py +75 -14
- docling/models/vlm_models_inline/mlx_model.py +58 -1
- docling/models/vlm_models_inline/vllm_model.py +189 -124
- docling/utils/api_image_request.py +107 -1
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/METADATA +2 -2
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/RECORD +24 -23
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/WHEEL +0 -0
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/entry_points.txt +0 -0
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/licenses/LICENSE +0 -0
- {docling-2.54.0.dist-info → docling-2.55.1.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from PIL.Image import Image
|
|
10
|
-
from transformers import StoppingCriteriaList, StopStringCriteria
|
|
10
|
+
from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
|
|
11
11
|
|
|
12
12
|
from docling.datamodel.accelerator_options import (
|
|
13
13
|
AcceleratorOptions,
|
|
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
|
20
20
|
TransformersPromptStyle,
|
|
21
21
|
)
|
|
22
22
|
from docling.models.base_model import BaseVlmPageModel
|
|
23
|
+
from docling.models.utils.generation_utils import (
|
|
24
|
+
GenerationStopper,
|
|
25
|
+
HFStoppingCriteriaWrapper,
|
|
26
|
+
)
|
|
23
27
|
from docling.models.utils.hf_model_download import (
|
|
24
28
|
HuggingFaceModelDownloadMixin,
|
|
25
29
|
)
|
|
@@ -75,7 +79,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
75
79
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
76
80
|
|
|
77
81
|
if artifacts_path is None:
|
|
78
|
-
artifacts_path = self.download_models(
|
|
82
|
+
artifacts_path = self.download_models(
|
|
83
|
+
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
|
84
|
+
)
|
|
79
85
|
elif (artifacts_path / repo_cache_folder).exists():
|
|
80
86
|
artifacts_path = artifacts_path / repo_cache_folder
|
|
81
87
|
|
|
@@ -106,6 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
106
112
|
self.processor = AutoProcessor.from_pretrained(
|
|
107
113
|
artifacts_path,
|
|
108
114
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
115
|
+
revision=vlm_options.revision,
|
|
109
116
|
)
|
|
110
117
|
self.processor.tokenizer.padding_side = "left"
|
|
111
118
|
|
|
@@ -120,11 +127,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
120
127
|
else "sdpa"
|
|
121
128
|
),
|
|
122
129
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
130
|
+
revision=vlm_options.revision,
|
|
123
131
|
)
|
|
124
132
|
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
|
125
133
|
|
|
126
134
|
# Load generation config
|
|
127
|
-
self.generation_config = GenerationConfig.from_pretrained(
|
|
135
|
+
self.generation_config = GenerationConfig.from_pretrained(
|
|
136
|
+
artifacts_path, revision=vlm_options.revision
|
|
137
|
+
)
|
|
128
138
|
|
|
129
139
|
def __call__(
|
|
130
140
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
@@ -196,7 +206,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
196
206
|
import torch
|
|
197
207
|
from PIL import Image as PILImage
|
|
198
208
|
|
|
199
|
-
# -- Normalize images to RGB PIL
|
|
209
|
+
# -- Normalize images to RGB PIL
|
|
200
210
|
pil_images: list[Image] = []
|
|
201
211
|
for img in image_batch:
|
|
202
212
|
if isinstance(img, np.ndarray):
|
|
@@ -247,24 +257,74 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
247
257
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
248
258
|
|
|
249
259
|
# -- Optional stopping criteria
|
|
250
|
-
|
|
260
|
+
stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
|
|
261
|
+
|
|
262
|
+
# Add string-based stopping criteria
|
|
251
263
|
if self.vlm_options.stop_strings:
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
)
|
|
258
|
-
]
|
|
264
|
+
stopping_criteria_list.append(
|
|
265
|
+
StopStringCriteria(
|
|
266
|
+
stop_strings=self.vlm_options.stop_strings,
|
|
267
|
+
tokenizer=self.processor.tokenizer,
|
|
268
|
+
)
|
|
259
269
|
)
|
|
260
270
|
|
|
271
|
+
# Add custom stopping criteria
|
|
272
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
273
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
274
|
+
# If it's a class (not an instance), determine the type and handle accordingly
|
|
275
|
+
if isinstance(criteria, type):
|
|
276
|
+
# Check if it's a GenerationStopper class
|
|
277
|
+
if issubclass(criteria, GenerationStopper):
|
|
278
|
+
# Instantiate GenerationStopper and wrap it
|
|
279
|
+
stopper_instance = criteria()
|
|
280
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
281
|
+
self.processor.tokenizer, stopper_instance
|
|
282
|
+
)
|
|
283
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
284
|
+
elif issubclass(criteria, StoppingCriteria):
|
|
285
|
+
# It's a StoppingCriteria class, instantiate with tokenizer
|
|
286
|
+
criteria_instance = criteria(self.processor.tokenizer)
|
|
287
|
+
stopping_criteria_list.append(criteria_instance)
|
|
288
|
+
elif isinstance(criteria, GenerationStopper):
|
|
289
|
+
# Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
|
|
290
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
291
|
+
self.processor.tokenizer, criteria
|
|
292
|
+
)
|
|
293
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
294
|
+
else:
|
|
295
|
+
# If it's already an instance of StoppingCriteria, use it directly
|
|
296
|
+
stopping_criteria_list.append(criteria)
|
|
297
|
+
|
|
298
|
+
stopping_criteria = (
|
|
299
|
+
StoppingCriteriaList(stopping_criteria_list)
|
|
300
|
+
if stopping_criteria_list
|
|
301
|
+
else None
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# -- Filter out decoder-specific keys from extra_generation_config
|
|
305
|
+
decoder_keys = {
|
|
306
|
+
"skip_special_tokens",
|
|
307
|
+
"clean_up_tokenization_spaces",
|
|
308
|
+
"spaces_between_special_tokens",
|
|
309
|
+
}
|
|
310
|
+
generation_config = {
|
|
311
|
+
k: v
|
|
312
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
313
|
+
if k not in decoder_keys
|
|
314
|
+
}
|
|
315
|
+
decoder_config = {
|
|
316
|
+
k: v
|
|
317
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
318
|
+
if k in decoder_keys
|
|
319
|
+
}
|
|
320
|
+
|
|
261
321
|
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
|
262
322
|
gen_kwargs = {
|
|
263
323
|
**inputs,
|
|
264
324
|
"max_new_tokens": self.max_new_tokens,
|
|
265
325
|
"use_cache": self.use_cache,
|
|
266
326
|
"generation_config": self.generation_config,
|
|
267
|
-
**
|
|
327
|
+
**generation_config,
|
|
268
328
|
}
|
|
269
329
|
if self.temperature > 0:
|
|
270
330
|
gen_kwargs["do_sample"] = True
|
|
@@ -293,7 +353,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
293
353
|
)
|
|
294
354
|
|
|
295
355
|
decoded_texts: list[str] = decode_fn(
|
|
296
|
-
trimmed_sequences,
|
|
356
|
+
trimmed_sequences,
|
|
357
|
+
**decoder_config,
|
|
297
358
|
)
|
|
298
359
|
|
|
299
360
|
# -- Clip off pad tokens from decoded texts
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
2
3
|
import threading
|
|
3
4
|
import time
|
|
4
5
|
from collections.abc import Iterable
|
|
@@ -7,6 +8,7 @@ from typing import Optional, Union
|
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from PIL.Image import Image
|
|
11
|
+
from transformers import StoppingCriteria
|
|
10
12
|
|
|
11
13
|
from docling.datamodel.accelerator_options import (
|
|
12
14
|
AcceleratorOptions,
|
|
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
|
|
|
15
17
|
from docling.datamodel.document import ConversionResult
|
|
16
18
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
|
17
19
|
from docling.models.base_model import BaseVlmPageModel
|
|
20
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
18
21
|
from docling.models.utils.hf_model_download import (
|
|
19
22
|
HuggingFaceModelDownloadMixin,
|
|
20
23
|
)
|
|
@@ -60,6 +63,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
60
63
|
if artifacts_path is None:
|
|
61
64
|
artifacts_path = self.download_models(
|
|
62
65
|
self.vlm_options.repo_id,
|
|
66
|
+
revision=self.vlm_options.revision,
|
|
63
67
|
)
|
|
64
68
|
elif (artifacts_path / repo_cache_folder).exists():
|
|
65
69
|
artifacts_path = artifacts_path / repo_cache_folder
|
|
@@ -68,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
68
72
|
self.vlm_model, self.processor = load(artifacts_path)
|
|
69
73
|
self.config = load_config(artifacts_path)
|
|
70
74
|
|
|
75
|
+
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
|
|
76
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
77
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
78
|
+
if isinstance(criteria, StoppingCriteria):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"MLX models do not support HuggingFace StoppingCriteria instances. "
|
|
81
|
+
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
|
|
82
|
+
)
|
|
83
|
+
elif isinstance(criteria, type) and issubclass(
|
|
84
|
+
criteria, StoppingCriteria
|
|
85
|
+
):
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"MLX models do not support HuggingFace StoppingCriteria classes. "
|
|
88
|
+
f"Found {criteria.__name__}. Use GenerationStopper instead."
|
|
89
|
+
)
|
|
90
|
+
|
|
71
91
|
def __call__(
|
|
72
92
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
73
93
|
) -> Iterable[Page]:
|
|
@@ -192,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
192
212
|
self.processor, self.config, user_prompt, num_images=1
|
|
193
213
|
)
|
|
194
214
|
|
|
195
|
-
# Stream generate with stop strings support
|
|
215
|
+
# Stream generate with stop strings and custom stopping criteria support
|
|
196
216
|
start_time = time.time()
|
|
197
217
|
_log.debug("start generating ...")
|
|
198
218
|
|
|
@@ -244,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
244
264
|
_log.debug("Stopping generation due to stop string match")
|
|
245
265
|
break
|
|
246
266
|
|
|
267
|
+
# Check for custom stopping criteria (GenerationStopper instances)
|
|
268
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
269
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
270
|
+
# Handle both instances and classes of GenerationStopper
|
|
271
|
+
if isinstance(criteria, GenerationStopper):
|
|
272
|
+
stopper = criteria
|
|
273
|
+
elif isinstance(criteria, type) and issubclass(
|
|
274
|
+
criteria, GenerationStopper
|
|
275
|
+
):
|
|
276
|
+
stopper = criteria()
|
|
277
|
+
|
|
278
|
+
# Determine the text window to check based on lookback_tokens
|
|
279
|
+
lookback_tokens = stopper.lookback_tokens()
|
|
280
|
+
# Check only the last N characters worth of text
|
|
281
|
+
# This is a simplified approach - in practice, you might want to
|
|
282
|
+
# decode the last N tokens from the token list for more accuracy
|
|
283
|
+
text_to_check = (
|
|
284
|
+
output[-lookback_tokens:]
|
|
285
|
+
if len(output) > lookback_tokens
|
|
286
|
+
else output
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
if stopper.should_stop(text_to_check):
|
|
291
|
+
_log.info(
|
|
292
|
+
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
|
|
293
|
+
)
|
|
294
|
+
break
|
|
295
|
+
except Exception as e:
|
|
296
|
+
_log.warning(
|
|
297
|
+
f"Error in GenerationStopper.should_stop: {e}"
|
|
298
|
+
)
|
|
299
|
+
continue
|
|
300
|
+
else: # note: for-else idiom
|
|
301
|
+
continue # Only executed if the inner loop didn't break
|
|
302
|
+
break # Break the outer loop if any stopper triggered
|
|
303
|
+
|
|
247
304
|
generation_time = time.time() - start_time
|
|
248
305
|
|
|
249
306
|
_log.debug(
|
|
@@ -7,9 +7,7 @@ from typing import Any, Dict, Optional, Union
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
from PIL.Image import Image
|
|
9
9
|
|
|
10
|
-
from docling.datamodel.accelerator_options import
|
|
11
|
-
AcceleratorOptions,
|
|
12
|
-
)
|
|
10
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
|
13
11
|
from docling.datamodel.base_models import Page, VlmPrediction
|
|
14
12
|
from docling.datamodel.document import ConversionResult
|
|
15
13
|
from docling.datamodel.pipeline_options_vlm_model import (
|
|
@@ -17,9 +15,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
|
17
15
|
TransformersPromptStyle,
|
|
18
16
|
)
|
|
19
17
|
from docling.models.base_model import BaseVlmPageModel
|
|
20
|
-
from docling.models.utils.hf_model_download import
|
|
21
|
-
HuggingFaceModelDownloadMixin,
|
|
22
|
-
)
|
|
18
|
+
from docling.models.utils.hf_model_download import HuggingFaceModelDownloadMixin
|
|
23
19
|
from docling.utils.accelerator_utils import decide_device
|
|
24
20
|
from docling.utils.profiling import TimeRecorder
|
|
25
21
|
|
|
@@ -27,6 +23,62 @@ _log = logging.getLogger(__name__)
|
|
|
27
23
|
|
|
28
24
|
|
|
29
25
|
class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
26
|
+
"""
|
|
27
|
+
vLLM-backed vision-language model that accepts PIL images (or numpy arrays)
|
|
28
|
+
via vLLM's multi_modal_data, with prompt formatting handled by formulate_prompt().
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
# --------- Allowlist of vLLM args ---------
|
|
32
|
+
# SamplingParams (runtime generation controls)
|
|
33
|
+
_VLLM_SAMPLING_KEYS = {
|
|
34
|
+
# Core
|
|
35
|
+
"max_tokens",
|
|
36
|
+
"temperature",
|
|
37
|
+
"top_p",
|
|
38
|
+
"top_k",
|
|
39
|
+
# Penalties
|
|
40
|
+
"presence_penalty",
|
|
41
|
+
"frequency_penalty",
|
|
42
|
+
"repetition_penalty",
|
|
43
|
+
# Stops / outputs
|
|
44
|
+
"stop",
|
|
45
|
+
"stop_token_ids",
|
|
46
|
+
"skip_special_tokens",
|
|
47
|
+
"spaces_between_special_tokens",
|
|
48
|
+
# Search / length
|
|
49
|
+
"n",
|
|
50
|
+
"best_of",
|
|
51
|
+
"length_penalty",
|
|
52
|
+
"early_stopping",
|
|
53
|
+
# Misc
|
|
54
|
+
"logprobs",
|
|
55
|
+
"prompt_logprobs",
|
|
56
|
+
"min_p",
|
|
57
|
+
"seed",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
# LLM(...) / EngineArgs (engine/load-time controls)
|
|
61
|
+
_VLLM_ENGINE_KEYS = {
|
|
62
|
+
# Model/tokenizer/impl
|
|
63
|
+
"tokenizer",
|
|
64
|
+
"tokenizer_mode",
|
|
65
|
+
"download_dir",
|
|
66
|
+
# Parallelism / memory / lengths
|
|
67
|
+
"tensor_parallel_size",
|
|
68
|
+
"pipeline_parallel_size",
|
|
69
|
+
"gpu_memory_utilization",
|
|
70
|
+
"max_model_len",
|
|
71
|
+
"max_num_batched_tokens",
|
|
72
|
+
"kv_cache_dtype",
|
|
73
|
+
"dtype",
|
|
74
|
+
# Quantization (coarse switch)
|
|
75
|
+
"quantization",
|
|
76
|
+
# Multimodal limits
|
|
77
|
+
"limit_mm_per_prompt",
|
|
78
|
+
# Execution toggles
|
|
79
|
+
"enforce_eager",
|
|
80
|
+
}
|
|
81
|
+
|
|
30
82
|
def __init__(
|
|
31
83
|
self,
|
|
32
84
|
enabled: bool,
|
|
@@ -35,120 +87,147 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
35
87
|
vlm_options: InlineVlmOptions,
|
|
36
88
|
):
|
|
37
89
|
self.enabled = enabled
|
|
38
|
-
|
|
39
90
|
self.vlm_options = vlm_options
|
|
40
91
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
supported_devices=vlm_options.supported_devices,
|
|
48
|
-
)
|
|
49
|
-
_log.debug(f"Available device for VLM: {self.device}")
|
|
50
|
-
|
|
51
|
-
self.max_new_tokens = vlm_options.max_new_tokens
|
|
52
|
-
self.temperature = vlm_options.temperature
|
|
53
|
-
|
|
54
|
-
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
92
|
+
self.llm = None
|
|
93
|
+
self.sampling_params = None
|
|
94
|
+
self.processor = None # used for CHAT templating in formulate_prompt()
|
|
95
|
+
self.device = "cpu"
|
|
96
|
+
self.max_new_tokens = vlm_options.max_new_tokens
|
|
97
|
+
self.temperature = vlm_options.temperature
|
|
55
98
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
elif (artifacts_path / repo_cache_folder).exists():
|
|
59
|
-
artifacts_path = artifacts_path / repo_cache_folder
|
|
60
|
-
|
|
61
|
-
# Initialize VLLM LLM
|
|
62
|
-
llm_kwargs: Dict[str, Any] = {
|
|
63
|
-
"model": str(artifacts_path),
|
|
64
|
-
"limit_mm_per_prompt": {"image": 1},
|
|
65
|
-
"trust_remote_code": vlm_options.trust_remote_code,
|
|
66
|
-
"model_impl": "transformers",
|
|
67
|
-
"gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
|
|
68
|
-
}
|
|
69
|
-
|
|
70
|
-
# Add device-specific configurations
|
|
71
|
-
|
|
72
|
-
if self.device == "cpu":
|
|
73
|
-
llm_kwargs["device"] = "cpu"
|
|
99
|
+
if not self.enabled:
|
|
100
|
+
return
|
|
74
101
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
if vlm_options.load_in_8bit:
|
|
78
|
-
llm_kwargs["quantization"] = "bitsandbytes"
|
|
102
|
+
from transformers import AutoProcessor
|
|
103
|
+
from vllm import LLM, SamplingParams
|
|
79
104
|
|
|
80
|
-
|
|
105
|
+
# Device selection
|
|
106
|
+
self.device = decide_device(
|
|
107
|
+
accelerator_options.device, supported_devices=vlm_options.supported_devices
|
|
108
|
+
)
|
|
109
|
+
_log.debug(f"Available device for VLM: {self.device}")
|
|
81
110
|
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
111
|
+
# Resolve artifacts path / cache folder
|
|
112
|
+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
113
|
+
if artifacts_path is None:
|
|
114
|
+
artifacts_path = self.download_models(
|
|
115
|
+
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
|
86
116
|
)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
117
|
+
elif (artifacts_path / repo_cache_folder).exists():
|
|
118
|
+
artifacts_path = artifacts_path / repo_cache_folder
|
|
119
|
+
|
|
120
|
+
# --------- Strict split & validation of extra_generation_config ---------
|
|
121
|
+
extra_cfg = self.vlm_options.extra_generation_config
|
|
122
|
+
|
|
123
|
+
load_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_ENGINE_KEYS}
|
|
124
|
+
gen_cfg = {k: v for k, v in extra_cfg.items() if k in self._VLLM_SAMPLING_KEYS}
|
|
125
|
+
|
|
126
|
+
unknown = sorted(
|
|
127
|
+
k
|
|
128
|
+
for k in extra_cfg.keys()
|
|
129
|
+
if k not in self._VLLM_ENGINE_KEYS and k not in self._VLLM_SAMPLING_KEYS
|
|
130
|
+
)
|
|
131
|
+
if unknown:
|
|
132
|
+
_log.warning(
|
|
133
|
+
"Ignoring unknown extra_generation_config keys for vLLM: %s", unknown
|
|
94
134
|
)
|
|
95
135
|
|
|
136
|
+
# --------- Construct LLM kwargs (engine/load-time) ---------
|
|
137
|
+
llm_kwargs: Dict[str, Any] = {
|
|
138
|
+
"model": str(artifacts_path),
|
|
139
|
+
"model_impl": "transformers",
|
|
140
|
+
"limit_mm_per_prompt": {"image": 1},
|
|
141
|
+
"revision": self.vlm_options.revision,
|
|
142
|
+
"trust_remote_code": self.vlm_options.trust_remote_code,
|
|
143
|
+
**load_cfg,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
if self.device == "cpu":
|
|
147
|
+
llm_kwargs.setdefault("enforce_eager", True)
|
|
148
|
+
else:
|
|
149
|
+
llm_kwargs.setdefault(
|
|
150
|
+
"gpu_memory_utilization", 0.3
|
|
151
|
+
) # room for other models
|
|
152
|
+
|
|
153
|
+
# Quantization (kept as-is; coarse)
|
|
154
|
+
if self.vlm_options.quantized and self.vlm_options.load_in_8bit:
|
|
155
|
+
llm_kwargs.setdefault("quantization", "bitsandbytes")
|
|
156
|
+
|
|
157
|
+
# Initialize vLLM LLM
|
|
158
|
+
self.llm = LLM(**llm_kwargs)
|
|
159
|
+
|
|
160
|
+
# Initialize processor for prompt templating (needed for CHAT style)
|
|
161
|
+
self.processor = AutoProcessor.from_pretrained(
|
|
162
|
+
artifacts_path,
|
|
163
|
+
trust_remote_code=self.vlm_options.trust_remote_code,
|
|
164
|
+
revision=self.vlm_options.revision,
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
# --------- SamplingParams (runtime) ---------
|
|
168
|
+
self.sampling_params = SamplingParams(
|
|
169
|
+
temperature=self.temperature,
|
|
170
|
+
max_tokens=self.max_new_tokens,
|
|
171
|
+
stop=(self.vlm_options.stop_strings or None),
|
|
172
|
+
**gen_cfg,
|
|
173
|
+
)
|
|
174
|
+
|
|
96
175
|
def __call__(
|
|
97
176
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
98
177
|
) -> Iterable[Page]:
|
|
178
|
+
# If disabled, pass-through
|
|
179
|
+
if not self.enabled:
|
|
180
|
+
for page in page_batch:
|
|
181
|
+
yield page
|
|
182
|
+
return
|
|
183
|
+
|
|
99
184
|
page_list = list(page_batch)
|
|
100
185
|
if not page_list:
|
|
101
186
|
return
|
|
102
187
|
|
|
103
|
-
|
|
104
|
-
|
|
188
|
+
# Preserve original order
|
|
189
|
+
original_order = page_list[:]
|
|
105
190
|
|
|
191
|
+
# Separate valid/invalid
|
|
192
|
+
valid_pages: list[Page] = []
|
|
193
|
+
invalid_pages: list[Page] = []
|
|
106
194
|
for page in page_list:
|
|
107
195
|
assert page._backend is not None
|
|
108
|
-
if
|
|
109
|
-
invalid_pages.append(page)
|
|
110
|
-
else:
|
|
196
|
+
if page._backend.is_valid():
|
|
111
197
|
valid_pages.append(page)
|
|
198
|
+
else:
|
|
199
|
+
invalid_pages.append(page)
|
|
112
200
|
|
|
113
|
-
# Process valid pages in batch
|
|
114
201
|
if valid_pages:
|
|
115
202
|
with TimeRecorder(conv_res, "vlm"):
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
pages_with_images = []
|
|
203
|
+
images: list[Image] = []
|
|
204
|
+
user_prompts: list[str] = []
|
|
205
|
+
pages_with_images: list[Page] = []
|
|
120
206
|
|
|
121
207
|
for page in valid_pages:
|
|
122
208
|
assert page.size is not None
|
|
123
209
|
hi_res_image = page.get_image(
|
|
124
|
-
scale=self.vlm_options.scale,
|
|
210
|
+
scale=self.vlm_options.scale,
|
|
211
|
+
max_size=self.vlm_options.max_size,
|
|
125
212
|
)
|
|
213
|
+
if hi_res_image is None:
|
|
214
|
+
continue
|
|
126
215
|
|
|
127
|
-
|
|
128
|
-
if hi_res_image is not None:
|
|
129
|
-
images.append(hi_res_image)
|
|
216
|
+
images.append(hi_res_image)
|
|
130
217
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
user_prompt = self.vlm_options.prompt(page.parsed_page)
|
|
134
|
-
else:
|
|
135
|
-
user_prompt = self.vlm_options.prompt
|
|
218
|
+
# Define prompt structure
|
|
219
|
+
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
|
136
220
|
|
|
137
|
-
|
|
138
|
-
|
|
221
|
+
user_prompts.append(user_prompt)
|
|
222
|
+
pages_with_images.append(page)
|
|
139
223
|
|
|
140
|
-
|
|
141
|
-
if images: # Only if we have valid images
|
|
224
|
+
if images:
|
|
142
225
|
predictions = list(self.process_images(images, user_prompts))
|
|
143
|
-
|
|
144
|
-
# Attach results to pages
|
|
145
226
|
for page, prediction in zip(pages_with_images, predictions):
|
|
146
227
|
page.predictions.vlm_response = prediction
|
|
147
228
|
|
|
148
|
-
# Yield
|
|
149
|
-
for page in
|
|
150
|
-
yield page
|
|
151
|
-
for page in valid_pages:
|
|
229
|
+
# Yield in original order
|
|
230
|
+
for page in original_order:
|
|
152
231
|
yield page
|
|
153
232
|
|
|
154
233
|
def process_images(
|
|
@@ -156,50 +235,33 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
156
235
|
image_batch: Iterable[Union[Image, np.ndarray]],
|
|
157
236
|
prompt: Union[str, list[str]],
|
|
158
237
|
) -> Iterable[VlmPrediction]:
|
|
159
|
-
"""Process
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
image_batch: Iterable of PIL Images or numpy arrays
|
|
163
|
-
prompt: Either:
|
|
164
|
-
- str: Single prompt used for all images
|
|
165
|
-
- list[str]: List of prompts (one per image, must match image count)
|
|
238
|
+
"""Process images in a single batched vLLM inference call."""
|
|
239
|
+
import numpy as np
|
|
240
|
+
from PIL import Image as PILImage
|
|
166
241
|
|
|
167
|
-
|
|
168
|
-
ValueError: If prompt list length doesn't match image count.
|
|
169
|
-
"""
|
|
242
|
+
# -- Normalize images to RGB PIL
|
|
170
243
|
pil_images: list[Image] = []
|
|
171
|
-
|
|
172
244
|
for img in image_batch:
|
|
173
|
-
# Convert numpy array to PIL Image if needed
|
|
174
245
|
if isinstance(img, np.ndarray):
|
|
175
|
-
if img.ndim == 3 and img.shape[2] in
|
|
176
|
-
from PIL import Image as PILImage
|
|
177
|
-
|
|
246
|
+
if img.ndim == 3 and img.shape[2] in (3, 4):
|
|
178
247
|
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
|
179
248
|
elif img.ndim == 2:
|
|
180
|
-
from PIL import Image as PILImage
|
|
181
|
-
|
|
182
249
|
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
|
183
250
|
else:
|
|
184
251
|
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
|
185
252
|
else:
|
|
186
253
|
pil_img = img
|
|
187
|
-
|
|
188
|
-
# Ensure image is in RGB mode (handles RGBA, L, etc.)
|
|
189
254
|
if pil_img.mode != "RGB":
|
|
190
255
|
pil_img = pil_img.convert("RGB")
|
|
191
|
-
|
|
192
256
|
pil_images.append(pil_img)
|
|
193
257
|
|
|
194
|
-
if
|
|
258
|
+
if not pil_images:
|
|
195
259
|
return
|
|
196
260
|
|
|
197
|
-
#
|
|
261
|
+
# Normalize prompts
|
|
198
262
|
if isinstance(prompt, str):
|
|
199
|
-
# Single prompt for all images
|
|
200
263
|
user_prompts = [prompt] * len(pil_images)
|
|
201
264
|
elif isinstance(prompt, list):
|
|
202
|
-
# List of prompts (one per image)
|
|
203
265
|
if len(prompt) != len(pil_images):
|
|
204
266
|
raise ValueError(
|
|
205
267
|
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
|
@@ -208,28 +270,31 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
208
270
|
else:
|
|
209
271
|
raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
|
|
210
272
|
|
|
211
|
-
# Format prompts
|
|
212
|
-
prompts: list[str] = [
|
|
213
|
-
self.formulate_prompt(user_prompt) for user_prompt in user_prompts
|
|
214
|
-
]
|
|
273
|
+
# Format prompts
|
|
274
|
+
prompts: list[str] = [self.formulate_prompt(up) for up in user_prompts]
|
|
215
275
|
|
|
216
|
-
#
|
|
217
|
-
llm_inputs = [
|
|
218
|
-
|
|
219
|
-
|
|
276
|
+
# Build vLLM inputs
|
|
277
|
+
llm_inputs = [
|
|
278
|
+
{"prompt": p, "multi_modal_data": {"image": im}}
|
|
279
|
+
for p, im in zip(prompts, pil_images)
|
|
280
|
+
]
|
|
220
281
|
|
|
282
|
+
# Generate
|
|
283
|
+
assert self.llm is not None and self.sampling_params is not None
|
|
221
284
|
start_time = time.time()
|
|
222
285
|
outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
|
|
223
286
|
generation_time = time.time() - start_time
|
|
224
287
|
|
|
225
|
-
#
|
|
226
|
-
if
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
f"Generated {num_tokens} tokens in
|
|
230
|
-
|
|
288
|
+
# Optional debug
|
|
289
|
+
if outputs:
|
|
290
|
+
try:
|
|
291
|
+
num_tokens = len(outputs[0].outputs[0].token_ids)
|
|
292
|
+
_log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
|
|
293
|
+
except Exception:
|
|
294
|
+
pass
|
|
231
295
|
|
|
296
|
+
# Emit predictions
|
|
232
297
|
for output in outputs:
|
|
233
|
-
|
|
234
|
-
decoded_text = self.vlm_options.decode_response(
|
|
298
|
+
text = output.outputs[0].text if output.outputs else ""
|
|
299
|
+
decoded_text = self.vlm_options.decode_response(text)
|
|
235
300
|
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|