mistralai 0.0.7__tar.gz → 0.0.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.0.7 → mistralai-0.0.9}/PKG-INFO +27 -6
- {mistralai-0.0.7 → mistralai-0.0.9}/README.md +25 -2
- {mistralai-0.0.7 → mistralai-0.0.9}/pyproject.toml +2 -4
- mistralai-0.0.9/src/mistralai/async_client.py +245 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/client.py +79 -67
- mistralai-0.0.9/src/mistralai/client_base.py +109 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/exceptions.py +3 -12
- mistralai-0.0.9/src/mistralai/py.typed +0 -0
- mistralai-0.0.7/src/mistralai/async_client.py +0 -342
- mistralai-0.0.7/src/mistralai/client_base.py +0 -73
- {mistralai-0.0.7 → mistralai-0.0.9}/LICENSE +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/__init__.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/constants.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/models/chat_completion.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/models/common.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.0.7 → mistralai-0.0.9}/src/mistralai/models/models.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.9
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Bam4d
|
|
6
6
|
Author-email: bam4d@mistral.ai
|
|
@@ -11,17 +11,25 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
-
Requires-Dist:
|
|
15
|
-
Requires-Dist: backoff (>=2.2.1,<3.0.0)
|
|
14
|
+
Requires-Dist: httpx (>=0.25.2,<0.26.0)
|
|
16
15
|
Requires-Dist: orjson (>=3.9.10,<4.0.0)
|
|
17
16
|
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
18
|
-
Requires-Dist: requests (>=2.31.0,<3.0.0)
|
|
19
17
|
Description-Content-Type: text/markdown
|
|
20
18
|
|
|
21
19
|
# Mistral Python Client
|
|
22
20
|
|
|
21
|
+
This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
|
|
22
|
+
|
|
23
23
|
You can use the Mistral Python client to interact with the Mistral AI API.
|
|
24
24
|
|
|
25
|
+
## Installing
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install mistralai
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
### From Source
|
|
32
|
+
|
|
25
33
|
This client uses `poetry` as a dependency and virtual environment manager.
|
|
26
34
|
|
|
27
35
|
You can install poetry with
|
|
@@ -30,8 +38,6 @@ You can install poetry with
|
|
|
30
38
|
pip install poetry
|
|
31
39
|
```
|
|
32
40
|
|
|
33
|
-
## Installing
|
|
34
|
-
|
|
35
41
|
`poetry` will set up a virtual environment and install dependencies with the following command:
|
|
36
42
|
|
|
37
43
|
```bash
|
|
@@ -42,6 +48,21 @@ poetry install
|
|
|
42
48
|
|
|
43
49
|
You can run the examples in the `examples/` directory using `poetry run` or by entering the virtual environment using `poetry shell`.
|
|
44
50
|
|
|
51
|
+
### API Key Setup
|
|
52
|
+
|
|
53
|
+
Running the examples requires a Mistral AI API key.
|
|
54
|
+
|
|
55
|
+
1. Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
|
|
56
|
+
2. Set your Mistral API Key as an environment variable. You only need to do this once.
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
# set Mistral API Key (using zsh for example)
|
|
60
|
+
$ echo 'export MISTRAL_API_KEY=[your_key_here]' >> ~/.zshenv
|
|
61
|
+
|
|
62
|
+
# reload the environment (or just quit and open a new terminal)
|
|
63
|
+
$ source ~/.zshenv
|
|
64
|
+
```
|
|
65
|
+
|
|
45
66
|
### Using poetry run
|
|
46
67
|
|
|
47
68
|
```bash
|
|
@@ -1,7 +1,17 @@
|
|
|
1
1
|
# Mistral Python Client
|
|
2
2
|
|
|
3
|
+
This client is inspired from [cohere-python](https://github.com/cohere-ai/cohere-python)
|
|
4
|
+
|
|
3
5
|
You can use the Mistral Python client to interact with the Mistral AI API.
|
|
4
6
|
|
|
7
|
+
## Installing
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install mistralai
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
### From Source
|
|
14
|
+
|
|
5
15
|
This client uses `poetry` as a dependency and virtual environment manager.
|
|
6
16
|
|
|
7
17
|
You can install poetry with
|
|
@@ -10,8 +20,6 @@ You can install poetry with
|
|
|
10
20
|
pip install poetry
|
|
11
21
|
```
|
|
12
22
|
|
|
13
|
-
## Installing
|
|
14
|
-
|
|
15
23
|
`poetry` will set up a virtual environment and install dependencies with the following command:
|
|
16
24
|
|
|
17
25
|
```bash
|
|
@@ -22,6 +30,21 @@ poetry install
|
|
|
22
30
|
|
|
23
31
|
You can run the examples in the `examples/` directory using `poetry run` or by entering the virtual environment using `poetry shell`.
|
|
24
32
|
|
|
33
|
+
### API Key Setup
|
|
34
|
+
|
|
35
|
+
Running the examples requires a Mistral AI API key.
|
|
36
|
+
|
|
37
|
+
1. Get your own Mistral API Key: <https://docs.mistral.ai/#api-access>
|
|
38
|
+
2. Set your Mistral API Key as an environment variable. You only need to do this once.
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
# set Mistral API Key (using zsh for example)
|
|
42
|
+
$ echo 'export MISTRAL_API_KEY=[your_key_here]' >> ~/.zshenv
|
|
43
|
+
|
|
44
|
+
# reload the environment (or just quit and open a new terminal)
|
|
45
|
+
$ source ~/.zshenv
|
|
46
|
+
```
|
|
47
|
+
|
|
25
48
|
### Using poetry run
|
|
26
49
|
|
|
27
50
|
```bash
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.9"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -24,11 +24,9 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
|
|
|
24
24
|
|
|
25
25
|
[tool.poetry.dependencies]
|
|
26
26
|
python = "^3.8"
|
|
27
|
-
aiohttp = "^3.9.1"
|
|
28
|
-
backoff = "^2.2.1"
|
|
29
27
|
orjson = "^3.9.10"
|
|
30
|
-
requests = "^2.31.0"
|
|
31
28
|
pydantic = "^2.5.2"
|
|
29
|
+
httpx = "^0.25.2"
|
|
32
30
|
|
|
33
31
|
|
|
34
32
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import posixpath
|
|
3
|
+
import time
|
|
4
|
+
from json import JSONDecodeError
|
|
5
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
from httpx import (
|
|
8
|
+
AsyncClient,
|
|
9
|
+
AsyncHTTPTransport,
|
|
10
|
+
ConnectError,
|
|
11
|
+
Limits,
|
|
12
|
+
RequestError,
|
|
13
|
+
Response,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from mistralai.client_base import ClientBase
|
|
17
|
+
from mistralai.constants import ENDPOINT
|
|
18
|
+
from mistralai.exceptions import (
|
|
19
|
+
MistralAPIException,
|
|
20
|
+
MistralAPIStatusException,
|
|
21
|
+
MistralConnectionException,
|
|
22
|
+
MistralException,
|
|
23
|
+
)
|
|
24
|
+
from mistralai.models.chat_completion import (
|
|
25
|
+
ChatCompletionResponse,
|
|
26
|
+
ChatCompletionStreamResponse,
|
|
27
|
+
ChatMessage,
|
|
28
|
+
)
|
|
29
|
+
from mistralai.models.embeddings import EmbeddingResponse
|
|
30
|
+
from mistralai.models.models import ModelList
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class MistralAsyncClient(ClientBase):
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY", None),
|
|
37
|
+
endpoint: str = ENDPOINT,
|
|
38
|
+
max_retries: int = 5,
|
|
39
|
+
timeout: int = 120,
|
|
40
|
+
max_concurrent_requests: int = 64,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
43
|
+
|
|
44
|
+
self._client = AsyncClient(
|
|
45
|
+
follow_redirects=True,
|
|
46
|
+
timeout=timeout,
|
|
47
|
+
limits=Limits(max_connections=max_concurrent_requests),
|
|
48
|
+
transport=AsyncHTTPTransport(retries=max_retries),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
async def close(self) -> None:
|
|
52
|
+
await self._client.aclose()
|
|
53
|
+
|
|
54
|
+
async def _request(
|
|
55
|
+
self,
|
|
56
|
+
method: str,
|
|
57
|
+
json: Dict[str, Any],
|
|
58
|
+
path: str,
|
|
59
|
+
stream: bool = False,
|
|
60
|
+
attempt: int = 1,
|
|
61
|
+
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
62
|
+
headers = {
|
|
63
|
+
"Authorization": f"Bearer {self._api_key}",
|
|
64
|
+
"Content-Type": "application/json",
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
url = posixpath.join(self._endpoint, path)
|
|
68
|
+
|
|
69
|
+
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
70
|
+
|
|
71
|
+
response: Response
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
if stream:
|
|
75
|
+
async with self._client.stream(
|
|
76
|
+
method,
|
|
77
|
+
url,
|
|
78
|
+
headers=headers,
|
|
79
|
+
json=json,
|
|
80
|
+
) as response:
|
|
81
|
+
self._check_streaming_response(response)
|
|
82
|
+
|
|
83
|
+
async for line in response.aiter_lines():
|
|
84
|
+
json_streamed_response = self._process_line(line)
|
|
85
|
+
if json_streamed_response:
|
|
86
|
+
yield json_streamed_response
|
|
87
|
+
|
|
88
|
+
else:
|
|
89
|
+
response = await self._client.request(
|
|
90
|
+
method,
|
|
91
|
+
url,
|
|
92
|
+
headers=headers,
|
|
93
|
+
json=json,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
yield self._check_response(response)
|
|
97
|
+
|
|
98
|
+
except ConnectError as e:
|
|
99
|
+
raise MistralConnectionException(str(e)) from e
|
|
100
|
+
except RequestError as e:
|
|
101
|
+
raise MistralException(
|
|
102
|
+
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
103
|
+
) from e
|
|
104
|
+
except JSONDecodeError as e:
|
|
105
|
+
raise MistralAPIException.from_response(
|
|
106
|
+
response,
|
|
107
|
+
message=f"Failed to decode json body: {response.text}",
|
|
108
|
+
) from e
|
|
109
|
+
except MistralAPIStatusException as e:
|
|
110
|
+
attempt += 1
|
|
111
|
+
if attempt > self._max_retries:
|
|
112
|
+
raise MistralAPIStatusException.from_response(
|
|
113
|
+
response, message=str(e)
|
|
114
|
+
) from e
|
|
115
|
+
backoff = 2.0**attempt # exponential backoff
|
|
116
|
+
time.sleep(backoff)
|
|
117
|
+
|
|
118
|
+
# Retry as a generator
|
|
119
|
+
async for r in self._request(
|
|
120
|
+
method, json, path, stream=stream, attempt=attempt
|
|
121
|
+
):
|
|
122
|
+
yield r
|
|
123
|
+
|
|
124
|
+
async def chat(
|
|
125
|
+
self,
|
|
126
|
+
model: str,
|
|
127
|
+
messages: List[ChatMessage],
|
|
128
|
+
temperature: Optional[float] = None,
|
|
129
|
+
max_tokens: Optional[int] = None,
|
|
130
|
+
top_p: Optional[float] = None,
|
|
131
|
+
random_seed: Optional[int] = None,
|
|
132
|
+
safe_mode: bool = False,
|
|
133
|
+
) -> ChatCompletionResponse:
|
|
134
|
+
"""A asynchronous chat endpoint that returns a single response.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
138
|
+
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
|
|
139
|
+
[{role: 'user', content: 'What is the best French cheese?'}]
|
|
140
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
141
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
142
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
143
|
+
Defaults to None.
|
|
144
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
145
|
+
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
ChatCompletionResponse: a response object containing the generated text.
|
|
149
|
+
"""
|
|
150
|
+
request = self._make_chat_request(
|
|
151
|
+
model,
|
|
152
|
+
messages,
|
|
153
|
+
temperature=temperature,
|
|
154
|
+
max_tokens=max_tokens,
|
|
155
|
+
top_p=top_p,
|
|
156
|
+
random_seed=random_seed,
|
|
157
|
+
stream=False,
|
|
158
|
+
safe_mode=safe_mode,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
single_response = self._request("post", request, "v1/chat/completions")
|
|
162
|
+
|
|
163
|
+
async for response in single_response:
|
|
164
|
+
return ChatCompletionResponse(**response)
|
|
165
|
+
|
|
166
|
+
raise MistralException("No response received")
|
|
167
|
+
|
|
168
|
+
async def chat_stream(
|
|
169
|
+
self,
|
|
170
|
+
model: str,
|
|
171
|
+
messages: List[ChatMessage],
|
|
172
|
+
temperature: Optional[float] = None,
|
|
173
|
+
max_tokens: Optional[int] = None,
|
|
174
|
+
top_p: Optional[float] = None,
|
|
175
|
+
random_seed: Optional[int] = None,
|
|
176
|
+
safe_mode: bool = False,
|
|
177
|
+
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
178
|
+
"""An Asynchronous chat endpoint that streams responses.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
182
|
+
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
|
|
183
|
+
[{role: 'user', content: 'What is the best French cheese?'}]
|
|
184
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
185
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
186
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
187
|
+
Defaults to None.
|
|
188
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
189
|
+
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
193
|
+
An async generator that yields ChatCompletionStreamResponse objects.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
request = self._make_chat_request(
|
|
197
|
+
model,
|
|
198
|
+
messages,
|
|
199
|
+
temperature=temperature,
|
|
200
|
+
max_tokens=max_tokens,
|
|
201
|
+
top_p=top_p,
|
|
202
|
+
random_seed=random_seed,
|
|
203
|
+
stream=True,
|
|
204
|
+
safe_mode=safe_mode,
|
|
205
|
+
)
|
|
206
|
+
async_response = self._request(
|
|
207
|
+
"post", request, "v1/chat/completions", stream=True
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
async for json_response in async_response:
|
|
211
|
+
yield ChatCompletionStreamResponse(**json_response)
|
|
212
|
+
|
|
213
|
+
async def embeddings(
|
|
214
|
+
self, model: str, input: Union[str, List[str]]
|
|
215
|
+
) -> EmbeddingResponse:
|
|
216
|
+
"""An asynchronous embeddings endpoint that returns embeddings for a single, or batch of inputs
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
model (str): The embedding model to use, e.g. mistral-embed
|
|
220
|
+
input (Union[str, List[str]]): The input to embed,
|
|
221
|
+
e.g. ['What is the best French cheese?']
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
EmbeddingResponse: A response object containing the embeddings.
|
|
225
|
+
"""
|
|
226
|
+
request = {"model": model, "input": input}
|
|
227
|
+
single_response = self._request("post", request, "v1/embeddings")
|
|
228
|
+
|
|
229
|
+
async for response in single_response:
|
|
230
|
+
return EmbeddingResponse(**response)
|
|
231
|
+
|
|
232
|
+
raise MistralException("No response received")
|
|
233
|
+
|
|
234
|
+
async def list_models(self) -> ModelList:
|
|
235
|
+
"""Returns a list of the available models
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
ModelList: A response object containing the list of models.
|
|
239
|
+
"""
|
|
240
|
+
single_response = self._request("get", {}, "v1/models")
|
|
241
|
+
|
|
242
|
+
async for response in single_response:
|
|
243
|
+
return ModelList(**response)
|
|
244
|
+
|
|
245
|
+
raise MistralException("No response received")
|
|
@@ -1,18 +1,16 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import posixpath
|
|
3
|
+
import time
|
|
3
4
|
from json import JSONDecodeError
|
|
4
|
-
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
5
|
+
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
|
|
5
6
|
|
|
6
|
-
import
|
|
7
|
-
import requests
|
|
8
|
-
from requests import Response
|
|
9
|
-
from requests.adapters import HTTPAdapter
|
|
10
|
-
from urllib3.util.retry import Retry
|
|
7
|
+
from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
|
|
11
8
|
|
|
12
9
|
from mistralai.client_base import ClientBase
|
|
13
|
-
from mistralai.constants import ENDPOINT
|
|
10
|
+
from mistralai.constants import ENDPOINT
|
|
14
11
|
from mistralai.exceptions import (
|
|
15
12
|
MistralAPIException,
|
|
13
|
+
MistralAPIStatusException,
|
|
16
14
|
MistralConnectionException,
|
|
17
15
|
MistralException,
|
|
18
16
|
)
|
|
@@ -39,14 +37,22 @@ class MistralClient(ClientBase):
|
|
|
39
37
|
):
|
|
40
38
|
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
41
39
|
|
|
40
|
+
self._client = Client(
|
|
41
|
+
follow_redirects=True,
|
|
42
|
+
timeout=self._timeout,
|
|
43
|
+
transport=HTTPTransport(retries=self._max_retries))
|
|
44
|
+
|
|
45
|
+
def __del__(self) -> None:
|
|
46
|
+
self._client.close()
|
|
47
|
+
|
|
42
48
|
def _request(
|
|
43
49
|
self,
|
|
44
50
|
method: str,
|
|
45
51
|
json: Dict[str, Any],
|
|
46
52
|
path: str,
|
|
47
53
|
stream: bool = False,
|
|
48
|
-
|
|
49
|
-
) ->
|
|
54
|
+
attempt: int = 1,
|
|
55
|
+
) -> Iterator[Dict[str, Any]]:
|
|
50
56
|
headers = {
|
|
51
57
|
"Authorization": f"Bearer {self._api_key}",
|
|
52
58
|
"Content-Type": "application/json",
|
|
@@ -54,49 +60,58 @@ class MistralClient(ClientBase):
|
|
|
54
60
|
|
|
55
61
|
url = posixpath.join(self._endpoint, path)
|
|
56
62
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
backoff_factor=0.5,
|
|
61
|
-
allowed_methods=["POST", "GET"],
|
|
62
|
-
status_forcelist=RETRY_STATUS_CODES,
|
|
63
|
-
raise_on_status=False,
|
|
64
|
-
)
|
|
65
|
-
session.mount("https://", HTTPAdapter(max_retries=retries))
|
|
66
|
-
session.mount("http://", HTTPAdapter(max_retries=retries))
|
|
63
|
+
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
64
|
+
|
|
65
|
+
response: Response
|
|
67
66
|
|
|
67
|
+
try:
|
|
68
68
|
if stream:
|
|
69
|
-
|
|
70
|
-
method,
|
|
71
|
-
|
|
69
|
+
with self._client.stream(
|
|
70
|
+
method,
|
|
71
|
+
url,
|
|
72
|
+
headers=headers,
|
|
73
|
+
json=json,
|
|
74
|
+
) as response:
|
|
75
|
+
self._check_streaming_response(response)
|
|
76
|
+
|
|
77
|
+
for line in response.iter_lines():
|
|
78
|
+
json_streamed_response = self._process_line(line)
|
|
79
|
+
if json_streamed_response:
|
|
80
|
+
yield json_streamed_response
|
|
72
81
|
|
|
73
|
-
|
|
74
|
-
response =
|
|
82
|
+
else:
|
|
83
|
+
response = self._client.request(
|
|
75
84
|
method,
|
|
76
85
|
url,
|
|
77
86
|
headers=headers,
|
|
78
87
|
json=json,
|
|
79
|
-
timeout=self._timeout,
|
|
80
|
-
params=params,
|
|
81
88
|
)
|
|
82
|
-
except requests.exceptions.ConnectionError as e:
|
|
83
|
-
raise MistralConnectionException(str(e)) from e
|
|
84
|
-
except requests.exceptions.RequestException as e:
|
|
85
|
-
raise MistralException(
|
|
86
|
-
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
87
|
-
) from e
|
|
88
89
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
90
|
+
yield self._check_response(response)
|
|
91
|
+
|
|
92
|
+
except ConnectError as e:
|
|
93
|
+
raise MistralConnectionException(str(e)) from e
|
|
94
|
+
except RequestError as e:
|
|
95
|
+
raise MistralException(
|
|
96
|
+
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
97
|
+
) from e
|
|
98
|
+
except JSONDecodeError as e:
|
|
99
|
+
raise MistralAPIException.from_response(
|
|
100
|
+
response,
|
|
101
|
+
message=f"Failed to decode json body: {response.text}",
|
|
102
|
+
) from e
|
|
103
|
+
except MistralAPIStatusException as e:
|
|
104
|
+
attempt += 1
|
|
105
|
+
if attempt > self._max_retries:
|
|
106
|
+
raise MistralAPIStatusException.from_response(
|
|
107
|
+
response, message=str(e)
|
|
108
|
+
) from e
|
|
109
|
+
backoff = 2.0**attempt # exponential backoff
|
|
110
|
+
time.sleep(backoff)
|
|
95
111
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
return json_response
|
|
112
|
+
# Retry as a generator
|
|
113
|
+
for r in self._request(method, json, path, stream=stream, attempt=attempt):
|
|
114
|
+
yield r
|
|
100
115
|
|
|
101
116
|
def chat(
|
|
102
117
|
self,
|
|
@@ -106,9 +121,9 @@ class MistralClient(ClientBase):
|
|
|
106
121
|
max_tokens: Optional[int] = None,
|
|
107
122
|
top_p: Optional[float] = None,
|
|
108
123
|
random_seed: Optional[int] = None,
|
|
109
|
-
safe_mode: bool =
|
|
124
|
+
safe_mode: bool = False,
|
|
110
125
|
) -> ChatCompletionResponse:
|
|
111
|
-
"""
|
|
126
|
+
"""A chat endpoint that returns a single response.
|
|
112
127
|
|
|
113
128
|
Args:
|
|
114
129
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -135,11 +150,12 @@ class MistralClient(ClientBase):
|
|
|
135
150
|
safe_mode=safe_mode,
|
|
136
151
|
)
|
|
137
152
|
|
|
138
|
-
|
|
153
|
+
single_response = self._request("post", request, "v1/chat/completions")
|
|
139
154
|
|
|
140
|
-
|
|
155
|
+
for response in single_response:
|
|
156
|
+
return ChatCompletionResponse(**response)
|
|
141
157
|
|
|
142
|
-
|
|
158
|
+
raise MistralException("No response received")
|
|
143
159
|
|
|
144
160
|
def chat_stream(
|
|
145
161
|
self,
|
|
@@ -149,9 +165,9 @@ class MistralClient(ClientBase):
|
|
|
149
165
|
max_tokens: Optional[int] = None,
|
|
150
166
|
top_p: Optional[float] = None,
|
|
151
167
|
random_seed: Optional[int] = None,
|
|
152
|
-
safe_mode: bool =
|
|
168
|
+
safe_mode: bool = False,
|
|
153
169
|
) -> Iterable[ChatCompletionStreamResponse]:
|
|
154
|
-
"""
|
|
170
|
+
"""A chat endpoint that streams responses.
|
|
155
171
|
|
|
156
172
|
Args:
|
|
157
173
|
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
@@ -181,18 +197,8 @@ class MistralClient(ClientBase):
|
|
|
181
197
|
|
|
182
198
|
response = self._request("post", request, "v1/chat/completions", stream=True)
|
|
183
199
|
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
for line in response.iter_lines():
|
|
187
|
-
self._logger.debug(f"Received line: {line}")
|
|
188
|
-
if line == b"\n":
|
|
189
|
-
continue
|
|
190
|
-
|
|
191
|
-
if line.startswith(b"data: "):
|
|
192
|
-
line = line[6:].strip()
|
|
193
|
-
if line != b"[DONE]":
|
|
194
|
-
json_response = orjson.loads(line)
|
|
195
|
-
yield ChatCompletionStreamResponse(**json_response)
|
|
200
|
+
for json_streamed_response in response:
|
|
201
|
+
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
196
202
|
|
|
197
203
|
def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse:
|
|
198
204
|
"""An embeddings endpoint that returns embeddings for a single, or batch of inputs
|
|
@@ -206,9 +212,12 @@ class MistralClient(ClientBase):
|
|
|
206
212
|
EmbeddingResponse: A response object containing the embeddings.
|
|
207
213
|
"""
|
|
208
214
|
request = {"model": model, "input": input}
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
215
|
+
singleton_response = self._request("post", request, "v1/embeddings")
|
|
216
|
+
|
|
217
|
+
for response in singleton_response:
|
|
218
|
+
return EmbeddingResponse(**response)
|
|
219
|
+
|
|
220
|
+
raise MistralException("No response received")
|
|
212
221
|
|
|
213
222
|
def list_models(self) -> ModelList:
|
|
214
223
|
"""Returns a list of the available models
|
|
@@ -216,6 +225,9 @@ class MistralClient(ClientBase):
|
|
|
216
225
|
Returns:
|
|
217
226
|
ModelList: A response object containing the list of models.
|
|
218
227
|
"""
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
228
|
+
singleton_response = self._request("get", {}, "v1/models")
|
|
229
|
+
|
|
230
|
+
for response in singleton_response:
|
|
231
|
+
return ModelList(**response)
|
|
232
|
+
|
|
233
|
+
raise MistralException("No response received")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from abc import ABC
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
import orjson
|
|
7
|
+
from httpx import Response
|
|
8
|
+
|
|
9
|
+
from mistralai.constants import RETRY_STATUS_CODES
|
|
10
|
+
from mistralai.exceptions import (
|
|
11
|
+
MistralAPIException,
|
|
12
|
+
MistralAPIStatusException,
|
|
13
|
+
MistralException,
|
|
14
|
+
)
|
|
15
|
+
from mistralai.models.chat_completion import ChatMessage
|
|
16
|
+
|
|
17
|
+
logging.basicConfig(
|
|
18
|
+
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
19
|
+
level=os.getenv("LOG_LEVEL", "ERROR"),
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ClientBase(ABC):
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
endpoint: str,
|
|
27
|
+
api_key: Optional[str] = None,
|
|
28
|
+
max_retries: int = 5,
|
|
29
|
+
timeout: int = 120,
|
|
30
|
+
):
|
|
31
|
+
self._max_retries = max_retries
|
|
32
|
+
self._timeout = timeout
|
|
33
|
+
|
|
34
|
+
self._endpoint = endpoint
|
|
35
|
+
self._api_key = api_key
|
|
36
|
+
self._logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
def _make_chat_request(
|
|
39
|
+
self,
|
|
40
|
+
model: str,
|
|
41
|
+
messages: List[ChatMessage],
|
|
42
|
+
temperature: Optional[float] = None,
|
|
43
|
+
max_tokens: Optional[int] = None,
|
|
44
|
+
top_p: Optional[float] = None,
|
|
45
|
+
random_seed: Optional[int] = None,
|
|
46
|
+
stream: Optional[bool] = None,
|
|
47
|
+
safe_mode: Optional[bool] = False,
|
|
48
|
+
) -> Dict[str, Any]:
|
|
49
|
+
request_data: Dict[str, Any] = {
|
|
50
|
+
"model": model,
|
|
51
|
+
"messages": [msg.model_dump() for msg in messages],
|
|
52
|
+
"safe_prompt": safe_mode,
|
|
53
|
+
}
|
|
54
|
+
if temperature is not None:
|
|
55
|
+
request_data["temperature"] = temperature
|
|
56
|
+
if max_tokens is not None:
|
|
57
|
+
request_data["max_tokens"] = max_tokens
|
|
58
|
+
if top_p is not None:
|
|
59
|
+
request_data["top_p"] = top_p
|
|
60
|
+
if random_seed is not None:
|
|
61
|
+
request_data["random_seed"] = random_seed
|
|
62
|
+
if stream is not None:
|
|
63
|
+
request_data["stream"] = stream
|
|
64
|
+
|
|
65
|
+
self._logger.debug(f"Chat request: {request_data}")
|
|
66
|
+
|
|
67
|
+
return request_data
|
|
68
|
+
|
|
69
|
+
def _check_response_status_codes(self, response: Response) -> None:
|
|
70
|
+
if response.status_code in RETRY_STATUS_CODES:
|
|
71
|
+
raise MistralAPIStatusException.from_response(
|
|
72
|
+
response,
|
|
73
|
+
message=f"Cannot stream response. Status: {response.status_code}",
|
|
74
|
+
)
|
|
75
|
+
elif 400 <= response.status_code < 500:
|
|
76
|
+
raise MistralAPIException.from_response(
|
|
77
|
+
response,
|
|
78
|
+
message=f"Cannot stream response. Status: {response.status_code}",
|
|
79
|
+
)
|
|
80
|
+
elif response.status_code >= 500:
|
|
81
|
+
raise MistralException(
|
|
82
|
+
message=f"Unexpected server error (status {response.status_code})"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def _check_streaming_response(self, response: Response) -> None:
|
|
86
|
+
self._check_response_status_codes(response)
|
|
87
|
+
|
|
88
|
+
def _check_response(self, response: Response) -> Dict[str, Any]:
|
|
89
|
+
self._check_response_status_codes(response)
|
|
90
|
+
|
|
91
|
+
json_response: Dict[str, Any] = response.json()
|
|
92
|
+
|
|
93
|
+
if "object" not in json_response:
|
|
94
|
+
raise MistralException(message=f"Unexpected response: {json_response}")
|
|
95
|
+
if "error" == json_response["object"]: # has errors
|
|
96
|
+
raise MistralAPIException.from_response(
|
|
97
|
+
response,
|
|
98
|
+
message=json_response["message"],
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
return json_response
|
|
102
|
+
|
|
103
|
+
def _process_line(self, line: str) -> Optional[Dict[str, Any]]:
|
|
104
|
+
if line.startswith("data: "):
|
|
105
|
+
line = line[6:].strip()
|
|
106
|
+
if line != "[DONE]":
|
|
107
|
+
json_streamed_response: Dict[str, Any] = orjson.loads(line)
|
|
108
|
+
return json_streamed_response
|
|
109
|
+
return None
|
|
@@ -2,8 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, Dict, Optional
|
|
4
4
|
|
|
5
|
-
import
|
|
6
|
-
from requests import Response
|
|
5
|
+
from httpx import Response
|
|
7
6
|
|
|
8
7
|
|
|
9
8
|
class MistralException(Exception):
|
|
@@ -45,19 +44,11 @@ class MistralAPIException(MistralException):
|
|
|
45
44
|
headers=dict(response.headers),
|
|
46
45
|
)
|
|
47
46
|
|
|
48
|
-
@classmethod
|
|
49
|
-
def from_aio_response(
|
|
50
|
-
cls, response: aiohttp.ClientResponse, message: Optional[str] = None
|
|
51
|
-
) -> MistralAPIException:
|
|
52
|
-
return cls(
|
|
53
|
-
message=message,
|
|
54
|
-
http_status=response.status,
|
|
55
|
-
headers=dict(response.headers),
|
|
56
|
-
)
|
|
57
|
-
|
|
58
47
|
def __repr__(self) -> str:
|
|
59
48
|
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
|
|
60
49
|
|
|
50
|
+
class MistralAPIStatusException(MistralAPIException):
|
|
51
|
+
"""Returned when we receive a non-200 response from the API that we should retry"""
|
|
61
52
|
|
|
62
53
|
class MistralConnectionException(MistralException):
|
|
63
54
|
"""Returned when the SDK can not reach the API server for any reason"""
|
|
File without changes
|
|
@@ -1,342 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import logging
|
|
3
|
-
import os
|
|
4
|
-
import posixpath
|
|
5
|
-
import time
|
|
6
|
-
from collections import defaultdict
|
|
7
|
-
from json import JSONDecodeError
|
|
8
|
-
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
|
9
|
-
|
|
10
|
-
import aiohttp
|
|
11
|
-
import backoff
|
|
12
|
-
import orjson
|
|
13
|
-
|
|
14
|
-
from mistralai.client_base import ClientBase
|
|
15
|
-
from mistralai.constants import ENDPOINT, RETRY_STATUS_CODES
|
|
16
|
-
from mistralai.exceptions import (
|
|
17
|
-
MistralAPIException,
|
|
18
|
-
MistralConnectionException,
|
|
19
|
-
MistralException,
|
|
20
|
-
)
|
|
21
|
-
from mistralai.models.chat_completion import (
|
|
22
|
-
ChatCompletionResponse,
|
|
23
|
-
ChatCompletionStreamResponse,
|
|
24
|
-
ChatMessage,
|
|
25
|
-
)
|
|
26
|
-
from mistralai.models.embeddings import EmbeddingResponse
|
|
27
|
-
from mistralai.models.models import ModelList
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
class AIOHTTPBackend:
|
|
31
|
-
"""HTTP backend which handles retries, concurrency limiting and logging"""
|
|
32
|
-
|
|
33
|
-
SLEEP_AFTER_FAILURE = defaultdict(lambda: 0.25, {429: 5.0})
|
|
34
|
-
|
|
35
|
-
_requester: Callable[..., Awaitable[aiohttp.ClientResponse]]
|
|
36
|
-
_semaphore: asyncio.Semaphore
|
|
37
|
-
_session: Optional[aiohttp.ClientSession]
|
|
38
|
-
|
|
39
|
-
def __init__(
|
|
40
|
-
self,
|
|
41
|
-
max_concurrent_requests: int = 64,
|
|
42
|
-
max_retries: int = 5,
|
|
43
|
-
timeout: int = 120,
|
|
44
|
-
):
|
|
45
|
-
self._logger = logging.getLogger(__name__)
|
|
46
|
-
self._timeout = timeout
|
|
47
|
-
self._max_retries = max_retries
|
|
48
|
-
self._session = None
|
|
49
|
-
self._max_concurrent_requests = max_concurrent_requests
|
|
50
|
-
|
|
51
|
-
def build_aio_requester(
|
|
52
|
-
self,
|
|
53
|
-
) -> Callable: # returns a function for retryable requests
|
|
54
|
-
@backoff.on_exception(
|
|
55
|
-
backoff.expo,
|
|
56
|
-
(aiohttp.ClientError, aiohttp.ClientResponseError),
|
|
57
|
-
max_tries=self._max_retries + 1,
|
|
58
|
-
max_time=self._timeout,
|
|
59
|
-
)
|
|
60
|
-
async def make_request_fn(
|
|
61
|
-
session: aiohttp.ClientSession, *args: Any, **kwargs: Any
|
|
62
|
-
) -> aiohttp.ClientResponse:
|
|
63
|
-
async with self._semaphore: # this limits total concurrency by the client
|
|
64
|
-
response = await session.request(*args, **kwargs)
|
|
65
|
-
if (
|
|
66
|
-
response.status in RETRY_STATUS_CODES
|
|
67
|
-
): # likely temporary, raise to retry
|
|
68
|
-
self._logger.info(f"Received status {response.status}, retrying...")
|
|
69
|
-
await asyncio.sleep(self.SLEEP_AFTER_FAILURE[response.status])
|
|
70
|
-
response.raise_for_status()
|
|
71
|
-
|
|
72
|
-
return response
|
|
73
|
-
|
|
74
|
-
return make_request_fn
|
|
75
|
-
|
|
76
|
-
async def request(
|
|
77
|
-
self,
|
|
78
|
-
url: str,
|
|
79
|
-
json: Optional[Dict[str, Any]] = None,
|
|
80
|
-
method: str = "post",
|
|
81
|
-
headers: Optional[Dict[str, Any]] = None,
|
|
82
|
-
session: Optional[aiohttp.ClientSession] = None,
|
|
83
|
-
params: Optional[Dict[str, Any]] = None,
|
|
84
|
-
**kwargs: Any,
|
|
85
|
-
) -> aiohttp.ClientResponse:
|
|
86
|
-
session = session or await self.session()
|
|
87
|
-
self._logger.debug(f"Making request to {url} with content {json}")
|
|
88
|
-
|
|
89
|
-
request_start = time.time()
|
|
90
|
-
try:
|
|
91
|
-
response = await self._requester(
|
|
92
|
-
session,
|
|
93
|
-
method,
|
|
94
|
-
url,
|
|
95
|
-
headers=headers,
|
|
96
|
-
json=json,
|
|
97
|
-
params=params,
|
|
98
|
-
**kwargs,
|
|
99
|
-
)
|
|
100
|
-
except (
|
|
101
|
-
aiohttp.ClientConnectionError
|
|
102
|
-
) as e: # ensure the SDK user does not have to deal with knowing aiohttp
|
|
103
|
-
self._logger.debug(
|
|
104
|
-
f"Fatal connection error after {time.time()-request_start:.1f}s: {e}"
|
|
105
|
-
)
|
|
106
|
-
raise MistralConnectionException(str(e)) from e
|
|
107
|
-
except (
|
|
108
|
-
aiohttp.ClientResponseError
|
|
109
|
-
) as e: # status 500 or something remains after retries
|
|
110
|
-
self._logger.debug(
|
|
111
|
-
f"Fatal ClientResponseError error after {time.time()-request_start:.1f}s: {e}"
|
|
112
|
-
)
|
|
113
|
-
raise MistralConnectionException(str(e)) from e
|
|
114
|
-
except asyncio.TimeoutError as e:
|
|
115
|
-
self._logger.debug(
|
|
116
|
-
f"Fatal timeout error after {time.time()-request_start:.1f}s: {e}"
|
|
117
|
-
)
|
|
118
|
-
raise MistralConnectionException("The request timed out") from e
|
|
119
|
-
except Exception as e: # Anything caught here should be added above
|
|
120
|
-
self._logger.debug(
|
|
121
|
-
f"Unexpected fatal error after {time.time()-request_start:.1f}s: {e}"
|
|
122
|
-
)
|
|
123
|
-
raise MistralException(
|
|
124
|
-
f"Unexpected exception ({e.__class__.__name__}): {e}"
|
|
125
|
-
) from e
|
|
126
|
-
|
|
127
|
-
self._logger.debug(
|
|
128
|
-
f"Received response with status {response.status} after {time.time()-request_start:.1f}s"
|
|
129
|
-
)
|
|
130
|
-
return response
|
|
131
|
-
|
|
132
|
-
async def session(self) -> aiohttp.ClientSession:
|
|
133
|
-
if self._session is None:
|
|
134
|
-
self._session = aiohttp.ClientSession(
|
|
135
|
-
timeout=aiohttp.ClientTimeout(self._timeout),
|
|
136
|
-
connector=aiohttp.TCPConnector(limit=0),
|
|
137
|
-
)
|
|
138
|
-
self._semaphore = asyncio.Semaphore(self._max_concurrent_requests)
|
|
139
|
-
self._requester = self.build_aio_requester()
|
|
140
|
-
return self._session
|
|
141
|
-
|
|
142
|
-
async def close(self) -> None:
|
|
143
|
-
if self._session is not None:
|
|
144
|
-
await self._session.close()
|
|
145
|
-
self._session = None
|
|
146
|
-
|
|
147
|
-
def __del__(self) -> None:
|
|
148
|
-
# https://stackoverflow.com/questions/54770360/how-can-i-wait-for-an-objects-del-to-finish-before-the-async-loop-closes
|
|
149
|
-
if self._session:
|
|
150
|
-
try:
|
|
151
|
-
loop = asyncio.get_event_loop()
|
|
152
|
-
if loop.is_running():
|
|
153
|
-
loop.create_task(self.close())
|
|
154
|
-
else:
|
|
155
|
-
loop.run_until_complete(self.close())
|
|
156
|
-
except Exception:
|
|
157
|
-
pass
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
class MistralAsyncClient(ClientBase):
|
|
161
|
-
def __init__(
|
|
162
|
-
self,
|
|
163
|
-
api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY", None),
|
|
164
|
-
endpoint: str = ENDPOINT,
|
|
165
|
-
max_retries: int = 5,
|
|
166
|
-
timeout: int = 120,
|
|
167
|
-
max_concurrent_requests: int = 64,
|
|
168
|
-
):
|
|
169
|
-
super().__init__(endpoint, api_key, max_retries, timeout)
|
|
170
|
-
|
|
171
|
-
self._backend = AIOHTTPBackend(
|
|
172
|
-
max_concurrent_requests=max_concurrent_requests,
|
|
173
|
-
max_retries=max_retries,
|
|
174
|
-
timeout=timeout,
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
async def close(self) -> None:
|
|
178
|
-
await self._backend.close()
|
|
179
|
-
|
|
180
|
-
async def _request(
|
|
181
|
-
self,
|
|
182
|
-
method: str,
|
|
183
|
-
json: Dict[str, Any],
|
|
184
|
-
path: str,
|
|
185
|
-
stream: bool = False,
|
|
186
|
-
params: Optional[Dict[str, Any]] = None,
|
|
187
|
-
) -> Union[Dict[str, Any], aiohttp.ClientResponse]:
|
|
188
|
-
|
|
189
|
-
headers = {
|
|
190
|
-
"Authorization": f"Bearer {self._api_key}",
|
|
191
|
-
"Content-Type": "application/json",
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
url = posixpath.join(self._endpoint, path)
|
|
195
|
-
|
|
196
|
-
response = await self._backend.request(
|
|
197
|
-
url, json, method, headers, params=params
|
|
198
|
-
)
|
|
199
|
-
if stream:
|
|
200
|
-
return response
|
|
201
|
-
|
|
202
|
-
try:
|
|
203
|
-
json_response: Dict[str, Any] = await response.json()
|
|
204
|
-
except JSONDecodeError:
|
|
205
|
-
raise MistralAPIException.from_aio_response(
|
|
206
|
-
response, message=f"Failed to decode json body: {await response.text()}"
|
|
207
|
-
)
|
|
208
|
-
except aiohttp.ClientPayloadError as e:
|
|
209
|
-
raise MistralAPIException.from_aio_response(
|
|
210
|
-
response,
|
|
211
|
-
message=f"An unexpected error occurred while receiving the response: {e}",
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
self._logger.debug(f"JSON response: {json_response}")
|
|
215
|
-
self._check_response(json_response, dict(response.headers), response.status)
|
|
216
|
-
return json_response
|
|
217
|
-
|
|
218
|
-
async def chat(
|
|
219
|
-
self,
|
|
220
|
-
model: str,
|
|
221
|
-
messages: List[ChatMessage],
|
|
222
|
-
temperature: Optional[float] = None,
|
|
223
|
-
max_tokens: Optional[int] = None,
|
|
224
|
-
top_p: Optional[float] = None,
|
|
225
|
-
random_seed: Optional[int] = None,
|
|
226
|
-
safe_mode: bool = True,
|
|
227
|
-
) -> ChatCompletionResponse:
|
|
228
|
-
""" A asynchronous chat endpoint that returns a single response.
|
|
229
|
-
|
|
230
|
-
Args:
|
|
231
|
-
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
232
|
-
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
|
|
233
|
-
[{role: 'user', content: 'What is the best French cheese?'}]
|
|
234
|
-
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
235
|
-
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
236
|
-
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
237
|
-
Defaults to None.
|
|
238
|
-
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
239
|
-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
|
|
240
|
-
|
|
241
|
-
Returns:
|
|
242
|
-
ChatCompletionResponse: a response object containing the generated text.
|
|
243
|
-
"""
|
|
244
|
-
request = self._make_chat_request(
|
|
245
|
-
model,
|
|
246
|
-
messages,
|
|
247
|
-
temperature=temperature,
|
|
248
|
-
max_tokens=max_tokens,
|
|
249
|
-
top_p=top_p,
|
|
250
|
-
random_seed=random_seed,
|
|
251
|
-
stream=False,
|
|
252
|
-
safe_mode=safe_mode,
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
response = await self._request("post", request, "v1/chat/completions")
|
|
256
|
-
assert isinstance(response, dict), "Bad response from _request"
|
|
257
|
-
return ChatCompletionResponse(**response)
|
|
258
|
-
|
|
259
|
-
async def chat_stream(
|
|
260
|
-
self,
|
|
261
|
-
model: str,
|
|
262
|
-
messages: List[ChatMessage],
|
|
263
|
-
temperature: Optional[float] = None,
|
|
264
|
-
max_tokens: Optional[int] = None,
|
|
265
|
-
top_p: Optional[float] = None,
|
|
266
|
-
random_seed: Optional[int] = None,
|
|
267
|
-
safe_mode: bool = False,
|
|
268
|
-
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
269
|
-
""" An Asynchronous chat endpoint that streams responses.
|
|
270
|
-
|
|
271
|
-
Args:
|
|
272
|
-
model (str): model the name of the model to chat with, e.g. mistral-tiny
|
|
273
|
-
messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
|
|
274
|
-
[{role: 'user', content: 'What is the best French cheese?'}]
|
|
275
|
-
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
276
|
-
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
277
|
-
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
278
|
-
Defaults to None.
|
|
279
|
-
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
280
|
-
safe_mode (bool, optional): whether to use safe mode, e.g. true. Defaults to False.
|
|
281
|
-
|
|
282
|
-
Returns:
|
|
283
|
-
AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
284
|
-
An async generator that yields ChatCompletionStreamResponse objects.
|
|
285
|
-
"""
|
|
286
|
-
|
|
287
|
-
request = self._make_chat_request(
|
|
288
|
-
model,
|
|
289
|
-
messages,
|
|
290
|
-
temperature=temperature,
|
|
291
|
-
max_tokens=max_tokens,
|
|
292
|
-
top_p=top_p,
|
|
293
|
-
random_seed=random_seed,
|
|
294
|
-
stream=True,
|
|
295
|
-
safe_mode=safe_mode,
|
|
296
|
-
)
|
|
297
|
-
async_response = await self._request(
|
|
298
|
-
"post", request, "v1/chat/completions", stream=True
|
|
299
|
-
)
|
|
300
|
-
|
|
301
|
-
assert isinstance(
|
|
302
|
-
async_response, aiohttp.ClientResponse
|
|
303
|
-
), "Bad response from _request"
|
|
304
|
-
|
|
305
|
-
async with async_response as response:
|
|
306
|
-
async for line in response.content:
|
|
307
|
-
if line == b"\n":
|
|
308
|
-
continue
|
|
309
|
-
|
|
310
|
-
if line.startswith(b"data: "):
|
|
311
|
-
line = line[6:].strip()
|
|
312
|
-
if line != b"[DONE]":
|
|
313
|
-
json_response = orjson.loads(line)
|
|
314
|
-
yield ChatCompletionStreamResponse(**json_response)
|
|
315
|
-
|
|
316
|
-
async def embeddings(
|
|
317
|
-
self, model: str, input: Union[str, List[str]]
|
|
318
|
-
) -> EmbeddingResponse:
|
|
319
|
-
"""An asynchronous embeddings endpoint that returns embeddings for a single, or batch of inputs
|
|
320
|
-
|
|
321
|
-
Args:
|
|
322
|
-
model (str): The embedding model to use, e.g. mistral-embed
|
|
323
|
-
input (Union[str, List[str]]): The input to embed,
|
|
324
|
-
e.g. ['What is the best French cheese?']
|
|
325
|
-
|
|
326
|
-
Returns:
|
|
327
|
-
EmbeddingResponse: A response object containing the embeddings.
|
|
328
|
-
"""
|
|
329
|
-
request = {"model": model, "input": input}
|
|
330
|
-
response = await self._request("post", request, "v1/embeddings")
|
|
331
|
-
assert isinstance(response, dict), "Bad response from _request"
|
|
332
|
-
return EmbeddingResponse(**response)
|
|
333
|
-
|
|
334
|
-
async def list_models(self) -> ModelList:
|
|
335
|
-
"""Returns a list of the available models
|
|
336
|
-
|
|
337
|
-
Returns:
|
|
338
|
-
ModelList: A response object containing the list of models.
|
|
339
|
-
"""
|
|
340
|
-
response = await self._request("get", {}, "v1/models")
|
|
341
|
-
assert isinstance(response, dict), "Bad response from _request"
|
|
342
|
-
return ModelList(**response)
|
|
@@ -1,73 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from abc import ABC
|
|
3
|
-
from typing import Any, Dict, List, Optional
|
|
4
|
-
|
|
5
|
-
from mistralai.exceptions import MistralAPIException, MistralException
|
|
6
|
-
from mistralai.models.chat_completion import ChatMessage
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class ClientBase(ABC):
|
|
10
|
-
def __init__(
|
|
11
|
-
self,
|
|
12
|
-
endpoint: str,
|
|
13
|
-
api_key: Optional[str] = None,
|
|
14
|
-
max_retries: int = 5,
|
|
15
|
-
timeout: int = 120,
|
|
16
|
-
):
|
|
17
|
-
self._max_retries = max_retries
|
|
18
|
-
self._timeout = timeout
|
|
19
|
-
|
|
20
|
-
self._endpoint = endpoint
|
|
21
|
-
self._api_key = api_key
|
|
22
|
-
self._logger = logging.getLogger(__name__)
|
|
23
|
-
|
|
24
|
-
@staticmethod
|
|
25
|
-
def _make_chat_request(
|
|
26
|
-
model: str,
|
|
27
|
-
messages: List[ChatMessage],
|
|
28
|
-
temperature: Optional[float] = None,
|
|
29
|
-
max_tokens: Optional[int] = None,
|
|
30
|
-
top_p: Optional[float] = None,
|
|
31
|
-
random_seed: Optional[int] = None,
|
|
32
|
-
stream: Optional[bool] = None,
|
|
33
|
-
safe_mode: Optional[bool] = True,
|
|
34
|
-
) -> Dict[str, Any]:
|
|
35
|
-
request_data: Dict[str, Any] = {
|
|
36
|
-
"model": model,
|
|
37
|
-
"messages": [msg.model_dump() for msg in messages],
|
|
38
|
-
"safe_prompt": safe_mode,
|
|
39
|
-
}
|
|
40
|
-
if temperature is not None:
|
|
41
|
-
request_data["temperature"] = temperature
|
|
42
|
-
if max_tokens is not None:
|
|
43
|
-
request_data["max_tokens"] = max_tokens
|
|
44
|
-
if top_p is not None:
|
|
45
|
-
request_data["top_p"] = top_p
|
|
46
|
-
if random_seed is not None:
|
|
47
|
-
request_data["random_seed"] = random_seed
|
|
48
|
-
if stream is not None:
|
|
49
|
-
request_data["stream"] = stream
|
|
50
|
-
|
|
51
|
-
return request_data
|
|
52
|
-
|
|
53
|
-
def _check_response(
|
|
54
|
-
self, json_response: Dict[str, Any], headers: Dict[str, Any], status: int
|
|
55
|
-
) -> None:
|
|
56
|
-
if "object" not in json_response:
|
|
57
|
-
raise MistralException(message=f"Unexpected response: {json_response}")
|
|
58
|
-
if "error" == json_response["object"]: # has errors
|
|
59
|
-
raise MistralAPIException(
|
|
60
|
-
message=json_response["message"],
|
|
61
|
-
http_status=status,
|
|
62
|
-
headers=headers,
|
|
63
|
-
)
|
|
64
|
-
if 400 <= status < 500:
|
|
65
|
-
raise MistralAPIException(
|
|
66
|
-
message=f"Unexpected client error (status {status}): {json_response}",
|
|
67
|
-
http_status=status,
|
|
68
|
-
headers=headers,
|
|
69
|
-
)
|
|
70
|
-
if status >= 500:
|
|
71
|
-
raise MistralException(
|
|
72
|
-
message=f"Unexpected server error (status {status}): {json_response}"
|
|
73
|
-
)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|