mistralai 0.1.8__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.1.8
3
+ Version: 0.3.0
4
4
  Summary:
5
5
  Author: Bam4d
6
6
  Author-email: bam4d@mistral.ai
@@ -10,7 +10,7 @@ Classifier: Programming Language :: Python :: 3.9
10
10
  Classifier: Programming Language :: Python :: 3.10
11
11
  Classifier: Programming Language :: Python :: 3.11
12
12
  Classifier: Programming Language :: Python :: 3.12
13
- Requires-Dist: httpx (>=0.25.2,<0.26.0)
13
+ Requires-Dist: httpx (>=0.25.2,<1)
14
14
  Requires-Dist: orjson (>=3.9.10,<4.0.0)
15
15
  Requires-Dist: pydantic (>=2.5.2,<3.0.0)
16
16
  Description-Content-Type: text/markdown
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "mistralai"
3
- version = "0.1.8"
3
+ version = "0.3.0"
4
4
  description = ""
5
5
  authors = ["Bam4d <bam4d@mistral.ai>"]
6
6
  readme = "README.md"
@@ -26,7 +26,7 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
26
26
  python = "^3.9"
27
27
  orjson = "^3.9.10"
28
28
  pydantic = "^2.5.2"
29
- httpx = "^0.25.2"
29
+ httpx = ">= 0.25.2, < 1"
30
30
 
31
31
 
32
32
  [tool.poetry.group.dev.dependencies]
@@ -1,6 +1,5 @@
1
- import os
1
+ import asyncio
2
2
  import posixpath
3
- import time
4
3
  from json import JSONDecodeError
5
4
  from typing import Any, AsyncGenerator, Dict, List, Optional, Union
6
5
 
@@ -34,7 +33,7 @@ from mistralai.models.models import ModelList
34
33
  class MistralAsyncClient(ClientBase):
35
34
  def __init__(
36
35
  self,
37
- api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY", None),
36
+ api_key: Optional[str] = None,
38
37
  endpoint: str = ENDPOINT,
39
38
  max_retries: int = 5,
40
39
  timeout: int = 120,
@@ -93,7 +92,7 @@ class MistralAsyncClient(ClientBase):
93
92
  async def _request(
94
93
  self,
95
94
  method: str,
96
- json: Dict[str, Any],
95
+ json: Optional[Dict[str, Any]],
97
96
  path: str,
98
97
  stream: bool = False,
99
98
  attempt: int = 1,
@@ -151,7 +150,7 @@ class MistralAsyncClient(ClientBase):
151
150
  if attempt > self._max_retries:
152
151
  raise MistralAPIStatusException.from_response(response, message=str(e)) from e
153
152
  backoff = 2.0**attempt # exponential backoff
154
- time.sleep(backoff)
153
+ await asyncio.sleep(backoff)
155
154
 
156
155
  # Retry as a generator
157
156
  async for r in self._request(method, json, path, stream=stream, attempt=attempt):
@@ -292,3 +291,74 @@ class MistralAsyncClient(ClientBase):
292
291
  return ModelList(**response)
293
292
 
294
293
  raise MistralException("No response received")
294
+
295
+ async def completion(
296
+ self,
297
+ model: str,
298
+ prompt: str,
299
+ suffix: Optional[str] = None,
300
+ temperature: Optional[float] = None,
301
+ max_tokens: Optional[int] = None,
302
+ top_p: Optional[float] = None,
303
+ random_seed: Optional[int] = None,
304
+ stop: Optional[List[str]] = None,
305
+ ) -> ChatCompletionResponse:
306
+ """An asynchronous completion endpoint that returns a single response.
307
+
308
+ Args:
309
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
310
+ prompt (str): the prompt to complete
311
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
312
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
313
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
314
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
315
+ Defaults to None.
316
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
317
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
318
+ Returns:
319
+ Dict[str, Any]: a response object containing the generated text.
320
+ """
321
+ request = self._make_completion_request(
322
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
323
+ )
324
+ single_response = self._request("post", request, "v1/fim/completions")
325
+
326
+ async for response in single_response:
327
+ return ChatCompletionResponse(**response)
328
+
329
+ raise MistralException("No response received")
330
+
331
+ async def completion_stream(
332
+ self,
333
+ model: str,
334
+ prompt: str,
335
+ suffix: Optional[str] = None,
336
+ temperature: Optional[float] = None,
337
+ max_tokens: Optional[int] = None,
338
+ top_p: Optional[float] = None,
339
+ random_seed: Optional[int] = None,
340
+ stop: Optional[List[str]] = None,
341
+ ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
342
+ """An asynchronous completion endpoint that returns a streaming response.
343
+
344
+ Args:
345
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
346
+ prompt (str): the prompt to complete
347
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
348
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
349
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
350
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
351
+ Defaults to None.
352
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
353
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
354
+
355
+ Returns:
356
+ Dict[str, Any]: a response object containing the generated text.
357
+ """
358
+ request = self._make_completion_request(
359
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
360
+ )
361
+ async_response = self._request("post", request, "v1/fim/completions", stream=True)
362
+
363
+ async for json_response in async_response:
364
+ yield ChatCompletionStreamResponse(**json_response)
@@ -1,4 +1,3 @@
1
- import os
2
1
  import posixpath
3
2
  import time
4
3
  from json import JSONDecodeError
@@ -31,7 +30,7 @@ class MistralClient(ClientBase):
31
30
 
32
31
  def __init__(
33
32
  self,
34
- api_key: Optional[str] = os.environ.get("MISTRAL_API_KEY", None),
33
+ api_key: Optional[str] = None,
35
34
  endpoint: str = ENDPOINT,
36
35
  max_retries: int = 5,
37
36
  timeout: int = 120,
@@ -86,7 +85,7 @@ class MistralClient(ClientBase):
86
85
  def _request(
87
86
  self,
88
87
  method: str,
89
- json: Dict[str, Any],
88
+ json: Optional[Dict[str, Any]],
90
89
  path: str,
91
90
  stream: bool = False,
92
91
  attempt: int = 1,
@@ -286,3 +285,77 @@ class MistralClient(ClientBase):
286
285
  return ModelList(**response)
287
286
 
288
287
  raise MistralException("No response received")
288
+
289
+ def completion(
290
+ self,
291
+ model: str,
292
+ prompt: str,
293
+ suffix: Optional[str] = None,
294
+ temperature: Optional[float] = None,
295
+ max_tokens: Optional[int] = None,
296
+ top_p: Optional[float] = None,
297
+ random_seed: Optional[int] = None,
298
+ stop: Optional[List[str]] = None,
299
+ ) -> ChatCompletionResponse:
300
+ """A completion endpoint that returns a single response.
301
+
302
+ Args:
303
+ model (str): model the name of the model to get completion with, e.g. codestral-latest
304
+ prompt (str): the prompt to complete
305
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
306
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
307
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
308
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
309
+ Defaults to None.
310
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
311
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
312
+
313
+ Returns:
314
+ Dict[str, Any]: a response object containing the generated text.
315
+ """
316
+ request = self._make_completion_request(
317
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
318
+ )
319
+
320
+ single_response = self._request("post", request, "v1/fim/completions", stream=False)
321
+
322
+ for response in single_response:
323
+ return ChatCompletionResponse(**response)
324
+
325
+ raise MistralException("No response received")
326
+
327
+ def completion_stream(
328
+ self,
329
+ model: str,
330
+ prompt: str,
331
+ suffix: Optional[str] = None,
332
+ temperature: Optional[float] = None,
333
+ max_tokens: Optional[int] = None,
334
+ top_p: Optional[float] = None,
335
+ random_seed: Optional[int] = None,
336
+ stop: Optional[List[str]] = None,
337
+ ) -> Iterable[ChatCompletionStreamResponse]:
338
+ """An asynchronous completion endpoint that streams responses.
339
+
340
+ Args:
341
+ model (str): model the name of the model to get completions with, e.g. codestral-latest
342
+ prompt (str): the prompt to complete
343
+ suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
344
+ temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
345
+ max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
346
+ top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
347
+ Defaults to None.
348
+ random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
349
+ stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
350
+
351
+ Returns:
352
+ Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
353
+ """
354
+ request = self._make_completion_request(
355
+ prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
356
+ )
357
+
358
+ response = self._request("post", request, "v1/fim/completions", stream=True)
359
+
360
+ for json_streamed_response in response:
361
+ yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -10,10 +10,7 @@ from mistralai.exceptions import (
10
10
  )
11
11
  from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
12
12
 
13
- logging.basicConfig(
14
- format="%(asctime)s %(levelname)s %(name)s: %(message)s",
15
- level=os.getenv("LOG_LEVEL", "ERROR"),
16
- )
13
+ CLIENT_VERSION = "0.2.0"
17
14
 
18
15
 
19
16
  class ClientBase(ABC):
@@ -27,16 +24,19 @@ class ClientBase(ABC):
27
24
  self._max_retries = max_retries
28
25
  self._timeout = timeout
29
26
 
30
- self._endpoint = endpoint
27
+ if api_key is None:
28
+ api_key = os.environ.get("MISTRAL_API_KEY")
29
+ if api_key is None:
30
+ raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
31
31
  self._api_key = api_key
32
+ self._endpoint = endpoint
32
33
  self._logger = logging.getLogger(__name__)
33
34
 
34
35
  # For azure endpoints, we default to the mistral model
35
36
  if "inference.azure.com" in self._endpoint:
36
37
  self._default_model = "mistral"
37
38
 
38
- # This should be automatically updated by the deploy script
39
- self._version = "0.1.8"
39
+ self._version = CLIENT_VERSION
40
40
 
41
41
  def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
42
42
  parsed_tools: List[Dict[str, Any]] = []
@@ -73,6 +73,63 @@ class ClientBase(ABC):
73
73
 
74
74
  return parsed_messages
75
75
 
76
+ def _make_completion_request(
77
+ self,
78
+ prompt: str,
79
+ model: Optional[str] = None,
80
+ suffix: Optional[str] = None,
81
+ temperature: Optional[float] = None,
82
+ max_tokens: Optional[int] = None,
83
+ top_p: Optional[float] = None,
84
+ random_seed: Optional[int] = None,
85
+ stop: Optional[List[str]] = None,
86
+ stream: Optional[bool] = False,
87
+ ) -> Dict[str, Any]:
88
+ request_data: Dict[str, Any] = {
89
+ "prompt": prompt,
90
+ "suffix": suffix,
91
+ "model": model,
92
+ "stream": stream,
93
+ }
94
+
95
+ if stop is not None:
96
+ request_data["stop"] = stop
97
+
98
+ if model is not None:
99
+ request_data["model"] = model
100
+ else:
101
+ if self._default_model is None:
102
+ raise MistralException(message="model must be provided")
103
+ request_data["model"] = self._default_model
104
+
105
+ request_data.update(
106
+ self._build_sampling_params(
107
+ temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
108
+ )
109
+ )
110
+
111
+ self._logger.debug(f"Completion request: {request_data}")
112
+
113
+ return request_data
114
+
115
+ def _build_sampling_params(
116
+ self,
117
+ max_tokens: Optional[int],
118
+ random_seed: Optional[int],
119
+ temperature: Optional[float],
120
+ top_p: Optional[float],
121
+ ) -> Dict[str, Any]:
122
+ params = {}
123
+ if temperature is not None:
124
+ params["temperature"] = temperature
125
+ if max_tokens is not None:
126
+ params["max_tokens"] = max_tokens
127
+ if top_p is not None:
128
+ params["top_p"] = top_p
129
+ if random_seed is not None:
130
+ params["random_seed"] = random_seed
131
+ return params
132
+
76
133
  def _make_chat_request(
77
134
  self,
78
135
  messages: List[Any],
@@ -99,16 +156,14 @@ class ClientBase(ABC):
99
156
  raise MistralException(message="model must be provided")
100
157
  request_data["model"] = self._default_model
101
158
 
159
+ request_data.update(
160
+ self._build_sampling_params(
161
+ temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
162
+ )
163
+ )
164
+
102
165
  if tools is not None:
103
166
  request_data["tools"] = self._parse_tools(tools)
104
- if temperature is not None:
105
- request_data["temperature"] = temperature
106
- if max_tokens is not None:
107
- request_data["max_tokens"] = max_tokens
108
- if top_p is not None:
109
- request_data["top_p"] = top_p
110
- if random_seed is not None:
111
- request_data["random_seed"] = random_seed
112
167
  if stream is not None:
113
168
  request_data["stream"] = stream
114
169
 
@@ -1,5 +1,3 @@
1
-
2
-
3
1
  RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
4
2
 
5
3
  ENDPOINT = "https://api.mistral.ai"
@@ -35,9 +35,7 @@ class MistralAPIException(MistralException):
35
35
  self.headers = headers or {}
36
36
 
37
37
  @classmethod
38
- def from_response(
39
- cls, response: Response, message: Optional[str] = None
40
- ) -> MistralAPIException:
38
+ def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
41
39
  return cls(
42
40
  message=message or response.text,
43
41
  http_status=response.status_code,
@@ -47,8 +45,10 @@ class MistralAPIException(MistralException):
47
45
  def __repr__(self) -> str:
48
46
  return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
49
47
 
48
+
50
49
  class MistralAPIStatusException(MistralAPIException):
51
50
  """Returned when we receive a non-200 response from the API that we should retry"""
52
51
 
52
+
53
53
  class MistralConnectionException(MistralException):
54
54
  """Returned when the SDK can not reach the API server for any reason"""
@@ -47,6 +47,7 @@ class ChatMessage(BaseModel):
47
47
  content: Union[str, List[str]]
48
48
  name: Optional[str] = None
49
49
  tool_calls: Optional[List[ToolCall]] = None
50
+ tool_call_id: Optional[str] = None
50
51
 
51
52
 
52
53
  class DeltaMessage(BaseModel):
@@ -17,6 +17,7 @@ class ModelPermission(BaseModel):
17
17
  group: Optional[str] = None
18
18
  is_blocking: bool = False
19
19
 
20
+
20
21
  class ModelCard(BaseModel):
21
22
  id: str
22
23
  object: str
File without changes
File without changes