docling 2.53.0__py3-none-any.whl → 2.55.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,157 @@
1
+ import logging
2
+ import re
3
+ import sys
4
+ from abc import abstractmethod
5
+ from typing import List
6
+
7
+ from transformers import StoppingCriteria
8
+
9
+ _log = logging.getLogger(__name__)
10
+
11
+
12
+ class GenerationStopper:
13
+ """
14
+ Base interface for stopping logic.
15
+ - should_stop(s): True to stop given the current decoded text window.
16
+ - lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
17
+ """
18
+
19
+ @abstractmethod
20
+ def should_stop(self, s: str) -> bool:
21
+ pass
22
+
23
+ def lookback_tokens(self) -> int:
24
+ return sys.maxsize
25
+
26
+
27
+ class DocTagsRepetitionStopper(GenerationStopper):
28
+ """
29
+ Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
30
+ but only when repeats are **consecutive** and both tag & inner text are identical.
31
+
32
+ Performance:
33
+ - Heavy check runs every N calls (default 32).
34
+ - Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
35
+ """
36
+
37
+ def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
38
+ self.N = max(1, int(N))
39
+ self._lookback_tokens = max(1, int(lookback_tokens))
40
+ self._call_count = 0
41
+
42
+ # <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
43
+ self._PATTERN = re.compile(
44
+ r"""
45
+ <(?P<tag>[a-zA-Z0-9_]+)>\s*
46
+ (?P<prefix>.*?)?
47
+ <loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
48
+ (?P<text>.*?)
49
+ </(?P=tag)>
50
+ """,
51
+ re.DOTALL | re.VERBOSE,
52
+ )
53
+
54
+ # --- small helper ---
55
+ def _regular(self, vals: List[int]) -> bool:
56
+ """3+ strictly increasing values with ~regular spacing (±20%)."""
57
+ if len(vals) < 3:
58
+ return False
59
+ diffs = [b - a for a, b in zip(vals, vals[1:])]
60
+ if any(d <= 0 for d in diffs):
61
+ return False
62
+ mean = sum(diffs) / len(diffs)
63
+ tol = 0.2 * mean
64
+ return all(abs(d - mean) <= tol for d in diffs)
65
+
66
+ def should_stop(self, s: str) -> bool:
67
+ """
68
+ Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
69
+ with the same <tag> and identical inner text, where within that run we see:
70
+ - any exact duplicate (x,y,w,h), or
71
+ - stable X/W with regular Y progression, or
72
+ - stable Y/H with regular X progression.
73
+ """
74
+ # Stream matches and evaluate runs on-the-fly to stay compact and fast.
75
+ prev_tag = prev_text = None
76
+ run = [] # list of (x,y,w,h)
77
+
78
+ def run_repetitive(boxes: List[tuple]) -> bool:
79
+ if len(boxes) < 3:
80
+ return False
81
+ # duplicates?
82
+ if len(set(boxes)) < len(boxes):
83
+ return True
84
+ xs, ys, ws, hs = zip(*boxes)
85
+ x_stable = all(x == xs[0] for x in xs)
86
+ y_stable = all(y == ys[0] for y in ys)
87
+ w_stable = all(w == ws[0] for w in ws)
88
+ h_stable = all(h == hs[0] for h in hs)
89
+ # horizontal (down the page): X/W stable, Y regular
90
+ if (x_stable or w_stable) and self._regular(list(ys)):
91
+ return True
92
+ # vertical (across): Y/H stable, X regular
93
+ if (y_stable or h_stable) and self._regular(list(xs)):
94
+ return True
95
+ return False
96
+
97
+ for m in self._PATTERN.finditer(s):
98
+ tag, text = m.group("tag"), m.group("text")
99
+ box = (
100
+ int(m.group("x")),
101
+ int(m.group("y")),
102
+ int(m.group("w")),
103
+ int(m.group("h")),
104
+ )
105
+
106
+ if prev_tag == tag and prev_text == text:
107
+ run.append(box) # consecutive same-tag+text
108
+ else:
109
+ # evaluate previous run before starting a new one
110
+ if run_repetitive(run):
111
+ return True
112
+ prev_tag, prev_text = tag, text
113
+ run = [box]
114
+
115
+ # check the last run
116
+ return run_repetitive(run)
117
+
118
+
119
+ class HFStoppingCriteriaWrapper(StoppingCriteria):
120
+ """
121
+ Adapts any GenerationStopper to HuggingFace Transformers.
122
+ Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ tokenizer,
128
+ stopper: GenerationStopper,
129
+ *,
130
+ skip_special_tokens: bool = False,
131
+ ):
132
+ self.tokenizer = tokenizer
133
+ self.stopper = stopper
134
+ self.skip_special_tokens = skip_special_tokens
135
+
136
+ def __call__(self, input_ids, scores, **kwargs) -> bool:
137
+ lb = max(1, int(self.stopper.lookback_tokens()))
138
+ for seq in input_ids: # (batch, seq_len)
139
+ window = seq[-lb:] # slicing handles lb > len(seq)
140
+ try:
141
+ text = self.tokenizer.decode(
142
+ window, skip_special_tokens=self.skip_special_tokens
143
+ )
144
+ except Exception as e:
145
+ _log.info(f"Decoding failed for stopping check: {e}")
146
+ continue
147
+
148
+ try:
149
+ if self.stopper.should_stop(text):
150
+ _log.info(
151
+ "HF wrapper: stopping due to TextStopper.should_stop==True"
152
+ )
153
+ return True
154
+ except Exception as e:
155
+ _log.info(f"Error in TextStopper.should_stop: {e}")
156
+ continue
157
+ return False
@@ -34,7 +34,12 @@ class HuggingFaceModelDownloadMixin:
34
34
  local_dir: Optional[Path] = None,
35
35
  force: bool = False,
36
36
  progress: bool = False,
37
+ revision: Optional[str] = None,
37
38
  ) -> Path:
38
39
  return download_hf_model(
39
- repo_id=repo_id, local_dir=local_dir, force=force, progress=progress
40
+ repo_id=repo_id,
41
+ local_dir=local_dir,
42
+ force=force,
43
+ progress=progress,
44
+ revision=revision,
40
45
  )
@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  from PIL.Image import Image
10
- from transformers import StoppingCriteriaList, StopStringCriteria
10
+ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
11
11
 
12
12
  from docling.datamodel.accelerator_options import (
13
13
  AcceleratorOptions,
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
20
20
  TransformersPromptStyle,
21
21
  )
22
22
  from docling.models.base_model import BaseVlmPageModel
23
+ from docling.models.utils.generation_utils import (
24
+ GenerationStopper,
25
+ HFStoppingCriteriaWrapper,
26
+ )
23
27
  from docling.models.utils.hf_model_download import (
24
28
  HuggingFaceModelDownloadMixin,
25
29
  )
@@ -75,7 +79,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
75
79
  repo_cache_folder = vlm_options.repo_id.replace("/", "--")
76
80
 
77
81
  if artifacts_path is None:
78
- artifacts_path = self.download_models(self.vlm_options.repo_id)
82
+ artifacts_path = self.download_models(
83
+ self.vlm_options.repo_id, revision=self.vlm_options.revision
84
+ )
79
85
  elif (artifacts_path / repo_cache_folder).exists():
80
86
  artifacts_path = artifacts_path / repo_cache_folder
81
87
 
@@ -106,6 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
106
112
  self.processor = AutoProcessor.from_pretrained(
107
113
  artifacts_path,
108
114
  trust_remote_code=vlm_options.trust_remote_code,
115
+ revision=vlm_options.revision,
109
116
  )
110
117
  self.processor.tokenizer.padding_side = "left"
111
118
 
@@ -120,11 +127,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
120
127
  else "sdpa"
121
128
  ),
122
129
  trust_remote_code=vlm_options.trust_remote_code,
130
+ revision=vlm_options.revision,
123
131
  )
124
132
  self.vlm_model = torch.compile(self.vlm_model) # type: ignore
125
133
 
126
134
  # Load generation config
127
- self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
135
+ self.generation_config = GenerationConfig.from_pretrained(
136
+ artifacts_path, revision=vlm_options.revision
137
+ )
128
138
 
129
139
  def __call__(
130
140
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
@@ -196,7 +206,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
196
206
  import torch
197
207
  from PIL import Image as PILImage
198
208
 
199
- # -- Normalize images to RGB PIL (SmolDocling & friends accept PIL/np via processor)
209
+ # -- Normalize images to RGB PIL
200
210
  pil_images: list[Image] = []
201
211
  for img in image_batch:
202
212
  if isinstance(img, np.ndarray):
@@ -247,24 +257,74 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
247
257
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
248
258
 
249
259
  # -- Optional stopping criteria
250
- stopping_criteria = None
260
+ stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
261
+
262
+ # Add string-based stopping criteria
251
263
  if self.vlm_options.stop_strings:
252
- stopping_criteria = StoppingCriteriaList(
253
- [
254
- StopStringCriteria(
255
- stop_strings=self.vlm_options.stop_strings,
256
- tokenizer=self.processor.tokenizer,
257
- )
258
- ]
264
+ stopping_criteria_list.append(
265
+ StopStringCriteria(
266
+ stop_strings=self.vlm_options.stop_strings,
267
+ tokenizer=self.processor.tokenizer,
268
+ )
259
269
  )
260
270
 
271
+ # Add custom stopping criteria
272
+ if self.vlm_options.custom_stopping_criteria:
273
+ for criteria in self.vlm_options.custom_stopping_criteria:
274
+ # If it's a class (not an instance), determine the type and handle accordingly
275
+ if isinstance(criteria, type):
276
+ # Check if it's a GenerationStopper class
277
+ if issubclass(criteria, GenerationStopper):
278
+ # Instantiate GenerationStopper and wrap it
279
+ stopper_instance = criteria()
280
+ wrapped_criteria = HFStoppingCriteriaWrapper(
281
+ self.processor.tokenizer, stopper_instance
282
+ )
283
+ stopping_criteria_list.append(wrapped_criteria)
284
+ elif issubclass(criteria, StoppingCriteria):
285
+ # It's a StoppingCriteria class, instantiate with tokenizer
286
+ criteria_instance = criteria(self.processor.tokenizer)
287
+ stopping_criteria_list.append(criteria_instance)
288
+ elif isinstance(criteria, GenerationStopper):
289
+ # Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
290
+ wrapped_criteria = HFStoppingCriteriaWrapper(
291
+ self.processor.tokenizer, criteria
292
+ )
293
+ stopping_criteria_list.append(wrapped_criteria)
294
+ else:
295
+ # If it's already an instance of StoppingCriteria, use it directly
296
+ stopping_criteria_list.append(criteria)
297
+
298
+ stopping_criteria = (
299
+ StoppingCriteriaList(stopping_criteria_list)
300
+ if stopping_criteria_list
301
+ else None
302
+ )
303
+
304
+ # -- Filter out decoder-specific keys from extra_generation_config
305
+ decoder_keys = {
306
+ "skip_special_tokens",
307
+ "clean_up_tokenization_spaces",
308
+ "spaces_between_special_tokens",
309
+ }
310
+ generation_config = {
311
+ k: v
312
+ for k, v in self.vlm_options.extra_generation_config.items()
313
+ if k not in decoder_keys
314
+ }
315
+ decoder_config = {
316
+ k: v
317
+ for k, v in self.vlm_options.extra_generation_config.items()
318
+ if k in decoder_keys
319
+ }
320
+
261
321
  # -- Generate (Image-Text-to-Text class expects these inputs from processor)
262
322
  gen_kwargs = {
263
323
  **inputs,
264
324
  "max_new_tokens": self.max_new_tokens,
265
325
  "use_cache": self.use_cache,
266
326
  "generation_config": self.generation_config,
267
- **self.vlm_options.extra_generation_config,
327
+ **generation_config,
268
328
  }
269
329
  if self.temperature > 0:
270
330
  gen_kwargs["do_sample"] = True
@@ -293,7 +353,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
293
353
  )
294
354
 
295
355
  decoded_texts: list[str] = decode_fn(
296
- trimmed_sequences, skip_special_tokens=False
356
+ trimmed_sequences,
357
+ **decoder_config,
297
358
  )
298
359
 
299
360
  # -- Clip off pad tokens from decoded texts
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import sys
2
3
  import threading
3
4
  import time
4
5
  from collections.abc import Iterable
@@ -7,6 +8,7 @@ from typing import Optional, Union
7
8
 
8
9
  import numpy as np
9
10
  from PIL.Image import Image
11
+ from transformers import StoppingCriteria
10
12
 
11
13
  from docling.datamodel.accelerator_options import (
12
14
  AcceleratorOptions,
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
15
17
  from docling.datamodel.document import ConversionResult
16
18
  from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
17
19
  from docling.models.base_model import BaseVlmPageModel
20
+ from docling.models.utils.generation_utils import GenerationStopper
18
21
  from docling.models.utils.hf_model_download import (
19
22
  HuggingFaceModelDownloadMixin,
20
23
  )
@@ -60,6 +63,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
60
63
  if artifacts_path is None:
61
64
  artifacts_path = self.download_models(
62
65
  self.vlm_options.repo_id,
66
+ revision=self.vlm_options.revision,
63
67
  )
64
68
  elif (artifacts_path / repo_cache_folder).exists():
65
69
  artifacts_path = artifacts_path / repo_cache_folder
@@ -68,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
68
72
  self.vlm_model, self.processor = load(artifacts_path)
69
73
  self.config = load_config(artifacts_path)
70
74
 
75
+ # Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
76
+ if self.vlm_options.custom_stopping_criteria:
77
+ for criteria in self.vlm_options.custom_stopping_criteria:
78
+ if isinstance(criteria, StoppingCriteria):
79
+ raise ValueError(
80
+ f"MLX models do not support HuggingFace StoppingCriteria instances. "
81
+ f"Found {type(criteria).__name__}. Use GenerationStopper instead."
82
+ )
83
+ elif isinstance(criteria, type) and issubclass(
84
+ criteria, StoppingCriteria
85
+ ):
86
+ raise ValueError(
87
+ f"MLX models do not support HuggingFace StoppingCriteria classes. "
88
+ f"Found {criteria.__name__}. Use GenerationStopper instead."
89
+ )
90
+
71
91
  def __call__(
72
92
  self, conv_res: ConversionResult, page_batch: Iterable[Page]
73
93
  ) -> Iterable[Page]:
@@ -192,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
192
212
  self.processor, self.config, user_prompt, num_images=1
193
213
  )
194
214
 
195
- # Stream generate with stop strings support
215
+ # Stream generate with stop strings and custom stopping criteria support
196
216
  start_time = time.time()
197
217
  _log.debug("start generating ...")
198
218
 
@@ -244,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
244
264
  _log.debug("Stopping generation due to stop string match")
245
265
  break
246
266
 
267
+ # Check for custom stopping criteria (GenerationStopper instances)
268
+ if self.vlm_options.custom_stopping_criteria:
269
+ for criteria in self.vlm_options.custom_stopping_criteria:
270
+ # Handle both instances and classes of GenerationStopper
271
+ if isinstance(criteria, GenerationStopper):
272
+ stopper = criteria
273
+ elif isinstance(criteria, type) and issubclass(
274
+ criteria, GenerationStopper
275
+ ):
276
+ stopper = criteria()
277
+
278
+ # Determine the text window to check based on lookback_tokens
279
+ lookback_tokens = stopper.lookback_tokens()
280
+ # Check only the last N characters worth of text
281
+ # This is a simplified approach - in practice, you might want to
282
+ # decode the last N tokens from the token list for more accuracy
283
+ text_to_check = (
284
+ output[-lookback_tokens:]
285
+ if len(output) > lookback_tokens
286
+ else output
287
+ )
288
+
289
+ try:
290
+ if stopper.should_stop(text_to_check):
291
+ _log.info(
292
+ f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
293
+ )
294
+ break
295
+ except Exception as e:
296
+ _log.warning(
297
+ f"Error in GenerationStopper.should_stop: {e}"
298
+ )
299
+ continue
300
+ else: # note: for-else idiom
301
+ continue # Only executed if the inner loop didn't break
302
+ break # Break the outer loop if any stopper triggered
303
+
247
304
  generation_time = time.time() - start_time
248
305
 
249
306
  _log.debug(