mistralai 0.0.8__py3-none-any.whl → 0.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mistralai/async_client.py CHANGED
@@ -1,20 +1,23 @@
1
- import asyncio
2
- import logging
3
1
  import os
4
2
  import posixpath
5
3
  import time
6
- from collections import defaultdict
7
4
  from json import JSONDecodeError
8
- from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
9
-
10
- import aiohttp
11
- import backoff
12
- import orjson
5
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
6
+
7
+ from httpx import (
8
+ AsyncClient,
9
+ AsyncHTTPTransport,
10
+ ConnectError,
11
+ Limits,
12
+ RequestError,
13
+ Response,
14
+ )
13
15
 
14
16
  from mistralai.client_base import ClientBase
15
- from mistralai.constants import ENDPOINT, RETRY_STATUS_CODES
17
+ from mistralai.constants import ENDPOINT
16
18
  from mistralai.exceptions import (
17
19
  MistralAPIException,
20
+ MistralAPIStatusException,
18
21
  MistralConnectionException,
19
22
  MistralException,
20
23
  )
@@ -27,136 +30,6 @@ from mistralai.models.embeddings import EmbeddingResponse
27
30
  from mistralai.models.models import ModelList
28
31
 
29
32
 
30
- class AIOHTTPBackend:
31
- """HTTP backend which handles retries, concurrency limiting and logging"""
32
-
33
- SLEEP_AFTER_FAILURE = defaultdict(lambda: 0.25, {429: 5.0})
34
-
35
- _requester: Callable[..., Awaitable[aiohttp.ClientResponse]]
36
- _semaphore: asyncio.Semaphore
37
- _session: Optional[aiohttp.ClientSession]
38
-
39
- def __init__(
40
- self,
41
- max_concurrent_requests: int = 64,
42
- max_retries: int = 5,
43
- timeout: int = 120,
44
- ):
45
- self._logger = logging.getLogger(__name__)
46
- self._timeout = timeout
47
- self._max_retries = max_retries
48
- self._session = None
49
- self._max_concurrent_requests = max_concurrent_requests
50
-
51
- def build_aio_requester(
52
- self,
53
- ) -> Callable: # returns a function for retryable requests
54
- @backoff.on_exception(
55
- backoff.expo,
56
- (aiohttp.ClientError, aiohttp.ClientResponseError),
57
- max_tries=self._max_retries + 1,
58
- max_time=self._timeout,
59
- )
60
- async def make_request_fn(
61
- session: aiohttp.ClientSession, *args: Any, **kwargs: Any
62
- ) -> aiohttp.ClientResponse:
63
- async with self._semaphore: # this limits total concurrency by the client
64
- response = await session.request(*args, **kwargs)
65
- if (
66
- response.status in RETRY_STATUS_CODES
67
- ): # likely temporary, raise to retry
68
- self._logger.info(f"Received status {response.status}, retrying...")
69
- await asyncio.sleep(self.SLEEP_AFTER_FAILURE[response.status])
70
- response.raise_for_status()
71
-
72
- return response
73
-
74
- return make_request_fn
75
-
76
- async def request(
77
- self,
78
- url: str,
79
- json: Optional[Dict[str, Any]] = None,
80
- method: str = "post",
81
- headers: Optional[Dict[str, Any]] = None,
82
- session: Optional[aiohttp.ClientSession] = None,
83
- params: Optional[Dict[str, Any]] = None,
84
- **kwargs: Any,
85
- ) -> aiohttp.ClientResponse:
86
- session = session or await self.session()
87
- self._logger.debug(f"Making request to {url} with content {json}")
88
-
89
- request_start = time.time()
90
- try:
91
- response = await self._requester(
92
- session,
93
- method,
94
- url,
95
- headers=headers,
96
- json=json,
97
- params=params,
98
- **kwargs,
99
- )
100
- except (
101
- aiohttp.ClientConnectionError
102
- ) as e: # ensure the SDK user does not have to deal with knowing aiohttp
103
- self._logger.debug(
104
- f"Fatal connection error after {time.time()-request_start:.1f}s: {e}"
105
- )
106
- raise MistralConnectionException(str(e)) from e
107
- except (
108
- aiohttp.ClientResponseError
109
- ) as e: # status 500 or something remains after retries
110
- self._logger.debug(
111
- f"Fatal ClientResponseError error after {time.time()-request_start:.1f}s: {e}"
112
- )
113
- raise MistralConnectionException(str(e)) from e
114
- except asyncio.TimeoutError as e:
115
- self._logger.debug(
116
- f"Fatal timeout error after {time.time()-request_start:.1f}s: {e}"
117
- )
118
- raise MistralConnectionException("The request timed out") from e
119
- except Exception as e: # Anything caught here should be added above
120
- self._logger.debug(
121
- f"Unexpected fatal error after {time.time()-request_start:.1f}s: {e}"
122
- )
123
- raise MistralException(
124
- f"Unexpected exception ({e.__class__.__name__}): {e}"
125
- ) from e
126
-
127
- self._logger.debug(
128
- f"Received response with status {response.status} after {time.time()-request_start:.1f}s"
129
- )
130
- return response
131
-
132
- async def session(self) -> aiohttp.ClientSession:
133
- if self._session is None:
134
- self._session = aiohttp.ClientSession(
135
- timeout=aiohttp.ClientTimeout(self._timeout),
136
- connector=aiohttp.TCPConnector(limit=0),
137
- )
138
- self._semaphore = asyncio.Semaphore(self._max_concurrent_requests)
139
- self._requester = self.build_aio_requester()
140
- return self._session
141
-
142
- async def close(self) -> None:
143
- if self._session is not None:
144
- await self._session.close()
145
- self._session = None
146
-
147
- def __del__(self) -> None:
148
- # https://stackoverflow.com/questions/54770360/how-can-i-wait-for-an-objects-del-to-finish-before-the-async-loop-closes
149
- if self._session:
150
- try:
151
- loop = asyncio.get_event_loop()
152
- if loop.is_running():
153
- loop.create_task(self.close())
154
- else:
155
- loop.run_until_complete(self.close())
156
- except Exception:
157
- pass
158
-
159
-
160
33
  class MistralAsyncClient(ClientBase):
161
34
  def __init__(
162
35
  self,
@@ -168,14 +41,15 @@ class MistralAsyncClient(ClientBase):
168
41
  ):
169
42
  super().__init__(endpoint, api_key, max_retries, timeout)
170
43
 
171
- self._backend = AIOHTTPBackend(
172
- max_concurrent_requests=max_concurrent_requests,
173
- max_retries=max_retries,
44
+ self._client = AsyncClient(
45
+ follow_redirects=True,
174
46
  timeout=timeout,
47
+ limits=Limits(max_connections=max_concurrent_requests),
48
+ transport=AsyncHTTPTransport(retries=max_retries),
175
49
  )
176
50
 
177
51
  async def close(self) -> None:
178
- await self._backend.close()
52
+ await self._client.aclose()
179
53
 
180
54
  async def _request(
181
55
  self,
@@ -183,9 +57,8 @@ class MistralAsyncClient(ClientBase):
183
57
  json: Dict[str, Any],
184
58
  path: str,
185
59
  stream: bool = False,
186
- params: Optional[Dict[str, Any]] = None,
187
- ) -> Union[Dict[str, Any], aiohttp.ClientResponse]:
188
-
60
+ attempt: int = 1,
61
+ ) -> AsyncGenerator[Dict[str, Any], None]:
189
62
  headers = {
190
63
  "Authorization": f"Bearer {self._api_key}",
191
64
  "Content-Type": "application/json",
@@ -193,27 +66,60 @@ class MistralAsyncClient(ClientBase):
193
66
 
194
67
  url = posixpath.join(self._endpoint, path)
195
68
 
196
- response = await self._backend.request(
197
- url, json, method, headers, params=params
198
- )
199
- if stream:
200
- return response
69
+ self._logger.debug(f"Sending request: {method} {url} {json}")
70
+
71
+ response: Response
201
72
 
202
73
  try:
203
- json_response: Dict[str, Any] = await response.json()
204
- except JSONDecodeError:
205
- raise MistralAPIException.from_aio_response(
206
- response, message=f"Failed to decode json body: {await response.text()}"
207
- )
208
- except aiohttp.ClientPayloadError as e:
209
- raise MistralAPIException.from_aio_response(
74
+ if stream:
75
+ async with self._client.stream(
76
+ method,
77
+ url,
78
+ headers=headers,
79
+ json=json,
80
+ ) as response:
81
+ self._check_streaming_response(response)
82
+
83
+ async for line in response.aiter_lines():
84
+ json_streamed_response = self._process_line(line)
85
+ if json_streamed_response:
86
+ yield json_streamed_response
87
+
88
+ else:
89
+ response = await self._client.request(
90
+ method,
91
+ url,
92
+ headers=headers,
93
+ json=json,
94
+ )
95
+
96
+ yield self._check_response(response)
97
+
98
+ except ConnectError as e:
99
+ raise MistralConnectionException(str(e)) from e
100
+ except RequestError as e:
101
+ raise MistralException(
102
+ f"Unexpected exception ({e.__class__.__name__}): {e}"
103
+ ) from e
104
+ except JSONDecodeError as e:
105
+ raise MistralAPIException.from_response(
210
106
  response,
211
- message=f"An unexpected error occurred while receiving the response: {e}",
212
- )
213
-
214
- self._logger.debug(f"JSON response: {json_response}")
215
- self._check_response(json_response, dict(response.headers), response.status)
216
- return json_response
107
+ message=f"Failed to decode json body: {response.text}",
108
+ ) from e
109
+ except MistralAPIStatusException as e:
110
+ attempt += 1
111
+ if attempt > self._max_retries:
112
+ raise MistralAPIStatusException.from_response(
113
+ response, message=str(e)
114
+ ) from e
115
+ backoff = 2.0**attempt # exponential backoff
116
+ time.sleep(backoff)
117
+
118
+ # Retry as a generator
119
+ async for r in self._request(
120
+ method, json, path, stream=stream, attempt=attempt
121
+ ):
122
+ yield r
217
123
 
218
124
  async def chat(
219
125
  self,
@@ -225,7 +131,7 @@ class MistralAsyncClient(ClientBase):
225
131
  random_seed: Optional[int] = None,
226
132
  safe_mode: bool = False,
227
133
  ) -> ChatCompletionResponse:
228
- """ A asynchronous chat endpoint that returns a single response.
134
+ """A asynchronous chat endpoint that returns a single response.
229
135
 
230
136
  Args:
231
137
  model (str): model the name of the model to chat with, e.g. mistral-tiny
@@ -252,9 +158,12 @@ class MistralAsyncClient(ClientBase):
252
158
  safe_mode=safe_mode,
253
159
  )
254
160
 
255
- response = await self._request("post", request, "v1/chat/completions")
256
- assert isinstance(response, dict), "Bad response from _request"
257
- return ChatCompletionResponse(**response)
161
+ single_response = self._request("post", request, "v1/chat/completions")
162
+
163
+ async for response in single_response:
164
+ return ChatCompletionResponse(**response)
165
+
166
+ raise MistralException("No response received")
258
167
 
259
168
  async def chat_stream(
260
169
  self,
@@ -266,7 +175,7 @@ class MistralAsyncClient(ClientBase):
266
175
  random_seed: Optional[int] = None,
267
176
  safe_mode: bool = False,
268
177
  ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
269
- """ An Asynchronous chat endpoint that streams responses.
178
+ """An Asynchronous chat endpoint that streams responses.
270
179
 
271
180
  Args:
272
181
  model (str): model the name of the model to chat with, e.g. mistral-tiny
@@ -294,24 +203,12 @@ class MistralAsyncClient(ClientBase):
294
203
  stream=True,
295
204
  safe_mode=safe_mode,
296
205
  )
297
- async_response = await self._request(
206
+ async_response = self._request(
298
207
  "post", request, "v1/chat/completions", stream=True
299
208
  )
300
209
 
301
- assert isinstance(
302
- async_response, aiohttp.ClientResponse
303
- ), "Bad response from _request"
304
-
305
- async with async_response as response:
306
- async for line in response.content:
307
- if line == b"\n":
308
- continue
309
-
310
- if line.startswith(b"data: "):
311
- line = line[6:].strip()
312
- if line != b"[DONE]":
313
- json_response = orjson.loads(line)
314
- yield ChatCompletionStreamResponse(**json_response)
210
+ async for json_response in async_response:
211
+ yield ChatCompletionStreamResponse(**json_response)
315
212
 
316
213
  async def embeddings(
317
214
  self, model: str, input: Union[str, List[str]]
@@ -327,9 +224,12 @@ class MistralAsyncClient(ClientBase):
327
224
  EmbeddingResponse: A response object containing the embeddings.
328
225
  """
329
226
  request = {"model": model, "input": input}
330
- response = await self._request("post", request, "v1/embeddings")
331
- assert isinstance(response, dict), "Bad response from _request"
332
- return EmbeddingResponse(**response)
227
+ single_response = self._request("post", request, "v1/embeddings")
228
+
229
+ async for response in single_response:
230
+ return EmbeddingResponse(**response)
231
+
232
+ raise MistralException("No response received")
333
233
 
334
234
  async def list_models(self) -> ModelList:
335
235
  """Returns a list of the available models
@@ -337,6 +237,9 @@ class MistralAsyncClient(ClientBase):
337
237
  Returns:
338
238
  ModelList: A response object containing the list of models.
339
239
  """
340
- response = await self._request("get", {}, "v1/models")
341
- assert isinstance(response, dict), "Bad response from _request"
342
- return ModelList(**response)
240
+ single_response = self._request("get", {}, "v1/models")
241
+
242
+ async for response in single_response:
243
+ return ModelList(**response)
244
+
245
+ raise MistralException("No response received")
mistralai/client.py CHANGED
@@ -1,18 +1,16 @@
1
1
  import os
2
2
  import posixpath
3
+ import time
3
4
  from json import JSONDecodeError
4
- from typing import Any, Dict, Iterable, List, Optional, Union
5
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
5
6
 
6
- import orjson
7
- import requests
8
- from requests import Response
9
- from requests.adapters import HTTPAdapter
10
- from urllib3.util.retry import Retry
7
+ from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
11
8
 
12
9
  from mistralai.client_base import ClientBase
13
- from mistralai.constants import ENDPOINT, RETRY_STATUS_CODES
10
+ from mistralai.constants import ENDPOINT
14
11
  from mistralai.exceptions import (
15
12
  MistralAPIException,
13
+ MistralAPIStatusException,
16
14
  MistralConnectionException,
17
15
  MistralException,
18
16
  )
@@ -39,14 +37,22 @@ class MistralClient(ClientBase):
39
37
  ):
40
38
  super().__init__(endpoint, api_key, max_retries, timeout)
41
39
 
40
+ self._client = Client(
41
+ follow_redirects=True,
42
+ timeout=self._timeout,
43
+ transport=HTTPTransport(retries=self._max_retries))
44
+
45
+ def __del__(self) -> None:
46
+ self._client.close()
47
+
42
48
  def _request(
43
49
  self,
44
50
  method: str,
45
51
  json: Dict[str, Any],
46
52
  path: str,
47
53
  stream: bool = False,
48
- params: Optional[Dict[str, Any]] = None,
49
- ) -> Union[Response, Dict[str, Any]]:
54
+ attempt: int = 1,
55
+ ) -> Iterator[Dict[str, Any]]:
50
56
  headers = {
51
57
  "Authorization": f"Bearer {self._api_key}",
52
58
  "Content-Type": "application/json",
@@ -54,49 +60,58 @@ class MistralClient(ClientBase):
54
60
 
55
61
  url = posixpath.join(self._endpoint, path)
56
62
 
57
- with requests.Session() as session:
58
- retries = Retry(
59
- total=self._max_retries,
60
- backoff_factor=0.5,
61
- allowed_methods=["POST", "GET"],
62
- status_forcelist=RETRY_STATUS_CODES,
63
- raise_on_status=False,
64
- )
65
- session.mount("https://", HTTPAdapter(max_retries=retries))
66
- session.mount("http://", HTTPAdapter(max_retries=retries))
63
+ self._logger.debug(f"Sending request: {method} {url} {json}")
64
+
65
+ response: Response
67
66
 
67
+ try:
68
68
  if stream:
69
- return session.request(
70
- method, url, headers=headers, json=json, stream=True
71
- )
69
+ with self._client.stream(
70
+ method,
71
+ url,
72
+ headers=headers,
73
+ json=json,
74
+ ) as response:
75
+ self._check_streaming_response(response)
76
+
77
+ for line in response.iter_lines():
78
+ json_streamed_response = self._process_line(line)
79
+ if json_streamed_response:
80
+ yield json_streamed_response
72
81
 
73
- try:
74
- response = session.request(
82
+ else:
83
+ response = self._client.request(
75
84
  method,
76
85
  url,
77
86
  headers=headers,
78
87
  json=json,
79
- timeout=self._timeout,
80
- params=params,
81
88
  )
82
- except requests.exceptions.ConnectionError as e:
83
- raise MistralConnectionException(str(e)) from e
84
- except requests.exceptions.RequestException as e:
85
- raise MistralException(
86
- f"Unexpected exception ({e.__class__.__name__}): {e}"
87
- ) from e
88
89
 
89
- try:
90
- json_response: Dict[str, Any] = response.json()
91
- except JSONDecodeError:
92
- raise MistralAPIException.from_response(
93
- response, message=f"Failed to decode json body: {response.text}"
94
- )
90
+ yield self._check_response(response)
91
+
92
+ except ConnectError as e:
93
+ raise MistralConnectionException(str(e)) from e
94
+ except RequestError as e:
95
+ raise MistralException(
96
+ f"Unexpected exception ({e.__class__.__name__}): {e}"
97
+ ) from e
98
+ except JSONDecodeError as e:
99
+ raise MistralAPIException.from_response(
100
+ response,
101
+ message=f"Failed to decode json body: {response.text}",
102
+ ) from e
103
+ except MistralAPIStatusException as e:
104
+ attempt += 1
105
+ if attempt > self._max_retries:
106
+ raise MistralAPIStatusException.from_response(
107
+ response, message=str(e)
108
+ ) from e
109
+ backoff = 2.0**attempt # exponential backoff
110
+ time.sleep(backoff)
95
111
 
96
- self._check_response(
97
- json_response, dict(response.headers), response.status_code
98
- )
99
- return json_response
112
+ # Retry as a generator
113
+ for r in self._request(method, json, path, stream=stream, attempt=attempt):
114
+ yield r
100
115
 
101
116
  def chat(
102
117
  self,
@@ -108,7 +123,7 @@ class MistralClient(ClientBase):
108
123
  random_seed: Optional[int] = None,
109
124
  safe_mode: bool = False,
110
125
  ) -> ChatCompletionResponse:
111
- """ A chat endpoint that returns a single response.
126
+ """A chat endpoint that returns a single response.
112
127
 
113
128
  Args:
114
129
  model (str): model the name of the model to chat with, e.g. mistral-tiny
@@ -135,11 +150,12 @@ class MistralClient(ClientBase):
135
150
  safe_mode=safe_mode,
136
151
  )
137
152
 
138
- response = self._request("post", request, "v1/chat/completions")
153
+ single_response = self._request("post", request, "v1/chat/completions")
139
154
 
140
- assert isinstance(response, dict), "Bad response from _request"
155
+ for response in single_response:
156
+ return ChatCompletionResponse(**response)
141
157
 
142
- return ChatCompletionResponse(**response)
158
+ raise MistralException("No response received")
143
159
 
144
160
  def chat_stream(
145
161
  self,
@@ -151,7 +167,7 @@ class MistralClient(ClientBase):
151
167
  random_seed: Optional[int] = None,
152
168
  safe_mode: bool = False,
153
169
  ) -> Iterable[ChatCompletionStreamResponse]:
154
- """ A chat endpoint that streams responses.
170
+ """A chat endpoint that streams responses.
155
171
 
156
172
  Args:
157
173
  model (str): model the name of the model to chat with, e.g. mistral-tiny
@@ -181,18 +197,8 @@ class MistralClient(ClientBase):
181
197
 
182
198
  response = self._request("post", request, "v1/chat/completions", stream=True)
183
199
 
184
- assert isinstance(response, Response), "Bad response from _request"
185
-
186
- for line in response.iter_lines():
187
- self._logger.debug(f"Received line: {line}")
188
- if line == b"\n":
189
- continue
190
-
191
- if line.startswith(b"data: "):
192
- line = line[6:].strip()
193
- if line != b"[DONE]":
194
- json_response = orjson.loads(line)
195
- yield ChatCompletionStreamResponse(**json_response)
200
+ for json_streamed_response in response:
201
+ yield ChatCompletionStreamResponse(**json_streamed_response)
196
202
 
197
203
  def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse:
198
204
  """An embeddings endpoint that returns embeddings for a single, or batch of inputs
@@ -206,9 +212,12 @@ class MistralClient(ClientBase):
206
212
  EmbeddingResponse: A response object containing the embeddings.
207
213
  """
208
214
  request = {"model": model, "input": input}
209
- response = self._request("post", request, "v1/embeddings")
210
- assert isinstance(response, dict), "Bad response from _request"
211
- return EmbeddingResponse(**response)
215
+ singleton_response = self._request("post", request, "v1/embeddings")
216
+
217
+ for response in singleton_response:
218
+ return EmbeddingResponse(**response)
219
+
220
+ raise MistralException("No response received")
212
221
 
213
222
  def list_models(self) -> ModelList:
214
223
  """Returns a list of the available models
@@ -216,6 +225,9 @@ class MistralClient(ClientBase):
216
225
  Returns:
217
226
  ModelList: A response object containing the list of models.
218
227
  """
219
- response = self._request("get", {}, "v1/models")
220
- assert isinstance(response, dict), "Bad response from _request"
221
- return ModelList(**response)
228
+ singleton_response = self._request("get", {}, "v1/models")
229
+
230
+ for response in singleton_response:
231
+ return ModelList(**response)
232
+
233
+ raise MistralException("No response received")
mistralai/client_base.py CHANGED
@@ -1,10 +1,24 @@
1
1
  import logging
2
+ import os
2
3
  from abc import ABC
3
4
  from typing import Any, Dict, List, Optional
4
5
 
5
- from mistralai.exceptions import MistralAPIException, MistralException
6
+ import orjson
7
+ from httpx import Response
8
+
9
+ from mistralai.constants import RETRY_STATUS_CODES
10
+ from mistralai.exceptions import (
11
+ MistralAPIException,
12
+ MistralAPIStatusException,
13
+ MistralException,
14
+ )
6
15
  from mistralai.models.chat_completion import ChatMessage
7
16
 
17
+ logging.basicConfig(
18
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
19
+ level=os.getenv("LOG_LEVEL", "ERROR"),
20
+ )
21
+
8
22
 
9
23
  class ClientBase(ABC):
10
24
  def __init__(
@@ -21,8 +35,8 @@ class ClientBase(ABC):
21
35
  self._api_key = api_key
22
36
  self._logger = logging.getLogger(__name__)
23
37
 
24
- @staticmethod
25
38
  def _make_chat_request(
39
+ self,
26
40
  model: str,
27
41
  messages: List[ChatMessage],
28
42
  temperature: Optional[float] = None,
@@ -48,26 +62,48 @@ class ClientBase(ABC):
48
62
  if stream is not None:
49
63
  request_data["stream"] = stream
50
64
 
65
+ self._logger.debug(f"Chat request: {request_data}")
66
+
51
67
  return request_data
52
68
 
53
- def _check_response(
54
- self, json_response: Dict[str, Any], headers: Dict[str, Any], status: int
55
- ) -> None:
69
+ def _check_response_status_codes(self, response: Response) -> None:
70
+ if response.status_code in RETRY_STATUS_CODES:
71
+ raise MistralAPIStatusException.from_response(
72
+ response,
73
+ message=f"Cannot stream response. Status: {response.status_code}",
74
+ )
75
+ elif 400 <= response.status_code < 500:
76
+ raise MistralAPIException.from_response(
77
+ response,
78
+ message=f"Cannot stream response. Status: {response.status_code}",
79
+ )
80
+ elif response.status_code >= 500:
81
+ raise MistralException(
82
+ message=f"Unexpected server error (status {response.status_code})"
83
+ )
84
+
85
+ def _check_streaming_response(self, response: Response) -> None:
86
+ self._check_response_status_codes(response)
87
+
88
+ def _check_response(self, response: Response) -> Dict[str, Any]:
89
+ self._check_response_status_codes(response)
90
+
91
+ json_response: Dict[str, Any] = response.json()
92
+
56
93
  if "object" not in json_response:
57
94
  raise MistralException(message=f"Unexpected response: {json_response}")
58
95
  if "error" == json_response["object"]: # has errors
59
- raise MistralAPIException(
96
+ raise MistralAPIException.from_response(
97
+ response,
60
98
  message=json_response["message"],
61
- http_status=status,
62
- headers=headers,
63
- )
64
- if 400 <= status < 500:
65
- raise MistralAPIException(
66
- message=f"Unexpected client error (status {status}): {json_response}",
67
- http_status=status,
68
- headers=headers,
69
- )
70
- if status >= 500:
71
- raise MistralException(
72
- message=f"Unexpected server error (status {status}): {json_response}"
73
99
  )
100
+
101
+ return json_response
102
+
103
+ def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
104
+ if line.startswith("data: "):
105
+ line = line[6:].strip()
106
+ if line != "[DONE]":
107
+ json_streamed_response: Dict[str, Any] = orjson.loads(line)
108
+ return json_streamed_response
109
+ return None
mistralai/exceptions.py CHANGED
@@ -2,8 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, Dict, Optional
4
4
 
5
- import aiohttp
6
- from requests import Response
5
+ from httpx import Response
7
6
 
8
7
 
9
8
  class MistralException(Exception):
@@ -45,19 +44,11 @@ class MistralAPIException(MistralException):
45
44
  headers=dict(response.headers),
46
45
  )
47
46
 
48
- @classmethod
49
- def from_aio_response(
50
- cls, response: aiohttp.ClientResponse, message: Optional[str] = None
51
- ) -> MistralAPIException:
52
- return cls(
53
- message=message,
54
- http_status=response.status,
55
- headers=dict(response.headers),
56
- )
57
-
58
47
  def __repr__(self) -> str:
59
48
  return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
60
49
 
50
+ class MistralAPIStatusException(MistralAPIException):
51
+ """Returned when we receive a non-200 response from the API that we should retry"""
61
52
 
62
53
  class MistralConnectionException(MistralException):
63
54
  """Returned when the SDK can not reach the API server for any reason"""
mistralai/py.typed ADDED
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.0.8
3
+ Version: 0.0.9
4
4
  Summary:
5
5
  Author: Bam4d
6
6
  Author-email: bam4d@mistral.ai
@@ -11,19 +11,25 @@ Classifier: Programming Language :: Python :: 3.9
11
11
  Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
- Requires-Dist: aiohttp (>=3.9.1,<4.0.0)
15
- Requires-Dist: backoff (>=2.2.1,<3.0.0)
14
+ Requires-Dist: httpx (>=0.25.2,<0.26.0)
16
15
  Requires-Dist: orjson (>=3.9.10,<4.0.0)
17
16
  Requires-Dist: pydantic (>=2.5.2,<3.0.0)
18
- Requires-Dist: requests (>=2.31.0,<3.0.0)
19
17
  Description-Content-Type: text/markdown
20
18
 
21
- This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
22
-
23
19
  # Mistral Python Client
24
20
 
21
+ This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
22
+
25
23
  You can use the Mistral Python client to interact with the Mistral AI API.
26
24
 
25
+ ## Installing
26
+
27
+ ```bash
28
+ pip install mistralai
29
+ ```
30
+
31
+ ### From Source
32
+
27
33
  This client uses `poetry` as a dependency and virtual environment manager.
28
34
 
29
35
  You can install poetry with
@@ -32,8 +38,6 @@ You can install poetry with
32
38
  pip install poetry
33
39
  ```
34
40
 
35
- ## Installing
36
-
37
41
  `poetry` will set up a virtual environment and install dependencies with the following command:
38
42
 
39
43
  ```bash
@@ -44,6 +48,21 @@ poetry install
44
48
 
45
49
  You can run the examples in the `examples/` directory using `poetry run` or by entering the virtual environment using `poetry shell`.
46
50
 
51
+ ### API Key Setup
52
+
53
+ Running the examples requires a Mistral AI API key.
54
+
55
+ 1. Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
56
+ 2. Set your Mistral API Key as an environment variable. You only need to do this once.
57
+
58
+ ```bash
59
+ # set Mistral API Key (using zsh for example)
60
+ $ echo 'export MISTRAL_API_KEY=[your_key_here]' >> ~/.zshenv
61
+
62
+ # reload the environment (or just quit and open a new terminal)
63
+ $ source ~/.zshenv
64
+ ```
65
+
47
66
  ### Using poetry run
48
67
 
49
68
  ```bash
@@ -0,0 +1,16 @@
1
+ mistralai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mistralai/async_client.py,sha256=2tfOsASbMvs0jK4Yj4ePXuY4ZfULw439xFhJaSZ-98o,8768
3
+ mistralai/client.py,sha256=LMA3PC67bNe-kHTI55WiZMzN7--w_RuJ3n1MQZHsp6E,8446
4
+ mistralai/client_base.py,sha256=jzpclGy1o016RHGaRH9jZEQ4whL0ZR7lO9m_l9aDnKY,3609
5
+ mistralai/constants.py,sha256=KK286HFjpoTxPKih8xdTp0lW4YMDPyYu2Shi3Nu5Vdw,86
6
+ mistralai/exceptions.py,sha256=d1cli28ZaPEBQy2RKKcSh-dO2T20GXHNoaJrXXBmyIs,1664
7
+ mistralai/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ mistralai/models/chat_completion.py,sha256=MzAIFrmORSrlyj45_JUk8KP8NyNjbDA7HChbtTY4S5A,1038
9
+ mistralai/models/common.py,sha256=zatP4aV_LIEpzj3_igsKkJBICwGhmXG0LX3CdO3kn-o,172
10
+ mistralai/models/embeddings.py,sha256=-VthLQBj6wrq7HXJbGmnkQEEanSemA3MAlaMFh94VBg,331
11
+ mistralai/models/models.py,sha256=I4I1kQwP0tJ2_rd4lbSUwOkCJE5E-dSTIExQg2ZNFcw,713
12
+ mistralai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ mistralai-0.0.9.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
14
+ mistralai-0.0.9.dist-info/METADATA,sha256=oWiSfOmHSlvwtqGsynFuIhwh1CPPsDoG4PI_GynvKco,1886
15
+ mistralai-0.0.9.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
16
+ mistralai-0.0.9.dist-info/RECORD,,
@@ -1,15 +0,0 @@
1
- mistralai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mistralai/async_client.py,sha256=XPeDjfhkGmwxtinq-KM7LxGPOzhjYz2whafQEcmDriA,12913
3
- mistralai/client.py,sha256=oE1kjNDKoCRLfrV8DHkaVTZ_YITZPH9z_iThFh01RkE,8317
4
- mistralai/client_base.py,sha256=URoVJqca_4Jcyji3iHNkVZN7PZBhlpGOoLGV4_hp80A,2472
5
- mistralai/constants.py,sha256=KK286HFjpoTxPKih8xdTp0lW4YMDPyYu2Shi3Nu5Vdw,86
6
- mistralai/exceptions.py,sha256=C4fWe4qYeYd8XmpcUmKugidokkfjkBgjYMTaCZqy4Wo,1836
7
- mistralai/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- mistralai/models/chat_completion.py,sha256=MzAIFrmORSrlyj45_JUk8KP8NyNjbDA7HChbtTY4S5A,1038
9
- mistralai/models/common.py,sha256=zatP4aV_LIEpzj3_igsKkJBICwGhmXG0LX3CdO3kn-o,172
10
- mistralai/models/embeddings.py,sha256=-VthLQBj6wrq7HXJbGmnkQEEanSemA3MAlaMFh94VBg,331
11
- mistralai/models/models.py,sha256=I4I1kQwP0tJ2_rd4lbSUwOkCJE5E-dSTIExQg2ZNFcw,713
12
- mistralai-0.0.8.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
13
- mistralai-0.0.8.dist-info/METADATA,sha256=ABTT6TfZoxM34HdHVDjZk71Rs2tbRvmBwbhEnA71Gkc,1481
14
- mistralai-0.0.8.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
15
- mistralai-0.0.8.dist-info/RECORD,,