docling 2.57.0__py3-none-any.whl → 2.59.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of docling might be problematic. Click here for more details.

Files changed (35) hide show
  1. docling/backend/abstract_backend.py +24 -3
  2. docling/backend/asciidoc_backend.py +3 -3
  3. docling/backend/docling_parse_v4_backend.py +15 -4
  4. docling/backend/html_backend.py +130 -20
  5. docling/backend/md_backend.py +27 -5
  6. docling/backend/msexcel_backend.py +121 -29
  7. docling/backend/mspowerpoint_backend.py +2 -2
  8. docling/backend/msword_backend.py +18 -18
  9. docling/backend/pdf_backend.py +9 -2
  10. docling/backend/pypdfium2_backend.py +12 -3
  11. docling/cli/main.py +104 -38
  12. docling/datamodel/asr_model_specs.py +408 -6
  13. docling/datamodel/backend_options.py +82 -0
  14. docling/datamodel/base_models.py +19 -2
  15. docling/datamodel/document.py +81 -48
  16. docling/datamodel/pipeline_options_asr_model.py +21 -1
  17. docling/datamodel/pipeline_options_vlm_model.py +1 -0
  18. docling/document_converter.py +37 -45
  19. docling/document_extractor.py +12 -11
  20. docling/models/api_vlm_model.py +5 -3
  21. docling/models/picture_description_vlm_model.py +5 -1
  22. docling/models/readingorder_model.py +6 -7
  23. docling/models/vlm_models_inline/hf_transformers_model.py +13 -3
  24. docling/models/vlm_models_inline/mlx_model.py +9 -3
  25. docling/models/vlm_models_inline/nuextract_transformers_model.py +13 -3
  26. docling/models/vlm_models_inline/vllm_model.py +42 -8
  27. docling/pipeline/asr_pipeline.py +149 -6
  28. docling/utils/api_image_request.py +20 -9
  29. docling/utils/layout_postprocessor.py +23 -24
  30. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/METADATA +11 -8
  31. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/RECORD +35 -34
  32. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/WHEEL +0 -0
  33. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/entry_points.txt +0 -0
  34. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/licenses/LICENSE +0 -0
  35. {docling-2.57.0.dist-info → docling-2.59.0.dist-info}/top_level.txt +0 -0
@@ -73,7 +73,7 @@ class ApiVlmModel(BasePageModel):
73
73
  # Skip non-GenerationStopper criteria (should have been caught in validation)
74
74
 
75
75
  # Streaming path with early abort support
76
- page_tags = api_image_request_streaming(
76
+ page_tags, num_tokens = api_image_request_streaming(
77
77
  image=hi_res_image,
78
78
  prompt=prompt,
79
79
  url=self.vlm_options.url,
@@ -84,7 +84,7 @@ class ApiVlmModel(BasePageModel):
84
84
  )
85
85
  else:
86
86
  # Non-streaming fallback (existing behavior)
87
- page_tags = api_image_request(
87
+ page_tags, num_tokens = api_image_request(
88
88
  image=hi_res_image,
89
89
  prompt=prompt,
90
90
  url=self.vlm_options.url,
@@ -94,7 +94,9 @@ class ApiVlmModel(BasePageModel):
94
94
  )
95
95
 
96
96
  page_tags = self.vlm_options.decode_response(page_tags)
97
- page.predictions.vlm_response = VlmPrediction(text=page_tags)
97
+ page.predictions.vlm_response = VlmPrediction(
98
+ text=page_tags, num_tokens=num_tokens
99
+ )
98
100
  return page
99
101
 
100
102
  with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
@@ -1,3 +1,4 @@
1
+ import sys
1
2
  import threading
2
3
  from collections.abc import Iterable
3
4
  from pathlib import Path
@@ -75,7 +76,10 @@ class PictureDescriptionVlmModel(
75
76
  else "sdpa"
76
77
  ),
77
78
  )
78
- self.model = torch.compile(self.model) # type: ignore
79
+ if sys.version_info < (3, 14):
80
+ self.model = torch.compile(self.model) # type: ignore
81
+ else:
82
+ self.model.eval()
79
83
 
80
84
  self.provenance = f"{self.options.repo_id}"
81
85
 
@@ -1,5 +1,4 @@
1
1
  from pathlib import Path
2
- from typing import Dict, List
3
2
 
4
3
  from docling_core.types.doc import (
5
4
  DocItemLabel,
@@ -48,8 +47,8 @@ class ReadingOrderModel:
48
47
 
49
48
  def _assembled_to_readingorder_elements(
50
49
  self, conv_res: ConversionResult
51
- ) -> List[ReadingOrderPageElement]:
52
- elements: List[ReadingOrderPageElement] = []
50
+ ) -> list[ReadingOrderPageElement]:
51
+ elements: list[ReadingOrderPageElement] = []
53
52
  page_no_to_pages = {p.page_no: p for p in conv_res.pages}
54
53
 
55
54
  for element in conv_res.assembled.elements:
@@ -123,10 +122,10 @@ class ReadingOrderModel:
123
122
  def _readingorder_elements_to_docling_doc(
124
123
  self,
125
124
  conv_res: ConversionResult,
126
- ro_elements: List[ReadingOrderPageElement],
127
- el_to_captions_mapping: Dict[int, List[int]],
128
- el_to_footnotes_mapping: Dict[int, List[int]],
129
- el_merges_mapping: Dict[int, List[int]],
125
+ ro_elements: list[ReadingOrderPageElement],
126
+ el_to_captions_mapping: dict[int, list[int]],
127
+ el_to_footnotes_mapping: dict[int, list[int]],
128
+ el_merges_mapping: dict[int, list[int]],
130
129
  ) -> DoclingDocument:
131
130
  id_to_elem = {
132
131
  RefItem(cref=f"#/{elem.page_no}/{elem.cluster.id}").cref: elem
@@ -1,5 +1,6 @@
1
1
  import importlib.metadata
2
2
  import logging
3
+ import sys
3
4
  import time
4
5
  from collections.abc import Iterable
5
6
  from pathlib import Path
@@ -129,7 +130,10 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
129
130
  trust_remote_code=vlm_options.trust_remote_code,
130
131
  revision=vlm_options.revision,
131
132
  )
132
- self.vlm_model = torch.compile(self.vlm_model) # type: ignore
133
+ if sys.version_info < (3, 14):
134
+ self.vlm_model = torch.compile(self.vlm_model) # type: ignore
135
+ else:
136
+ self.vlm_model.eval()
133
137
 
134
138
  # Load generation config
135
139
  self.generation_config = GenerationConfig.from_pretrained(
@@ -363,13 +367,19 @@ class HuggingFaceTransformersVlmModel(BaseVlmPageModel, HuggingFaceModelDownload
363
367
  decoded_texts = [text.rstrip(pad_token) for text in decoded_texts]
364
368
 
365
369
  # -- Optional logging
370
+ num_tokens = None
366
371
  if generated_ids.shape[0] > 0:
372
+ num_tokens = int(generated_ids[0].shape[0])
367
373
  _log.debug(
368
- f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
374
+ f"Generated {num_tokens} tokens in {generation_time:.2f}s "
369
375
  f"for batch size {generated_ids.shape[0]}."
370
376
  )
371
377
 
372
378
  for text in decoded_texts:
373
379
  # Apply decode_response to the output text
374
380
  decoded_text = self.vlm_options.decode_response(text)
375
- yield VlmPrediction(text=decoded_text, generation_time=generation_time)
381
+ yield VlmPrediction(
382
+ text=decoded_text,
383
+ generation_time=generation_time,
384
+ num_tokens=num_tokens,
385
+ )
@@ -50,9 +50,14 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
50
50
  from mlx_vlm.prompt_utils import apply_chat_template # type: ignore
51
51
  from mlx_vlm.utils import load_config # type: ignore
52
52
  except ImportError:
53
- raise ImportError(
54
- "mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
55
- )
53
+ if sys.version_info < (3, 14):
54
+ raise ImportError(
55
+ "mlx-vlm is not installed. Please install it via `pip install mlx-vlm` to use MLX VLM models."
56
+ )
57
+ else:
58
+ raise ImportError(
59
+ "mlx-vlm is not installed. It is not yet available on Python 3.14."
60
+ )
56
61
 
57
62
  repo_cache_folder = vlm_options.repo_id.replace("/", "--")
58
63
 
@@ -313,5 +318,6 @@ class HuggingFaceMlxModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
313
318
  text=decoded_output,
314
319
  generation_time=generation_time,
315
320
  generated_tokens=tokens,
321
+ num_tokens=len(tokens),
316
322
  )
317
323
  _log.debug("MLX model: Released global lock")
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import sys
2
3
  import time
3
4
  from collections.abc import Iterable
4
5
  from pathlib import Path
@@ -153,7 +154,10 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
153
154
  ),
154
155
  trust_remote_code=vlm_options.trust_remote_code,
155
156
  )
156
- self.vlm_model = torch.compile(self.vlm_model) # type: ignore
157
+ if sys.version_info < (3, 14):
158
+ self.vlm_model = torch.compile(self.vlm_model) # type: ignore
159
+ else:
160
+ self.vlm_model.eval()
157
161
 
158
162
  # Load generation config
159
163
  self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
@@ -278,13 +282,19 @@ class NuExtractTransformersModel(BaseVlmModel, HuggingFaceModelDownloadMixin):
278
282
  )
279
283
 
280
284
  # Optional logging
285
+ num_tokens = None
281
286
  if generated_ids.shape[0] > 0: # type: ignore
287
+ num_tokens = int(generated_ids[0].shape[0])
282
288
  _log.debug(
283
- f"Generated {int(generated_ids[0].shape[0])} tokens in {generation_time:.2f}s "
289
+ f"Generated {num_tokens} tokens in {generation_time:.2f}s "
284
290
  f"for batch size {generated_ids.shape[0]}." # type: ignore
285
291
  )
286
292
 
287
293
  for text in decoded_texts:
288
294
  # Apply decode_response to the output text
289
295
  decoded_text = self.vlm_options.decode_response(text)
290
- yield VlmPrediction(text=decoded_text, generation_time=generation_time)
296
+ yield VlmPrediction(
297
+ text=decoded_text,
298
+ generation_time=generation_time,
299
+ num_tokens=num_tokens,
300
+ )
@@ -1,4 +1,5 @@
1
1
  import logging
2
+ import sys
2
3
  import time
3
4
  from collections.abc import Iterable
4
5
  from pathlib import Path
@@ -8,7 +9,7 @@ import numpy as np
8
9
  from PIL.Image import Image
9
10
 
10
11
  from docling.datamodel.accelerator_options import AcceleratorOptions
11
- from docling.datamodel.base_models import Page, VlmPrediction
12
+ from docling.datamodel.base_models import Page, VlmPrediction, VlmPredictionToken
12
13
  from docling.datamodel.document import ConversionResult
13
14
  from docling.datamodel.pipeline_options_vlm_model import (
14
15
  InlineVlmOptions,
@@ -87,7 +88,7 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
87
88
  vlm_options: InlineVlmOptions,
88
89
  ):
89
90
  self.enabled = enabled
90
- self.vlm_options = vlm_options
91
+ self.vlm_options: InlineVlmOptions = vlm_options
91
92
 
92
93
  self.llm = None
93
94
  self.sampling_params = None
@@ -100,7 +101,18 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
100
101
  return
101
102
 
102
103
  from transformers import AutoProcessor
103
- from vllm import LLM, SamplingParams
104
+
105
+ try:
106
+ from vllm import LLM, SamplingParams
107
+ except ImportError:
108
+ if sys.version_info < (3, 14):
109
+ raise ImportError(
110
+ "vllm is not installed. Please install it via `pip install vllm`."
111
+ )
112
+ else:
113
+ raise ImportError(
114
+ "vllm is not installed. It is not yet available on Python 3.14."
115
+ )
104
116
 
105
117
  # Device selection
106
118
  self.device = decide_device(
@@ -222,7 +234,8 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
222
234
  pages_with_images.append(page)
223
235
 
224
236
  if images:
225
- predictions = list(self.process_images(images, user_prompts))
237
+ with TimeRecorder(conv_res, "vlm_inference"):
238
+ predictions = list(self.process_images(images, user_prompts))
226
239
  for page, prediction in zip(pages_with_images, predictions):
227
240
  page.predictions.vlm_response = prediction
228
241
 
@@ -288,13 +301,34 @@ class VllmVlmModel(BaseVlmPageModel, HuggingFaceModelDownloadMixin):
288
301
  # Optional debug
289
302
  if outputs:
290
303
  try:
291
- num_tokens = len(outputs[0].outputs[0].token_ids)
292
- _log.debug(f"Generated {num_tokens} tokens in {generation_time:.2f}s.")
304
+ num_tokens_within_batch = len(outputs[0].outputs[0].token_ids)
305
+ _log.debug(
306
+ f"Generated {num_tokens_within_batch} tokens for batch in {generation_time:.2f}s."
307
+ )
293
308
  except Exception:
294
- pass
309
+ num_tokens_within_batch = 0
295
310
 
296
311
  # Emit predictions
297
312
  for output in outputs:
298
313
  text = output.outputs[0].text if output.outputs else ""
314
+ stop_reason = output.outputs[0].stop_reason if output.outputs else ""
315
+ generated_tokens = [
316
+ VlmPredictionToken(token=int(p)) for p in output.outputs[0].token_ids
317
+ ]
318
+ num_tokens = len(generated_tokens)
299
319
  decoded_text = self.vlm_options.decode_response(text)
300
- yield VlmPrediction(text=decoded_text, generation_time=generation_time)
320
+ if self.vlm_options.track_generated_tokens:
321
+ yield VlmPrediction(
322
+ text=decoded_text,
323
+ generation_time=generation_time,
324
+ num_tokens=num_tokens,
325
+ stop_reason=stop_reason,
326
+ generated_tokens=generated_tokens,
327
+ )
328
+ else:
329
+ yield VlmPrediction(
330
+ text=decoded_text,
331
+ generation_time=generation_time,
332
+ num_tokens=num_tokens,
333
+ stop_reason=stop_reason,
334
+ )
@@ -1,10 +1,11 @@
1
1
  import logging
2
2
  import os
3
3
  import re
4
+ import sys
4
5
  import tempfile
5
6
  from io import BytesIO
6
7
  from pathlib import Path
7
- from typing import List, Optional, Union, cast
8
+ from typing import TYPE_CHECKING, List, Optional, Union, cast
8
9
 
9
10
  from docling_core.types.doc import DoclingDocument, DocumentOrigin
10
11
 
@@ -32,6 +33,7 @@ from docling.datamodel.pipeline_options import (
32
33
  AsrPipelineOptions,
33
34
  )
34
35
  from docling.datamodel.pipeline_options_asr_model import (
36
+ InlineAsrMlxWhisperOptions,
35
37
  InlineAsrNativeWhisperOptions,
36
38
  # AsrResponseFormat,
37
39
  InlineAsrOptions,
@@ -116,9 +118,15 @@ class _NativeWhisperModel:
116
118
  try:
117
119
  import whisper # type: ignore
118
120
  except ImportError:
119
- raise ImportError(
120
- "whisper is not installed. Please install it via `pip install openai-whisper` or do `uv sync --extra asr`."
121
- )
121
+ if sys.version_info < (3, 14):
122
+ raise ImportError(
123
+ "whisper is not installed. Please install it via `pip install openai-whisper` or do `uv sync --extra asr`."
124
+ )
125
+ else:
126
+ raise ImportError(
127
+ "whisper is not installed. Unfortunately its dependencies are not yet available for Python 3.14."
128
+ )
129
+
122
130
  self.asr_options = asr_options
123
131
  self.max_tokens = asr_options.max_new_tokens
124
132
  self.temperature = asr_options.temperature
@@ -228,22 +236,157 @@ class _NativeWhisperModel:
228
236
  return convo
229
237
 
230
238
 
239
+ class _MlxWhisperModel:
240
+ def __init__(
241
+ self,
242
+ enabled: bool,
243
+ artifacts_path: Optional[Path],
244
+ accelerator_options: AcceleratorOptions,
245
+ asr_options: InlineAsrMlxWhisperOptions,
246
+ ):
247
+ """
248
+ Transcriber using MLX Whisper for Apple Silicon optimization.
249
+ """
250
+ self.enabled = enabled
251
+
252
+ _log.info(f"artifacts-path: {artifacts_path}")
253
+ _log.info(f"accelerator_options: {accelerator_options}")
254
+
255
+ if self.enabled:
256
+ try:
257
+ import mlx_whisper # type: ignore
258
+ except ImportError:
259
+ raise ImportError(
260
+ "mlx-whisper is not installed. Please install it via `pip install mlx-whisper` or do `uv sync --extra asr`."
261
+ )
262
+ self.asr_options = asr_options
263
+ self.mlx_whisper = mlx_whisper
264
+
265
+ self.device = decide_device(
266
+ accelerator_options.device,
267
+ supported_devices=asr_options.supported_devices,
268
+ )
269
+ _log.info(f"Available device for MLX Whisper: {self.device}")
270
+
271
+ self.model_name = asr_options.repo_id
272
+ _log.info(f"loading _MlxWhisperModel({self.model_name})")
273
+
274
+ # MLX Whisper models are loaded differently - they use HuggingFace repos
275
+ self.model_path = self.model_name
276
+
277
+ # Store MLX-specific options
278
+ self.language = asr_options.language
279
+ self.task = asr_options.task
280
+ self.word_timestamps = asr_options.word_timestamps
281
+ self.no_speech_threshold = asr_options.no_speech_threshold
282
+ self.logprob_threshold = asr_options.logprob_threshold
283
+ self.compression_ratio_threshold = asr_options.compression_ratio_threshold
284
+
285
+ def run(self, conv_res: ConversionResult) -> ConversionResult:
286
+ audio_path: Path = Path(conv_res.input.file).resolve()
287
+
288
+ try:
289
+ conversation = self.transcribe(audio_path)
290
+
291
+ # Ensure we have a proper DoclingDocument
292
+ origin = DocumentOrigin(
293
+ filename=conv_res.input.file.name or "audio.wav",
294
+ mimetype="audio/x-wav",
295
+ binary_hash=conv_res.input.document_hash,
296
+ )
297
+ conv_res.document = DoclingDocument(
298
+ name=conv_res.input.file.stem or "audio.wav", origin=origin
299
+ )
300
+
301
+ for citem in conversation:
302
+ conv_res.document.add_text(
303
+ label=DocItemLabel.TEXT, text=citem.to_string()
304
+ )
305
+
306
+ conv_res.status = ConversionStatus.SUCCESS
307
+ return conv_res
308
+
309
+ except Exception as exc:
310
+ _log.error(f"MLX Audio transcription has an error: {exc}")
311
+
312
+ conv_res.status = ConversionStatus.FAILURE
313
+ return conv_res
314
+
315
+ def transcribe(self, fpath: Path) -> list[_ConversationItem]:
316
+ """
317
+ Transcribe audio using MLX Whisper.
318
+
319
+ Args:
320
+ fpath: Path to audio file
321
+
322
+ Returns:
323
+ List of conversation items with timestamps
324
+ """
325
+ result = self.mlx_whisper.transcribe(
326
+ str(fpath),
327
+ path_or_hf_repo=self.model_path,
328
+ language=self.language,
329
+ task=self.task,
330
+ word_timestamps=self.word_timestamps,
331
+ no_speech_threshold=self.no_speech_threshold,
332
+ logprob_threshold=self.logprob_threshold,
333
+ compression_ratio_threshold=self.compression_ratio_threshold,
334
+ )
335
+
336
+ convo: list[_ConversationItem] = []
337
+
338
+ # MLX Whisper returns segments similar to native Whisper
339
+ for segment in result.get("segments", []):
340
+ item = _ConversationItem(
341
+ start_time=segment.get("start"),
342
+ end_time=segment.get("end"),
343
+ text=segment.get("text", "").strip(),
344
+ words=[],
345
+ )
346
+
347
+ # Add word-level timestamps if available
348
+ if self.word_timestamps and "words" in segment:
349
+ item.words = []
350
+ for word_data in segment["words"]:
351
+ item.words.append(
352
+ _ConversationWord(
353
+ start_time=word_data.get("start"),
354
+ end_time=word_data.get("end"),
355
+ text=word_data.get("word", ""),
356
+ )
357
+ )
358
+ convo.append(item)
359
+
360
+ return convo
361
+
362
+
231
363
  class AsrPipeline(BasePipeline):
232
364
  def __init__(self, pipeline_options: AsrPipelineOptions):
233
365
  super().__init__(pipeline_options)
234
366
  self.keep_backend = True
235
367
 
236
368
  self.pipeline_options: AsrPipelineOptions = pipeline_options
369
+ self._model: Union[_NativeWhisperModel, _MlxWhisperModel]
237
370
 
238
371
  if isinstance(self.pipeline_options.asr_options, InlineAsrNativeWhisperOptions):
239
- asr_options: InlineAsrNativeWhisperOptions = (
372
+ native_asr_options: InlineAsrNativeWhisperOptions = (
240
373
  self.pipeline_options.asr_options
241
374
  )
242
375
  self._model = _NativeWhisperModel(
243
376
  enabled=True, # must be always enabled for this pipeline to make sense.
244
377
  artifacts_path=self.artifacts_path,
245
378
  accelerator_options=pipeline_options.accelerator_options,
246
- asr_options=asr_options,
379
+ asr_options=native_asr_options,
380
+ )
381
+ elif isinstance(self.pipeline_options.asr_options, InlineAsrMlxWhisperOptions):
382
+ mlx_asr_options: InlineAsrMlxWhisperOptions = (
383
+ self.pipeline_options.asr_options
384
+ )
385
+ self._model = _MlxWhisperModel(
386
+ enabled=True, # must be always enabled for this pipeline to make sense.
387
+ artifacts_path=self.artifacts_path,
388
+ accelerator_options=pipeline_options.accelerator_options,
389
+ asr_options=mlx_asr_options,
247
390
  )
248
391
  else:
249
392
  _log.error(f"No model support for {self.pipeline_options.asr_options}")
@@ -2,7 +2,7 @@ import base64
2
2
  import json
3
3
  import logging
4
4
  from io import BytesIO
5
- from typing import Dict, List, Optional
5
+ from typing import Dict, List, Optional, Tuple
6
6
 
7
7
  import requests
8
8
  from PIL import Image
@@ -19,9 +19,9 @@ def api_image_request(
19
19
  prompt: str,
20
20
  url: AnyUrl,
21
21
  timeout: float = 20,
22
- headers: Optional[Dict[str, str]] = None,
22
+ headers: Optional[dict[str, str]] = None,
23
23
  **params,
24
- ) -> str:
24
+ ) -> Tuple[str, Optional[int]]:
25
25
  img_io = BytesIO()
26
26
  image.save(img_io, "PNG")
27
27
  image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
@@ -60,7 +60,8 @@ def api_image_request(
60
60
 
61
61
  api_resp = OpenAiApiResponse.model_validate_json(r.text)
62
62
  generated_text = api_resp.choices[0].message.content.strip()
63
- return generated_text
63
+ num_tokens = api_resp.usage.total_tokens
64
+ return generated_text, num_tokens
64
65
 
65
66
 
66
67
  def api_image_request_streaming(
@@ -69,10 +70,10 @@ def api_image_request_streaming(
69
70
  url: AnyUrl,
70
71
  *,
71
72
  timeout: float = 20,
72
- headers: Optional[Dict[str, str]] = None,
73
- generation_stoppers: List[GenerationStopper] = [],
73
+ headers: Optional[dict[str, str]] = None,
74
+ generation_stoppers: list[GenerationStopper] = [],
74
75
  **params,
75
- ) -> str:
76
+ ) -> Tuple[str, Optional[int]]:
76
77
  """
77
78
  Stream a chat completion from an OpenAI-compatible server (e.g., vLLM).
78
79
  Parses SSE lines: 'data: {json}\\n\\n', terminated by 'data: [DONE]'.
@@ -150,6 +151,16 @@ def api_image_request_streaming(
150
151
  _log.debug("Unexpected SSE chunk shape: %s", e)
151
152
  piece = ""
152
153
 
154
+ # Try to extract token count
155
+ num_tokens = None
156
+ try:
157
+ if "usage" in obj:
158
+ usage = obj["usage"]
159
+ num_tokens = usage.get("total_tokens")
160
+ except Exception as e:
161
+ num_tokens = None
162
+ _log.debug("Usage key not included in response: %s", e)
163
+
153
164
  if piece:
154
165
  full_text.append(piece)
155
166
  for stopper in generation_stoppers:
@@ -162,6 +173,6 @@ def api_image_request_streaming(
162
173
  # closing the connection when we exit the 'with' block.
163
174
  # vLLM/OpenAI-compatible servers will detect the client disconnect
164
175
  # and abort the request server-side.
165
- return "".join(full_text)
176
+ return "".join(full_text), num_tokens
166
177
 
167
- return "".join(full_text)
178
+ return "".join(full_text), num_tokens