docling 2.57.0__py3-none-any.whl → 2.59.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of docling might be problematic. Click here for more details.
- docling/backend/abstract_backend.py +24 -3
- docling/backend/asciidoc_backend.py +3 -3
- docling/backend/docling_parse_v4_backend.py +15 -4
- docling/backend/html_backend.py +130 -20
- docling/backend/md_backend.py +27 -5
- docling/backend/msexcel_backend.py +121 -29
- docling/backend/mspowerpoint_backend.py +2 -2
- docling/backend/msword_backend.py +18 -18
- docling/backend/pdf_backend.py +9 -2
- docling/backend/pypdfium2_backend.py +12 -3
- docling/cli/main.py +104 -38
- docling/datamodel/asr_model_specs.py +408 -6
- docling/datamodel/backend_options.py +82 -0
- docling/datamodel/base_models.py +19 -2
- docling/datamodel/document.py +81 -48
- docling/datamodel/pipeline_options_asr_model.py +21 -1
- docling/datamodel/pipeline_options_vlm_model.py +1 -0
- docling/document_converter.py +37 -45
- docling/document_extractor.py +12 -11
- docling/models/api_vlm_model.py +5 -3
- docling/models/picture_description_vlm_model.py +5 -1
- docling/models/readingorder_model.py +6 -7
- docling/models/vlm_models_inline/hf_transformers_model.py +13 -3
- docling/models/vlm_models_inline/mlx_model.py +9 -3
- docling/models/vlm_models_inline/nuextract_transformers_model.py +13 -3
- docling/models/vlm_models_inline/vllm_model.py +42 -8
- docling/pipeline/asr_pipeline.py +149 -6
- docling/utils/api_image_request.py +20 -9
- docling/utils/layout_postprocessor.py +23 -24
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/METADATA +11 -8
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/RECORD +35 -34
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/WHEEL +0 -0
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/entry_points.txt +0 -0
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/licenses/LICENSE +0 -0
- {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/top_level.txt +0 -0
docling/models/api_vlm_model.py
CHANGED
|
@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
|
|
|
73
73
|
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
|
74
74
|
|
|
75
75
|
# Streaming path with early abort support
|
|
76
|
-
page_tags = api_image_request_streaming(
|
|
76
|
+
page_tags, num_tokens = api_image_request_streaming(
|
|
77
77
|
image=hi_res_image,
|
|
78
78
|
prompt=prompt,
|
|
79
79
|
url=self.vlm_options.url,
|
|
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
|
|
|
84
84
|
)
|
|
85
85
|
else:
|
|
86
86
|
# Non-streaming fallback (existing behavior)
|
|
87
|
-
page_tags = api_image_request(
|
|
87
|
+
page_tags, num_tokens = api_image_request(
|
|
88
88
|
image=hi_res_image,
|
|
89
89
|
prompt=prompt,
|
|
90
90
|
url=self.vlm_options.url,
|
|
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
|
|
|
94
94
|
)
|
|
95
95
|
|
|
96
96
|
page_tags = self.vlm_options.decode_response(page_tags)
|
|
97
|
-
page.predictions.vlm_response = VlmPrediction(
|
|
97
|
+
page.predictions.vlm_response = VlmPrediction(
|
|
98
|
+
text=page_tags, num_tokens=num_tokens
|
|
99
|
+
)
|
|
98
100
|
return page
|
|
99
101
|
|
|
100
102
|
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import sys
|
|
1
2
|
import threading
|
|
2
3
|
from collections.abc import Iterable
|
|
3
4
|
from pathlib import Path
|
|
@@ -75,7 +76,10 @@ class PictureDescriptionVlmModel(
|
|
|
75
76
|
else "sdpa"
|
|
76
77
|
),
|
|
77
78
|
)
|
|
78
|
-
|
|
79
|
+
if sys.version_info < (3, 14):
|
|
80
|
+
self.model = torch.compile(self.model) # type: ignore
|
|
81
|
+
else:
|
|
82
|
+
self.model.eval()
|
|
79
83
|
|
|
80
84
|
self.provenance = f"{self.options.repo_id}"
|
|
81
85
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
from pathlib import Path
|
|
2
|
-
from typing import Dict, List
|
|
3
2
|
|
|
4
3
|
from docling_core.types.doc import (
|
|
5
4
|
DocItemLabel,
|
|
@@ -48,8 +47,8 @@ class ReadingOrderModel:
|
|
|
48
47
|
|
|
49
48
|
def _assembled_to_readingorder_elements(
|
|
50
49
|
self, conv_res: ConversionResult
|
|
51
|
-
) ->
|
|
52
|
-
elements:
|
|
50
|
+
) -> list[ReadingOrderPageElement]:
|
|
51
|
+
elements: list[ReadingOrderPageElement] = []
|
|
53
52
|
page_no_to_pages = {p.page_no: p for p in conv_res.pages}
|
|
54
53
|
|
|
55
54
|
for element in conv_res.assembled.elements:
|
|
@@ -123,10 +122,10 @@ class ReadingOrderModel:
|
|
|
123
122
|
def _readingorder_elements_to_docling_doc(
|
|
124
123
|
self,
|
|
125
124
|
conv_res: ConversionResult,
|
|
126
|
-
ro_elements:
|
|
127
|
-
el_to_captions_mapping:
|
|
128
|
-
el_to_footnotes_mapping:
|
|
129
|
-
el_merges_mapping:
|
|
125
|
+
ro_elements: list[ReadingOrderPageElement],
|
|
126
|
+
el_to_captions_mapping: dict[int, list[int]],
|
|
127
|
+
el_to_footnotes_mapping: dict[int, list[int]],
|
|
128
|
+
el_merges_mapping: dict[int, list[int]],
|
|
130
129
|
) -> DoclingDocument:
|
|
131
130
|
id_to_elem = {
|
|
132
131
|
RefItem(cref=f"#/{elem.page_no}/{elem.cluster.id}").cref: elem
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
2
|
import logging
|
|
3
|
+
import sys
|
|
3
4
|
import time
|
|
4
5
|
from collections.abc import Iterable
|
|
5
6
|
from pathlib import Path
|
|
@@ -129,7 +130,10 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
129
130
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
130
131
|
revision=vlm_options.revision,
|
|
131
132
|
)
|
|
132
|
-
|
|
133
|
+
if sys.version_info < (3, 14):
|
|
134
|
+
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
|
135
|
+
else:
|
|
136
|
+
self.vlm_model.eval()
|
|
133
137
|
|
|
134
138
|
# Load generation config
|
|
135
139
|
self.generation_config = GenerationConfig.from_pretrained(
|
|
@@ -363,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
|
|
|
363
367
|
decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
|
|
364
368
|
|
|
365
369
|
# -- Optional logging
|
|
370
|
+
num_tokens = None
|
|
366
371
|
if generated_ids.shape[0] > 0:
|
|
372
|
+
num_tokens = int(generated_ids[0].shape[0])
|
|
367
373
|
_log.debug(
|
|
368
|
-
f"Generated {
|
|
374
|
+
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
|
369
375
|
f"for batch size {generated_ids.shape[0]}."
|
|
370
376
|
)
|
|
371
377
|
|
|
372
378
|
for text in decoded_texts:
|
|
373
379
|
# Apply decode_response to the output text
|
|
374
380
|
decoded_text = self.vlm_options.decode_response(text)
|
|
375
|
-
yield VlmPrediction(
|
|
381
|
+
yield VlmPrediction(
|
|
382
|
+
text=decoded_text,
|
|
383
|
+
generation_time=generation_time,
|
|
384
|
+
num_tokens=num_tokens,
|
|
385
|
+
)
|
|
@@ -50,9 +50,14 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
50
50
|
from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
|
|
51
51
|
from mlx_vlm.utils import load_config # type: ignore
|
|
52
52
|
except ImportError:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
53
|
+
if sys.version_info < (3, 14):
|
|
54
|
+
raise ImportError(
|
|
55
|
+
"mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
|
|
56
|
+
)
|
|
57
|
+
else:
|
|
58
|
+
raise ImportError(
|
|
59
|
+
"mlx-vlm is not installed. It is not yet available on Python 3.14."
|
|
60
|
+
)
|
|
56
61
|
|
|
57
62
|
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
|
58
63
|
|
|
@@ -313,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
313
318
|
text=decoded_output,
|
|
314
319
|
generation_time=generation_time,
|
|
315
320
|
generated_tokens=tokens,
|
|
321
|
+
num_tokens=len(tokens),
|
|
316
322
|
)
|
|
317
323
|
_log.debug("MLX model: Released global lock")
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
2
3
|
import time
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from pathlib import Path
|
|
@@ -153,7 +154,10 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|
|
153
154
|
),
|
|
154
155
|
trust_remote_code=vlm_options.trust_remote_code,
|
|
155
156
|
)
|
|
156
|
-
|
|
157
|
+
if sys.version_info < (3, 14):
|
|
158
|
+
self.vlm_model = torch.compile(self.vlm_model) # type: ignore
|
|
159
|
+
else:
|
|
160
|
+
self.vlm_model.eval()
|
|
157
161
|
|
|
158
162
|
# Load generation config
|
|
159
163
|
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
|
@@ -278,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
|
|
|
278
282
|
)
|
|
279
283
|
|
|
280
284
|
# Optional logging
|
|
285
|
+
num_tokens = None
|
|
281
286
|
if generated_ids.shape[0] > 0: # type: ignore
|
|
287
|
+
num_tokens = int(generated_ids[0].shape[0])
|
|
282
288
|
_log.debug(
|
|
283
|
-
f"Generated {
|
|
289
|
+
f"Generated {num_tokens} tokens in {generation_time:.2f}s "
|
|
284
290
|
f"for batch size {generated_ids.shape[0]}." # type: ignore
|
|
285
291
|
)
|
|
286
292
|
|
|
287
293
|
for text in decoded_texts:
|
|
288
294
|
# Apply decode_response to the output text
|
|
289
295
|
decoded_text = self.vlm_options.decode_response(text)
|
|
290
|
-
yield VlmPrediction(
|
|
296
|
+
yield VlmPrediction(
|
|
297
|
+
text=decoded_text,
|
|
298
|
+
generation_time=generation_time,
|
|
299
|
+
num_tokens=num_tokens,
|
|
300
|
+
)
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import sys
|
|
2
3
|
import time
|
|
3
4
|
from collections.abc import Iterable
|
|
4
5
|
from pathlib import Path
|
|
@@ -8,7 +9,7 @@ import numpy as np
|
|
|
8
9
|
from PIL.Image import Image
|
|
9
10
|
|
|
10
11
|
from docling.datamodel.accelerator_options import AcceleratorOptions
|
|
11
|
-
from docling.datamodel.base_models import Page, VlmPrediction
|
|
12
|
+
from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
|
|
12
13
|
from docling.datamodel.document import ConversionResult
|
|
13
14
|
from docling.datamodel.pipeline_options_vlm_model import (
|
|
14
15
|
InlineVlmOptions,
|
|
@@ -87,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
87
88
|
vlm_options: InlineVlmOptions,
|
|
88
89
|
):
|
|
89
90
|
self.enabled = enabled
|
|
90
|
-
self.vlm_options = vlm_options
|
|
91
|
+
self.vlm_options: InlineVlmOptions = vlm_options
|
|
91
92
|
|
|
92
93
|
self.llm = None
|
|
93
94
|
self.sampling_params = None
|
|
@@ -100,7 +101,18 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
100
101
|
return
|
|
101
102
|
|
|
102
103
|
from transformers import AutoProcessor
|
|
103
|
-
|
|
104
|
+
|
|
105
|
+
try:
|
|
106
|
+
from vllm import LLM, SamplingParams
|
|
107
|
+
except ImportError:
|
|
108
|
+
if sys.version_info < (3, 14):
|
|
109
|
+
raise ImportError(
|
|
110
|
+
"vllm is not installed. Please install it via `pip install vllm`."
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
raise ImportError(
|
|
114
|
+
"vllm is not installed. It is not yet available on Python 3.14."
|
|
115
|
+
)
|
|
104
116
|
|
|
105
117
|
# Device selection
|
|
106
118
|
self.device = decide_device(
|
|
@@ -222,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
222
234
|
pages_with_images.append(page)
|
|
223
235
|
|
|
224
236
|
if images:
|
|
225
|
-
|
|
237
|
+
with TimeRecorder(conv_res, "vlm_inference"):
|
|
238
|
+
predictions = list(self.process_images(images, user_prompts))
|
|
226
239
|
for page, prediction in zip(pages_with_images, predictions):
|
|
227
240
|
page.predictions.vlm_response = prediction
|
|
228
241
|
|
|
@@ -288,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
|
|
|
288
301
|
# Optional debug
|
|
289
302
|
if outputs:
|
|
290
303
|
try:
|
|
291
|
-
|
|
292
|
-
_log.debug(
|
|
304
|
+
num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
|
|
305
|
+
_log.debug(
|
|
306
|
+
f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
|
|
307
|
+
)
|
|
293
308
|
except Exception:
|
|
294
|
-
|
|
309
|
+
num_tokens_within_batch = 0
|
|
295
310
|
|
|
296
311
|
# Emit predictions
|
|
297
312
|
for output in outputs:
|
|
298
313
|
text = output.outputs[0].text if output.outputs else ""
|
|
314
|
+
stop_reason = output.outputs[0].stop_reason if output.outputs else ""
|
|
315
|
+
generated_tokens = [
|
|
316
|
+
VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
|
|
317
|
+
]
|
|
318
|
+
num_tokens = len(generated_tokens)
|
|
299
319
|
decoded_text = self.vlm_options.decode_response(text)
|
|
300
|
-
|
|
320
|
+
if self.vlm_options.track_generated_tokens:
|
|
321
|
+
yield VlmPrediction(
|
|
322
|
+
text=decoded_text,
|
|
323
|
+
generation_time=generation_time,
|
|
324
|
+
num_tokens=num_tokens,
|
|
325
|
+
stop_reason=stop_reason,
|
|
326
|
+
generated_tokens=generated_tokens,
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
yield VlmPrediction(
|
|
330
|
+
text=decoded_text,
|
|
331
|
+
generation_time=generation_time,
|
|
332
|
+
num_tokens=num_tokens,
|
|
333
|
+
stop_reason=stop_reason,
|
|
334
|
+
)
|
docling/pipeline/asr_pipeline.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import re
|
|
4
|
+
import sys
|
|
4
5
|
import tempfile
|
|
5
6
|
from io import BytesIO
|
|
6
7
|
from pathlib import Path
|
|
7
|
-
from typing import List, Optional, Union, cast
|
|
8
|
+
from typing import TYPE_CHECKING, List, Optional, Union, cast
|
|
8
9
|
|
|
9
10
|
from docling_core.types.doc import DoclingDocument, DocumentOrigin
|
|
10
11
|
|
|
@@ -32,6 +33,7 @@ from docling.datamodel.pipeline_options import (
|
|
|
32
33
|
AsrPipelineOptions,
|
|
33
34
|
)
|
|
34
35
|
from docling.datamodel.pipeline_options_asr_model import (
|
|
36
|
+
InlineAsrMlxWhisperOptions,
|
|
35
37
|
InlineAsrNativeWhisperOptions,
|
|
36
38
|
# AsrResponseFormat,
|
|
37
39
|
InlineAsrOptions,
|
|
@@ -116,9 +118,15 @@ class _NativeWhisperModel:
|
|
|
116
118
|
try:
|
|
117
119
|
import whisper # type: ignore
|
|
118
120
|
except ImportError:
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
121
|
+
if sys.version_info < (3, 14):
|
|
122
|
+
raise ImportError(
|
|
123
|
+
"whisper is not installed. Please install it via `pip install openai-whisper` or do `uv sync --extra asr`."
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
raise ImportError(
|
|
127
|
+
"whisper is not installed. Unfortunately its dependencies are not yet available for Python 3.14."
|
|
128
|
+
)
|
|
129
|
+
|
|
122
130
|
self.asr_options = asr_options
|
|
123
131
|
self.max_tokens = asr_options.max_new_tokens
|
|
124
132
|
self.temperature = asr_options.temperature
|
|
@@ -228,22 +236,157 @@ class _NativeWhisperModel:
|
|
|
228
236
|
return convo
|
|
229
237
|
|
|
230
238
|
|
|
239
|
+
class _MlxWhisperModel:
|
|
240
|
+
def __init__(
|
|
241
|
+
self,
|
|
242
|
+
enabled: bool,
|
|
243
|
+
artifacts_path: Optional[Path],
|
|
244
|
+
accelerator_options: AcceleratorOptions,
|
|
245
|
+
asr_options: InlineAsrMlxWhisperOptions,
|
|
246
|
+
):
|
|
247
|
+
"""
|
|
248
|
+
Transcriber using MLX Whisper for Apple Silicon optimization.
|
|
249
|
+
"""
|
|
250
|
+
self.enabled = enabled
|
|
251
|
+
|
|
252
|
+
_log.info(f"artifacts-path: {artifacts_path}")
|
|
253
|
+
_log.info(f"accelerator_options: {accelerator_options}")
|
|
254
|
+
|
|
255
|
+
if self.enabled:
|
|
256
|
+
try:
|
|
257
|
+
import mlx_whisper # type: ignore
|
|
258
|
+
except ImportError:
|
|
259
|
+
raise ImportError(
|
|
260
|
+
"mlx-whisper is not installed. Please install it via `pip install mlx-whisper` or do `uv sync --extra asr`."
|
|
261
|
+
)
|
|
262
|
+
self.asr_options = asr_options
|
|
263
|
+
self.mlx_whisper = mlx_whisper
|
|
264
|
+
|
|
265
|
+
self.device = decide_device(
|
|
266
|
+
accelerator_options.device,
|
|
267
|
+
supported_devices=asr_options.supported_devices,
|
|
268
|
+
)
|
|
269
|
+
_log.info(f"Available device for MLX Whisper: {self.device}")
|
|
270
|
+
|
|
271
|
+
self.model_name = asr_options.repo_id
|
|
272
|
+
_log.info(f"loading _MlxWhisperModel({self.model_name})")
|
|
273
|
+
|
|
274
|
+
# MLX Whisper models are loaded differently - they use HuggingFace repos
|
|
275
|
+
self.model_path = self.model_name
|
|
276
|
+
|
|
277
|
+
# Store MLX-specific options
|
|
278
|
+
self.language = asr_options.language
|
|
279
|
+
self.task = asr_options.task
|
|
280
|
+
self.word_timestamps = asr_options.word_timestamps
|
|
281
|
+
self.no_speech_threshold = asr_options.no_speech_threshold
|
|
282
|
+
self.logprob_threshold = asr_options.logprob_threshold
|
|
283
|
+
self.compression_ratio_threshold = asr_options.compression_ratio_threshold
|
|
284
|
+
|
|
285
|
+
def run(self, conv_res: ConversionResult) -> ConversionResult:
|
|
286
|
+
audio_path: Path = Path(conv_res.input.file).resolve()
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
conversation = self.transcribe(audio_path)
|
|
290
|
+
|
|
291
|
+
# Ensure we have a proper DoclingDocument
|
|
292
|
+
origin = DocumentOrigin(
|
|
293
|
+
filename=conv_res.input.file.name or "audio.wav",
|
|
294
|
+
mimetype="audio/x-wav",
|
|
295
|
+
binary_hash=conv_res.input.document_hash,
|
|
296
|
+
)
|
|
297
|
+
conv_res.document = DoclingDocument(
|
|
298
|
+
name=conv_res.input.file.stem or "audio.wav", origin=origin
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
for citem in conversation:
|
|
302
|
+
conv_res.document.add_text(
|
|
303
|
+
label=DocItemLabel.TEXT, text=citem.to_string()
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
conv_res.status = ConversionStatus.SUCCESS
|
|
307
|
+
return conv_res
|
|
308
|
+
|
|
309
|
+
except Exception as exc:
|
|
310
|
+
_log.error(f"MLX Audio transcription has an error: {exc}")
|
|
311
|
+
|
|
312
|
+
conv_res.status = ConversionStatus.FAILURE
|
|
313
|
+
return conv_res
|
|
314
|
+
|
|
315
|
+
def transcribe(self, fpath: Path) -> list[_ConversationItem]:
|
|
316
|
+
"""
|
|
317
|
+
Transcribe audio using MLX Whisper.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
fpath: Path to audio file
|
|
321
|
+
|
|
322
|
+
Returns:
|
|
323
|
+
List of conversation items with timestamps
|
|
324
|
+
"""
|
|
325
|
+
result = self.mlx_whisper.transcribe(
|
|
326
|
+
str(fpath),
|
|
327
|
+
path_or_hf_repo=self.model_path,
|
|
328
|
+
language=self.language,
|
|
329
|
+
task=self.task,
|
|
330
|
+
word_timestamps=self.word_timestamps,
|
|
331
|
+
no_speech_threshold=self.no_speech_threshold,
|
|
332
|
+
logprob_threshold=self.logprob_threshold,
|
|
333
|
+
compression_ratio_threshold=self.compression_ratio_threshold,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
convo: list[_ConversationItem] = []
|
|
337
|
+
|
|
338
|
+
# MLX Whisper returns segments similar to native Whisper
|
|
339
|
+
for segment in result.get("segments", []):
|
|
340
|
+
item = _ConversationItem(
|
|
341
|
+
start_time=segment.get("start"),
|
|
342
|
+
end_time=segment.get("end"),
|
|
343
|
+
text=segment.get("text", "").strip(),
|
|
344
|
+
words=[],
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
# Add word-level timestamps if available
|
|
348
|
+
if self.word_timestamps and "words" in segment:
|
|
349
|
+
item.words = []
|
|
350
|
+
for word_data in segment["words"]:
|
|
351
|
+
item.words.append(
|
|
352
|
+
_ConversationWord(
|
|
353
|
+
start_time=word_data.get("start"),
|
|
354
|
+
end_time=word_data.get("end"),
|
|
355
|
+
text=word_data.get("word", ""),
|
|
356
|
+
)
|
|
357
|
+
)
|
|
358
|
+
convo.append(item)
|
|
359
|
+
|
|
360
|
+
return convo
|
|
361
|
+
|
|
362
|
+
|
|
231
363
|
class AsrPipeline(BasePipeline):
|
|
232
364
|
def __init__(self, pipeline_options: AsrPipelineOptions):
|
|
233
365
|
super().__init__(pipeline_options)
|
|
234
366
|
self.keep_backend = True
|
|
235
367
|
|
|
236
368
|
self.pipeline_options: AsrPipelineOptions = pipeline_options
|
|
369
|
+
self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
|
|
237
370
|
|
|
238
371
|
if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
|
|
239
|
-
|
|
372
|
+
native_asr_options: InlineAsrNativeWhisperOptions = (
|
|
240
373
|
self.pipeline_options.asr_options
|
|
241
374
|
)
|
|
242
375
|
self._model = _NativeWhisperModel(
|
|
243
376
|
enabled=True, # must be always enabled for this pipeline to make sense.
|
|
244
377
|
artifacts_path=self.artifacts_path,
|
|
245
378
|
accelerator_options=pipeline_options.accelerator_options,
|
|
246
|
-
asr_options=
|
|
379
|
+
asr_options=native_asr_options,
|
|
380
|
+
)
|
|
381
|
+
elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
|
|
382
|
+
mlx_asr_options: InlineAsrMlxWhisperOptions = (
|
|
383
|
+
self.pipeline_options.asr_options
|
|
384
|
+
)
|
|
385
|
+
self._model = _MlxWhisperModel(
|
|
386
|
+
enabled=True, # must be always enabled for this pipeline to make sense.
|
|
387
|
+
artifacts_path=self.artifacts_path,
|
|
388
|
+
accelerator_options=pipeline_options.accelerator_options,
|
|
389
|
+
asr_options=mlx_asr_options,
|
|
247
390
|
)
|
|
248
391
|
else:
|
|
249
392
|
_log.error(f"No model support for {self.pipeline_options.asr_options}")
|
|
@@ -2,7 +2,7 @@ import base64
|
|
|
2
2
|
import json
|
|
3
3
|
import logging
|
|
4
4
|
from io import BytesIO
|
|
5
|
-
from typing import Dict, List, Optional
|
|
5
|
+
from typing import Dict, List, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import requests
|
|
8
8
|
from PIL import Image
|
|
@@ -19,9 +19,9 @@ def api_image_request(
|
|
|
19
19
|
prompt: str,
|
|
20
20
|
url: AnyUrl,
|
|
21
21
|
timeout: float = 20,
|
|
22
|
-
headers: Optional[
|
|
22
|
+
headers: Optional[dict[str, str]] = None,
|
|
23
23
|
**params,
|
|
24
|
-
) -> str:
|
|
24
|
+
) -> Tuple[str, Optional[int]]:
|
|
25
25
|
img_io = BytesIO()
|
|
26
26
|
image.save(img_io, "PNG")
|
|
27
27
|
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
|
@@ -60,7 +60,8 @@ def api_image_request(
|
|
|
60
60
|
|
|
61
61
|
api_resp = OpenAiApiResponse.model_validate_json(r.text)
|
|
62
62
|
generated_text = api_resp.choices[0].message.content.strip()
|
|
63
|
-
|
|
63
|
+
num_tokens = api_resp.usage.total_tokens
|
|
64
|
+
return generated_text, num_tokens
|
|
64
65
|
|
|
65
66
|
|
|
66
67
|
def api_image_request_streaming(
|
|
@@ -69,10 +70,10 @@ def api_image_request_streaming(
|
|
|
69
70
|
url: AnyUrl,
|
|
70
71
|
*,
|
|
71
72
|
timeout: float = 20,
|
|
72
|
-
headers: Optional[
|
|
73
|
-
generation_stoppers:
|
|
73
|
+
headers: Optional[dict[str, str]] = None,
|
|
74
|
+
generation_stoppers: list[GenerationStopper] = [],
|
|
74
75
|
**params,
|
|
75
|
-
) -> str:
|
|
76
|
+
) -> Tuple[str, Optional[int]]:
|
|
76
77
|
"""
|
|
77
78
|
Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
|
|
78
79
|
Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
|
|
@@ -150,6 +151,16 @@ def api_image_request_streaming(
|
|
|
150
151
|
_log.debug("Unexpected SSE chunk shape: %s", e)
|
|
151
152
|
piece = ""
|
|
152
153
|
|
|
154
|
+
# Try to extract token count
|
|
155
|
+
num_tokens = None
|
|
156
|
+
try:
|
|
157
|
+
if "usage" in obj:
|
|
158
|
+
usage = obj["usage"]
|
|
159
|
+
num_tokens = usage.get("total_tokens")
|
|
160
|
+
except Exception as e:
|
|
161
|
+
num_tokens = None
|
|
162
|
+
_log.debug("Usage key not included in response: %s", e)
|
|
163
|
+
|
|
153
164
|
if piece:
|
|
154
165
|
full_text.append(piece)
|
|
155
166
|
for stopper in generation_stoppers:
|
|
@@ -162,6 +173,6 @@ def api_image_request_streaming(
|
|
|
162
173
|
# closing the connection when we exit the 'with' block.
|
|
163
174
|
# vLLM/OpenAI-compatible servers will detect the client disconnect
|
|
164
175
|
# and abort the request server-side.
|
|
165
|
-
return "".join(full_text)
|
|
176
|
+
return "".join(full_text), num_tokens
|
|
166
177
|
|
|
167
|
-
return "".join(full_text)
|
|
178
|
+
return "".join(full_text), num_tokens
|