mistralai 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.2.0 → mistralai-0.4.0}/PKG-INFO +3 -3
- {mistralai-0.2.0 → mistralai-0.4.0}/pyproject.toml +5 -5
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/async_client.py +85 -2
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/client.py +88 -2
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/client_base.py +66 -10
- mistralai-0.4.0/src/mistralai/files.py +84 -0
- mistralai-0.4.0/src/mistralai/jobs.py +172 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/models/chat_completion.py +2 -2
- mistralai-0.4.0/src/mistralai/models/files.py +23 -0
- mistralai-0.4.0/src/mistralai/models/jobs.py +98 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/models/models.py +3 -3
- {mistralai-0.2.0 → mistralai-0.4.0}/LICENSE +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/README.md +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/__init__.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/constants.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/exceptions.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/models/common.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.2.0 → mistralai-0.4.0}/src/mistralai/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Bam4d
|
|
6
6
|
Author-email: bam4d@mistral.ai
|
|
@@ -10,8 +10,8 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist: httpx (>=0.25.
|
|
14
|
-
Requires-Dist: orjson (>=3.9.10,<
|
|
13
|
+
Requires-Dist: httpx (>=0.25,<0.26)
|
|
14
|
+
Requires-Dist: orjson (>=3.9.10,<3.11)
|
|
15
15
|
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -23,10 +23,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
[tool.poetry.dependencies]
|
|
26
|
-
python = "^3.9"
|
|
27
|
-
orjson = "^3.9.10"
|
|
28
|
-
pydantic = "^2.5.2"
|
|
29
|
-
httpx = "
|
|
26
|
+
python = "^3.9,<4.0"
|
|
27
|
+
orjson = "^3.9.10,<3.11"
|
|
28
|
+
pydantic = "^2.5.2,<3"
|
|
29
|
+
httpx = "^0.25,<1"
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -20,6 +20,8 @@ from mistralai.exceptions import (
|
|
|
20
20
|
MistralConnectionException,
|
|
21
21
|
MistralException,
|
|
22
22
|
)
|
|
23
|
+
from mistralai.files import FilesAsyncClient
|
|
24
|
+
from mistralai.jobs import JobsAsyncClient
|
|
23
25
|
from mistralai.models.chat_completion import (
|
|
24
26
|
ChatCompletionResponse,
|
|
25
27
|
ChatCompletionStreamResponse,
|
|
@@ -47,6 +49,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
47
49
|
limits=Limits(max_connections=max_concurrent_requests),
|
|
48
50
|
transport=AsyncHTTPTransport(retries=max_retries),
|
|
49
51
|
)
|
|
52
|
+
self.files = FilesAsyncClient(self)
|
|
53
|
+
self.jobs = JobsAsyncClient(self)
|
|
50
54
|
|
|
51
55
|
async def close(self) -> None:
|
|
52
56
|
await self._client.aclose()
|
|
@@ -92,19 +96,23 @@ class MistralAsyncClient(ClientBase):
|
|
|
92
96
|
async def _request(
|
|
93
97
|
self,
|
|
94
98
|
method: str,
|
|
95
|
-
json: Dict[str, Any],
|
|
99
|
+
json: Optional[Dict[str, Any]],
|
|
96
100
|
path: str,
|
|
97
101
|
stream: bool = False,
|
|
98
102
|
attempt: int = 1,
|
|
103
|
+
data: Optional[Dict[str, Any]] = None,
|
|
104
|
+
**kwargs: Any,
|
|
99
105
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
100
106
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
101
107
|
headers = {
|
|
102
108
|
"Accept": accept_header,
|
|
103
109
|
"User-Agent": f"mistral-client-python/{self._version}",
|
|
104
110
|
"Authorization": f"Bearer {self._api_key}",
|
|
105
|
-
"Content-Type": "application/json",
|
|
106
111
|
}
|
|
107
112
|
|
|
113
|
+
if json is not None:
|
|
114
|
+
headers["Content-Type"] = "application/json"
|
|
115
|
+
|
|
108
116
|
url = posixpath.join(self._endpoint, path)
|
|
109
117
|
|
|
110
118
|
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
@@ -118,6 +126,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
118
126
|
url,
|
|
119
127
|
headers=headers,
|
|
120
128
|
json=json,
|
|
129
|
+
data=data,
|
|
130
|
+
**kwargs,
|
|
121
131
|
) as response:
|
|
122
132
|
await self._check_streaming_response(response)
|
|
123
133
|
|
|
@@ -132,6 +142,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
132
142
|
url,
|
|
133
143
|
headers=headers,
|
|
134
144
|
json=json,
|
|
145
|
+
data=data,
|
|
146
|
+
**kwargs,
|
|
135
147
|
)
|
|
136
148
|
|
|
137
149
|
yield await self._check_response(response)
|
|
@@ -291,3 +303,74 @@ class MistralAsyncClient(ClientBase):
|
|
|
291
303
|
return ModelList(**response)
|
|
292
304
|
|
|
293
305
|
raise MistralException("No response received")
|
|
306
|
+
|
|
307
|
+
async def completion(
|
|
308
|
+
self,
|
|
309
|
+
model: str,
|
|
310
|
+
prompt: str,
|
|
311
|
+
suffix: Optional[str] = None,
|
|
312
|
+
temperature: Optional[float] = None,
|
|
313
|
+
max_tokens: Optional[int] = None,
|
|
314
|
+
top_p: Optional[float] = None,
|
|
315
|
+
random_seed: Optional[int] = None,
|
|
316
|
+
stop: Optional[List[str]] = None,
|
|
317
|
+
) -> ChatCompletionResponse:
|
|
318
|
+
"""An asynchronous completion endpoint that returns a single response.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
322
|
+
prompt (str): the prompt to complete
|
|
323
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
324
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
325
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
326
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
327
|
+
Defaults to None.
|
|
328
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
329
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
330
|
+
Returns:
|
|
331
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
332
|
+
"""
|
|
333
|
+
request = self._make_completion_request(
|
|
334
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
335
|
+
)
|
|
336
|
+
single_response = self._request("post", request, "v1/fim/completions")
|
|
337
|
+
|
|
338
|
+
async for response in single_response:
|
|
339
|
+
return ChatCompletionResponse(**response)
|
|
340
|
+
|
|
341
|
+
raise MistralException("No response received")
|
|
342
|
+
|
|
343
|
+
async def completion_stream(
|
|
344
|
+
self,
|
|
345
|
+
model: str,
|
|
346
|
+
prompt: str,
|
|
347
|
+
suffix: Optional[str] = None,
|
|
348
|
+
temperature: Optional[float] = None,
|
|
349
|
+
max_tokens: Optional[int] = None,
|
|
350
|
+
top_p: Optional[float] = None,
|
|
351
|
+
random_seed: Optional[int] = None,
|
|
352
|
+
stop: Optional[List[str]] = None,
|
|
353
|
+
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
|
|
354
|
+
"""An asynchronous completion endpoint that returns a streaming response.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
358
|
+
prompt (str): the prompt to complete
|
|
359
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
360
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
361
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
362
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
363
|
+
Defaults to None.
|
|
364
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
365
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
369
|
+
"""
|
|
370
|
+
request = self._make_completion_request(
|
|
371
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
|
|
372
|
+
)
|
|
373
|
+
async_response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
374
|
+
|
|
375
|
+
async for json_response in async_response:
|
|
376
|
+
yield ChatCompletionStreamResponse(**json_response)
|
|
@@ -13,6 +13,8 @@ from mistralai.exceptions import (
|
|
|
13
13
|
MistralConnectionException,
|
|
14
14
|
MistralException,
|
|
15
15
|
)
|
|
16
|
+
from mistralai.files import FilesClient
|
|
17
|
+
from mistralai.jobs import JobsClient
|
|
16
18
|
from mistralai.models.chat_completion import (
|
|
17
19
|
ChatCompletionResponse,
|
|
18
20
|
ChatCompletionStreamResponse,
|
|
@@ -40,6 +42,8 @@ class MistralClient(ClientBase):
|
|
|
40
42
|
self._client = Client(
|
|
41
43
|
follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
|
|
42
44
|
)
|
|
45
|
+
self.files = FilesClient(self)
|
|
46
|
+
self.jobs = JobsClient(self)
|
|
43
47
|
|
|
44
48
|
def __del__(self) -> None:
|
|
45
49
|
self._client.close()
|
|
@@ -85,19 +89,23 @@ class MistralClient(ClientBase):
|
|
|
85
89
|
def _request(
|
|
86
90
|
self,
|
|
87
91
|
method: str,
|
|
88
|
-
json: Dict[str, Any],
|
|
92
|
+
json: Optional[Dict[str, Any]],
|
|
89
93
|
path: str,
|
|
90
94
|
stream: bool = False,
|
|
91
95
|
attempt: int = 1,
|
|
96
|
+
data: Optional[Dict[str, Any]] = None,
|
|
97
|
+
**kwargs: Any,
|
|
92
98
|
) -> Iterator[Dict[str, Any]]:
|
|
93
99
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
94
100
|
headers = {
|
|
95
101
|
"Accept": accept_header,
|
|
96
102
|
"User-Agent": f"mistral-client-python/{self._version}",
|
|
97
103
|
"Authorization": f"Bearer {self._api_key}",
|
|
98
|
-
"Content-Type": "application/json",
|
|
99
104
|
}
|
|
100
105
|
|
|
106
|
+
if json is not None:
|
|
107
|
+
headers["Content-Type"] = "application/json"
|
|
108
|
+
|
|
101
109
|
url = posixpath.join(self._endpoint, path)
|
|
102
110
|
|
|
103
111
|
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
@@ -111,6 +119,8 @@ class MistralClient(ClientBase):
|
|
|
111
119
|
url,
|
|
112
120
|
headers=headers,
|
|
113
121
|
json=json,
|
|
122
|
+
data=data,
|
|
123
|
+
**kwargs,
|
|
114
124
|
) as response:
|
|
115
125
|
self._check_streaming_response(response)
|
|
116
126
|
|
|
@@ -125,6 +135,8 @@ class MistralClient(ClientBase):
|
|
|
125
135
|
url,
|
|
126
136
|
headers=headers,
|
|
127
137
|
json=json,
|
|
138
|
+
data=data,
|
|
139
|
+
**kwargs,
|
|
128
140
|
)
|
|
129
141
|
|
|
130
142
|
yield self._check_response(response)
|
|
@@ -285,3 +297,77 @@ class MistralClient(ClientBase):
|
|
|
285
297
|
return ModelList(**response)
|
|
286
298
|
|
|
287
299
|
raise MistralException("No response received")
|
|
300
|
+
|
|
301
|
+
def completion(
|
|
302
|
+
self,
|
|
303
|
+
model: str,
|
|
304
|
+
prompt: str,
|
|
305
|
+
suffix: Optional[str] = None,
|
|
306
|
+
temperature: Optional[float] = None,
|
|
307
|
+
max_tokens: Optional[int] = None,
|
|
308
|
+
top_p: Optional[float] = None,
|
|
309
|
+
random_seed: Optional[int] = None,
|
|
310
|
+
stop: Optional[List[str]] = None,
|
|
311
|
+
) -> ChatCompletionResponse:
|
|
312
|
+
"""A completion endpoint that returns a single response.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
model (str): model the name of the model to get completion with, e.g. codestral-latest
|
|
316
|
+
prompt (str): the prompt to complete
|
|
317
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
318
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
319
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
320
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
321
|
+
Defaults to None.
|
|
322
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
323
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Dict[str, Any]: a response object containing the generated text.
|
|
327
|
+
"""
|
|
328
|
+
request = self._make_completion_request(
|
|
329
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
single_response = self._request("post", request, "v1/fim/completions", stream=False)
|
|
333
|
+
|
|
334
|
+
for response in single_response:
|
|
335
|
+
return ChatCompletionResponse(**response)
|
|
336
|
+
|
|
337
|
+
raise MistralException("No response received")
|
|
338
|
+
|
|
339
|
+
def completion_stream(
|
|
340
|
+
self,
|
|
341
|
+
model: str,
|
|
342
|
+
prompt: str,
|
|
343
|
+
suffix: Optional[str] = None,
|
|
344
|
+
temperature: Optional[float] = None,
|
|
345
|
+
max_tokens: Optional[int] = None,
|
|
346
|
+
top_p: Optional[float] = None,
|
|
347
|
+
random_seed: Optional[int] = None,
|
|
348
|
+
stop: Optional[List[str]] = None,
|
|
349
|
+
) -> Iterable[ChatCompletionStreamResponse]:
|
|
350
|
+
"""An asynchronous completion endpoint that streams responses.
|
|
351
|
+
|
|
352
|
+
Args:
|
|
353
|
+
model (str): model the name of the model to get completions with, e.g. codestral-latest
|
|
354
|
+
prompt (str): the prompt to complete
|
|
355
|
+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
|
|
356
|
+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
|
|
357
|
+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
|
|
358
|
+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
|
|
359
|
+
Defaults to None.
|
|
360
|
+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
|
|
361
|
+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
|
|
365
|
+
"""
|
|
366
|
+
request = self._make_completion_request(
|
|
367
|
+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
response = self._request("post", request, "v1/fim/completions", stream=True)
|
|
371
|
+
|
|
372
|
+
for json_streamed_response in response:
|
|
373
|
+
yield ChatCompletionStreamResponse(**json_streamed_response)
|
|
@@ -10,7 +10,7 @@ from mistralai.exceptions import (
|
|
|
10
10
|
)
|
|
11
11
|
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
|
|
12
12
|
|
|
13
|
-
CLIENT_VERSION = "0.
|
|
13
|
+
CLIENT_VERSION = "0.4.0"
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ClientBase(ABC):
|
|
@@ -73,6 +73,63 @@ class ClientBase(ABC):
|
|
|
73
73
|
|
|
74
74
|
return parsed_messages
|
|
75
75
|
|
|
76
|
+
def _make_completion_request(
|
|
77
|
+
self,
|
|
78
|
+
prompt: str,
|
|
79
|
+
model: Optional[str] = None,
|
|
80
|
+
suffix: Optional[str] = None,
|
|
81
|
+
temperature: Optional[float] = None,
|
|
82
|
+
max_tokens: Optional[int] = None,
|
|
83
|
+
top_p: Optional[float] = None,
|
|
84
|
+
random_seed: Optional[int] = None,
|
|
85
|
+
stop: Optional[List[str]] = None,
|
|
86
|
+
stream: Optional[bool] = False,
|
|
87
|
+
) -> Dict[str, Any]:
|
|
88
|
+
request_data: Dict[str, Any] = {
|
|
89
|
+
"prompt": prompt,
|
|
90
|
+
"suffix": suffix,
|
|
91
|
+
"model": model,
|
|
92
|
+
"stream": stream,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
if stop is not None:
|
|
96
|
+
request_data["stop"] = stop
|
|
97
|
+
|
|
98
|
+
if model is not None:
|
|
99
|
+
request_data["model"] = model
|
|
100
|
+
else:
|
|
101
|
+
if self._default_model is None:
|
|
102
|
+
raise MistralException(message="model must be provided")
|
|
103
|
+
request_data["model"] = self._default_model
|
|
104
|
+
|
|
105
|
+
request_data.update(
|
|
106
|
+
self._build_sampling_params(
|
|
107
|
+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self._logger.debug(f"Completion request: {request_data}")
|
|
112
|
+
|
|
113
|
+
return request_data
|
|
114
|
+
|
|
115
|
+
def _build_sampling_params(
|
|
116
|
+
self,
|
|
117
|
+
max_tokens: Optional[int],
|
|
118
|
+
random_seed: Optional[int],
|
|
119
|
+
temperature: Optional[float],
|
|
120
|
+
top_p: Optional[float],
|
|
121
|
+
) -> Dict[str, Any]:
|
|
122
|
+
params = {}
|
|
123
|
+
if temperature is not None:
|
|
124
|
+
params["temperature"] = temperature
|
|
125
|
+
if max_tokens is not None:
|
|
126
|
+
params["max_tokens"] = max_tokens
|
|
127
|
+
if top_p is not None:
|
|
128
|
+
params["top_p"] = top_p
|
|
129
|
+
if random_seed is not None:
|
|
130
|
+
params["random_seed"] = random_seed
|
|
131
|
+
return params
|
|
132
|
+
|
|
76
133
|
def _make_chat_request(
|
|
77
134
|
self,
|
|
78
135
|
messages: List[Any],
|
|
@@ -89,7 +146,6 @@ class ClientBase(ABC):
|
|
|
89
146
|
) -> Dict[str, Any]:
|
|
90
147
|
request_data: Dict[str, Any] = {
|
|
91
148
|
"messages": self._parse_messages(messages),
|
|
92
|
-
"safe_prompt": safe_prompt,
|
|
93
149
|
}
|
|
94
150
|
|
|
95
151
|
if model is not None:
|
|
@@ -99,16 +155,16 @@ class ClientBase(ABC):
|
|
|
99
155
|
raise MistralException(message="model must be provided")
|
|
100
156
|
request_data["model"] = self._default_model
|
|
101
157
|
|
|
158
|
+
request_data.update(
|
|
159
|
+
self._build_sampling_params(
|
|
160
|
+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if safe_prompt:
|
|
165
|
+
request_data["safe_prompt"] = safe_prompt
|
|
102
166
|
if tools is not None:
|
|
103
167
|
request_data["tools"] = self._parse_tools(tools)
|
|
104
|
-
if temperature is not None:
|
|
105
|
-
request_data["temperature"] = temperature
|
|
106
|
-
if max_tokens is not None:
|
|
107
|
-
request_data["max_tokens"] = max_tokens
|
|
108
|
-
if top_p is not None:
|
|
109
|
-
request_data["top_p"] = top_p
|
|
110
|
-
if random_seed is not None:
|
|
111
|
-
request_data["random_seed"] = random_seed
|
|
112
168
|
if stream is not None:
|
|
113
169
|
request_data["stream"] = stream
|
|
114
170
|
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from mistralai.exceptions import (
|
|
4
|
+
MistralException,
|
|
5
|
+
)
|
|
6
|
+
from mistralai.models.files import FileDeleted, FileObject, Files
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FilesClient:
|
|
10
|
+
def __init__(self, client: Any):
|
|
11
|
+
self.client = client
|
|
12
|
+
|
|
13
|
+
def create(
|
|
14
|
+
self,
|
|
15
|
+
file: bytes,
|
|
16
|
+
purpose: str = "fine-tune",
|
|
17
|
+
) -> FileObject:
|
|
18
|
+
single_response = self.client._request(
|
|
19
|
+
"post",
|
|
20
|
+
None,
|
|
21
|
+
"v1/files",
|
|
22
|
+
files={"file": file},
|
|
23
|
+
data={"purpose": purpose},
|
|
24
|
+
)
|
|
25
|
+
for response in single_response:
|
|
26
|
+
return FileObject(**response)
|
|
27
|
+
raise MistralException("No response received")
|
|
28
|
+
|
|
29
|
+
def retrieve(self, file_id: str) -> FileObject:
|
|
30
|
+
single_response = self.client._request("get", {}, f"v1/files/{file_id}")
|
|
31
|
+
for response in single_response:
|
|
32
|
+
return FileObject(**response)
|
|
33
|
+
raise MistralException("No response received")
|
|
34
|
+
|
|
35
|
+
def list(self) -> Files:
|
|
36
|
+
single_response = self.client._request("get", {}, "v1/files")
|
|
37
|
+
for response in single_response:
|
|
38
|
+
return Files(**response)
|
|
39
|
+
raise MistralException("No response received")
|
|
40
|
+
|
|
41
|
+
def delete(self, file_id: str) -> FileDeleted:
|
|
42
|
+
single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
|
|
43
|
+
for response in single_response:
|
|
44
|
+
return FileDeleted(**response)
|
|
45
|
+
raise MistralException("No response received")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FilesAsyncClient:
|
|
49
|
+
def __init__(self, client: Any):
|
|
50
|
+
self.client = client
|
|
51
|
+
|
|
52
|
+
async def create(
|
|
53
|
+
self,
|
|
54
|
+
file: bytes,
|
|
55
|
+
purpose: str = "fine-tune",
|
|
56
|
+
) -> FileObject:
|
|
57
|
+
single_response = self.client._request(
|
|
58
|
+
"post",
|
|
59
|
+
None,
|
|
60
|
+
"v1/files",
|
|
61
|
+
files={"file": file},
|
|
62
|
+
data={"purpose": purpose},
|
|
63
|
+
)
|
|
64
|
+
async for response in single_response:
|
|
65
|
+
return FileObject(**response)
|
|
66
|
+
raise MistralException("No response received")
|
|
67
|
+
|
|
68
|
+
async def retrieve(self, file_id: str) -> FileObject:
|
|
69
|
+
single_response = self.client._request("get", {}, f"v1/files/{file_id}")
|
|
70
|
+
async for response in single_response:
|
|
71
|
+
return FileObject(**response)
|
|
72
|
+
raise MistralException("No response received")
|
|
73
|
+
|
|
74
|
+
async def list(self) -> Files:
|
|
75
|
+
single_response = self.client._request("get", {}, "v1/files")
|
|
76
|
+
async for response in single_response:
|
|
77
|
+
return Files(**response)
|
|
78
|
+
raise MistralException("No response received")
|
|
79
|
+
|
|
80
|
+
async def delete(self, file_id: str) -> FileDeleted:
|
|
81
|
+
single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
|
|
82
|
+
async for response in single_response:
|
|
83
|
+
return FileDeleted(**response)
|
|
84
|
+
raise MistralException("No response received")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
from mistralai.exceptions import (
|
|
5
|
+
MistralException,
|
|
6
|
+
)
|
|
7
|
+
from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class JobsClient:
|
|
11
|
+
def __init__(self, client: Any):
|
|
12
|
+
self.client = client
|
|
13
|
+
|
|
14
|
+
def create(
|
|
15
|
+
self,
|
|
16
|
+
model: str,
|
|
17
|
+
training_files: Union[list[str], None] = None,
|
|
18
|
+
validation_files: Union[list[str], None] = None,
|
|
19
|
+
hyperparameters: TrainingParameters = TrainingParameters(
|
|
20
|
+
training_steps=1800,
|
|
21
|
+
learning_rate=1.0e-4,
|
|
22
|
+
),
|
|
23
|
+
suffix: Union[str, None] = None,
|
|
24
|
+
integrations: Union[set[IntegrationIn], None] = None,
|
|
25
|
+
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
26
|
+
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
27
|
+
dry_run: bool = False,
|
|
28
|
+
) -> Union[Job, JobMetadata]:
|
|
29
|
+
# Handle deprecated arguments
|
|
30
|
+
if not training_files and training_file:
|
|
31
|
+
training_files = [training_file]
|
|
32
|
+
if not validation_files and validation_file:
|
|
33
|
+
validation_files = [validation_file]
|
|
34
|
+
single_response = self.client._request(
|
|
35
|
+
method="post",
|
|
36
|
+
json={
|
|
37
|
+
"model": model,
|
|
38
|
+
"training_files": training_files,
|
|
39
|
+
"validation_files": validation_files,
|
|
40
|
+
"hyperparameters": hyperparameters.dict(),
|
|
41
|
+
"suffix": suffix,
|
|
42
|
+
"integrations": integrations,
|
|
43
|
+
},
|
|
44
|
+
path="v1/fine_tuning/jobs",
|
|
45
|
+
params={"dry_run": dry_run},
|
|
46
|
+
)
|
|
47
|
+
for response in single_response:
|
|
48
|
+
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
49
|
+
raise MistralException("No response received")
|
|
50
|
+
|
|
51
|
+
def retrieve(self, job_id: str) -> DetailedJob:
|
|
52
|
+
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
53
|
+
for response in single_response:
|
|
54
|
+
return DetailedJob(**response)
|
|
55
|
+
raise MistralException("No response received")
|
|
56
|
+
|
|
57
|
+
def list(
|
|
58
|
+
self,
|
|
59
|
+
page: int = 0,
|
|
60
|
+
page_size: int = 10,
|
|
61
|
+
model: Optional[str] = None,
|
|
62
|
+
created_after: Optional[datetime] = None,
|
|
63
|
+
created_by_me: Optional[bool] = None,
|
|
64
|
+
status: Optional[str] = None,
|
|
65
|
+
wandb_project: Optional[str] = None,
|
|
66
|
+
wandb_name: Optional[str] = None,
|
|
67
|
+
suffix: Optional[str] = None,
|
|
68
|
+
) -> Jobs:
|
|
69
|
+
query_params = JobQueryFilter(
|
|
70
|
+
page=page,
|
|
71
|
+
page_size=page_size,
|
|
72
|
+
model=model,
|
|
73
|
+
created_after=created_after,
|
|
74
|
+
created_by_me=created_by_me,
|
|
75
|
+
status=status,
|
|
76
|
+
wandb_project=wandb_project,
|
|
77
|
+
wandb_name=wandb_name,
|
|
78
|
+
suffix=suffix,
|
|
79
|
+
).model_dump(exclude_none=True)
|
|
80
|
+
single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
|
|
81
|
+
for response in single_response:
|
|
82
|
+
return Jobs(**response)
|
|
83
|
+
raise MistralException("No response received")
|
|
84
|
+
|
|
85
|
+
def cancel(self, job_id: str) -> DetailedJob:
|
|
86
|
+
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
87
|
+
for response in single_response:
|
|
88
|
+
return DetailedJob(**response)
|
|
89
|
+
raise MistralException("No response received")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class JobsAsyncClient:
|
|
93
|
+
def __init__(self, client: Any):
|
|
94
|
+
self.client = client
|
|
95
|
+
|
|
96
|
+
async def create(
|
|
97
|
+
self,
|
|
98
|
+
model: str,
|
|
99
|
+
training_files: Union[list[str], None] = None,
|
|
100
|
+
validation_files: Union[list[str], None] = None,
|
|
101
|
+
hyperparameters: TrainingParameters = TrainingParameters(
|
|
102
|
+
training_steps=1800,
|
|
103
|
+
learning_rate=1.0e-4,
|
|
104
|
+
),
|
|
105
|
+
suffix: Union[str, None] = None,
|
|
106
|
+
integrations: Union[set[IntegrationIn], None] = None,
|
|
107
|
+
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
108
|
+
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
109
|
+
dry_run: bool = False,
|
|
110
|
+
) -> Union[Job, JobMetadata]:
|
|
111
|
+
# Handle deprecated arguments
|
|
112
|
+
if not training_files and training_file:
|
|
113
|
+
training_files = [training_file]
|
|
114
|
+
if not validation_files and validation_file:
|
|
115
|
+
validation_files = [validation_file]
|
|
116
|
+
|
|
117
|
+
single_response = self.client._request(
|
|
118
|
+
method="post",
|
|
119
|
+
json={
|
|
120
|
+
"model": model,
|
|
121
|
+
"training_files": training_files,
|
|
122
|
+
"validation_files": validation_files,
|
|
123
|
+
"hyperparameters": hyperparameters.dict(),
|
|
124
|
+
"suffix": suffix,
|
|
125
|
+
"integrations": integrations,
|
|
126
|
+
},
|
|
127
|
+
path="v1/fine_tuning/jobs",
|
|
128
|
+
params={"dry_run": dry_run},
|
|
129
|
+
)
|
|
130
|
+
async for response in single_response:
|
|
131
|
+
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
132
|
+
raise MistralException("No response received")
|
|
133
|
+
|
|
134
|
+
async def retrieve(self, job_id: str) -> DetailedJob:
|
|
135
|
+
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
136
|
+
async for response in single_response:
|
|
137
|
+
return DetailedJob(**response)
|
|
138
|
+
raise MistralException("No response received")
|
|
139
|
+
|
|
140
|
+
async def list(
|
|
141
|
+
self,
|
|
142
|
+
page: int = 0,
|
|
143
|
+
page_size: int = 10,
|
|
144
|
+
model: Optional[str] = None,
|
|
145
|
+
created_after: Optional[datetime] = None,
|
|
146
|
+
created_by_me: Optional[bool] = None,
|
|
147
|
+
status: Optional[str] = None,
|
|
148
|
+
wandb_project: Optional[str] = None,
|
|
149
|
+
wandb_name: Optional[str] = None,
|
|
150
|
+
suffix: Optional[str] = None,
|
|
151
|
+
) -> Jobs:
|
|
152
|
+
query_params = JobQueryFilter(
|
|
153
|
+
page=page,
|
|
154
|
+
page_size=page_size,
|
|
155
|
+
model=model,
|
|
156
|
+
created_after=created_after,
|
|
157
|
+
created_by_me=created_by_me,
|
|
158
|
+
status=status,
|
|
159
|
+
wandb_project=wandb_project,
|
|
160
|
+
wandb_name=wandb_name,
|
|
161
|
+
suffix=suffix,
|
|
162
|
+
).model_dump(exclude_none=True)
|
|
163
|
+
single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
|
|
164
|
+
async for response in single_response:
|
|
165
|
+
return Jobs(**response)
|
|
166
|
+
raise MistralException("No response received")
|
|
167
|
+
|
|
168
|
+
async def cancel(self, job_id: str) -> DetailedJob:
|
|
169
|
+
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
170
|
+
async for response in single_response:
|
|
171
|
+
return DetailedJob(**response)
|
|
172
|
+
raise MistralException("No response received")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import List, Optional
|
|
2
|
+
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
@@ -44,7 +44,7 @@ class ResponseFormat(BaseModel):
|
|
|
44
44
|
|
|
45
45
|
class ChatMessage(BaseModel):
|
|
46
46
|
role: str
|
|
47
|
-
content:
|
|
47
|
+
content: str
|
|
48
48
|
name: Optional[str] = None
|
|
49
49
|
tool_calls: Optional[List[ToolCall]] = None
|
|
50
50
|
tool_call_id: Optional[str] = None
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Literal, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FileObject(BaseModel):
|
|
7
|
+
id: str
|
|
8
|
+
object: str
|
|
9
|
+
bytes: int
|
|
10
|
+
created_at: int
|
|
11
|
+
filename: str
|
|
12
|
+
purpose: Optional[Literal["fine-tune"]] = "fine-tune"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FileDeleted(BaseModel):
|
|
16
|
+
id: str
|
|
17
|
+
object: str
|
|
18
|
+
deleted: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Files(BaseModel):
|
|
22
|
+
data: list[FileObject]
|
|
23
|
+
object: Literal["list"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Annotated, List, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TrainingParameters(BaseModel):
|
|
8
|
+
training_steps: int = Field(1800, le=10000, ge=1)
|
|
9
|
+
learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WandbIntegration(BaseModel):
|
|
13
|
+
type: Literal["wandb"] = "wandb"
|
|
14
|
+
project: str
|
|
15
|
+
name: Union[str, None] = None
|
|
16
|
+
run_name: Union[str, None] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WandbIntegrationIn(WandbIntegration):
|
|
20
|
+
api_key: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
|
|
24
|
+
IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class JobMetadata(BaseModel):
|
|
28
|
+
object: Literal["job.metadata"] = "job.metadata"
|
|
29
|
+
training_steps: int
|
|
30
|
+
train_tokens_per_step: int
|
|
31
|
+
data_tokens: int
|
|
32
|
+
train_tokens: int
|
|
33
|
+
epochs: float
|
|
34
|
+
expected_duration_seconds: Optional[int]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Job(BaseModel):
|
|
38
|
+
id: str
|
|
39
|
+
hyperparameters: TrainingParameters
|
|
40
|
+
fine_tuned_model: Union[str, None]
|
|
41
|
+
model: str
|
|
42
|
+
status: Literal[
|
|
43
|
+
"QUEUED",
|
|
44
|
+
"STARTED",
|
|
45
|
+
"RUNNING",
|
|
46
|
+
"FAILED",
|
|
47
|
+
"SUCCESS",
|
|
48
|
+
"CANCELLED",
|
|
49
|
+
"CANCELLATION_REQUESTED",
|
|
50
|
+
]
|
|
51
|
+
job_type: str
|
|
52
|
+
created_at: int
|
|
53
|
+
modified_at: int
|
|
54
|
+
training_files: list[str]
|
|
55
|
+
validation_files: Union[list[str], None] = []
|
|
56
|
+
object: Literal["job"]
|
|
57
|
+
integrations: List[Integration] = []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Event(BaseModel):
|
|
61
|
+
name: str
|
|
62
|
+
data: Union[dict, None] = None
|
|
63
|
+
created_at: int
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Metric(BaseModel):
|
|
67
|
+
train_loss: Union[float, None] = None
|
|
68
|
+
valid_loss: Union[float, None] = None
|
|
69
|
+
valid_mean_token_accuracy: Union[float, None] = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Checkpoint(BaseModel):
|
|
73
|
+
metrics: Metric
|
|
74
|
+
step_number: int
|
|
75
|
+
created_at: int
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class JobQueryFilter(BaseModel):
|
|
79
|
+
page: int = 0
|
|
80
|
+
page_size: int = 100
|
|
81
|
+
model: Optional[str] = None
|
|
82
|
+
created_after: Optional[datetime] = None
|
|
83
|
+
created_by_me: Optional[bool] = None
|
|
84
|
+
status: Optional[str] = None
|
|
85
|
+
wandb_project: Optional[str] = None
|
|
86
|
+
wandb_name: Optional[str] = None
|
|
87
|
+
suffix: Optional[str] = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DetailedJob(Job):
|
|
91
|
+
events: list[Event] = []
|
|
92
|
+
checkpoints: list[Checkpoint] = []
|
|
93
|
+
estimated_start_time: Optional[int] = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Jobs(BaseModel):
|
|
97
|
+
data: list[Job] = []
|
|
98
|
+
object: Literal["list"]
|
|
@@ -7,15 +7,15 @@ class ModelPermission(BaseModel):
|
|
|
7
7
|
id: str
|
|
8
8
|
object: str
|
|
9
9
|
created: int
|
|
10
|
-
allow_create_engine: bool = False
|
|
10
|
+
allow_create_engine: Optional[bool] = False
|
|
11
11
|
allow_sampling: bool = True
|
|
12
12
|
allow_logprobs: bool = True
|
|
13
|
-
allow_search_indices: bool = False
|
|
13
|
+
allow_search_indices: Optional[bool] = False
|
|
14
14
|
allow_view: bool = True
|
|
15
15
|
allow_fine_tuning: bool = False
|
|
16
16
|
organization: str = "*"
|
|
17
17
|
group: Optional[str] = None
|
|
18
|
-
is_blocking: bool = False
|
|
18
|
+
is_blocking: Optional[bool] = False
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ModelCard(BaseModel):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|