mistralai 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/async_client.py +93 -190
- mistralai/client.py +77 -65
- mistralai/client_base.py +54 -18
- mistralai/exceptions.py +3 -12
- mistralai/py.typed +0 -0
- {mistralai-0.0.8.dist-info → mistralai-0.0.9.dist-info}/METADATA +27 -8
- mistralai-0.0.9.dist-info/RECORD +16 -0
- mistralai-0.0.8.dist-info/RECORD +0 -15
- {mistralai-0.0.8.dist-info → mistralai-0.0.9.dist-info}/LICENSE +0 -0
- {mistralai-0.0.8.dist-info → mistralai-0.0.9.dist-info}/WHEEL +0 -0
mistralai/async_client.py
CHANGED
|
@@ -1,20 +1,23 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
1
|
import os
|
|
4
2
|
import posixpath
|
|
5
3
|
import time
|
|
6
|
-
from collections import defaultdict
|
|
7
4
|
from json import JSONDecodeError
|
|
8
|
-
from typing import Any, AsyncGenerator,
|
|
9
|
-
|
|
10
|
-
import
|
|
11
|
-
|
|
12
|
-
|
|
5
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from httpx import (
|
|
8
|
+
AsyncClient,
|
|
9
|
+
AsyncHTTPTransport,
|
|
10
|
+
ConnectError,
|
|
11
|
+
Limits,
|
|
12
|
+
RequestError,
|
|
13
|
+
Response,
|
|
14
|
+
)
|
|
13
15
|
|
|
14
16
|
from mistralai.client_base import ClientBase
|
|
15
|
-
from mistralai.constants import ENDPOINT
|
|
17
|
+
from mistralai.constants import ENDPOINT
|
|
16
18
|
from mistralai.exceptions import (
|
|
17
19
|
MistralAPIException,
|
|
20
|
+
MistralAPIStatusException,
|
|
18
21
|
MistralConnectionException,
|
|
19
22
|
MistralException,
|
|
20
23
|
)
|
|
@@ -27,136 +30,6 @@ from mistralai.models.embeddings import EmbeddingResponse
|
|
|
27
30
|
from mistralai.models.models import ModelList
|
|
28
31
|
|
|
29
32
|
|
|
30
|
-
class AIOHTTPBackend:
|
|
31
|
-
"""HTTP backend which handles retries, concurrency limiting and logging"""
|
|
32
|
-
|
|
33
|
-
SLEEP_AFTER_FAILURE = defaultdict(lambda: 0.25, {429: 5.0})
|
|
34
|
-
|
|
35
|
-
_requester: Callable[..., Awaitable[aiohttp.ClientResponse]]
|
|
36
|
-
_semaphore: asyncio.Semaphore
|
|
37
|
-
_session: Optional[aiohttp.ClientSession]
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
max_concurrent_requests: int = 64,
|
|
42
|
-
max_retries: int = 5,
|
|
43
|
-
timeout: int = 120,
|
|
44
|
-
):
|
|
45
|
-
self._logger = logging.getLogger(__name__)
|
|
46
|
-
self._timeout = timeout
|
|
47
|
-
self._max_retries = max_retries
|
|
48
|
-
self._session = None
|
|
49
|
-
self._max_concurrent_requests = max_concurrent_requests
|
|
50
|
-
|
|
51
|
-
def build_aio_requester(
|
|
52
|
-
self,
|
|
53
|
-
) -> Callable: # returns a function for retryable requests
|
|
54
|
-
@backoff.on_exception(
|
|
55
|
-
backoff.expo,
|
|
56
|
-
(aiohttp.ClientError, aiohttp.ClientResponseError),
|
|
57
|
-
max_tries=self._max_retries + 1,
|
|
58
|
-
max_time=self._timeout,
|
|
59
|
-
)
|
|
60
|
-
async def make_request_fn(
|
|
61
|
-
session: aiohttp.ClientSession, *args: Any, **kwargs: Any
|
|
62
|
-
) -> aiohttp.ClientResponse:
|
|
63
|
-
async with self._semaphore: # this limits total concurrency by the client
|
|
64
|
-
response = await session.request(*args, **kwargs)
|
|
65
|
-
if (
|
|
66
|
-
response.status in RETRY_STATUS_CODES
|
|
67
|
-
): # likely temporary, raise to retry
|
|
68
|
-
self._logger.info(f"Received status {response.status}, retrying...")
|
|
69
|
-
await asyncio.sleep(self.SLEEP_AFTER_FAILURE[response.status])
|
|
70
|
-
response.raise_for_status()
|
|
71
|
-
|
|
72
|
-
return response
|
|
73
|
-
|
|
74
|
-
return make_request_fn
|
|
75
|
-
|
|
76
|
-
async def request(
|
|
77
|
-
self,
|
|
78
|
-
url: str,
|
|
79
|
-
json: Optional[Dict[str, Any]] = None,
|
|
80
|
-
method: str = "post",
|
|
81
|
-
headers: Optional[Dict[str, Any]] = None,
|
|
82
|
-
session: Optional[aiohttp.ClientSession] = None,
|
|
83
|
-
params: Optional[Dict[str, Any]] = None,
|
|
84
|
-
**kwargs: Any,
|
|
85
|
-
) -> aiohttp.ClientResponse:
|
|
86
|
-
session = session or await self.session()
|
|
87
|
-
self._logger.debug(f"Making request to {url} with content {json}")
|
|
88
|
-
|
|
89
|
-
request_start = time.time()
|
|
90
|
-
try:
|
|
91
|
-
response = await self._requester(
|
|
92
|
-
session,
|
|
93
|
-
method,
|
|
94
|
-
url,
|
|
95
|
-
headers=headers,
|
|
96
|
-
json=json,
|
|
97
|
-
params=params,
|
|
98
|
-
**kwargs,
|
|
99
|
-
)
|
|
100
|
-
except (
|
|
101
|
-
aiohttp.ClientConnectionError
|
|
102
|
-
) as e: # ensure the SDK user does not have to deal with knowing aiohttp
|
|
103
|
-
self._logger.debug(
|
|
104
|
-
f"Fatal connection error after {time.time()-request_start:.1f}s: {e}"
|
|
105
|
-
)
|
|
106
|
-
raise MistralConnectionException(str(e)) from e
|
|
107
|
-
except (
|
|
108
|
-
aiohttp.ClientResponseError
|
|
109
|
-
) as e: # status 500 or something remains after retries
|
|
110
|
-
self._logger.debug(
|
|
111
|
-
f"Fatal ClientResponseError error after {time.time()-request_start:.1f}s: {e}"
|
|
112
|
-
)
|
|
113
|
-
raise MistralConnectionException(str(e)) from e
|
|
114
|
-
except asyncio.TimeoutError as e:
|
|
115
|
-
self._logger.debug(
|
|
116
|
-
f"Fatal timeout error after {time.time()-request_start:.1f}s: {e}"
|
|
117
|
-
)
|
|
118
|
-
raise MistralConnectionException("The request timed out") from e
|
|
119
|
-
except Exception as e: # Anything caught here should be added above
|
|
120
|
-
self._logger.debug(
|
|
121
|
-
f"Unexpected fatal error after {time.time()-request_start:.1f}s: {e}"
|
|
122
|
-
)
|
|
123
|
-
raise MistralException(
|
|
124
|
-
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
125
|
-
) from e
|
|
126
|
-
|
|
127
|
-
self._logger.debug(
|
|
128
|
-
f"Received response with status {response.status} after {time.time()-request_start:.1f}s"
|
|
129
|
-
)
|
|
130
|
-
return response
|
|
131
|
-
|
|
132
|
-
async def session(self) -> aiohttp.ClientSession:
|
|
133
|
-
if self._session is None:
|
|
134
|
-
self._session = aiohttp.ClientSession(
|
|
135
|
-
timeout=aiohttp.ClientTimeout(self._timeout),
|
|
136
|
-
connector=aiohttp.TCPConnector(limit=0),
|
|
137
|
-
)
|
|
138
|
-
self._semaphore = asyncio.Semaphore(self._max_concurrent_requests)
|
|
139
|
-
self._requester = self.build_aio_requester()
|
|
140
|
-
return self._session
|
|
141
|
-
|
|
142
|
-
async def close(self) -> None:
|
|
143
|
-
if self._session is not None:
|
|
144
|
-
await self._session.close()
|
|
145
|
-
self._session = None
|
|
146
|
-
|
|
147
|
-
def __del__(self) -> None:
|
|
148
|
-
# https://stackoverflow.com/questions/54770360/how-can-i-wait-for-an-objects-del-to-finish-before-the-async-loop-closes
|
|
149
|
-
if self._session:
|
|
150
|
-
try:
|
|
151
|
-
loop = asyncio.get_event_loop()
|
|
152
|
-
if loop.is_running():
|
|
153
|
-
loop.create_task(self.close())
|
|
154
|
-
else:
|
|
155
|
-
loop.run_until_complete(self.close())
|
|
156
|
-
except Exception:
|
|
157
|
-
pass
|
|
158
|
-
|
|
159
|
-
|
|
160
33
|
class MistralAsyncClient(ClientBase):
|
|
161
34
|
def __init__(
|
|
162
35
|
self,
|
|
@@ -168,14 +41,15 @@ class MistralAsyncClient(ClientBase):
|
|
|
168
41
|
):
|
|
169
42
|
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
170
43
|
|
|
171
|
-
self.
|
|
172
|
-
|
|
173
|
-
max_retries=max_retries,
|
|
44
|
+
self._client = AsyncClient(
|
|
45
|
+
follow_redirects=True,
|
|
174
46
|
timeout=timeout,
|
|
47
|
+
limits=Limits(max_connections=max_concurrent_requests),
|
|
48
|
+
transport=AsyncHTTPTransport(retries=max_retries),
|
|
175
49
|
)
|
|
176
50
|
|
|
177
51
|
async def close(self) -> None:
|
|
178
|
-
await self.
|
|
52
|
+
await self._client.aclose()
|
|
179
53
|
|
|
180
54
|
async def _request(
|
|
181
55
|
self,
|
|
@@ -183,9 +57,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
183
57
|
json: Dict[str, Any],
|
|
184
58
|
path: str,
|
|
185
59
|
stream: bool = False,
|
|
186
|
-
|
|
187
|
-
) ->
|
|
188
|
-
|
|
60
|
+
attempt: int = 1,
|
|
61
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
189
62
|
headers = {
|
|
190
63
|
"Authorization": f"Bearer {self._api_key}",
|
|
191
64
|
"Content-Type": "application/json",
|
|
@@ -193,27 +66,60 @@ class MistralAsyncClient(ClientBase):
|
|
|
193
66
|
|
|
194
67
|
url = posixpath.join(self._endpoint, path)
|
|
195
68
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
if stream:
|
|
200
|
-
return response
|
|
69
|
+
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
70
|
+
|
|
71
|
+
response: Response
|
|
201
72
|
|
|
202
73
|
try:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
74
|
+
if stream:
|
|
75
|
+
async with self._client.stream(
|
|
76
|
+
method,
|
|
77
|
+
url,
|
|
78
|
+
headers=headers,
|
|
79
|
+
json=json,
|
|
80
|
+
) as response:
|
|
81
|
+
self._check_streaming_response(response)
|
|
82
|
+
|
|
83
|
+
async for line in response.aiter_lines():
|
|
84
|
+
json_streamed_response = self._process_line(line)
|
|
85
|
+
if json_streamed_response:
|
|
86
|
+
yield json_streamed_response
|
|
87
|
+
|
|
88
|
+
else:
|
|
89
|
+
response = await self._client.request(
|
|
90
|
+
method,
|
|
91
|
+
url,
|
|
92
|
+
headers=headers,
|
|
93
|
+
json=json,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
yield self._check_response(response)
|
|
97
|
+
|
|
98
|
+
except ConnectError as e:
|
|
99
|
+
raise MistralConnectionException(str(e)) from e
|
|
100
|
+
except RequestError as e:
|
|
101
|
+
raise MistralException(
|
|
102
|
+
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
103
|
+
) from e
|
|
104
|
+
except JSONDecodeError as e:
|
|
105
|
+
raise MistralAPIException.from_response(
|
|
210
106
|
response,
|
|
211
|
-
message=f"
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
107
|
+
message=f"Failed to decode json body: {response.text}",
|
|
108
|
+
) from e
|
|
109
|
+
except MistralAPIStatusException as e:
|
|
110
|
+
attempt += 1
|
|
111
|
+
if attempt > self._max_retries:
|
|
112
|
+
raise MistralAPIStatusException.from_response(
|
|
113
|
+
response, message=str(e)
|
|
114
|
+
) from e
|
|
115
|
+
backoff = 2.0**attempt # exponential backoff
|
|
116
|
+
time.sleep(backoff)
|
|
117
|
+
|
|
118
|
+
# Retry as a generator
|
|
119
|
+
async for r in self._request(
|
|
120
|
+
method, json, path, stream=stream, attempt=attempt
|
|
121
|
+
):
|
|
122
|
+
yield r
|
|
217
123
|
|
|
218
124
|
async def chat(
|
|
219
125
|
self,
|
|
@@ -225,7 +131,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
225
131
|
random_seed: Optional[int] = None,
|
|
226
132
|
safe_mode: bool = False,
|
|
227
133
|
) -> ChatCompletionResponse:
|
|
228
|
-
"""
|
|
134
|
+
"""A asynchronous chat endpoint that returns a single response.
|
|
229
135
|
|
|
230
136
|
Args:
|
|
231
137
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -252,9 +158,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
252
158
|
safe_mode=safe_mode,
|
|
253
159
|
)
|
|
254
160
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
161
|
+
single_response = self._request("post", request, "v1/chat/completions")
|
|
162
|
+
|
|
163
|
+
async for response in single_response:
|
|
164
|
+
return ChatCompletionResponse(**response)
|
|
165
|
+
|
|
166
|
+
raise MistralException("No response received")
|
|
258
167
|
|
|
259
168
|
async def chat_stream(
|
|
260
169
|
self,
|
|
@@ -266,7 +175,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
266
175
|
random_seed: Optional[int] = None,
|
|
267
176
|
safe_mode: bool = False,
|
|
268
177
|
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
269
|
-
"""
|
|
178
|
+
"""An Asynchronous chat endpoint that streams responses.
|
|
270
179
|
|
|
271
180
|
Args:
|
|
272
181
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -294,24 +203,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
294
203
|
stream=True,
|
|
295
204
|
safe_mode=safe_mode,
|
|
296
205
|
)
|
|
297
|
-
async_response =
|
|
206
|
+
async_response = self._request(
|
|
298
207
|
"post", request, "v1/chat/completions", stream=True
|
|
299
208
|
)
|
|
300
209
|
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
), "Bad response from _request"
|
|
304
|
-
|
|
305
|
-
async with async_response as response:
|
|
306
|
-
async for line in response.content:
|
|
307
|
-
if line == b"\n":
|
|
308
|
-
continue
|
|
309
|
-
|
|
310
|
-
if line.startswith(b"data: "):
|
|
311
|
-
line = line[6:].strip()
|
|
312
|
-
if line != b"[DONE]":
|
|
313
|
-
json_response = orjson.loads(line)
|
|
314
|
-
yield ChatCompletionStreamResponse(**json_response)
|
|
210
|
+
async for json_response in async_response:
|
|
211
|
+
yield ChatCompletionStreamResponse(**json_response)
|
|
315
212
|
|
|
316
213
|
async def embeddings(
|
|
317
214
|
self, model: str, input: Union[str, List[str]]
|
|
@@ -327,9 +224,12 @@ class MistralAsyncClient(ClientBase):
|
|
|
327
224
|
EmbeddingResponse: A response object containing the embeddings.
|
|
328
225
|
"""
|
|
329
226
|
request = {"model": model, "input": input}
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
227
|
+
single_response = self._request("post", request, "v1/embeddings")
|
|
228
|
+
|
|
229
|
+
async for response in single_response:
|
|
230
|
+
return EmbeddingResponse(**response)
|
|
231
|
+
|
|
232
|
+
raise MistralException("No response received")
|
|
333
233
|
|
|
334
234
|
async def list_models(self) -> ModelList:
|
|
335
235
|
"""Returns a list of the available models
|
|
@@ -337,6 +237,9 @@ class MistralAsyncClient(ClientBase):
|
|
|
337
237
|
Returns:
|
|
338
238
|
ModelList: A response object containing the list of models.
|
|
339
239
|
"""
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
240
|
+
single_response = self._request("get", {}, "v1/models")
|
|
241
|
+
|
|
242
|
+
async for response in single_response:
|
|
243
|
+
return ModelList(**response)
|
|
244
|
+
|
|
245
|
+
raise MistralException("No response received")
|
mistralai/client.py
CHANGED
|
@@ -1,18 +1,16 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import posixpath
|
|
3
|
+
import time
|
|
3
4
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
5
|
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
|
5
6
|
|
|
6
|
-
import
|
|
7
|
-
import requests
|
|
8
|
-
from requests import Response
|
|
9
|
-
from requests.adapters import HTTPAdapter
|
|
10
|
-
from urllib3.util.retry import Retry
|
|
7
|
+
from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
|
|
11
8
|
|
|
12
9
|
from mistralai.client_base import ClientBase
|
|
13
|
-
from mistralai.constants import ENDPOINT
|
|
10
|
+
from mistralai.constants import ENDPOINT
|
|
14
11
|
from mistralai.exceptions import (
|
|
15
12
|
MistralAPIException,
|
|
13
|
+
MistralAPIStatusException,
|
|
16
14
|
MistralConnectionException,
|
|
17
15
|
MistralException,
|
|
18
16
|
)
|
|
@@ -39,14 +37,22 @@ class MistralClient(ClientBase):
|
|
|
39
37
|
):
|
|
40
38
|
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
41
39
|
|
|
40
|
+
self._client = Client(
|
|
41
|
+
follow_redirects=True,
|
|
42
|
+
timeout=self._timeout,
|
|
43
|
+
transport=HTTPTransport(retries=self._max_retries))
|
|
44
|
+
|
|
45
|
+
def __del__(self) -> None:
|
|
46
|
+
self._client.close()
|
|
47
|
+
|
|
42
48
|
def _request(
|
|
43
49
|
self,
|
|
44
50
|
method: str,
|
|
45
51
|
json: Dict[str, Any],
|
|
46
52
|
path: str,
|
|
47
53
|
stream: bool = False,
|
|
48
|
-
|
|
49
|
-
) ->
|
|
54
|
+
attempt: int = 1,
|
|
55
|
+
) -> Iterator[Dict[str, Any]]:
|
|
50
56
|
headers = {
|
|
51
57
|
"Authorization": f"Bearer {self._api_key}",
|
|
52
58
|
"Content-Type": "application/json",
|
|
@@ -54,49 +60,58 @@ class MistralClient(ClientBase):
|
|
|
54
60
|
|
|
55
61
|
url = posixpath.join(self._endpoint, path)
|
|
56
62
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
backoff_factor=0.5,
|
|
61
|
-
allowed_methods=["POST", "GET"],
|
|
62
|
-
status_forcelist=RETRY_STATUS_CODES,
|
|
63
|
-
raise_on_status=False,
|
|
64
|
-
)
|
|
65
|
-
session.mount("https://", HTTPAdapter(max_retries=retries))
|
|
66
|
-
session.mount("http://", HTTPAdapter(max_retries=retries))
|
|
63
|
+
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
64
|
+
|
|
65
|
+
response: Response
|
|
67
66
|
|
|
67
|
+
try:
|
|
68
68
|
if stream:
|
|
69
|
-
|
|
70
|
-
method,
|
|
71
|
-
|
|
69
|
+
with self._client.stream(
|
|
70
|
+
method,
|
|
71
|
+
url,
|
|
72
|
+
headers=headers,
|
|
73
|
+
json=json,
|
|
74
|
+
) as response:
|
|
75
|
+
self._check_streaming_response(response)
|
|
76
|
+
|
|
77
|
+
for line in response.iter_lines():
|
|
78
|
+
json_streamed_response = self._process_line(line)
|
|
79
|
+
if json_streamed_response:
|
|
80
|
+
yield json_streamed_response
|
|
72
81
|
|
|
73
|
-
|
|
74
|
-
response =
|
|
82
|
+
else:
|
|
83
|
+
response = self._client.request(
|
|
75
84
|
method,
|
|
76
85
|
url,
|
|
77
86
|
headers=headers,
|
|
78
87
|
json=json,
|
|
79
|
-
timeout=self._timeout,
|
|
80
|
-
params=params,
|
|
81
88
|
)
|
|
82
|
-
except requests.exceptions.ConnectionError as e:
|
|
83
|
-
raise MistralConnectionException(str(e)) from e
|
|
84
|
-
except requests.exceptions.RequestException as e:
|
|
85
|
-
raise MistralException(
|
|
86
|
-
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
87
|
-
) from e
|
|
88
89
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
90
|
+
yield self._check_response(response)
|
|
91
|
+
|
|
92
|
+
except ConnectError as e:
|
|
93
|
+
raise MistralConnectionException(str(e)) from e
|
|
94
|
+
except RequestError as e:
|
|
95
|
+
raise MistralException(
|
|
96
|
+
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
97
|
+
) from e
|
|
98
|
+
except JSONDecodeError as e:
|
|
99
|
+
raise MistralAPIException.from_response(
|
|
100
|
+
response,
|
|
101
|
+
message=f"Failed to decode json body: {response.text}",
|
|
102
|
+
) from e
|
|
103
|
+
except MistralAPIStatusException as e:
|
|
104
|
+
attempt += 1
|
|
105
|
+
if attempt > self._max_retries:
|
|
106
|
+
raise MistralAPIStatusException.from_response(
|
|
107
|
+
response, message=str(e)
|
|
108
|
+
) from e
|
|
109
|
+
backoff = 2.0**attempt # exponential backoff
|
|
110
|
+
time.sleep(backoff)
|
|
95
111
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
return json_response
|
|
112
|
+
# Retry as a generator
|
|
113
|
+
for r in self._request(method, json, path, stream=stream, attempt=attempt):
|
|
114
|
+
yield r
|
|
100
115
|
|
|
101
116
|
def chat(
|
|
102
117
|
self,
|
|
@@ -108,7 +123,7 @@ class MistralClient(ClientBase):
|
|
|
108
123
|
random_seed: Optional[int] = None,
|
|
109
124
|
safe_mode: bool = False,
|
|
110
125
|
) -> ChatCompletionResponse:
|
|
111
|
-
"""
|
|
126
|
+
"""A chat endpoint that returns a single response.
|
|
112
127
|
|
|
113
128
|
Args:
|
|
114
129
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -135,11 +150,12 @@ class MistralClient(ClientBase):
|
|
|
135
150
|
safe_mode=safe_mode,
|
|
136
151
|
)
|
|
137
152
|
|
|
138
|
-
|
|
153
|
+
single_response = self._request("post", request, "v1/chat/completions")
|
|
139
154
|
|
|
140
|
-
|
|
155
|
+
for response in single_response:
|
|
156
|
+
return ChatCompletionResponse(**response)
|
|
141
157
|
|
|
142
|
-
|
|
158
|
+
raise MistralException("No response received")
|
|
143
159
|
|
|
144
160
|
def chat_stream(
|
|
145
161
|
self,
|
|
@@ -151,7 +167,7 @@ class MistralClient(ClientBase):
|
|
|
151
167
|
random_seed: Optional[int] = None,
|
|
152
168
|
safe_mode: bool = False,
|
|
153
169
|
) -> Iterable[ChatCompletionStreamResponse]:
|
|
154
|
-
"""
|
|
170
|
+
"""A chat endpoint that streams responses.
|
|
155
171
|
|
|
156
172
|
Args:
|
|
157
173
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -181,18 +197,8 @@ class MistralClient(ClientBase):
|
|
|
181
197
|
|
|
182
198
|
response = self._request("post", request, "v1/chat/completions", stream=True)
|
|
183
199
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
for line in response.iter_lines():
|
|
187
|
-
self._logger.debug(f"Received line: {line}")
|
|
188
|
-
if line == b"\n":
|
|
189
|
-
continue
|
|
190
|
-
|
|
191
|
-
if line.startswith(b"data: "):
|
|
192
|
-
line = line[6:].strip()
|
|
193
|
-
if line != b"[DONE]":
|
|
194
|
-
json_response = orjson.loads(line)
|
|
195
|
-
yield ChatCompletionStreamResponse(**json_response)
|
|
200
|
+
for json_streamed_response in response:
|
|
201
|
+
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
196
202
|
|
|
197
203
|
def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse:
|
|
198
204
|
"""An embeddings endpoint that returns embeddings for a single, or batch of inputs
|
|
@@ -206,9 +212,12 @@ class MistralClient(ClientBase):
|
|
|
206
212
|
EmbeddingResponse: A response object containing the embeddings.
|
|
207
213
|
"""
|
|
208
214
|
request = {"model": model, "input": input}
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
215
|
+
singleton_response = self._request("post", request, "v1/embeddings")
|
|
216
|
+
|
|
217
|
+
for response in singleton_response:
|
|
218
|
+
return EmbeddingResponse(**response)
|
|
219
|
+
|
|
220
|
+
raise MistralException("No response received")
|
|
212
221
|
|
|
213
222
|
def list_models(self) -> ModelList:
|
|
214
223
|
"""Returns a list of the available models
|
|
@@ -216,6 +225,9 @@ class MistralClient(ClientBase):
|
|
|
216
225
|
Returns:
|
|
217
226
|
ModelList: A response object containing the list of models.
|
|
218
227
|
"""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
228
|
+
singleton_response = self._request("get", {}, "v1/models")
|
|
229
|
+
|
|
230
|
+
for response in singleton_response:
|
|
231
|
+
return ModelList(**response)
|
|
232
|
+
|
|
233
|
+
raise MistralException("No response received")
|
mistralai/client_base.py
CHANGED
|
@@ -1,10 +1,24 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from abc import ABC
|
|
3
4
|
from typing import Any, Dict, List, Optional
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
import orjson
|
|
7
|
+
from httpx import Response
|
|
8
|
+
|
|
9
|
+
from mistralai.constants import RETRY_STATUS_CODES
|
|
10
|
+
from mistralai.exceptions import (
|
|
11
|
+
MistralAPIException,
|
|
12
|
+
MistralAPIStatusException,
|
|
13
|
+
MistralException,
|
|
14
|
+
)
|
|
6
15
|
from mistralai.models.chat_completion import ChatMessage
|
|
7
16
|
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
19
|
+
level=os.getenv("LOG_LEVEL", "ERROR"),
|
|
20
|
+
)
|
|
21
|
+
|
|
8
22
|
|
|
9
23
|
class ClientBase(ABC):
|
|
10
24
|
def __init__(
|
|
@@ -21,8 +35,8 @@ class ClientBase(ABC):
|
|
|
21
35
|
self._api_key = api_key
|
|
22
36
|
self._logger = logging.getLogger(__name__)
|
|
23
37
|
|
|
24
|
-
@staticmethod
|
|
25
38
|
def _make_chat_request(
|
|
39
|
+
self,
|
|
26
40
|
model: str,
|
|
27
41
|
messages: List[ChatMessage],
|
|
28
42
|
temperature: Optional[float] = None,
|
|
@@ -48,26 +62,48 @@ class ClientBase(ABC):
|
|
|
48
62
|
if stream is not None:
|
|
49
63
|
request_data["stream"] = stream
|
|
50
64
|
|
|
65
|
+
self._logger.debug(f"Chat request: {request_data}")
|
|
66
|
+
|
|
51
67
|
return request_data
|
|
52
68
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
69
|
+
def _check_response_status_codes(self, response: Response) -> None:
|
|
70
|
+
if response.status_code in RETRY_STATUS_CODES:
|
|
71
|
+
raise MistralAPIStatusException.from_response(
|
|
72
|
+
response,
|
|
73
|
+
message=f"Cannot stream response. Status: {response.status_code}",
|
|
74
|
+
)
|
|
75
|
+
elif 400 <= response.status_code < 500:
|
|
76
|
+
raise MistralAPIException.from_response(
|
|
77
|
+
response,
|
|
78
|
+
message=f"Cannot stream response. Status: {response.status_code}",
|
|
79
|
+
)
|
|
80
|
+
elif response.status_code >= 500:
|
|
81
|
+
raise MistralException(
|
|
82
|
+
message=f"Unexpected server error (status {response.status_code})"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _check_streaming_response(self, response: Response) -> None:
|
|
86
|
+
self._check_response_status_codes(response)
|
|
87
|
+
|
|
88
|
+
def _check_response(self, response: Response) -> Dict[str, Any]:
|
|
89
|
+
self._check_response_status_codes(response)
|
|
90
|
+
|
|
91
|
+
json_response: Dict[str, Any] = response.json()
|
|
92
|
+
|
|
56
93
|
if "object" not in json_response:
|
|
57
94
|
raise MistralException(message=f"Unexpected response: {json_response}")
|
|
58
95
|
if "error" == json_response["object"]: # has errors
|
|
59
|
-
raise MistralAPIException(
|
|
96
|
+
raise MistralAPIException.from_response(
|
|
97
|
+
response,
|
|
60
98
|
message=json_response["message"],
|
|
61
|
-
http_status=status,
|
|
62
|
-
headers=headers,
|
|
63
|
-
)
|
|
64
|
-
if 400 <= status < 500:
|
|
65
|
-
raise MistralAPIException(
|
|
66
|
-
message=f"Unexpected client error (status {status}): {json_response}",
|
|
67
|
-
http_status=status,
|
|
68
|
-
headers=headers,
|
|
69
|
-
)
|
|
70
|
-
if status >= 500:
|
|
71
|
-
raise MistralException(
|
|
72
|
-
message=f"Unexpected server error (status {status}): {json_response}"
|
|
73
99
|
)
|
|
100
|
+
|
|
101
|
+
return json_response
|
|
102
|
+
|
|
103
|
+
def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
|
|
104
|
+
if line.startswith("data: "):
|
|
105
|
+
line = line[6:].strip()
|
|
106
|
+
if line != "[DONE]":
|
|
107
|
+
json_streamed_response: Dict[str, Any] = orjson.loads(line)
|
|
108
|
+
return json_streamed_response
|
|
109
|
+
return None
|
mistralai/exceptions.py
CHANGED
|
@@ -2,8 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
from requests import Response
|
|
5
|
+
from httpx import Response
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class MistralException(Exception):
|
|
@@ -45,19 +44,11 @@ class MistralAPIException(MistralException):
|
|
|
45
44
|
headers=dict(response.headers),
|
|
46
45
|
)
|
|
47
46
|
|
|
48
|
-
@classmethod
|
|
49
|
-
def from_aio_response(
|
|
50
|
-
cls, response: aiohttp.ClientResponse, message: Optional[str] = None
|
|
51
|
-
) -> MistralAPIException:
|
|
52
|
-
return cls(
|
|
53
|
-
message=message,
|
|
54
|
-
http_status=response.status,
|
|
55
|
-
headers=dict(response.headers),
|
|
56
|
-
)
|
|
57
|
-
|
|
58
47
|
def __repr__(self) -> str:
|
|
59
48
|
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
|
|
60
49
|
|
|
50
|
+
class MistralAPIStatusException(MistralAPIException):
|
|
51
|
+
"""Returned when we receive a non-200 response from the API that we should retry"""
|
|
61
52
|
|
|
62
53
|
class MistralConnectionException(MistralException):
|
|
63
54
|
"""Returned when the SDK can not reach the API server for any reason"""
|
mistralai/py.typed
ADDED
|
File without changes
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Bam4d
|
|
6
6
|
Author-email: bam4d@mistral.ai
|
|
@@ -11,19 +11,25 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist: backoff (>=2.2.1,<3.0.0)
|
|
14
|
+
Requires-Dist: httpx (>=0.25.2,<0.26.0)
|
|
16
15
|
Requires-Dist: orjson (>=3.9.10,<4.0.0)
|
|
17
16
|
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
18
|
-
Requires-Dist: requests (>=2.31.0,<3.0.0)
|
|
19
17
|
Description-Content-Type: text/markdown
|
|
20
18
|
|
|
21
|
-
This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
|
|
22
|
-
|
|
23
19
|
# Mistral Python Client
|
|
24
20
|
|
|
21
|
+
This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
|
|
22
|
+
|
|
25
23
|
You can use the Mistral Python client to interact with the Mistral AI API.
|
|
26
24
|
|
|
25
|
+
## Installing
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install mistralai
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### From Source
|
|
32
|
+
|
|
27
33
|
This client uses `poetry` as a dependency and virtual environment manager.
|
|
28
34
|
|
|
29
35
|
You can install poetry with
|
|
@@ -32,8 +38,6 @@ You can install poetry with
|
|
|
32
38
|
pip install poetry
|
|
33
39
|
```
|
|
34
40
|
|
|
35
|
-
## Installing
|
|
36
|
-
|
|
37
41
|
`poetry` will set up a virtual environment and install dependencies with the following command:
|
|
38
42
|
|
|
39
43
|
```bash
|
|
@@ -44,6 +48,21 @@ poetry install
|
|
|
44
48
|
|
|
45
49
|
You can run the examples in the `examples/` directory using `poetry run` or by entering the virtual environment using `poetry shell`.
|
|
46
50
|
|
|
51
|
+
### API Key Setup
|
|
52
|
+
|
|
53
|
+
Running the examples requires a Mistral AI API key.
|
|
54
|
+
|
|
55
|
+
1. Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
|
|
56
|
+
2. Set your Mistral API Key as an environment variable. You only need to do this once.
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# set Mistral API Key (using zsh for example)
|
|
60
|
+
$ echo 'export MISTRAL_API_KEY=[your_key_here]' >> ~/.zshenv
|
|
61
|
+
|
|
62
|
+
# reload the environment (or just quit and open a new terminal)
|
|
63
|
+
$ source ~/.zshenv
|
|
64
|
+
```
|
|
65
|
+
|
|
47
66
|
### Using poetry run
|
|
48
67
|
|
|
49
68
|
```bash
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
mistralai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
mistralai/async_client.py,sha256=2tfOsASbMvs0jK4Yj4ePXuY4ZfULw439xFhJaSZ-98o,8768
|
|
3
|
+
mistralai/client.py,sha256=LMA3PC67bNe-kHTI55WiZMzN7--w_RuJ3n1MQZHsp6E,8446
|
|
4
|
+
mistralai/client_base.py,sha256=jzpclGy1o016RHGaRH9jZEQ4whL0ZR7lO9m_l9aDnKY,3609
|
|
5
|
+
mistralai/constants.py,sha256=KK286HFjpoTxPKih8xdTp0lW4YMDPyYu2Shi3Nu5Vdw,86
|
|
6
|
+
mistralai/exceptions.py,sha256=d1cli28ZaPEBQy2RKKcSh-dO2T20GXHNoaJrXXBmyIs,1664
|
|
7
|
+
mistralai/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
mistralai/models/chat_completion.py,sha256=MzAIFrmORSrlyj45_JUk8KP8NyNjbDA7HChbtTY4S5A,1038
|
|
9
|
+
mistralai/models/common.py,sha256=zatP4aV_LIEpzj3_igsKkJBICwGhmXG0LX3CdO3kn-o,172
|
|
10
|
+
mistralai/models/embeddings.py,sha256=-VthLQBj6wrq7HXJbGmnkQEEanSemA3MAlaMFh94VBg,331
|
|
11
|
+
mistralai/models/models.py,sha256=I4I1kQwP0tJ2_rd4lbSUwOkCJE5E-dSTIExQg2ZNFcw,713
|
|
12
|
+
mistralai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
mistralai-0.0.9.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
14
|
+
mistralai-0.0.9.dist-info/METADATA,sha256=oWiSfOmHSlvwtqGsynFuIhwh1CPPsDoG4PI_GynvKco,1886
|
|
15
|
+
mistralai-0.0.9.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
16
|
+
mistralai-0.0.9.dist-info/RECORD,,
|
mistralai-0.0.8.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
mistralai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
mistralai/async_client.py,sha256=XPeDjfhkGmwxtinq-KM7LxGPOzhjYz2whafQEcmDriA,12913
|
|
3
|
-
mistralai/client.py,sha256=oE1kjNDKoCRLfrV8DHkaVTZ_YITZPH9z_iThFh01RkE,8317
|
|
4
|
-
mistralai/client_base.py,sha256=URoVJqca_4Jcyji3iHNkVZN7PZBhlpGOoLGV4_hp80A,2472
|
|
5
|
-
mistralai/constants.py,sha256=KK286HFjpoTxPKih8xdTp0lW4YMDPyYu2Shi3Nu5Vdw,86
|
|
6
|
-
mistralai/exceptions.py,sha256=C4fWe4qYeYd8XmpcUmKugidokkfjkBgjYMTaCZqy4Wo,1836
|
|
7
|
-
mistralai/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
mistralai/models/chat_completion.py,sha256=MzAIFrmORSrlyj45_JUk8KP8NyNjbDA7HChbtTY4S5A,1038
|
|
9
|
-
mistralai/models/common.py,sha256=zatP4aV_LIEpzj3_igsKkJBICwGhmXG0LX3CdO3kn-o,172
|
|
10
|
-
mistralai/models/embeddings.py,sha256=-VthLQBj6wrq7HXJbGmnkQEEanSemA3MAlaMFh94VBg,331
|
|
11
|
-
mistralai/models/models.py,sha256=I4I1kQwP0tJ2_rd4lbSUwOkCJE5E-dSTIExQg2ZNFcw,713
|
|
12
|
-
mistralai-0.0.8.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
13
|
-
mistralai-0.0.8.dist-info/METADATA,sha256=ABTT6TfZoxM34HdHVDjZk71Rs2tbRvmBwbhEnA71Gkc,1481
|
|
14
|
-
mistralai-0.0.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
|
15
|
-
mistralai-0.0.8.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|