model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. model_library/base/base.py +237 -62
  2. model_library/base/delegate_only.py +86 -9
  3. model_library/base/input.py +10 -7
  4. model_library/base/output.py +48 -0
  5. model_library/base/utils.py +56 -7
  6. model_library/config/alibaba_models.yaml +44 -57
  7. model_library/config/all_models.json +253 -126
  8. model_library/config/kimi_models.yaml +30 -3
  9. model_library/config/openai_models.yaml +15 -23
  10. model_library/config/zai_models.yaml +24 -3
  11. model_library/exceptions.py +14 -77
  12. model_library/logging.py +6 -2
  13. model_library/providers/ai21labs.py +30 -14
  14. model_library/providers/alibaba.py +17 -8
  15. model_library/providers/amazon.py +119 -64
  16. model_library/providers/anthropic.py +184 -104
  17. model_library/providers/azure.py +22 -10
  18. model_library/providers/cohere.py +7 -7
  19. model_library/providers/deepseek.py +8 -8
  20. model_library/providers/fireworks.py +7 -8
  21. model_library/providers/google/batch.py +17 -13
  22. model_library/providers/google/google.py +130 -73
  23. model_library/providers/inception.py +7 -7
  24. model_library/providers/kimi.py +18 -8
  25. model_library/providers/minimax.py +30 -13
  26. model_library/providers/mistral.py +61 -35
  27. model_library/providers/openai.py +219 -93
  28. model_library/providers/openrouter.py +34 -0
  29. model_library/providers/perplexity.py +7 -7
  30. model_library/providers/together.py +7 -8
  31. model_library/providers/vals.py +16 -9
  32. model_library/providers/xai.py +157 -144
  33. model_library/providers/zai.py +38 -8
  34. model_library/register_models.py +4 -2
  35. model_library/registry_utils.py +39 -15
  36. model_library/retriers/__init__.py +0 -0
  37. model_library/retriers/backoff.py +73 -0
  38. model_library/retriers/base.py +225 -0
  39. model_library/retriers/token.py +427 -0
  40. model_library/retriers/utils.py +11 -0
  41. model_library/settings.py +1 -1
  42. model_library/utils.py +13 -35
  43. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
  44. model_library-0.1.8.dist-info/RECORD +70 -0
  45. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
  46. model_library-0.1.6.dist-info/RECORD +0 -64
  47. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
  48. {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,15 @@
1
1
  import io
2
2
  import logging
3
- import time
4
3
  from collections.abc import Sequence
5
4
  from typing import Any, Literal
6
5
 
7
- from mistralai import AssistantMessage, ContentChunk, Mistral, TextChunk, ThinkChunk
6
+ from mistralai import (
7
+ AssistantMessage,
8
+ ContentChunk,
9
+ Mistral,
10
+ TextChunk,
11
+ ThinkChunk,
12
+ )
8
13
  from mistralai.models.completionevent import CompletionEvent
9
14
  from mistralai.models.toolcall import ToolCall as MistralToolCall
10
15
  from mistralai.utils.eventstreaming import EventStreamAsync
@@ -13,14 +18,16 @@ from typing_extensions import override
13
18
  from model_library import model_library_settings
14
19
  from model_library.base import (
15
20
  LLM,
21
+ FileBase,
16
22
  FileInput,
17
23
  FileWithBase64,
18
24
  FileWithId,
19
- FileWithUrl,
20
25
  InputItem,
21
26
  LLMConfig,
22
27
  QueryResult,
23
28
  QueryResultMetadata,
29
+ RawInput,
30
+ RawResponse,
24
31
  TextInput,
25
32
  ToolBody,
26
33
  ToolCall,
@@ -39,16 +46,20 @@ from model_library.utils import default_httpx_client
39
46
 
40
47
  @register_provider("mistralai")
41
48
  class MistralModel(LLM):
42
- _client: Mistral | None = None
49
+ @override
50
+ def _get_default_api_key(self) -> str:
51
+ return model_library_settings.MISTRAL_API_KEY
43
52
 
44
53
  @override
45
- def get_client(self) -> Mistral:
46
- if not MistralModel._client:
47
- MistralModel._client = Mistral(
48
- api_key=model_library_settings.MISTRAL_API_KEY,
54
+ def get_client(self, api_key: str | None = None) -> Mistral:
55
+ if not self.has_client():
56
+ assert api_key
57
+ client = Mistral(
58
+ api_key=api_key,
49
59
  async_client=default_httpx_client(),
50
60
  )
51
- return MistralModel._client
61
+ self.assign_client(client)
62
+ return super().get_client()
52
63
 
53
64
  def __init__(
54
65
  self,
@@ -69,27 +80,30 @@ class MistralModel(LLM):
69
80
  content_user: list[dict[str, Any]] = []
70
81
 
71
82
  def flush_content_user():
72
- nonlocal content_user
73
-
74
83
  if content_user:
75
- new_input.append({"role": "user", "content": content_user})
76
- content_user = []
84
+ # NOTE: must make new object as we clear()
85
+ new_input.append({"role": "user", "content": content_user.copy()})
86
+ content_user.clear()
77
87
 
78
88
  for item in input:
89
+ if isinstance(item, TextInput):
90
+ content_user.append({"type": "text", "text": item.text})
91
+ continue
92
+
93
+ if isinstance(item, FileBase):
94
+ match item.type:
95
+ case "image":
96
+ parsed = await self.parse_image(item)
97
+ case "file":
98
+ parsed = await self.parse_file(item)
99
+ content_user.append(parsed)
100
+ continue
101
+
102
+ # non content user item
103
+ flush_content_user()
104
+
79
105
  match item:
80
- case TextInput():
81
- content_user.append({"type": "text", "text": item.text})
82
- case FileWithBase64() | FileWithUrl() | FileWithId():
83
- match item.type:
84
- case "image":
85
- content_user.append(await self.parse_image(item))
86
- case "file":
87
- content_user.append(await self.parse_file(item))
88
- case AssistantMessage():
89
- flush_content_user()
90
- new_input.append(item)
91
106
  case ToolResult():
92
- flush_content_user()
93
107
  new_input.append(
94
108
  {
95
109
  "role": "tool",
@@ -98,9 +112,12 @@ class MistralModel(LLM):
98
112
  "tool_call_id": item.tool_call.id,
99
113
  }
100
114
  )
101
- case _:
102
- raise BadInputError("Unsupported input type")
115
+ case RawResponse():
116
+ new_input.append(item.response)
117
+ case RawInput():
118
+ new_input.append(item.input)
103
119
 
120
+ # in case content user item is the last item
104
121
  flush_content_user()
105
122
 
106
123
  return new_input
@@ -167,14 +184,13 @@ class MistralModel(LLM):
167
184
  raise NotImplementedError()
168
185
 
169
186
  @override
170
- async def _query_impl(
187
+ async def build_body(
171
188
  self,
172
189
  input: Sequence[InputItem],
173
190
  *,
174
191
  tools: list[ToolDefinition],
175
- query_logger: logging.Logger,
176
192
  **kwargs: object,
177
- ) -> QueryResult:
193
+ ) -> dict[str, Any]:
178
194
  # mistral supports max 8 images, merge extra images into the 8th image
179
195
  input = trim_images(input, max_images=8)
180
196
 
@@ -192,12 +208,14 @@ class MistralModel(LLM):
192
208
 
193
209
  body: dict[str, Any] = {
194
210
  "model": self.model_name,
195
- "max_tokens": self.max_tokens,
196
211
  "messages": messages,
197
212
  "prompt_mode": "reasoning" if self.reasoning else None,
198
213
  "tools": tools,
199
214
  }
200
215
 
216
+ if self.max_tokens:
217
+ body["max_tokens"] = self.max_tokens
218
+
201
219
  if self.supports_temperature:
202
220
  if self.temperature is not None:
203
221
  body["temperature"] = self.temperature
@@ -205,8 +223,18 @@ class MistralModel(LLM):
205
223
  body["top_p"] = self.top_p
206
224
 
207
225
  body.update(kwargs)
226
+ return body
208
227
 
209
- start = time.time()
228
+ @override
229
+ async def _query_impl(
230
+ self,
231
+ input: Sequence[InputItem],
232
+ *,
233
+ tools: list[ToolDefinition],
234
+ query_logger: logging.Logger,
235
+ **kwargs: object,
236
+ ) -> QueryResult:
237
+ body = await self.build_body(input, tools=tools, **kwargs)
210
238
 
211
239
  response: EventStreamAsync[
212
240
  CompletionEvent
@@ -247,8 +275,6 @@ class MistralModel(LLM):
247
275
  in_tokens += data.usage.prompt_tokens or 0
248
276
  out_tokens += data.usage.completion_tokens or 0
249
277
 
250
- self.logger.info(f"Finished in: {time.time() - start}")
251
-
252
278
  except Exception as e:
253
279
  self.logger.error(f"Error: {e}", exc_info=True)
254
280
  raise e
@@ -302,7 +328,7 @@ class MistralModel(LLM):
302
328
  return QueryResult(
303
329
  output_text=text,
304
330
  reasoning=reasoning or None,
305
- history=[*input, message],
331
+ history=[*input, RawResponse(response=message)],
306
332
  tool_calls=tool_calls,
307
333
  metadata=QueryResultMetadata(
308
334
  in_tokens=in_tokens,