docling 2.46.0__py3-none-any.whl → 2.47.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/html_backend.py +111 -13
- docling/backend/msword_backend.py +126 -16
- docling/cli/main.py +14 -0
- docling/cli/models.py +56 -0
- docling/datamodel/base_models.py +1 -1
- docling/datamodel/pipeline_options.py +3 -0
- docling/datamodel/pipeline_options_vlm_model.py +5 -0
- docling/datamodel/vlm_model_specs.py +114 -1
- docling/models/base_model.py +95 -2
- docling/models/page_preprocessing_model.py +5 -1
- docling/models/picture_description_vlm_model.py +4 -2
- docling/models/vlm_models_inline/__init__.py +1 -0
- docling/models/vlm_models_inline/hf_transformers_model.py +179 -79
- docling/models/vlm_models_inline/mlx_model.py +179 -68
- docling/models/vlm_models_inline/vllm_model.py +235 -0
- docling/pipeline/threaded_standard_pdf_pipeline.py +1 -1
- docling/pipeline/vlm_pipeline.py +14 -1
- docling/utils/layout_postprocessor.py +51 -43
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/METADATA +2 -1
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/RECORD +24 -23
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/WHEEL +0 -0
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/entry_points.txt +0 -0
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/licenses/LICENSE +0 -0
- {docling-2.46.0.dist-info → docling-2.47.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
12
12
|
InlineVlmOptions,
|
13
13
|
ResponseFormat,
|
14
14
|
TransformersModelType,
|
15
|
+
TransformersPromptStyle,
|
15
16
|
)
|
16
17
|
|
17
18
|
_log = logging.getLogger(__name__)
|
@@ -26,6 +27,7 @@ SMOLDOCLING_MLX = InlineVlmOptions(
|
|
26
27
|
supported_devices=[AcceleratorDevice.MPS],
|
27
28
|
scale=2.0,
|
28
29
|
temperature=0.0,
|
30
|
+
stop_strings=["</doctag>", "<end_of_utterance>"],
|
29
31
|
)
|
30
32
|
|
31
33
|
SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
@@ -33,16 +35,74 @@ SMOLDOCLING_TRANSFORMERS = InlineVlmOptions(
|
|
33
35
|
prompt="Convert this page to docling.",
|
34
36
|
response_format=ResponseFormat.DOCTAGS,
|
35
37
|
inference_framework=InferenceFramework.TRANSFORMERS,
|
36
|
-
transformers_model_type=TransformersModelType.
|
38
|
+
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
37
39
|
supported_devices=[
|
38
40
|
AcceleratorDevice.CPU,
|
39
41
|
AcceleratorDevice.CUDA,
|
42
|
+
],
|
43
|
+
torch_dtype="bfloat16",
|
44
|
+
scale=2.0,
|
45
|
+
temperature=0.0,
|
46
|
+
stop_strings=["</doctag>", "<end_of_utterance>"],
|
47
|
+
)
|
48
|
+
|
49
|
+
SMOLDOCLING_VLLM = InlineVlmOptions(
|
50
|
+
repo_id="ds4sd/SmolDocling-256M-preview",
|
51
|
+
prompt="Convert this page to docling.",
|
52
|
+
response_format=ResponseFormat.DOCTAGS,
|
53
|
+
inference_framework=InferenceFramework.VLLM,
|
54
|
+
supported_devices=[
|
55
|
+
AcceleratorDevice.CUDA,
|
56
|
+
],
|
57
|
+
scale=2.0,
|
58
|
+
temperature=0.0,
|
59
|
+
stop_strings=["</doctag>", "<end_of_utterance>"],
|
60
|
+
)
|
61
|
+
|
62
|
+
# SmolVLM-256M-Instruct
|
63
|
+
SMOLVLM256_TRANSFORMERS = InlineVlmOptions(
|
64
|
+
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
|
65
|
+
prompt="Transcribe this image to plain text.",
|
66
|
+
response_format=ResponseFormat.PLAINTEXT,
|
67
|
+
inference_framework=InferenceFramework.TRANSFORMERS,
|
68
|
+
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
69
|
+
supported_devices=[
|
70
|
+
AcceleratorDevice.CPU,
|
71
|
+
AcceleratorDevice.CUDA,
|
72
|
+
# AcceleratorDevice.MPS,
|
73
|
+
],
|
74
|
+
torch_dtype="bfloat16",
|
75
|
+
scale=2.0,
|
76
|
+
temperature=0.0,
|
77
|
+
)
|
78
|
+
|
79
|
+
# SmolVLM2-2.2b-Instruct
|
80
|
+
SMOLVLM256_MLX = InlineVlmOptions(
|
81
|
+
repo_id="moot20/SmolVLM-256M-Instruct-MLX",
|
82
|
+
prompt="Extract the text.",
|
83
|
+
response_format=ResponseFormat.DOCTAGS,
|
84
|
+
inference_framework=InferenceFramework.MLX,
|
85
|
+
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
86
|
+
supported_devices=[
|
40
87
|
AcceleratorDevice.MPS,
|
41
88
|
],
|
42
89
|
scale=2.0,
|
43
90
|
temperature=0.0,
|
44
91
|
)
|
45
92
|
|
93
|
+
SMOLVLM256_VLLM = InlineVlmOptions(
|
94
|
+
repo_id="HuggingFaceTB/SmolVLM-256M-Instruct",
|
95
|
+
prompt="Transcribe this image to plain text.",
|
96
|
+
response_format=ResponseFormat.PLAINTEXT,
|
97
|
+
inference_framework=InferenceFramework.VLLM,
|
98
|
+
supported_devices=[
|
99
|
+
AcceleratorDevice.CUDA,
|
100
|
+
],
|
101
|
+
scale=2.0,
|
102
|
+
temperature=0.0,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
46
106
|
# GraniteVision
|
47
107
|
GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
48
108
|
repo_id="ibm-granite/granite-vision-3.2-2b",
|
@@ -59,6 +119,18 @@ GRANITE_VISION_TRANSFORMERS = InlineVlmOptions(
|
|
59
119
|
temperature=0.0,
|
60
120
|
)
|
61
121
|
|
122
|
+
GRANITE_VISION_VLLM = InlineVlmOptions(
|
123
|
+
repo_id="ibm-granite/granite-vision-3.2-2b",
|
124
|
+
prompt="Convert this page to markdown. Do not miss any text and only output the bare markdown!",
|
125
|
+
response_format=ResponseFormat.MARKDOWN,
|
126
|
+
inference_framework=InferenceFramework.VLLM,
|
127
|
+
supported_devices=[
|
128
|
+
AcceleratorDevice.CUDA,
|
129
|
+
],
|
130
|
+
scale=2.0,
|
131
|
+
temperature=0.0,
|
132
|
+
)
|
133
|
+
|
62
134
|
GRANITE_VISION_OLLAMA = ApiVlmOptions(
|
63
135
|
url=AnyUrl("http://localhost:11434/v1/chat/completions"),
|
64
136
|
params={"model": "granite3.2-vision:2b"},
|
@@ -116,6 +188,26 @@ QWEN25_VL_3B_MLX = InlineVlmOptions(
|
|
116
188
|
temperature=0.0,
|
117
189
|
)
|
118
190
|
|
191
|
+
# GoT 2.0
|
192
|
+
GOT2_TRANSFORMERS = InlineVlmOptions(
|
193
|
+
repo_id="stepfun-ai/GOT-OCR-2.0-hf",
|
194
|
+
prompt="",
|
195
|
+
response_format=ResponseFormat.MARKDOWN,
|
196
|
+
inference_framework=InferenceFramework.TRANSFORMERS,
|
197
|
+
transformers_prompt_style=TransformersPromptStyle.NONE,
|
198
|
+
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
199
|
+
supported_devices=[
|
200
|
+
AcceleratorDevice.CPU,
|
201
|
+
AcceleratorDevice.CUDA,
|
202
|
+
# AcceleratorDevice.MPS,
|
203
|
+
],
|
204
|
+
scale=2.0,
|
205
|
+
temperature=0.0,
|
206
|
+
stop_strings=["<|im_end|>"],
|
207
|
+
extra_processor_kwargs={"format": True},
|
208
|
+
)
|
209
|
+
|
210
|
+
|
119
211
|
# Gemma-3
|
120
212
|
GEMMA3_12B_MLX = InlineVlmOptions(
|
121
213
|
repo_id="mlx-community/gemma-3-12b-it-bf16",
|
@@ -137,8 +229,29 @@ GEMMA3_27B_MLX = InlineVlmOptions(
|
|
137
229
|
temperature=0.0,
|
138
230
|
)
|
139
231
|
|
232
|
+
# Dolphin
|
233
|
+
|
234
|
+
DOLPHIN_TRANSFORMERS = InlineVlmOptions(
|
235
|
+
repo_id="ByteDance/Dolphin",
|
236
|
+
prompt="<s>Read text in the image. <Answer/>",
|
237
|
+
response_format=ResponseFormat.MARKDOWN,
|
238
|
+
inference_framework=InferenceFramework.TRANSFORMERS,
|
239
|
+
transformers_model_type=TransformersModelType.AUTOMODEL_IMAGETEXTTOTEXT,
|
240
|
+
transformers_prompt_style=TransformersPromptStyle.RAW,
|
241
|
+
supported_devices=[
|
242
|
+
AcceleratorDevice.CUDA,
|
243
|
+
AcceleratorDevice.CPU,
|
244
|
+
AcceleratorDevice.MPS,
|
245
|
+
],
|
246
|
+
scale=2.0,
|
247
|
+
temperature=0.0,
|
248
|
+
)
|
249
|
+
|
140
250
|
|
141
251
|
class VlmModelType(str, Enum):
|
142
252
|
SMOLDOCLING = "smoldocling"
|
253
|
+
SMOLDOCLING_VLLM = "smoldocling_vllm"
|
143
254
|
GRANITE_VISION = "granite_vision"
|
255
|
+
GRANITE_VISION_VLLM = "granite_vision_vllm"
|
144
256
|
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
257
|
+
GOT_OCR_2 = "got_ocr_2"
|
docling/models/base_model.py
CHANGED
@@ -1,13 +1,24 @@
|
|
1
|
+
import logging
|
1
2
|
from abc import ABC, abstractmethod
|
2
3
|
from collections.abc import Iterable
|
3
|
-
from typing import Generic, Optional, Protocol, Type
|
4
|
+
from typing import Any, Generic, Optional, Protocol, Type, Union
|
4
5
|
|
6
|
+
import numpy as np
|
5
7
|
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
|
8
|
+
from PIL.Image import Image
|
6
9
|
from typing_extensions import TypeVar
|
7
10
|
|
8
|
-
from docling.datamodel.base_models import
|
11
|
+
from docling.datamodel.base_models import (
|
12
|
+
ItemAndImageEnrichmentElement,
|
13
|
+
Page,
|
14
|
+
VlmPrediction,
|
15
|
+
)
|
9
16
|
from docling.datamodel.document import ConversionResult
|
10
17
|
from docling.datamodel.pipeline_options import BaseOptions
|
18
|
+
from docling.datamodel.pipeline_options_vlm_model import (
|
19
|
+
InlineVlmOptions,
|
20
|
+
TransformersPromptStyle,
|
21
|
+
)
|
11
22
|
from docling.datamodel.settings import settings
|
12
23
|
|
13
24
|
|
@@ -26,6 +37,88 @@ class BasePageModel(ABC):
|
|
26
37
|
pass
|
27
38
|
|
28
39
|
|
40
|
+
class BaseVlmModel(ABC):
|
41
|
+
"""Base class for Vision-Language Models that adds image processing capability."""
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def process_images(
|
45
|
+
self,
|
46
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
47
|
+
prompt: Union[str, list[str]],
|
48
|
+
) -> Iterable[VlmPrediction]:
|
49
|
+
"""Process raw images without page metadata.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
image_batch: Iterable of PIL Images or numpy arrays
|
53
|
+
prompt: Either:
|
54
|
+
- str: Single prompt used for all images
|
55
|
+
- list[str]: List of prompts (one per image, must match image count)
|
56
|
+
|
57
|
+
Raises:
|
58
|
+
ValueError: If prompt list length doesn't match image count.
|
59
|
+
"""
|
60
|
+
|
61
|
+
|
62
|
+
class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
63
|
+
"""Base implementation for VLM models that inherit from BasePageModel.
|
64
|
+
|
65
|
+
Provides a default __call__ implementation that extracts images from pages,
|
66
|
+
processes them using process_images, and attaches results back to pages.
|
67
|
+
"""
|
68
|
+
|
69
|
+
# Type annotations for attributes that subclasses must initialize
|
70
|
+
vlm_options: InlineVlmOptions
|
71
|
+
processor: Any
|
72
|
+
|
73
|
+
@abstractmethod
|
74
|
+
def __call__(
|
75
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
76
|
+
) -> Iterable[Page]:
|
77
|
+
"""Extract images from pages, process them, and attach results back."""
|
78
|
+
|
79
|
+
def formulate_prompt(self, user_prompt: str) -> str:
|
80
|
+
"""Formulate a prompt for the VLM."""
|
81
|
+
_log = logging.getLogger(__name__)
|
82
|
+
|
83
|
+
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
84
|
+
return user_prompt
|
85
|
+
|
86
|
+
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
87
|
+
_log.debug("Using specialized prompt for Phi-4")
|
88
|
+
# Note: This might need adjustment for VLLM vs transformers
|
89
|
+
user_prompt_prefix = "<|user|>"
|
90
|
+
assistant_prompt = "<|assistant|>"
|
91
|
+
prompt_suffix = "<|end|>"
|
92
|
+
|
93
|
+
prompt = f"{user_prompt_prefix}<|image_1|>{user_prompt}{prompt_suffix}{assistant_prompt}"
|
94
|
+
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
95
|
+
|
96
|
+
return prompt
|
97
|
+
|
98
|
+
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.CHAT:
|
99
|
+
messages = [
|
100
|
+
{
|
101
|
+
"role": "user",
|
102
|
+
"content": [
|
103
|
+
{
|
104
|
+
"type": "text",
|
105
|
+
"text": "This is a page from a document.",
|
106
|
+
},
|
107
|
+
{"type": "image"},
|
108
|
+
{"type": "text", "text": user_prompt},
|
109
|
+
],
|
110
|
+
}
|
111
|
+
]
|
112
|
+
prompt = self.processor.apply_chat_template(
|
113
|
+
messages, add_generation_prompt=True
|
114
|
+
)
|
115
|
+
return prompt
|
116
|
+
|
117
|
+
raise RuntimeError(
|
118
|
+
f"Unknown prompt style `{self.vlm_options.transformers_prompt_style}`. Valid values are {', '.join(s.value for s in TransformersPromptStyle)}."
|
119
|
+
)
|
120
|
+
|
121
|
+
|
29
122
|
EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
30
123
|
|
31
124
|
|
@@ -17,6 +17,9 @@ from docling.utils.profiling import TimeRecorder
|
|
17
17
|
|
18
18
|
class PagePreprocessingOptions(BaseModel):
|
19
19
|
images_scale: Optional[float]
|
20
|
+
skip_cell_extraction: bool = (
|
21
|
+
False # Skip text cell extraction for VLM-only processing
|
22
|
+
)
|
20
23
|
|
21
24
|
|
22
25
|
class PagePreprocessingModel(BasePageModel):
|
@@ -41,7 +44,8 @@ class PagePreprocessingModel(BasePageModel):
|
|
41
44
|
else:
|
42
45
|
with TimeRecorder(conv_res, "page_parse"):
|
43
46
|
page = self._populate_page_images(page)
|
44
|
-
|
47
|
+
if not self.options.skip_cell_extraction:
|
48
|
+
page = self._parse_page_cells(conv_res, page)
|
45
49
|
yield page
|
46
50
|
|
47
51
|
# Generate the page image and store it in the page object
|
@@ -4,6 +4,7 @@ from pathlib import Path
|
|
4
4
|
from typing import Optional, Type, Union
|
5
5
|
|
6
6
|
from PIL import Image
|
7
|
+
from transformers import AutoModelForImageTextToText
|
7
8
|
|
8
9
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
9
10
|
from docling.datamodel.pipeline_options import (
|
@@ -63,7 +64,7 @@ class PictureDescriptionVlmModel(
|
|
63
64
|
# Initialize processor and model
|
64
65
|
with _model_init_lock:
|
65
66
|
self.processor = AutoProcessor.from_pretrained(artifacts_path)
|
66
|
-
self.model =
|
67
|
+
self.model = AutoModelForImageTextToText.from_pretrained(
|
67
68
|
artifacts_path,
|
68
69
|
device_map=self.device,
|
69
70
|
torch_dtype=torch.bfloat16,
|
@@ -71,9 +72,10 @@ class PictureDescriptionVlmModel(
|
|
71
72
|
"flash_attention_2"
|
72
73
|
if self.device.startswith("cuda")
|
73
74
|
and accelerator_options.cuda_use_flash_attention2
|
74
|
-
else "
|
75
|
+
else "sdpa"
|
75
76
|
),
|
76
77
|
)
|
78
|
+
self.model = torch.compile(self.model) # type: ignore
|
77
79
|
|
78
80
|
self.provenance = f"{self.options.repo_id}"
|
79
81
|
|
@@ -0,0 +1 @@
|
|
1
|
+
|
@@ -3,7 +3,11 @@ import logging
|
|
3
3
|
import time
|
4
4
|
from collections.abc import Iterable
|
5
5
|
from pathlib import Path
|
6
|
-
from typing import Any, Optional
|
6
|
+
from typing import Any, Optional, Union
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
from PIL.Image import Image
|
10
|
+
from transformers import StoppingCriteriaList, StopStringCriteria
|
7
11
|
|
8
12
|
from docling.datamodel.accelerator_options import (
|
9
13
|
AcceleratorOptions,
|
@@ -15,7 +19,7 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
15
19
|
TransformersModelType,
|
16
20
|
TransformersPromptStyle,
|
17
21
|
)
|
18
|
-
from docling.models.base_model import
|
22
|
+
from docling.models.base_model import BaseVlmPageModel
|
19
23
|
from docling.models.utils.hf_model_download import (
|
20
24
|
HuggingFaceModelDownloadMixin,
|
21
25
|
)
|
@@ -25,7 +29,7 @@ from docling.utils.profiling import TimeRecorder
|
|
25
29
|
_log = logging.getLogger(__name__)
|
26
30
|
|
27
31
|
|
28
|
-
class HuggingFaceTransformersVlmModel(
|
32
|
+
class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
29
33
|
def __init__(
|
30
34
|
self,
|
31
35
|
enabled: bool,
|
@@ -103,6 +107,8 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
103
107
|
artifacts_path,
|
104
108
|
trust_remote_code=vlm_options.trust_remote_code,
|
105
109
|
)
|
110
|
+
self.processor.tokenizer.padding_side = "left"
|
111
|
+
|
106
112
|
self.vlm_model = model_cls.from_pretrained(
|
107
113
|
artifacts_path,
|
108
114
|
device_map=self.device,
|
@@ -111,10 +117,11 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
111
117
|
"flash_attention_2"
|
112
118
|
if self.device.startswith("cuda")
|
113
119
|
and accelerator_options.cuda_use_flash_attention2
|
114
|
-
else "
|
120
|
+
else "sdpa"
|
115
121
|
),
|
116
122
|
trust_remote_code=vlm_options.trust_remote_code,
|
117
123
|
)
|
124
|
+
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
118
125
|
|
119
126
|
# Load generation config
|
120
127
|
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
@@ -122,93 +129,186 @@ class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMix
|
|
122
129
|
def __call__(
|
123
130
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
124
131
|
) -> Iterable[Page]:
|
125
|
-
|
132
|
+
page_list = list(page_batch)
|
133
|
+
if not page_list:
|
134
|
+
return
|
135
|
+
|
136
|
+
valid_pages = []
|
137
|
+
invalid_pages = []
|
138
|
+
|
139
|
+
for page in page_list:
|
126
140
|
assert page._backend is not None
|
127
141
|
if not page._backend.is_valid():
|
128
|
-
|
142
|
+
invalid_pages.append(page)
|
129
143
|
else:
|
130
|
-
|
131
|
-
assert page.size is not None
|
144
|
+
valid_pages.append(page)
|
132
145
|
|
146
|
+
# Process valid pages in batch
|
147
|
+
if valid_pages:
|
148
|
+
with TimeRecorder(conv_res, "vlm"):
|
149
|
+
# Prepare images and prompts for batch processing
|
150
|
+
images = []
|
151
|
+
user_prompts = []
|
152
|
+
pages_with_images = []
|
153
|
+
|
154
|
+
for page in valid_pages:
|
155
|
+
assert page.size is not None
|
133
156
|
hi_res_image = page.get_image(
|
134
157
|
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
135
158
|
)
|
136
159
|
|
137
|
-
#
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
inputs = self.processor(
|
142
|
-
text=prompt, images=[hi_res_image], return_tensors="pt"
|
143
|
-
).to(self.device)
|
144
|
-
|
145
|
-
start_time = time.time()
|
146
|
-
# Call model to generate:
|
147
|
-
generated_ids = self.vlm_model.generate(
|
148
|
-
**inputs,
|
149
|
-
max_new_tokens=self.max_new_tokens,
|
150
|
-
use_cache=self.use_cache,
|
151
|
-
temperature=self.temperature,
|
152
|
-
generation_config=self.generation_config,
|
153
|
-
**self.vlm_options.extra_generation_config,
|
154
|
-
)
|
160
|
+
# Only process pages with valid images
|
161
|
+
if hi_res_image is not None:
|
162
|
+
images.append(hi_res_image)
|
155
163
|
|
156
|
-
|
157
|
-
|
158
|
-
generated_ids[:, inputs["input_ids"].shape[1] :],
|
159
|
-
skip_special_tokens=False,
|
160
|
-
)[0]
|
164
|
+
# Define prompt structure
|
165
|
+
user_prompt = self.vlm_options.build_prompt(page.parsed_page)
|
161
166
|
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
167
|
+
user_prompts.append(user_prompt)
|
168
|
+
pages_with_images.append(page)
|
169
|
+
|
170
|
+
# Use process_images for the actual inference
|
171
|
+
if images: # Only if we have valid images
|
172
|
+
predictions = list(self.process_images(images, user_prompts))
|
173
|
+
|
174
|
+
# Attach results to pages
|
175
|
+
for page, prediction in zip(pages_with_images, predictions):
|
176
|
+
page.predictions.vlm_response = prediction
|
177
|
+
|
178
|
+
# Yield all pages (valid and invalid)
|
179
|
+
for page in invalid_pages:
|
180
|
+
yield page
|
181
|
+
for page in valid_pages:
|
182
|
+
yield page
|
183
|
+
|
184
|
+
def process_images(
|
185
|
+
self,
|
186
|
+
image_batch: Iterable[Union[Image, np.ndarray]],
|
187
|
+
prompt: Union[str, list[str]],
|
188
|
+
) -> Iterable[VlmPrediction]:
|
189
|
+
"""
|
190
|
+
Batched inference for Hugging Face Image-Text-to-Text VLMs (e.g., SmolDocling / SmolVLM).
|
191
|
+
- Lets the processor handle all padding & batching for text+images.
|
192
|
+
- Trims generated sequences per row using attention_mask (no pad-id fallbacks).
|
193
|
+
- Keeps your formulate_prompt() exactly as-is.
|
194
|
+
"""
|
195
|
+
import numpy as np
|
196
|
+
import torch
|
197
|
+
from PIL import Image as PILImage
|
198
|
+
|
199
|
+
# -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
|
200
|
+
pil_images: list[Image] = []
|
201
|
+
for img in image_batch:
|
202
|
+
if isinstance(img, np.ndarray):
|
203
|
+
if img.ndim == 3 and img.shape[2] in (3, 4):
|
204
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8))
|
205
|
+
elif img.ndim == 2:
|
206
|
+
pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
|
207
|
+
else:
|
208
|
+
raise ValueError(f"Unsupported numpy array shape: {img.shape}")
|
209
|
+
else:
|
210
|
+
pil_img = img
|
211
|
+
if pil_img.mode != "RGB":
|
212
|
+
pil_img = pil_img.convert("RGB")
|
213
|
+
pil_images.append(pil_img)
|
214
|
+
|
215
|
+
if not pil_images:
|
216
|
+
return
|
217
|
+
|
218
|
+
# -- Normalize prompts (1 per image)
|
219
|
+
if isinstance(prompt, str):
|
220
|
+
user_prompts = [prompt] * len(pil_images)
|
221
|
+
else:
|
222
|
+
if len(prompt) != len(pil_images):
|
223
|
+
raise ValueError(
|
224
|
+
f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
|
225
|
+
)
|
226
|
+
user_prompts = prompt
|
227
|
+
|
228
|
+
# Use your prompt formatter verbatim
|
229
|
+
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
230
|
+
inputs = self.processor(
|
231
|
+
pil_images,
|
232
|
+
return_tensors="pt",
|
233
|
+
padding=True, # pad across batch for both text and vision
|
234
|
+
**self.vlm_options.extra_processor_kwargs,
|
235
|
+
)
|
236
|
+
else:
|
237
|
+
prompts: list[str] = [self.formulate_prompt(p) for p in user_prompts]
|
238
|
+
|
239
|
+
# -- Processor performs BOTH text+image preprocessing + batch padding (recommended)
|
240
|
+
inputs = self.processor(
|
241
|
+
text=prompts,
|
242
|
+
images=pil_images,
|
243
|
+
return_tensors="pt",
|
244
|
+
padding=True, # pad across batch for both text and vision
|
245
|
+
**self.vlm_options.extra_processor_kwargs,
|
246
|
+
)
|
247
|
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
248
|
+
|
249
|
+
# -- Optional stopping criteria
|
250
|
+
stopping_criteria = None
|
251
|
+
if self.vlm_options.stop_strings:
|
252
|
+
stopping_criteria = StoppingCriteriaList(
|
253
|
+
[
|
254
|
+
StopStringCriteria(
|
255
|
+
stop_strings=self.vlm_options.stop_strings,
|
256
|
+
tokenizer=self.processor.tokenizer,
|
170
257
|
)
|
258
|
+
]
|
259
|
+
)
|
171
260
|
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
],
|
205
|
-
}
|
206
|
-
]
|
207
|
-
prompt = self.processor.apply_chat_template(
|
208
|
-
messages, add_generation_prompt=False
|
261
|
+
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
262
|
+
gen_kwargs = {
|
263
|
+
**inputs,
|
264
|
+
"max_new_tokens": self.max_new_tokens,
|
265
|
+
"use_cache": self.use_cache,
|
266
|
+
"generation_config": self.generation_config,
|
267
|
+
**self.vlm_options.extra_generation_config,
|
268
|
+
}
|
269
|
+
if self.temperature > 0:
|
270
|
+
gen_kwargs["do_sample"] = True
|
271
|
+
gen_kwargs["temperature"] = self.temperature
|
272
|
+
else:
|
273
|
+
gen_kwargs["do_sample"] = False
|
274
|
+
|
275
|
+
if stopping_criteria is not None:
|
276
|
+
gen_kwargs["stopping_criteria"] = stopping_criteria
|
277
|
+
|
278
|
+
start_time = time.time()
|
279
|
+
with torch.inference_mode():
|
280
|
+
generated_ids = self.vlm_model.generate(**gen_kwargs)
|
281
|
+
generation_time = time.time() - start_time
|
282
|
+
|
283
|
+
input_len = inputs["input_ids"].shape[1] # common right-aligned prompt length
|
284
|
+
trimmed_sequences = generated_ids[:, input_len:] # only newly generated tokens
|
285
|
+
|
286
|
+
# -- Decode with the processor/tokenizer (skip specials, keep DocTags as text)
|
287
|
+
decode_fn = getattr(self.processor, "batch_decode", None)
|
288
|
+
if decode_fn is None and getattr(self.processor, "tokenizer", None) is not None:
|
289
|
+
decode_fn = self.processor.tokenizer.batch_decode
|
290
|
+
if decode_fn is None:
|
291
|
+
raise RuntimeError(
|
292
|
+
"Neither processor.batch_decode nor tokenizer.batch_decode is available."
|
209
293
|
)
|
210
|
-
return prompt
|
211
294
|
|
212
|
-
|
213
|
-
|
295
|
+
decoded_texts: list[str] = decode_fn(
|
296
|
+
trimmed_sequences, skip_special_tokens=False
|
214
297
|
)
|
298
|
+
|
299
|
+
# -- Clip off pad tokens from decoded texts
|
300
|
+
pad_token = self.processor.tokenizer.pad_token
|
301
|
+
if pad_token:
|
302
|
+
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
303
|
+
|
304
|
+
# -- Optional logging
|
305
|
+
if generated_ids.shape[0] > 0:
|
306
|
+
_log.debug(
|
307
|
+
f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
|
308
|
+
f"for batch size {generated_ids.shape[0]}."
|
309
|
+
)
|
310
|
+
|
311
|
+
for text in decoded_texts:
|
312
|
+
# Apply decode_response to the output text
|
313
|
+
decoded_text = self.vlm_options.decode_response(text)
|
314
|
+
yield VlmPrediction(text=decoded_text, generation_time=generation_time)
|