ccs-llmconnector 1.1.2__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -82,35 +82,35 @@ def _build_parser() -> argparse.ArgumentParser:
82
82
  default=32000,
83
83
  help="Maximum output tokens (provider-specific meaning)",
84
84
  )
85
- p_respond.add_argument(
86
- "--reasoning-effort",
87
- choices=["low", "medium", "high"],
88
- default=None,
89
- help="Optional reasoning effort hint if supported",
90
- )
91
- p_respond.add_argument(
92
- "--request-id",
93
- default=None,
94
- help="Optional request identifier for tracing/logging",
95
- )
96
- p_respond.add_argument(
97
- "--timeout-s",
98
- type=float,
99
- default=None,
100
- help="Optional timeout in seconds",
101
- )
102
- p_respond.add_argument(
103
- "--max-retries",
104
- type=int,
105
- default=0,
106
- help="Number of retries for transient failures",
107
- )
108
- p_respond.add_argument(
109
- "--retry-backoff-s",
110
- type=float,
111
- default=0.5,
112
- help="Base delay in seconds for exponential backoff",
113
- )
85
+ p_respond.add_argument(
86
+ "--reasoning-effort",
87
+ choices=["low", "medium", "high"],
88
+ default=None,
89
+ help="Optional reasoning effort hint if supported",
90
+ )
91
+ p_respond.add_argument(
92
+ "--request-id",
93
+ default=None,
94
+ help="Optional request identifier for tracing/logging",
95
+ )
96
+ p_respond.add_argument(
97
+ "--timeout-s",
98
+ type=float,
99
+ default=None,
100
+ help="Optional timeout in seconds",
101
+ )
102
+ p_respond.add_argument(
103
+ "--max-retries",
104
+ type=int,
105
+ default=0,
106
+ help="Number of retries for transient failures",
107
+ )
108
+ p_respond.add_argument(
109
+ "--retry-backoff-s",
110
+ type=float,
111
+ default=0.5,
112
+ help="Base delay in seconds for exponential backoff",
113
+ )
114
114
 
115
115
  # models: list available models
116
116
  p_models = subparsers.add_parser(
@@ -172,19 +172,19 @@ def _cmd_respond(args: argparse.Namespace) -> int:
172
172
  print("Error: provide a prompt or at least one image.", file=sys.stderr)
173
173
  return 2
174
174
  try:
175
- output = client.generate_response(
176
- provider=provider,
177
- api_key=api_key,
178
- prompt=prompt,
179
- model=model,
180
- max_tokens=args.max_tokens,
181
- reasoning_effort=args.reasoning_effort,
182
- images=images,
183
- request_id=args.request_id,
184
- timeout_s=args.timeout_s,
185
- max_retries=args.max_retries,
186
- retry_backoff_s=args.retry_backoff_s,
187
- )
175
+ output = client.generate_response(
176
+ provider=provider,
177
+ api_key=api_key,
178
+ prompt=prompt,
179
+ model=model,
180
+ max_tokens=args.max_tokens,
181
+ reasoning_effort=args.reasoning_effort,
182
+ images=images,
183
+ request_id=args.request_id,
184
+ timeout_s=args.timeout_s,
185
+ max_retries=args.max_retries,
186
+ retry_backoff_s=args.retry_backoff_s,
187
+ )
188
188
  except Exception as exc: # pragma: no cover - CLI surface
189
189
  print(f"Error: {exc}", file=sys.stderr)
190
190
  return 2
@@ -6,38 +6,45 @@ import base64
6
6
  import mimetypes
7
7
  from pathlib import Path
8
8
  import logging
9
- from typing import Optional, Sequence
9
+ from typing import Optional, Sequence, Union
10
10
  from urllib.request import urlopen
11
11
 
12
12
  from google import genai
13
13
  from google.genai import types
14
14
 
15
- from .types import ImageInput, MessageSequence, normalize_messages
15
+ from .types import (
16
+ EmbeddingVector,
17
+ ImageInput,
18
+ LLMResponse,
19
+ MessageSequence,
20
+ TokenUsage,
21
+ normalize_messages,
22
+ )
16
23
  from .utils import clamp_retries, run_sync_in_thread, run_with_retries
17
24
 
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
- _GEMINI_MIN_TIMEOUT_S = 10.0
22
- _GEMINI_MIN_TIMEOUT_MS = int(_GEMINI_MIN_TIMEOUT_S * 1000)
23
-
24
-
25
- def _normalize_gemini_timeout_ms(timeout_s: float) -> int:
26
- """Convert a seconds timeout into the millisecond value expected by google-genai HttpOptions."""
27
- # google-genai HttpOptions expects milliseconds, but our public API uses seconds.
28
- effective_timeout_s = max(_GEMINI_MIN_TIMEOUT_S, timeout_s)
29
- if effective_timeout_s != timeout_s:
30
- logger.warning(
31
- "Gemini timeout %ss is too short, clamping to %ss.",
32
- timeout_s,
33
- effective_timeout_s,
34
- )
35
- timeout_ms = int(effective_timeout_s * 1000)
36
- return max(_GEMINI_MIN_TIMEOUT_MS, timeout_ms)
37
-
38
-
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ _GEMINI_MIN_TIMEOUT_S = 10.0
29
+ _GEMINI_MIN_TIMEOUT_MS = int(_GEMINI_MIN_TIMEOUT_S * 1000)
30
+
31
+
32
+ def _normalize_gemini_timeout_ms(timeout_s: float) -> int:
33
+ """Convert a seconds timeout into the millisecond value expected by google-genai HttpOptions."""
34
+ # google-genai HttpOptions expects milliseconds, but our public API uses seconds.
35
+ effective_timeout_s = max(_GEMINI_MIN_TIMEOUT_S, timeout_s)
36
+ if effective_timeout_s != timeout_s:
37
+ logger.warning(
38
+ "Gemini timeout %ss is too short, clamping to %ss.",
39
+ timeout_s,
40
+ effective_timeout_s,
41
+ )
42
+ timeout_ms = int(effective_timeout_s * 1000)
43
+ return max(_GEMINI_MIN_TIMEOUT_MS, timeout_ms)
44
+
45
+
39
46
  class GeminiClient:
40
- """Convenience wrapper around the Google Gemini SDK."""
47
+ """Convenience wrapper around the Google Gemini SDK."""
41
48
 
42
49
  def generate_response(
43
50
  self,
@@ -120,13 +127,13 @@ class GeminiClient:
120
127
 
121
128
  retry_count = clamp_retries(max_retries)
122
129
 
123
- def _build_client() -> genai.Client:
124
- client_kwargs: dict[str, object] = {"api_key": api_key}
125
- if timeout_s is not None:
126
- client_kwargs["http_options"] = types.HttpOptions(
127
- timeout=_normalize_gemini_timeout_ms(timeout_s)
128
- )
129
- return genai.Client(**client_kwargs)
130
+ def _build_client() -> genai.Client:
131
+ client_kwargs: dict[str, object] = {"api_key": api_key}
132
+ if timeout_s is not None:
133
+ client_kwargs["http_options"] = types.HttpOptions(
134
+ timeout=_normalize_gemini_timeout_ms(timeout_s)
135
+ )
136
+ return genai.Client(**client_kwargs)
130
137
 
131
138
  def _run_request() -> str:
132
139
  client = _build_client()
@@ -194,17 +201,153 @@ class GeminiClient:
194
201
  )
195
202
  return ""
196
203
 
197
- return run_with_retries(
198
- func=_run_request,
199
- max_retries=retry_count,
200
- retry_backoff_s=retry_backoff_s,
201
- request_id=request_id,
202
- )
203
-
204
- async def async_generate_response(
205
- self,
206
- *,
207
- api_key: str,
204
+ return run_with_retries(
205
+ func=_run_request,
206
+ max_retries=retry_count,
207
+ retry_backoff_s=retry_backoff_s,
208
+ request_id=request_id,
209
+ )
210
+
211
+ def generate_response_with_usage(
212
+ self,
213
+ *,
214
+ api_key: str,
215
+ prompt: Optional[str] = None,
216
+ model: str,
217
+ max_tokens: int = 32000,
218
+ reasoning_effort: Optional[str] = None,
219
+ images: Optional[Sequence[ImageInput]] = None,
220
+ messages: Optional[MessageSequence] = None,
221
+ request_id: Optional[str] = None,
222
+ timeout_s: Optional[float] = None,
223
+ max_retries: Optional[int] = None,
224
+ retry_backoff_s: float = 0.5,
225
+ ) -> LLMResponse:
226
+ if not api_key:
227
+ raise ValueError("api_key must be provided.")
228
+ if not prompt and not messages and not images:
229
+ raise ValueError("At least one of prompt, messages, or images must be provided.")
230
+ if not model:
231
+ raise ValueError("model must be provided.")
232
+
233
+ normalized_messages = normalize_messages(prompt=prompt, messages=messages)
234
+ contents: list[types.Content] = []
235
+ for message in normalized_messages:
236
+ parts: list[types.Part] = []
237
+ if message["content"]:
238
+ parts.append(types.Part.from_text(text=message["content"]))
239
+ contents.append(types.Content(role=message["role"], parts=parts))
240
+
241
+ if images:
242
+ image_parts = [self._to_image_part(image) for image in images]
243
+ target_index = next(
244
+ (
245
+ index
246
+ for index in range(len(contents) - 1, -1, -1)
247
+ if contents[index].role == "user"
248
+ ),
249
+ None,
250
+ )
251
+ if target_index is None:
252
+ contents.append(types.Content(role="user", parts=image_parts))
253
+ else:
254
+ existing_parts = list(contents[target_index].parts or [])
255
+ existing_parts.extend(image_parts)
256
+ contents[target_index] = types.Content(role="user", parts=existing_parts)
257
+
258
+ if not contents or not any(content.parts for content in contents):
259
+ raise ValueError("No content provided for response generation.")
260
+
261
+ config = types.GenerateContentConfig(max_output_tokens=max_tokens)
262
+ _ = reasoning_effort # accepted for API parity; not currently applied by the Gemini SDK.
263
+
264
+ retry_count = clamp_retries(max_retries)
265
+
266
+ def _build_client() -> genai.Client:
267
+ client_kwargs: dict[str, object] = {"api_key": api_key}
268
+ if timeout_s is not None:
269
+ client_kwargs["http_options"] = types.HttpOptions(
270
+ timeout=_normalize_gemini_timeout_ms(timeout_s)
271
+ )
272
+ return genai.Client(**client_kwargs)
273
+
274
+ def _run_request() -> LLMResponse:
275
+ client = _build_client()
276
+ try:
277
+ try:
278
+ response = client.models.generate_content(
279
+ model=model,
280
+ contents=contents,
281
+ config=config,
282
+ )
283
+ except Exception as exc:
284
+ logger.exception(
285
+ "Gemini generate_content failed: %s request_id=%s",
286
+ exc,
287
+ request_id,
288
+ )
289
+ raise
290
+ finally:
291
+ closer = getattr(client, "close", None)
292
+ if callable(closer):
293
+ try:
294
+ closer()
295
+ except Exception:
296
+ pass
297
+
298
+ usage = _extract_gemini_usage(response)
299
+
300
+ if response.text:
301
+ result_text = response.text
302
+ logger.info(
303
+ "Gemini generate_content succeeded: model=%s images=%d text_len=%d request_id=%s",
304
+ model,
305
+ len(images or []),
306
+ len(result_text or ""),
307
+ request_id,
308
+ )
309
+ return LLMResponse(text=result_text, usage=usage, provider="gemini", model=model)
310
+
311
+ candidate_texts: list[str] = []
312
+ for candidate in getattr(response, "candidates", []) or []:
313
+ content_obj = getattr(candidate, "content", None)
314
+ if not content_obj:
315
+ continue
316
+ for part in getattr(content_obj, "parts", []) or []:
317
+ text = getattr(part, "text", None)
318
+ if text:
319
+ candidate_texts.append(text)
320
+
321
+ if candidate_texts:
322
+ result_text = "\n".join(candidate_texts)
323
+ logger.info(
324
+ "Gemini generate_content succeeded (candidates): model=%s images=%d text_len=%d request_id=%s",
325
+ model,
326
+ len(images or []),
327
+ len(result_text or ""),
328
+ request_id,
329
+ )
330
+ return LLMResponse(text=result_text, usage=usage, provider="gemini", model=model)
331
+
332
+ logger.info(
333
+ "Gemini generate_content succeeded with no text: model=%s images=%d request_id=%s",
334
+ model,
335
+ len(images or []),
336
+ request_id,
337
+ )
338
+ return LLMResponse(text="", usage=usage, provider="gemini", model=model)
339
+
340
+ return run_with_retries(
341
+ func=_run_request,
342
+ max_retries=retry_count,
343
+ retry_backoff_s=retry_backoff_s,
344
+ request_id=request_id,
345
+ )
346
+
347
+ async def async_generate_response(
348
+ self,
349
+ *,
350
+ api_key: str,
208
351
  prompt: Optional[str] = None,
209
352
  model: str,
210
353
  max_tokens: int = 32000,
@@ -216,21 +359,52 @@ class GeminiClient:
216
359
  max_retries: Optional[int] = None,
217
360
  retry_backoff_s: float = 0.5,
218
361
  ) -> str:
219
- return await run_sync_in_thread(
220
- lambda: self.generate_response(
221
- api_key=api_key,
222
- prompt=prompt,
223
- model=model,
224
- max_tokens=max_tokens,
225
- reasoning_effort=reasoning_effort,
226
- images=images,
227
- messages=messages,
228
- request_id=request_id,
229
- timeout_s=timeout_s,
230
- max_retries=max_retries,
231
- retry_backoff_s=retry_backoff_s,
232
- )
233
- )
362
+ return await run_sync_in_thread(
363
+ lambda: self.generate_response(
364
+ api_key=api_key,
365
+ prompt=prompt,
366
+ model=model,
367
+ max_tokens=max_tokens,
368
+ reasoning_effort=reasoning_effort,
369
+ images=images,
370
+ messages=messages,
371
+ request_id=request_id,
372
+ timeout_s=timeout_s,
373
+ max_retries=max_retries,
374
+ retry_backoff_s=retry_backoff_s,
375
+ )
376
+ )
377
+
378
+ async def async_generate_response_with_usage(
379
+ self,
380
+ *,
381
+ api_key: str,
382
+ prompt: Optional[str] = None,
383
+ model: str,
384
+ max_tokens: int = 32000,
385
+ reasoning_effort: Optional[str] = None,
386
+ images: Optional[Sequence[ImageInput]] = None,
387
+ messages: Optional[MessageSequence] = None,
388
+ request_id: Optional[str] = None,
389
+ timeout_s: Optional[float] = None,
390
+ max_retries: Optional[int] = None,
391
+ retry_backoff_s: float = 0.5,
392
+ ) -> LLMResponse:
393
+ return await run_sync_in_thread(
394
+ lambda: self.generate_response_with_usage(
395
+ api_key=api_key,
396
+ prompt=prompt,
397
+ model=model,
398
+ max_tokens=max_tokens,
399
+ reasoning_effort=reasoning_effort,
400
+ images=images,
401
+ messages=messages,
402
+ request_id=request_id,
403
+ timeout_s=timeout_s,
404
+ max_retries=max_retries,
405
+ retry_backoff_s=retry_backoff_s,
406
+ )
407
+ )
234
408
 
235
409
  def generate_image(
236
410
  self,
@@ -284,13 +458,13 @@ class GeminiClient:
284
458
 
285
459
  retry_count = clamp_retries(max_retries)
286
460
 
287
- def _build_client() -> genai.Client:
288
- client_kwargs: dict[str, object] = {"api_key": api_key}
289
- if timeout_s is not None:
290
- client_kwargs["http_options"] = types.HttpOptions(
291
- timeout=_normalize_gemini_timeout_ms(timeout_s)
292
- )
293
- return genai.Client(**client_kwargs)
461
+ def _build_client() -> genai.Client:
462
+ client_kwargs: dict[str, object] = {"api_key": api_key}
463
+ if timeout_s is not None:
464
+ client_kwargs["http_options"] = types.HttpOptions(
465
+ timeout=_normalize_gemini_timeout_ms(timeout_s)
466
+ )
467
+ return genai.Client(**client_kwargs)
294
468
 
295
469
  def _run_request() -> bytes:
296
470
  client = _build_client()
@@ -376,13 +550,13 @@ class GeminiClient:
376
550
 
377
551
  retry_count = clamp_retries(max_retries)
378
552
 
379
- def _build_client() -> genai.Client:
380
- client_kwargs: dict[str, object] = {"api_key": api_key}
381
- if timeout_s is not None:
382
- client_kwargs["http_options"] = types.HttpOptions(
383
- timeout=_normalize_gemini_timeout_ms(timeout_s)
384
- )
385
- return genai.Client(**client_kwargs)
553
+ def _build_client() -> genai.Client:
554
+ client_kwargs: dict[str, object] = {"api_key": api_key}
555
+ if timeout_s is not None:
556
+ client_kwargs["http_options"] = types.HttpOptions(
557
+ timeout=_normalize_gemini_timeout_ms(timeout_s)
558
+ )
559
+ return genai.Client(**client_kwargs)
386
560
 
387
561
  def _run_request() -> list[dict[str, Optional[str]]]:
388
562
  models: list[dict[str, Optional[str]]] = []
@@ -435,7 +609,7 @@ class GeminiClient:
435
609
  request_id=request_id,
436
610
  )
437
611
 
438
- async def async_list_models(
612
+ async def async_list_models(
439
613
  self,
440
614
  *,
441
615
  api_key: str,
@@ -444,15 +618,126 @@ class GeminiClient:
444
618
  max_retries: Optional[int] = None,
445
619
  retry_backoff_s: float = 0.5,
446
620
  ) -> list[dict[str, Optional[str]]]:
447
- return await run_sync_in_thread(
448
- lambda: self.list_models(
449
- api_key=api_key,
450
- request_id=request_id,
451
- timeout_s=timeout_s,
452
- max_retries=max_retries,
453
- retry_backoff_s=retry_backoff_s,
454
- )
455
- )
621
+ return await run_sync_in_thread(
622
+ lambda: self.list_models(
623
+ api_key=api_key,
624
+ request_id=request_id,
625
+ timeout_s=timeout_s,
626
+ max_retries=max_retries,
627
+ retry_backoff_s=retry_backoff_s,
628
+ )
629
+ )
630
+
631
+ def embed_content(
632
+ self,
633
+ *,
634
+ api_key: str,
635
+ model: str,
636
+ contents: Union[str, Sequence[str]],
637
+ task_type: Optional[str] = None,
638
+ output_dimensionality: Optional[int] = None,
639
+ request_id: Optional[str] = None,
640
+ timeout_s: Optional[float] = None,
641
+ max_retries: Optional[int] = None,
642
+ retry_backoff_s: float = 0.5,
643
+ ) -> list[EmbeddingVector]:
644
+ if not api_key:
645
+ raise ValueError("api_key must be provided.")
646
+ if not model:
647
+ raise ValueError("model must be provided.")
648
+
649
+ if isinstance(contents, str):
650
+ payload: Union[str, list[str]] = contents
651
+ else:
652
+ payload = list(contents)
653
+ if not payload:
654
+ raise ValueError("contents must not be empty.")
655
+
656
+ retry_count = clamp_retries(max_retries)
657
+
658
+ def _build_client() -> genai.Client:
659
+ client_kwargs: dict[str, object] = {"api_key": api_key}
660
+ if timeout_s is not None:
661
+ http_options = getattr(types, "HttpOptions", None)
662
+ if http_options is not None:
663
+ try:
664
+ client_kwargs["http_options"] = http_options(timeout=timeout_s)
665
+ except Exception:
666
+ logger.debug("Gemini HttpOptions timeout not applied.", exc_info=True)
667
+ return genai.Client(**client_kwargs)
668
+
669
+ config_kwargs: dict[str, object] = {}
670
+ if task_type is not None:
671
+ config_kwargs["task_type"] = task_type
672
+ if output_dimensionality is not None:
673
+ config_kwargs["output_dimensionality"] = output_dimensionality
674
+ config = types.EmbedContentConfig(**config_kwargs) if config_kwargs else None
675
+
676
+ def _run_request() -> list[EmbeddingVector]:
677
+ client = _build_client()
678
+ try:
679
+ result = client.models.embed_content(
680
+ model=model,
681
+ contents=payload,
682
+ config=config,
683
+ )
684
+ embeddings = getattr(result, "embeddings", None)
685
+ if embeddings is None:
686
+ raise ValueError("Gemini embeddings response missing embeddings field.")
687
+ vectors: list[EmbeddingVector] = []
688
+ for embedding in embeddings:
689
+ values = getattr(embedding, "values", None)
690
+ if values is None:
691
+ raise ValueError("Gemini embedding missing values field.")
692
+ vectors.append(list(values))
693
+ return vectors
694
+ finally:
695
+ closer = getattr(client, "close", None)
696
+ if callable(closer):
697
+ try:
698
+ closer()
699
+ except Exception:
700
+ pass
701
+
702
+ vectors = run_with_retries(
703
+ func=_run_request,
704
+ max_retries=retry_count,
705
+ retry_backoff_s=retry_backoff_s,
706
+ request_id=request_id,
707
+ )
708
+ logger.info(
709
+ "Gemini embed_content succeeded: count=%d request_id=%s",
710
+ len(vectors),
711
+ request_id,
712
+ )
713
+ return vectors
714
+
715
+ async def async_embed_content(
716
+ self,
717
+ *,
718
+ api_key: str,
719
+ model: str,
720
+ contents: Union[str, Sequence[str]],
721
+ task_type: Optional[str] = None,
722
+ output_dimensionality: Optional[int] = None,
723
+ request_id: Optional[str] = None,
724
+ timeout_s: Optional[float] = None,
725
+ max_retries: Optional[int] = None,
726
+ retry_backoff_s: float = 0.5,
727
+ ) -> list[EmbeddingVector]:
728
+ return await run_sync_in_thread(
729
+ lambda: self.embed_content(
730
+ api_key=api_key,
731
+ model=model,
732
+ contents=contents,
733
+ task_type=task_type,
734
+ output_dimensionality=output_dimensionality,
735
+ request_id=request_id,
736
+ timeout_s=timeout_s,
737
+ max_retries=max_retries,
738
+ retry_backoff_s=retry_backoff_s,
739
+ )
740
+ )
456
741
 
457
742
  @staticmethod
458
743
  def _to_image_part(image: ImageInput) -> types.Part:
@@ -469,18 +754,18 @@ class GeminiClient:
469
754
  return _part_from_path(Path(image))
470
755
 
471
756
 
472
- def _part_from_path(path: Path) -> types.Part:
473
- """Create an image part from a local filesystem path."""
474
- # Ensure common audio types are recognized across platforms (used for transcription as well).
475
- mimetypes.add_type("audio/mp4", ".m4a")
476
- mimetypes.add_type("audio/mpeg", ".mp3")
477
- mimetypes.add_type("audio/wav", ".wav")
478
- mimetypes.add_type("audio/aac", ".aac")
479
-
480
- expanded = path.expanduser()
481
- data = expanded.read_bytes()
482
- mime_type = mimetypes.guess_type(expanded.name)[0] or "application/octet-stream"
483
- return types.Part.from_bytes(data=data, mime_type=mime_type)
757
+ def _part_from_path(path: Path) -> types.Part:
758
+ """Create an image part from a local filesystem path."""
759
+ # Ensure common audio types are recognized across platforms (used for transcription as well).
760
+ mimetypes.add_type("audio/mp4", ".m4a")
761
+ mimetypes.add_type("audio/mpeg", ".mp3")
762
+ mimetypes.add_type("audio/wav", ".wav")
763
+ mimetypes.add_type("audio/aac", ".aac")
764
+
765
+ expanded = path.expanduser()
766
+ data = expanded.read_bytes()
767
+ mime_type = mimetypes.guess_type(expanded.name)[0] or "application/octet-stream"
768
+ return types.Part.from_bytes(data=data, mime_type=mime_type)
484
769
 
485
770
 
486
771
  def _part_from_url(url: str) -> types.Part:
@@ -495,7 +780,7 @@ def _part_from_url(url: str) -> types.Part:
495
780
  return types.Part.from_bytes(data=data, mime_type=mime_type)
496
781
 
497
782
 
498
- def _part_from_data_url(data_url: str) -> types.Part:
783
+ def _part_from_data_url(data_url: str) -> types.Part:
499
784
  """Create an image part from a data URL."""
500
785
  header, encoded = data_url.split(",", 1)
501
786
  metadata = header[len("data:") :]
@@ -510,5 +795,35 @@ def _part_from_data_url(data_url: str) -> types.Part:
510
795
  data = base64.b64decode(encoded)
511
796
  else:
512
797
  data = encoded.encode("utf-8")
513
-
514
- return types.Part.from_bytes(data=data, mime_type=mime_type or "application/octet-stream")
798
+
799
+ return types.Part.from_bytes(data=data, mime_type=mime_type or "application/octet-stream")
800
+
801
+
802
+ def _extract_gemini_usage(response: object) -> TokenUsage | None:
803
+ usage_obj = getattr(response, "usage_metadata", None)
804
+ if usage_obj is None:
805
+ usage_obj = getattr(response, "usage", None)
806
+ if usage_obj is None:
807
+ return None
808
+
809
+ input_tokens = getattr(usage_obj, "prompt_token_count", None)
810
+ output_tokens = getattr(usage_obj, "candidates_token_count", None)
811
+ total_tokens = getattr(usage_obj, "total_token_count", None)
812
+
813
+ if input_tokens is None:
814
+ input_tokens = getattr(usage_obj, "input_tokens", None)
815
+ if output_tokens is None:
816
+ output_tokens = getattr(usage_obj, "output_tokens", None)
817
+ if total_tokens is None:
818
+ total_tokens = getattr(usage_obj, "total_tokens", None)
819
+
820
+ if isinstance(usage_obj, dict):
821
+ input_tokens = usage_obj.get("prompt_token_count", usage_obj.get("input_tokens"))
822
+ output_tokens = usage_obj.get("candidates_token_count", usage_obj.get("output_tokens"))
823
+ total_tokens = usage_obj.get("total_token_count", usage_obj.get("total_tokens"))
824
+
825
+ return TokenUsage(
826
+ input_tokens=int(input_tokens) if isinstance(input_tokens, int) else None,
827
+ output_tokens=int(output_tokens) if isinstance(output_tokens, int) else None,
828
+ total_tokens=int(total_tokens) if isinstance(total_tokens, int) else None,
829
+ )