model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,9 @@
1
- import logging
2
- import random
3
1
  import re
4
- from typing import Any, Callable
2
+ from typing import Any
5
3
 
6
- import backoff
7
4
  from ai21 import TooManyRequestsError as AI21RateLimitError
8
5
  from anthropic import InternalServerError
9
6
  from anthropic import RateLimitError as AnthropicRateLimitError
10
- from backoff._typing import Details
11
7
  from httpcore import ReadError as HTTPCoreReadError
12
8
  from httpx import ConnectError as HTTPXConnectError
13
9
  from httpx import ReadError as HTTPXReadError
@@ -75,12 +71,14 @@ CONTEXT_WINDOW_PATTERN = re.compile(
75
71
  r"maximum context length is \d+ tokens|"
76
72
  r"context length is \d+ tokens|"
77
73
  r"exceed.* context (limit|window|length)|"
74
+ r"context window exceeds|"
78
75
  r"exceeds maximum length|"
79
76
  r"too long.*tokens.*maximum|"
80
77
  r"too large for model with \d+ maximum context length|"
81
78
  r"longer than the model's context length|"
82
79
  r"too many tokens.*size limit exceeded|"
83
80
  r"prompt is too long|"
81
+ r"maximum prompt length|"
84
82
  r"input length should be|"
85
83
  r"sent message larger than max|"
86
84
  r"input tokens exceeded|"
@@ -146,6 +144,17 @@ class BadInputError(Exception):
146
144
  super().__init__(message or BadInputError.DEFAULT_MESSAGE)
147
145
 
148
146
 
147
+ class NoMatchingToolCallError(Exception):
148
+ """
149
+ Raised when a tool call result is provided with no matching tool call
150
+ """
151
+
152
+ DEFAULT_MESSAGE: str = "Tool call result provided with no matching tool call"
153
+
154
+ def __init__(self, message: str | None = None):
155
+ super().__init__(message or NoMatchingToolCallError.DEFAULT_MESSAGE)
156
+
157
+
149
158
  # Add more retriable exceptions as needed
150
159
  # Providers that don't have an explicit rate limit error are handled manually
151
160
  # by wrapping errored Http/gRPC requests with a BackoffRetryException
@@ -211,75 +220,3 @@ def exception_message(exception: Exception | Any) -> str:
211
220
  if str(exception)
212
221
  else type(exception).__name__
213
222
  )
214
-
215
-
216
- RETRY_MAX_TRIES: int = 20
217
- RETRY_INITIAL: float = 10.0
218
- RETRY_EXPO: float = 1.4
219
- RETRY_MAX_BACKOFF_WAIT: float = 240.0 # 4 minutes (more with jitter)
220
-
221
-
222
- def jitter(wait: float) -> float:
223
- """
224
- Increase or decrease the wait time by up to 20%.
225
- """
226
- jitter_fraction = 0.2
227
- min_wait = wait * (1 - jitter_fraction)
228
- max_wait = wait * (1 + jitter_fraction)
229
- return random.uniform(min_wait, max_wait)
230
-
231
-
232
- def retry_llm_call(
233
- logger: logging.Logger,
234
- max_tries: int = RETRY_MAX_TRIES,
235
- max_time: float | None = None,
236
- backoff_callback: (
237
- Callable[[int, Exception | None, float, float], None] | None
238
- ) = None,
239
- ):
240
- def on_backoff(details: Details):
241
- exception = details.get("exception")
242
- tries = details.get("tries", 0)
243
- elapsed = details.get("elapsed", 0.0)
244
- wait = details.get("wait", 0.0)
245
-
246
- logger.warning(
247
- f"[Retrying] Exception: {exception_message(exception)} | Attempt: {tries} | "
248
- + f"Elapsed: {elapsed:.1f}s | Next wait: {wait:.1f}s"
249
- )
250
-
251
- if backoff_callback:
252
- backoff_callback(tries, exception, elapsed, wait)
253
-
254
- def giveup(e: Exception) -> bool:
255
- return not is_retriable_error(e)
256
-
257
- def on_giveup(details: Details) -> None:
258
- exception: Exception | None = details.get("exception", None)
259
- if not exception:
260
- return
261
-
262
- logger.error(
263
- f"Giving up after retries. Final exception: {exception_message(exception)}"
264
- )
265
-
266
- if is_context_window_error(exception):
267
- message = exception.args[0] if exception.args else str(exception)
268
- raise MaxContextWindowExceededError(message)
269
-
270
- raise exception
271
-
272
- return backoff.on_exception(
273
- wait_gen=lambda: backoff.expo(
274
- base=RETRY_EXPO,
275
- factor=RETRY_INITIAL,
276
- max_value=RETRY_MAX_BACKOFF_WAIT,
277
- ),
278
- exception=Exception,
279
- max_tries=max_tries,
280
- max_time=max_time,
281
- giveup=giveup,
282
- on_backoff=on_backoff,
283
- on_giveup=on_giveup,
284
- jitter=jitter,
285
- )
model_library/logging.py CHANGED
@@ -6,7 +6,11 @@ from rich.logging import RichHandler
6
6
  _llm_logger = logging.getLogger("llm")
7
7
 
8
8
 
9
- def set_logging(enable: bool = True, handler: logging.Handler | None = None):
9
+ def set_logging(
10
+ enable: bool = True,
11
+ level: int = logging.INFO,
12
+ handler: logging.Handler | None = None,
13
+ ):
10
14
  """
11
15
  Sets up logging for the model library
12
16
 
@@ -15,7 +19,7 @@ def set_logging(enable: bool = True, handler: logging.Handler | None = None):
15
19
  handler (logging.Handler, optional): A custom logging handler. Defaults to RichHandler.
16
20
  """
17
21
  if enable:
18
- _llm_logger.setLevel(logging.INFO)
22
+ _llm_logger.setLevel(level)
19
23
  else:
20
24
  _llm_logger.setLevel(logging.CRITICAL)
21
25
 
@@ -16,6 +16,7 @@ from model_library.base import (
16
16
  LLMConfig,
17
17
  QueryResult,
18
18
  QueryResultMetadata,
19
+ RawResponse,
19
20
  TextInput,
20
21
  ToolBody,
21
22
  ToolCall,
@@ -33,17 +34,21 @@ from model_library.utils import default_httpx_client
33
34
 
34
35
  @register_provider("ai21labs")
35
36
  class AI21LabsModel(LLM):
36
- _client: AsyncAI21Client | None = None
37
+ @override
38
+ def _get_default_api_key(self) -> str:
39
+ return model_library_settings.AI21LABS_API_KEY
37
40
 
38
41
  @override
39
- def get_client(self) -> AsyncAI21Client:
40
- if not AI21LabsModel._client:
41
- AI21LabsModel._client = AsyncAI21Client(
42
- api_key=model_library_settings.AI21LABS_API_KEY,
42
+ def get_client(self, api_key: str | None = None) -> AsyncAI21Client:
43
+ if not self.has_client():
44
+ assert api_key
45
+ client = AsyncAI21Client(
46
+ api_key=api_key,
43
47
  http_client=default_httpx_client(),
44
- num_retries=1,
48
+ num_retries=3,
45
49
  )
46
- return AI21LabsModel._client
50
+ self.assign_client(client)
51
+ return super().get_client()
47
52
 
48
53
  def __init__(
49
54
  self,
@@ -65,8 +70,6 @@ class AI21LabsModel(LLM):
65
70
  match item:
66
71
  case TextInput():
67
72
  new_input.append(ChatMessage(role="user", content=item.text))
68
- case AssistantMessage():
69
- new_input.append(item)
70
73
  case ToolResult():
71
74
  new_input.append(
72
75
  ToolMessage(
@@ -74,7 +77,9 @@ class AI21LabsModel(LLM):
74
77
  content=item.result,
75
78
  tool_call_id=item.tool_call.id,
76
79
  )
77
- )
80
+ ) # TODO: tool calling metadata and test
81
+ case RawResponse():
82
+ new_input.append(item.response)
78
83
  case _:
79
84
  raise BadInputError("Unsupported input type")
80
85
  return new_input
@@ -133,14 +138,13 @@ class AI21LabsModel(LLM):
133
138
  raise NotImplementedError()
134
139
 
135
140
  @override
136
- async def _query_impl(
141
+ async def build_body(
137
142
  self,
138
143
  input: Sequence[InputItem],
139
144
  *,
140
145
  tools: list[ToolDefinition],
141
- query_logger: logging.Logger,
142
146
  **kwargs: object,
143
- ) -> QueryResult:
147
+ ) -> dict[str, Any]:
144
148
  messages: list[ChatMessage] = []
145
149
  if "system_prompt" in kwargs:
146
150
  messages.append(
@@ -162,6 +166,18 @@ class AI21LabsModel(LLM):
162
166
  body["top_p"] = self.top_p
163
167
 
164
168
  body.update(kwargs)
169
+ return body
170
+
171
+ @override
172
+ async def _query_impl(
173
+ self,
174
+ input: Sequence[InputItem],
175
+ *,
176
+ tools: list[ToolDefinition],
177
+ query_logger: logging.Logger,
178
+ **kwargs: object,
179
+ ) -> QueryResult:
180
+ body = await self.build_body(input, tools=tools, **kwargs)
165
181
 
166
182
  response: ChatCompletionResponse = (
167
183
  await self.get_client().chat.completions.create(**body, stream=False) # pyright: ignore[reportAny, reportUnknownMemberType]
@@ -186,7 +202,7 @@ class AI21LabsModel(LLM):
186
202
 
187
203
  output = QueryResult(
188
204
  output_text=choice.message.content,
189
- history=[*input, choice.message],
205
+ history=[*input, RawResponse(response=choice.message)],
190
206
  metadata=QueryResultMetadata(
191
207
  in_tokens=response.usage.prompt_tokens,
192
208
  out_tokens=response.usage.completion_tokens,
@@ -1,17 +1,17 @@
1
- from typing import Literal
1
+ from typing import Any, Literal
2
2
 
3
+ from pydantic import SecretStr
3
4
  from typing_extensions import override
4
5
 
5
6
  from model_library import model_library_settings
6
7
  from model_library.base import (
8
+ DelegateConfig,
7
9
  DelegateOnly,
8
10
  LLMConfig,
9
11
  QueryResultCost,
10
12
  QueryResultMetadata,
11
13
  )
12
- from model_library.providers.openai import OpenAIModel
13
14
  from model_library.register_models import register_provider
14
- from model_library.utils import create_openai_client_with_defaults
15
15
 
16
16
 
17
17
  @register_provider("alibaba")
@@ -26,17 +26,26 @@ class AlibabaModel(DelegateOnly):
26
26
  super().__init__(model_name, provider, config=config)
27
27
 
28
28
  # https://www.alibabacloud.com/help/en/model-studio/first-api-call-to-qwen
29
- self.delegate = OpenAIModel(
30
- model_name=self.model_name,
31
- provider=self.provider,
29
+ self.init_delegate(
32
30
  config=config,
33
- custom_client=create_openai_client_with_defaults(
34
- api_key=model_library_settings.DASHSCOPE_API_KEY,
31
+ delegate_config=DelegateConfig(
35
32
  base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
33
+ api_key=SecretStr(model_library_settings.DASHSCOPE_API_KEY),
36
34
  ),
37
35
  use_completions=True,
36
+ delegate_provider="openai",
38
37
  )
39
38
 
39
+ @override
40
+ def _get_extra_body(self) -> dict[str, Any]:
41
+ """Build extra body parameters for Qwen-specific features."""
42
+ extra: dict[str, Any] = {}
43
+ # Enable thinking mode for Qwen3 reasoning models
44
+ # https://www.alibabacloud.com/help/en/model-studio/use-qwen-by-calling-api
45
+ if self.reasoning:
46
+ extra["enable_thinking"] = True
47
+ return extra
48
+
40
49
  @override
41
50
  async def _calculate_cost(
42
51
  self,
@@ -11,26 +11,29 @@ import botocore
11
11
  from botocore.client import BaseClient
12
12
  from typing_extensions import override
13
13
 
14
+ from model_library import model_library_settings
14
15
  from model_library.base import (
15
16
  LLM,
17
+ FileBase,
16
18
  FileInput,
17
19
  FileWithBase64,
18
20
  FileWithId,
19
- FileWithUrl,
20
21
  InputItem,
21
22
  LLMConfig,
22
23
  QueryResult,
23
24
  QueryResultMetadata,
25
+ RawInput,
26
+ RawResponse,
24
27
  TextInput,
25
28
  ToolBody,
26
29
  ToolCall,
27
30
  ToolDefinition,
28
31
  ToolResult,
29
32
  )
30
- from model_library.base.input import FileBase
31
33
  from model_library.exceptions import (
32
34
  BadInputError,
33
35
  MaxOutputTokensExceededError,
36
+ NoMatchingToolCallError,
34
37
  )
35
38
  from model_library.model_utils import get_default_budget_tokens
36
39
  from model_library.register_models import register_provider
@@ -39,20 +42,46 @@ from model_library.register_models import register_provider
39
42
  @register_provider("amazon")
40
43
  @register_provider("bedrock")
41
44
  class AmazonModel(LLM):
42
- _client: BaseClient | None = None
45
+ @override
46
+ def _get_default_api_key(self) -> str:
47
+ if getattr(model_library_settings, "AWS_ACCESS_KEY_ID", None):
48
+ return json.dumps(
49
+ {
50
+ "AWS_ACCESS_KEY_ID": model_library_settings.AWS_ACCESS_KEY_ID,
51
+ "AWS_SECRET_ACCESS_KEY": model_library_settings.AWS_SECRET_ACCESS_KEY,
52
+ "AWS_DEFAULT_REGION": model_library_settings.AWS_DEFAULT_REGION,
53
+ }
54
+ )
55
+ return "using-environment"
43
56
 
44
57
  @override
45
- def get_client(self) -> BaseClient:
46
- if not AmazonModel._client:
47
- AmazonModel._client = cast(
48
- BaseClient,
49
- boto3.client(
50
- "bedrock-runtime",
51
- # default connection pool is 10
52
- config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
53
- ),
54
- ) # pyright: ignore[reportUnknownMemberType]
55
- return AmazonModel._client
58
+ def get_client(self, api_key: str | None = None) -> BaseClient:
59
+ if not self.has_client():
60
+ assert api_key
61
+ if api_key != "using-environment":
62
+ creds = json.loads(api_key)
63
+ client = cast(
64
+ BaseClient,
65
+ boto3.client(
66
+ "bedrock-runtime",
67
+ aws_access_key_id=creds["AWS_ACCESS_KEY_ID"],
68
+ aws_secret_access_key=creds["AWS_SECRET_ACCESS_KEY"],
69
+ region_name=creds["AWS_DEFAULT_REGION"],
70
+ config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
71
+ ),
72
+ )
73
+ else:
74
+ client = cast(
75
+ BaseClient,
76
+ boto3.client(
77
+ "bedrock-runtime",
78
+ # default connection pool is 10
79
+ config=botocore.config.Config(max_pool_connections=1000), # pyright: ignore[reportAttributeAccessIssue]
80
+ ),
81
+ )
82
+
83
+ self.assign_client(client)
84
+ return super().get_client()
56
85
 
57
86
  def __init__(
58
87
  self,
@@ -68,8 +97,27 @@ class AmazonModel(LLM):
68
97
  ) # supported but no access yet
69
98
  self.supports_tool_cache = self.supports_cache and "claude" in self.model_name
70
99
 
100
+ if config and config.custom_api_key:
101
+ raise Exception(
102
+ "custom_api_key is not currently supported for Amazon models"
103
+ )
104
+
71
105
  cache_control = {"type": "default"}
72
106
 
107
+ async def get_tool_call_ids(self, input: Sequence[InputItem]) -> list[str]:
108
+ raw_responses = [x for x in input if isinstance(x, RawResponse)]
109
+ tool_call_ids: list[str] = []
110
+
111
+ calls = [
112
+ y["toolUse"]
113
+ for x in raw_responses
114
+ if "content" in x.response
115
+ for y in x.response["content"]
116
+ if "toolUse" in y
117
+ ]
118
+ tool_call_ids.extend([x["toolUseId"] for x in calls])
119
+ return tool_call_ids
120
+
73
121
  @override
74
122
  async def parse_input(
75
123
  self,
@@ -77,58 +125,63 @@ class AmazonModel(LLM):
77
125
  **kwargs: Any,
78
126
  ) -> list[dict[str, Any]]:
79
127
  new_input: list[dict[str, Any] | Any] = []
128
+
80
129
  content_user: list[dict[str, Any]] = []
81
130
 
131
+ def flush_content_user():
132
+ if content_user:
133
+ # NOTE: must make new object as we clear()
134
+ new_input.append({"role": "user", "content": content_user.copy()})
135
+ content_user.clear()
136
+
137
+ tool_call_ids = await self.get_tool_call_ids(input)
138
+
82
139
  for item in input:
140
+ if isinstance(item, TextInput):
141
+ content_user.append({"text": item.text})
142
+ continue
143
+
144
+ if isinstance(item, FileBase):
145
+ match item.type:
146
+ case "image":
147
+ parsed = await self.parse_image(item)
148
+ case "file":
149
+ parsed = await self.parse_file(item)
150
+ content_user.append(parsed)
151
+ continue
152
+
153
+ # non content user item
154
+ flush_content_user()
155
+
83
156
  match item:
84
- case TextInput():
85
- content_user.append({"text": item.text})
86
- case FileWithBase64() | FileWithUrl() | FileWithId():
87
- match item.type:
88
- case "image":
89
- content_user.append(await self.parse_image(item))
90
- case "file":
91
- content_user.append(await self.parse_file(item))
92
- case _:
93
- if content_user:
94
- new_input.append({"role": "user", "content": content_user})
95
- content_user = []
96
- match item:
97
- case ToolResult():
98
- if not (
99
- isinstance(x, dict)
100
- and "toolUse" in x
101
- and x["toolUse"].get("toolUseId")
102
- == item.tool_call.call_id
103
- for x in new_input
104
- ):
105
- raise Exception(
106
- "Tool call result provided with no matching tool call"
107
- )
108
- new_input.append(
157
+ case ToolResult():
158
+ if item.tool_call.id not in tool_call_ids:
159
+ raise NoMatchingToolCallError()
160
+
161
+ new_input.append(
162
+ {
163
+ "role": "user",
164
+ "content": [
109
165
  {
110
- "role": "user",
111
- "content": [
112
- {
113
- "toolResult": {
114
- "toolUseId": item.tool_call.id,
115
- "content": [
116
- {"json": {"result": item.result}}
117
- ],
118
- }
119
- }
120
- ],
166
+ "toolResult": {
167
+ "toolUseId": item.tool_call.id,
168
+ "content": [{"json": {"result": item.result}}],
169
+ }
121
170
  }
122
- )
123
- case dict(): # RawInputItem and RawResponse
124
- new_input.append(item)
171
+ ],
172
+ }
173
+ )
174
+ case RawResponse():
175
+ new_input.append(item.response)
176
+ case RawInput():
177
+ new_input.append(item.input)
125
178
 
126
- if content_user:
127
- if self.supports_cache:
128
- if not isinstance(input[-1], FileBase):
129
- # last item cannot be file
130
- content_user.append({"cachePoint": self.cache_control})
131
- new_input.append({"role": "user", "content": content_user})
179
+ if content_user and self.supports_cache:
180
+ if not isinstance(input[-1], FileBase):
181
+ # last item cannot be file
182
+ content_user.append({"cachePoint": self.cache_control})
183
+
184
+ flush_content_user()
132
185
 
133
186
  return new_input
134
187
 
@@ -196,6 +249,7 @@ class AmazonModel(LLM):
196
249
  ) -> FileWithId:
197
250
  raise NotImplementedError()
198
251
 
252
+ @override
199
253
  async def build_body(
200
254
  self,
201
255
  input: Sequence[InputItem],
@@ -216,7 +270,7 @@ class AmazonModel(LLM):
216
270
  if self.supports_cache:
217
271
  body["system"].append({"cachePoint": self.cache_control})
218
272
 
219
- if self.reasoning:
273
+ if self.reasoning and self.max_tokens:
220
274
  if self.max_tokens < 1024:
221
275
  self.max_tokens = 2048
222
276
  budget_tokens = kwargs.pop(
@@ -229,9 +283,10 @@ class AmazonModel(LLM):
229
283
  }
230
284
  }
231
285
 
232
- inference: dict[str, Any] = {
233
- "maxTokens": self.max_tokens,
234
- }
286
+ inference: dict[str, Any] = {}
287
+
288
+ if self.max_tokens:
289
+ inference["maxTokens"] = self.max_tokens
235
290
 
236
291
  # Only set temperature for models where supports_temperature is True.
237
292
  # For example, "thinking" models don't support temperature: https://docs.claude.com/en/docs/build-with-claude/extended-thinking#feature-compatibility
@@ -383,5 +438,5 @@ class AmazonModel(LLM):
383
438
  reasoning=reasoning,
384
439
  metadata=metadata,
385
440
  tool_calls=tool_calls,
386
- history=[*input, messages],
441
+ history=[*input, RawResponse(response=messages)],
387
442
  )