mistralai 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary:
5
5
  Author: Bam4d
6
6
  Author-email: bam4d@mistral.ai
@@ -10,8 +10,8 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
- Requires-Dist: httpx (>=0.25.2,<1)
14
- Requires-Dist: orjson (>=3.9.10,<4.0.0)
13
+ Requires-Dist: httpx (>=0.25,<0.26)
14
+ Requires-Dist: orjson (>=3.9.10,<3.11)
15
15
  Requires-Dist: pydantic (>=2.5.2,<3.0.0)
16
16
  Description-Content-Type: text/markdown
17
17
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "mistralai"
3
- version = "0.2.0"
3
+ version = "0.4.0"
4
4
  description = ""
5
5
  authors = ["Bam4d <bam4d@mistral.ai>"]
6
6
  readme = "README.md"
@@ -23,10 +23,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
23
23
 
24
24
 
25
25
  [tool.poetry.dependencies]
26
- python = "^3.9"
27
- orjson = "^3.9.10"
28
- pydantic = "^2.5.2"
29
- httpx = ">= 0.25.2, < 1"
26
+ python = "^3.9,<4.0"
27
+ orjson = "^3.9.10,<3.11"
28
+ pydantic = "^2.5.2,<3"
29
+ httpx = "^0.25,<1"
30
30
 
31
31
 
32
32
  [tool.poetry.group.dev.dependencies]
@@ -20,6 +20,8 @@ from mistralai.exceptions import (
20
20
  MistralConnectionException,
21
21
  MistralException,
22
22
  )
23
+ from mistralai.files import FilesAsyncClient
24
+ from mistralai.jobs import JobsAsyncClient
23
25
  from mistralai.models.chat_completion import (
24
26
  ChatCompletionResponse,
25
27
  ChatCompletionStreamResponse,
@@ -47,6 +49,8 @@ class MistralAsyncClient(ClientBase):
47
49
  limits=Limits(max_connections=max_concurrent_requests),
48
50
  transport=AsyncHTTPTransport(retries=max_retries),
49
51
  )
52
+ self.files = FilesAsyncClient(self)
53
+ self.jobs = JobsAsyncClient(self)
50
54
 
51
55
  async def close(self) -> None:
52
56
  await self._client.aclose()
@@ -92,19 +96,23 @@ class MistralAsyncClient(ClientBase):
92
96
  async def _request(
93
97
  self,
94
98
  method: str,
95
- json: Dict[str, Any],
99
+ json: Optional[Dict[str, Any]],
96
100
  path: str,
97
101
  stream: bool = False,
98
102
  attempt: int = 1,
103
+ data: Optional[Dict[str, Any]] = None,
104
+ **kwargs: Any,
99
105
  ) -> AsyncGenerator[Dict[str, Any], None]:
100
106
  accept_header = "text/event-stream" if stream else "application/json"
101
107
  headers = {
102
108
  "Accept": accept_header,
103
109
  "User-Agent": f"mistral-client-python/{self._version}",
104
110
  "Authorization": f"Bearer {self._api_key}",
105
- "Content-Type": "application/json",
106
111
  }
107
112
 
113
+ if json is not None:
114
+ headers["Content-Type"] = "application/json"
115
+
108
116
  url = posixpath.join(self._endpoint, path)
109
117
 
110
118
  self._logger.debug(f"Sending request: {method} {url} {json}")
@@ -118,6 +126,8 @@ class MistralAsyncClient(ClientBase):
118
126
  url,
119
127
  headers=headers,
120
128
  json=json,
129
+ data=data,
130
+ **kwargs,
121
131
  ) as response:
122
132
  await self._check_streaming_response(response)
123
133
 
@@ -132,6 +142,8 @@ class MistralAsyncClient(ClientBase):
132
142
  url,
133
143
  headers=headers,
134
144
  json=json,
145
+ data=data,
146
+ **kwargs,
135
147
  )
136
148
 
137
149
  yield await self._check_response(response)
@@ -291,3 +303,74 @@ class MistralAsyncClient(ClientBase):
291
303
  return ModelList(**response)
292
304
 
293
305
  raise MistralException("No response received")
306
+
307
+ async def completion(
308
+ self,
309
+ model: str,
310
+ prompt: str,
311
+ suffix: Optional[str] = None,
312
+ temperature: Optional[float] = None,
313
+ max_tokens: Optional[int] = None,
314
+ top_p: Optional[float] = None,
315
+ random_seed: Optional[int] = None,
316
+ stop: Optional[List[str]] = None,
317
+ ) -> ChatCompletionResponse:
318
+ """An asynchronous completion endpoint that returns a single response.
319
+
320
+ Args:
321
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
322
+ prompt (str): the prompt to complete
323
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
324
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
325
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
326
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
327
+ Defaults to None.
328
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
329
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
330
+ Returns:
331
+ Dict[str, Any]: a response object containing the generated text.
332
+ """
333
+ request = self._make_completion_request(
334
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
335
+ )
336
+ single_response = self._request("post", request, "v1/fim/completions")
337
+
338
+ async for response in single_response:
339
+ return ChatCompletionResponse(**response)
340
+
341
+ raise MistralException("No response received")
342
+
343
+ async def completion_stream(
344
+ self,
345
+ model: str,
346
+ prompt: str,
347
+ suffix: Optional[str] = None,
348
+ temperature: Optional[float] = None,
349
+ max_tokens: Optional[int] = None,
350
+ top_p: Optional[float] = None,
351
+ random_seed: Optional[int] = None,
352
+ stop: Optional[List[str]] = None,
353
+ ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
354
+ """An asynchronous completion endpoint that returns a streaming response.
355
+
356
+ Args:
357
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
358
+ prompt (str): the prompt to complete
359
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
360
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
361
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
362
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
363
+ Defaults to None.
364
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
365
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
366
+
367
+ Returns:
368
+ Dict[str, Any]: a response object containing the generated text.
369
+ """
370
+ request = self._make_completion_request(
371
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
372
+ )
373
+ async_response = self._request("post", request, "v1/fim/completions", stream=True)
374
+
375
+ async for json_response in async_response:
376
+ yield ChatCompletionStreamResponse(**json_response)
@@ -13,6 +13,8 @@ from mistralai.exceptions import (
13
13
  MistralConnectionException,
14
14
  MistralException,
15
15
  )
16
+ from mistralai.files import FilesClient
17
+ from mistralai.jobs import JobsClient
16
18
  from mistralai.models.chat_completion import (
17
19
  ChatCompletionResponse,
18
20
  ChatCompletionStreamResponse,
@@ -40,6 +42,8 @@ class MistralClient(ClientBase):
40
42
  self._client = Client(
41
43
  follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
42
44
  )
45
+ self.files = FilesClient(self)
46
+ self.jobs = JobsClient(self)
43
47
 
44
48
  def __del__(self) -> None:
45
49
  self._client.close()
@@ -85,19 +89,23 @@ class MistralClient(ClientBase):
85
89
  def _request(
86
90
  self,
87
91
  method: str,
88
- json: Dict[str, Any],
92
+ json: Optional[Dict[str, Any]],
89
93
  path: str,
90
94
  stream: bool = False,
91
95
  attempt: int = 1,
96
+ data: Optional[Dict[str, Any]] = None,
97
+ **kwargs: Any,
92
98
  ) -> Iterator[Dict[str, Any]]:
93
99
  accept_header = "text/event-stream" if stream else "application/json"
94
100
  headers = {
95
101
  "Accept": accept_header,
96
102
  "User-Agent": f"mistral-client-python/{self._version}",
97
103
  "Authorization": f"Bearer {self._api_key}",
98
- "Content-Type": "application/json",
99
104
  }
100
105
 
106
+ if json is not None:
107
+ headers["Content-Type"] = "application/json"
108
+
101
109
  url = posixpath.join(self._endpoint, path)
102
110
 
103
111
  self._logger.debug(f"Sending request: {method} {url} {json}")
@@ -111,6 +119,8 @@ class MistralClient(ClientBase):
111
119
  url,
112
120
  headers=headers,
113
121
  json=json,
122
+ data=data,
123
+ **kwargs,
114
124
  ) as response:
115
125
  self._check_streaming_response(response)
116
126
 
@@ -125,6 +135,8 @@ class MistralClient(ClientBase):
125
135
  url,
126
136
  headers=headers,
127
137
  json=json,
138
+ data=data,
139
+ **kwargs,
128
140
  )
129
141
 
130
142
  yield self._check_response(response)
@@ -285,3 +297,77 @@ class MistralClient(ClientBase):
285
297
  return ModelList(**response)
286
298
 
287
299
  raise MistralException("No response received")
300
+
301
+ def completion(
302
+ self,
303
+ model: str,
304
+ prompt: str,
305
+ suffix: Optional[str] = None,
306
+ temperature: Optional[float] = None,
307
+ max_tokens: Optional[int] = None,
308
+ top_p: Optional[float] = None,
309
+ random_seed: Optional[int] = None,
310
+ stop: Optional[List[str]] = None,
311
+ ) -> ChatCompletionResponse:
312
+ """A completion endpoint that returns a single response.
313
+
314
+ Args:
315
+ model (str): model the name of the model to get completion with, e.g. codestral-latest
316
+ prompt (str): the prompt to complete
317
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
318
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
319
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
320
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
321
+ Defaults to None.
322
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
323
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
324
+
325
+ Returns:
326
+ Dict[str, Any]: a response object containing the generated text.
327
+ """
328
+ request = self._make_completion_request(
329
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
330
+ )
331
+
332
+ single_response = self._request("post", request, "v1/fim/completions", stream=False)
333
+
334
+ for response in single_response:
335
+ return ChatCompletionResponse(**response)
336
+
337
+ raise MistralException("No response received")
338
+
339
+ def completion_stream(
340
+ self,
341
+ model: str,
342
+ prompt: str,
343
+ suffix: Optional[str] = None,
344
+ temperature: Optional[float] = None,
345
+ max_tokens: Optional[int] = None,
346
+ top_p: Optional[float] = None,
347
+ random_seed: Optional[int] = None,
348
+ stop: Optional[List[str]] = None,
349
+ ) -> Iterable[ChatCompletionStreamResponse]:
350
+ """An asynchronous completion endpoint that streams responses.
351
+
352
+ Args:
353
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
354
+ prompt (str): the prompt to complete
355
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
356
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
357
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
358
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
359
+ Defaults to None.
360
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
361
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
362
+
363
+ Returns:
364
+ Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
365
+ """
366
+ request = self._make_completion_request(
367
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
368
+ )
369
+
370
+ response = self._request("post", request, "v1/fim/completions", stream=True)
371
+
372
+ for json_streamed_response in response:
373
+ yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -10,7 +10,7 @@ from mistralai.exceptions import (
10
10
  )
11
11
  from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
12
 
13
- CLIENT_VERSION = "0.2.0"
13
+ CLIENT_VERSION = "0.4.0"
14
14
 
15
15
 
16
16
  class ClientBase(ABC):
@@ -73,6 +73,63 @@ class ClientBase(ABC):
73
73
 
74
74
  return parsed_messages
75
75
 
76
+ def _make_completion_request(
77
+ self,
78
+ prompt: str,
79
+ model: Optional[str] = None,
80
+ suffix: Optional[str] = None,
81
+ temperature: Optional[float] = None,
82
+ max_tokens: Optional[int] = None,
83
+ top_p: Optional[float] = None,
84
+ random_seed: Optional[int] = None,
85
+ stop: Optional[List[str]] = None,
86
+ stream: Optional[bool] = False,
87
+ ) -> Dict[str, Any]:
88
+ request_data: Dict[str, Any] = {
89
+ "prompt": prompt,
90
+ "suffix": suffix,
91
+ "model": model,
92
+ "stream": stream,
93
+ }
94
+
95
+ if stop is not None:
96
+ request_data["stop"] = stop
97
+
98
+ if model is not None:
99
+ request_data["model"] = model
100
+ else:
101
+ if self._default_model is None:
102
+ raise MistralException(message="model must be provided")
103
+ request_data["model"] = self._default_model
104
+
105
+ request_data.update(
106
+ self._build_sampling_params(
107
+ temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
108
+ )
109
+ )
110
+
111
+ self._logger.debug(f"Completion request: {request_data}")
112
+
113
+ return request_data
114
+
115
+ def _build_sampling_params(
116
+ self,
117
+ max_tokens: Optional[int],
118
+ random_seed: Optional[int],
119
+ temperature: Optional[float],
120
+ top_p: Optional[float],
121
+ ) -> Dict[str, Any]:
122
+ params = {}
123
+ if temperature is not None:
124
+ params["temperature"] = temperature
125
+ if max_tokens is not None:
126
+ params["max_tokens"] = max_tokens
127
+ if top_p is not None:
128
+ params["top_p"] = top_p
129
+ if random_seed is not None:
130
+ params["random_seed"] = random_seed
131
+ return params
132
+
76
133
  def _make_chat_request(
77
134
  self,
78
135
  messages: List[Any],
@@ -89,7 +146,6 @@ class ClientBase(ABC):
89
146
  ) -> Dict[str, Any]:
90
147
  request_data: Dict[str, Any] = {
91
148
  "messages": self._parse_messages(messages),
92
- "safe_prompt": safe_prompt,
93
149
  }
94
150
 
95
151
  if model is not None:
@@ -99,16 +155,16 @@ class ClientBase(ABC):
99
155
  raise MistralException(message="model must be provided")
100
156
  request_data["model"] = self._default_model
101
157
 
158
+ request_data.update(
159
+ self._build_sampling_params(
160
+ temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
161
+ )
162
+ )
163
+
164
+ if safe_prompt:
165
+ request_data["safe_prompt"] = safe_prompt
102
166
  if tools is not None:
103
167
  request_data["tools"] = self._parse_tools(tools)
104
- if temperature is not None:
105
- request_data["temperature"] = temperature
106
- if max_tokens is not None:
107
- request_data["max_tokens"] = max_tokens
108
- if top_p is not None:
109
- request_data["top_p"] = top_p
110
- if random_seed is not None:
111
- request_data["random_seed"] = random_seed
112
168
  if stream is not None:
113
169
  request_data["stream"] = stream
114
170
 
@@ -0,0 +1,84 @@
1
+ from typing import Any
2
+
3
+ from mistralai.exceptions import (
4
+ MistralException,
5
+ )
6
+ from mistralai.models.files import FileDeleted, FileObject, Files
7
+
8
+
9
+ class FilesClient:
10
+ def __init__(self, client: Any):
11
+ self.client = client
12
+
13
+ def create(
14
+ self,
15
+ file: bytes,
16
+ purpose: str = "fine-tune",
17
+ ) -> FileObject:
18
+ single_response = self.client._request(
19
+ "post",
20
+ None,
21
+ "v1/files",
22
+ files={"file": file},
23
+ data={"purpose": purpose},
24
+ )
25
+ for response in single_response:
26
+ return FileObject(**response)
27
+ raise MistralException("No response received")
28
+
29
+ def retrieve(self, file_id: str) -> FileObject:
30
+ single_response = self.client._request("get", {}, f"v1/files/{file_id}")
31
+ for response in single_response:
32
+ return FileObject(**response)
33
+ raise MistralException("No response received")
34
+
35
+ def list(self) -> Files:
36
+ single_response = self.client._request("get", {}, "v1/files")
37
+ for response in single_response:
38
+ return Files(**response)
39
+ raise MistralException("No response received")
40
+
41
+ def delete(self, file_id: str) -> FileDeleted:
42
+ single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
43
+ for response in single_response:
44
+ return FileDeleted(**response)
45
+ raise MistralException("No response received")
46
+
47
+
48
+ class FilesAsyncClient:
49
+ def __init__(self, client: Any):
50
+ self.client = client
51
+
52
+ async def create(
53
+ self,
54
+ file: bytes,
55
+ purpose: str = "fine-tune",
56
+ ) -> FileObject:
57
+ single_response = self.client._request(
58
+ "post",
59
+ None,
60
+ "v1/files",
61
+ files={"file": file},
62
+ data={"purpose": purpose},
63
+ )
64
+ async for response in single_response:
65
+ return FileObject(**response)
66
+ raise MistralException("No response received")
67
+
68
+ async def retrieve(self, file_id: str) -> FileObject:
69
+ single_response = self.client._request("get", {}, f"v1/files/{file_id}")
70
+ async for response in single_response:
71
+ return FileObject(**response)
72
+ raise MistralException("No response received")
73
+
74
+ async def list(self) -> Files:
75
+ single_response = self.client._request("get", {}, "v1/files")
76
+ async for response in single_response:
77
+ return Files(**response)
78
+ raise MistralException("No response received")
79
+
80
+ async def delete(self, file_id: str) -> FileDeleted:
81
+ single_response = self.client._request("delete", {}, f"v1/files/{file_id}")
82
+ async for response in single_response:
83
+ return FileDeleted(**response)
84
+ raise MistralException("No response received")
@@ -0,0 +1,172 @@
1
+ from datetime import datetime
2
+ from typing import Any, Optional, Union
3
+
4
+ from mistralai.exceptions import (
5
+ MistralException,
6
+ )
7
+ from mistralai.models.jobs import DetailedJob, IntegrationIn, Job, JobMetadata, JobQueryFilter, Jobs, TrainingParameters
8
+
9
+
10
+ class JobsClient:
11
+ def __init__(self, client: Any):
12
+ self.client = client
13
+
14
+ def create(
15
+ self,
16
+ model: str,
17
+ training_files: Union[list[str], None] = None,
18
+ validation_files: Union[list[str], None] = None,
19
+ hyperparameters: TrainingParameters = TrainingParameters(
20
+ training_steps=1800,
21
+ learning_rate=1.0e-4,
22
+ ),
23
+ suffix: Union[str, None] = None,
24
+ integrations: Union[set[IntegrationIn], None] = None,
25
+ training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
26
+ validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
27
+ dry_run: bool = False,
28
+ ) -> Union[Job, JobMetadata]:
29
+ # Handle deprecated arguments
30
+ if not training_files and training_file:
31
+ training_files = [training_file]
32
+ if not validation_files and validation_file:
33
+ validation_files = [validation_file]
34
+ single_response = self.client._request(
35
+ method="post",
36
+ json={
37
+ "model": model,
38
+ "training_files": training_files,
39
+ "validation_files": validation_files,
40
+ "hyperparameters": hyperparameters.dict(),
41
+ "suffix": suffix,
42
+ "integrations": integrations,
43
+ },
44
+ path="v1/fine_tuning/jobs",
45
+ params={"dry_run": dry_run},
46
+ )
47
+ for response in single_response:
48
+ return Job(**response) if not dry_run else JobMetadata(**response)
49
+ raise MistralException("No response received")
50
+
51
+ def retrieve(self, job_id: str) -> DetailedJob:
52
+ single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
53
+ for response in single_response:
54
+ return DetailedJob(**response)
55
+ raise MistralException("No response received")
56
+
57
+ def list(
58
+ self,
59
+ page: int = 0,
60
+ page_size: int = 10,
61
+ model: Optional[str] = None,
62
+ created_after: Optional[datetime] = None,
63
+ created_by_me: Optional[bool] = None,
64
+ status: Optional[str] = None,
65
+ wandb_project: Optional[str] = None,
66
+ wandb_name: Optional[str] = None,
67
+ suffix: Optional[str] = None,
68
+ ) -> Jobs:
69
+ query_params = JobQueryFilter(
70
+ page=page,
71
+ page_size=page_size,
72
+ model=model,
73
+ created_after=created_after,
74
+ created_by_me=created_by_me,
75
+ status=status,
76
+ wandb_project=wandb_project,
77
+ wandb_name=wandb_name,
78
+ suffix=suffix,
79
+ ).model_dump(exclude_none=True)
80
+ single_response = self.client._request(method="get", params=query_params, path="v1/fine_tuning/jobs", json={})
81
+ for response in single_response:
82
+ return Jobs(**response)
83
+ raise MistralException("No response received")
84
+
85
+ def cancel(self, job_id: str) -> DetailedJob:
86
+ single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
87
+ for response in single_response:
88
+ return DetailedJob(**response)
89
+ raise MistralException("No response received")
90
+
91
+
92
+ class JobsAsyncClient:
93
+ def __init__(self, client: Any):
94
+ self.client = client
95
+
96
+ async def create(
97
+ self,
98
+ model: str,
99
+ training_files: Union[list[str], None] = None,
100
+ validation_files: Union[list[str], None] = None,
101
+ hyperparameters: TrainingParameters = TrainingParameters(
102
+ training_steps=1800,
103
+ learning_rate=1.0e-4,
104
+ ),
105
+ suffix: Union[str, None] = None,
106
+ integrations: Union[set[IntegrationIn], None] = None,
107
+ training_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
108
+ validation_file: Union[str, None] = None, # Deprecated: Added for compatibility with OpenAI API
109
+ dry_run: bool = False,
110
+ ) -> Union[Job, JobMetadata]:
111
+ # Handle deprecated arguments
112
+ if not training_files and training_file:
113
+ training_files = [training_file]
114
+ if not validation_files and validation_file:
115
+ validation_files = [validation_file]
116
+
117
+ single_response = self.client._request(
118
+ method="post",
119
+ json={
120
+ "model": model,
121
+ "training_files": training_files,
122
+ "validation_files": validation_files,
123
+ "hyperparameters": hyperparameters.dict(),
124
+ "suffix": suffix,
125
+ "integrations": integrations,
126
+ },
127
+ path="v1/fine_tuning/jobs",
128
+ params={"dry_run": dry_run},
129
+ )
130
+ async for response in single_response:
131
+ return Job(**response) if not dry_run else JobMetadata(**response)
132
+ raise MistralException("No response received")
133
+
134
+ async def retrieve(self, job_id: str) -> DetailedJob:
135
+ single_response = self.client._request(method="get", path=f"v1/fine_tuning/jobs/{job_id}", json={})
136
+ async for response in single_response:
137
+ return DetailedJob(**response)
138
+ raise MistralException("No response received")
139
+
140
+ async def list(
141
+ self,
142
+ page: int = 0,
143
+ page_size: int = 10,
144
+ model: Optional[str] = None,
145
+ created_after: Optional[datetime] = None,
146
+ created_by_me: Optional[bool] = None,
147
+ status: Optional[str] = None,
148
+ wandb_project: Optional[str] = None,
149
+ wandb_name: Optional[str] = None,
150
+ suffix: Optional[str] = None,
151
+ ) -> Jobs:
152
+ query_params = JobQueryFilter(
153
+ page=page,
154
+ page_size=page_size,
155
+ model=model,
156
+ created_after=created_after,
157
+ created_by_me=created_by_me,
158
+ status=status,
159
+ wandb_project=wandb_project,
160
+ wandb_name=wandb_name,
161
+ suffix=suffix,
162
+ ).model_dump(exclude_none=True)
163
+ single_response = self.client._request(method="get", path="v1/fine_tuning/jobs", params=query_params, json={})
164
+ async for response in single_response:
165
+ return Jobs(**response)
166
+ raise MistralException("No response received")
167
+
168
+ async def cancel(self, job_id: str) -> DetailedJob:
169
+ single_response = self.client._request(method="post", path=f"v1/fine_tuning/jobs/{job_id}/cancel", json={})
170
+ async for response in single_response:
171
+ return DetailedJob(**response)
172
+ raise MistralException("No response received")
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import List, Optional, Union
2
+ from typing import List, Optional
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
@@ -44,7 +44,7 @@ class ResponseFormat(BaseModel):
44
44
 
45
45
  class ChatMessage(BaseModel):
46
46
  role: str
47
- content: Union[str, List[str]]
47
+ content: str
48
48
  name: Optional[str] = None
49
49
  tool_calls: Optional[List[ToolCall]] = None
50
50
  tool_call_id: Optional[str] = None
@@ -0,0 +1,23 @@
1
+ from typing import Literal, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class FileObject(BaseModel):
7
+ id: str
8
+ object: str
9
+ bytes: int
10
+ created_at: int
11
+ filename: str
12
+ purpose: Optional[Literal["fine-tune"]] = "fine-tune"
13
+
14
+
15
+ class FileDeleted(BaseModel):
16
+ id: str
17
+ object: str
18
+ deleted: bool
19
+
20
+
21
+ class Files(BaseModel):
22
+ data: list[FileObject]
23
+ object: Literal["list"]
@@ -0,0 +1,98 @@
1
+ from datetime import datetime
2
+ from typing import Annotated, List, Literal, Optional, Union
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class TrainingParameters(BaseModel):
8
+ training_steps: int = Field(1800, le=10000, ge=1)
9
+ learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)
10
+
11
+
12
+ class WandbIntegration(BaseModel):
13
+ type: Literal["wandb"] = "wandb"
14
+ project: str
15
+ name: Union[str, None] = None
16
+ run_name: Union[str, None] = None
17
+
18
+
19
+ class WandbIntegrationIn(WandbIntegration):
20
+ api_key: str
21
+
22
+
23
+ Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
24
+ IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
25
+
26
+
27
+ class JobMetadata(BaseModel):
28
+ object: Literal["job.metadata"] = "job.metadata"
29
+ training_steps: int
30
+ train_tokens_per_step: int
31
+ data_tokens: int
32
+ train_tokens: int
33
+ epochs: float
34
+ expected_duration_seconds: Optional[int]
35
+
36
+
37
+ class Job(BaseModel):
38
+ id: str
39
+ hyperparameters: TrainingParameters
40
+ fine_tuned_model: Union[str, None]
41
+ model: str
42
+ status: Literal[
43
+ "QUEUED",
44
+ "STARTED",
45
+ "RUNNING",
46
+ "FAILED",
47
+ "SUCCESS",
48
+ "CANCELLED",
49
+ "CANCELLATION_REQUESTED",
50
+ ]
51
+ job_type: str
52
+ created_at: int
53
+ modified_at: int
54
+ training_files: list[str]
55
+ validation_files: Union[list[str], None] = []
56
+ object: Literal["job"]
57
+ integrations: List[Integration] = []
58
+
59
+
60
+ class Event(BaseModel):
61
+ name: str
62
+ data: Union[dict, None] = None
63
+ created_at: int
64
+
65
+
66
+ class Metric(BaseModel):
67
+ train_loss: Union[float, None] = None
68
+ valid_loss: Union[float, None] = None
69
+ valid_mean_token_accuracy: Union[float, None] = None
70
+
71
+
72
+ class Checkpoint(BaseModel):
73
+ metrics: Metric
74
+ step_number: int
75
+ created_at: int
76
+
77
+
78
+ class JobQueryFilter(BaseModel):
79
+ page: int = 0
80
+ page_size: int = 100
81
+ model: Optional[str] = None
82
+ created_after: Optional[datetime] = None
83
+ created_by_me: Optional[bool] = None
84
+ status: Optional[str] = None
85
+ wandb_project: Optional[str] = None
86
+ wandb_name: Optional[str] = None
87
+ suffix: Optional[str] = None
88
+
89
+
90
+ class DetailedJob(Job):
91
+ events: list[Event] = []
92
+ checkpoints: list[Checkpoint] = []
93
+ estimated_start_time: Optional[int] = None
94
+
95
+
96
+ class Jobs(BaseModel):
97
+ data: list[Job] = []
98
+ object: Literal["list"]
@@ -7,15 +7,15 @@ class ModelPermission(BaseModel):
7
7
  id: str
8
8
  object: str
9
9
  created: int
10
- allow_create_engine: bool = False
10
+ allow_create_engine: Optional[bool] = False
11
11
  allow_sampling: bool = True
12
12
  allow_logprobs: bool = True
13
- allow_search_indices: bool = False
13
+ allow_search_indices: Optional[bool] = False
14
14
  allow_view: bool = True
15
15
  allow_fine_tuning: bool = False
16
16
  organization: str = "*"
17
17
  group: Optional[str] = None
18
- is_blocking: bool = False
18
+ is_blocking: Optional[bool] = False
19
19
 
20
20
 
21
21
  class ModelCard(BaseModel):
File without changes
File without changes