mistralai 0.4.1__tar.gz → 0.4.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.4.1 → mistralai-0.4.2}/PKG-INFO +3 -1
- {mistralai-0.4.1 → mistralai-0.4.2}/pyproject.toml +2 -3
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/async_client.py +47 -8
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/client.py +51 -9
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/client_base.py +45 -20
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/constants.py +2 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/jobs.py +2 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/LICENSE +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/README.md +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/__init__.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/exceptions.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/files.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/jobs.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/chat_completion.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/common.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/files.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/models/models.py +0 -0
- {mistralai-0.4.1 → mistralai-0.4.2}/src/mistralai/py.typed +0 -0
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.2
|
|
4
4
|
Summary:
|
|
5
|
+
License: Apache 2.0 License
|
|
5
6
|
Author: Bam4d
|
|
6
7
|
Author-email: bam4d@mistral.ai
|
|
7
8
|
Requires-Python: >=3.9,<4.0
|
|
9
|
+
Classifier: License :: Other/Proprietary License
|
|
8
10
|
Classifier: Programming Language :: Python :: 3
|
|
9
11
|
Classifier: Programming Language :: Python :: 3.9
|
|
10
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.2"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
7
|
+
license = "Apache 2.0 License"
|
|
7
8
|
|
|
8
9
|
[tool.ruff]
|
|
9
10
|
select = ["E", "F", "W", "Q", "I"]
|
|
@@ -39,5 +40,3 @@ pytest-asyncio = "^0.23.2"
|
|
|
39
40
|
[build-system]
|
|
40
41
|
requires = ["poetry-core"]
|
|
41
42
|
build-backend = "poetry.core.masonry.api"
|
|
42
|
-
|
|
43
|
-
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import posixpath
|
|
3
3
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
from httpx import (
|
|
7
7
|
AsyncClient,
|
|
@@ -101,6 +101,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
101
101
|
stream: bool = False,
|
|
102
102
|
attempt: int = 1,
|
|
103
103
|
data: Optional[Dict[str, Any]] = None,
|
|
104
|
+
check_model_deprecation_headers_callback: Optional[Callable] = None,
|
|
104
105
|
**kwargs: Any,
|
|
105
106
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
106
107
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
@@ -129,6 +130,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
129
130
|
data=data,
|
|
130
131
|
**kwargs,
|
|
131
132
|
) as response:
|
|
133
|
+
if check_model_deprecation_headers_callback:
|
|
134
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
132
135
|
await self._check_streaming_response(response)
|
|
133
136
|
|
|
134
137
|
async for line in response.aiter_lines():
|
|
@@ -145,7 +148,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
145
148
|
data=data,
|
|
146
149
|
**kwargs,
|
|
147
150
|
)
|
|
148
|
-
|
|
151
|
+
if check_model_deprecation_headers_callback:
|
|
152
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
149
153
|
yield await self._check_response(response)
|
|
150
154
|
|
|
151
155
|
except ConnectError as e:
|
|
@@ -213,7 +217,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
213
217
|
response_format=response_format,
|
|
214
218
|
)
|
|
215
219
|
|
|
216
|
-
single_response = self._request(
|
|
220
|
+
single_response = self._request(
|
|
221
|
+
"post",
|
|
222
|
+
request,
|
|
223
|
+
"v1/chat/completions",
|
|
224
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
225
|
+
)
|
|
217
226
|
|
|
218
227
|
async for response in single_response:
|
|
219
228
|
return ChatCompletionResponse(**response)
|
|
@@ -267,7 +276,13 @@ class MistralAsyncClient(ClientBase):
|
|
|
267
276
|
tool_choice=tool_choice,
|
|
268
277
|
response_format=response_format,
|
|
269
278
|
)
|
|
270
|
-
async_response = self._request(
|
|
279
|
+
async_response = self._request(
|
|
280
|
+
"post",
|
|
281
|
+
request,
|
|
282
|
+
"v1/chat/completions",
|
|
283
|
+
stream=True,
|
|
284
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
285
|
+
)
|
|
271
286
|
|
|
272
287
|
async for json_response in async_response:
|
|
273
288
|
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -284,7 +299,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
284
299
|
EmbeddingResponse: A response object containing the embeddings.
|
|
285
300
|
"""
|
|
286
301
|
request = {"model": model, "input": input}
|
|
287
|
-
single_response = self._request(
|
|
302
|
+
single_response = self._request(
|
|
303
|
+
"post",
|
|
304
|
+
request,
|
|
305
|
+
"v1/embeddings",
|
|
306
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
307
|
+
)
|
|
288
308
|
|
|
289
309
|
async for response in single_response:
|
|
290
310
|
return EmbeddingResponse(**response)
|
|
@@ -341,7 +361,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
341
361
|
request = self._make_completion_request(
|
|
342
362
|
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
343
363
|
)
|
|
344
|
-
single_response = self._request(
|
|
364
|
+
single_response = self._request(
|
|
365
|
+
"post",
|
|
366
|
+
request,
|
|
367
|
+
"v1/fim/completions",
|
|
368
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
369
|
+
)
|
|
345
370
|
|
|
346
371
|
async for response in single_response:
|
|
347
372
|
return ChatCompletionResponse(**response)
|
|
@@ -376,9 +401,23 @@ class MistralAsyncClient(ClientBase):
|
|
|
376
401
|
Dict[str, Any]: a response object containing the generated text.
|
|
377
402
|
"""
|
|
378
403
|
request = self._make_completion_request(
|
|
379
|
-
prompt,
|
|
404
|
+
prompt,
|
|
405
|
+
model,
|
|
406
|
+
suffix,
|
|
407
|
+
temperature,
|
|
408
|
+
max_tokens,
|
|
409
|
+
top_p,
|
|
410
|
+
random_seed,
|
|
411
|
+
stop,
|
|
412
|
+
stream=True,
|
|
413
|
+
)
|
|
414
|
+
async_response = self._request(
|
|
415
|
+
"post",
|
|
416
|
+
request,
|
|
417
|
+
"v1/fim/completions",
|
|
418
|
+
stream=True,
|
|
419
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
380
420
|
)
|
|
381
|
-
async_response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
382
421
|
|
|
383
422
|
async for json_response in async_response:
|
|
384
423
|
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import posixpath
|
|
2
2
|
import time
|
|
3
3
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
|
|
7
7
|
|
|
@@ -40,7 +40,9 @@ class MistralClient(ClientBase):
|
|
|
40
40
|
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
41
41
|
|
|
42
42
|
self._client = Client(
|
|
43
|
-
follow_redirects=True,
|
|
43
|
+
follow_redirects=True,
|
|
44
|
+
timeout=self._timeout,
|
|
45
|
+
transport=HTTPTransport(retries=self._max_retries),
|
|
44
46
|
)
|
|
45
47
|
self.files = FilesClient(self)
|
|
46
48
|
self.jobs = JobsClient(self)
|
|
@@ -94,6 +96,7 @@ class MistralClient(ClientBase):
|
|
|
94
96
|
stream: bool = False,
|
|
95
97
|
attempt: int = 1,
|
|
96
98
|
data: Optional[Dict[str, Any]] = None,
|
|
99
|
+
check_model_deprecation_headers_callback: Optional[Callable] = None,
|
|
97
100
|
**kwargs: Any,
|
|
98
101
|
) -> Iterator[Dict[str, Any]]:
|
|
99
102
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
@@ -122,6 +125,8 @@ class MistralClient(ClientBase):
|
|
|
122
125
|
data=data,
|
|
123
126
|
**kwargs,
|
|
124
127
|
) as response:
|
|
128
|
+
if check_model_deprecation_headers_callback:
|
|
129
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
125
130
|
self._check_streaming_response(response)
|
|
126
131
|
|
|
127
132
|
for line in response.iter_lines():
|
|
@@ -138,7 +143,8 @@ class MistralClient(ClientBase):
|
|
|
138
143
|
data=data,
|
|
139
144
|
**kwargs,
|
|
140
145
|
)
|
|
141
|
-
|
|
146
|
+
if check_model_deprecation_headers_callback:
|
|
147
|
+
check_model_deprecation_headers_callback(response.headers)
|
|
142
148
|
yield self._check_response(response)
|
|
143
149
|
|
|
144
150
|
except ConnectError as e:
|
|
@@ -207,7 +213,12 @@ class MistralClient(ClientBase):
|
|
|
207
213
|
response_format=response_format,
|
|
208
214
|
)
|
|
209
215
|
|
|
210
|
-
single_response = self._request(
|
|
216
|
+
single_response = self._request(
|
|
217
|
+
"post",
|
|
218
|
+
request,
|
|
219
|
+
"v1/chat/completions",
|
|
220
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
221
|
+
)
|
|
211
222
|
|
|
212
223
|
for response in single_response:
|
|
213
224
|
return ChatCompletionResponse(**response)
|
|
@@ -261,7 +272,13 @@ class MistralClient(ClientBase):
|
|
|
261
272
|
response_format=response_format,
|
|
262
273
|
)
|
|
263
274
|
|
|
264
|
-
response = self._request(
|
|
275
|
+
response = self._request(
|
|
276
|
+
"post",
|
|
277
|
+
request,
|
|
278
|
+
"v1/chat/completions",
|
|
279
|
+
stream=True,
|
|
280
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
281
|
+
)
|
|
265
282
|
|
|
266
283
|
for json_streamed_response in response:
|
|
267
284
|
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -278,7 +295,12 @@ class MistralClient(ClientBase):
|
|
|
278
295
|
EmbeddingResponse: A response object containing the embeddings.
|
|
279
296
|
"""
|
|
280
297
|
request = {"model": model, "input": input}
|
|
281
|
-
singleton_response = self._request(
|
|
298
|
+
singleton_response = self._request(
|
|
299
|
+
"post",
|
|
300
|
+
request,
|
|
301
|
+
"v1/embeddings",
|
|
302
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
303
|
+
)
|
|
282
304
|
|
|
283
305
|
for response in singleton_response:
|
|
284
306
|
return EmbeddingResponse(**response)
|
|
@@ -337,7 +359,13 @@ class MistralClient(ClientBase):
|
|
|
337
359
|
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
338
360
|
)
|
|
339
361
|
|
|
340
|
-
single_response = self._request(
|
|
362
|
+
single_response = self._request(
|
|
363
|
+
"post",
|
|
364
|
+
request,
|
|
365
|
+
"v1/fim/completions",
|
|
366
|
+
stream=False,
|
|
367
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
368
|
+
)
|
|
341
369
|
|
|
342
370
|
for response in single_response:
|
|
343
371
|
return ChatCompletionResponse(**response)
|
|
@@ -372,10 +400,24 @@ class MistralClient(ClientBase):
|
|
|
372
400
|
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
|
|
373
401
|
"""
|
|
374
402
|
request = self._make_completion_request(
|
|
375
|
-
prompt,
|
|
403
|
+
prompt,
|
|
404
|
+
model,
|
|
405
|
+
suffix,
|
|
406
|
+
temperature,
|
|
407
|
+
max_tokens,
|
|
408
|
+
top_p,
|
|
409
|
+
random_seed,
|
|
410
|
+
stop,
|
|
411
|
+
stream=True,
|
|
376
412
|
)
|
|
377
413
|
|
|
378
|
-
response = self._request(
|
|
414
|
+
response = self._request(
|
|
415
|
+
"post",
|
|
416
|
+
request,
|
|
417
|
+
"v1/fim/completions",
|
|
418
|
+
stream=True,
|
|
419
|
+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
|
|
420
|
+
)
|
|
379
421
|
|
|
380
422
|
for json_streamed_response in response:
|
|
381
423
|
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -1,16 +1,21 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from abc import ABC
|
|
4
|
-
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
import orjson
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
7
|
+
from httpx import Headers
|
|
8
|
+
|
|
9
|
+
from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
|
|
10
|
+
from mistralai.exceptions import MistralException
|
|
11
|
+
from mistralai.models.chat_completion import (
|
|
12
|
+
ChatMessage,
|
|
13
|
+
Function,
|
|
14
|
+
ResponseFormat,
|
|
15
|
+
ToolChoice,
|
|
10
16
|
)
|
|
11
|
-
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
|
|
12
17
|
|
|
13
|
-
CLIENT_VERSION = "0.4.
|
|
18
|
+
CLIENT_VERSION = "0.4.2"
|
|
14
19
|
|
|
15
20
|
|
|
16
21
|
class ClientBase(ABC):
|
|
@@ -38,6 +43,14 @@ class ClientBase(ABC):
|
|
|
38
43
|
|
|
39
44
|
self._version = CLIENT_VERSION
|
|
40
45
|
|
|
46
|
+
def _get_model(self, model: Optional[str] = None) -> str:
|
|
47
|
+
if model is not None:
|
|
48
|
+
return model
|
|
49
|
+
else:
|
|
50
|
+
if self._default_model is None:
|
|
51
|
+
raise MistralException(message="model must be provided")
|
|
52
|
+
return self._default_model
|
|
53
|
+
|
|
41
54
|
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
42
55
|
parsed_tools: List[Dict[str, Any]] = []
|
|
43
56
|
for tool in tools:
|
|
@@ -73,6 +86,22 @@ class ClientBase(ABC):
|
|
|
73
86
|
|
|
74
87
|
return parsed_messages
|
|
75
88
|
|
|
89
|
+
def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
|
|
90
|
+
model = self._get_model(model)
|
|
91
|
+
|
|
92
|
+
def _check_model_deprecation_header_callback(
|
|
93
|
+
headers: Headers,
|
|
94
|
+
) -> None:
|
|
95
|
+
if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
|
|
96
|
+
self._logger.warning(
|
|
97
|
+
f"WARNING: The model {model} is deprecated "
|
|
98
|
+
f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
|
|
99
|
+
"Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
|
|
100
|
+
"for more information."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return _check_model_deprecation_header_callback
|
|
104
|
+
|
|
76
105
|
def _make_completion_request(
|
|
77
106
|
self,
|
|
78
107
|
prompt: str,
|
|
@@ -95,16 +124,14 @@ class ClientBase(ABC):
|
|
|
95
124
|
if stop is not None:
|
|
96
125
|
request_data["stop"] = stop
|
|
97
126
|
|
|
98
|
-
|
|
99
|
-
request_data["model"] = model
|
|
100
|
-
else:
|
|
101
|
-
if self._default_model is None:
|
|
102
|
-
raise MistralException(message="model must be provided")
|
|
103
|
-
request_data["model"] = self._default_model
|
|
127
|
+
request_data["model"] = self._get_model(model)
|
|
104
128
|
|
|
105
129
|
request_data.update(
|
|
106
130
|
self._build_sampling_params(
|
|
107
|
-
temperature=temperature,
|
|
131
|
+
temperature=temperature,
|
|
132
|
+
max_tokens=max_tokens,
|
|
133
|
+
top_p=top_p,
|
|
134
|
+
random_seed=random_seed,
|
|
108
135
|
)
|
|
109
136
|
)
|
|
110
137
|
|
|
@@ -148,16 +175,14 @@ class ClientBase(ABC):
|
|
|
148
175
|
"messages": self._parse_messages(messages),
|
|
149
176
|
}
|
|
150
177
|
|
|
151
|
-
|
|
152
|
-
request_data["model"] = model
|
|
153
|
-
else:
|
|
154
|
-
if self._default_model is None:
|
|
155
|
-
raise MistralException(message="model must be provided")
|
|
156
|
-
request_data["model"] = self._default_model
|
|
178
|
+
request_data["model"] = self._get_model(model)
|
|
157
179
|
|
|
158
180
|
request_data.update(
|
|
159
181
|
self._build_sampling_params(
|
|
160
|
-
temperature=temperature,
|
|
182
|
+
temperature=temperature,
|
|
183
|
+
max_tokens=max_tokens,
|
|
184
|
+
top_p=top_p,
|
|
185
|
+
random_seed=random_seed,
|
|
161
186
|
)
|
|
162
187
|
)
|
|
163
188
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|