mistralai 1.7.1__py3-none-any.whl → 1.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +2 -2
- mistralai/beta.py +22 -0
- mistralai/conversations.py +2660 -0
- mistralai/embeddings.py +12 -0
- mistralai/extra/__init__.py +10 -2
- mistralai/extra/exceptions.py +14 -0
- mistralai/extra/mcp/__init__.py +0 -0
- mistralai/extra/mcp/auth.py +166 -0
- mistralai/extra/mcp/base.py +155 -0
- mistralai/extra/mcp/sse.py +165 -0
- mistralai/extra/mcp/stdio.py +22 -0
- mistralai/extra/run/__init__.py +0 -0
- mistralai/extra/run/context.py +295 -0
- mistralai/extra/run/result.py +212 -0
- mistralai/extra/run/tools.py +225 -0
- mistralai/extra/run/utils.py +36 -0
- mistralai/extra/tests/test_struct_chat.py +1 -1
- mistralai/mistral_agents.py +1160 -0
- mistralai/models/__init__.py +472 -1
- mistralai/models/agent.py +129 -0
- mistralai/models/agentconversation.py +71 -0
- mistralai/models/agentcreationrequest.py +109 -0
- mistralai/models/agenthandoffdoneevent.py +33 -0
- mistralai/models/agenthandoffentry.py +75 -0
- mistralai/models/agenthandoffstartedevent.py +33 -0
- mistralai/models/agents_api_v1_agents_getop.py +16 -0
- mistralai/models/agents_api_v1_agents_listop.py +24 -0
- mistralai/models/agents_api_v1_agents_update_versionop.py +21 -0
- mistralai/models/agents_api_v1_agents_updateop.py +23 -0
- mistralai/models/agents_api_v1_conversations_append_streamop.py +28 -0
- mistralai/models/agents_api_v1_conversations_appendop.py +28 -0
- mistralai/models/agents_api_v1_conversations_getop.py +33 -0
- mistralai/models/agents_api_v1_conversations_historyop.py +16 -0
- mistralai/models/agents_api_v1_conversations_listop.py +37 -0
- mistralai/models/agents_api_v1_conversations_messagesop.py +16 -0
- mistralai/models/agents_api_v1_conversations_restart_streamop.py +26 -0
- mistralai/models/agents_api_v1_conversations_restartop.py +26 -0
- mistralai/models/agentupdaterequest.py +111 -0
- mistralai/models/builtinconnectors.py +13 -0
- mistralai/models/chatcompletionresponse.py +6 -6
- mistralai/models/codeinterpretertool.py +17 -0
- mistralai/models/completionargs.py +100 -0
- mistralai/models/completionargsstop.py +13 -0
- mistralai/models/completionjobout.py +3 -3
- mistralai/models/conversationappendrequest.py +35 -0
- mistralai/models/conversationappendstreamrequest.py +37 -0
- mistralai/models/conversationevents.py +72 -0
- mistralai/models/conversationhistory.py +58 -0
- mistralai/models/conversationinputs.py +14 -0
- mistralai/models/conversationmessages.py +28 -0
- mistralai/models/conversationrequest.py +133 -0
- mistralai/models/conversationresponse.py +51 -0
- mistralai/models/conversationrestartrequest.py +42 -0
- mistralai/models/conversationrestartstreamrequest.py +44 -0
- mistralai/models/conversationstreamrequest.py +135 -0
- mistralai/models/conversationusageinfo.py +63 -0
- mistralai/models/documentlibrarytool.py +22 -0
- mistralai/models/embeddingdtype.py +7 -0
- mistralai/models/embeddingrequest.py +43 -3
- mistralai/models/fimcompletionresponse.py +6 -6
- mistralai/models/functioncallentry.py +76 -0
- mistralai/models/functioncallentryarguments.py +15 -0
- mistralai/models/functioncallevent.py +36 -0
- mistralai/models/functionresultentry.py +69 -0
- mistralai/models/functiontool.py +21 -0
- mistralai/models/imagegenerationtool.py +17 -0
- mistralai/models/inputentries.py +18 -0
- mistralai/models/messageentries.py +18 -0
- mistralai/models/messageinputcontentchunks.py +26 -0
- mistralai/models/messageinputentry.py +89 -0
- mistralai/models/messageoutputcontentchunks.py +30 -0
- mistralai/models/messageoutputentry.py +100 -0
- mistralai/models/messageoutputevent.py +93 -0
- mistralai/models/modelconversation.py +127 -0
- mistralai/models/outputcontentchunks.py +30 -0
- mistralai/models/responsedoneevent.py +25 -0
- mistralai/models/responseerrorevent.py +27 -0
- mistralai/models/responsestartedevent.py +24 -0
- mistralai/models/ssetypes.py +18 -0
- mistralai/models/toolexecutiondoneevent.py +34 -0
- mistralai/models/toolexecutionentry.py +70 -0
- mistralai/models/toolexecutionstartedevent.py +31 -0
- mistralai/models/toolfilechunk.py +61 -0
- mistralai/models/toolreferencechunk.py +61 -0
- mistralai/models/websearchpremiumtool.py +17 -0
- mistralai/models/websearchtool.py +17 -0
- mistralai/sdk.py +3 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/METADATA +42 -7
- {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/RECORD +91 -14
- {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/LICENSE +0 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
import inspect
|
|
5
|
+
|
|
6
|
+
from pydantic import Field, create_model
|
|
7
|
+
from pydantic.fields import FieldInfo
|
|
8
|
+
import json
|
|
9
|
+
from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
|
|
10
|
+
|
|
11
|
+
from griffe import (
|
|
12
|
+
Docstring,
|
|
13
|
+
DocstringSectionKind,
|
|
14
|
+
DocstringSectionText,
|
|
15
|
+
DocstringParameter,
|
|
16
|
+
DocstringSection,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from mistralai.extra.exceptions import RunException
|
|
20
|
+
from mistralai.extra.mcp.base import MCPClientProtocol
|
|
21
|
+
from mistralai.extra.run.result import RunOutputEntries
|
|
22
|
+
from mistralai.models import (
|
|
23
|
+
FunctionResultEntry,
|
|
24
|
+
FunctionTool,
|
|
25
|
+
Function,
|
|
26
|
+
FunctionCallEntry,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
logger = logging.getLogger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class RunFunction:
|
|
35
|
+
name: str
|
|
36
|
+
callable: Callable
|
|
37
|
+
tool: FunctionTool
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class RunCoroutine:
|
|
42
|
+
name: str
|
|
43
|
+
awaitable: Callable
|
|
44
|
+
tool: FunctionTool
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class RunMCPTool:
|
|
49
|
+
name: str
|
|
50
|
+
tool: FunctionTool
|
|
51
|
+
mcp_client: MCPClientProtocol
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
RunTool = Union[RunFunction, RunCoroutine, RunMCPTool]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _get_function_description(docstring_sections: list[DocstringSection]) -> str:
|
|
58
|
+
"""Given a list of docstring sections create a description for the function."""
|
|
59
|
+
text_sections: list[DocstringSectionText] = []
|
|
60
|
+
for section in docstring_sections:
|
|
61
|
+
if section.kind == DocstringSectionKind.text:
|
|
62
|
+
text_sections.append(cast(DocstringSectionText, section))
|
|
63
|
+
return "\n".join(text.value for text in text_sections)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _get_function_parameters(
|
|
67
|
+
docstring_sections: list[DocstringSection],
|
|
68
|
+
params_from_sig: list[inspect.Parameter],
|
|
69
|
+
type_hints: dict[str, Any],
|
|
70
|
+
):
|
|
71
|
+
"""Given a list of docstring sections and type annotations create the most accurate tool parameters"""
|
|
72
|
+
params_from_docstrings: list[DocstringParameter] = list(
|
|
73
|
+
itertools.chain.from_iterable(
|
|
74
|
+
section.value
|
|
75
|
+
for section in docstring_sections
|
|
76
|
+
if section.kind
|
|
77
|
+
in (DocstringSectionKind.parameters, DocstringSectionKind.other_parameters)
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
# Extract all description and annotation
|
|
82
|
+
param_descriptions = {}
|
|
83
|
+
param_annotations = {}
|
|
84
|
+
|
|
85
|
+
for param_doc in params_from_docstrings:
|
|
86
|
+
param_descriptions[param_doc.name] = param_doc.description
|
|
87
|
+
|
|
88
|
+
for param in params_from_sig:
|
|
89
|
+
if param.name not in param_descriptions:
|
|
90
|
+
param_descriptions[param.name] = ""
|
|
91
|
+
param_annotations[param.name] = type_hints.get(param.name)
|
|
92
|
+
|
|
93
|
+
# resolve all params into Field and create the parameters schema
|
|
94
|
+
fields: dict[str, tuple[type, FieldInfo]] = {}
|
|
95
|
+
for p in params_from_sig:
|
|
96
|
+
default = p.default if p.default is not inspect.Parameter.empty else ...
|
|
97
|
+
annotation = (
|
|
98
|
+
p.annotation if p.annotation is not inspect.Parameter.empty else Any
|
|
99
|
+
)
|
|
100
|
+
# handle forward ref with the help of get_type_hints
|
|
101
|
+
if isinstance(annotation, str):
|
|
102
|
+
annotation = type_hints[p.name]
|
|
103
|
+
|
|
104
|
+
if isinstance(default, FieldInfo):
|
|
105
|
+
field_info = default
|
|
106
|
+
else:
|
|
107
|
+
# If the annotation is Annotated[..., Field(...)] extract the Field and annotation
|
|
108
|
+
# Otherwise, just use the annotation as-is
|
|
109
|
+
field_info = None
|
|
110
|
+
# If it's Annotated[..., SomeFieldMarker(...)], find it
|
|
111
|
+
if hasattr(annotation, "__metadata__") and hasattr(annotation, "__args__"):
|
|
112
|
+
# It's Annotated
|
|
113
|
+
# e.g. Annotated[str, Field(...)]
|
|
114
|
+
# Extract the first Field(...) or None if not found
|
|
115
|
+
for meta in annotation.__metadata__: # type: ignore
|
|
116
|
+
if isinstance(meta, FieldInfo):
|
|
117
|
+
field_info = meta
|
|
118
|
+
break
|
|
119
|
+
# The actual annotation is the first part of Annotated
|
|
120
|
+
annotation = annotation.__args__[0] # type: ignore
|
|
121
|
+
|
|
122
|
+
# handle forward ref with the help of get_type_hints
|
|
123
|
+
if isinstance(annotation, ForwardRef):
|
|
124
|
+
annotation = param_annotations[p.name]
|
|
125
|
+
|
|
126
|
+
# no Field
|
|
127
|
+
if field_info is None:
|
|
128
|
+
if default is ...:
|
|
129
|
+
field_info = Field()
|
|
130
|
+
else:
|
|
131
|
+
field_info = Field(default=default)
|
|
132
|
+
|
|
133
|
+
field_info.description = param_descriptions[p.name]
|
|
134
|
+
fields[p.name] = (cast(type, annotation), field_info)
|
|
135
|
+
|
|
136
|
+
schema = create_model("_", **fields).model_json_schema() # type: ignore[call-overload]
|
|
137
|
+
schema.pop("title", None)
|
|
138
|
+
for prop in schema.get("properties", {}).values():
|
|
139
|
+
prop.pop("title", None)
|
|
140
|
+
return schema
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def create_tool_call(func: Callable) -> FunctionTool:
|
|
144
|
+
"""Parse a function docstring / type annotations to create a FunctionTool."""
|
|
145
|
+
name = func.__name__
|
|
146
|
+
|
|
147
|
+
# Inspect and parse the docstring of the function
|
|
148
|
+
doc = inspect.getdoc(func)
|
|
149
|
+
docstring_sections: list[DocstringSection]
|
|
150
|
+
if not doc:
|
|
151
|
+
logger.warning(
|
|
152
|
+
f"Function '{name}' without a docstring is being parsed, add docstring for more accurate result."
|
|
153
|
+
)
|
|
154
|
+
docstring_sections = []
|
|
155
|
+
else:
|
|
156
|
+
docstring = Docstring(doc, parser="google")
|
|
157
|
+
docstring_sections = docstring.parse(warnings=False)
|
|
158
|
+
if len(docstring_sections) == 0:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Function '{name}' has no relevant docstring sections, add docstring for more accurate result."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# Extract the function's signature and type hints
|
|
164
|
+
sig = inspect.signature(func)
|
|
165
|
+
params_from_sig = list(sig.parameters.values())
|
|
166
|
+
type_hints = get_type_hints(func, include_extras=True, localns=None, globalns=None)
|
|
167
|
+
|
|
168
|
+
return FunctionTool(
|
|
169
|
+
type="function",
|
|
170
|
+
function=Function(
|
|
171
|
+
name=name,
|
|
172
|
+
description=_get_function_description(docstring_sections),
|
|
173
|
+
parameters=_get_function_parameters(
|
|
174
|
+
docstring_sections=docstring_sections,
|
|
175
|
+
params_from_sig=params_from_sig,
|
|
176
|
+
type_hints=type_hints,
|
|
177
|
+
),
|
|
178
|
+
strict=True,
|
|
179
|
+
),
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
async def create_function_result(
|
|
184
|
+
function_call: FunctionCallEntry,
|
|
185
|
+
run_tool: RunTool,
|
|
186
|
+
continue_on_fn_error: bool = False,
|
|
187
|
+
) -> FunctionResultEntry:
|
|
188
|
+
"""Run the function with arguments of a FunctionCallEntry."""
|
|
189
|
+
arguments = (
|
|
190
|
+
json.loads(function_call.arguments)
|
|
191
|
+
if isinstance(function_call.arguments, str)
|
|
192
|
+
else function_call.arguments
|
|
193
|
+
)
|
|
194
|
+
try:
|
|
195
|
+
if isinstance(run_tool, RunFunction):
|
|
196
|
+
res = run_tool.callable(**arguments)
|
|
197
|
+
elif isinstance(run_tool, RunCoroutine):
|
|
198
|
+
res = await run_tool.awaitable(**arguments)
|
|
199
|
+
elif isinstance(run_tool, RunMCPTool):
|
|
200
|
+
res = await run_tool.mcp_client.execute_tool(function_call.name, arguments)
|
|
201
|
+
except Exception as e:
|
|
202
|
+
if continue_on_fn_error is True:
|
|
203
|
+
return FunctionResultEntry(
|
|
204
|
+
tool_call_id=function_call.tool_call_id,
|
|
205
|
+
result=f"Error while executing {function_call.name}: {str(e)}",
|
|
206
|
+
)
|
|
207
|
+
raise RunException(
|
|
208
|
+
f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'"
|
|
209
|
+
) from e
|
|
210
|
+
|
|
211
|
+
return FunctionResultEntry(
|
|
212
|
+
tool_call_id=function_call.tool_call_id,
|
|
213
|
+
result=res if isinstance(res, str) else json.dumps(res),
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def get_function_calls(
|
|
218
|
+
output_entries: Sequence[RunOutputEntries],
|
|
219
|
+
) -> list[FunctionCallEntry]:
|
|
220
|
+
"""Extract all FunctionCallEntry from a conversation response"""
|
|
221
|
+
function_calls = []
|
|
222
|
+
for entry in output_entries:
|
|
223
|
+
if isinstance(entry, FunctionCallEntry):
|
|
224
|
+
function_calls.append(entry)
|
|
225
|
+
return function_calls
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
import importlib.util
|
|
2
|
+
import sys
|
|
3
|
+
from typing import Callable, TypeVar, Any, cast
|
|
4
|
+
from functools import wraps
|
|
5
|
+
|
|
6
|
+
from mistralai.extra.exceptions import MistralClientException
|
|
7
|
+
|
|
8
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
REQUIRED_PYTHON_VERSION = (3, 10)
|
|
12
|
+
REQUIRED_PYTHON_VERSION_STR = "3.10"
|
|
13
|
+
REQUIRED_PACKAGES = ["mcp"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def is_module_installed(module_name: str) -> bool:
|
|
17
|
+
spec = importlib.util.find_spec(module_name)
|
|
18
|
+
return spec is not None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def run_requirements(func: F) -> F:
|
|
22
|
+
@wraps(func)
|
|
23
|
+
def wrapper(*args, **kwargs):
|
|
24
|
+
if sys.version_info < REQUIRED_PYTHON_VERSION:
|
|
25
|
+
raise MistralClientException(
|
|
26
|
+
f"{func.__name__} requires a Python version higher than {REQUIRED_PYTHON_VERSION_STR}."
|
|
27
|
+
f"You are using Python {sys.version_info.major}.{sys.version_info.minor}."
|
|
28
|
+
)
|
|
29
|
+
for package in REQUIRED_PACKAGES:
|
|
30
|
+
if not is_module_installed(package):
|
|
31
|
+
raise MistralClientException(
|
|
32
|
+
f"{func.__name__} requires the sdk to be installed with 'agents' extra dependencies."
|
|
33
|
+
)
|
|
34
|
+
return func(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return cast(F, wrapper)
|