docling 2.53.0__py3-none-any.whl → 2.55.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/asciidoc_backend.py +1 -1
- docling/backend/html_backend.py +254 -136
- docling/backend/md_backend.py +4 -1
- docling/backend/msword_backend.py +177 -76
- docling/backend/webvtt_backend.py +572 -0
- docling/backend/xml/jats_backend.py +111 -7
- docling/backend/xml/uspto_backend.py +1 -1
- docling/cli/main.py +5 -0
- docling/datamodel/base_models.py +23 -23
- docling/datamodel/document.py +2 -0
- docling/datamodel/pipeline_options_vlm_model.py +13 -2
- docling/datamodel/vlm_model_specs.py +9 -0
- docling/document_converter.py +4 -0
- docling/models/api_vlm_model.py +45 -16
- docling/models/base_model.py +2 -1
- docling/models/readingorder_model.py +1 -1
- docling/models/table_structure_model.py +3 -3
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +6 -1
- docling/models/vlm_models_inline/hf_transformers_model.py +75 -14
- docling/models/vlm_models_inline/mlx_model.py +58 -1
- docling/models/vlm_models_inline/vllm_model.py +189 -124
- docling/utils/api_image_request.py +107 -1
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/METADATA +5 -5
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/RECORD +29 -27
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/WHEEL +0 -0
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/entry_points.txt +0 -0
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/licenses/LICENSE +0 -0
- {docling-2.53.0.dist-info → docling-2.55.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
import sys
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from transformers import StoppingCriteria
|
|
8
|
+
|
|
9
|
+
_log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GenerationStopper:
|
|
13
|
+
"""
|
|
14
|
+
Base interface for stopping logic.
|
|
15
|
+
- should_stop(s): True to stop given the current decoded text window.
|
|
16
|
+
- lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def should_stop(self, s: str) -> bool:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def lookback_tokens(self) -> int:
|
|
24
|
+
return sys.maxsize
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DocTagsRepetitionStopper(GenerationStopper):
|
|
28
|
+
"""
|
|
29
|
+
Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
|
|
30
|
+
but only when repeats are **consecutive** and both tag & inner text are identical.
|
|
31
|
+
|
|
32
|
+
Performance:
|
|
33
|
+
- Heavy check runs every N calls (default 32).
|
|
34
|
+
- Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
|
|
38
|
+
self.N = max(1, int(N))
|
|
39
|
+
self._lookback_tokens = max(1, int(lookback_tokens))
|
|
40
|
+
self._call_count = 0
|
|
41
|
+
|
|
42
|
+
# <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
|
|
43
|
+
self._PATTERN = re.compile(
|
|
44
|
+
r"""
|
|
45
|
+
<(?P<tag>[a-zA-Z0-9_]+)>\s*
|
|
46
|
+
(?P<prefix>.*?)?
|
|
47
|
+
<loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
|
|
48
|
+
(?P<text>.*?)
|
|
49
|
+
</(?P=tag)>
|
|
50
|
+
""",
|
|
51
|
+
re.DOTALL | re.VERBOSE,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# --- small helper ---
|
|
55
|
+
def _regular(self, vals: List[int]) -> bool:
|
|
56
|
+
"""3+ strictly increasing values with ~regular spacing (±20%)."""
|
|
57
|
+
if len(vals) < 3:
|
|
58
|
+
return False
|
|
59
|
+
diffs = [b - a for a, b in zip(vals, vals[1:])]
|
|
60
|
+
if any(d <= 0 for d in diffs):
|
|
61
|
+
return False
|
|
62
|
+
mean = sum(diffs) / len(diffs)
|
|
63
|
+
tol = 0.2 * mean
|
|
64
|
+
return all(abs(d - mean) <= tol for d in diffs)
|
|
65
|
+
|
|
66
|
+
def should_stop(self, s: str) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
|
|
69
|
+
with the same <tag> and identical inner text, where within that run we see:
|
|
70
|
+
- any exact duplicate (x,y,w,h), or
|
|
71
|
+
- stable X/W with regular Y progression, or
|
|
72
|
+
- stable Y/H with regular X progression.
|
|
73
|
+
"""
|
|
74
|
+
# Stream matches and evaluate runs on-the-fly to stay compact and fast.
|
|
75
|
+
prev_tag = prev_text = None
|
|
76
|
+
run = [] # list of (x,y,w,h)
|
|
77
|
+
|
|
78
|
+
def run_repetitive(boxes: List[tuple]) -> bool:
|
|
79
|
+
if len(boxes) < 3:
|
|
80
|
+
return False
|
|
81
|
+
# duplicates?
|
|
82
|
+
if len(set(boxes)) < len(boxes):
|
|
83
|
+
return True
|
|
84
|
+
xs, ys, ws, hs = zip(*boxes)
|
|
85
|
+
x_stable = all(x == xs[0] for x in xs)
|
|
86
|
+
y_stable = all(y == ys[0] for y in ys)
|
|
87
|
+
w_stable = all(w == ws[0] for w in ws)
|
|
88
|
+
h_stable = all(h == hs[0] for h in hs)
|
|
89
|
+
# horizontal (down the page): X/W stable, Y regular
|
|
90
|
+
if (x_stable or w_stable) and self._regular(list(ys)):
|
|
91
|
+
return True
|
|
92
|
+
# vertical (across): Y/H stable, X regular
|
|
93
|
+
if (y_stable or h_stable) and self._regular(list(xs)):
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
for m in self._PATTERN.finditer(s):
|
|
98
|
+
tag, text = m.group("tag"), m.group("text")
|
|
99
|
+
box = (
|
|
100
|
+
int(m.group("x")),
|
|
101
|
+
int(m.group("y")),
|
|
102
|
+
int(m.group("w")),
|
|
103
|
+
int(m.group("h")),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if prev_tag == tag and prev_text == text:
|
|
107
|
+
run.append(box) # consecutive same-tag+text
|
|
108
|
+
else:
|
|
109
|
+
# evaluate previous run before starting a new one
|
|
110
|
+
if run_repetitive(run):
|
|
111
|
+
return True
|
|
112
|
+
prev_tag, prev_text = tag, text
|
|
113
|
+
run = [box]
|
|
114
|
+
|
|
115
|
+
# check the last run
|
|
116
|
+
return run_repetitive(run)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class HFStoppingCriteriaWrapper(StoppingCriteria):
|
|
120
|
+
"""
|
|
121
|
+
Adapts any GenerationStopper to HuggingFace Transformers.
|
|
122
|
+
Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
tokenizer,
|
|
128
|
+
stopper: GenerationStopper,
|
|
129
|
+
*,
|
|
130
|
+
skip_special_tokens: bool = False,
|
|
131
|
+
):
|
|
132
|
+
self.tokenizer = tokenizer
|
|
133
|
+
self.stopper = stopper
|
|
134
|
+
self.skip_special_tokens = skip_special_tokens
|
|
135
|
+
|
|
136
|
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
|
137
|
+
lb = max(1, int(self.stopper.lookback_tokens()))
|
|
138
|
+
for seq in input_ids: # (batch, seq_len)
|
|
139
|
+
window = seq[-lb:] # slicing handles lb > len(seq)
|
|
140
|
+
try:
|
|
141
|
+
text = self.tokenizer.decode(
|
|
142
|
+
window, skip_special_tokens=self.skip_special_tokens
|
|
143
|
+
)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
_log.info(f"Decoding failed for stopping check: {e}")
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
if self.stopper.should_stop(text):
|
|
150
|
+
_log.info(
|
|
151
|
+
"HF wrapper: stopping due to TextStopper.should_stop==True"
|
|
152
|
+
)
|
|
153
|
+
return True
|
|
154
|
+
except Exception as e:
|
|
155
|
+
_log.info(f"Error in TextStopper.should_stop: {e}")
|
|
156
|
+
continue
|
|
157
|
+
return False
|
|
@@ -34,7 +34,12 @@ class HuggingFaceModelDownloadMixin:
|
|
|
34
34
|
local_dir: Optional[Path] = None,
|
|
35
35
|
force: bool = False,
|
|
36
36
|
progress: bool = False,
|
|
37
|
+
revision: Optional[str] = None,
|
|
37
38
|
) -> Path:
|
|
38
39
|
return download_hf_model(
|
|
39
|
-
repo_id=repo_id,
|
|
40
|
+
repo_id=repo_id,
|
|
41
|
+
local_dir=local_dir,
|
|
42
|
+
force=force,
|
|
43
|
+
progress=progress,
|
|
44
|
+
revision=revision,
|
|
40
45
|
)
|
|
@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
|
|
7
7
|
|
|
8
8
|
import numpy as np
|
|
9
9
|
from PIL.Image import Image
|
|
10
|
-
from transformers import StoppingCriteriaList, StopStringCriteria
|
|
10
|
+
from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
|
|
11
11
|
|
|
12
12
|
from docling.datamodel.accelerator_options import (
|
|
13
13
|
AcceleratorOptions,
|
|
@@ -20,6 +20,10 @@ from docling.datamodel.pipeline_options_vlm_model import (
|
|
|
20
20
|
TransformersPromptStyle,
|
|
21
21
|
)
|
|
22
22
|
from docling.models.base_model import BaseVlmPageModel
|
|
23
|
+
from docling.models.utils.generation_utils import (
|
|
24
|
+
GenerationStopper,
|
|
25
|
+
HFStoppingCriteriaWrapper,
|
|
26
|
+
)
|
|
23
27
|
from docling.models.utils.hf_model_download import (
|
|
24
28
|
HuggingFaceModelDownloadMixin,
|
|
25
29
|
)
|
|
@@ -75,7 +79,9 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
75
79
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
76
80
|
|
|
77
81
|
if artifacts_path is None:
|
|
78
|
-
artifacts_path = self.download_models(
|
|
82
|
+
artifacts_path = self.download_models(
|
|
83
|
+
self.vlm_options.repo_id, revision=self.vlm_options.revision
|
|
84
|
+
)
|
|
79
85
|
elif (artifacts_path / repo_cache_folder).exists():
|
|
80
86
|
artifacts_path = artifacts_path / repo_cache_folder
|
|
81
87
|
|
|
@@ -106,6 +112,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
106
112
|
self.processor = AutoProcessor.from_pretrained(
|
|
107
113
|
artifacts_path,
|
|
108
114
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
115
|
+
revision=vlm_options.revision,
|
|
109
116
|
)
|
|
110
117
|
self.processor.tokenizer.padding_side = "left"
|
|
111
118
|
|
|
@@ -120,11 +127,14 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
120
127
|
else "sdpa"
|
|
121
128
|
),
|
|
122
129
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
130
|
+
revision=vlm_options.revision,
|
|
123
131
|
)
|
|
124
132
|
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
|
125
133
|
|
|
126
134
|
# Load generation config
|
|
127
|
-
self.generation_config = GenerationConfig.from_pretrained(
|
|
135
|
+
self.generation_config = GenerationConfig.from_pretrained(
|
|
136
|
+
artifacts_path, revision=vlm_options.revision
|
|
137
|
+
)
|
|
128
138
|
|
|
129
139
|
def __call__(
|
|
130
140
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
@@ -196,7 +206,7 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
196
206
|
import torch
|
|
197
207
|
from PIL import Image as PILImage
|
|
198
208
|
|
|
199
|
-
# -- Normalize images to RGB PIL
|
|
209
|
+
# -- Normalize images to RGB PIL
|
|
200
210
|
pil_images: list[Image] = []
|
|
201
211
|
for img in image_batch:
|
|
202
212
|
if isinstance(img, np.ndarray):
|
|
@@ -247,24 +257,74 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
247
257
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
248
258
|
|
|
249
259
|
# -- Optional stopping criteria
|
|
250
|
-
|
|
260
|
+
stopping_criteria_list: StoppingCriteriaList = StoppingCriteriaList()
|
|
261
|
+
|
|
262
|
+
# Add string-based stopping criteria
|
|
251
263
|
if self.vlm_options.stop_strings:
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
)
|
|
258
|
-
]
|
|
264
|
+
stopping_criteria_list.append(
|
|
265
|
+
StopStringCriteria(
|
|
266
|
+
stop_strings=self.vlm_options.stop_strings,
|
|
267
|
+
tokenizer=self.processor.tokenizer,
|
|
268
|
+
)
|
|
259
269
|
)
|
|
260
270
|
|
|
271
|
+
# Add custom stopping criteria
|
|
272
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
273
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
274
|
+
# If it's a class (not an instance), determine the type and handle accordingly
|
|
275
|
+
if isinstance(criteria, type):
|
|
276
|
+
# Check if it's a GenerationStopper class
|
|
277
|
+
if issubclass(criteria, GenerationStopper):
|
|
278
|
+
# Instantiate GenerationStopper and wrap it
|
|
279
|
+
stopper_instance = criteria()
|
|
280
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
281
|
+
self.processor.tokenizer, stopper_instance
|
|
282
|
+
)
|
|
283
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
284
|
+
elif issubclass(criteria, StoppingCriteria):
|
|
285
|
+
# It's a StoppingCriteria class, instantiate with tokenizer
|
|
286
|
+
criteria_instance = criteria(self.processor.tokenizer)
|
|
287
|
+
stopping_criteria_list.append(criteria_instance)
|
|
288
|
+
elif isinstance(criteria, GenerationStopper):
|
|
289
|
+
# Wrap GenerationStopper instances in HFStoppingCriteriaWrapper
|
|
290
|
+
wrapped_criteria = HFStoppingCriteriaWrapper(
|
|
291
|
+
self.processor.tokenizer, criteria
|
|
292
|
+
)
|
|
293
|
+
stopping_criteria_list.append(wrapped_criteria)
|
|
294
|
+
else:
|
|
295
|
+
# If it's already an instance of StoppingCriteria, use it directly
|
|
296
|
+
stopping_criteria_list.append(criteria)
|
|
297
|
+
|
|
298
|
+
stopping_criteria = (
|
|
299
|
+
StoppingCriteriaList(stopping_criteria_list)
|
|
300
|
+
if stopping_criteria_list
|
|
301
|
+
else None
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
# -- Filter out decoder-specific keys from extra_generation_config
|
|
305
|
+
decoder_keys = {
|
|
306
|
+
"skip_special_tokens",
|
|
307
|
+
"clean_up_tokenization_spaces",
|
|
308
|
+
"spaces_between_special_tokens",
|
|
309
|
+
}
|
|
310
|
+
generation_config = {
|
|
311
|
+
k: v
|
|
312
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
313
|
+
if k not in decoder_keys
|
|
314
|
+
}
|
|
315
|
+
decoder_config = {
|
|
316
|
+
k: v
|
|
317
|
+
for k, v in self.vlm_options.extra_generation_config.items()
|
|
318
|
+
if k in decoder_keys
|
|
319
|
+
}
|
|
320
|
+
|
|
261
321
|
# -- Generate (Image-Text-to-Text class expects these inputs from processor)
|
|
262
322
|
gen_kwargs = {
|
|
263
323
|
**inputs,
|
|
264
324
|
"max_new_tokens": self.max_new_tokens,
|
|
265
325
|
"use_cache": self.use_cache,
|
|
266
326
|
"generation_config": self.generation_config,
|
|
267
|
-
**
|
|
327
|
+
**generation_config,
|
|
268
328
|
}
|
|
269
329
|
if self.temperature > 0:
|
|
270
330
|
gen_kwargs["do_sample"] = True
|
|
@@ -293,7 +353,8 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
293
353
|
)
|
|
294
354
|
|
|
295
355
|
decoded_texts: list[str] = decode_fn(
|
|
296
|
-
trimmed_sequences,
|
|
356
|
+
trimmed_sequences,
|
|
357
|
+
**decoder_config,
|
|
297
358
|
)
|
|
298
359
|
|
|
299
360
|
# -- Clip off pad tokens from decoded texts
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
2
3
|
import threading
|
|
3
4
|
import time
|
|
4
5
|
from collections.abc import Iterable
|
|
@@ -7,6 +8,7 @@ from typing import Optional, Union
|
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
from PIL.Image import Image
|
|
11
|
+
from transformers import StoppingCriteria
|
|
10
12
|
|
|
11
13
|
from docling.datamodel.accelerator_options import (
|
|
12
14
|
AcceleratorOptions,
|
|
@@ -15,6 +17,7 @@ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToke
|
|
|
15
17
|
from docling.datamodel.document import ConversionResult
|
|
16
18
|
from docling.datamodel.pipeline_options_vlm_model import InlineVlmOptions
|
|
17
19
|
from docling.models.base_model import BaseVlmPageModel
|
|
20
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
18
21
|
from docling.models.utils.hf_model_download import (
|
|
19
22
|
HuggingFaceModelDownloadMixin,
|
|
20
23
|
)
|
|
@@ -60,6 +63,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
60
63
|
if artifacts_path is None:
|
|
61
64
|
artifacts_path = self.download_models(
|
|
62
65
|
self.vlm_options.repo_id,
|
|
66
|
+
revision=self.vlm_options.revision,
|
|
63
67
|
)
|
|
64
68
|
elif (artifacts_path / repo_cache_folder).exists():
|
|
65
69
|
artifacts_path = artifacts_path / repo_cache_folder
|
|
@@ -68,6 +72,22 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
68
72
|
self.vlm_model, self.processor = load(artifacts_path)
|
|
69
73
|
self.config = load_config(artifacts_path)
|
|
70
74
|
|
|
75
|
+
# Validate custom stopping criteria - MLX doesn't support HF StoppingCriteria
|
|
76
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
77
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
78
|
+
if isinstance(criteria, StoppingCriteria):
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"MLX models do not support HuggingFace StoppingCriteria instances. "
|
|
81
|
+
f"Found {type(criteria).__name__}. Use GenerationStopper instead."
|
|
82
|
+
)
|
|
83
|
+
elif isinstance(criteria, type) and issubclass(
|
|
84
|
+
criteria, StoppingCriteria
|
|
85
|
+
):
|
|
86
|
+
raise ValueError(
|
|
87
|
+
f"MLX models do not support HuggingFace StoppingCriteria classes. "
|
|
88
|
+
f"Found {criteria.__name__}. Use GenerationStopper instead."
|
|
89
|
+
)
|
|
90
|
+
|
|
71
91
|
def __call__(
|
|
72
92
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
|
73
93
|
) -> Iterable[Page]:
|
|
@@ -192,7 +212,7 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
192
212
|
self.processor, self.config, user_prompt, num_images=1
|
|
193
213
|
)
|
|
194
214
|
|
|
195
|
-
# Stream generate with stop strings support
|
|
215
|
+
# Stream generate with stop strings and custom stopping criteria support
|
|
196
216
|
start_time = time.time()
|
|
197
217
|
_log.debug("start generating ...")
|
|
198
218
|
|
|
@@ -244,6 +264,43 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
244
264
|
_log.debug("Stopping generation due to stop string match")
|
|
245
265
|
break
|
|
246
266
|
|
|
267
|
+
# Check for custom stopping criteria (GenerationStopper instances)
|
|
268
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
269
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
270
|
+
# Handle both instances and classes of GenerationStopper
|
|
271
|
+
if isinstance(criteria, GenerationStopper):
|
|
272
|
+
stopper = criteria
|
|
273
|
+
elif isinstance(criteria, type) and issubclass(
|
|
274
|
+
criteria, GenerationStopper
|
|
275
|
+
):
|
|
276
|
+
stopper = criteria()
|
|
277
|
+
|
|
278
|
+
# Determine the text window to check based on lookback_tokens
|
|
279
|
+
lookback_tokens = stopper.lookback_tokens()
|
|
280
|
+
# Check only the last N characters worth of text
|
|
281
|
+
# This is a simplified approach - in practice, you might want to
|
|
282
|
+
# decode the last N tokens from the token list for more accuracy
|
|
283
|
+
text_to_check = (
|
|
284
|
+
output[-lookback_tokens:]
|
|
285
|
+
if len(output) > lookback_tokens
|
|
286
|
+
else output
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
if stopper.should_stop(text_to_check):
|
|
291
|
+
_log.info(
|
|
292
|
+
f"Stopping generation due to GenerationStopper: {type(stopper).__name__}"
|
|
293
|
+
)
|
|
294
|
+
break
|
|
295
|
+
except Exception as e:
|
|
296
|
+
_log.warning(
|
|
297
|
+
f"Error in GenerationStopper.should_stop: {e}"
|
|
298
|
+
)
|
|
299
|
+
continue
|
|
300
|
+
else: # note: for-else idiom
|
|
301
|
+
continue # Only executed if the inner loop didn't break
|
|
302
|
+
break # Break the outer loop if any stopper triggered
|
|
303
|
+
|
|
247
304
|
generation_time = time.time() - start_time
|
|
248
305
|
|
|
249
306
|
_log.debug(
|