mistralai 0.1.8__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.1.8 → mistralai-0.3.0}/PKG-INFO +2 -2
- {mistralai-0.1.8 → mistralai-0.3.0}/pyproject.toml +2 -2
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/async_client.py +75 -5
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/client.py +76 -3
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/client_base.py +70 -15
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/constants.py +0 -2
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/exceptions.py +3 -3
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/models/chat_completion.py +1 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/models/models.py +1 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/LICENSE +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/README.md +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/__init__.py +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/models/common.py +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.1.8 → mistralai-0.3.0}/src/mistralai/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Bam4d
|
|
6
6
|
Author-email: bam4d@mistral.ai
|
|
@@ -10,7 +10,7 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist: httpx (>=0.25.2,<
|
|
13
|
+
Requires-Dist: httpx (>=0.25.2,<1)
|
|
14
14
|
Requires-Dist: orjson (>=3.9.10,<4.0.0)
|
|
15
15
|
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.3.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -26,7 +26,7 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
|
|
|
26
26
|
python = "^3.9"
|
|
27
27
|
orjson = "^3.9.10"
|
|
28
28
|
pydantic = "^2.5.2"
|
|
29
|
-
httpx = "
|
|
29
|
+
httpx = ">= 0.25.2, < 1"
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -1,6 +1,5 @@
|
|
|
1
|
-
import
|
|
1
|
+
import asyncio
|
|
2
2
|
import posixpath
|
|
3
|
-
import time
|
|
4
3
|
from json import JSONDecodeError
|
|
5
4
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
6
5
|
|
|
@@ -34,7 +33,7 @@ from mistralai.models.models import ModelList
|
|
|
34
33
|
class MistralAsyncClient(ClientBase):
|
|
35
34
|
def __init__(
|
|
36
35
|
self,
|
|
37
|
-
api_key: Optional[str] =
|
|
36
|
+
api_key: Optional[str] = None,
|
|
38
37
|
endpoint: str = ENDPOINT,
|
|
39
38
|
max_retries: int = 5,
|
|
40
39
|
timeout: int = 120,
|
|
@@ -93,7 +92,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
93
92
|
async def _request(
|
|
94
93
|
self,
|
|
95
94
|
method: str,
|
|
96
|
-
json: Dict[str, Any],
|
|
95
|
+
json: Optional[Dict[str, Any]],
|
|
97
96
|
path: str,
|
|
98
97
|
stream: bool = False,
|
|
99
98
|
attempt: int = 1,
|
|
@@ -151,7 +150,7 @@ class MistralAsyncClient(ClientBase):
|
|
|
151
150
|
if attempt > self._max_retries:
|
|
152
151
|
raise MistralAPIStatusException.from_response(response, message=str(e)) from e
|
|
153
152
|
backoff = 2.0**attempt # exponential backoff
|
|
154
|
-
|
|
153
|
+
await asyncio.sleep(backoff)
|
|
155
154
|
|
|
156
155
|
# Retry as a generator
|
|
157
156
|
async for r in self._request(method, json, path, stream=stream, attempt=attempt):
|
|
@@ -292,3 +291,74 @@ class MistralAsyncClient(ClientBase):
|
|
|
292
291
|
return ModelList(**response)
|
|
293
292
|
|
|
294
293
|
raise MistralException("No response received")
|
|
294
|
+
|
|
295
|
+
async def completion(
|
|
296
|
+
self,
|
|
297
|
+
model: str,
|
|
298
|
+
prompt: str,
|
|
299
|
+
suffix: Optional[str] = None,
|
|
300
|
+
temperature: Optional[float] = None,
|
|
301
|
+
max_tokens: Optional[int] = None,
|
|
302
|
+
top_p: Optional[float] = None,
|
|
303
|
+
random_seed: Optional[int] = None,
|
|
304
|
+
stop: Optional[List[str]] = None,
|
|
305
|
+
) -> ChatCompletionResponse:
|
|
306
|
+
"""An asynchronous completion endpoint that returns a single response.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
310
|
+
prompt (str): the prompt to complete
|
|
311
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
312
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
313
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
314
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
315
|
+
Defaults to None.
|
|
316
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
317
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
318
|
+
Returns:
|
|
319
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
320
|
+
"""
|
|
321
|
+
request = self._make_completion_request(
|
|
322
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
323
|
+
)
|
|
324
|
+
single_response = self._request("post", request, "v1/fim/completions")
|
|
325
|
+
|
|
326
|
+
async for response in single_response:
|
|
327
|
+
return ChatCompletionResponse(**response)
|
|
328
|
+
|
|
329
|
+
raise MistralException("No response received")
|
|
330
|
+
|
|
331
|
+
async def completion_stream(
|
|
332
|
+
self,
|
|
333
|
+
model: str,
|
|
334
|
+
prompt: str,
|
|
335
|
+
suffix: Optional[str] = None,
|
|
336
|
+
temperature: Optional[float] = None,
|
|
337
|
+
max_tokens: Optional[int] = None,
|
|
338
|
+
top_p: Optional[float] = None,
|
|
339
|
+
random_seed: Optional[int] = None,
|
|
340
|
+
stop: Optional[List[str]] = None,
|
|
341
|
+
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
342
|
+
"""An asynchronous completion endpoint that returns a streaming response.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
346
|
+
prompt (str): the prompt to complete
|
|
347
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
348
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
349
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
350
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
351
|
+
Defaults to None.
|
|
352
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
353
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
357
|
+
"""
|
|
358
|
+
request = self._make_completion_request(
|
|
359
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
|
|
360
|
+
)
|
|
361
|
+
async_response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
362
|
+
|
|
363
|
+
async for json_response in async_response:
|
|
364
|
+
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import os
|
|
2
1
|
import posixpath
|
|
3
2
|
import time
|
|
4
3
|
from json import JSONDecodeError
|
|
@@ -31,7 +30,7 @@ class MistralClient(ClientBase):
|
|
|
31
30
|
|
|
32
31
|
def __init__(
|
|
33
32
|
self,
|
|
34
|
-
api_key: Optional[str] =
|
|
33
|
+
api_key: Optional[str] = None,
|
|
35
34
|
endpoint: str = ENDPOINT,
|
|
36
35
|
max_retries: int = 5,
|
|
37
36
|
timeout: int = 120,
|
|
@@ -86,7 +85,7 @@ class MistralClient(ClientBase):
|
|
|
86
85
|
def _request(
|
|
87
86
|
self,
|
|
88
87
|
method: str,
|
|
89
|
-
json: Dict[str, Any],
|
|
88
|
+
json: Optional[Dict[str, Any]],
|
|
90
89
|
path: str,
|
|
91
90
|
stream: bool = False,
|
|
92
91
|
attempt: int = 1,
|
|
@@ -286,3 +285,77 @@ class MistralClient(ClientBase):
|
|
|
286
285
|
return ModelList(**response)
|
|
287
286
|
|
|
288
287
|
raise MistralException("No response received")
|
|
288
|
+
|
|
289
|
+
def completion(
|
|
290
|
+
self,
|
|
291
|
+
model: str,
|
|
292
|
+
prompt: str,
|
|
293
|
+
suffix: Optional[str] = None,
|
|
294
|
+
temperature: Optional[float] = None,
|
|
295
|
+
max_tokens: Optional[int] = None,
|
|
296
|
+
top_p: Optional[float] = None,
|
|
297
|
+
random_seed: Optional[int] = None,
|
|
298
|
+
stop: Optional[List[str]] = None,
|
|
299
|
+
) -> ChatCompletionResponse:
|
|
300
|
+
"""A completion endpoint that returns a single response.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
model (str): model the name of the model to get completion with, e.g. codestral-latest
|
|
304
|
+
prompt (str): the prompt to complete
|
|
305
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
306
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
307
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
308
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
309
|
+
Defaults to None.
|
|
310
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
311
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
312
|
+
|
|
313
|
+
Returns:
|
|
314
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
315
|
+
"""
|
|
316
|
+
request = self._make_completion_request(
|
|
317
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
single_response = self._request("post", request, "v1/fim/completions", stream=False)
|
|
321
|
+
|
|
322
|
+
for response in single_response:
|
|
323
|
+
return ChatCompletionResponse(**response)
|
|
324
|
+
|
|
325
|
+
raise MistralException("No response received")
|
|
326
|
+
|
|
327
|
+
def completion_stream(
|
|
328
|
+
self,
|
|
329
|
+
model: str,
|
|
330
|
+
prompt: str,
|
|
331
|
+
suffix: Optional[str] = None,
|
|
332
|
+
temperature: Optional[float] = None,
|
|
333
|
+
max_tokens: Optional[int] = None,
|
|
334
|
+
top_p: Optional[float] = None,
|
|
335
|
+
random_seed: Optional[int] = None,
|
|
336
|
+
stop: Optional[List[str]] = None,
|
|
337
|
+
) -> Iterable[ChatCompletionStreamResponse]:
|
|
338
|
+
"""An asynchronous completion endpoint that streams responses.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
342
|
+
prompt (str): the prompt to complete
|
|
343
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
344
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
345
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
346
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
347
|
+
Defaults to None.
|
|
348
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
349
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
|
|
353
|
+
"""
|
|
354
|
+
request = self._make_completion_request(
|
|
355
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
359
|
+
|
|
360
|
+
for json_streamed_response in response:
|
|
361
|
+
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -10,10 +10,7 @@ from mistralai.exceptions import (
|
|
|
10
10
|
)
|
|
11
11
|
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
15
|
-
level=os.getenv("LOG_LEVEL", "ERROR"),
|
|
16
|
-
)
|
|
13
|
+
CLIENT_VERSION = "0.2.0"
|
|
17
14
|
|
|
18
15
|
|
|
19
16
|
class ClientBase(ABC):
|
|
@@ -27,16 +24,19 @@ class ClientBase(ABC):
|
|
|
27
24
|
self._max_retries = max_retries
|
|
28
25
|
self._timeout = timeout
|
|
29
26
|
|
|
30
|
-
|
|
27
|
+
if api_key is None:
|
|
28
|
+
api_key = os.environ.get("MISTRAL_API_KEY")
|
|
29
|
+
if api_key is None:
|
|
30
|
+
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
|
|
31
31
|
self._api_key = api_key
|
|
32
|
+
self._endpoint = endpoint
|
|
32
33
|
self._logger = logging.getLogger(__name__)
|
|
33
34
|
|
|
34
35
|
# For azure endpoints, we default to the mistral model
|
|
35
36
|
if "inference.azure.com" in self._endpoint:
|
|
36
37
|
self._default_model = "mistral"
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
self._version = "0.1.8"
|
|
39
|
+
self._version = CLIENT_VERSION
|
|
40
40
|
|
|
41
41
|
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
42
42
|
parsed_tools: List[Dict[str, Any]] = []
|
|
@@ -73,6 +73,63 @@ class ClientBase(ABC):
|
|
|
73
73
|
|
|
74
74
|
return parsed_messages
|
|
75
75
|
|
|
76
|
+
def _make_completion_request(
|
|
77
|
+
self,
|
|
78
|
+
prompt: str,
|
|
79
|
+
model: Optional[str] = None,
|
|
80
|
+
suffix: Optional[str] = None,
|
|
81
|
+
temperature: Optional[float] = None,
|
|
82
|
+
max_tokens: Optional[int] = None,
|
|
83
|
+
top_p: Optional[float] = None,
|
|
84
|
+
random_seed: Optional[int] = None,
|
|
85
|
+
stop: Optional[List[str]] = None,
|
|
86
|
+
stream: Optional[bool] = False,
|
|
87
|
+
) -> Dict[str, Any]:
|
|
88
|
+
request_data: Dict[str, Any] = {
|
|
89
|
+
"prompt": prompt,
|
|
90
|
+
"suffix": suffix,
|
|
91
|
+
"model": model,
|
|
92
|
+
"stream": stream,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
if stop is not None:
|
|
96
|
+
request_data["stop"] = stop
|
|
97
|
+
|
|
98
|
+
if model is not None:
|
|
99
|
+
request_data["model"] = model
|
|
100
|
+
else:
|
|
101
|
+
if self._default_model is None:
|
|
102
|
+
raise MistralException(message="model must be provided")
|
|
103
|
+
request_data["model"] = self._default_model
|
|
104
|
+
|
|
105
|
+
request_data.update(
|
|
106
|
+
self._build_sampling_params(
|
|
107
|
+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self._logger.debug(f"Completion request: {request_data}")
|
|
112
|
+
|
|
113
|
+
return request_data
|
|
114
|
+
|
|
115
|
+
def _build_sampling_params(
|
|
116
|
+
self,
|
|
117
|
+
max_tokens: Optional[int],
|
|
118
|
+
random_seed: Optional[int],
|
|
119
|
+
temperature: Optional[float],
|
|
120
|
+
top_p: Optional[float],
|
|
121
|
+
) -> Dict[str, Any]:
|
|
122
|
+
params = {}
|
|
123
|
+
if temperature is not None:
|
|
124
|
+
params["temperature"] = temperature
|
|
125
|
+
if max_tokens is not None:
|
|
126
|
+
params["max_tokens"] = max_tokens
|
|
127
|
+
if top_p is not None:
|
|
128
|
+
params["top_p"] = top_p
|
|
129
|
+
if random_seed is not None:
|
|
130
|
+
params["random_seed"] = random_seed
|
|
131
|
+
return params
|
|
132
|
+
|
|
76
133
|
def _make_chat_request(
|
|
77
134
|
self,
|
|
78
135
|
messages: List[Any],
|
|
@@ -99,16 +156,14 @@ class ClientBase(ABC):
|
|
|
99
156
|
raise MistralException(message="model must be provided")
|
|
100
157
|
request_data["model"] = self._default_model
|
|
101
158
|
|
|
159
|
+
request_data.update(
|
|
160
|
+
self._build_sampling_params(
|
|
161
|
+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
102
165
|
if tools is not None:
|
|
103
166
|
request_data["tools"] = self._parse_tools(tools)
|
|
104
|
-
if temperature is not None:
|
|
105
|
-
request_data["temperature"] = temperature
|
|
106
|
-
if max_tokens is not None:
|
|
107
|
-
request_data["max_tokens"] = max_tokens
|
|
108
|
-
if top_p is not None:
|
|
109
|
-
request_data["top_p"] = top_p
|
|
110
|
-
if random_seed is not None:
|
|
111
|
-
request_data["random_seed"] = random_seed
|
|
112
167
|
if stream is not None:
|
|
113
168
|
request_data["stream"] = stream
|
|
114
169
|
|
|
@@ -35,9 +35,7 @@ class MistralAPIException(MistralException):
|
|
|
35
35
|
self.headers = headers or {}
|
|
36
36
|
|
|
37
37
|
@classmethod
|
|
38
|
-
def from_response(
|
|
39
|
-
cls, response: Response, message: Optional[str] = None
|
|
40
|
-
) -> MistralAPIException:
|
|
38
|
+
def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
|
|
41
39
|
return cls(
|
|
42
40
|
message=message or response.text,
|
|
43
41
|
http_status=response.status_code,
|
|
@@ -47,8 +45,10 @@ class MistralAPIException(MistralException):
|
|
|
47
45
|
def __repr__(self) -> str:
|
|
48
46
|
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
|
|
49
47
|
|
|
48
|
+
|
|
50
49
|
class MistralAPIStatusException(MistralAPIException):
|
|
51
50
|
"""Returned when we receive a non-200 response from the API that we should retry"""
|
|
52
51
|
|
|
52
|
+
|
|
53
53
|
class MistralConnectionException(MistralException):
|
|
54
54
|
"""Returned when the SDK can not reach the API server for any reason"""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|