model-library 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. model_library/base/base.py +141 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +49 -57
  6. model_library/config/all_models.json +353 -120
  7. model_library/config/anthropic_models.yaml +2 -1
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/mistral_models.yaml +2 -0
  10. model_library/config/openai_models.yaml +15 -23
  11. model_library/config/together_models.yaml +2 -0
  12. model_library/config/xiaomi_models.yaml +43 -0
  13. model_library/config/zai_models.yaml +27 -3
  14. model_library/exceptions.py +3 -77
  15. model_library/providers/ai21labs.py +12 -8
  16. model_library/providers/alibaba.py +17 -8
  17. model_library/providers/amazon.py +49 -16
  18. model_library/providers/anthropic.py +128 -48
  19. model_library/providers/azure.py +22 -10
  20. model_library/providers/cohere.py +7 -7
  21. model_library/providers/deepseek.py +8 -8
  22. model_library/providers/fireworks.py +7 -8
  23. model_library/providers/google/batch.py +14 -10
  24. model_library/providers/google/google.py +57 -30
  25. model_library/providers/inception.py +7 -7
  26. model_library/providers/kimi.py +18 -8
  27. model_library/providers/minimax.py +15 -17
  28. model_library/providers/mistral.py +20 -8
  29. model_library/providers/openai.py +99 -22
  30. model_library/providers/openrouter.py +34 -0
  31. model_library/providers/perplexity.py +7 -7
  32. model_library/providers/together.py +7 -8
  33. model_library/providers/vals.py +12 -6
  34. model_library/providers/vercel.py +34 -0
  35. model_library/providers/xai.py +47 -42
  36. model_library/providers/xiaomi.py +34 -0
  37. model_library/providers/zai.py +38 -8
  38. model_library/register_models.py +5 -0
  39. model_library/registry_utils.py +48 -17
  40. model_library/retriers/__init__.py +0 -0
  41. model_library/retriers/backoff.py +73 -0
  42. model_library/retriers/base.py +225 -0
  43. model_library/retriers/token.py +427 -0
  44. model_library/retriers/utils.py +11 -0
  45. model_library/settings.py +1 -1
  46. model_library/utils.py +17 -7
  47. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/METADATA +2 -1
  48. model_library-0.1.9.dist-info/RECORD +73 -0
  49. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/WHEEL +1 -1
  50. model_library-0.1.7.dist-info/RECORD +0 -64
  51. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/licenses/LICENSE +0 -0
  52. {model_library-0.1.7.dist-info → model_library-0.1.9.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,20 @@
1
+ import datetime
1
2
  import io
2
3
  import logging
4
+ import time
3
5
  from typing import Any, Literal, Sequence, cast
4
6
 
5
- from anthropic import AsyncAnthropic
7
+ from anthropic import APIConnectionError, AsyncAnthropic
6
8
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
7
9
  from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
10
+ from pydantic import SecretStr
8
11
  from typing_extensions import override
9
12
 
10
13
  from model_library import model_library_settings
11
14
  from model_library.base import (
12
15
  LLM,
13
16
  BatchResult,
17
+ DelegateConfig,
14
18
  FileBase,
15
19
  FileInput,
16
20
  FileWithBase64,
@@ -19,9 +23,11 @@ from model_library.base import (
19
23
  InputItem,
20
24
  LLMBatchMixin,
21
25
  LLMConfig,
26
+ ProviderConfig,
22
27
  QueryResult,
23
28
  QueryResultCost,
24
29
  QueryResultMetadata,
30
+ RateLimit,
25
31
  RawInput,
26
32
  RawResponse,
27
33
  TextInput,
@@ -31,6 +37,7 @@ from model_library.base import (
31
37
  ToolResult,
32
38
  )
33
39
  from model_library.exceptions import (
40
+ ImmediateRetryException,
34
41
  MaxOutputTokensExceededError,
35
42
  NoMatchingToolCallError,
36
43
  )
@@ -38,11 +45,15 @@ from model_library.model_utils import get_default_budget_tokens
38
45
  from model_library.providers.openai import OpenAIModel
39
46
  from model_library.register_models import register_provider
40
47
  from model_library.utils import (
41
- create_openai_client_with_defaults,
42
- default_httpx_client,
48
+ create_anthropic_client_with_defaults,
43
49
  )
44
50
 
45
51
 
52
+ class AnthropicConfig(ProviderConfig):
53
+ supports_compute_effort: bool = False
54
+ supports_auto_thinking: bool = False
55
+
56
+
46
57
  class AnthropicBatchMixin(LLMBatchMixin):
47
58
  """Batch processing support for Anthropic's Message Batches API."""
48
59
 
@@ -246,21 +257,27 @@ class AnthropicBatchMixin(LLMBatchMixin):
246
257
 
247
258
  @register_provider("anthropic")
248
259
  class AnthropicModel(LLM):
249
- _client: AsyncAnthropic | None = None
260
+ provider_config = AnthropicConfig()
261
+
262
+ def _get_default_api_key(self) -> str:
263
+ if self.delegate_config:
264
+ return self.delegate_config.api_key.get_secret_value()
265
+ return model_library_settings.ANTHROPIC_API_KEY
250
266
 
251
267
  @override
252
- def get_client(self) -> AsyncAnthropic:
253
- if self._delegate_client:
254
- return self._delegate_client
255
- if not AnthropicModel._client:
268
+ def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
269
+ if not self.has_client():
270
+ assert api_key
256
271
  headers: dict[str, str] = {}
257
- AnthropicModel._client = AsyncAnthropic(
258
- api_key=model_library_settings.ANTHROPIC_API_KEY,
259
- http_client=default_httpx_client(),
260
- max_retries=1,
272
+ client = create_anthropic_client_with_defaults(
273
+ base_url=self.delegate_config.base_url
274
+ if self.delegate_config
275
+ else None,
276
+ api_key=api_key,
261
277
  default_headers=headers,
262
278
  )
263
- return AnthropicModel._client
279
+ self.assign_client(client)
280
+ return super().get_client()
264
281
 
265
282
  def __init__(
266
283
  self,
@@ -268,33 +285,32 @@ class AnthropicModel(LLM):
268
285
  provider: str = "anthropic",
269
286
  *,
270
287
  config: LLMConfig | None = None,
271
- custom_client: AsyncAnthropic | None = None,
288
+ delegate_config: DelegateConfig | None = None,
272
289
  ):
273
- super().__init__(model_name, provider, config=config)
290
+ self.delegate_config = delegate_config
274
291
 
275
- # allow custom client to act as delegate (native)
276
- self._delegate_client: AsyncAnthropic | None = custom_client
292
+ super().__init__(model_name, provider, config=config)
277
293
 
278
294
  # https://docs.anthropic.com/en/api/openai-sdk
279
295
  self.delegate = (
280
296
  None
281
- if self.native or custom_client
297
+ if self.native or self.delegate_config
282
298
  else OpenAIModel(
283
299
  model_name=self.model_name,
284
- provider=provider,
300
+ provider=self.provider,
285
301
  config=config,
286
- custom_client=create_openai_client_with_defaults(
287
- api_key=model_library_settings.ANTHROPIC_API_KEY,
302
+ use_completions=True,
303
+ delegate_config=DelegateConfig(
288
304
  base_url="https://api.anthropic.com/v1/",
305
+ api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
289
306
  ),
290
- use_completions=True,
291
307
  )
292
308
  )
293
309
 
294
310
  # Initialize batch support if enabled
295
311
  # Disable batch when using custom_client (similar to OpenAI)
296
312
  self.supports_batch: bool = (
297
- self.supports_batch and self.native and not custom_client
313
+ self.supports_batch and self.native and not self.delegate_config
298
314
  )
299
315
  self.batch: LLMBatchMixin | None = (
300
316
  AnthropicBatchMixin(self) if self.supports_batch else None
@@ -520,7 +536,6 @@ class AnthropicModel(LLM):
520
536
  **kwargs: object,
521
537
  ) -> dict[str, Any]:
522
538
  body: dict[str, Any] = {
523
- "max_tokens": self.max_tokens,
524
539
  "model": self.model_name,
525
540
  "messages": await self.parse_input(input),
526
541
  }
@@ -534,14 +549,28 @@ class AnthropicModel(LLM):
534
549
  }
535
550
  ]
536
551
 
552
+ if not self.max_tokens:
553
+ raise Exception("Anthropic models require a max_tokens parameter")
554
+
555
+ body["max_tokens"] = self.max_tokens
556
+
537
557
  if self.reasoning:
538
- budget_tokens = kwargs.pop(
539
- "budget_tokens", get_default_budget_tokens(self.max_tokens)
540
- )
541
- body["thinking"] = {
542
- "type": "enabled",
543
- "budget_tokens": budget_tokens,
544
- }
558
+ if self.provider_config.supports_auto_thinking:
559
+ body["thinking"] = {"type": "auto"}
560
+ else:
561
+ budget_tokens = kwargs.pop(
562
+ "budget_tokens", get_default_budget_tokens(self.max_tokens)
563
+ )
564
+ body["thinking"] = {
565
+ "type": "enabled",
566
+ "budget_tokens": budget_tokens,
567
+ }
568
+
569
+ # effort controls compute allocation for text, tool calls, and thinking. Opus-4.5+
570
+ # use instead of reasoning_effort with auto_thinking
571
+ if self.provider_config.supports_compute_effort and self.compute_effort:
572
+ # default is "high"
573
+ body["output_config"] = {"effort": self.compute_effort}
545
574
 
546
575
  # Thinking models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
547
576
  if self.supports_temperature and not self.reasoning:
@@ -577,7 +606,7 @@ class AnthropicModel(LLM):
577
606
  client = self.get_client()
578
607
 
579
608
  # only send betas for the official Anthropic endpoint
580
- is_anthropic_endpoint = self._delegate_client is None
609
+ is_anthropic_endpoint = self.delegate_config is None
581
610
  if not is_anthropic_endpoint:
582
611
  client_base_url = getattr(client, "_base_url", None) or getattr(
583
612
  client, "base_url", None
@@ -587,16 +616,29 @@ class AnthropicModel(LLM):
587
616
 
588
617
  stream_kwargs = {**body}
589
618
  if is_anthropic_endpoint:
590
- betas = ["files-api-2025-04-14", "interleaved-thinking-2025-05-14"]
619
+ betas = ["files-api-2025-04-14"]
620
+ if self.provider_config.supports_auto_thinking:
621
+ betas.extend(
622
+ [
623
+ "auto-thinking-2026-01-12",
624
+ "effort-2025-11-24",
625
+ "max-effort-2026-01-24",
626
+ ]
627
+ )
628
+ else:
629
+ betas.extend(["interleaved-thinking-2025-05-14"])
591
630
  if "sonnet-4-5" in self.model_name:
592
631
  betas.append("context-1m-2025-08-07")
593
632
  stream_kwargs["betas"] = betas
594
633
 
595
- async with client.beta.messages.stream(
596
- **stream_kwargs,
597
- ) as stream: # pyright: ignore[reportAny]
598
- message = await stream.get_final_message()
599
- self.logger.info(f"Anthropic Response finished: {message.id}")
634
+ try:
635
+ async with client.beta.messages.stream(
636
+ **stream_kwargs,
637
+ ) as stream: # pyright: ignore[reportAny]
638
+ message = await stream.get_final_message()
639
+ self.logger.info(f"Anthropic Response finished: {message.id}")
640
+ except APIConnectionError:
641
+ raise ImmediateRetryException("Failed to connect to Anthropic")
600
642
 
601
643
  text = ""
602
644
  reasoning = ""
@@ -632,6 +674,38 @@ class AnthropicModel(LLM):
632
674
  history=[*input, RawResponse(response=message)],
633
675
  )
634
676
 
677
+ @override
678
+ async def get_rate_limit(self) -> RateLimit:
679
+ response = await self.get_client().messages.with_raw_response.create(
680
+ max_tokens=1,
681
+ messages=[
682
+ {
683
+ "role": "user",
684
+ "content": "Ping",
685
+ }
686
+ ],
687
+ model=self.model_name,
688
+ )
689
+ headers = response.headers
690
+
691
+ server_time_str = headers.get("date")
692
+ if server_time_str:
693
+ server_time = datetime.datetime.strptime(
694
+ server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
695
+ ).replace(tzinfo=datetime.timezone.utc)
696
+ timestamp = server_time.timestamp()
697
+ else:
698
+ timestamp = time.time()
699
+
700
+ return RateLimit(
701
+ unix_timestamp=timestamp,
702
+ raw=headers,
703
+ request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
704
+ request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
705
+ token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
706
+ token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
707
+ )
708
+
635
709
  @override
636
710
  async def count_tokens(
637
711
  self,
@@ -645,20 +719,26 @@ class AnthropicModel(LLM):
645
719
  Count the number of tokens using Anthropic's native token counting API.
646
720
  https://docs.anthropic.com/en/docs/build-with-claude/token-counting
647
721
  """
648
- input = [*history, *input]
649
- if not input:
650
- return 0
722
+ try:
723
+ input = [*history, *input]
724
+ if not input:
725
+ return 0
651
726
 
652
- body = await self.build_body(input, tools=tools, **kwargs)
727
+ body = await self.build_body(input, tools=tools, **kwargs)
653
728
 
654
- # Remove fields not supported by count_tokens endpoint
655
- body.pop("max_tokens", None)
656
- body.pop("temperature", None)
729
+ # Remove fields not supported by count_tokens endpoint
730
+ body.pop("max_tokens", None)
731
+ body.pop("temperature", None)
657
732
 
658
- client = self.get_client()
659
- response = await client.messages.count_tokens(**body)
733
+ client = self.get_client()
734
+ response = await client.messages.count_tokens(**body)
660
735
 
661
- return response.input_tokens
736
+ return response.input_tokens
737
+ except Exception as e:
738
+ self.logger.error(f"Error counting tokens: {e}")
739
+ return await super().count_tokens(
740
+ input, history=history, tools=tools, **kwargs
741
+ )
662
742
 
663
743
  @override
664
744
  async def _calculate_cost(
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Literal
2
3
 
3
4
  from openai.lib.azure import AsyncAzureOpenAI
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
14
15
 
15
16
  @register_provider("azure")
16
17
  class AzureOpenAIModel(OpenAIModel):
17
- _azure_client: AsyncAzureOpenAI | None = None
18
-
19
18
  @override
20
- def get_client(self) -> AsyncAzureOpenAI:
21
- if not AzureOpenAIModel._azure_client:
22
- AzureOpenAIModel._azure_client = AsyncAzureOpenAI(
23
- api_key=model_library_settings.AZURE_API_KEY,
24
- azure_endpoint=model_library_settings.AZURE_ENDPOINT,
25
- api_version=model_library_settings.get(
19
+ def _get_default_api_key(self) -> str:
20
+ return json.dumps(
21
+ {
22
+ "AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
23
+ "AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
24
+ "AZURE_API_VERSION": model_library_settings.get(
26
25
  "AZURE_API_VERSION", "2025-04-01-preview"
27
26
  ),
27
+ }
28
+ )
29
+
30
+ @override
31
+ def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
32
+ if not self.has_client():
33
+ assert api_key
34
+ creds = json.loads(api_key)
35
+ client = AsyncAzureOpenAI(
36
+ api_key=creds["AZURE_API_KEY"],
37
+ azure_endpoint=creds["AZURE_ENDPOINT"],
38
+ api_version=creds["AZURE_API_VERSION"],
28
39
  http_client=default_httpx_client(),
29
- max_retries=1,
40
+ max_retries=3,
30
41
  )
31
- return AzureOpenAIModel._azure_client
42
+ self.assign_client(client)
43
+ return super(OpenAIModel, self).get_client(api_key)
32
44
 
33
45
  def __init__(
34
46
  self,
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("cohere")
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.cohere.com/docs/compatibility-api
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.COHERE_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.cohere.ai/compatibility/v1",
30
+ api_key=SecretStr(model_library_settings.COHERE_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
5
5
 
6
6
  from typing import Literal
7
7
 
8
+ from pydantic import SecretStr
9
+
8
10
  from model_library import model_library_settings
9
11
  from model_library.base import (
12
+ DelegateConfig,
10
13
  DelegateOnly,
11
14
  LLMConfig,
12
15
  )
13
- from model_library.providers.openai import OpenAIModel
14
16
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
17
 
17
18
 
18
19
  @register_provider("deepseek")
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
27
28
  super().__init__(model_name, provider, config=config)
28
29
 
29
30
  # https://api-docs.deepseek.com/
30
- self.delegate = OpenAIModel(
31
- model_name=self.model_name,
32
- provider=self.provider,
31
+ self.init_delegate(
33
32
  config=config,
34
- custom_client=create_openai_client_with_defaults(
35
- api_key=model_library_settings.DEEPSEEK_API_KEY,
36
- base_url="https://api.deepseek.com",
33
+ delegate_config=DelegateConfig(
34
+ base_url="https://api.deepseek.com/v1",
35
+ api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
37
36
  ),
38
37
  use_completions=True,
38
+ delegate_provider="openai",
39
39
  )
@@ -1,18 +1,18 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
9
+ DelegateOnly,
7
10
  LLMConfig,
8
11
  ProviderConfig,
9
12
  QueryResultCost,
10
13
  QueryResultMetadata,
11
14
  )
12
- from model_library.base.delegate_only import DelegateOnly
13
- from model_library.providers.openai import OpenAIModel
14
15
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
16
 
17
17
 
18
18
  class FireworksConfig(ProviderConfig):
@@ -38,15 +38,14 @@ class FireworksModel(DelegateOnly):
38
38
  self.model_name = "accounts/rayan-936e28/deployedModels/" + self.model_name
39
39
 
40
40
  # https://docs.fireworks.ai/tools-sdks/openai-compatibility
41
- self.delegate = OpenAIModel(
42
- model_name=self.model_name,
43
- provider=self.provider,
41
+ self.init_delegate(
44
42
  config=config,
45
- custom_client=create_openai_client_with_defaults(
46
- api_key=model_library_settings.FIREWORKS_API_KEY,
43
+ delegate_config=DelegateConfig(
47
44
  base_url="https://api.fireworks.ai/inference/v1",
45
+ api_key=SecretStr(model_library_settings.FIREWORKS_API_KEY),
48
46
  ),
49
47
  use_completions=True,
48
+ delegate_provider="openai",
50
49
  )
51
50
 
52
51
  @override
@@ -24,16 +24,19 @@ from google.genai.types import (
24
24
  )
25
25
 
26
26
 
27
- def extract_text_from_json_response(response: dict[str, Any]) -> str:
27
+ def extract_text_from_json_response(response: dict[str, Any]) -> tuple[str, str]:
28
28
  """Extract concatenated non-thought text from a JSON response structure."""
29
29
  # TODO: fix the typing we always ignore
30
30
  text = ""
31
+ reasoning = ""
31
32
  for candidate in response.get("candidates", []) or []: # type: ignore
32
33
  content = (candidate or {}).get("content") or {} # type: ignore
33
34
  for part in content.get("parts", []) or []: # type: ignore
34
- if not part.get("thought", False): # type: ignore
35
+ if part.get("thought", False): # type: ignore
36
+ reasoning += part.get("text", "") # type: ignore
37
+ else:
35
38
  text += part.get("text", "") # type: ignore
36
- return text # type: ignore
39
+ return text, reasoning # type: ignore
37
40
 
38
41
 
39
42
  def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
@@ -48,9 +51,10 @@ def parse_predictions_jsonl(jsonl: str) -> list[BatchResult]:
48
51
  custom_id = data.get("key", "unknown")
49
52
  if "response" in data:
50
53
  response = data["response"]
51
- text = extract_text_from_json_response(response)
54
+ text, reasoning = extract_text_from_json_response(response)
52
55
  output = QueryResult()
53
56
  output.output_text = text
57
+ output.reasoning = reasoning
54
58
  if "usageMetadata" in response:
55
59
  output.metadata.in_tokens = response["usageMetadata"].get(
56
60
  "promptTokenCount", 0
@@ -196,7 +200,7 @@ class GoogleBatchMixin(LLMBatchMixin):
196
200
  custom_id = labels.get("qa_pair_id", f"request-{i}")
197
201
  jsonl_lines.append(json.dumps({"key": custom_id, "request": request_data}))
198
202
 
199
- batch_request_file = self._root.client.files.upload(
203
+ batch_request_file = self._root.get_client().files.upload(
200
204
  file=io.StringIO("\n".join(jsonl_lines)),
201
205
  config=UploadFileConfig(mime_type="application/jsonl"),
202
206
  )
@@ -205,7 +209,7 @@ class GoogleBatchMixin(LLMBatchMixin):
205
209
  raise Exception("Failed to upload batch jsonl")
206
210
 
207
211
  try:
208
- job: BatchJob = await self._root.client.aio.batches.create(
212
+ job: BatchJob = await self._root.get_client().aio.batches.create(
209
213
  model=self._root.model_name,
210
214
  src=batch_request_file.name,
211
215
  config={"display_name": batch_name},
@@ -224,14 +228,14 @@ class GoogleBatchMixin(LLMBatchMixin):
224
228
  async def get_batch_results(self, batch_id: str) -> list[BatchResult]:
225
229
  self._root.logger.info(f"Retrieving batch results for {batch_id}")
226
230
 
227
- job = await self._root.client.aio.batches.get(name=batch_id)
231
+ job = await self._root.get_client().aio.batches.get(name=batch_id)
228
232
 
229
233
  results: list[BatchResult] = []
230
234
 
231
235
  if job.state == JobState.JOB_STATE_SUCCEEDED:
232
236
  if job.dest and job.dest.file_name:
233
237
  results_file_name = job.dest.file_name
234
- file_content = await self._root.client.aio.files.download(
238
+ file_content = await self._root.get_client().aio.files.download(
235
239
  file=results_file_name
236
240
  )
237
241
  decoded = file_content.decode("utf-8")
@@ -250,7 +254,7 @@ class GoogleBatchMixin(LLMBatchMixin):
250
254
  @override
251
255
  async def cancel_batch_request(self, batch_id: str):
252
256
  self._root.logger.info(f"Cancelling batch {batch_id}")
253
- await self._root.client.aio.batches.cancel(name=batch_id)
257
+ await self._root.get_client().aio.batches.cancel(name=batch_id)
254
258
 
255
259
  @override
256
260
  async def get_batch_progress(self, batch_id: str) -> int:
@@ -262,7 +266,7 @@ class GoogleBatchMixin(LLMBatchMixin):
262
266
 
263
267
  try:
264
268
  self._root.logger.debug(f"Checking batch status for {batch_id}")
265
- job: BatchJob = await self._root.client.aio.batches.get(name=batch_id)
269
+ job: BatchJob = await self._root.get_client().aio.batches.get(name=batch_id)
266
270
  state = job.state
267
271
 
268
272
  if not state: