mistralai 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mistralai/async_client.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import posixpath
3
3
  from json import JSONDecodeError
4
- from typing import Any, AsyncGenerator, Dict, List, Optional, Union
4
+ from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
5
5
 
6
6
  from httpx import (
7
7
  AsyncClient,
@@ -101,6 +101,7 @@ class MistralAsyncClient(ClientBase):
101
101
  stream: bool = False,
102
102
  attempt: int = 1,
103
103
  data: Optional[Dict[str, Any]] = None,
104
+ check_model_deprecation_headers_callback: Optional[Callable] = None,
104
105
  **kwargs: Any,
105
106
  ) -> AsyncGenerator[Dict[str, Any], None]:
106
107
  accept_header = "text/event-stream" if stream else "application/json"
@@ -129,6 +130,8 @@ class MistralAsyncClient(ClientBase):
129
130
  data=data,
130
131
  **kwargs,
131
132
  ) as response:
133
+ if check_model_deprecation_headers_callback:
134
+ check_model_deprecation_headers_callback(response.headers)
132
135
  await self._check_streaming_response(response)
133
136
 
134
137
  async for line in response.aiter_lines():
@@ -145,7 +148,8 @@ class MistralAsyncClient(ClientBase):
145
148
  data=data,
146
149
  **kwargs,
147
150
  )
148
-
151
+ if check_model_deprecation_headers_callback:
152
+ check_model_deprecation_headers_callback(response.headers)
149
153
  yield await self._check_response(response)
150
154
 
151
155
  except ConnectError as e:
@@ -213,7 +217,12 @@ class MistralAsyncClient(ClientBase):
213
217
  response_format=response_format,
214
218
  )
215
219
 
216
- single_response = self._request("post", request, "v1/chat/completions")
220
+ single_response = self._request(
221
+ "post",
222
+ request,
223
+ "v1/chat/completions",
224
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
225
+ )
217
226
 
218
227
  async for response in single_response:
219
228
  return ChatCompletionResponse(**response)
@@ -267,7 +276,13 @@ class MistralAsyncClient(ClientBase):
267
276
  tool_choice=tool_choice,
268
277
  response_format=response_format,
269
278
  )
270
- async_response = self._request("post", request, "v1/chat/completions", stream=True)
279
+ async_response = self._request(
280
+ "post",
281
+ request,
282
+ "v1/chat/completions",
283
+ stream=True,
284
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
285
+ )
271
286
 
272
287
  async for json_response in async_response:
273
288
  yield ChatCompletionStreamResponse(**json_response)
@@ -284,7 +299,12 @@ class MistralAsyncClient(ClientBase):
284
299
  EmbeddingResponse: A response object containing the embeddings.
285
300
  """
286
301
  request = {"model": model, "input": input}
287
- single_response = self._request("post", request, "v1/embeddings")
302
+ single_response = self._request(
303
+ "post",
304
+ request,
305
+ "v1/embeddings",
306
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
307
+ )
288
308
 
289
309
  async for response in single_response:
290
310
  return EmbeddingResponse(**response)
@@ -341,7 +361,12 @@ class MistralAsyncClient(ClientBase):
341
361
  request = self._make_completion_request(
342
362
  prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
343
363
  )
344
- single_response = self._request("post", request, "v1/fim/completions")
364
+ single_response = self._request(
365
+ "post",
366
+ request,
367
+ "v1/fim/completions",
368
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
369
+ )
345
370
 
346
371
  async for response in single_response:
347
372
  return ChatCompletionResponse(**response)
@@ -376,9 +401,23 @@ class MistralAsyncClient(ClientBase):
376
401
  Dict[str, Any]: a response object containing the generated text.
377
402
  """
378
403
  request = self._make_completion_request(
379
- prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
404
+ prompt,
405
+ model,
406
+ suffix,
407
+ temperature,
408
+ max_tokens,
409
+ top_p,
410
+ random_seed,
411
+ stop,
412
+ stream=True,
413
+ )
414
+ async_response = self._request(
415
+ "post",
416
+ request,
417
+ "v1/fim/completions",
418
+ stream=True,
419
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
380
420
  )
381
- async_response = self._request("post", request, "v1/fim/completions", stream=True)
382
421
 
383
422
  async for json_response in async_response:
384
423
  yield ChatCompletionStreamResponse(**json_response)
mistralai/client.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import posixpath
2
2
  import time
3
3
  from json import JSONDecodeError
4
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
4
+ from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
5
5
 
6
6
  from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
7
7
 
@@ -40,7 +40,9 @@ class MistralClient(ClientBase):
40
40
  super().__init__(endpoint, api_key, max_retries, timeout)
41
41
 
42
42
  self._client = Client(
43
- follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
43
+ follow_redirects=True,
44
+ timeout=self._timeout,
45
+ transport=HTTPTransport(retries=self._max_retries),
44
46
  )
45
47
  self.files = FilesClient(self)
46
48
  self.jobs = JobsClient(self)
@@ -94,6 +96,7 @@ class MistralClient(ClientBase):
94
96
  stream: bool = False,
95
97
  attempt: int = 1,
96
98
  data: Optional[Dict[str, Any]] = None,
99
+ check_model_deprecation_headers_callback: Optional[Callable] = None,
97
100
  **kwargs: Any,
98
101
  ) -> Iterator[Dict[str, Any]]:
99
102
  accept_header = "text/event-stream" if stream else "application/json"
@@ -122,6 +125,8 @@ class MistralClient(ClientBase):
122
125
  data=data,
123
126
  **kwargs,
124
127
  ) as response:
128
+ if check_model_deprecation_headers_callback:
129
+ check_model_deprecation_headers_callback(response.headers)
125
130
  self._check_streaming_response(response)
126
131
 
127
132
  for line in response.iter_lines():
@@ -138,7 +143,8 @@ class MistralClient(ClientBase):
138
143
  data=data,
139
144
  **kwargs,
140
145
  )
141
-
146
+ if check_model_deprecation_headers_callback:
147
+ check_model_deprecation_headers_callback(response.headers)
142
148
  yield self._check_response(response)
143
149
 
144
150
  except ConnectError as e:
@@ -207,7 +213,12 @@ class MistralClient(ClientBase):
207
213
  response_format=response_format,
208
214
  )
209
215
 
210
- single_response = self._request("post", request, "v1/chat/completions")
216
+ single_response = self._request(
217
+ "post",
218
+ request,
219
+ "v1/chat/completions",
220
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
221
+ )
211
222
 
212
223
  for response in single_response:
213
224
  return ChatCompletionResponse(**response)
@@ -261,7 +272,13 @@ class MistralClient(ClientBase):
261
272
  response_format=response_format,
262
273
  )
263
274
 
264
- response = self._request("post", request, "v1/chat/completions", stream=True)
275
+ response = self._request(
276
+ "post",
277
+ request,
278
+ "v1/chat/completions",
279
+ stream=True,
280
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
281
+ )
265
282
 
266
283
  for json_streamed_response in response:
267
284
  yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -278,7 +295,12 @@ class MistralClient(ClientBase):
278
295
  EmbeddingResponse: A response object containing the embeddings.
279
296
  """
280
297
  request = {"model": model, "input": input}
281
- singleton_response = self._request("post", request, "v1/embeddings")
298
+ singleton_response = self._request(
299
+ "post",
300
+ request,
301
+ "v1/embeddings",
302
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
303
+ )
282
304
 
283
305
  for response in singleton_response:
284
306
  return EmbeddingResponse(**response)
@@ -337,7 +359,13 @@ class MistralClient(ClientBase):
337
359
  prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
338
360
  )
339
361
 
340
- single_response = self._request("post", request, "v1/fim/completions", stream=False)
362
+ single_response = self._request(
363
+ "post",
364
+ request,
365
+ "v1/fim/completions",
366
+ stream=False,
367
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
368
+ )
341
369
 
342
370
  for response in single_response:
343
371
  return ChatCompletionResponse(**response)
@@ -372,10 +400,24 @@ class MistralClient(ClientBase):
372
400
  Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
373
401
  """
374
402
  request = self._make_completion_request(
375
- prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
403
+ prompt,
404
+ model,
405
+ suffix,
406
+ temperature,
407
+ max_tokens,
408
+ top_p,
409
+ random_seed,
410
+ stop,
411
+ stream=True,
376
412
  )
377
413
 
378
- response = self._request("post", request, "v1/fim/completions", stream=True)
414
+ response = self._request(
415
+ "post",
416
+ request,
417
+ "v1/fim/completions",
418
+ stream=True,
419
+ check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
420
+ )
379
421
 
380
422
  for json_streamed_response in response:
381
423
  yield ChatCompletionStreamResponse(**json_streamed_response)
mistralai/client_base.py CHANGED
@@ -1,16 +1,21 @@
1
1
  import logging
2
2
  import os
3
3
  from abc import ABC
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union
5
5
 
6
6
  import orjson
7
-
8
- from mistralai.exceptions import (
9
- MistralException,
7
+ from httpx import Headers
8
+
9
+ from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
10
+ from mistralai.exceptions import MistralException
11
+ from mistralai.models.chat_completion import (
12
+ ChatMessage,
13
+ Function,
14
+ ResponseFormat,
15
+ ToolChoice,
10
16
  )
11
- from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
17
 
13
- CLIENT_VERSION = "0.4.1"
18
+ CLIENT_VERSION = "0.4.2"
14
19
 
15
20
 
16
21
  class ClientBase(ABC):
@@ -38,6 +43,14 @@ class ClientBase(ABC):
38
43
 
39
44
  self._version = CLIENT_VERSION
40
45
 
46
+ def _get_model(self, model: Optional[str] = None) -> str:
47
+ if model is not None:
48
+ return model
49
+ else:
50
+ if self._default_model is None:
51
+ raise MistralException(message="model must be provided")
52
+ return self._default_model
53
+
41
54
  def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
55
  parsed_tools: List[Dict[str, Any]] = []
43
56
  for tool in tools:
@@ -73,6 +86,22 @@ class ClientBase(ABC):
73
86
 
74
87
  return parsed_messages
75
88
 
89
+ def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
90
+ model = self._get_model(model)
91
+
92
+ def _check_model_deprecation_header_callback(
93
+ headers: Headers,
94
+ ) -> None:
95
+ if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
96
+ self._logger.warning(
97
+ f"WARNING: The model {model} is deprecated "
98
+ f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
99
+ "Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
100
+ "for more information."
101
+ )
102
+
103
+ return _check_model_deprecation_header_callback
104
+
76
105
  def _make_completion_request(
77
106
  self,
78
107
  prompt: str,
@@ -95,16 +124,14 @@ class ClientBase(ABC):
95
124
  if stop is not None:
96
125
  request_data["stop"] = stop
97
126
 
98
- if model is not None:
99
- request_data["model"] = model
100
- else:
101
- if self._default_model is None:
102
- raise MistralException(message="model must be provided")
103
- request_data["model"] = self._default_model
127
+ request_data["model"] = self._get_model(model)
104
128
 
105
129
  request_data.update(
106
130
  self._build_sampling_params(
107
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
131
+ temperature=temperature,
132
+ max_tokens=max_tokens,
133
+ top_p=top_p,
134
+ random_seed=random_seed,
108
135
  )
109
136
  )
110
137
 
@@ -148,16 +175,14 @@ class ClientBase(ABC):
148
175
  "messages": self._parse_messages(messages),
149
176
  }
150
177
 
151
- if model is not None:
152
- request_data["model"] = model
153
- else:
154
- if self._default_model is None:
155
- raise MistralException(message="model must be provided")
156
- request_data["model"] = self._default_model
178
+ request_data["model"] = self._get_model(model)
157
179
 
158
180
  request_data.update(
159
181
  self._build_sampling_params(
160
- temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
182
+ temperature=temperature,
183
+ max_tokens=max_tokens,
184
+ top_p=top_p,
185
+ random_seed=random_seed,
161
186
  )
162
187
  )
163
188
 
mistralai/constants.py CHANGED
@@ -1,3 +1,5 @@
1
1
  RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
2
2
 
3
3
  ENDPOINT = "https://api.mistral.ai"
4
+
5
+ HEADER_MODEL_DEPRECATION_TIMESTAMP = "x-model-deprecation-timestamp"
mistralai/models/jobs.py CHANGED
@@ -32,6 +32,8 @@ class JobMetadata(BaseModel):
32
32
  train_tokens: int
33
33
  epochs: float
34
34
  expected_duration_seconds: Optional[int]
35
+ cost: Optional[float] = None
36
+ cost_currency: Optional[str] = None
35
37
 
36
38
 
37
39
  class Job(BaseModel):
@@ -1,10 +1,12 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary:
5
+ License: Apache 2.0 License
5
6
  Author: Bam4d
6
7
  Author-email: bam4d@mistral.ai
7
8
  Requires-Python: >=3.9,<4.0
9
+ Classifier: License :: Other/Proprietary License
8
10
  Classifier: Programming Language :: Python :: 3
9
11
  Classifier: Programming Language :: Python :: 3.9
10
12
  Classifier: Programming Language :: Python :: 3.10
@@ -1,8 +1,8 @@
1
1
  mistralai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- mistralai/async_client.py,sha256=nqD8iYAttILXEL0DYqps8-83VixeuQN7PQ5N1_ZfAqo,15696
3
- mistralai/client.py,sha256=jrb_ucBJGPnyj1HPUBQjkUQuTTvohokjxSNPP4xnCqg,15399
4
- mistralai/client_base.py,sha256=oAPqGMBLxQUGd_gTOxsFpNdpzlYNCgJ84EouZMSpdgI,6387
5
- mistralai/constants.py,sha256=FvokZPfTBC-DC6-HfiV83pD3FP6huHk3SUIl0yx5jx8,84
2
+ mistralai/async_client.py,sha256=YgS67borTql5f9VcuMxdDbxxc2v5f74peKu9W-ktS_s,17001
3
+ mistralai/client.py,sha256=kumm79_-9ljrUEpHRQEfLr-sRP5r0bwLnySSsVVdk5M,16741
4
+ mistralai/client_base.py,sha256=obzi3F1VBHN9JwSIJPp_hYlryeVHOGxO3-osJVRRelM,7209
5
+ mistralai/constants.py,sha256=AtACPXuky6_0srsDP4AcpeFa2PyH96RSQxPSUM3rORA,154
6
6
  mistralai/exceptions.py,sha256=R3pswvZyY5CuSbqhVklgfGPVJoz7T7l2VQKMOXK229A,1652
7
7
  mistralai/files.py,sha256=5ZyQaozyOCGE2H2OToMNoNETPi9LIWSACJ4lOrx-DZU,2816
8
8
  mistralai/jobs.py,sha256=I1Ko6fS-6l_9RqEtpB_z7p0bw84aJf4CQ_ZrCjoV2lg,6853
@@ -11,10 +11,10 @@ mistralai/models/chat_completion.py,sha256=Jn8A9SzH7F1rKVjryUKabtYksEeWWegwHreti
11
11
  mistralai/models/common.py,sha256=zatP4aV_LIEpzj3_igsKkJBICwGhmXG0LX3CdO3kn-o,172
12
12
  mistralai/models/embeddings.py,sha256=-VthLQBj6wrq7HXJbGmnkQEEanSemA3MAlaMFh94VBg,331
13
13
  mistralai/models/files.py,sha256=74HzLUqrSoGj13TmC0u5sYyE7DUyDPxtL_LSDdKONl8,398
14
- mistralai/models/jobs.py,sha256=hljCBiHtpMW7pLIq_GjCH8X1lpd__NoZ5st9aouMYiA,2340
14
+ mistralai/models/jobs.py,sha256=peUAsljxJZfRSFh4vgQ5Scub3a1ma68DpkqCp8qrmks,2413
15
15
  mistralai/models/models.py,sha256=q7sZvEYKtaP4x_GGEHE4GUQDsLV_rOt_M_MQkgNems4,823
16
16
  mistralai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- mistralai-0.4.1.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
- mistralai-0.4.1.dist-info/METADATA,sha256=NhUcYQKSsxjbT-ynrBq3lyBnlv3fOIpAQm9oXlG1rFI,1824
19
- mistralai-0.4.1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
20
- mistralai-0.4.1.dist-info/RECORD,,
17
+ mistralai-0.4.2.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
18
+ mistralai-0.4.2.dist-info/METADATA,sha256=HC_8vxzUDo0J6F-Iwmwi5kSP5wWuGFbnsgvxSgPSl-I,1901
19
+ mistralai-0.4.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
20
+ mistralai-0.4.2.dist-info/RECORD,,