docling 2.35.0__py3-none-any.whl → 2.36.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. docling/backend/xml/jats_backend.py +0 -0
  2. docling/cli/main.py +12 -15
  3. docling/datamodel/accelerator_options.py +68 -0
  4. docling/datamodel/base_models.py +10 -8
  5. docling/datamodel/pipeline_options.py +29 -161
  6. docling/datamodel/pipeline_options_vlm_model.py +81 -0
  7. docling/datamodel/vlm_model_specs.py +144 -0
  8. docling/document_converter.py +5 -0
  9. docling/models/api_vlm_model.py +1 -1
  10. docling/models/base_ocr_model.py +2 -1
  11. docling/models/code_formula_model.py +6 -11
  12. docling/models/document_picture_classifier.py +6 -11
  13. docling/models/easyocr_model.py +1 -2
  14. docling/models/layout_model.py +6 -11
  15. docling/models/ocr_mac_model.py +1 -1
  16. docling/models/picture_description_api_model.py +1 -1
  17. docling/models/picture_description_base_model.py +1 -1
  18. docling/models/picture_description_vlm_model.py +7 -22
  19. docling/models/rapid_ocr_model.py +1 -2
  20. docling/models/table_structure_model.py +6 -12
  21. docling/models/tesseract_ocr_cli_model.py +1 -1
  22. docling/models/tesseract_ocr_model.py +1 -1
  23. docling/models/utils/__init__.py +0 -0
  24. docling/models/utils/hf_model_download.py +40 -0
  25. docling/models/vlm_models_inline/__init__.py +0 -0
  26. docling/models/vlm_models_inline/hf_transformers_model.py +194 -0
  27. docling/models/{hf_mlx_model.py → vlm_models_inline/mlx_model.py} +56 -44
  28. docling/pipeline/vlm_pipeline.py +228 -61
  29. docling/utils/accelerator_utils.py +17 -2
  30. docling/utils/model_downloader.py +13 -12
  31. {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/METADATA +54 -55
  32. {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/RECORD +46 -39
  33. {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/WHEEL +2 -1
  34. docling-2.36.0.dist-info/entry_points.txt +6 -0
  35. docling-2.36.0.dist-info/top_level.txt +1 -0
  36. docling/models/hf_vlm_model.py +0 -182
  37. docling-2.35.0.dist-info/entry_points.txt +0 -7
  38. {docling-2.35.0.dist-info → docling-2.36.0.dist-info/licenses}/LICENSE +0 -0
@@ -4,29 +4,34 @@ from collections.abc import Iterable
4
4
  from pathlib import Path
5
5
  from typing import Optional
6
6
 
7
- from docling.datamodel.base_models import Page, VlmPrediction
8
- from docling.datamodel.document import ConversionResult
9
- from docling.datamodel.pipeline_options import (
7
+ from docling.datamodel.accelerator_options import (
10
8
  AcceleratorOptions,
11
- HuggingFaceVlmOptions,
12
9
  )
10
+ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
11
+ from docling.datamodel.document import ConversionResult
12
+ from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
13
13
  from docling.models.base_model import BasePageModel
14
+ from docling.models.utils.hf_model_download import (
15
+ HuggingFaceModelDownloadMixin,
16
+ )
14
17
  from docling.utils.profiling import TimeRecorder
15
18
 
16
19
  _log = logging.getLogger(__name__)
17
20
 
18
21
 
19
- class HuggingFaceMlxModel(BasePageModel):
22
+ class HuggingFaceMlxModel(BasePageModel, HuggingFaceModelDownloadMixin):
20
23
  def __init__(
21
24
  self,
22
25
  enabled: bool,
23
26
  artifacts_path: Optional[Path],
24
27
  accelerator_options: AcceleratorOptions,
25
- vlm_options: HuggingFaceVlmOptions,
28
+ vlm_options: InlineVlmOptions,
26
29
  ):
27
30
  self.enabled = enabled
28
31
 
29
32
  self.vlm_options = vlm_options
33
+ self.max_tokens = vlm_options.max_new_tokens
34
+ self.temperature = vlm_options.temperature
30
35
 
31
36
  if self.enabled:
32
37
  try:
@@ -39,42 +44,24 @@ class HuggingFaceMlxModel(BasePageModel):
39
44
  )
40
45
 
41
46
  repo_cache_folder = vlm_options.repo_id.replace("/", "--")
47
+
42
48
  self.apply_chat_template = apply_chat_template
43
49
  self.stream_generate = stream_generate
44
50
 
45
51
  # PARAMETERS:
46
52
  if artifacts_path is None:
47
- artifacts_path = self.download_models(self.vlm_options.repo_id)
53
+ artifacts_path = self.download_models(
54
+ self.vlm_options.repo_id,
55
+ )
48
56
  elif (artifacts_path / repo_cache_folder).exists():
49
57
  artifacts_path = artifacts_path / repo_cache_folder
50
58
 
51
- self.param_question = vlm_options.prompt # "Perform Layout Analysis."
59
+ self.param_question = vlm_options.prompt
52
60
 
53
61
  ## Load the model
54
62
  self.vlm_model, self.processor = load(artifacts_path)
55
63
  self.config = load_config(artifacts_path)
56
64
 
57
- @staticmethod
58
- def download_models(
59
- repo_id: str,
60
- local_dir: Optional[Path] = None,
61
- force: bool = False,
62
- progress: bool = False,
63
- ) -> Path:
64
- from huggingface_hub import snapshot_download
65
- from huggingface_hub.utils import disable_progress_bars
66
-
67
- if not progress:
68
- disable_progress_bars()
69
- download_path = snapshot_download(
70
- repo_id=repo_id,
71
- force_download=force,
72
- local_dir=local_dir,
73
- # revision="v0.0.1",
74
- )
75
-
76
- return Path(download_path)
77
-
78
65
  def __call__(
79
66
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
80
67
  ) -> Iterable[Page]:
@@ -83,12 +70,10 @@ class HuggingFaceMlxModel(BasePageModel):
83
70
  if not page._backend.is_valid():
84
71
  yield page
85
72
  else:
86
- with TimeRecorder(conv_res, "vlm"):
73
+ with TimeRecorder(conv_res, f"vlm-mlx-{self.vlm_options.repo_id}"):
87
74
  assert page.size is not None
88
75
 
89
- hi_res_image = page.get_image(scale=2.0) # 144dpi
90
- # hi_res_image = page.get_image(scale=1.0) # 72dpi
91
-
76
+ hi_res_image = page.get_image(scale=self.vlm_options.scale)
92
77
  if hi_res_image is not None:
93
78
  im_width, im_height = hi_res_image.size
94
79
 
@@ -104,16 +89,45 @@ class HuggingFaceMlxModel(BasePageModel):
104
89
  )
105
90
 
106
91
  start_time = time.time()
92
+ _log.debug("start generating ...")
93
+
107
94
  # Call model to generate:
95
+ tokens: list[VlmPredictionToken] = []
96
+
108
97
  output = ""
109
98
  for token in self.stream_generate(
110
99
  self.vlm_model,
111
100
  self.processor,
112
101
  prompt,
113
102
  [hi_res_image],
114
- max_tokens=4096,
103
+ max_tokens=self.max_tokens,
115
104
  verbose=False,
105
+ temp=self.temperature,
116
106
  ):
107
+ if len(token.logprobs.shape) == 1:
108
+ tokens.append(
109
+ VlmPredictionToken(
110
+ text=token.text,
111
+ token=token.token,
112
+ logprob=token.logprobs[token.token],
113
+ )
114
+ )
115
+ elif (
116
+ len(token.logprobs.shape) == 2
117
+ and token.logprobs.shape[0] == 1
118
+ ):
119
+ tokens.append(
120
+ VlmPredictionToken(
121
+ text=token.text,
122
+ token=token.token,
123
+ logprob=token.logprobs[0, token.token],
124
+ )
125
+ )
126
+ else:
127
+ _log.warning(
128
+ f"incompatible shape for logprobs: {token.logprobs.shape}"
129
+ )
130
+
117
131
  output += token.text
118
132
  if "</doctag>" in token.text:
119
133
  break
@@ -121,15 +135,13 @@ class HuggingFaceMlxModel(BasePageModel):
121
135
  generation_time = time.time() - start_time
122
136
  page_tags = output
123
137
 
124
- _log.debug(f"Generation time {generation_time:.2f} seconds.")
125
-
126
- # inference_time = time.time() - start_time
127
- # tokens_per_second = num_tokens / generation_time
128
- # print("")
129
- # print(f"Page Inference Time: {inference_time:.2f} seconds")
130
- # print(f"Total tokens on page: {num_tokens:.2f}")
131
- # print(f"Tokens/sec: {tokens_per_second:.2f}")
132
- # print("")
133
- page.predictions.vlm_response = VlmPrediction(text=page_tags)
138
+ _log.debug(
139
+ f"{generation_time:.2f} seconds for {len(tokens)} tokens ({len(tokens) / generation_time} tokens/sec)."
140
+ )
141
+ page.predictions.vlm_response = VlmPrediction(
142
+ text=page_tags,
143
+ generation_time=generation_time,
144
+ generated_tokens=tokens,
145
+ )
134
146
 
135
147
  yield page
@@ -1,29 +1,46 @@
1
1
  import logging
2
+ import re
2
3
  from io import BytesIO
3
4
  from pathlib import Path
4
5
  from typing import List, Optional, Union, cast
5
6
 
6
- from docling_core.types import DoclingDocument
7
- from docling_core.types.doc import BoundingBox, DocItem, ImageRef, PictureItem, TextItem
7
+ from docling_core.types.doc import (
8
+ BoundingBox,
9
+ DocItem,
10
+ DoclingDocument,
11
+ ImageRef,
12
+ PictureItem,
13
+ ProvenanceItem,
14
+ TextItem,
15
+ )
16
+ from docling_core.types.doc.base import (
17
+ BoundingBox,
18
+ Size,
19
+ )
8
20
  from docling_core.types.doc.document import DocTagsDocument
9
21
  from PIL import Image as PILImage
10
22
 
11
23
  from docling.backend.abstract_backend import AbstractDocumentBackend
24
+ from docling.backend.html_backend import HTMLDocumentBackend
12
25
  from docling.backend.md_backend import MarkdownDocumentBackend
13
26
  from docling.backend.pdf_backend import PdfDocumentBackend
14
27
  from docling.datamodel.base_models import InputFormat, Page
15
28
  from docling.datamodel.document import ConversionResult, InputDocument
16
29
  from docling.datamodel.pipeline_options import (
30
+ VlmPipelineOptions,
31
+ )
32
+ from docling.datamodel.pipeline_options_vlm_model import (
17
33
  ApiVlmOptions,
18
- HuggingFaceVlmOptions,
19
34
  InferenceFramework,
35
+ InlineVlmOptions,
20
36
  ResponseFormat,
21
- VlmPipelineOptions,
22
37
  )
23
38
  from docling.datamodel.settings import settings
24
39
  from docling.models.api_vlm_model import ApiVlmModel
25
- from docling.models.hf_mlx_model import HuggingFaceMlxModel
26
- from docling.models.hf_vlm_model import HuggingFaceVlmModel
40
+ from docling.models.vlm_models_inline.hf_transformers_model import (
41
+ HuggingFaceTransformersVlmModel,
42
+ )
43
+ from docling.models.vlm_models_inline.mlx_model import HuggingFaceMlxModel
27
44
  from docling.pipeline.base_pipeline import PaginatedPipeline
28
45
  from docling.utils.profiling import ProfilingScope, TimeRecorder
29
46
 
@@ -66,8 +83,8 @@ class VlmPipeline(PaginatedPipeline):
66
83
  vlm_options=cast(ApiVlmOptions, self.pipeline_options.vlm_options),
67
84
  ),
68
85
  ]
69
- elif isinstance(self.pipeline_options.vlm_options, HuggingFaceVlmOptions):
70
- vlm_options = cast(HuggingFaceVlmOptions, self.pipeline_options.vlm_options)
86
+ elif isinstance(self.pipeline_options.vlm_options, InlineVlmOptions):
87
+ vlm_options = cast(InlineVlmOptions, self.pipeline_options.vlm_options)
71
88
  if vlm_options.inference_framework == InferenceFramework.MLX:
72
89
  self.build_pipe = [
73
90
  HuggingFaceMlxModel(
@@ -77,15 +94,19 @@ class VlmPipeline(PaginatedPipeline):
77
94
  vlm_options=vlm_options,
78
95
  ),
79
96
  ]
80
- else:
97
+ elif vlm_options.inference_framework == InferenceFramework.TRANSFORMERS:
81
98
  self.build_pipe = [
82
- HuggingFaceVlmModel(
99
+ HuggingFaceTransformersVlmModel(
83
100
  enabled=True, # must be always enabled for this pipeline to make sense.
84
101
  artifacts_path=artifacts_path,
85
102
  accelerator_options=pipeline_options.accelerator_options,
86
103
  vlm_options=vlm_options,
87
104
  ),
88
105
  ]
106
+ else:
107
+ raise ValueError(
108
+ f"Could not instantiate the right type of VLM pipeline: {vlm_options.inference_framework}"
109
+ )
89
110
 
90
111
  self.enrichment_pipe = [
91
112
  # Other models working on `NodeItem` elements in the DoclingDocument
@@ -116,49 +137,19 @@ class VlmPipeline(PaginatedPipeline):
116
137
  self.pipeline_options.vlm_options.response_format
117
138
  == ResponseFormat.DOCTAGS
118
139
  ):
119
- doctags_list = []
120
- image_list = []
121
- for page in conv_res.pages:
122
- predicted_doctags = ""
123
- img = PILImage.new("RGB", (1, 1), "rgb(255,255,255)")
124
- if page.predictions.vlm_response:
125
- predicted_doctags = page.predictions.vlm_response.text
126
- if page.image:
127
- img = page.image
128
- image_list.append(img)
129
- doctags_list.append(predicted_doctags)
130
-
131
- doctags_list_c = cast(List[Union[Path, str]], doctags_list)
132
- image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
133
- doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
134
- doctags_list_c, image_list_c
135
- )
136
- conv_res.document = DoclingDocument.load_from_doctags(doctags_doc)
137
-
138
- # If forced backend text, replace model predicted text with backend one
139
- if self.force_backend_text:
140
- scale = self.pipeline_options.images_scale
141
- for element, _level in conv_res.document.iterate_items():
142
- if not isinstance(element, TextItem) or len(element.prov) == 0:
143
- continue
144
- page_ix = element.prov[0].page_no - 1
145
- page = conv_res.pages[page_ix]
146
- if not page.size:
147
- continue
148
- crop_bbox = (
149
- element.prov[0]
150
- .bbox.scaled(scale=scale)
151
- .to_top_left_origin(page_height=page.size.height * scale)
152
- )
153
- txt = self.extract_text_from_backend(page, crop_bbox)
154
- element.text = txt
155
- element.orig = txt
140
+ conv_res.document = self._turn_dt_into_doc(conv_res)
141
+
156
142
  elif (
157
143
  self.pipeline_options.vlm_options.response_format
158
144
  == ResponseFormat.MARKDOWN
159
145
  ):
160
146
  conv_res.document = self._turn_md_into_doc(conv_res)
161
147
 
148
+ elif (
149
+ self.pipeline_options.vlm_options.response_format == ResponseFormat.HTML
150
+ ):
151
+ conv_res.document = self._turn_html_into_doc(conv_res)
152
+
162
153
  else:
163
154
  raise RuntimeError(
164
155
  f"Unsupported VLM response format {self.pipeline_options.vlm_options.response_format}"
@@ -192,23 +183,199 @@ class VlmPipeline(PaginatedPipeline):
192
183
 
193
184
  return conv_res
194
185
 
195
- def _turn_md_into_doc(self, conv_res):
196
- predicted_text = ""
197
- for pg_idx, page in enumerate(conv_res.pages):
186
+ def _turn_dt_into_doc(self, conv_res) -> DoclingDocument:
187
+ doctags_list = []
188
+ image_list = []
189
+ for page in conv_res.pages:
190
+ predicted_doctags = ""
191
+ img = PILImage.new("RGB", (1, 1), "rgb(255,255,255)")
198
192
  if page.predictions.vlm_response:
199
- predicted_text += page.predictions.vlm_response.text + "\n\n"
200
- response_bytes = BytesIO(predicted_text.encode("utf8"))
201
- out_doc = InputDocument(
202
- path_or_stream=response_bytes,
203
- filename=conv_res.input.file.name,
204
- format=InputFormat.MD,
205
- backend=MarkdownDocumentBackend,
193
+ predicted_doctags = page.predictions.vlm_response.text
194
+ if page.image:
195
+ img = page.image
196
+ image_list.append(img)
197
+ doctags_list.append(predicted_doctags)
198
+
199
+ doctags_list_c = cast(List[Union[Path, str]], doctags_list)
200
+ image_list_c = cast(List[Union[Path, PILImage.Image]], image_list)
201
+ doctags_doc = DocTagsDocument.from_doctags_and_image_pairs(
202
+ doctags_list_c, image_list_c
206
203
  )
207
- backend = MarkdownDocumentBackend(
208
- in_doc=out_doc,
209
- path_or_stream=response_bytes,
204
+ conv_res.document = DoclingDocument.load_from_doctags(
205
+ doctag_document=doctags_doc
210
206
  )
211
- return backend.convert()
207
+
208
+ # If forced backend text, replace model predicted text with backend one
209
+ if page.size:
210
+ if self.force_backend_text:
211
+ scale = self.pipeline_options.images_scale
212
+ for element, _level in conv_res.document.iterate_items():
213
+ if not isinstance(element, TextItem) or len(element.prov) == 0:
214
+ continue
215
+ crop_bbox = (
216
+ element.prov[0]
217
+ .bbox.scaled(scale=scale)
218
+ .to_top_left_origin(page_height=page.size.height * scale)
219
+ )
220
+ txt = self.extract_text_from_backend(page, crop_bbox)
221
+ element.text = txt
222
+ element.orig = txt
223
+
224
+ return conv_res.document
225
+
226
+ def _turn_md_into_doc(self, conv_res):
227
+ def _extract_markdown_code(text):
228
+ """
229
+ Extracts text from markdown code blocks (enclosed in triple backticks).
230
+ If no code blocks are found, returns the original text.
231
+
232
+ Args:
233
+ text (str): Input text that may contain markdown code blocks
234
+
235
+ Returns:
236
+ str: Extracted code if code blocks exist, otherwise original text
237
+ """
238
+ # Regex pattern to match content between triple backticks
239
+ # This handles multiline content and optional language specifier
240
+ pattern = r"^```(?:\w*\n)?(.*?)```(\n)*$"
241
+
242
+ # Search with DOTALL flag to match across multiple lines
243
+ mtch = re.search(pattern, text, re.DOTALL)
244
+
245
+ if mtch:
246
+ # Return only the content of the first capturing group
247
+ return mtch.group(1)
248
+ else:
249
+ # No code blocks found, return original text
250
+ return text
251
+
252
+ for pg_idx, page in enumerate(conv_res.pages):
253
+ page_no = pg_idx + 1 # FIXME: might be incorrect
254
+
255
+ predicted_text = ""
256
+ if page.predictions.vlm_response:
257
+ predicted_text = page.predictions.vlm_response.text + "\n\n"
258
+
259
+ predicted_text = _extract_markdown_code(text=predicted_text)
260
+
261
+ response_bytes = BytesIO(predicted_text.encode("utf8"))
262
+ out_doc = InputDocument(
263
+ path_or_stream=response_bytes,
264
+ filename=conv_res.input.file.name,
265
+ format=InputFormat.MD,
266
+ backend=MarkdownDocumentBackend,
267
+ )
268
+ backend = MarkdownDocumentBackend(
269
+ in_doc=out_doc,
270
+ path_or_stream=response_bytes,
271
+ )
272
+ page_doc = backend.convert()
273
+
274
+ if page.image is not None:
275
+ pg_width = page.image.width
276
+ pg_height = page.image.height
277
+ else:
278
+ pg_width = 1
279
+ pg_height = 1
280
+
281
+ conv_res.document.add_page(
282
+ page_no=page_no,
283
+ size=Size(width=pg_width, height=pg_height),
284
+ image=ImageRef.from_pil(image=page.image, dpi=72)
285
+ if page.image
286
+ else None,
287
+ )
288
+
289
+ for item, level in page_doc.iterate_items():
290
+ item.prov = [
291
+ ProvenanceItem(
292
+ page_no=pg_idx + 1,
293
+ bbox=BoundingBox(
294
+ t=0.0, b=0.0, l=0.0, r=0.0
295
+ ), # FIXME: would be nice not to have to "fake" it
296
+ charspan=[0, 0],
297
+ )
298
+ ]
299
+ conv_res.document.append_child_item(child=item)
300
+
301
+ return conv_res.document
302
+
303
+ def _turn_html_into_doc(self, conv_res):
304
+ def _extract_html_code(text):
305
+ """
306
+ Extracts text from markdown code blocks (enclosed in triple backticks).
307
+ If no code blocks are found, returns the original text.
308
+
309
+ Args:
310
+ text (str): Input text that may contain markdown code blocks
311
+
312
+ Returns:
313
+ str: Extracted code if code blocks exist, otherwise original text
314
+ """
315
+ # Regex pattern to match content between triple backticks
316
+ # This handles multiline content and optional language specifier
317
+ pattern = r"^```(?:\w*\n)?(.*?)```(\n)*$"
318
+
319
+ # Search with DOTALL flag to match across multiple lines
320
+ mtch = re.search(pattern, text, re.DOTALL)
321
+
322
+ if mtch:
323
+ # Return only the content of the first capturing group
324
+ return mtch.group(1)
325
+ else:
326
+ # No code blocks found, return original text
327
+ return text
328
+
329
+ for pg_idx, page in enumerate(conv_res.pages):
330
+ page_no = pg_idx + 1 # FIXME: might be incorrect
331
+
332
+ predicted_text = ""
333
+ if page.predictions.vlm_response:
334
+ predicted_text = page.predictions.vlm_response.text + "\n\n"
335
+
336
+ predicted_text = _extract_html_code(text=predicted_text)
337
+
338
+ response_bytes = BytesIO(predicted_text.encode("utf8"))
339
+ out_doc = InputDocument(
340
+ path_or_stream=response_bytes,
341
+ filename=conv_res.input.file.name,
342
+ format=InputFormat.MD,
343
+ backend=HTMLDocumentBackend,
344
+ )
345
+ backend = HTMLDocumentBackend(
346
+ in_doc=out_doc,
347
+ path_or_stream=response_bytes,
348
+ )
349
+ page_doc = backend.convert()
350
+
351
+ if page.image is not None:
352
+ pg_width = page.image.width
353
+ pg_height = page.image.height
354
+ else:
355
+ pg_width = 1
356
+ pg_height = 1
357
+
358
+ conv_res.document.add_page(
359
+ page_no=page_no,
360
+ size=Size(width=pg_width, height=pg_height),
361
+ image=ImageRef.from_pil(image=page.image, dpi=72)
362
+ if page.image
363
+ else None,
364
+ )
365
+
366
+ for item, level in page_doc.iterate_items():
367
+ item.prov = [
368
+ ProvenanceItem(
369
+ page_no=pg_idx + 1,
370
+ bbox=BoundingBox(
371
+ t=0.0, b=0.0, l=0.0, r=0.0
372
+ ), # FIXME: would be nice not to have to "fake" it
373
+ charspan=[0, 0],
374
+ )
375
+ ]
376
+ conv_res.document.append_child_item(child=item)
377
+
378
+ return conv_res.document
212
379
 
213
380
  @classmethod
214
381
  def get_default_options(cls) -> VlmPipelineOptions:
@@ -1,13 +1,16 @@
1
1
  import logging
2
+ from typing import List, Optional
2
3
 
3
4
  import torch
4
5
 
5
- from docling.datamodel.pipeline_options import AcceleratorDevice
6
+ from docling.datamodel.accelerator_options import AcceleratorDevice
6
7
 
7
8
  _log = logging.getLogger(__name__)
8
9
 
9
10
 
10
- def decide_device(accelerator_device: str) -> str:
11
+ def decide_device(
12
+ accelerator_device: str, supported_devices: Optional[List[AcceleratorDevice]] = None
13
+ ) -> str:
11
14
  r"""
12
15
  Resolve the device based on the acceleration options and the available devices in the system.
13
16
 
@@ -20,6 +23,18 @@ def decide_device(accelerator_device: str) -> str:
20
23
  has_cuda = torch.backends.cuda.is_built() and torch.cuda.is_available()
21
24
  has_mps = torch.backends.mps.is_built() and torch.backends.mps.is_available()
22
25
 
26
+ if supported_devices is not None:
27
+ if has_cuda and AcceleratorDevice.CUDA not in supported_devices:
28
+ _log.info(
29
+ f"Removing CUDA from available devices because it is not in {supported_devices=}"
30
+ )
31
+ has_cuda = False
32
+ if has_mps and AcceleratorDevice.MPS not in supported_devices:
33
+ _log.info(
34
+ f"Removing MPS from available devices because it is not in {supported_devices=}"
35
+ )
36
+ has_mps = False
37
+
23
38
  if accelerator_device == AcceleratorDevice.AUTO.value: # Handle 'auto'
24
39
  if has_cuda:
25
40
  device = "cuda:0"
@@ -4,18 +4,20 @@ from typing import Optional
4
4
 
5
5
  from docling.datamodel.pipeline_options import (
6
6
  granite_picture_description,
7
- smoldocling_vlm_conversion_options,
8
- smoldocling_vlm_mlx_conversion_options,
9
7
  smolvlm_picture_description,
10
8
  )
11
9
  from docling.datamodel.settings import settings
10
+ from docling.datamodel.vlm_model_specs import (
11
+ SMOLDOCLING_MLX,
12
+ SMOLDOCLING_TRANSFORMERS,
13
+ )
12
14
  from docling.models.code_formula_model import CodeFormulaModel
13
15
  from docling.models.document_picture_classifier import DocumentPictureClassifier
14
16
  from docling.models.easyocr_model import EasyOcrModel
15
- from docling.models.hf_vlm_model import HuggingFaceVlmModel
16
17
  from docling.models.layout_model import LayoutModel
17
18
  from docling.models.picture_description_vlm_model import PictureDescriptionVlmModel
18
19
  from docling.models.table_structure_model import TableStructureModel
20
+ from docling.models.utils.hf_model_download import download_hf_model
19
21
 
20
22
  _log = logging.getLogger(__name__)
21
23
 
@@ -75,7 +77,7 @@ def download_models(
75
77
 
76
78
  if with_smolvlm:
77
79
  _log.info("Downloading SmolVlm model...")
78
- PictureDescriptionVlmModel.download_models(
80
+ download_hf_model(
79
81
  repo_id=smolvlm_picture_description.repo_id,
80
82
  local_dir=output_dir / smolvlm_picture_description.repo_cache_folder,
81
83
  force=force,
@@ -84,26 +86,25 @@ def download_models(
84
86
 
85
87
  if with_smoldocling:
86
88
  _log.info("Downloading SmolDocling model...")
87
- HuggingFaceVlmModel.download_models(
88
- repo_id=smoldocling_vlm_conversion_options.repo_id,
89
- local_dir=output_dir / smoldocling_vlm_conversion_options.repo_cache_folder,
89
+ download_hf_model(
90
+ repo_id=SMOLDOCLING_TRANSFORMERS.repo_id,
91
+ local_dir=output_dir / SMOLDOCLING_TRANSFORMERS.repo_cache_folder,
90
92
  force=force,
91
93
  progress=progress,
92
94
  )
93
95
 
94
96
  if with_smoldocling_mlx:
95
97
  _log.info("Downloading SmolDocling MLX model...")
96
- HuggingFaceVlmModel.download_models(
97
- repo_id=smoldocling_vlm_mlx_conversion_options.repo_id,
98
- local_dir=output_dir
99
- / smoldocling_vlm_mlx_conversion_options.repo_cache_folder,
98
+ download_hf_model(
99
+ repo_id=SMOLDOCLING_MLX.repo_id,
100
+ local_dir=output_dir / SMOLDOCLING_MLX.repo_cache_folder,
100
101
  force=force,
101
102
  progress=progress,
102
103
  )
103
104
 
104
105
  if with_granite_vision:
105
106
  _log.info("Downloading Granite Vision model...")
106
- PictureDescriptionVlmModel.download_models(
107
+ download_hf_model(
107
108
  repo_id=granite_picture_description.repo_id,
108
109
  local_dir=output_dir / granite_picture_description.repo_cache_folder,
109
110
  force=force,