docling 2.24.0__py3-none-any.whl → 2.25.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -154,6 +154,10 @@ class LayoutPrediction(BaseModel):
154
154
  clusters: List[Cluster] = []
155
155
 
156
156
 
157
+ class VlmPrediction(BaseModel):
158
+ text: str = ""
159
+
160
+
157
161
  class ContainerElement(
158
162
  BasePageElement
159
163
  ): # Used for Form and Key-Value-Regions, only for typing.
@@ -197,6 +201,7 @@ class PagePredictions(BaseModel):
197
201
  tablestructure: Optional[TableStructurePrediction] = None
198
202
  figures_classification: Optional[FigureClassificationPrediction] = None
199
203
  equations_prediction: Optional[EquationPrediction] = None
204
+ vlm_response: Optional[VlmPrediction] = None
200
205
 
201
206
 
202
207
  PageElement = Union[TextElement, Table, FigureElement, ContainerElement]
@@ -41,6 +41,7 @@ class AcceleratorOptions(BaseSettings):
41
41
 
42
42
  num_threads: int = 4
43
43
  device: Union[str, AcceleratorDevice] = "auto"
44
+ cuda_use_flash_attention2: bool = False
44
45
 
45
46
  @field_validator("device")
46
47
  def validate_device(cls, value):
@@ -254,6 +255,45 @@ granite_picture_description = PictureDescriptionVlmOptions(
254
255
  )
255
256
 
256
257
 
258
+ class BaseVlmOptions(BaseModel):
259
+ kind: str
260
+ prompt: str
261
+
262
+
263
+ class ResponseFormat(str, Enum):
264
+ DOCTAGS = "doctags"
265
+ MARKDOWN = "markdown"
266
+
267
+
268
+ class HuggingFaceVlmOptions(BaseVlmOptions):
269
+ kind: Literal["hf_model_options"] = "hf_model_options"
270
+
271
+ repo_id: str
272
+ load_in_8bit: bool = True
273
+ llm_int8_threshold: float = 6.0
274
+ quantized: bool = False
275
+
276
+ response_format: ResponseFormat
277
+
278
+ @property
279
+ def repo_cache_folder(self) -> str:
280
+ return self.repo_id.replace("/", "--")
281
+
282
+
283
+ smoldocling_vlm_conversion_options = HuggingFaceVlmOptions(
284
+ repo_id="ds4sd/SmolDocling-256M-preview",
285
+ prompt="Convert this page to docling.",
286
+ response_format=ResponseFormat.DOCTAGS,
287
+ )
288
+
289
+ granite_vision_vlm_conversion_options = HuggingFaceVlmOptions(
290
+ repo_id="ibm-granite/granite-vision-3.1-2b-preview",
291
+ # prompt="OCR the full page to markdown.",
292
+ prompt="OCR this image.",
293
+ response_format=ResponseFormat.MARKDOWN,
294
+ )
295
+
296
+
257
297
  # Define an enum for the backend options
258
298
  class PdfBackend(str, Enum):
259
299
  """Enum of valid PDF backends."""
@@ -285,7 +325,24 @@ class PipelineOptions(BaseModel):
285
325
  enable_remote_services: bool = False
286
326
 
287
327
 
288
- class PdfPipelineOptions(PipelineOptions):
328
+ class PaginatedPipelineOptions(PipelineOptions):
329
+ images_scale: float = 1.0
330
+ generate_page_images: bool = False
331
+ generate_picture_images: bool = False
332
+
333
+
334
+ class VlmPipelineOptions(PaginatedPipelineOptions):
335
+ artifacts_path: Optional[Union[Path, str]] = None
336
+
337
+ generate_page_images: bool = True
338
+ force_backend_text: bool = (
339
+ False # (To be used with vlms, or other generative models)
340
+ )
341
+ # If True, text from backend will be used instead of generated text
342
+ vlm_options: Union[HuggingFaceVlmOptions] = smoldocling_vlm_conversion_options
343
+
344
+
345
+ class PdfPipelineOptions(PaginatedPipelineOptions):
289
346
  """Options for the PDF pipeline."""
290
347
 
291
348
  artifacts_path: Optional[Union[Path, str]] = None
@@ -295,6 +352,10 @@ class PdfPipelineOptions(PipelineOptions):
295
352
  do_formula_enrichment: bool = False # True: perform formula OCR, return Latex code
296
353
  do_picture_classification: bool = False # True: classify pictures in documents
297
354
  do_picture_description: bool = False # True: run describe pictures in documents
355
+ force_backend_text: bool = (
356
+ False # (To be used with vlms, or other generative models)
357
+ )
358
+ # If True, text from backend will be used instead of generated text
298
359
 
299
360
  table_structure_options: TableStructureOptions = TableStructureOptions()
300
361
  ocr_options: Union[
@@ -0,0 +1,180 @@
1
+ import logging
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Iterable, List, Optional
5
+
6
+ from docling.datamodel.base_models import Page, VlmPrediction
7
+ from docling.datamodel.document import ConversionResult
8
+ from docling.datamodel.pipeline_options import (
9
+ AcceleratorDevice,
10
+ AcceleratorOptions,
11
+ HuggingFaceVlmOptions,
12
+ )
13
+ from docling.datamodel.settings import settings
14
+ from docling.models.base_model import BasePageModel
15
+ from docling.utils.accelerator_utils import decide_device
16
+ from docling.utils.profiling import TimeRecorder
17
+
18
+ _log = logging.getLogger(__name__)
19
+
20
+
21
+ class HuggingFaceVlmModel(BasePageModel):
22
+
23
+ def __init__(
24
+ self,
25
+ enabled: bool,
26
+ artifacts_path: Optional[Path],
27
+ accelerator_options: AcceleratorOptions,
28
+ vlm_options: HuggingFaceVlmOptions,
29
+ ):
30
+ self.enabled = enabled
31
+
32
+ self.vlm_options = vlm_options
33
+
34
+ if self.enabled:
35
+ import torch
36
+ from transformers import ( # type: ignore
37
+ AutoModelForVision2Seq,
38
+ AutoProcessor,
39
+ BitsAndBytesConfig,
40
+ )
41
+
42
+ device = decide_device(accelerator_options.device)
43
+ self.device = device
44
+
45
+ _log.debug("Available device for HuggingFace VLM: {}".format(device))
46
+
47
+ repo_cache_folder = vlm_options.repo_id.replace("/", "--")
48
+
49
+ # PARAMETERS:
50
+ if artifacts_path is None:
51
+ artifacts_path = self.download_models(self.vlm_options.repo_id)
52
+ elif (artifacts_path / repo_cache_folder).exists():
53
+ artifacts_path = artifacts_path / repo_cache_folder
54
+
55
+ self.param_question = vlm_options.prompt # "Perform Layout Analysis."
56
+ self.param_quantization_config = BitsAndBytesConfig(
57
+ load_in_8bit=vlm_options.load_in_8bit, # True,
58
+ llm_int8_threshold=vlm_options.llm_int8_threshold, # 6.0
59
+ )
60
+ self.param_quantized = vlm_options.quantized # False
61
+
62
+ self.processor = AutoProcessor.from_pretrained(artifacts_path)
63
+ if not self.param_quantized:
64
+ self.vlm_model = AutoModelForVision2Seq.from_pretrained(
65
+ artifacts_path,
66
+ device_map=device,
67
+ torch_dtype=torch.bfloat16,
68
+ _attn_implementation=(
69
+ "flash_attention_2"
70
+ if self.device.startswith("cuda")
71
+ and accelerator_options.cuda_use_flash_attention2
72
+ else "eager"
73
+ ),
74
+ ) # .to(self.device)
75
+
76
+ else:
77
+ self.vlm_model = AutoModelForVision2Seq.from_pretrained(
78
+ artifacts_path,
79
+ device_map=device,
80
+ torch_dtype="auto",
81
+ quantization_config=self.param_quantization_config,
82
+ _attn_implementation=(
83
+ "flash_attention_2"
84
+ if self.device.startswith("cuda")
85
+ and accelerator_options.cuda_use_flash_attention2
86
+ else "eager"
87
+ ),
88
+ ) # .to(self.device)
89
+
90
+ @staticmethod
91
+ def download_models(
92
+ repo_id: str,
93
+ local_dir: Optional[Path] = None,
94
+ force: bool = False,
95
+ progress: bool = False,
96
+ ) -> Path:
97
+ from huggingface_hub import snapshot_download
98
+ from huggingface_hub.utils import disable_progress_bars
99
+
100
+ if not progress:
101
+ disable_progress_bars()
102
+ download_path = snapshot_download(
103
+ repo_id=repo_id,
104
+ force_download=force,
105
+ local_dir=local_dir,
106
+ # revision="v0.0.1",
107
+ )
108
+
109
+ return Path(download_path)
110
+
111
+ def __call__(
112
+ self, conv_res: ConversionResult, page_batch: Iterable[Page]
113
+ ) -> Iterable[Page]:
114
+ for page in page_batch:
115
+ assert page._backend is not None
116
+ if not page._backend.is_valid():
117
+ yield page
118
+ else:
119
+ with TimeRecorder(conv_res, "vlm"):
120
+ assert page.size is not None
121
+
122
+ hi_res_image = page.get_image(scale=2.0) # 144dpi
123
+ # hi_res_image = page.get_image(scale=1.0) # 72dpi
124
+
125
+ if hi_res_image is not None:
126
+ im_width, im_height = hi_res_image.size
127
+
128
+ # populate page_tags with predicted doc tags
129
+ page_tags = ""
130
+
131
+ if hi_res_image:
132
+ if hi_res_image.mode != "RGB":
133
+ hi_res_image = hi_res_image.convert("RGB")
134
+
135
+ messages = [
136
+ {
137
+ "role": "user",
138
+ "content": [
139
+ {
140
+ "type": "text",
141
+ "text": "This is a page from a document.",
142
+ },
143
+ {"type": "image"},
144
+ {"type": "text", "text": self.param_question},
145
+ ],
146
+ }
147
+ ]
148
+ prompt = self.processor.apply_chat_template(
149
+ messages, add_generation_prompt=False
150
+ )
151
+ inputs = self.processor(
152
+ text=prompt, images=[hi_res_image], return_tensors="pt"
153
+ )
154
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
155
+
156
+ start_time = time.time()
157
+ # Call model to generate:
158
+ generated_ids = self.vlm_model.generate(
159
+ **inputs, max_new_tokens=4096, use_cache=True
160
+ )
161
+
162
+ generation_time = time.time() - start_time
163
+ generated_texts = self.processor.batch_decode(
164
+ generated_ids[:, inputs["input_ids"].shape[1] :],
165
+ skip_special_tokens=False,
166
+ )[0]
167
+
168
+ num_tokens = len(generated_ids[0])
169
+ page_tags = generated_texts
170
+
171
+ # inference_time = time.time() - start_time
172
+ # tokens_per_second = num_tokens / generation_time
173
+ # print("")
174
+ # print(f"Page Inference Time: {inference_time:.2f} seconds")
175
+ # print(f"Total tokens on page: {num_tokens:.2f}")
176
+ # print(f"Tokens/sec: {tokens_per_second:.2f}")
177
+ # print("")
178
+ page.predictions.vlm_response = VlmPrediction(text=page_tags)
179
+
180
+ yield page
@@ -41,9 +41,9 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
41
41
  )
42
42
 
43
43
  # Initialize processor and model
44
- self.processor = AutoProcessor.from_pretrained(self.options.repo_id)
44
+ self.processor = AutoProcessor.from_pretrained(artifacts_path)
45
45
  self.model = AutoModelForVision2Seq.from_pretrained(
46
- self.options.repo_id,
46
+ artifacts_path,
47
47
  torch_dtype=torch.bfloat16,
48
48
  _attn_implementation=(
49
49
  "flash_attention_2" if self.device.startswith("cuda") else "eager"