docling 2.45.0__py3-none-any.whl → 2.47.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,12 @@
1
1
  import logging
2
+ import threading
2
3
  import time
3
4
  from collections.abc import Iterable
4
5
  from pathlib import Path
5
- from typing import Optional
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ from PIL.Image import Image
6
10
 
7
11
  from docling.datamodel.accelerator_options import (
8
12
  AcceleratorOptions,
@@ -10,7 +14,7 @@ from docling.datamodel.accelerator_options import (
10
14
  from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
11
15
  from docling.datamodel.document import ConversionResult
12
16
  from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
13
- from docling.models.base_model import BasePageModel
17
+ from docling.models.base_model import BaseVlmPageModel
14
18
  from docling.models.utils.hf_model_download import (
15
19
  HuggingFaceModelDownloadMixin,
16
20
  )
@@ -18,8 +22,12 @@ from docling.utils.profiling import TimeRecorder
18
22
 
19
23
  _log = logging.getLogger(__name__)
20
24
 
25
+ # Global lock for MLX model calls - MLX models are not thread-safe
26
+ # All MLX models share this lock to prevent concurrent MLX operations
27
+ _MLX_GLOBAL_LOCK = threading.Lock()
28
+
21
29
 
22
- class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
30
+ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
23
31
  def __init__(
24
32
  self,
25
33
  enabled: bool,
@@ -63,87 +71,190 @@ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
63
71
  def __call__(
64
72
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
65
73
  ) -> Iterable[Page]:
66
- for page in page_batch:
74
+ page_list = list(page_batch)
75
+ if not page_list:
76
+ return
77
+
78
+ valid_pages = []
79
+ invalid_pages = []
80
+
81
+ for page in page_list:
67
82
  assert page._backend is not None
68
83
  if not page._backend.is_valid():
69
- yield page
84
+ invalid_pages.append(page)
70
85
  else:
71
- with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
72
- assert page.size is not None
86
+ valid_pages.append(page)
73
87
 
88
+ # Process valid pages in batch
89
+ if valid_pages:
90
+ with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
91
+ # Prepare images and prompts for batch processing
92
+ images = []
93
+ user_prompts = []
94
+ pages_with_images = []
95
+
96
+ for page in valid_pages:
97
+ assert page.size is not None
74
98
  hi_res_image = page.get_image(
75
99
  scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
76
100
  )
101
+
102
+ # Only process pages with valid images
77
103
  if hi_res_image is not None:
78
- im_width, im_height = hi_res_image.size
104
+ images.append(hi_res_image)
79
105
 
80
- # populate page_tags with predicted doc tags
81
- page_tags = ""
106
+ # Define prompt structure
107
+ if callable(self.vlm_options.prompt):
108
+ user_prompt = self.vlm_options.prompt(page.parsed_page)
109
+ else:
110
+ user_prompt = self.vlm_options.prompt
82
111
 
83
- if hi_res_image:
84
- if hi_res_image.mode != "RGB":
85
- hi_res_image = hi_res_image.convert("RGB")
112
+ user_prompts.append(user_prompt)
113
+ pages_with_images.append(page)
86
114
 
87
- user_prompt = self.vlm_options.build_prompt(page.parsed_page)
88
- prompt = self.apply_chat_template(
89
- self.processor, self.config, user_prompt, num_images=1
90
- )
115
+ # Use process_images for the actual inference
116
+ if images: # Only if we have valid images
117
+ predictions = list(self.process_images(images, user_prompts))
91
118
 
92
- start_time = time.time()
93
- _log.debug("start generating ...")
94
-
95
- # Call model to generate:
96
- tokens: list[VlmPredictionToken] = []
97
-
98
- output = ""
99
- for token in self.stream_generate(
100
- self.vlm_model,
101
- self.processor,
102
- prompt,
103
- [hi_res_image],
104
- max_tokens=self.max_tokens,
105
- verbose=False,
106
- temp=self.temperature,
107
- ):
108
- if len(token.logprobs.shape) == 1:
109
- tokens.append(
110
- VlmPredictionToken(
111
- text=token.text,
112
- token=token.token,
113
- logprob=token.logprobs[token.token],
114
- )
115
- )
116
- elif (
117
- len(token.logprobs.shape) == 2
118
- and token.logprobs.shape[0] == 1
119
- ):
120
- tokens.append(
121
- VlmPredictionToken(
122
- text=token.text,
123
- token=token.token,
124
- logprob=token.logprobs[0, token.token],
125
- )
119
+ # Attach results to pages
120
+ for page, prediction in zip(pages_with_images, predictions):
121
+ page.predictions.vlm_response = prediction
122
+
123
+ # Yield all pages (valid and invalid)
124
+ for page in invalid_pages:
125
+ yield page
126
+ for page in valid_pages:
127
+ yield page
128
+
129
+ def process_images(
130
+ self,
131
+ image_batch: Iterable[Union[Image, np.ndarray]],
132
+ prompt: Union[str, list[str]],
133
+ ) -> Iterable[VlmPrediction]:
134
+ """Process raw images without page metadata.
135
+
136
+ Args:
137
+ image_batch: Iterable of PIL Images or numpy arrays
138
+ prompt: Either:
139
+ - str: Single prompt used for all images
140
+ - list[str]: List of prompts (one per image, must match image count)
141
+
142
+ Raises:
143
+ ValueError: If prompt list length doesn't match image count.
144
+ """
145
+ # Convert image batch to list for length validation
146
+ image_list = list(image_batch)
147
+
148
+ if len(image_list) == 0:
149
+ return
150
+
151
+ # Handle prompt parameter
152
+ if isinstance(prompt, str):
153
+ # Single prompt for all images
154
+ user_prompts = [prompt] * len(image_list)
155
+ elif isinstance(prompt, list):
156
+ # List of prompts (one per image)
157
+ if len(prompt) != len(image_list):
158
+ raise ValueError(
159
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(image_list)})"
160
+ )
161
+ user_prompts = prompt
162
+ else:
163
+ raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
164
+
165
+ # MLX models are not thread-safe - use global lock to serialize access
166
+ with _MLX_GLOBAL_LOCK:
167
+ _log.debug("MLX model: Acquired global lock for thread safety")
168
+ for image, user_prompt in zip(image_list, user_prompts):
169
+ # Convert numpy array to PIL Image if needed
170
+ if isinstance(image, np.ndarray):
171
+ if image.ndim == 3 and image.shape[2] in [3, 4]:
172
+ # RGB or RGBA array
173
+ from PIL import Image as PILImage
174
+
175
+ image = PILImage.fromarray(image.astype(np.uint8))
176
+ elif image.ndim == 2:
177
+ # Grayscale array
178
+ from PIL import Image as PILImage
179
+
180
+ image = PILImage.fromarray(image.astype(np.uint8), mode="L")
181
+ else:
182
+ raise ValueError(
183
+ f"Unsupported numpy array shape: {image.shape}"
184
+ )
185
+
186
+ # Ensure image is in RGB mode (handles RGBA, L, etc.)
187
+ if image.mode != "RGB":
188
+ image = image.convert("RGB")
189
+
190
+ # Use the MLX chat template approach like in the __call__ method
191
+ formatted_prompt = self.apply_chat_template(
192
+ self.processor, self.config, user_prompt, num_images=1
193
+ )
194
+
195
+ # Stream generate with stop strings support
196
+ start_time = time.time()
197
+ _log.debug("start generating ...")
198
+
199
+ tokens: list[VlmPredictionToken] = []
200
+ output = ""
201
+
202
+ # Use stream_generate for proper stop string handling
203
+ for token in self.stream_generate(
204
+ self.vlm_model,
205
+ self.processor,
206
+ formatted_prompt,
207
+ [image], # MLX stream_generate expects list of images
208
+ max_tokens=self.max_tokens,
209
+ verbose=False,
210
+ temp=self.temperature,
211
+ ):
212
+ # Collect token information
213
+ if len(token.logprobs.shape) == 1:
214
+ tokens.append(
215
+ VlmPredictionToken(
216
+ text=token.text,
217
+ token=token.token,
218
+ logprob=token.logprobs[token.token],
126
219
  )
127
- else:
128
- _log.warning(
129
- f"incompatible shape for logprobs: {token.logprobs.shape}"
220
+ )
221
+ elif (
222
+ len(token.logprobs.shape) == 2 and token.logprobs.shape[0] == 1
223
+ ):
224
+ tokens.append(
225
+ VlmPredictionToken(
226
+ text=token.text,
227
+ token=token.token,
228
+ logprob=token.logprobs[0, token.token],
130
229
  )
230
+ )
231
+ else:
232
+ _log.warning(
233
+ f"incompatible shape for logprobs: {token.logprobs.shape}"
234
+ )
131
235
 
132
- output += token.text
133
- if "</doctag>" in token.text:
236
+ output += token.text
237
+
238
+ # Check for any configured stop strings
239
+ if self.vlm_options.stop_strings:
240
+ if any(
241
+ stop_str in output
242
+ for stop_str in self.vlm_options.stop_strings
243
+ ):
244
+ _log.debug("Stopping generation due to stop string match")
134
245
  break
135
246
 
136
- generation_time = time.time() - start_time
137
- page_tags = output
247
+ generation_time = time.time() - start_time
138
248
 
139
- _log.debug(
140
- f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
141
- )
142
- page_tags = self.vlm_options.decode_response(page_tags)
143
- page.predictions.vlm_response = VlmPrediction(
144
- text=page_tags,
145
- generation_time=generation_time,
146
- generated_tokens=tokens,
147
- )
249
+ _log.debug(
250
+ f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time:.1f} tokens/sec)."
251
+ )
148
252
 
149
- yield page
253
+ # Apply decode_response to the output before yielding
254
+ decoded_output = self.vlm_options.decode_response(output)
255
+ yield VlmPrediction(
256
+ text=decoded_output,
257
+ generation_time=generation_time,
258
+ generated_tokens=tokens,
259
+ )
260
+ _log.debug("MLX model: Released global lock")
@@ -0,0 +1,235 @@
1
+ import logging
2
+ import time
3
+ from collections.abc import Iterable
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Optional, Union
6
+
7
+ import numpy as np
8
+ from PIL.Image import Image
9
+
10
+ from docling.datamodel.accelerator_options import (
11
+ AcceleratorOptions,
12
+ )
13
+ from docling.datamodel.base_models import Page, VlmPrediction
14
+ from docling.datamodel.document import ConversionResult
15
+ from docling.datamodel.pipeline_options_vlm_model import (
16
+ InlineVlmOptions,
17
+ TransformersPromptStyle,
18
+ )
19
+ from docling.models.base_model import BaseVlmPageModel
20
+ from docling.models.utils.hf_model_download import (
21
+ HuggingFaceModelDownloadMixin,
22
+ )
23
+ from docling.utils.accelerator_utils import decide_device
24
+ from docling.utils.profiling import TimeRecorder
25
+
26
+ _log = logging.getLogger(__name__)
27
+
28
+
29
+ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
30
+ def __init__(
31
+ self,
32
+ enabled: bool,
33
+ artifacts_path: Optional[Path],
34
+ accelerator_options: AcceleratorOptions,
35
+ vlm_options: InlineVlmOptions,
36
+ ):
37
+ self.enabled = enabled
38
+
39
+ self.vlm_options = vlm_options
40
+
41
+ if self.enabled:
42
+ from transformers import AutoProcessor
43
+ from vllm import LLM, SamplingParams
44
+
45
+ self.device = decide_device(
46
+ accelerator_options.device,
47
+ supported_devices=vlm_options.supported_devices,
48
+ )
49
+ _log.debug(f"Available device for VLM: {self.device}")
50
+
51
+ self.max_new_tokens = vlm_options.max_new_tokens
52
+ self.temperature = vlm_options.temperature
53
+
54
+ repo_cache_folder = vlm_options.repo_id.replace("/", "--")
55
+
56
+ if artifacts_path is None:
57
+ artifacts_path = self.download_models(self.vlm_options.repo_id)
58
+ elif (artifacts_path / repo_cache_folder).exists():
59
+ artifacts_path = artifacts_path / repo_cache_folder
60
+
61
+ # Initialize VLLM LLM
62
+ llm_kwargs: Dict[str, Any] = {
63
+ "model": str(artifacts_path),
64
+ "limit_mm_per_prompt": {"image": 1},
65
+ "trust_remote_code": vlm_options.trust_remote_code,
66
+ "model_impl": "transformers",
67
+ "gpu_memory_utilization": 0.3, # hardcoded for now, leaves room for ~3 different models.
68
+ }
69
+
70
+ # Add device-specific configurations
71
+
72
+ if self.device == "cpu":
73
+ llm_kwargs["device"] = "cpu"
74
+
75
+ # Add quantization if specified
76
+ if vlm_options.quantized:
77
+ if vlm_options.load_in_8bit:
78
+ llm_kwargs["quantization"] = "bitsandbytes"
79
+
80
+ self.llm = LLM(**llm_kwargs)
81
+
82
+ # Initialize processor for prompt formatting
83
+ self.processor = AutoProcessor.from_pretrained(
84
+ artifacts_path,
85
+ trust_remote_code=vlm_options.trust_remote_code,
86
+ )
87
+
88
+ # Set up sampling parameters
89
+ self.sampling_params = SamplingParams(
90
+ temperature=self.temperature,
91
+ max_tokens=self.max_new_tokens,
92
+ stop=vlm_options.stop_strings if vlm_options.stop_strings else None,
93
+ **vlm_options.extra_generation_config,
94
+ )
95
+
96
+ def __call__(
97
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
98
+ ) -> Iterable[Page]:
99
+ page_list = list(page_batch)
100
+ if not page_list:
101
+ return
102
+
103
+ valid_pages = []
104
+ invalid_pages = []
105
+
106
+ for page in page_list:
107
+ assert page._backend is not None
108
+ if not page._backend.is_valid():
109
+ invalid_pages.append(page)
110
+ else:
111
+ valid_pages.append(page)
112
+
113
+ # Process valid pages in batch
114
+ if valid_pages:
115
+ with TimeRecorder(conv_res, "vlm"):
116
+ # Prepare images and prompts for batch processing
117
+ images = []
118
+ user_prompts = []
119
+ pages_with_images = []
120
+
121
+ for page in valid_pages:
122
+ assert page.size is not None
123
+ hi_res_image = page.get_image(
124
+ scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
125
+ )
126
+
127
+ # Only process pages with valid images
128
+ if hi_res_image is not None:
129
+ images.append(hi_res_image)
130
+
131
+ # Define prompt structure
132
+ if callable(self.vlm_options.prompt):
133
+ user_prompt = self.vlm_options.prompt(page.parsed_page)
134
+ else:
135
+ user_prompt = self.vlm_options.prompt
136
+
137
+ user_prompts.append(user_prompt)
138
+ pages_with_images.append(page)
139
+
140
+ # Use process_images for the actual inference
141
+ if images: # Only if we have valid images
142
+ predictions = list(self.process_images(images, user_prompts))
143
+
144
+ # Attach results to pages
145
+ for page, prediction in zip(pages_with_images, predictions):
146
+ page.predictions.vlm_response = prediction
147
+
148
+ # Yield all pages (valid and invalid)
149
+ for page in invalid_pages:
150
+ yield page
151
+ for page in valid_pages:
152
+ yield page
153
+
154
+ def process_images(
155
+ self,
156
+ image_batch: Iterable[Union[Image, np.ndarray]],
157
+ prompt: Union[str, list[str]],
158
+ ) -> Iterable[VlmPrediction]:
159
+ """Process raw images without page metadata in a single batched inference call.
160
+
161
+ Args:
162
+ image_batch: Iterable of PIL Images or numpy arrays
163
+ prompt: Either:
164
+ - str: Single prompt used for all images
165
+ - list[str]: List of prompts (one per image, must match image count)
166
+
167
+ Raises:
168
+ ValueError: If prompt list length doesn't match image count.
169
+ """
170
+ pil_images: list[Image] = []
171
+
172
+ for img in image_batch:
173
+ # Convert numpy array to PIL Image if needed
174
+ if isinstance(img, np.ndarray):
175
+ if img.ndim == 3 and img.shape[2] in [3, 4]:
176
+ from PIL import Image as PILImage
177
+
178
+ pil_img = PILImage.fromarray(img.astype(np.uint8))
179
+ elif img.ndim == 2:
180
+ from PIL import Image as PILImage
181
+
182
+ pil_img = PILImage.fromarray(img.astype(np.uint8), mode="L")
183
+ else:
184
+ raise ValueError(f"Unsupported numpy array shape: {img.shape}")
185
+ else:
186
+ pil_img = img
187
+
188
+ # Ensure image is in RGB mode (handles RGBA, L, etc.)
189
+ if pil_img.mode != "RGB":
190
+ pil_img = pil_img.convert("RGB")
191
+
192
+ pil_images.append(pil_img)
193
+
194
+ if len(pil_images) == 0:
195
+ return
196
+
197
+ # Handle prompt parameter
198
+ if isinstance(prompt, str):
199
+ # Single prompt for all images
200
+ user_prompts = [prompt] * len(pil_images)
201
+ elif isinstance(prompt, list):
202
+ # List of prompts (one per image)
203
+ if len(prompt) != len(pil_images):
204
+ raise ValueError(
205
+ f"Number of prompts ({len(prompt)}) must match number of images ({len(pil_images)})"
206
+ )
207
+ user_prompts = prompt
208
+ else:
209
+ raise ValueError(f"prompt must be str or list[str], got {type(prompt)}")
210
+
211
+ # Format prompts individually
212
+ prompts: list[str] = [
213
+ self.formulate_prompt(user_prompt) for user_prompt in user_prompts
214
+ ]
215
+
216
+ # Prepare VLLM inputs
217
+ llm_inputs = []
218
+ for prompt, image in zip(prompts, pil_images):
219
+ llm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
220
+
221
+ start_time = time.time()
222
+ outputs = self.llm.generate(llm_inputs, sampling_params=self.sampling_params) # type: ignore
223
+ generation_time = time.time() - start_time
224
+
225
+ # Logging tokens count for the first sample as a representative metric
226
+ if len(outputs) > 0:
227
+ num_tokens = len(outputs[0].outputs[0].token_ids)
228
+ _log.debug(
229
+ f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
230
+ )
231
+
232
+ for output in outputs:
233
+ # Apply decode_response to the output text
234
+ decoded_text = self.vlm_options.decode_response(output.outputs[0].text)
235
+ yield VlmPrediction(text=decoded_text, generation_time=generation_time)
@@ -20,7 +20,7 @@ from docling.datamodel.base_models import (
20
20
  Page,
21
21
  )
22
22
  from docling.datamodel.document import ConversionResult, InputDocument
23
- from docling.datamodel.pipeline_options import PipelineOptions
23
+ from docling.datamodel.pipeline_options import PdfPipelineOptions, PipelineOptions
24
24
  from docling.datamodel.settings import settings
25
25
  from docling.models.base_model import GenericEnrichmentModel
26
26
  from docling.utils.profiling import ProfilingScope, TimeRecorder
@@ -168,6 +168,12 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
168
168
  # Cleanup page backends
169
169
  if not self.keep_backend and p._backend is not None:
170
170
  p._backend.unload()
171
+ if (
172
+ isinstance(self.pipeline_options, PdfPipelineOptions)
173
+ and not self.pipeline_options.generate_parsed_pages
174
+ ):
175
+ del p.parsed_page
176
+ p.parsed_page = None
171
177
 
172
178
  end_batch_time = time.monotonic()
173
179
  total_elapsed_time += end_batch_time - start_batch_time
@@ -194,7 +194,7 @@ class ThreadedPipelineStage:
194
194
  return
195
195
  self._running = True
196
196
  self._thread = threading.Thread(
197
- target=self._run, name=f"Stage-{self.name}", daemon=False
197
+ target=self._run, name=f"Stage-{self.name}", daemon=True
198
198
  )
199
199
  self._thread.start()
200
200
 
@@ -565,10 +565,12 @@ class ThreadedStandardPdfPipeline(BasePipeline):
565
565
  if not self.keep_images:
566
566
  for p in conv_res.pages:
567
567
  p._image_cache = {}
568
- if not self.keep_backend:
569
- for p in conv_res.pages:
570
- if p._backend is not None:
571
- p._backend.unload()
568
+ for p in conv_res.pages:
569
+ if not self.keep_backend and p._backend is not None:
570
+ p._backend.unload()
571
+ if not self.pipeline_options.generate_parsed_pages:
572
+ del p.parsed_page
573
+ p.parsed_page = None
572
574
 
573
575
  # ---------------------------------------------------------------- assemble
574
576
  def _assemble_document(self, conv_res: ConversionResult) -> ConversionResult:
@@ -103,6 +103,17 @@ class VlmPipeline(PaginatedPipeline):
103
103
  vlm_options=vlm_options,
104
104
  ),
105
105
  ]
106
+ elif vlm_options.inference_framework == InferenceFramework.VLLM:
107
+ from docling.models.vlm_models_inline.vllm_model import VllmVlmModel
108
+
109
+ self.build_pipe = [
110
+ VllmVlmModel(
111
+ enabled=True, # must be always enabled for this pipeline to make sense.
112
+ artifacts_path=artifacts_path,
113
+ accelerator_options=pipeline_options.accelerator_options,
114
+ vlm_options=vlm_options,
115
+ ),
116
+ ]
106
117
  else:
107
118
  raise ValueError(
108
119
  f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
@@ -117,7 +128,9 @@ class VlmPipeline(PaginatedPipeline):
117
128
  page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
118
129
  if page._backend is not None and page._backend.is_valid():
119
130
  page.size = page._backend.get_size()
120
- page.parsed_page = page._backend.get_segmented_page()
131
+
132
+ if self.force_backend_text:
133
+ page.parsed_page = page._backend.get_segmented_page()
121
134
 
122
135
  return page
123
136