mistralai 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary:
5
5
  Author: Bam4d
6
6
  Author-email: bam4d@mistral.ai
@@ -10,8 +10,8 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
- Requires-Dist: httpx (>=0.25.2,<1)
14
- Requires-Dist: orjson (>=3.9.10,<4.0.0)
13
+ Requires-Dist: httpx (>=0.25,<0.26)
14
+ Requires-Dist: orjson (>=3.9.10,<3.11)
15
15
  Requires-Dist: pydantic (>=2.5.2,<3.0.0)
16
16
  Description-Content-Type: text/markdown
17
17
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "mistralai"
3
- version = "0.3.0"
3
+ version = "0.4.0"
4
4
  description = ""
5
5
  authors = ["Bam4d <bam4d@mistral.ai>"]
6
6
  readme = "README.md"
@@ -23,10 +23,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
23
23
 
24
24
 
25
25
  [tool.poetry.dependencies]
26
- python = "^3.9"
27
- orjson = "^3.9.10"
28
- pydantic = "^2.5.2"
29
- httpx = ">= 0.25.2, < 1"
26
+ python = "^3.9,<4.0"
27
+ orjson = "^3.9.10,<3.11"
28
+ pydantic = "^2.5.2,<3"
29
+ httpx = "^0.25,<1"
30
30
 
31
31
 
32
32
  [tool.poetry.group.dev.dependencies]
@@ -20,6 +20,8 @@ from mistralai.exceptions import (
20
20
  MistralConnectionException,
21
21
  MistralException,
22
22
  )
23
+ from mistralai.files import FilesAsyncClient
24
+ from mistralai.jobs import JobsAsyncClient
23
25
  from mistralai.models.chat_completion import (
24
26
  ChatCompletionResponse,
25
27
  ChatCompletionStreamResponse,
@@ -47,6 +49,8 @@ class MistralAsyncClient(ClientBase):
47
49
  limits=Limits(max_connections=max_concurrent_requests),
48
50
  transport=AsyncHTTPTransport(retries=max_retries),
49
51
  )
52
+ self.files = FilesAsyncClient(self)
53
+ self.jobs = JobsAsyncClient(self)
50
54
 
51
55
  async def close(self) -> None:
52
56
  await self._client.aclose()
@@ -96,15 +100,19 @@ class MistralAsyncClient(ClientBase):
96
100
  path: str,
97
101
  stream: bool = False,
98
102
  attempt: int = 1,
103
+ data: Optional[Dict[str, Any]] = None,
104
+ **kwargs: Any,
99
105
  ) -> AsyncGenerator[Dict[str, Any], None]:
100
106
  accept_header = "text/event-stream" if stream else "application/json"
101
107
  headers = {
102
108
  "Accept": accept_header,
103
109
  "User-Agent": f"mistral-client-python/{self._version}",
104
110
  "Authorization": f"Bearer {self._api_key}",
105
- "Content-Type": "application/json",
106
111
  }
107
112
 
113
+ if json is not None:
114
+ headers["Content-Type"] = "application/json"
115
+
108
116
  url = posixpath.join(self._endpoint, path)
109
117
 
110
118
  self._logger.debug(f"Sending request: {method} {url} {json}")
@@ -118,6 +126,8 @@ class MistralAsyncClient(ClientBase):
118
126
  url,
119
127
  headers=headers,
120
128
  json=json,
129
+ data=data,
130
+ **kwargs,
121
131
  ) as response:
122
132
  await self._check_streaming_response(response)
123
133
 
@@ -132,6 +142,8 @@ class MistralAsyncClient(ClientBase):
132
142
  url,
133
143
  headers=headers,
134
144
  json=json,
145
+ data=data,
146
+ **kwargs,
135
147
  )
136
148
 
137
149
  yield await self._check_response(response)
@@ -13,6 +13,8 @@ from mistralai.exceptions import (
13
13
  MistralConnectionException,
14
14
  MistralException,
15
15
  )
16
+ from mistralai.files import FilesClient
17
+ from mistralai.jobs import JobsClient
16
18
  from mistralai.models.chat_completion import (
17
19
  ChatCompletionResponse,
18
20
  ChatCompletionStreamResponse,
@@ -40,6 +42,8 @@ class MistralClient(ClientBase):
40
42
  self._client = Client(
41
43
  follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
42
44
  )
45
+ self.files = FilesClient(self)
46
+ self.jobs = JobsClient(self)
43
47
 
44
48
  def __del__(self) -> None:
45
49
  self._client.close()
@@ -89,15 +93,19 @@ class MistralClient(ClientBase):
89
93
  path: str,
90
94
  stream: bool = False,
91
95
  attempt: int = 1,
96
+ data: Optional[Dict[str, Any]] = None,
97
+ **kwargs: Any,
92
98
  ) -> Iterator[Dict[str, Any]]:
93
99
  accept_header = "text/event-stream" if stream else "application/json"
94
100
  headers = {
95
101
  "Accept": accept_header,
96
102
  "User-Agent": f"mistral-client-python/{self._version}",
97
103
  "Authorization": f"Bearer {self._api_key}",
98
- "Content-Type": "application/json",
99
104
  }
100
105
 
106
+ if json is not None:
107
+ headers["Content-Type"] = "application/json"
108
+
101
109
  url = posixpath.join(self._endpoint, path)
102
110
 
103
111
  self._logger.debug(f"Sending request: {method} {url} {json}")
@@ -111,6 +119,8 @@ class MistralClient(ClientBase):
111
119
  url,
112
120
  headers=headers,
113
121
  json=json,
122
+ data=data,
123
+ **kwargs,
114
124
  ) as response:
115
125
  self._check_streaming_response(response)
116
126
 
@@ -125,6 +135,8 @@ class MistralClient(ClientBase):
125
135
  url,
126
136
  headers=headers,
127
137
  json=json,
138
+ data=data,
139
+ **kwargs,
128
140
  )
129
141
 
130
142
  yield self._check_response(response)
@@ -10,7 +10,7 @@ from mistralai.exceptions import (
10
10
  )
11
11
  from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
12
 
13
- CLIENT_VERSION = "0.2.0"
13
+ CLIENT_VERSION = "0.4.0"
14
14
 
15
15
 
16
16
  class ClientBase(ABC):
@@ -146,7 +146,6 @@ class ClientBase(ABC):
146
146
  ) -> Dict[str, Any]:
147
147
  request_data: Dict[str, Any] = {
148
148
  "messages": self._parse_messages(messages),
149
- "safe_prompt": safe_prompt,
150
149
  }
151
150
 
152
151
  if model is not None:
@@ -162,6 +161,8 @@ class ClientBase(ABC):
162
161
  )
163
162
  )
164
163
 
164
+ if safe_prompt:
165
+ request_data["safe_prompt"] = safe_prompt
165
166
  if tools is not None:
166
167
  request_data["tools"] = self._parse_tools(tools)
167
168
  if stream is not None:
@@ -0,0 +1,84 @@
1
+ from typing import Any
2
+
3
+ from mistralai.exceptions import (
4
+ MistralException,
5
+ )
6
+ from mistralai.models.files import FileDeleted, FileObject, Files
7
+
8
+
9
+ class FilesClient:
10
+ def __init__(self, client: Any):
11
+ self.client = client
12
+
13
+ def create(
14
+ self,
15
+ file: bytes,
16
+ purpose: str = "fine-tune",
17
+ ) -> FileObject:
18
+ single_response = self.client._request(
19
+ "post",
20
+ None,
21
+ "v1/files",
22
+ files={"file": file},
23
+ data={"purpose": purpose},
24
+ )
25
+ for response in single_response:
26
+ return FileObject(**response)
27
+ raise MistralException("No response received")
28
+
29
+ def retrieve(self, file_id: str) -> FileObject:
30
+ single_response = self.client._request("get", {}, f"v1/files/{file_id}")
31
+ for response in single_response:
32
+ return FileObject(**response)
33
+ raise MistralException("No response received")
34
+
35
+ def list(self) -> Files:
36
+ single_response = self.client._request("get", {}, "v1/files")
37
+ for response in single_response:
38
+ return Files(**response)
39
+ raise MistralException("No response received")
40
+
41
+ def delete(self, file_id: str) -> FileDeleted:
42
+ single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
43
+ for response in single_response:
44
+ return FileDeleted(**response)
45
+ raise MistralException("No response received")
46
+
47
+
48
+ class FilesAsyncClient:
49
+ def __init__(self, client: Any):
50
+ self.client = client
51
+
52
+ async def create(
53
+ self,
54
+ file: bytes,
55
+ purpose: str = "fine-tune",
56
+ ) -> FileObject:
57
+ single_response = self.client._request(
58
+ "post",
59
+ None,
60
+ "v1/files",
61
+ files={"file": file},
62
+ data={"purpose": purpose},
63
+ )
64
+ async for response in single_response:
65
+ return FileObject(**response)
66
+ raise MistralException("No response received")
67
+
68
+ async def retrieve(self, file_id: str) -> FileObject:
69
+ single_response = self.client._request("get", {}, f"v1/files/{file_id}")
70
+ async for response in single_response:
71
+ return FileObject(**response)
72
+ raise MistralException("No response received")
73
+
74
+ async def list(self) -> Files:
75
+ single_response = self.client._request("get", {}, "v1/files")
76
+ async for response in single_response:
77
+ return Files(**response)
78
+ raise MistralException("No response received")
79
+
80
+ async def delete(self, file_id: str) -> FileDeleted:
81
+ single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
82
+ async for response in single_response:
83
+ return FileDeleted(**response)
84
+ raise MistralException("No response received")
@@ -0,0 +1,172 @@
1
+ from datetime import datetime
2
+ from typing import Any, Optional, Union
3
+
4
+ from mistralai.exceptions import (
5
+ MistralException,
6
+ )
7
+ from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
8
+
9
+
10
+ class JobsClient:
11
+ def __init__(self, client: Any):
12
+ self.client = client
13
+
14
+ def create(
15
+ self,
16
+ model: str,
17
+ training_files: Union[list[str], None] = None,
18
+ validation_files: Union[list[str], None] = None,
19
+ hyperparameters: TrainingParameters = TrainingParameters(
20
+ training_steps=1800,
21
+ learning_rate=1.0e-4,
22
+ ),
23
+ suffix: Union[str, None] = None,
24
+ integrations: Union[set[IntegrationIn], None] = None,
25
+ training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
26
+ validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
27
+ dry_run: bool = False,
28
+ ) -> Union[Job, JobMetadata]:
29
+ # Handle deprecated arguments
30
+ if not training_files and training_file:
31
+ training_files = [training_file]
32
+ if not validation_files and validation_file:
33
+ validation_files = [validation_file]
34
+ single_response = self.client._request(
35
+ method="post",
36
+ json={
37
+ "model": model,
38
+ "training_files": training_files,
39
+ "validation_files": validation_files,
40
+ "hyperparameters": hyperparameters.dict(),
41
+ "suffix": suffix,
42
+ "integrations": integrations,
43
+ },
44
+ path="v1/fine_tuning/jobs",
45
+ params={"dry_run": dry_run},
46
+ )
47
+ for response in single_response:
48
+ return Job(**response) if not dry_run else JobMetadata(**response)
49
+ raise MistralException("No response received")
50
+
51
+ def retrieve(self, job_id: str) -> DetailedJob:
52
+ single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
53
+ for response in single_response:
54
+ return DetailedJob(**response)
55
+ raise MistralException("No response received")
56
+
57
+ def list(
58
+ self,
59
+ page: int = 0,
60
+ page_size: int = 10,
61
+ model: Optional[str] = None,
62
+ created_after: Optional[datetime] = None,
63
+ created_by_me: Optional[bool] = None,
64
+ status: Optional[str] = None,
65
+ wandb_project: Optional[str] = None,
66
+ wandb_name: Optional[str] = None,
67
+ suffix: Optional[str] = None,
68
+ ) -> Jobs:
69
+ query_params = JobQueryFilter(
70
+ page=page,
71
+ page_size=page_size,
72
+ model=model,
73
+ created_after=created_after,
74
+ created_by_me=created_by_me,
75
+ status=status,
76
+ wandb_project=wandb_project,
77
+ wandb_name=wandb_name,
78
+ suffix=suffix,
79
+ ).model_dump(exclude_none=True)
80
+ single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
81
+ for response in single_response:
82
+ return Jobs(**response)
83
+ raise MistralException("No response received")
84
+
85
+ def cancel(self, job_id: str) -> DetailedJob:
86
+ single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
87
+ for response in single_response:
88
+ return DetailedJob(**response)
89
+ raise MistralException("No response received")
90
+
91
+
92
+ class JobsAsyncClient:
93
+ def __init__(self, client: Any):
94
+ self.client = client
95
+
96
+ async def create(
97
+ self,
98
+ model: str,
99
+ training_files: Union[list[str], None] = None,
100
+ validation_files: Union[list[str], None] = None,
101
+ hyperparameters: TrainingParameters = TrainingParameters(
102
+ training_steps=1800,
103
+ learning_rate=1.0e-4,
104
+ ),
105
+ suffix: Union[str, None] = None,
106
+ integrations: Union[set[IntegrationIn], None] = None,
107
+ training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
108
+ validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
109
+ dry_run: bool = False,
110
+ ) -> Union[Job, JobMetadata]:
111
+ # Handle deprecated arguments
112
+ if not training_files and training_file:
113
+ training_files = [training_file]
114
+ if not validation_files and validation_file:
115
+ validation_files = [validation_file]
116
+
117
+ single_response = self.client._request(
118
+ method="post",
119
+ json={
120
+ "model": model,
121
+ "training_files": training_files,
122
+ "validation_files": validation_files,
123
+ "hyperparameters": hyperparameters.dict(),
124
+ "suffix": suffix,
125
+ "integrations": integrations,
126
+ },
127
+ path="v1/fine_tuning/jobs",
128
+ params={"dry_run": dry_run},
129
+ )
130
+ async for response in single_response:
131
+ return Job(**response) if not dry_run else JobMetadata(**response)
132
+ raise MistralException("No response received")
133
+
134
+ async def retrieve(self, job_id: str) -> DetailedJob:
135
+ single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
136
+ async for response in single_response:
137
+ return DetailedJob(**response)
138
+ raise MistralException("No response received")
139
+
140
+ async def list(
141
+ self,
142
+ page: int = 0,
143
+ page_size: int = 10,
144
+ model: Optional[str] = None,
145
+ created_after: Optional[datetime] = None,
146
+ created_by_me: Optional[bool] = None,
147
+ status: Optional[str] = None,
148
+ wandb_project: Optional[str] = None,
149
+ wandb_name: Optional[str] = None,
150
+ suffix: Optional[str] = None,
151
+ ) -> Jobs:
152
+ query_params = JobQueryFilter(
153
+ page=page,
154
+ page_size=page_size,
155
+ model=model,
156
+ created_after=created_after,
157
+ created_by_me=created_by_me,
158
+ status=status,
159
+ wandb_project=wandb_project,
160
+ wandb_name=wandb_name,
161
+ suffix=suffix,
162
+ ).model_dump(exclude_none=True)
163
+ single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
164
+ async for response in single_response:
165
+ return Jobs(**response)
166
+ raise MistralException("No response received")
167
+
168
+ async def cancel(self, job_id: str) -> DetailedJob:
169
+ single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
170
+ async for response in single_response:
171
+ return DetailedJob(**response)
172
+ raise MistralException("No response received")
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import List, Optional, Union
2
+ from typing import List, Optional
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
@@ -44,7 +44,7 @@ class ResponseFormat(BaseModel):
44
44
 
45
45
  class ChatMessage(BaseModel):
46
46
  role: str
47
- content: Union[str, List[str]]
47
+ content: str
48
48
  name: Optional[str] = None
49
49
  tool_calls: Optional[List[ToolCall]] = None
50
50
  tool_call_id: Optional[str] = None
@@ -0,0 +1,23 @@
1
+ from typing import Literal, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class FileObject(BaseModel):
7
+ id: str
8
+ object: str
9
+ bytes: int
10
+ created_at: int
11
+ filename: str
12
+ purpose: Optional[Literal["fine-tune"]] = "fine-tune"
13
+
14
+
15
+ class FileDeleted(BaseModel):
16
+ id: str
17
+ object: str
18
+ deleted: bool
19
+
20
+
21
+ class Files(BaseModel):
22
+ data: list[FileObject]
23
+ object: Literal["list"]
@@ -0,0 +1,98 @@
1
+ from datetime import datetime
2
+ from typing import Annotated, List, Literal, Optional, Union
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class TrainingParameters(BaseModel):
8
+ training_steps: int = Field(1800, le=10000, ge=1)
9
+ learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
10
+
11
+
12
+ class WandbIntegration(BaseModel):
13
+ type: Literal["wandb"] = "wandb"
14
+ project: str
15
+ name: Union[str, None] = None
16
+ run_name: Union[str, None] = None
17
+
18
+
19
+ class WandbIntegrationIn(WandbIntegration):
20
+ api_key: str
21
+
22
+
23
+ Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
24
+ IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
25
+
26
+
27
+ class JobMetadata(BaseModel):
28
+ object: Literal["job.metadata"] = "job.metadata"
29
+ training_steps: int
30
+ train_tokens_per_step: int
31
+ data_tokens: int
32
+ train_tokens: int
33
+ epochs: float
34
+ expected_duration_seconds: Optional[int]
35
+
36
+
37
+ class Job(BaseModel):
38
+ id: str
39
+ hyperparameters: TrainingParameters
40
+ fine_tuned_model: Union[str, None]
41
+ model: str
42
+ status: Literal[
43
+ "QUEUED",
44
+ "STARTED",
45
+ "RUNNING",
46
+ "FAILED",
47
+ "SUCCESS",
48
+ "CANCELLED",
49
+ "CANCELLATION_REQUESTED",
50
+ ]
51
+ job_type: str
52
+ created_at: int
53
+ modified_at: int
54
+ training_files: list[str]
55
+ validation_files: Union[list[str], None] = []
56
+ object: Literal["job"]
57
+ integrations: List[Integration] = []
58
+
59
+
60
+ class Event(BaseModel):
61
+ name: str
62
+ data: Union[dict, None] = None
63
+ created_at: int
64
+
65
+
66
+ class Metric(BaseModel):
67
+ train_loss: Union[float, None] = None
68
+ valid_loss: Union[float, None] = None
69
+ valid_mean_token_accuracy: Union[float, None] = None
70
+
71
+
72
+ class Checkpoint(BaseModel):
73
+ metrics: Metric
74
+ step_number: int
75
+ created_at: int
76
+
77
+
78
+ class JobQueryFilter(BaseModel):
79
+ page: int = 0
80
+ page_size: int = 100
81
+ model: Optional[str] = None
82
+ created_after: Optional[datetime] = None
83
+ created_by_me: Optional[bool] = None
84
+ status: Optional[str] = None
85
+ wandb_project: Optional[str] = None
86
+ wandb_name: Optional[str] = None
87
+ suffix: Optional[str] = None
88
+
89
+
90
+ class DetailedJob(Job):
91
+ events: list[Event] = []
92
+ checkpoints: list[Checkpoint] = []
93
+ estimated_start_time: Optional[int] = None
94
+
95
+
96
+ class Jobs(BaseModel):
97
+ data: list[Job] = []
98
+ object: Literal["list"]
@@ -7,15 +7,15 @@ class ModelPermission(BaseModel):
7
7
  id: str
8
8
  object: str
9
9
  created: int
10
- allow_create_engine: bool = False
10
+ allow_create_engine: Optional[bool] = False
11
11
  allow_sampling: bool = True
12
12
  allow_logprobs: bool = True
13
- allow_search_indices: bool = False
13
+ allow_search_indices: Optional[bool] = False
14
14
  allow_view: bool = True
15
15
  allow_fine_tuning: bool = False
16
16
  organization: str = "*"
17
17
  group: Optional[str] = None
18
- is_blocking: bool = False
18
+ is_blocking: Optional[bool] = False
19
19
 
20
20
 
21
21
  class ModelCard(BaseModel):
File without changes
File without changes