mistralai 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mistralai-0.3.0 → mistralai-0.4.0}/PKG-INFO +3 -3
- {mistralai-0.3.0 → mistralai-0.4.0}/pyproject.toml +5 -5
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/async_client.py +13 -1
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/client.py +13 -1
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/client_base.py +3 -2
- mistralai-0.4.0/src/mistralai/files.py +84 -0
- mistralai-0.4.0/src/mistralai/jobs.py +172 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/models/chat_completion.py +2 -2
- mistralai-0.4.0/src/mistralai/models/files.py +23 -0
- mistralai-0.4.0/src/mistralai/models/jobs.py +98 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/models/models.py +3 -3
- {mistralai-0.3.0 → mistralai-0.4.0}/LICENSE +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/README.md +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/__init__.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/constants.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/exceptions.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/models/__init__.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/models/common.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/models/embeddings.py +0 -0
- {mistralai-0.3.0 → mistralai-0.4.0}/src/mistralai/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: mistralai
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary:
|
|
5
5
|
Author: Bam4d
|
|
6
6
|
Author-email: bam4d@mistral.ai
|
|
@@ -10,8 +10,8 @@ Classifier: Programming Language :: Python :: 3.9
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Requires-Dist: httpx (>=0.25.
|
|
14
|
-
Requires-Dist: orjson (>=3.9.10,<
|
|
13
|
+
Requires-Dist: httpx (>=0.25,<0.26)
|
|
14
|
+
Requires-Dist: orjson (>=3.9.10,<3.11)
|
|
15
15
|
Requires-Dist: pydantic (>=2.5.2,<3.0.0)
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "mistralai"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = ["Bam4d <bam4d@mistral.ai>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -23,10 +23,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
[tool.poetry.dependencies]
|
|
26
|
-
python = "^3.9"
|
|
27
|
-
orjson = "^3.9.10"
|
|
28
|
-
pydantic = "^2.5.2"
|
|
29
|
-
httpx = "
|
|
26
|
+
python = "^3.9,<4.0"
|
|
27
|
+
orjson = "^3.9.10,<3.11"
|
|
28
|
+
pydantic = "^2.5.2,<3"
|
|
29
|
+
httpx = "^0.25,<1"
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
[tool.poetry.group.dev.dependencies]
|
|
@@ -20,6 +20,8 @@ from mistralai.exceptions import (
|
|
|
20
20
|
MistralConnectionException,
|
|
21
21
|
MistralException,
|
|
22
22
|
)
|
|
23
|
+
from mistralai.files import FilesAsyncClient
|
|
24
|
+
from mistralai.jobs import JobsAsyncClient
|
|
23
25
|
from mistralai.models.chat_completion import (
|
|
24
26
|
ChatCompletionResponse,
|
|
25
27
|
ChatCompletionStreamResponse,
|
|
@@ -47,6 +49,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
47
49
|
limits=Limits(max_connections=max_concurrent_requests),
|
|
48
50
|
transport=AsyncHTTPTransport(retries=max_retries),
|
|
49
51
|
)
|
|
52
|
+
self.files = FilesAsyncClient(self)
|
|
53
|
+
self.jobs = JobsAsyncClient(self)
|
|
50
54
|
|
|
51
55
|
async def close(self) -> None:
|
|
52
56
|
await self._client.aclose()
|
|
@@ -96,15 +100,19 @@ class MistralAsyncClient(ClientBase):
|
|
|
96
100
|
path: str,
|
|
97
101
|
stream: bool = False,
|
|
98
102
|
attempt: int = 1,
|
|
103
|
+
data: Optional[Dict[str, Any]] = None,
|
|
104
|
+
**kwargs: Any,
|
|
99
105
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
100
106
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
101
107
|
headers = {
|
|
102
108
|
"Accept": accept_header,
|
|
103
109
|
"User-Agent": f"mistral-client-python/{self._version}",
|
|
104
110
|
"Authorization": f"Bearer {self._api_key}",
|
|
105
|
-
"Content-Type": "application/json",
|
|
106
111
|
}
|
|
107
112
|
|
|
113
|
+
if json is not None:
|
|
114
|
+
headers["Content-Type"] = "application/json"
|
|
115
|
+
|
|
108
116
|
url = posixpath.join(self._endpoint, path)
|
|
109
117
|
|
|
110
118
|
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
@@ -118,6 +126,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
118
126
|
url,
|
|
119
127
|
headers=headers,
|
|
120
128
|
json=json,
|
|
129
|
+
data=data,
|
|
130
|
+
**kwargs,
|
|
121
131
|
) as response:
|
|
122
132
|
await self._check_streaming_response(response)
|
|
123
133
|
|
|
@@ -132,6 +142,8 @@ class MistralAsyncClient(ClientBase):
|
|
|
132
142
|
url,
|
|
133
143
|
headers=headers,
|
|
134
144
|
json=json,
|
|
145
|
+
data=data,
|
|
146
|
+
**kwargs,
|
|
135
147
|
)
|
|
136
148
|
|
|
137
149
|
yield await self._check_response(response)
|
|
@@ -13,6 +13,8 @@ from mistralai.exceptions import (
|
|
|
13
13
|
MistralConnectionException,
|
|
14
14
|
MistralException,
|
|
15
15
|
)
|
|
16
|
+
from mistralai.files import FilesClient
|
|
17
|
+
from mistralai.jobs import JobsClient
|
|
16
18
|
from mistralai.models.chat_completion import (
|
|
17
19
|
ChatCompletionResponse,
|
|
18
20
|
ChatCompletionStreamResponse,
|
|
@@ -40,6 +42,8 @@ class MistralClient(ClientBase):
|
|
|
40
42
|
self._client = Client(
|
|
41
43
|
follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
|
|
42
44
|
)
|
|
45
|
+
self.files = FilesClient(self)
|
|
46
|
+
self.jobs = JobsClient(self)
|
|
43
47
|
|
|
44
48
|
def __del__(self) -> None:
|
|
45
49
|
self._client.close()
|
|
@@ -89,15 +93,19 @@ class MistralClient(ClientBase):
|
|
|
89
93
|
path: str,
|
|
90
94
|
stream: bool = False,
|
|
91
95
|
attempt: int = 1,
|
|
96
|
+
data: Optional[Dict[str, Any]] = None,
|
|
97
|
+
**kwargs: Any,
|
|
92
98
|
) -> Iterator[Dict[str, Any]]:
|
|
93
99
|
accept_header = "text/event-stream" if stream else "application/json"
|
|
94
100
|
headers = {
|
|
95
101
|
"Accept": accept_header,
|
|
96
102
|
"User-Agent": f"mistral-client-python/{self._version}",
|
|
97
103
|
"Authorization": f"Bearer {self._api_key}",
|
|
98
|
-
"Content-Type": "application/json",
|
|
99
104
|
}
|
|
100
105
|
|
|
106
|
+
if json is not None:
|
|
107
|
+
headers["Content-Type"] = "application/json"
|
|
108
|
+
|
|
101
109
|
url = posixpath.join(self._endpoint, path)
|
|
102
110
|
|
|
103
111
|
self._logger.debug(f"Sending request: {method} {url} {json}")
|
|
@@ -111,6 +119,8 @@ class MistralClient(ClientBase):
|
|
|
111
119
|
url,
|
|
112
120
|
headers=headers,
|
|
113
121
|
json=json,
|
|
122
|
+
data=data,
|
|
123
|
+
**kwargs,
|
|
114
124
|
) as response:
|
|
115
125
|
self._check_streaming_response(response)
|
|
116
126
|
|
|
@@ -125,6 +135,8 @@ class MistralClient(ClientBase):
|
|
|
125
135
|
url,
|
|
126
136
|
headers=headers,
|
|
127
137
|
json=json,
|
|
138
|
+
data=data,
|
|
139
|
+
**kwargs,
|
|
128
140
|
)
|
|
129
141
|
|
|
130
142
|
yield self._check_response(response)
|
|
@@ -10,7 +10,7 @@ from mistralai.exceptions import (
|
|
|
10
10
|
)
|
|
11
11
|
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
|
|
12
12
|
|
|
13
|
-
CLIENT_VERSION = "0.
|
|
13
|
+
CLIENT_VERSION = "0.4.0"
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class ClientBase(ABC):
|
|
@@ -146,7 +146,6 @@ class ClientBase(ABC):
|
|
|
146
146
|
) -> Dict[str, Any]:
|
|
147
147
|
request_data: Dict[str, Any] = {
|
|
148
148
|
"messages": self._parse_messages(messages),
|
|
149
|
-
"safe_prompt": safe_prompt,
|
|
150
149
|
}
|
|
151
150
|
|
|
152
151
|
if model is not None:
|
|
@@ -162,6 +161,8 @@ class ClientBase(ABC):
|
|
|
162
161
|
)
|
|
163
162
|
)
|
|
164
163
|
|
|
164
|
+
if safe_prompt:
|
|
165
|
+
request_data["safe_prompt"] = safe_prompt
|
|
165
166
|
if tools is not None:
|
|
166
167
|
request_data["tools"] = self._parse_tools(tools)
|
|
167
168
|
if stream is not None:
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from mistralai.exceptions import (
|
|
4
|
+
MistralException,
|
|
5
|
+
)
|
|
6
|
+
from mistralai.models.files import FileDeleted, FileObject, Files
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FilesClient:
|
|
10
|
+
def __init__(self, client: Any):
|
|
11
|
+
self.client = client
|
|
12
|
+
|
|
13
|
+
def create(
|
|
14
|
+
self,
|
|
15
|
+
file: bytes,
|
|
16
|
+
purpose: str = "fine-tune",
|
|
17
|
+
) -> FileObject:
|
|
18
|
+
single_response = self.client._request(
|
|
19
|
+
"post",
|
|
20
|
+
None,
|
|
21
|
+
"v1/files",
|
|
22
|
+
files={"file": file},
|
|
23
|
+
data={"purpose": purpose},
|
|
24
|
+
)
|
|
25
|
+
for response in single_response:
|
|
26
|
+
return FileObject(**response)
|
|
27
|
+
raise MistralException("No response received")
|
|
28
|
+
|
|
29
|
+
def retrieve(self, file_id: str) -> FileObject:
|
|
30
|
+
single_response = self.client._request("get", {}, f"v1/files/{file_id}")
|
|
31
|
+
for response in single_response:
|
|
32
|
+
return FileObject(**response)
|
|
33
|
+
raise MistralException("No response received")
|
|
34
|
+
|
|
35
|
+
def list(self) -> Files:
|
|
36
|
+
single_response = self.client._request("get", {}, "v1/files")
|
|
37
|
+
for response in single_response:
|
|
38
|
+
return Files(**response)
|
|
39
|
+
raise MistralException("No response received")
|
|
40
|
+
|
|
41
|
+
def delete(self, file_id: str) -> FileDeleted:
|
|
42
|
+
single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
|
|
43
|
+
for response in single_response:
|
|
44
|
+
return FileDeleted(**response)
|
|
45
|
+
raise MistralException("No response received")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FilesAsyncClient:
|
|
49
|
+
def __init__(self, client: Any):
|
|
50
|
+
self.client = client
|
|
51
|
+
|
|
52
|
+
async def create(
|
|
53
|
+
self,
|
|
54
|
+
file: bytes,
|
|
55
|
+
purpose: str = "fine-tune",
|
|
56
|
+
) -> FileObject:
|
|
57
|
+
single_response = self.client._request(
|
|
58
|
+
"post",
|
|
59
|
+
None,
|
|
60
|
+
"v1/files",
|
|
61
|
+
files={"file": file},
|
|
62
|
+
data={"purpose": purpose},
|
|
63
|
+
)
|
|
64
|
+
async for response in single_response:
|
|
65
|
+
return FileObject(**response)
|
|
66
|
+
raise MistralException("No response received")
|
|
67
|
+
|
|
68
|
+
async def retrieve(self, file_id: str) -> FileObject:
|
|
69
|
+
single_response = self.client._request("get", {}, f"v1/files/{file_id}")
|
|
70
|
+
async for response in single_response:
|
|
71
|
+
return FileObject(**response)
|
|
72
|
+
raise MistralException("No response received")
|
|
73
|
+
|
|
74
|
+
async def list(self) -> Files:
|
|
75
|
+
single_response = self.client._request("get", {}, "v1/files")
|
|
76
|
+
async for response in single_response:
|
|
77
|
+
return Files(**response)
|
|
78
|
+
raise MistralException("No response received")
|
|
79
|
+
|
|
80
|
+
async def delete(self, file_id: str) -> FileDeleted:
|
|
81
|
+
single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
|
|
82
|
+
async for response in single_response:
|
|
83
|
+
return FileDeleted(**response)
|
|
84
|
+
raise MistralException("No response received")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Optional, Union
|
|
3
|
+
|
|
4
|
+
from mistralai.exceptions import (
|
|
5
|
+
MistralException,
|
|
6
|
+
)
|
|
7
|
+
from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class JobsClient:
|
|
11
|
+
def __init__(self, client: Any):
|
|
12
|
+
self.client = client
|
|
13
|
+
|
|
14
|
+
def create(
|
|
15
|
+
self,
|
|
16
|
+
model: str,
|
|
17
|
+
training_files: Union[list[str], None] = None,
|
|
18
|
+
validation_files: Union[list[str], None] = None,
|
|
19
|
+
hyperparameters: TrainingParameters = TrainingParameters(
|
|
20
|
+
training_steps=1800,
|
|
21
|
+
learning_rate=1.0e-4,
|
|
22
|
+
),
|
|
23
|
+
suffix: Union[str, None] = None,
|
|
24
|
+
integrations: Union[set[IntegrationIn], None] = None,
|
|
25
|
+
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
26
|
+
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
27
|
+
dry_run: bool = False,
|
|
28
|
+
) -> Union[Job, JobMetadata]:
|
|
29
|
+
# Handle deprecated arguments
|
|
30
|
+
if not training_files and training_file:
|
|
31
|
+
training_files = [training_file]
|
|
32
|
+
if not validation_files and validation_file:
|
|
33
|
+
validation_files = [validation_file]
|
|
34
|
+
single_response = self.client._request(
|
|
35
|
+
method="post",
|
|
36
|
+
json={
|
|
37
|
+
"model": model,
|
|
38
|
+
"training_files": training_files,
|
|
39
|
+
"validation_files": validation_files,
|
|
40
|
+
"hyperparameters": hyperparameters.dict(),
|
|
41
|
+
"suffix": suffix,
|
|
42
|
+
"integrations": integrations,
|
|
43
|
+
},
|
|
44
|
+
path="v1/fine_tuning/jobs",
|
|
45
|
+
params={"dry_run": dry_run},
|
|
46
|
+
)
|
|
47
|
+
for response in single_response:
|
|
48
|
+
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
49
|
+
raise MistralException("No response received")
|
|
50
|
+
|
|
51
|
+
def retrieve(self, job_id: str) -> DetailedJob:
|
|
52
|
+
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
53
|
+
for response in single_response:
|
|
54
|
+
return DetailedJob(**response)
|
|
55
|
+
raise MistralException("No response received")
|
|
56
|
+
|
|
57
|
+
def list(
|
|
58
|
+
self,
|
|
59
|
+
page: int = 0,
|
|
60
|
+
page_size: int = 10,
|
|
61
|
+
model: Optional[str] = None,
|
|
62
|
+
created_after: Optional[datetime] = None,
|
|
63
|
+
created_by_me: Optional[bool] = None,
|
|
64
|
+
status: Optional[str] = None,
|
|
65
|
+
wandb_project: Optional[str] = None,
|
|
66
|
+
wandb_name: Optional[str] = None,
|
|
67
|
+
suffix: Optional[str] = None,
|
|
68
|
+
) -> Jobs:
|
|
69
|
+
query_params = JobQueryFilter(
|
|
70
|
+
page=page,
|
|
71
|
+
page_size=page_size,
|
|
72
|
+
model=model,
|
|
73
|
+
created_after=created_after,
|
|
74
|
+
created_by_me=created_by_me,
|
|
75
|
+
status=status,
|
|
76
|
+
wandb_project=wandb_project,
|
|
77
|
+
wandb_name=wandb_name,
|
|
78
|
+
suffix=suffix,
|
|
79
|
+
).model_dump(exclude_none=True)
|
|
80
|
+
single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
|
|
81
|
+
for response in single_response:
|
|
82
|
+
return Jobs(**response)
|
|
83
|
+
raise MistralException("No response received")
|
|
84
|
+
|
|
85
|
+
def cancel(self, job_id: str) -> DetailedJob:
|
|
86
|
+
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
87
|
+
for response in single_response:
|
|
88
|
+
return DetailedJob(**response)
|
|
89
|
+
raise MistralException("No response received")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class JobsAsyncClient:
|
|
93
|
+
def __init__(self, client: Any):
|
|
94
|
+
self.client = client
|
|
95
|
+
|
|
96
|
+
async def create(
|
|
97
|
+
self,
|
|
98
|
+
model: str,
|
|
99
|
+
training_files: Union[list[str], None] = None,
|
|
100
|
+
validation_files: Union[list[str], None] = None,
|
|
101
|
+
hyperparameters: TrainingParameters = TrainingParameters(
|
|
102
|
+
training_steps=1800,
|
|
103
|
+
learning_rate=1.0e-4,
|
|
104
|
+
),
|
|
105
|
+
suffix: Union[str, None] = None,
|
|
106
|
+
integrations: Union[set[IntegrationIn], None] = None,
|
|
107
|
+
training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
108
|
+
validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
|
|
109
|
+
dry_run: bool = False,
|
|
110
|
+
) -> Union[Job, JobMetadata]:
|
|
111
|
+
# Handle deprecated arguments
|
|
112
|
+
if not training_files and training_file:
|
|
113
|
+
training_files = [training_file]
|
|
114
|
+
if not validation_files and validation_file:
|
|
115
|
+
validation_files = [validation_file]
|
|
116
|
+
|
|
117
|
+
single_response = self.client._request(
|
|
118
|
+
method="post",
|
|
119
|
+
json={
|
|
120
|
+
"model": model,
|
|
121
|
+
"training_files": training_files,
|
|
122
|
+
"validation_files": validation_files,
|
|
123
|
+
"hyperparameters": hyperparameters.dict(),
|
|
124
|
+
"suffix": suffix,
|
|
125
|
+
"integrations": integrations,
|
|
126
|
+
},
|
|
127
|
+
path="v1/fine_tuning/jobs",
|
|
128
|
+
params={"dry_run": dry_run},
|
|
129
|
+
)
|
|
130
|
+
async for response in single_response:
|
|
131
|
+
return Job(**response) if not dry_run else JobMetadata(**response)
|
|
132
|
+
raise MistralException("No response received")
|
|
133
|
+
|
|
134
|
+
async def retrieve(self, job_id: str) -> DetailedJob:
|
|
135
|
+
single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
|
|
136
|
+
async for response in single_response:
|
|
137
|
+
return DetailedJob(**response)
|
|
138
|
+
raise MistralException("No response received")
|
|
139
|
+
|
|
140
|
+
async def list(
|
|
141
|
+
self,
|
|
142
|
+
page: int = 0,
|
|
143
|
+
page_size: int = 10,
|
|
144
|
+
model: Optional[str] = None,
|
|
145
|
+
created_after: Optional[datetime] = None,
|
|
146
|
+
created_by_me: Optional[bool] = None,
|
|
147
|
+
status: Optional[str] = None,
|
|
148
|
+
wandb_project: Optional[str] = None,
|
|
149
|
+
wandb_name: Optional[str] = None,
|
|
150
|
+
suffix: Optional[str] = None,
|
|
151
|
+
) -> Jobs:
|
|
152
|
+
query_params = JobQueryFilter(
|
|
153
|
+
page=page,
|
|
154
|
+
page_size=page_size,
|
|
155
|
+
model=model,
|
|
156
|
+
created_after=created_after,
|
|
157
|
+
created_by_me=created_by_me,
|
|
158
|
+
status=status,
|
|
159
|
+
wandb_project=wandb_project,
|
|
160
|
+
wandb_name=wandb_name,
|
|
161
|
+
suffix=suffix,
|
|
162
|
+
).model_dump(exclude_none=True)
|
|
163
|
+
single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
|
|
164
|
+
async for response in single_response:
|
|
165
|
+
return Jobs(**response)
|
|
166
|
+
raise MistralException("No response received")
|
|
167
|
+
|
|
168
|
+
async def cancel(self, job_id: str) -> DetailedJob:
|
|
169
|
+
single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
|
|
170
|
+
async for response in single_response:
|
|
171
|
+
return DetailedJob(**response)
|
|
172
|
+
raise MistralException("No response received")
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import List, Optional
|
|
2
|
+
from typing import List, Optional
|
|
3
3
|
|
|
4
4
|
from pydantic import BaseModel
|
|
5
5
|
|
|
@@ -44,7 +44,7 @@ class ResponseFormat(BaseModel):
|
|
|
44
44
|
|
|
45
45
|
class ChatMessage(BaseModel):
|
|
46
46
|
role: str
|
|
47
|
-
content:
|
|
47
|
+
content: str
|
|
48
48
|
name: Optional[str] = None
|
|
49
49
|
tool_calls: Optional[List[ToolCall]] = None
|
|
50
50
|
tool_call_id: Optional[str] = None
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Literal, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class FileObject(BaseModel):
|
|
7
|
+
id: str
|
|
8
|
+
object: str
|
|
9
|
+
bytes: int
|
|
10
|
+
created_at: int
|
|
11
|
+
filename: str
|
|
12
|
+
purpose: Optional[Literal["fine-tune"]] = "fine-tune"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FileDeleted(BaseModel):
|
|
16
|
+
id: str
|
|
17
|
+
object: str
|
|
18
|
+
deleted: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Files(BaseModel):
|
|
22
|
+
data: list[FileObject]
|
|
23
|
+
object: Literal["list"]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Annotated, List, Literal, Optional, Union
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TrainingParameters(BaseModel):
|
|
8
|
+
training_steps: int = Field(1800, le=10000, ge=1)
|
|
9
|
+
learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class WandbIntegration(BaseModel):
|
|
13
|
+
type: Literal["wandb"] = "wandb"
|
|
14
|
+
project: str
|
|
15
|
+
name: Union[str, None] = None
|
|
16
|
+
run_name: Union[str, None] = None
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class WandbIntegrationIn(WandbIntegration):
|
|
20
|
+
api_key: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
|
|
24
|
+
IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class JobMetadata(BaseModel):
|
|
28
|
+
object: Literal["job.metadata"] = "job.metadata"
|
|
29
|
+
training_steps: int
|
|
30
|
+
train_tokens_per_step: int
|
|
31
|
+
data_tokens: int
|
|
32
|
+
train_tokens: int
|
|
33
|
+
epochs: float
|
|
34
|
+
expected_duration_seconds: Optional[int]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Job(BaseModel):
|
|
38
|
+
id: str
|
|
39
|
+
hyperparameters: TrainingParameters
|
|
40
|
+
fine_tuned_model: Union[str, None]
|
|
41
|
+
model: str
|
|
42
|
+
status: Literal[
|
|
43
|
+
"QUEUED",
|
|
44
|
+
"STARTED",
|
|
45
|
+
"RUNNING",
|
|
46
|
+
"FAILED",
|
|
47
|
+
"SUCCESS",
|
|
48
|
+
"CANCELLED",
|
|
49
|
+
"CANCELLATION_REQUESTED",
|
|
50
|
+
]
|
|
51
|
+
job_type: str
|
|
52
|
+
created_at: int
|
|
53
|
+
modified_at: int
|
|
54
|
+
training_files: list[str]
|
|
55
|
+
validation_files: Union[list[str], None] = []
|
|
56
|
+
object: Literal["job"]
|
|
57
|
+
integrations: List[Integration] = []
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Event(BaseModel):
|
|
61
|
+
name: str
|
|
62
|
+
data: Union[dict, None] = None
|
|
63
|
+
created_at: int
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Metric(BaseModel):
|
|
67
|
+
train_loss: Union[float, None] = None
|
|
68
|
+
valid_loss: Union[float, None] = None
|
|
69
|
+
valid_mean_token_accuracy: Union[float, None] = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Checkpoint(BaseModel):
|
|
73
|
+
metrics: Metric
|
|
74
|
+
step_number: int
|
|
75
|
+
created_at: int
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class JobQueryFilter(BaseModel):
|
|
79
|
+
page: int = 0
|
|
80
|
+
page_size: int = 100
|
|
81
|
+
model: Optional[str] = None
|
|
82
|
+
created_after: Optional[datetime] = None
|
|
83
|
+
created_by_me: Optional[bool] = None
|
|
84
|
+
status: Optional[str] = None
|
|
85
|
+
wandb_project: Optional[str] = None
|
|
86
|
+
wandb_name: Optional[str] = None
|
|
87
|
+
suffix: Optional[str] = None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class DetailedJob(Job):
|
|
91
|
+
events: list[Event] = []
|
|
92
|
+
checkpoints: list[Checkpoint] = []
|
|
93
|
+
estimated_start_time: Optional[int] = None
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Jobs(BaseModel):
|
|
97
|
+
data: list[Job] = []
|
|
98
|
+
object: Literal["list"]
|
|
@@ -7,15 +7,15 @@ class ModelPermission(BaseModel):
|
|
|
7
7
|
id: str
|
|
8
8
|
object: str
|
|
9
9
|
created: int
|
|
10
|
-
allow_create_engine: bool = False
|
|
10
|
+
allow_create_engine: Optional[bool] = False
|
|
11
11
|
allow_sampling: bool = True
|
|
12
12
|
allow_logprobs: bool = True
|
|
13
|
-
allow_search_indices: bool = False
|
|
13
|
+
allow_search_indices: Optional[bool] = False
|
|
14
14
|
allow_view: bool = True
|
|
15
15
|
allow_fine_tuning: bool = False
|
|
16
16
|
organization: str = "*"
|
|
17
17
|
group: Optional[str] = None
|
|
18
|
-
is_blocking: bool = False
|
|
18
|
+
is_blocking: Optional[bool] = False
|
|
19
19
|
|
|
20
20
|
|
|
21
21
|
class ModelCard(BaseModel):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|