mistralai 0.4.0__tar.gz → 0.4.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,20 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary:
5
+ License: Apache 2.0 License
5
6
  Author: Bam4d
6
7
  Author-email: bam4d@mistral.ai
7
8
  Requires-Python: >=3.9,<4.0
9
+ Classifier: License :: Other/Proprietary License
8
10
  Classifier: Programming Language :: Python :: 3
9
11
  Classifier: Programming Language :: Python :: 3.9
10
12
  Classifier: Programming Language :: Python :: 3.10
11
13
  Classifier: Programming Language :: Python :: 3.11
12
14
  Classifier: Programming Language :: Python :: 3.12
13
- Requires-Dist: httpx (>=0.25,<0.26)
15
+ Requires-Dist: httpx (>=0.25,<1)
14
16
  Requires-Dist: orjson (>=3.9.10,<3.11)
15
- Requires-Dist: pydantic (>=2.5.2,<3.0.0)
17
+ Requires-Dist: pydantic (>=2.5.2,<3)
16
18
  Description-Content-Type: text/markdown
17
19
 
18
20
  # Mistral Python Client
@@ -1,9 +1,10 @@
1
1
  [tool.poetry]
2
2
  name = "mistralai"
3
- version = "0.4.0"
3
+ version = "0.4.2"
4
4
  description = ""
5
5
  authors = ["Bam4d <bam4d@mistral.ai>"]
6
6
  readme = "README.md"
7
+ license = "Apache 2.0 License"
7
8
 
8
9
  [tool.ruff]
9
10
  select = ["E", "F", "W", "Q", "I"]
@@ -23,10 +24,10 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
23
24
 
24
25
 
25
26
  [tool.poetry.dependencies]
26
- python = "^3.9,<4.0"
27
- orjson = "^3.9.10,<3.11"
28
- pydantic = "^2.5.2,<3"
29
- httpx = "^0.25,<1"
27
+ python = ">=3.9,<4.0"
28
+ orjson = ">=3.9.10,<3.11"
29
+ pydantic = ">=2.5.2,<3"
30
+ httpx = ">=0.25,<1"
30
31
 
31
32
 
32
33
  [tool.poetry.group.dev.dependencies]
@@ -39,5 +40,3 @@ pytest-asyncio = "^0.23.2"
39
40
  [build-system]
40
41
  requires = ["poetry-core"]
41
42
  build-backend = "poetry.core.masonry.api"
42
-
43
-
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import posixpath
3
3
  from json import JSONDecodeError
4
- from typing import Any, AsyncGenerator, Dict, List, Optional, Union
4
+ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
5
5
 
6
6
  from httpx import (
7
7
  AsyncClient,
@@ -29,7 +29,7 @@ from mistralai.models.chat_completion import (
29
29
  ToolChoice,
30
30
  )
31
31
  from mistralai.models.embeddings import EmbeddingResponse
32
- from mistralai.models.models import ModelList
32
+ from mistralai.models.models import ModelDeleted, ModelList
33
33
 
34
34
 
35
35
  class MistralAsyncClient(ClientBase):
@@ -101,6 +101,7 @@ class MistralAsyncClient(ClientBase):
101
101
  stream: bool = False,
102
102
  attempt: int = 1,
103
103
  data: Optional[Dict[str, Any]] = None,
104
+ check_model_deprecation_headers_callback: Optional[Callable] = None,
104
105
  **kwargs: Any,
105
106
  ) -> AsyncGenerator[Dict[str, Any], None]:
106
107
  accept_header = "text/event-stream" if stream else "application/json"
@@ -129,6 +130,8 @@ class MistralAsyncClient(ClientBase):
129
130
  data=data,
130
131
  **kwargs,
131
132
  ) as response:
133
+ if check_model_deprecation_headers_callback:
134
+ check_model_deprecation_headers_callback(response.headers)
132
135
  await self._check_streaming_response(response)
133
136
 
134
137
  async for line in response.aiter_lines():
@@ -145,7 +148,8 @@ class MistralAsyncClient(ClientBase):
145
148
  data=data,
146
149
  **kwargs,
147
150
  )
148
-
151
+ if check_model_deprecation_headers_callback:
152
+ check_model_deprecation_headers_callback(response.headers)
149
153
  yield await self._check_response(response)
150
154
 
151
155
  except ConnectError as e:
@@ -213,7 +217,12 @@ class MistralAsyncClient(ClientBase):
213
217
  response_format=response_format,
214
218
  )
215
219
 
216
- single_response = self._request("post", request, "v1/chat/completions")
220
+ single_response = self._request(
221
+ "post",
222
+ request,
223
+ "v1/chat/completions",
224
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
225
+ )
217
226
 
218
227
  async for response in single_response:
219
228
  return ChatCompletionResponse(**response)
@@ -267,7 +276,13 @@ class MistralAsyncClient(ClientBase):
267
276
  tool_choice=tool_choice,
268
277
  response_format=response_format,
269
278
  )
270
- async_response = self._request("post", request, "v1/chat/completions", stream=True)
279
+ async_response = self._request(
280
+ "post",
281
+ request,
282
+ "v1/chat/completions",
283
+ stream=True,
284
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
285
+ )
271
286
 
272
287
  async for json_response in async_response:
273
288
  yield ChatCompletionStreamResponse(**json_response)
@@ -284,7 +299,12 @@ class MistralAsyncClient(ClientBase):
284
299
  EmbeddingResponse: A response object containing the embeddings.
285
300
  """
286
301
  request = {"model": model, "input": input}
287
- single_response = self._request("post", request, "v1/embeddings")
302
+ single_response = self._request(
303
+ "post",
304
+ request,
305
+ "v1/embeddings",
306
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
307
+ )
288
308
 
289
309
  async for response in single_response:
290
310
  return EmbeddingResponse(**response)
@@ -304,6 +324,14 @@ class MistralAsyncClient(ClientBase):
304
324
 
305
325
  raise MistralException("No response received")
306
326
 
327
+ async def delete_model(self, model_id: str) -> ModelDeleted:
328
+ single_response = self._request("delete", {}, f"v1/models/{model_id}")
329
+
330
+ async for response in single_response:
331
+ return ModelDeleted(**response)
332
+
333
+ raise MistralException("No response received")
334
+
307
335
  async def completion(
308
336
  self,
309
337
  model: str,
@@ -333,7 +361,12 @@ class MistralAsyncClient(ClientBase):
333
361
  request = self._make_completion_request(
334
362
  prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
335
363
  )
336
- single_response = self._request("post", request, "v1/fim/completions")
364
+ single_response = self._request(
365
+ "post",
366
+ request,
367
+ "v1/fim/completions",
368
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
369
+ )
337
370
 
338
371
  async for response in single_response:
339
372
  return ChatCompletionResponse(**response)
@@ -368,9 +401,23 @@ class MistralAsyncClient(ClientBase):
368
401
  Dict[str, Any]: a response object containing the generated text.
369
402
  """
370
403
  request = self._make_completion_request(
371
- prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
404
+ prompt,
405
+ model,
406
+ suffix,
407
+ temperature,
408
+ max_tokens,
409
+ top_p,
410
+ random_seed,
411
+ stop,
412
+ stream=True,
413
+ )
414
+ async_response = self._request(
415
+ "post",
416
+ request,
417
+ "v1/fim/completions",
418
+ stream=True,
419
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
372
420
  )
373
- async_response = self._request("post", request, "v1/fim/completions", stream=True)
374
421
 
375
422
  async for json_response in async_response:
376
423
  yield ChatCompletionStreamResponse(**json_response)
@@ -1,7 +1,7 @@
1
1
  import posixpath
2
2
  import time
3
3
  from json import JSONDecodeError
4
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
4
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
5
5
 
6
6
  from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
7
7
 
@@ -22,7 +22,7 @@ from mistralai.models.chat_completion import (
22
22
  ToolChoice,
23
23
  )
24
24
  from mistralai.models.embeddings import EmbeddingResponse
25
- from mistralai.models.models import ModelList
25
+ from mistralai.models.models import ModelDeleted, ModelList
26
26
 
27
27
 
28
28
  class MistralClient(ClientBase):
@@ -40,7 +40,9 @@ class MistralClient(ClientBase):
40
40
  super().__init__(endpoint, api_key, max_retries, timeout)
41
41
 
42
42
  self._client = Client(
43
- follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
43
+ follow_redirects=True,
44
+ timeout=self._timeout,
45
+ transport=HTTPTransport(retries=self._max_retries),
44
46
  )
45
47
  self.files = FilesClient(self)
46
48
  self.jobs = JobsClient(self)
@@ -94,6 +96,7 @@ class MistralClient(ClientBase):
94
96
  stream: bool = False,
95
97
  attempt: int = 1,
96
98
  data: Optional[Dict[str, Any]] = None,
99
+ check_model_deprecation_headers_callback: Optional[Callable] = None,
97
100
  **kwargs: Any,
98
101
  ) -> Iterator[Dict[str, Any]]:
99
102
  accept_header = "text/event-stream" if stream else "application/json"
@@ -122,6 +125,8 @@ class MistralClient(ClientBase):
122
125
  data=data,
123
126
  **kwargs,
124
127
  ) as response:
128
+ if check_model_deprecation_headers_callback:
129
+ check_model_deprecation_headers_callback(response.headers)
125
130
  self._check_streaming_response(response)
126
131
 
127
132
  for line in response.iter_lines():
@@ -138,7 +143,8 @@ class MistralClient(ClientBase):
138
143
  data=data,
139
144
  **kwargs,
140
145
  )
141
-
146
+ if check_model_deprecation_headers_callback:
147
+ check_model_deprecation_headers_callback(response.headers)
142
148
  yield self._check_response(response)
143
149
 
144
150
  except ConnectError as e:
@@ -207,7 +213,12 @@ class MistralClient(ClientBase):
207
213
  response_format=response_format,
208
214
  )
209
215
 
210
- single_response = self._request("post", request, "v1/chat/completions")
216
+ single_response = self._request(
217
+ "post",
218
+ request,
219
+ "v1/chat/completions",
220
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
221
+ )
211
222
 
212
223
  for response in single_response:
213
224
  return ChatCompletionResponse(**response)
@@ -261,7 +272,13 @@ class MistralClient(ClientBase):
261
272
  response_format=response_format,
262
273
  )
263
274
 
264
- response = self._request("post", request, "v1/chat/completions", stream=True)
275
+ response = self._request(
276
+ "post",
277
+ request,
278
+ "v1/chat/completions",
279
+ stream=True,
280
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
281
+ )
265
282
 
266
283
  for json_streamed_response in response:
267
284
  yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -278,7 +295,12 @@ class MistralClient(ClientBase):
278
295
  EmbeddingResponse: A response object containing the embeddings.
279
296
  """
280
297
  request = {"model": model, "input": input}
281
- singleton_response = self._request("post", request, "v1/embeddings")
298
+ singleton_response = self._request(
299
+ "post",
300
+ request,
301
+ "v1/embeddings",
302
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
303
+ )
282
304
 
283
305
  for response in singleton_response:
284
306
  return EmbeddingResponse(**response)
@@ -298,6 +320,14 @@ class MistralClient(ClientBase):
298
320
 
299
321
  raise MistralException("No response received")
300
322
 
323
+ def delete_model(self, model_id: str) -> ModelDeleted:
324
+ single_response = self._request("delete", {}, f"v1/models/{model_id}")
325
+
326
+ for response in single_response:
327
+ return ModelDeleted(**response)
328
+
329
+ raise MistralException("No response received")
330
+
301
331
  def completion(
302
332
  self,
303
333
  model: str,
@@ -329,7 +359,13 @@ class MistralClient(ClientBase):
329
359
  prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
330
360
  )
331
361
 
332
- single_response = self._request("post", request, "v1/fim/completions", stream=False)
362
+ single_response = self._request(
363
+ "post",
364
+ request,
365
+ "v1/fim/completions",
366
+ stream=False,
367
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
368
+ )
333
369
 
334
370
  for response in single_response:
335
371
  return ChatCompletionResponse(**response)
@@ -364,10 +400,24 @@ class MistralClient(ClientBase):
364
400
  Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
365
401
  """
366
402
  request = self._make_completion_request(
367
- prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
403
+ prompt,
404
+ model,
405
+ suffix,
406
+ temperature,
407
+ max_tokens,
408
+ top_p,
409
+ random_seed,
410
+ stop,
411
+ stream=True,
368
412
  )
369
413
 
370
- response = self._request("post", request, "v1/fim/completions", stream=True)
414
+ response = self._request(
415
+ "post",
416
+ request,
417
+ "v1/fim/completions",
418
+ stream=True,
419
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
420
+ )
371
421
 
372
422
  for json_streamed_response in response:
373
423
  yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -1,16 +1,21 @@
1
1
  import logging
2
2
  import os
3
3
  from abc import ABC
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import orjson
7
-
8
- from mistralai.exceptions import (
9
- MistralException,
7
+ from httpx import Headers
8
+
9
+ from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
10
+ from mistralai.exceptions import MistralException
11
+ from mistralai.models.chat_completion import (
12
+ ChatMessage,
13
+ Function,
14
+ ResponseFormat,
15
+ ToolChoice,
10
16
  )
11
- from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
17
 
13
- CLIENT_VERSION = "0.4.0"
18
+ CLIENT_VERSION = "0.4.2"
14
19
 
15
20
 
16
21
  class ClientBase(ABC):
@@ -38,6 +43,14 @@ class ClientBase(ABC):
38
43
 
39
44
  self._version = CLIENT_VERSION
40
45
 
46
+ def _get_model(self, model: Optional[str] = None) -> str:
47
+ if model is not None:
48
+ return model
49
+ else:
50
+ if self._default_model is None:
51
+ raise MistralException(message="model must be provided")
52
+ return self._default_model
53
+
41
54
  def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
55
  parsed_tools: List[Dict[str, Any]] = []
43
56
  for tool in tools:
@@ -73,6 +86,22 @@ class ClientBase(ABC):
73
86
 
74
87
  return parsed_messages
75
88
 
89
+ def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
90
+ model = self._get_model(model)
91
+
92
+ def _check_model_deprecation_header_callback(
93
+ headers: Headers,
94
+ ) -> None:
95
+ if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
96
+ self._logger.warning(
97
+ f"WARNING: The model {model} is deprecated "
98
+ f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
99
+ "Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
100
+ "for more information."
101
+ )
102
+
103
+ return _check_model_deprecation_header_callback
104
+
76
105
  def _make_completion_request(
77
106
  self,
78
107
  prompt: str,
@@ -95,16 +124,14 @@ class ClientBase(ABC):
95
124
  if stop is not None:
96
125
  request_data["stop"] = stop
97
126
 
98
- if model is not None:
99
- request_data["model"] = model
100
- else:
101
- if self._default_model is None:
102
- raise MistralException(message="model must be provided")
103
- request_data["model"] = self._default_model
127
+ request_data["model"] = self._get_model(model)
104
128
 
105
129
  request_data.update(
106
130
  self._build_sampling_params(
107
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
131
+ temperature=temperature,
132
+ max_tokens=max_tokens,
133
+ top_p=top_p,
134
+ random_seed=random_seed,
108
135
  )
109
136
  )
110
137
 
@@ -148,16 +175,14 @@ class ClientBase(ABC):
148
175
  "messages": self._parse_messages(messages),
149
176
  }
150
177
 
151
- if model is not None:
152
- request_data["model"] = model
153
- else:
154
- if self._default_model is None:
155
- raise MistralException(message="model must be provided")
156
- request_data["model"] = self._default_model
178
+ request_data["model"] = self._get_model(model)
157
179
 
158
180
  request_data.update(
159
181
  self._build_sampling_params(
160
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
182
+ temperature=temperature,
183
+ max_tokens=max_tokens,
184
+ top_p=top_p,
185
+ random_seed=random_seed,
161
186
  )
162
187
  )
163
188
 
@@ -1,3 +1,5 @@
1
1
  RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
2
2
 
3
3
  ENDPOINT = "https://api.mistral.ai"
4
+
5
+ HEADER_MODEL_DEPRECATION_TIMESTAMP = "x-model-deprecation-timestamp"
@@ -32,6 +32,8 @@ class JobMetadata(BaseModel):
32
32
  train_tokens: int
33
33
  epochs: float
34
34
  expected_duration_seconds: Optional[int]
35
+ cost: Optional[float] = None
36
+ cost_currency: Optional[str] = None
35
37
 
36
38
 
37
39
  class Job(BaseModel):
@@ -31,3 +31,9 @@ class ModelCard(BaseModel):
31
31
  class ModelList(BaseModel):
32
32
  object: str
33
33
  data: List[ModelCard]
34
+
35
+
36
+ class ModelDeleted(BaseModel):
37
+ id: str
38
+ object: str
39
+ deleted: bool
File without changes
File without changes