model-library 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. model_library/base/base.py +139 -62
  2. model_library/base/delegate_only.py +77 -10
  3. model_library/base/output.py +43 -0
  4. model_library/base/utils.py +35 -0
  5. model_library/config/alibaba_models.yaml +44 -57
  6. model_library/config/all_models.json +253 -126
  7. model_library/config/kimi_models.yaml +30 -3
  8. model_library/config/openai_models.yaml +15 -23
  9. model_library/config/zai_models.yaml +24 -3
  10. model_library/exceptions.py +3 -77
  11. model_library/providers/ai21labs.py +12 -8
  12. model_library/providers/alibaba.py +17 -8
  13. model_library/providers/amazon.py +49 -16
  14. model_library/providers/anthropic.py +93 -40
  15. model_library/providers/azure.py +22 -10
  16. model_library/providers/cohere.py +7 -7
  17. model_library/providers/deepseek.py +8 -8
  18. model_library/providers/fireworks.py +7 -8
  19. model_library/providers/google/batch.py +14 -10
  20. model_library/providers/google/google.py +48 -29
  21. model_library/providers/inception.py +7 -7
  22. model_library/providers/kimi.py +18 -8
  23. model_library/providers/minimax.py +15 -17
  24. model_library/providers/mistral.py +20 -8
  25. model_library/providers/openai.py +99 -22
  26. model_library/providers/openrouter.py +34 -0
  27. model_library/providers/perplexity.py +7 -7
  28. model_library/providers/together.py +7 -8
  29. model_library/providers/vals.py +12 -6
  30. model_library/providers/xai.py +47 -42
  31. model_library/providers/zai.py +38 -8
  32. model_library/registry_utils.py +39 -15
  33. model_library/retriers/__init__.py +0 -0
  34. model_library/retriers/backoff.py +73 -0
  35. model_library/retriers/base.py +225 -0
  36. model_library/retriers/token.py +427 -0
  37. model_library/retriers/utils.py +11 -0
  38. model_library/settings.py +1 -1
  39. model_library/utils.py +13 -0
  40. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/METADATA +2 -1
  41. model_library-0.1.8.dist-info/RECORD +70 -0
  42. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  43. model_library-0.1.7.dist-info/RECORD +0 -64
  44. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  45. {model_library-0.1.7.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,9 @@
1
- import logging
2
- import random
3
1
  import re
4
- from typing import Any, Callable
2
+ from typing import Any
5
3
 
6
- import backoff
7
4
  from ai21 import TooManyRequestsError as AI21RateLimitError
8
5
  from anthropic import InternalServerError
9
6
  from anthropic import RateLimitError as AnthropicRateLimitError
10
- from backoff._typing import Details
11
7
  from httpcore import ReadError as HTTPCoreReadError
12
8
  from httpx import ConnectError as HTTPXConnectError
13
9
  from httpx import ReadError as HTTPXReadError
@@ -75,12 +71,14 @@ CONTEXT_WINDOW_PATTERN = re.compile(
75
71
  r"maximum context length is \d+ tokens|"
76
72
  r"context length is \d+ tokens|"
77
73
  r"exceed.* context (limit|window|length)|"
74
+ r"context window exceeds|"
78
75
  r"exceeds maximum length|"
79
76
  r"too long.*tokens.*maximum|"
80
77
  r"too large for model with \d+ maximum context length|"
81
78
  r"longer than the model's context length|"
82
79
  r"too many tokens.*size limit exceeded|"
83
80
  r"prompt is too long|"
81
+ r"maximum prompt length|"
84
82
  r"input length should be|"
85
83
  r"sent message larger than max|"
86
84
  r"input tokens exceeded|"
@@ -222,75 +220,3 @@ def exception_message(exception: Exception | Any) -> str:
222
220
  if str(exception)
223
221
  else type(exception).__name__
224
222
  )
225
-
226
-
227
- RETRY_MAX_TRIES: int = 20
228
- RETRY_INITIAL: float = 10.0
229
- RETRY_EXPO: float = 1.4
230
- RETRY_MAX_BACKOFF_WAIT: float = 240.0 # 4 minutes (more with jitter)
231
-
232
-
233
- def jitter(wait: float) -> float:
234
- """
235
- Increase or decrease the wait time by up to 20%.
236
- """
237
- jitter_fraction = 0.2
238
- min_wait = wait * (1 - jitter_fraction)
239
- max_wait = wait * (1 + jitter_fraction)
240
- return random.uniform(min_wait, max_wait)
241
-
242
-
243
- def retry_llm_call(
244
- logger: logging.Logger,
245
- max_tries: int = RETRY_MAX_TRIES,
246
- max_time: float | None = None,
247
- backoff_callback: (
248
- Callable[[int, Exception | None, float, float], None] | None
249
- ) = None,
250
- ):
251
- def on_backoff(details: Details):
252
- exception = details.get("exception")
253
- tries = details.get("tries", 0)
254
- elapsed = details.get("elapsed", 0.0)
255
- wait = details.get("wait", 0.0)
256
-
257
- logger.warning(
258
- f"[Retrying] Exception: {exception_message(exception)} | Attempt: {tries} | "
259
- + f"Elapsed: {elapsed:.1f}s | Next wait: {wait:.1f}s"
260
- )
261
-
262
- if backoff_callback:
263
- backoff_callback(tries, exception, elapsed, wait)
264
-
265
- def giveup(e: Exception) -> bool:
266
- return not is_retriable_error(e)
267
-
268
- def on_giveup(details: Details) -> None:
269
- exception: Exception | None = details.get("exception", None)
270
- if not exception:
271
- return
272
-
273
- logger.error(
274
- f"Giving up after retries. Final exception: {exception_message(exception)}"
275
- )
276
-
277
- if is_context_window_error(exception):
278
- message = exception.args[0] if exception.args else str(exception)
279
- raise MaxContextWindowExceededError(message)
280
-
281
- raise exception
282
-
283
- return backoff.on_exception(
284
- wait_gen=lambda: backoff.expo(
285
- base=RETRY_EXPO,
286
- factor=RETRY_INITIAL,
287
- max_value=RETRY_MAX_BACKOFF_WAIT,
288
- ),
289
- exception=Exception,
290
- max_tries=max_tries,
291
- max_time=max_time,
292
- giveup=giveup,
293
- on_backoff=on_backoff,
294
- on_giveup=on_giveup,
295
- jitter=jitter,
296
- )
@@ -16,13 +16,13 @@ from model_library.base import (
16
16
  LLMConfig,
17
17
  QueryResult,
18
18
  QueryResultMetadata,
19
+ RawResponse,
19
20
  TextInput,
20
21
  ToolBody,
21
22
  ToolCall,
22
23
  ToolDefinition,
23
24
  ToolResult,
24
25
  )
25
- from model_library.base.input import RawResponse
26
26
  from model_library.exceptions import (
27
27
  BadInputError,
28
28
  MaxOutputTokensExceededError,
@@ -34,17 +34,21 @@ from model_library.utils import default_httpx_client
34
34
 
35
35
  @register_provider("ai21labs")
36
36
  class AI21LabsModel(LLM):
37
- _client: AsyncAI21Client | None = None
37
+ @override
38
+ def _get_default_api_key(self) -> str:
39
+ return model_library_settings.AI21LABS_API_KEY
38
40
 
39
41
  @override
40
- def get_client(self) -> AsyncAI21Client:
41
- if not AI21LabsModel._client:
42
- AI21LabsModel._client = AsyncAI21Client(
43
- api_key=model_library_settings.AI21LABS_API_KEY,
42
+ def get_client(self, api_key: str | None = None) -> AsyncAI21Client:
43
+ if not self.has_client():
44
+ assert api_key
45
+ client = AsyncAI21Client(
46
+ api_key=api_key,
44
47
  http_client=default_httpx_client(),
45
- num_retries=1,
48
+ num_retries=3,
46
49
  )
47
- return AI21LabsModel._client
50
+ self.assign_client(client)
51
+ return super().get_client()
48
52
 
49
53
  def __init__(
50
54
  self,
@@ -1,17 +1,17 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
7
9
  DelegateOnly,
8
10
  LLMConfig,
9
11
  QueryResultCost,
10
12
  QueryResultMetadata,
11
13
  )
12
- from model_library.providers.openai import OpenAIModel
13
14
  from model_library.register_models import register_provider
14
- from model_library.utils import create_openai_client_with_defaults
15
15
 
16
16
 
17
17
  @register_provider("alibaba")
@@ -26,17 +26,26 @@ class AlibabaModel(DelegateOnly):
26
26
  super().__init__(model_name, provider, config=config)
27
27
 
28
28
  # https://www.alibabacloud.com/help/en/model-studio/first-api-call-to-qwen
29
- self.delegate = OpenAIModel(
30
- model_name=self.model_name,
31
- provider=self.provider,
29
+ self.init_delegate(
32
30
  config=config,
33
- custom_client=create_openai_client_with_defaults(
34
- api_key=model_library_settings.DASHSCOPE_API_KEY,
31
+ delegate_config=DelegateConfig(
35
32
  base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
33
+ api_key=SecretStr(model_library_settings.DASHSCOPE_API_KEY),
36
34
  ),
37
35
  use_completions=True,
36
+ delegate_provider="openai",
38
37
  )
39
38
 
39
+ @override
40
+ def _get_extra_body(self) -> dict[str, Any]:
41
+ """Build extra body parameters for Qwen-specific features."""
42
+ extra: dict[str, Any] = {}
43
+ # Enable thinking mode for Qwen3 reasoning models
44
+ # https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api
45
+ if self.reasoning:
46
+ extra["enable_thinking"] = True
47
+ return extra
48
+
40
49
  @override
41
50
  async def _calculate_cost(
42
51
  self,
@@ -11,6 +11,7 @@ import botocore
11
11
  from botocore.client import BaseClient
12
12
  from typing_extensions import override
13
13
 
14
+ from model_library import model_library_settings
14
15
  from model_library.base import (
15
16
  LLM,
16
17
  FileBase,
@@ -41,20 +42,46 @@ from model_library.register_models import register_provider
41
42
  @register_provider("amazon")
42
43
  @register_provider("bedrock")
43
44
  class AmazonModel(LLM):
44
- _client: BaseClient | None = None
45
+ @override
46
+ def _get_default_api_key(self) -> str:
47
+ if getattr(model_library_settings, "AWS_ACCESS_KEY_ID", None):
48
+ return json.dumps(
49
+ {
50
+ "AWS_ACCESS_KEY_ID": model_library_settings.AWS_ACCESS_KEY_ID,
51
+ "AWS_SECRET_ACCESS_KEY": model_library_settings.AWS_SECRET_ACCESS_KEY,
52
+ "AWS_DEFAULT_REGION": model_library_settings.AWS_DEFAULT_REGION,
53
+ }
54
+ )
55
+ return "using-environment"
45
56
 
46
57
  @override
47
- def get_client(self) -> BaseClient:
48
- if not AmazonModel._client:
49
- AmazonModel._client = cast(
50
- BaseClient,
51
- boto3.client(
52
- "bedrock-runtime",
53
- # default connection pool is 10
54
- config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
55
- ),
56
- ) # pyright: ignore[reportUnknownMemberType]
57
- return AmazonModel._client
58
+ def get_client(self, api_key: str | None = None) -> BaseClient:
59
+ if not self.has_client():
60
+ assert api_key
61
+ if api_key != "using-environment":
62
+ creds = json.loads(api_key)
63
+ client = cast(
64
+ BaseClient,
65
+ boto3.client(
66
+ "bedrock-runtime",
67
+ aws_access_key_id=creds["AWS_ACCESS_KEY_ID"],
68
+ aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"],
69
+ region_name=creds["AWS_DEFAULT_REGION"],
70
+ config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
71
+ ),
72
+ )
73
+ else:
74
+ client = cast(
75
+ BaseClient,
76
+ boto3.client(
77
+ "bedrock-runtime",
78
+ # default connection pool is 10
79
+ config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
80
+ ),
81
+ )
82
+
83
+ self.assign_client(client)
84
+ return super().get_client()
58
85
 
59
86
  def __init__(
60
87
  self,
@@ -70,6 +97,11 @@ class AmazonModel(LLM):
70
97
  ) # supported but no access yet
71
98
  self.supports_tool_cache = self.supports_cache and "claude" in self.model_name
72
99
 
100
+ if config and config.custom_api_key:
101
+ raise Exception(
102
+ "custom_api_key is not currently supported for Amazon models"
103
+ )
104
+
73
105
  cache_control = {"type": "default"}
74
106
 
75
107
  async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
@@ -238,7 +270,7 @@ class AmazonModel(LLM):
238
270
  if self.supports_cache:
239
271
  body["system"].append({"cachePoint": self.cache_control})
240
272
 
241
- if self.reasoning:
273
+ if self.reasoning and self.max_tokens:
242
274
  if self.max_tokens < 1024:
243
275
  self.max_tokens = 2048
244
276
  budget_tokens = kwargs.pop(
@@ -251,9 +283,10 @@ class AmazonModel(LLM):
251
283
  }
252
284
  }
253
285
 
254
- inference: dict[str, Any] = {
255
- "maxTokens": self.max_tokens,
256
- }
286
+ inference: dict[str, Any] = {}
287
+
288
+ if self.max_tokens:
289
+ inference["maxTokens"] = self.max_tokens
257
290
 
258
291
  # Only set temperature for models where supports_temperature is True.
259
292
  # For example, "thinking" models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
@@ -1,16 +1,20 @@
1
+ import datetime
1
2
  import io
2
3
  import logging
4
+ import time
3
5
  from typing import Any, Literal, Sequence, cast
4
6
 
5
- from anthropic import AsyncAnthropic
7
+ from anthropic import APIConnectionError, AsyncAnthropic
6
8
  from anthropic.types.beta.beta_tool_use_block import BetaToolUseBlock
7
9
  from anthropic.types.beta.parsed_beta_message import ParsedBetaMessage
10
+ from pydantic import SecretStr
8
11
  from typing_extensions import override
9
12
 
10
13
  from model_library import model_library_settings
11
14
  from model_library.base import (
12
15
  LLM,
13
16
  BatchResult,
17
+ DelegateConfig,
14
18
  FileBase,
15
19
  FileInput,
16
20
  FileWithBase64,
@@ -22,6 +26,7 @@ from model_library.base import (
22
26
  QueryResult,
23
27
  QueryResultCost,
24
28
  QueryResultMetadata,
29
+ RateLimit,
25
30
  RawInput,
26
31
  RawResponse,
27
32
  TextInput,
@@ -31,6 +36,7 @@ from model_library.base import (
31
36
  ToolResult,
32
37
  )
33
38
  from model_library.exceptions import (
39
+ ImmediateRetryException,
34
40
  MaxOutputTokensExceededError,
35
41
  NoMatchingToolCallError,
36
42
  )
@@ -38,8 +44,7 @@ from model_library.model_utils import get_default_budget_tokens
38
44
  from model_library.providers.openai import OpenAIModel
39
45
  from model_library.register_models import register_provider
40
46
  from model_library.utils import (
41
- create_openai_client_with_defaults,
42
- default_httpx_client,
47
+ create_anthropic_client_with_defaults,
43
48
  )
44
49
 
45
50
 
@@ -246,21 +251,25 @@ class AnthropicBatchMixin(LLMBatchMixin):
246
251
 
247
252
  @register_provider("anthropic")
248
253
  class AnthropicModel(LLM):
249
- _client: AsyncAnthropic | None = None
254
+ def _get_default_api_key(self) -> str:
255
+ if self.delegate_config:
256
+ return self.delegate_config.api_key.get_secret_value()
257
+ return model_library_settings.ANTHROPIC_API_KEY
250
258
 
251
259
  @override
252
- def get_client(self) -> AsyncAnthropic:
253
- if self._delegate_client:
254
- return self._delegate_client
255
- if not AnthropicModel._client:
260
+ def get_client(self, api_key: str | None = None) -> AsyncAnthropic:
261
+ if not self.has_client():
262
+ assert api_key
256
263
  headers: dict[str, str] = {}
257
- AnthropicModel._client = AsyncAnthropic(
258
- api_key=model_library_settings.ANTHROPIC_API_KEY,
259
- http_client=default_httpx_client(),
260
- max_retries=1,
264
+ client = create_anthropic_client_with_defaults(
265
+ base_url=self.delegate_config.base_url
266
+ if self.delegate_config
267
+ else None,
268
+ api_key=api_key,
261
269
  default_headers=headers,
262
270
  )
263
- return AnthropicModel._client
271
+ self.assign_client(client)
272
+ return super().get_client()
264
273
 
265
274
  def __init__(
266
275
  self,
@@ -268,33 +277,32 @@ class AnthropicModel(LLM):
268
277
  provider: str = "anthropic",
269
278
  *,
270
279
  config: LLMConfig | None = None,
271
- custom_client: AsyncAnthropic | None = None,
280
+ delegate_config: DelegateConfig | None = None,
272
281
  ):
273
- super().__init__(model_name, provider, config=config)
282
+ self.delegate_config = delegate_config
274
283
 
275
- # allow custom client to act as delegate (native)
276
- self._delegate_client: AsyncAnthropic | None = custom_client
284
+ super().__init__(model_name, provider, config=config)
277
285
 
278
286
  # https://docs.anthropic.com/en/api/openai-sdk
279
287
  self.delegate = (
280
288
  None
281
- if self.native or custom_client
289
+ if self.native or self.delegate_config
282
290
  else OpenAIModel(
283
291
  model_name=self.model_name,
284
- provider=provider,
292
+ provider=self.provider,
285
293
  config=config,
286
- custom_client=create_openai_client_with_defaults(
287
- api_key=model_library_settings.ANTHROPIC_API_KEY,
294
+ use_completions=True,
295
+ delegate_config=DelegateConfig(
288
296
  base_url="https://api.anthropic.com/v1/",
297
+ api_key=SecretStr(model_library_settings.ANTHROPIC_API_KEY),
289
298
  ),
290
- use_completions=True,
291
299
  )
292
300
  )
293
301
 
294
302
  # Initialize batch support if enabled
295
303
  # Disable batch when using custom_client (similar to OpenAI)
296
304
  self.supports_batch: bool = (
297
- self.supports_batch and self.native and not custom_client
305
+ self.supports_batch and self.native and not self.delegate_config
298
306
  )
299
307
  self.batch: LLMBatchMixin | None = (
300
308
  AnthropicBatchMixin(self) if self.supports_batch else None
@@ -520,7 +528,6 @@ class AnthropicModel(LLM):
520
528
  **kwargs: object,
521
529
  ) -> dict[str, Any]:
522
530
  body: dict[str, Any] = {
523
- "max_tokens": self.max_tokens,
524
531
  "model": self.model_name,
525
532
  "messages": await self.parse_input(input),
526
533
  }
@@ -534,6 +541,11 @@ class AnthropicModel(LLM):
534
541
  }
535
542
  ]
536
543
 
544
+ if not self.max_tokens:
545
+ raise Exception("Anthropic models require a max_tokens parameter")
546
+
547
+ body["max_tokens"] = self.max_tokens
548
+
537
549
  if self.reasoning:
538
550
  budget_tokens = kwargs.pop(
539
551
  "budget_tokens", get_default_budget_tokens(self.max_tokens)
@@ -577,7 +589,7 @@ class AnthropicModel(LLM):
577
589
  client = self.get_client()
578
590
 
579
591
  # only send betas for the official Anthropic endpoint
580
- is_anthropic_endpoint = self._delegate_client is None
592
+ is_anthropic_endpoint = self.delegate_config is None
581
593
  if not is_anthropic_endpoint:
582
594
  client_base_url = getattr(client, "_base_url", None) or getattr(
583
595
  client, "base_url", None
@@ -592,11 +604,14 @@ class AnthropicModel(LLM):
592
604
  betas.append("context-1m-2025-08-07")
593
605
  stream_kwargs["betas"] = betas
594
606
 
595
- async with client.beta.messages.stream(
596
- **stream_kwargs,
597
- ) as stream: # pyright: ignore[reportAny]
598
- message = await stream.get_final_message()
599
- self.logger.info(f"Anthropic Response finished: {message.id}")
607
+ try:
608
+ async with client.beta.messages.stream(
609
+ **stream_kwargs,
610
+ ) as stream: # pyright: ignore[reportAny]
611
+ message = await stream.get_final_message()
612
+ self.logger.info(f"Anthropic Response finished: {message.id}")
613
+ except APIConnectionError:
614
+ raise ImmediateRetryException("Failed to connect to Anthropic")
600
615
 
601
616
  text = ""
602
617
  reasoning = ""
@@ -632,6 +647,38 @@ class AnthropicModel(LLM):
632
647
  history=[*input, RawResponse(response=message)],
633
648
  )
634
649
 
650
+ @override
651
+ async def get_rate_limit(self) -> RateLimit:
652
+ response = await self.get_client().messages.with_raw_response.create(
653
+ max_tokens=1,
654
+ messages=[
655
+ {
656
+ "role": "user",
657
+ "content": "Ping",
658
+ }
659
+ ],
660
+ model=self.model_name,
661
+ )
662
+ headers = response.headers
663
+
664
+ server_time_str = headers.get("date")
665
+ if server_time_str:
666
+ server_time = datetime.datetime.strptime(
667
+ server_time_str, "%a, %d %b %Y %H:%M:%S GMT"
668
+ ).replace(tzinfo=datetime.timezone.utc)
669
+ timestamp = server_time.timestamp()
670
+ else:
671
+ timestamp = time.time()
672
+
673
+ return RateLimit(
674
+ unix_timestamp=timestamp,
675
+ raw=headers,
676
+ request_limit=int(headers["anthropic-ratelimit-requests-limit"]),
677
+ request_remaining=int(headers["anthropic-ratelimit-requests-remaining"]),
678
+ token_limit=int(response.headers["anthropic-ratelimit-tokens-limit"]),
679
+ token_remaining=int(headers["anthropic-ratelimit-tokens-remaining"]),
680
+ )
681
+
635
682
  @override
636
683
  async def count_tokens(
637
684
  self,
@@ -645,20 +692,26 @@ class AnthropicModel(LLM):
645
692
  Count the number of tokens using Anthropic's native token counting API.
646
693
  https://docs.anthropic.com/en/docs/build-with-claude/token-counting
647
694
  """
648
- input = [*history, *input]
649
- if not input:
650
- return 0
695
+ try:
696
+ input = [*history, *input]
697
+ if not input:
698
+ return 0
651
699
 
652
- body = await self.build_body(input, tools=tools, **kwargs)
700
+ body = await self.build_body(input, tools=tools, **kwargs)
653
701
 
654
- # Remove fields not supported by count_tokens endpoint
655
- body.pop("max_tokens", None)
656
- body.pop("temperature", None)
702
+ # Remove fields not supported by count_tokens endpoint
703
+ body.pop("max_tokens", None)
704
+ body.pop("temperature", None)
657
705
 
658
- client = self.get_client()
659
- response = await client.messages.count_tokens(**body)
706
+ client = self.get_client()
707
+ response = await client.messages.count_tokens(**body)
660
708
 
661
- return response.input_tokens
709
+ return response.input_tokens
710
+ except Exception as e:
711
+ self.logger.error(f"Error counting tokens: {e}")
712
+ return await super().count_tokens(
713
+ input, history=history, tools=tools, **kwargs
714
+ )
662
715
 
663
716
  @override
664
717
  async def _calculate_cost(
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from typing import Literal
2
3
 
3
4
  from openai.lib.azure import AsyncAzureOpenAI
@@ -14,21 +15,32 @@ from model_library.utils import default_httpx_client
14
15
 
15
16
  @register_provider("azure")
16
17
  class AzureOpenAIModel(OpenAIModel):
17
- _azure_client: AsyncAzureOpenAI | None = None
18
-
19
18
  @override
20
- def get_client(self) -> AsyncAzureOpenAI:
21
- if not AzureOpenAIModel._azure_client:
22
- AzureOpenAIModel._azure_client = AsyncAzureOpenAI(
23
- api_key=model_library_settings.AZURE_API_KEY,
24
- azure_endpoint=model_library_settings.AZURE_ENDPOINT,
25
- api_version=model_library_settings.get(
19
+ def _get_default_api_key(self) -> str:
20
+ return json.dumps(
21
+ {
22
+ "AZURE_API_KEY": model_library_settings.AZURE_API_KEY,
23
+ "AZURE_ENDPOINT": model_library_settings.AZURE_ENDPOINT,
24
+ "AZURE_API_VERSION": model_library_settings.get(
26
25
  "AZURE_API_VERSION", "2025-04-01-preview"
27
26
  ),
27
+ }
28
+ )
29
+
30
+ @override
31
+ def get_client(self, api_key: str | None = None) -> AsyncAzureOpenAI:
32
+ if not self.has_client():
33
+ assert api_key
34
+ creds = json.loads(api_key)
35
+ client = AsyncAzureOpenAI(
36
+ api_key=creds["AZURE_API_KEY"],
37
+ azure_endpoint=creds["AZURE_ENDPOINT"],
38
+ api_version=creds["AZURE_API_VERSION"],
28
39
  http_client=default_httpx_client(),
29
- max_retries=1,
40
+ max_retries=3,
30
41
  )
31
- return AzureOpenAIModel._azure_client
42
+ self.assign_client(client)
43
+ return super(OpenAIModel, self).get_client(api_key)
32
44
 
33
45
  def __init__(
34
46
  self,
@@ -1,13 +1,14 @@
1
1
  from typing import Literal
2
2
 
3
+ from pydantic import SecretStr
4
+
3
5
  from model_library import model_library_settings
4
6
  from model_library.base import (
7
+ DelegateConfig,
5
8
  DelegateOnly,
6
9
  LLMConfig,
7
10
  )
8
- from model_library.providers.openai import OpenAIModel
9
11
  from model_library.register_models import register_provider
10
- from model_library.utils import create_openai_client_with_defaults
11
12
 
12
13
 
13
14
  @register_provider("cohere")
@@ -22,13 +23,12 @@ class CohereModel(DelegateOnly):
22
23
  super().__init__(model_name, provider, config=config)
23
24
 
24
25
  # https://docs.cohere.com/docs/compatibility-api
25
- self.delegate = OpenAIModel(
26
- model_name=self.model_name,
27
- provider=self.provider,
26
+ self.init_delegate(
28
27
  config=config,
29
- custom_client=create_openai_client_with_defaults(
30
- api_key=model_library_settings.COHERE_API_KEY,
28
+ delegate_config=DelegateConfig(
31
29
  base_url="https://api.cohere.ai/compatibility/v1",
30
+ api_key=SecretStr(model_library_settings.COHERE_API_KEY),
32
31
  ),
33
32
  use_completions=True,
33
+ delegate_provider="openai",
34
34
  )
@@ -5,14 +5,15 @@ https://cdn.deepseek.com/policies/en-US/deepseek-privacy-policy.html
5
5
 
6
6
  from typing import Literal
7
7
 
8
+ from pydantic import SecretStr
9
+
8
10
  from model_library import model_library_settings
9
11
  from model_library.base import (
12
+ DelegateConfig,
10
13
  DelegateOnly,
11
14
  LLMConfig,
12
15
  )
13
- from model_library.providers.openai import OpenAIModel
14
16
  from model_library.register_models import register_provider
15
- from model_library.utils import create_openai_client_with_defaults
16
17
 
17
18
 
18
19
  @register_provider("deepseek")
@@ -27,13 +28,12 @@ class DeepSeekModel(DelegateOnly):
27
28
  super().__init__(model_name, provider, config=config)
28
29
 
29
30
  # https://api-docs.deepseek.com/
30
- self.delegate = OpenAIModel(
31
- model_name=self.model_name,
32
- provider=self.provider,
31
+ self.init_delegate(
33
32
  config=config,
34
- custom_client=create_openai_client_with_defaults(
35
- api_key=model_library_settings.DEEPSEEK_API_KEY,
36
- base_url="https://api.deepseek.com",
33
+ delegate_config=DelegateConfig(
34
+ base_url="https://api.deepseek.com/v1",
35
+ api_key=SecretStr(model_library_settings.DEEPSEEK_API_KEY),
37
36
  ),
38
37
  use_completions=True,
38
+ delegate_provider="openai",
39
39
  )