mistralai 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. mistralai/_version.py +2 -2
  2. mistralai/beta.py +20 -0
  3. mistralai/conversations.py +2657 -0
  4. mistralai/extra/__init__.py +10 -2
  5. mistralai/extra/exceptions.py +14 -0
  6. mistralai/extra/mcp/__init__.py +0 -0
  7. mistralai/extra/mcp/auth.py +166 -0
  8. mistralai/extra/mcp/base.py +155 -0
  9. mistralai/extra/mcp/sse.py +165 -0
  10. mistralai/extra/mcp/stdio.py +22 -0
  11. mistralai/extra/run/__init__.py +0 -0
  12. mistralai/extra/run/context.py +295 -0
  13. mistralai/extra/run/result.py +212 -0
  14. mistralai/extra/run/tools.py +225 -0
  15. mistralai/extra/run/utils.py +36 -0
  16. mistralai/extra/tests/test_struct_chat.py +1 -1
  17. mistralai/mistral_agents.py +1158 -0
  18. mistralai/models/__init__.py +470 -1
  19. mistralai/models/agent.py +129 -0
  20. mistralai/models/agentconversation.py +71 -0
  21. mistralai/models/agentcreationrequest.py +109 -0
  22. mistralai/models/agenthandoffdoneevent.py +33 -0
  23. mistralai/models/agenthandoffentry.py +75 -0
  24. mistralai/models/agenthandoffstartedevent.py +33 -0
  25. mistralai/models/agents_api_v1_agents_getop.py +16 -0
  26. mistralai/models/agents_api_v1_agents_listop.py +24 -0
  27. mistralai/models/agents_api_v1_agents_update_versionop.py +21 -0
  28. mistralai/models/agents_api_v1_agents_updateop.py +23 -0
  29. mistralai/models/agents_api_v1_conversations_append_streamop.py +28 -0
  30. mistralai/models/agents_api_v1_conversations_appendop.py +28 -0
  31. mistralai/models/agents_api_v1_conversations_getop.py +33 -0
  32. mistralai/models/agents_api_v1_conversations_historyop.py +16 -0
  33. mistralai/models/agents_api_v1_conversations_listop.py +37 -0
  34. mistralai/models/agents_api_v1_conversations_messagesop.py +16 -0
  35. mistralai/models/agents_api_v1_conversations_restart_streamop.py +26 -0
  36. mistralai/models/agents_api_v1_conversations_restartop.py +26 -0
  37. mistralai/models/agentupdaterequest.py +111 -0
  38. mistralai/models/builtinconnectors.py +13 -0
  39. mistralai/models/codeinterpretertool.py +17 -0
  40. mistralai/models/completionargs.py +100 -0
  41. mistralai/models/completionargsstop.py +13 -0
  42. mistralai/models/completionjobout.py +3 -3
  43. mistralai/models/conversationappendrequest.py +35 -0
  44. mistralai/models/conversationappendstreamrequest.py +37 -0
  45. mistralai/models/conversationevents.py +72 -0
  46. mistralai/models/conversationhistory.py +58 -0
  47. mistralai/models/conversationinputs.py +14 -0
  48. mistralai/models/conversationmessages.py +28 -0
  49. mistralai/models/conversationrequest.py +133 -0
  50. mistralai/models/conversationresponse.py +51 -0
  51. mistralai/models/conversationrestartrequest.py +42 -0
  52. mistralai/models/conversationrestartstreamrequest.py +44 -0
  53. mistralai/models/conversationstreamrequest.py +135 -0
  54. mistralai/models/conversationusageinfo.py +63 -0
  55. mistralai/models/documentlibrarytool.py +22 -0
  56. mistralai/models/functioncallentry.py +76 -0
  57. mistralai/models/functioncallentryarguments.py +15 -0
  58. mistralai/models/functioncallevent.py +36 -0
  59. mistralai/models/functionresultentry.py +69 -0
  60. mistralai/models/functiontool.py +21 -0
  61. mistralai/models/imagegenerationtool.py +17 -0
  62. mistralai/models/inputentries.py +18 -0
  63. mistralai/models/messageentries.py +18 -0
  64. mistralai/models/messageinputcontentchunks.py +26 -0
  65. mistralai/models/messageinputentry.py +89 -0
  66. mistralai/models/messageoutputcontentchunks.py +30 -0
  67. mistralai/models/messageoutputentry.py +100 -0
  68. mistralai/models/messageoutputevent.py +93 -0
  69. mistralai/models/modelconversation.py +127 -0
  70. mistralai/models/outputcontentchunks.py +30 -0
  71. mistralai/models/responsedoneevent.py +25 -0
  72. mistralai/models/responseerrorevent.py +27 -0
  73. mistralai/models/responsestartedevent.py +24 -0
  74. mistralai/models/ssetypes.py +18 -0
  75. mistralai/models/toolexecutiondoneevent.py +34 -0
  76. mistralai/models/toolexecutionentry.py +70 -0
  77. mistralai/models/toolexecutionstartedevent.py +31 -0
  78. mistralai/models/toolfilechunk.py +61 -0
  79. mistralai/models/toolreferencechunk.py +61 -0
  80. mistralai/models/websearchpremiumtool.py +17 -0
  81. mistralai/models/websearchtool.py +17 -0
  82. mistralai/sdk.py +3 -0
  83. {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/METADATA +42 -7
  84. {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/RECORD +86 -10
  85. {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/LICENSE +0 -0
  86. {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,225 @@
1
+ import itertools
2
+ import logging
3
+ from dataclasses import dataclass
4
+ import inspect
5
+
6
+ from pydantic import Field, create_model
7
+ from pydantic.fields import FieldInfo
8
+ import json
9
+ from typing import cast, Callable, Sequence, Any, ForwardRef, get_type_hints, Union
10
+
11
+ from griffe import (
12
+ Docstring,
13
+ DocstringSectionKind,
14
+ DocstringSectionText,
15
+ DocstringParameter,
16
+ DocstringSection,
17
+ )
18
+
19
+ from mistralai.extra.exceptions import RunException
20
+ from mistralai.extra.mcp.base import MCPClientProtocol
21
+ from mistralai.extra.run.result import RunOutputEntries
22
+ from mistralai.models import (
23
+ FunctionResultEntry,
24
+ FunctionTool,
25
+ Function,
26
+ FunctionCallEntry,
27
+ )
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ @dataclass
34
+ class RunFunction:
35
+ name: str
36
+ callable: Callable
37
+ tool: FunctionTool
38
+
39
+
40
+ @dataclass
41
+ class RunCoroutine:
42
+ name: str
43
+ awaitable: Callable
44
+ tool: FunctionTool
45
+
46
+
47
+ @dataclass
48
+ class RunMCPTool:
49
+ name: str
50
+ tool: FunctionTool
51
+ mcp_client: MCPClientProtocol
52
+
53
+
54
+ RunTool = Union[RunFunction, RunCoroutine, RunMCPTool]
55
+
56
+
57
+ def _get_function_description(docstring_sections: list[DocstringSection]) -> str:
58
+ """Given a list of docstring sections create a description for the function."""
59
+ text_sections: list[DocstringSectionText] = []
60
+ for section in docstring_sections:
61
+ if section.kind == DocstringSectionKind.text:
62
+ text_sections.append(cast(DocstringSectionText, section))
63
+ return "\n".join(text.value for text in text_sections)
64
+
65
+
66
+ def _get_function_parameters(
67
+ docstring_sections: list[DocstringSection],
68
+ params_from_sig: list[inspect.Parameter],
69
+ type_hints: dict[str, Any],
70
+ ):
71
+ """Given a list of docstring sections and type annotations create the most accurate tool parameters"""
72
+ params_from_docstrings: list[DocstringParameter] = list(
73
+ itertools.chain.from_iterable(
74
+ section.value
75
+ for section in docstring_sections
76
+ if section.kind
77
+ in (DocstringSectionKind.parameters, DocstringSectionKind.other_parameters)
78
+ )
79
+ )
80
+
81
+ # Extract all description and annotation
82
+ param_descriptions = {}
83
+ param_annotations = {}
84
+
85
+ for param_doc in params_from_docstrings:
86
+ param_descriptions[param_doc.name] = param_doc.description
87
+
88
+ for param in params_from_sig:
89
+ if param.name not in param_descriptions:
90
+ param_descriptions[param.name] = ""
91
+ param_annotations[param.name] = type_hints.get(param.name)
92
+
93
+ # resolve all params into Field and create the parameters schema
94
+ fields: dict[str, tuple[type, FieldInfo]] = {}
95
+ for p in params_from_sig:
96
+ default = p.default if p.default is not inspect.Parameter.empty else ...
97
+ annotation = (
98
+ p.annotation if p.annotation is not inspect.Parameter.empty else Any
99
+ )
100
+ # handle forward ref with the help of get_type_hints
101
+ if isinstance(annotation, str):
102
+ annotation = type_hints[p.name]
103
+
104
+ if isinstance(default, FieldInfo):
105
+ field_info = default
106
+ else:
107
+ # If the annotation is Annotated[..., Field(...)] extract the Field and annotation
108
+ # Otherwise, just use the annotation as-is
109
+ field_info = None
110
+ # If it's Annotated[..., SomeFieldMarker(...)], find it
111
+ if hasattr(annotation, "__metadata__") and hasattr(annotation, "__args__"):
112
+ # It's Annotated
113
+ # e.g. Annotated[str, Field(...)]
114
+ # Extract the first Field(...) or None if not found
115
+ for meta in annotation.__metadata__: # type: ignore
116
+ if isinstance(meta, FieldInfo):
117
+ field_info = meta
118
+ break
119
+ # The actual annotation is the first part of Annotated
120
+ annotation = annotation.__args__[0] # type: ignore
121
+
122
+ # handle forward ref with the help of get_type_hints
123
+ if isinstance(annotation, ForwardRef):
124
+ annotation = param_annotations[p.name]
125
+
126
+ # no Field
127
+ if field_info is None:
128
+ if default is ...:
129
+ field_info = Field()
130
+ else:
131
+ field_info = Field(default=default)
132
+
133
+ field_info.description = param_descriptions[p.name]
134
+ fields[p.name] = (cast(type, annotation), field_info)
135
+
136
+ schema = create_model("_", **fields).model_json_schema() # type: ignore[call-overload]
137
+ schema.pop("title", None)
138
+ for prop in schema.get("properties", {}).values():
139
+ prop.pop("title", None)
140
+ return schema
141
+
142
+
143
+ def create_tool_call(func: Callable) -> FunctionTool:
144
+ """Parse a function docstring / type annotations to create a FunctionTool."""
145
+ name = func.__name__
146
+
147
+ # Inspect and parse the docstring of the function
148
+ doc = inspect.getdoc(func)
149
+ docstring_sections: list[DocstringSection]
150
+ if not doc:
151
+ logger.warning(
152
+ f"Function '{name}' without a docstring is being parsed, add docstring for more accurate result."
153
+ )
154
+ docstring_sections = []
155
+ else:
156
+ docstring = Docstring(doc, parser="google")
157
+ docstring_sections = docstring.parse(warnings=False)
158
+ if len(docstring_sections) == 0:
159
+ logger.warning(
160
+ f"Function '{name}' has no relevant docstring sections, add docstring for more accurate result."
161
+ )
162
+
163
+ # Extract the function's signature and type hints
164
+ sig = inspect.signature(func)
165
+ params_from_sig = list(sig.parameters.values())
166
+ type_hints = get_type_hints(func, include_extras=True, localns=None, globalns=None)
167
+
168
+ return FunctionTool(
169
+ type="function",
170
+ function=Function(
171
+ name=name,
172
+ description=_get_function_description(docstring_sections),
173
+ parameters=_get_function_parameters(
174
+ docstring_sections=docstring_sections,
175
+ params_from_sig=params_from_sig,
176
+ type_hints=type_hints,
177
+ ),
178
+ strict=True,
179
+ ),
180
+ )
181
+
182
+
183
+ async def create_function_result(
184
+ function_call: FunctionCallEntry,
185
+ run_tool: RunTool,
186
+ continue_on_fn_error: bool = False,
187
+ ) -> FunctionResultEntry:
188
+ """Run the function with arguments of a FunctionCallEntry."""
189
+ arguments = (
190
+ json.loads(function_call.arguments)
191
+ if isinstance(function_call.arguments, str)
192
+ else function_call.arguments
193
+ )
194
+ try:
195
+ if isinstance(run_tool, RunFunction):
196
+ res = run_tool.callable(**arguments)
197
+ elif isinstance(run_tool, RunCoroutine):
198
+ res = await run_tool.awaitable(**arguments)
199
+ elif isinstance(run_tool, RunMCPTool):
200
+ res = await run_tool.mcp_client.execute_tool(function_call.name, arguments)
201
+ except Exception as e:
202
+ if continue_on_fn_error is True:
203
+ return FunctionResultEntry(
204
+ tool_call_id=function_call.tool_call_id,
205
+ result=f"Error while executing {function_call.name}: {str(e)}",
206
+ )
207
+ raise RunException(
208
+ f"Failed to execute tool {function_call.name} with arguments '{function_call.arguments}'"
209
+ ) from e
210
+
211
+ return FunctionResultEntry(
212
+ tool_call_id=function_call.tool_call_id,
213
+ result=res if isinstance(res, str) else json.dumps(res),
214
+ )
215
+
216
+
217
+ def get_function_calls(
218
+ output_entries: Sequence[RunOutputEntries],
219
+ ) -> list[FunctionCallEntry]:
220
+ """Extract all FunctionCallEntry from a conversation response"""
221
+ function_calls = []
222
+ for entry in output_entries:
223
+ if isinstance(entry, FunctionCallEntry):
224
+ function_calls.append(entry)
225
+ return function_calls
@@ -0,0 +1,36 @@
1
+ import importlib.util
2
+ import sys
3
+ from typing import Callable, TypeVar, Any, cast
4
+ from functools import wraps
5
+
6
+ from mistralai.extra.exceptions import MistralClientException
7
+
8
+ F = TypeVar("F", bound=Callable[..., Any])
9
+
10
+
11
+ REQUIRED_PYTHON_VERSION = (3, 10)
12
+ REQUIRED_PYTHON_VERSION_STR = "3.10"
13
+ REQUIRED_PACKAGES = ["mcp"]
14
+
15
+
16
+ def is_module_installed(module_name: str) -> bool:
17
+ spec = importlib.util.find_spec(module_name)
18
+ return spec is not None
19
+
20
+
21
+ def run_requirements(func: F) -> F:
22
+ @wraps(func)
23
+ def wrapper(*args, **kwargs):
24
+ if sys.version_info < REQUIRED_PYTHON_VERSION:
25
+ raise MistralClientException(
26
+ f"{func.__name__} requires a Python version higher than {REQUIRED_PYTHON_VERSION_STR}."
27
+ f"You are using Python {sys.version_info.major}.{sys.version_info.minor}."
28
+ )
29
+ for package in REQUIRED_PACKAGES:
30
+ if not is_module_installed(package):
31
+ raise MistralClientException(
32
+ f"{func.__name__} requires the sdk to be installed with 'agents' extra dependencies."
33
+ )
34
+ return func(*args, **kwargs)
35
+
36
+ return cast(F, wrapper)
@@ -45,7 +45,7 @@ mock_cc_response = ChatCompletionResponse(
45
45
  )
46
46
 
47
47
 
48
- expected_response = ParsedChatCompletionResponse(
48
+ expected_response: ParsedChatCompletionResponse = ParsedChatCompletionResponse(
49
49
  choices=[
50
50
  ParsedChatCompletionChoice(
51
51
  index=0,