mistralai 0.0.12__tar.gz → 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,19 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mistralai
3
- Version: 0.0.12
3
+ Version: 0.1.2
4
4
  Summary:
5
5
  Author: Bam4d
6
6
  Author-email: bam4d@mistral.ai
7
- Requires-Python: >=3.8,<4.0
7
+ Requires-Python: >=3.9,<4.0
8
8
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
9
  Classifier: Programming Language :: Python :: 3.9
11
10
  Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
13
12
  Classifier: Programming Language :: Python :: 3.12
14
13
  Requires-Dist: httpx (>=0.25.2,<0.26.0)
15
14
  Requires-Dist: orjson (>=3.9.10,<4.0.0)
15
+ Requires-Dist: pandas (>=2.2.0,<3.0.0)
16
+ Requires-Dist: pyarrow (>=15.0.0,<16.0.0)
16
17
  Requires-Dist: pydantic (>=2.5.2,<3.0.0)
17
18
  Description-Content-Type: text/markdown
18
19
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "mistralai"
3
- version = "0.0.12"
3
+ version = "0.1.2"
4
4
  description = ""
5
5
  authors = ["Bam4d <bam4d@mistral.ai>"]
6
6
  readme = "README.md"
@@ -23,10 +23,12 @@ exclude = ["docs", "tests", "examples", "tools", "build"]
23
23
 
24
24
 
25
25
  [tool.poetry.dependencies]
26
- python = "^3.8"
26
+ python = "^3.9"
27
27
  orjson = "^3.9.10"
28
28
  pydantic = "^2.5.2"
29
29
  httpx = "^0.25.2"
30
+ pandas = "^2.2.0"
31
+ pyarrow = "^15.0.0"
30
32
 
31
33
 
32
34
  [tool.poetry.group.dev.dependencies]
@@ -24,7 +24,8 @@ from mistralai.exceptions import (
24
24
  from mistralai.models.chat_completion import (
25
25
  ChatCompletionResponse,
26
26
  ChatCompletionStreamResponse,
27
- ChatMessage,
27
+ ResponseFormat,
28
+ ToolChoice,
28
29
  )
29
30
  from mistralai.models.embeddings import EmbeddingResponse
30
31
  from mistralai.models.models import ModelList
@@ -101,9 +102,7 @@ class MistralAsyncClient(ClientBase):
101
102
  except ConnectError as e:
102
103
  raise MistralConnectionException(str(e)) from e
103
104
  except RequestError as e:
104
- raise MistralException(
105
- f"Unexpected exception ({e.__class__.__name__}): {e}"
106
- ) from e
105
+ raise MistralException(f"Unexpected exception ({e.__class__.__name__}): {e}") from e
107
106
  except JSONDecodeError as e:
108
107
  raise MistralAPIException.from_response(
109
108
  response,
@@ -112,34 +111,33 @@ class MistralAsyncClient(ClientBase):
112
111
  except MistralAPIStatusException as e:
113
112
  attempt += 1
114
113
  if attempt > self._max_retries:
115
- raise MistralAPIStatusException.from_response(
116
- response, message=str(e)
117
- ) from e
114
+ raise MistralAPIStatusException.from_response(response, message=str(e)) from e
118
115
  backoff = 2.0**attempt # exponential backoff
119
116
  time.sleep(backoff)
120
117
 
121
118
  # Retry as a generator
122
- async for r in self._request(
123
- method, json, path, stream=stream, attempt=attempt
124
- ):
119
+ async for r in self._request(method, json, path, stream=stream, attempt=attempt):
125
120
  yield r
126
121
 
127
122
  async def chat(
128
123
  self,
129
- model: str,
130
- messages: List[ChatMessage],
124
+ messages: List[Any],
125
+ model: Optional[str] = None,
126
+ tools: Optional[List[Dict[str, Any]]] = None,
131
127
  temperature: Optional[float] = None,
132
128
  max_tokens: Optional[int] = None,
133
129
  top_p: Optional[float] = None,
134
130
  random_seed: Optional[int] = None,
135
131
  safe_mode: bool = False,
136
132
  safe_prompt: bool = False,
133
+ tool_choice: Optional[Union[str, ToolChoice]] = None,
134
+ response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
137
135
  ) -> ChatCompletionResponse:
138
136
  """A asynchronous chat endpoint that returns a single response.
139
137
 
140
138
  Args:
141
139
  model (str): model the name of the model to chat with, e.g. mistral-tiny
142
- messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
140
+ messages (List[Any]): messages an array of messages to chat with, e.g.
143
141
  [{role: 'user', content: 'What is the best French cheese?'}]
144
142
  temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
145
143
  max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
@@ -153,14 +151,17 @@ class MistralAsyncClient(ClientBase):
153
151
  ChatCompletionResponse: a response object containing the generated text.
154
152
  """
155
153
  request = self._make_chat_request(
156
- model,
157
154
  messages,
155
+ model,
156
+ tools=tools,
158
157
  temperature=temperature,
159
158
  max_tokens=max_tokens,
160
159
  top_p=top_p,
161
160
  random_seed=random_seed,
162
161
  stream=False,
163
162
  safe_prompt=safe_mode or safe_prompt,
163
+ tool_choice=tool_choice,
164
+ response_format=response_format,
164
165
  )
165
166
 
166
167
  single_response = self._request("post", request, "v1/chat/completions")
@@ -172,21 +173,25 @@ class MistralAsyncClient(ClientBase):
172
173
 
173
174
  async def chat_stream(
174
175
  self,
175
- model: str,
176
- messages: List[ChatMessage],
176
+ messages: List[Any],
177
+ model: Optional[str] = None,
178
+ tools: Optional[List[Dict[str, Any]]] = None,
177
179
  temperature: Optional[float] = None,
178
180
  max_tokens: Optional[int] = None,
179
181
  top_p: Optional[float] = None,
180
182
  random_seed: Optional[int] = None,
181
183
  safe_mode: bool = False,
182
184
  safe_prompt: bool = False,
185
+ tool_choice: Optional[Union[str, ToolChoice]] = None,
186
+ response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
183
187
  ) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
184
188
  """An Asynchronous chat endpoint that streams responses.
185
189
 
186
190
  Args:
187
191
  model (str): model the name of the model to chat with, e.g. mistral-tiny
188
- messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
192
+ messages (List[Any]): messages an array of messages to chat with, e.g.
189
193
  [{role: 'user', content: 'What is the best French cheese?'}]
194
+ tools (Optional[List[Function]], optional): a list of tools to use.
190
195
  temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
191
196
  max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
192
197
  top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
@@ -201,25 +206,24 @@ class MistralAsyncClient(ClientBase):
201
206
  """
202
207
 
203
208
  request = self._make_chat_request(
204
- model,
205
209
  messages,
210
+ model,
211
+ tools=tools,
206
212
  temperature=temperature,
207
213
  max_tokens=max_tokens,
208
214
  top_p=top_p,
209
215
  random_seed=random_seed,
210
216
  stream=True,
211
217
  safe_prompt=safe_mode or safe_prompt,
218
+ tool_choice=tool_choice,
219
+ response_format=response_format,
212
220
  )
213
- async_response = self._request(
214
- "post", request, "v1/chat/completions", stream=True
215
- )
221
+ async_response = self._request("post", request, "v1/chat/completions", stream=True)
216
222
 
217
223
  async for json_response in async_response:
218
224
  yield ChatCompletionStreamResponse(**json_response)
219
225
 
220
- async def embeddings(
221
- self, model: str, input: Union[str, List[str]]
222
- ) -> EmbeddingResponse:
226
+ async def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingResponse:
223
227
  """An asynchronous embeddings endpoint that returns embeddings for a single, or batch of inputs
224
228
 
225
229
  Args:
@@ -17,7 +17,8 @@ from mistralai.exceptions import (
17
17
  from mistralai.models.chat_completion import (
18
18
  ChatCompletionResponse,
19
19
  ChatCompletionStreamResponse,
20
- ChatMessage,
20
+ ResponseFormat,
21
+ ToolChoice,
21
22
  )
22
23
  from mistralai.models.embeddings import EmbeddingResponse
23
24
  from mistralai.models.models import ModelList
@@ -38,9 +39,8 @@ class MistralClient(ClientBase):
38
39
  super().__init__(endpoint, api_key, max_retries, timeout)
39
40
 
40
41
  self._client = Client(
41
- follow_redirects=True,
42
- timeout=self._timeout,
43
- transport=HTTPTransport(retries=self._max_retries))
42
+ follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
43
+ )
44
44
 
45
45
  def __del__(self) -> None:
46
46
  self._client.close()
@@ -95,9 +95,7 @@ class MistralClient(ClientBase):
95
95
  except ConnectError as e:
96
96
  raise MistralConnectionException(str(e)) from e
97
97
  except RequestError as e:
98
- raise MistralException(
99
- f"Unexpected exception ({e.__class__.__name__}): {e}"
100
- ) from e
98
+ raise MistralException(f"Unexpected exception ({e.__class__.__name__}): {e}") from e
101
99
  except JSONDecodeError as e:
102
100
  raise MistralAPIException.from_response(
103
101
  response,
@@ -106,9 +104,7 @@ class MistralClient(ClientBase):
106
104
  except MistralAPIStatusException as e:
107
105
  attempt += 1
108
106
  if attempt > self._max_retries:
109
- raise MistralAPIStatusException.from_response(
110
- response, message=str(e)
111
- ) from e
107
+ raise MistralAPIStatusException.from_response(response, message=str(e)) from e
112
108
  backoff = 2.0**attempt # exponential backoff
113
109
  time.sleep(backoff)
114
110
 
@@ -118,21 +114,25 @@ class MistralClient(ClientBase):
118
114
 
119
115
  def chat(
120
116
  self,
121
- model: str,
122
- messages: List[ChatMessage],
117
+ messages: List[Any],
118
+ model: Optional[str] = None,
119
+ tools: Optional[List[Dict[str, Any]]] = None,
123
120
  temperature: Optional[float] = None,
124
121
  max_tokens: Optional[int] = None,
125
122
  top_p: Optional[float] = None,
126
123
  random_seed: Optional[int] = None,
127
124
  safe_mode: bool = False,
128
125
  safe_prompt: bool = False,
126
+ tool_choice: Optional[Union[str, ToolChoice]] = None,
127
+ response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
129
128
  ) -> ChatCompletionResponse:
130
129
  """A chat endpoint that returns a single response.
131
130
 
132
131
  Args:
133
132
  model (str): model the name of the model to chat with, e.g. mistral-tiny
134
- messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
133
+ messages (List[Any]): messages an array of messages to chat with, e.g.
135
134
  [{role: 'user', content: 'What is the best French cheese?'}]
135
+ tools (Optional[List[Function]], optional): a list of tools to use.
136
136
  temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
137
137
  max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
138
138
  top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
@@ -145,14 +145,17 @@ class MistralClient(ClientBase):
145
145
  ChatCompletionResponse: a response object containing the generated text.
146
146
  """
147
147
  request = self._make_chat_request(
148
- model,
149
148
  messages,
149
+ model,
150
+ tools=tools,
150
151
  temperature=temperature,
151
152
  max_tokens=max_tokens,
152
153
  top_p=top_p,
153
154
  random_seed=random_seed,
154
155
  stream=False,
155
156
  safe_prompt=safe_mode or safe_prompt,
157
+ tool_choice=tool_choice,
158
+ response_format=response_format,
156
159
  )
157
160
 
158
161
  single_response = self._request("post", request, "v1/chat/completions")
@@ -164,21 +167,25 @@ class MistralClient(ClientBase):
164
167
 
165
168
  def chat_stream(
166
169
  self,
167
- model: str,
168
- messages: List[ChatMessage],
170
+ messages: List[Any],
171
+ model: Optional[str] = None,
172
+ tools: Optional[List[Dict[str, Any]]] = None,
169
173
  temperature: Optional[float] = None,
170
174
  max_tokens: Optional[int] = None,
171
175
  top_p: Optional[float] = None,
172
176
  random_seed: Optional[int] = None,
173
177
  safe_mode: bool = False,
174
178
  safe_prompt: bool = False,
179
+ tool_choice: Optional[Union[str, ToolChoice]] = None,
180
+ response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
175
181
  ) -> Iterable[ChatCompletionStreamResponse]:
176
182
  """A chat endpoint that streams responses.
177
183
 
178
184
  Args:
179
185
  model (str): model the name of the model to chat with, e.g. mistral-tiny
180
- messages (List[ChatMessage]): messages an array of messages to chat with, e.g.
186
+ messages (List[Any]): messages an array of messages to chat with, e.g.
181
187
  [{role: 'user', content: 'What is the best French cheese?'}]
188
+ tools (Optional[List[Function]], optional): a list of tools to use.
182
189
  temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
183
190
  max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
184
191
  top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
@@ -192,14 +199,17 @@ class MistralClient(ClientBase):
192
199
  A generator that yields ChatCompletionStreamResponse objects.
193
200
  """
194
201
  request = self._make_chat_request(
195
- model,
196
202
  messages,
203
+ model,
204
+ tools=tools,
197
205
  temperature=temperature,
198
206
  max_tokens=max_tokens,
199
207
  top_p=top_p,
200
208
  random_seed=random_seed,
201
209
  stream=True,
202
210
  safe_prompt=safe_mode or safe_prompt,
211
+ tool_choice=tool_choice,
212
+ response_format=response_format,
203
213
  )
204
214
 
205
215
  response = self._request("post", request, "v1/chat/completions", stream=True)
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  import os
3
3
  from abc import ABC
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Dict, List, Optional, Union
5
5
 
6
6
  import orjson
7
7
  from httpx import Response
@@ -12,7 +12,7 @@ from mistralai.exceptions import (
12
12
  MistralAPIStatusException,
13
13
  MistralException,
14
14
  )
15
- from mistralai.models.chat_completion import ChatMessage
15
+ from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
16
16
 
17
17
  logging.basicConfig(
18
18
  format="%(asctime)s %(levelname)s %(name)s: %(message)s",
@@ -35,25 +35,76 @@ class ClientBase(ABC):
35
35
  self._api_key = api_key
36
36
  self._logger = logging.getLogger(__name__)
37
37
 
38
+ # For azure endpoints, we default to the mistral model
39
+ if "inference.azure.com" in self._endpoint:
40
+ self._default_model = "mistral"
41
+
38
42
  # This should be automatically updated by the deploy script
39
- self._version = "0.0.12"
43
+ self._version = "0.1.2"
44
+
45
+ def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
46
+ parsed_tools: List[Dict[str, Any]] = []
47
+ for tool in tools:
48
+ if tool["type"] == "function":
49
+ parsed_function = {}
50
+ parsed_function["type"] = tool["type"]
51
+ if isinstance(tool["function"], Function):
52
+ parsed_function["function"] = tool["function"].model_dump(exclude_none=True)
53
+ else:
54
+ parsed_function["function"] = tool["function"]
55
+
56
+ parsed_tools.append(parsed_function)
57
+
58
+ return parsed_tools
59
+
60
+ def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
61
+ if isinstance(tool_choice, ToolChoice):
62
+ return tool_choice.value
63
+ return tool_choice
64
+
65
+ def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
66
+ if isinstance(response_format, ResponseFormat):
67
+ return response_format.model_dump(exclude_none=True)
68
+ return response_format
69
+
70
+ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
71
+ parsed_messages: List[Dict[str, Any]] = []
72
+ for message in messages:
73
+ if isinstance(message, ChatMessage):
74
+ parsed_messages.append(message.model_dump(exclude_none=True))
75
+ else:
76
+ parsed_messages.append(message)
77
+
78
+ return parsed_messages
40
79
 
41
80
  def _make_chat_request(
42
81
  self,
43
- model: str,
44
- messages: List[ChatMessage],
82
+ messages: List[Any],
83
+ model: Optional[str] = None,
84
+ tools: Optional[List[Dict[str, Any]]] = None,
45
85
  temperature: Optional[float] = None,
46
86
  max_tokens: Optional[int] = None,
47
87
  top_p: Optional[float] = None,
48
88
  random_seed: Optional[int] = None,
49
89
  stream: Optional[bool] = None,
50
90
  safe_prompt: Optional[bool] = False,
91
+ tool_choice: Optional[Union[str, ToolChoice]] = None,
92
+ response_format: Optional[Union[Dict[str, str], ResponseFormat]] = None,
51
93
  ) -> Dict[str, Any]:
52
94
  request_data: Dict[str, Any] = {
53
- "model": model,
54
- "messages": [msg.model_dump() for msg in messages],
95
+ "messages": self._parse_messages(messages),
55
96
  "safe_prompt": safe_prompt,
56
97
  }
98
+
99
+ if model is not None:
100
+ request_data["model"] = model
101
+ else:
102
+ if self._default_model is None:
103
+ raise MistralException(message="model must be provided")
104
+ request_data["model"] = self._default_model
105
+
106
+ if tools is not None:
107
+ request_data["tools"] = self._parse_tools(tools)
57
108
  if temperature is not None:
58
109
  request_data["temperature"] = temperature
59
110
  if max_tokens is not None:
@@ -65,6 +116,11 @@ class ClientBase(ABC):
65
116
  if stream is not None:
66
117
  request_data["stream"] = stream
67
118
 
119
+ if tool_choice is not None:
120
+ request_data["tool_choice"] = self._parse_tool_choice(tool_choice)
121
+ if response_format is not None:
122
+ request_data["response_format"] = self._parse_response_format(response_format)
123
+
68
124
  self._logger.debug(f"Chat request: {request_data}")
69
125
 
70
126
  return request_data
@@ -1,22 +1,61 @@
1
1
  from enum import Enum
2
- from typing import List, Optional
2
+ from typing import List, Optional, Union
3
3
 
4
4
  from pydantic import BaseModel
5
5
 
6
6
  from mistralai.models.common import UsageInfo
7
7
 
8
8
 
9
+ class Function(BaseModel):
10
+ name: str
11
+ description: str
12
+ parameters: dict
13
+
14
+
15
+ class ToolType(str, Enum):
16
+ function = "function"
17
+
18
+
19
+ class FunctionCall(BaseModel):
20
+ name: str
21
+ arguments: str
22
+
23
+
24
+ class ToolCall(BaseModel):
25
+ id: str = "null"
26
+ type: ToolType = ToolType.function
27
+ function: FunctionCall
28
+
29
+
30
+ class ResponseFormats(str, Enum):
31
+ text: str = "text"
32
+ json_object: str = "json_object"
33
+
34
+
35
+ class ToolChoice(str, Enum):
36
+ auto: str = "auto"
37
+ any: str = "any"
38
+ none: str = "none"
39
+
40
+
41
+ class ResponseFormat(BaseModel):
42
+ type: ResponseFormats = ResponseFormats.text
43
+
44
+
9
45
  class ChatMessage(BaseModel):
10
46
  role: str
11
- content: str
47
+ content: Union[str, List[str]]
48
+ name: Optional[str] = None
49
+ tool_calls: Optional[List[ToolCall]] = None
12
50
 
13
51
 
14
52
  class DeltaMessage(BaseModel):
15
53
  role: Optional[str] = None
16
54
  content: Optional[str] = None
55
+ tool_calls: Optional[List[ToolCall]] = None
17
56
 
18
57
 
19
- class FinishReason(Enum):
58
+ class FinishReason(str, Enum):
20
59
  stop = "stop"
21
60
  length = "length"
22
61
  error = "error"
File without changes
File without changes