model-library 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_library/base/base.py +237 -62
- model_library/base/delegate_only.py +86 -9
- model_library/base/input.py +10 -7
- model_library/base/output.py +48 -0
- model_library/base/utils.py +56 -7
- model_library/config/alibaba_models.yaml +44 -57
- model_library/config/all_models.json +253 -126
- model_library/config/kimi_models.yaml +30 -3
- model_library/config/openai_models.yaml +15 -23
- model_library/config/zai_models.yaml +24 -3
- model_library/exceptions.py +14 -77
- model_library/logging.py +6 -2
- model_library/providers/ai21labs.py +30 -14
- model_library/providers/alibaba.py +17 -8
- model_library/providers/amazon.py +119 -64
- model_library/providers/anthropic.py +184 -104
- model_library/providers/azure.py +22 -10
- model_library/providers/cohere.py +7 -7
- model_library/providers/deepseek.py +8 -8
- model_library/providers/fireworks.py +7 -8
- model_library/providers/google/batch.py +17 -13
- model_library/providers/google/google.py +130 -73
- model_library/providers/inception.py +7 -7
- model_library/providers/kimi.py +18 -8
- model_library/providers/minimax.py +30 -13
- model_library/providers/mistral.py +61 -35
- model_library/providers/openai.py +219 -93
- model_library/providers/openrouter.py +34 -0
- model_library/providers/perplexity.py +7 -7
- model_library/providers/together.py +7 -8
- model_library/providers/vals.py +16 -9
- model_library/providers/xai.py +157 -144
- model_library/providers/zai.py +38 -8
- model_library/register_models.py +4 -2
- model_library/registry_utils.py +39 -15
- model_library/retriers/__init__.py +0 -0
- model_library/retriers/backoff.py +73 -0
- model_library/retriers/base.py +225 -0
- model_library/retriers/token.py +427 -0
- model_library/retriers/utils.py +11 -0
- model_library/settings.py +1 -1
- model_library/utils.py +13 -35
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/METADATA +4 -3
- model_library-0.1.8.dist-info/RECORD +70 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/WHEEL +1 -1
- model_library-0.1.6.dist-info/RECORD +0 -64
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {model_library-0.1.6.dist-info → model_library-0.1.8.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,15 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import logging
|
|
3
|
-
import time
|
|
4
3
|
from collections.abc import Sequence
|
|
5
4
|
from typing import Any, Literal
|
|
6
5
|
|
|
7
|
-
from mistralai import
|
|
6
|
+
from mistralai import (
|
|
7
|
+
AssistantMessage,
|
|
8
|
+
ContentChunk,
|
|
9
|
+
Mistral,
|
|
10
|
+
TextChunk,
|
|
11
|
+
ThinkChunk,
|
|
12
|
+
)
|
|
8
13
|
from mistralai.models.completionevent import CompletionEvent
|
|
9
14
|
from mistralai.models.toolcall import ToolCall as MistralToolCall
|
|
10
15
|
from mistralai.utils.eventstreaming import EventStreamAsync
|
|
@@ -13,14 +18,16 @@ from typing_extensions import override
|
|
|
13
18
|
from model_library import model_library_settings
|
|
14
19
|
from model_library.base import (
|
|
15
20
|
LLM,
|
|
21
|
+
FileBase,
|
|
16
22
|
FileInput,
|
|
17
23
|
FileWithBase64,
|
|
18
24
|
FileWithId,
|
|
19
|
-
FileWithUrl,
|
|
20
25
|
InputItem,
|
|
21
26
|
LLMConfig,
|
|
22
27
|
QueryResult,
|
|
23
28
|
QueryResultMetadata,
|
|
29
|
+
RawInput,
|
|
30
|
+
RawResponse,
|
|
24
31
|
TextInput,
|
|
25
32
|
ToolBody,
|
|
26
33
|
ToolCall,
|
|
@@ -39,16 +46,20 @@ from model_library.utils import default_httpx_client
|
|
|
39
46
|
|
|
40
47
|
@register_provider("mistralai")
|
|
41
48
|
class MistralModel(LLM):
|
|
42
|
-
|
|
49
|
+
@override
|
|
50
|
+
def _get_default_api_key(self) -> str:
|
|
51
|
+
return model_library_settings.MISTRAL_API_KEY
|
|
43
52
|
|
|
44
53
|
@override
|
|
45
|
-
def get_client(self) -> Mistral:
|
|
46
|
-
if not
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
def get_client(self, api_key: str | None = None) -> Mistral:
|
|
55
|
+
if not self.has_client():
|
|
56
|
+
assert api_key
|
|
57
|
+
client = Mistral(
|
|
58
|
+
api_key=api_key,
|
|
49
59
|
async_client=default_httpx_client(),
|
|
50
60
|
)
|
|
51
|
-
|
|
61
|
+
self.assign_client(client)
|
|
62
|
+
return super().get_client()
|
|
52
63
|
|
|
53
64
|
def __init__(
|
|
54
65
|
self,
|
|
@@ -69,27 +80,30 @@ class MistralModel(LLM):
|
|
|
69
80
|
content_user: list[dict[str, Any]] = []
|
|
70
81
|
|
|
71
82
|
def flush_content_user():
|
|
72
|
-
nonlocal content_user
|
|
73
|
-
|
|
74
83
|
if content_user:
|
|
75
|
-
|
|
76
|
-
|
|
84
|
+
# NOTE: must make new object as we clear()
|
|
85
|
+
new_input.append({"role": "user", "content": content_user.copy()})
|
|
86
|
+
content_user.clear()
|
|
77
87
|
|
|
78
88
|
for item in input:
|
|
89
|
+
if isinstance(item, TextInput):
|
|
90
|
+
content_user.append({"type": "text", "text": item.text})
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
if isinstance(item, FileBase):
|
|
94
|
+
match item.type:
|
|
95
|
+
case "image":
|
|
96
|
+
parsed = await self.parse_image(item)
|
|
97
|
+
case "file":
|
|
98
|
+
parsed = await self.parse_file(item)
|
|
99
|
+
content_user.append(parsed)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
# non content user item
|
|
103
|
+
flush_content_user()
|
|
104
|
+
|
|
79
105
|
match item:
|
|
80
|
-
case TextInput():
|
|
81
|
-
content_user.append({"type": "text", "text": item.text})
|
|
82
|
-
case FileWithBase64() | FileWithUrl() | FileWithId():
|
|
83
|
-
match item.type:
|
|
84
|
-
case "image":
|
|
85
|
-
content_user.append(await self.parse_image(item))
|
|
86
|
-
case "file":
|
|
87
|
-
content_user.append(await self.parse_file(item))
|
|
88
|
-
case AssistantMessage():
|
|
89
|
-
flush_content_user()
|
|
90
|
-
new_input.append(item)
|
|
91
106
|
case ToolResult():
|
|
92
|
-
flush_content_user()
|
|
93
107
|
new_input.append(
|
|
94
108
|
{
|
|
95
109
|
"role": "tool",
|
|
@@ -98,9 +112,12 @@ class MistralModel(LLM):
|
|
|
98
112
|
"tool_call_id": item.tool_call.id,
|
|
99
113
|
}
|
|
100
114
|
)
|
|
101
|
-
case
|
|
102
|
-
|
|
115
|
+
case RawResponse():
|
|
116
|
+
new_input.append(item.response)
|
|
117
|
+
case RawInput():
|
|
118
|
+
new_input.append(item.input)
|
|
103
119
|
|
|
120
|
+
# in case content user item is the last item
|
|
104
121
|
flush_content_user()
|
|
105
122
|
|
|
106
123
|
return new_input
|
|
@@ -167,14 +184,13 @@ class MistralModel(LLM):
|
|
|
167
184
|
raise NotImplementedError()
|
|
168
185
|
|
|
169
186
|
@override
|
|
170
|
-
async def
|
|
187
|
+
async def build_body(
|
|
171
188
|
self,
|
|
172
189
|
input: Sequence[InputItem],
|
|
173
190
|
*,
|
|
174
191
|
tools: list[ToolDefinition],
|
|
175
|
-
query_logger: logging.Logger,
|
|
176
192
|
**kwargs: object,
|
|
177
|
-
) ->
|
|
193
|
+
) -> dict[str, Any]:
|
|
178
194
|
# mistral supports max 8 images, merge extra images into the 8th image
|
|
179
195
|
input = trim_images(input, max_images=8)
|
|
180
196
|
|
|
@@ -192,12 +208,14 @@ class MistralModel(LLM):
|
|
|
192
208
|
|
|
193
209
|
body: dict[str, Any] = {
|
|
194
210
|
"model": self.model_name,
|
|
195
|
-
"max_tokens": self.max_tokens,
|
|
196
211
|
"messages": messages,
|
|
197
212
|
"prompt_mode": "reasoning" if self.reasoning else None,
|
|
198
213
|
"tools": tools,
|
|
199
214
|
}
|
|
200
215
|
|
|
216
|
+
if self.max_tokens:
|
|
217
|
+
body["max_tokens"] = self.max_tokens
|
|
218
|
+
|
|
201
219
|
if self.supports_temperature:
|
|
202
220
|
if self.temperature is not None:
|
|
203
221
|
body["temperature"] = self.temperature
|
|
@@ -205,8 +223,18 @@ class MistralModel(LLM):
|
|
|
205
223
|
body["top_p"] = self.top_p
|
|
206
224
|
|
|
207
225
|
body.update(kwargs)
|
|
226
|
+
return body
|
|
208
227
|
|
|
209
|
-
|
|
228
|
+
@override
|
|
229
|
+
async def _query_impl(
|
|
230
|
+
self,
|
|
231
|
+
input: Sequence[InputItem],
|
|
232
|
+
*,
|
|
233
|
+
tools: list[ToolDefinition],
|
|
234
|
+
query_logger: logging.Logger,
|
|
235
|
+
**kwargs: object,
|
|
236
|
+
) -> QueryResult:
|
|
237
|
+
body = await self.build_body(input, tools=tools, **kwargs)
|
|
210
238
|
|
|
211
239
|
response: EventStreamAsync[
|
|
212
240
|
CompletionEvent
|
|
@@ -247,8 +275,6 @@ class MistralModel(LLM):
|
|
|
247
275
|
in_tokens += data.usage.prompt_tokens or 0
|
|
248
276
|
out_tokens += data.usage.completion_tokens or 0
|
|
249
277
|
|
|
250
|
-
self.logger.info(f"Finished in: {time.time() - start}")
|
|
251
|
-
|
|
252
278
|
except Exception as e:
|
|
253
279
|
self.logger.error(f"Error: {e}", exc_info=True)
|
|
254
280
|
raise e
|
|
@@ -302,7 +328,7 @@ class MistralModel(LLM):
|
|
|
302
328
|
return QueryResult(
|
|
303
329
|
output_text=text,
|
|
304
330
|
reasoning=reasoning or None,
|
|
305
|
-
history=[*input, message],
|
|
331
|
+
history=[*input, RawResponse(response=message)],
|
|
306
332
|
tool_calls=tool_calls,
|
|
307
333
|
metadata=QueryResultMetadata(
|
|
308
334
|
in_tokens=in_tokens,
|