mistralai 1.7.1__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +2 -2
- mistralai/beta.py +20 -0
- mistralai/conversations.py +2657 -0
- mistralai/extra/__init__.py +10 -2
- mistralai/extra/exceptions.py +14 -0
- mistralai/extra/mcp/__init__.py +0 -0
- mistralai/extra/mcp/auth.py +166 -0
- mistralai/extra/mcp/base.py +155 -0
- mistralai/extra/mcp/sse.py +165 -0
- mistralai/extra/mcp/stdio.py +22 -0
- mistralai/extra/run/__init__.py +0 -0
- mistralai/extra/run/context.py +295 -0
- mistralai/extra/run/result.py +212 -0
- mistralai/extra/run/tools.py +225 -0
- mistralai/extra/run/utils.py +36 -0
- mistralai/extra/tests/test_struct_chat.py +1 -1
- mistralai/mistral_agents.py +1158 -0
- mistralai/models/__init__.py +470 -1
- mistralai/models/agent.py +129 -0
- mistralai/models/agentconversation.py +71 -0
- mistralai/models/agentcreationrequest.py +109 -0
- mistralai/models/agenthandoffdoneevent.py +33 -0
- mistralai/models/agenthandoffentry.py +75 -0
- mistralai/models/agenthandoffstartedevent.py +33 -0
- mistralai/models/agents_api_v1_agents_getop.py +16 -0
- mistralai/models/agents_api_v1_agents_listop.py +24 -0
- mistralai/models/agents_api_v1_agents_update_versionop.py +21 -0
- mistralai/models/agents_api_v1_agents_updateop.py +23 -0
- mistralai/models/agents_api_v1_conversations_append_streamop.py +28 -0
- mistralai/models/agents_api_v1_conversations_appendop.py +28 -0
- mistralai/models/agents_api_v1_conversations_getop.py +33 -0
- mistralai/models/agents_api_v1_conversations_historyop.py +16 -0
- mistralai/models/agents_api_v1_conversations_listop.py +37 -0
- mistralai/models/agents_api_v1_conversations_messagesop.py +16 -0
- mistralai/models/agents_api_v1_conversations_restart_streamop.py +26 -0
- mistralai/models/agents_api_v1_conversations_restartop.py +26 -0
- mistralai/models/agentupdaterequest.py +111 -0
- mistralai/models/builtinconnectors.py +13 -0
- mistralai/models/codeinterpretertool.py +17 -0
- mistralai/models/completionargs.py +100 -0
- mistralai/models/completionargsstop.py +13 -0
- mistralai/models/completionjobout.py +3 -3
- mistralai/models/conversationappendrequest.py +35 -0
- mistralai/models/conversationappendstreamrequest.py +37 -0
- mistralai/models/conversationevents.py +72 -0
- mistralai/models/conversationhistory.py +58 -0
- mistralai/models/conversationinputs.py +14 -0
- mistralai/models/conversationmessages.py +28 -0
- mistralai/models/conversationrequest.py +133 -0
- mistralai/models/conversationresponse.py +51 -0
- mistralai/models/conversationrestartrequest.py +42 -0
- mistralai/models/conversationrestartstreamrequest.py +44 -0
- mistralai/models/conversationstreamrequest.py +135 -0
- mistralai/models/conversationusageinfo.py +63 -0
- mistralai/models/documentlibrarytool.py +22 -0
- mistralai/models/functioncallentry.py +76 -0
- mistralai/models/functioncallentryarguments.py +15 -0
- mistralai/models/functioncallevent.py +36 -0
- mistralai/models/functionresultentry.py +69 -0
- mistralai/models/functiontool.py +21 -0
- mistralai/models/imagegenerationtool.py +17 -0
- mistralai/models/inputentries.py +18 -0
- mistralai/models/messageentries.py +18 -0
- mistralai/models/messageinputcontentchunks.py +26 -0
- mistralai/models/messageinputentry.py +89 -0
- mistralai/models/messageoutputcontentchunks.py +30 -0
- mistralai/models/messageoutputentry.py +100 -0
- mistralai/models/messageoutputevent.py +93 -0
- mistralai/models/modelconversation.py +127 -0
- mistralai/models/outputcontentchunks.py +30 -0
- mistralai/models/responsedoneevent.py +25 -0
- mistralai/models/responseerrorevent.py +27 -0
- mistralai/models/responsestartedevent.py +24 -0
- mistralai/models/ssetypes.py +18 -0
- mistralai/models/toolexecutiondoneevent.py +34 -0
- mistralai/models/toolexecutionentry.py +70 -0
- mistralai/models/toolexecutionstartedevent.py +31 -0
- mistralai/models/toolfilechunk.py +61 -0
- mistralai/models/toolreferencechunk.py +61 -0
- mistralai/models/websearchpremiumtool.py +17 -0
- mistralai/models/websearchtool.py +17 -0
- mistralai/sdk.py +3 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/METADATA +42 -7
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/RECORD +86 -10
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/LICENSE +0 -0
- {mistralai-1.7.1.dist-info → mistralai-1.8.0.dist-info}/WHEEL +0 -0
mistralai/extra/__init__.py
CHANGED
|
@@ -1,5 +1,13 @@
|
|
|
1
|
-
from .struct_chat import
|
|
1
|
+
from .struct_chat import (
|
|
2
|
+
ParsedChatCompletionResponse,
|
|
3
|
+
convert_to_parsed_chat_completion_response,
|
|
4
|
+
)
|
|
2
5
|
from .utils import response_format_from_pydantic_model
|
|
3
6
|
from .utils.response_format import CustomPydanticModel
|
|
4
7
|
|
|
5
|
-
__all__ = [
|
|
8
|
+
__all__ = [
|
|
9
|
+
"convert_to_parsed_chat_completion_response",
|
|
10
|
+
"response_format_from_pydantic_model",
|
|
11
|
+
"CustomPydanticModel",
|
|
12
|
+
"ParsedChatCompletionResponse",
|
|
13
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
class MistralClientException(Exception):
|
|
2
|
+
"""Base exception for all the client errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class RunException(MistralClientException):
|
|
6
|
+
"""Exception raised for errors during a conversation run."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MCPException(MistralClientException):
|
|
10
|
+
"""Exception raised for errors related to MCP operations."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MCPAuthException(MCPException):
|
|
14
|
+
"""Exception raised for authentication errors with an MCP server."""
|
|
File without changes
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
|
|
4
|
+
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
|
|
5
|
+
import httpx
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from mistralai.types import BaseModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Oauth2AuthorizationScheme(BaseModel):
|
|
14
|
+
"""Information about the oauth flow to perform with the authorization server."""
|
|
15
|
+
|
|
16
|
+
authorization_url: str
|
|
17
|
+
token_url: str
|
|
18
|
+
scope: list[str]
|
|
19
|
+
description: Optional[str] = None
|
|
20
|
+
refresh_url: Optional[str] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OAuthParams(BaseModel):
|
|
24
|
+
"""Required params for authorization."""
|
|
25
|
+
|
|
26
|
+
scheme: Oauth2AuthorizationScheme
|
|
27
|
+
client_id: str
|
|
28
|
+
client_secret: str
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AsyncOAuth2Client(AsyncOAuth2ClientBase):
|
|
32
|
+
"""Subclass of the Async httpx oauth client which provides a constructor from OAuthParams."""
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def from_oauth_params(cls, oauth_params: OAuthParams) -> "AsyncOAuth2Client":
|
|
36
|
+
return cls(
|
|
37
|
+
client_id=oauth_params.client_id,
|
|
38
|
+
client_secret=oauth_params.client_secret,
|
|
39
|
+
scope=oauth_params.scheme.scope,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def get_well_known_authorization_server_metadata(
|
|
44
|
+
server_url: str,
|
|
45
|
+
) -> Optional[AuthorizationServerMetadata]:
|
|
46
|
+
"""Fetch the metadata from the well-known location.
|
|
47
|
+
|
|
48
|
+
This should be available on MCP servers as described by the specification:
|
|
49
|
+
https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#2-3-server-metadata-discovery.
|
|
50
|
+
"""
|
|
51
|
+
well_known_url = f"{server_url}/.well-known/oauth-authorization-server"
|
|
52
|
+
response = await httpx.AsyncClient().get(well_known_url)
|
|
53
|
+
if 200 <= response.status_code < 300:
|
|
54
|
+
try:
|
|
55
|
+
server_metadata = AuthorizationServerMetadata(**response.json())
|
|
56
|
+
server_metadata.validate()
|
|
57
|
+
return server_metadata
|
|
58
|
+
except ValueError:
|
|
59
|
+
logger.exception("Failed to parse oauth well-known metadata")
|
|
60
|
+
return None
|
|
61
|
+
else:
|
|
62
|
+
logger.error(f"Failed to get oauth well-known metadata from {server_url}")
|
|
63
|
+
return None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
async def get_oauth_server_metadata(server_url: str) -> AuthorizationServerMetadata:
|
|
67
|
+
"""Fetch the metadata from the authorization server to perform the oauth flow."""
|
|
68
|
+
# 1) attempt to get the metadata from the resource server at /.well-known/oauth-protected-resource
|
|
69
|
+
# TODO: new self-discovery protocol, not released yet
|
|
70
|
+
|
|
71
|
+
# 2) attempt to get the metadata from the authorization server at /.well-known/oauth-authorization-server
|
|
72
|
+
metadata = await get_well_known_authorization_server_metadata(server_url=server_url)
|
|
73
|
+
if metadata is not None:
|
|
74
|
+
return metadata
|
|
75
|
+
|
|
76
|
+
# 3) fallback on default endpoints
|
|
77
|
+
# https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#2-3-3-fallbacks-for-servers-without-metadata-discovery
|
|
78
|
+
return AuthorizationServerMetadata(
|
|
79
|
+
issuer=server_url,
|
|
80
|
+
authorization_endpoint=f"{server_url}/authorize",
|
|
81
|
+
token_endpoint=f"{server_url}/token",
|
|
82
|
+
register_endpoint=f"{server_url}/register",
|
|
83
|
+
response_types_supported=["code"],
|
|
84
|
+
response_modes_supported=["query"],
|
|
85
|
+
grant_types_supported=["authorization_code", "refresh_token"],
|
|
86
|
+
token_endpoint_auth_methods_supported=["client_secret_basic"],
|
|
87
|
+
code_challenge_methods_supported=["S256", "plain"],
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def dynamic_client_registration(
|
|
92
|
+
register_endpoint: str,
|
|
93
|
+
redirect_url: str,
|
|
94
|
+
async_client: httpx.AsyncClient,
|
|
95
|
+
) -> tuple[str, str]:
|
|
96
|
+
"""Try to register the client dynamically with an MCP server.
|
|
97
|
+
|
|
98
|
+
Returns a client_id and client_secret.
|
|
99
|
+
"""
|
|
100
|
+
# Construct the registration request payload
|
|
101
|
+
registration_payload = {
|
|
102
|
+
"client_name": "MistralSDKClient",
|
|
103
|
+
"grant_types": ["authorization_code", "refresh_token"],
|
|
104
|
+
"token_endpoint_auth_method": "client_secret_basic",
|
|
105
|
+
"response_types": ["code"],
|
|
106
|
+
"redirect_uris": [redirect_url],
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# Make the registration request
|
|
110
|
+
response = await async_client.post(register_endpoint, json=registration_payload)
|
|
111
|
+
try:
|
|
112
|
+
response.raise_for_status()
|
|
113
|
+
registration_info = response.json()
|
|
114
|
+
client_id = registration_info["client_id"]
|
|
115
|
+
client_secret = registration_info["client_secret"]
|
|
116
|
+
except Exception as e:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"Client registration failed: status={response.status_code}, error={response.text}"
|
|
119
|
+
) from e
|
|
120
|
+
return client_id, client_secret
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
async def build_oauth_params(
|
|
124
|
+
server_url: str,
|
|
125
|
+
redirect_url: str,
|
|
126
|
+
client_id: Optional[str] = None,
|
|
127
|
+
client_secret: Optional[str] = None,
|
|
128
|
+
scope: Optional[list[str]] = None,
|
|
129
|
+
async_client: Optional[httpx.AsyncClient] = None,
|
|
130
|
+
) -> OAuthParams:
|
|
131
|
+
"""Get issuer metadata and build the oauth required params."""
|
|
132
|
+
metadata = await get_oauth_server_metadata(server_url=server_url)
|
|
133
|
+
oauth_scheme = Oauth2AuthorizationScheme(
|
|
134
|
+
authorization_url=metadata.authorization_endpoint,
|
|
135
|
+
token_url=metadata.token_endpoint,
|
|
136
|
+
scope=scope or [],
|
|
137
|
+
refresh_url=metadata.token_endpoint
|
|
138
|
+
if "refresh_token" in metadata.grant_types_supported
|
|
139
|
+
else None,
|
|
140
|
+
)
|
|
141
|
+
if client_id and client_secret:
|
|
142
|
+
return OAuthParams(
|
|
143
|
+
client_id=client_id,
|
|
144
|
+
client_secret=client_secret,
|
|
145
|
+
scheme=oauth_scheme,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Try to dynamically register the client
|
|
149
|
+
if async_client:
|
|
150
|
+
reg_client_id, reg_client_secret = await dynamic_client_registration(
|
|
151
|
+
register_endpoint=metadata.registration_endpoint,
|
|
152
|
+
redirect_url=redirect_url,
|
|
153
|
+
async_client=async_client,
|
|
154
|
+
)
|
|
155
|
+
else:
|
|
156
|
+
async with httpx.AsyncClient() as async_client:
|
|
157
|
+
reg_client_id, reg_client_secret = await dynamic_client_registration(
|
|
158
|
+
register_endpoint=metadata.registration_endpoint,
|
|
159
|
+
redirect_url=redirect_url,
|
|
160
|
+
async_client=async_client,
|
|
161
|
+
)
|
|
162
|
+
return OAuthParams(
|
|
163
|
+
client_id=reg_client_id,
|
|
164
|
+
client_secret=reg_client_secret,
|
|
165
|
+
scheme=oauth_scheme,
|
|
166
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
import logging
|
|
3
|
+
import typing
|
|
4
|
+
from contextlib import AsyncExitStack
|
|
5
|
+
from typing import Protocol, Any
|
|
6
|
+
|
|
7
|
+
from mcp import ClientSession
|
|
8
|
+
from mcp.types import ListPromptsResult, EmbeddedResource, ImageContent, TextContent
|
|
9
|
+
|
|
10
|
+
from mistralai.extra.exceptions import MCPException
|
|
11
|
+
from mistralai.models import (
|
|
12
|
+
FunctionTool,
|
|
13
|
+
Function,
|
|
14
|
+
SystemMessageTypedDict,
|
|
15
|
+
AssistantMessageTypedDict,
|
|
16
|
+
TextChunkTypedDict,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class MCPSystemPrompt(typing.TypedDict):
|
|
23
|
+
description: Optional[str]
|
|
24
|
+
messages: list[Union[SystemMessageTypedDict, AssistantMessageTypedDict]]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class MCPClientProtocol(Protocol):
|
|
28
|
+
"""MCP client that converts MCP artifacts to Mistral format."""
|
|
29
|
+
|
|
30
|
+
_name: str
|
|
31
|
+
|
|
32
|
+
async def initialize(self, exit_stack: Optional[AsyncExitStack]) -> None:
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
async def aclose(self) -> None:
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
async def get_tools(self) -> list[FunctionTool]:
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
async def execute_tool(
|
|
42
|
+
self, name: str, arguments: dict
|
|
43
|
+
) -> list[TextChunkTypedDict]:
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
async def get_system_prompt(
|
|
47
|
+
self, name: str, arguments: dict[str, Any]
|
|
48
|
+
) -> MCPSystemPrompt:
|
|
49
|
+
...
|
|
50
|
+
|
|
51
|
+
async def list_system_prompts(self) -> ListPromptsResult:
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MCPClientBase(MCPClientProtocol):
|
|
56
|
+
"""Base class to implement functionalities from an initialized MCP session."""
|
|
57
|
+
|
|
58
|
+
_session: ClientSession
|
|
59
|
+
|
|
60
|
+
def __init__(self, name: Optional[str] = None):
|
|
61
|
+
self._name = name or self.__class__.__name__
|
|
62
|
+
self._exit_stack: Optional[AsyncExitStack] = None
|
|
63
|
+
self._is_initialized = False
|
|
64
|
+
|
|
65
|
+
def _convert_content(
|
|
66
|
+
self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource]
|
|
67
|
+
) -> TextChunkTypedDict:
|
|
68
|
+
if not mcp_content.type == "text":
|
|
69
|
+
raise MCPException("Only supporting text tool responses for now.")
|
|
70
|
+
return {"type": "text", "text": mcp_content.text}
|
|
71
|
+
|
|
72
|
+
def _convert_content_list(
|
|
73
|
+
self, mcp_contents: list[Union[TextContent, ImageContent, EmbeddedResource]]
|
|
74
|
+
) -> list[TextChunkTypedDict]:
|
|
75
|
+
content_chunks = []
|
|
76
|
+
for mcp_content in mcp_contents:
|
|
77
|
+
content_chunks.append(self._convert_content(mcp_content))
|
|
78
|
+
return content_chunks
|
|
79
|
+
|
|
80
|
+
async def get_tools(self) -> list[FunctionTool]:
|
|
81
|
+
mcp_tools = await self._session.list_tools()
|
|
82
|
+
tools = []
|
|
83
|
+
for mcp_tool in mcp_tools.tools:
|
|
84
|
+
tools.append(
|
|
85
|
+
FunctionTool(
|
|
86
|
+
type="function",
|
|
87
|
+
function=Function(
|
|
88
|
+
name=mcp_tool.name,
|
|
89
|
+
description=mcp_tool.description,
|
|
90
|
+
parameters=mcp_tool.inputSchema,
|
|
91
|
+
strict=True,
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
return tools
|
|
96
|
+
|
|
97
|
+
async def execute_tool(
|
|
98
|
+
self, name: str, arguments: dict[str, Any]
|
|
99
|
+
) -> list[TextChunkTypedDict]:
|
|
100
|
+
contents = await self._session.call_tool(name=name, arguments=arguments)
|
|
101
|
+
return self._convert_content_list(contents.content)
|
|
102
|
+
|
|
103
|
+
async def get_system_prompt(
|
|
104
|
+
self, name: str, arguments: dict[str, Any]
|
|
105
|
+
) -> MCPSystemPrompt:
|
|
106
|
+
prompt_result = await self._session.get_prompt(name=name, arguments=arguments)
|
|
107
|
+
return {
|
|
108
|
+
"description": prompt_result.description,
|
|
109
|
+
"messages": [
|
|
110
|
+
typing.cast(
|
|
111
|
+
Union[SystemMessageTypedDict, AssistantMessageTypedDict],
|
|
112
|
+
{
|
|
113
|
+
"role": message.role,
|
|
114
|
+
"content": self._convert_content(mcp_content=message.content),
|
|
115
|
+
},
|
|
116
|
+
)
|
|
117
|
+
for message in prompt_result.messages
|
|
118
|
+
],
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
async def list_system_prompts(self) -> ListPromptsResult:
|
|
122
|
+
return await self._session.list_prompts()
|
|
123
|
+
|
|
124
|
+
async def initialize(self, exit_stack: Optional[AsyncExitStack] = None) -> None:
|
|
125
|
+
"""Initialize the MCP session."""
|
|
126
|
+
# client is already initialized so return
|
|
127
|
+
if self._is_initialized:
|
|
128
|
+
return
|
|
129
|
+
if exit_stack is None:
|
|
130
|
+
self._exit_stack = AsyncExitStack()
|
|
131
|
+
exit_stack = self._exit_stack
|
|
132
|
+
stdio_transport = await self._get_transport(exit_stack=exit_stack)
|
|
133
|
+
mcp_session = await exit_stack.enter_async_context(
|
|
134
|
+
ClientSession(
|
|
135
|
+
read_stream=stdio_transport[0],
|
|
136
|
+
write_stream=stdio_transport[1],
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
await mcp_session.initialize()
|
|
140
|
+
self._session = mcp_session
|
|
141
|
+
self._is_initialized = True
|
|
142
|
+
|
|
143
|
+
async def aclose(self):
|
|
144
|
+
"""Close the MCP session."""
|
|
145
|
+
if self._exit_stack:
|
|
146
|
+
await self._exit_stack.aclose()
|
|
147
|
+
|
|
148
|
+
def __repr__(self):
|
|
149
|
+
return f"<{self.__class__.__name__} name={self._name!r} id=0x{id(self):x}>"
|
|
150
|
+
|
|
151
|
+
def __str__(self):
|
|
152
|
+
return f"{self.__class__.__name__}(name={self._name})"
|
|
153
|
+
|
|
154
|
+
async def _get_transport(self, exit_stack: AsyncExitStack):
|
|
155
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
import http
|
|
2
|
+
import logging
|
|
3
|
+
import typing
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
from contextlib import AsyncExitStack
|
|
6
|
+
from functools import cached_property
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
|
|
10
|
+
from mistralai.extra.exceptions import MCPAuthException
|
|
11
|
+
from mistralai.extra.mcp.base import (
|
|
12
|
+
MCPClientBase,
|
|
13
|
+
)
|
|
14
|
+
from mistralai.extra.mcp.auth import OAuthParams, AsyncOAuth2Client
|
|
15
|
+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
|
16
|
+
|
|
17
|
+
from mcp.client.sse import sse_client
|
|
18
|
+
from mcp.shared.message import SessionMessage
|
|
19
|
+
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
20
|
+
|
|
21
|
+
from mistralai.types import BaseModel
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SSEServerParams(BaseModel):
|
|
27
|
+
"""Parameters required for a MCPClient with SSE transport"""
|
|
28
|
+
|
|
29
|
+
url: str
|
|
30
|
+
headers: Optional[dict[str, Any]] = None
|
|
31
|
+
timeout: float = 5
|
|
32
|
+
sse_read_timeout: float = 60 * 5
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MCPClientSSE(MCPClientBase):
|
|
36
|
+
"""MCP client that uses sse for communication.
|
|
37
|
+
|
|
38
|
+
The client provides authentication for OAuth2 protocol following the current MCP authorization spec:
|
|
39
|
+
https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization.
|
|
40
|
+
|
|
41
|
+
This is possibly going to change in the future since the protocol has ongoing discussions.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
_oauth_params: Optional[OAuthParams]
|
|
45
|
+
_sse_params: SSEServerParams
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
sse_params: SSEServerParams,
|
|
50
|
+
name: Optional[str] = None,
|
|
51
|
+
oauth_params: Optional[OAuthParams] = None,
|
|
52
|
+
auth_token: Optional[OAuth2Token] = None,
|
|
53
|
+
):
|
|
54
|
+
super().__init__(name=name)
|
|
55
|
+
self._sse_params = sse_params
|
|
56
|
+
self._oauth_params: Optional[OAuthParams] = oauth_params
|
|
57
|
+
self._auth_token: Optional[OAuth2Token] = auth_token
|
|
58
|
+
|
|
59
|
+
@cached_property
|
|
60
|
+
def base_url(self) -> str:
|
|
61
|
+
return self._sse_params.url.rstrip("/sse")
|
|
62
|
+
|
|
63
|
+
def set_oauth_params(self, oauth_params: OAuthParams):
|
|
64
|
+
"""Update the oauth params and client accordingly."""
|
|
65
|
+
if self._oauth_params is not None:
|
|
66
|
+
logger.warning(f"Overriding current oauth params for {self._name}")
|
|
67
|
+
self._oauth_params = oauth_params
|
|
68
|
+
|
|
69
|
+
async def get_auth_url_and_state(self, redirect_url: str) -> tuple[str, str]:
|
|
70
|
+
"""Create the authorization url for client to start oauth flow."""
|
|
71
|
+
if self._oauth_params is None:
|
|
72
|
+
raise MCPAuthException(
|
|
73
|
+
"Can't generate an authorization url without oauth_params being set, "
|
|
74
|
+
"make sure the oauth params have been set."
|
|
75
|
+
)
|
|
76
|
+
oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
|
|
77
|
+
auth_url, state = oauth_client.create_authorization_url(
|
|
78
|
+
self._oauth_params.scheme.authorization_url, redirect_uri=redirect_url
|
|
79
|
+
)
|
|
80
|
+
return auth_url, state
|
|
81
|
+
|
|
82
|
+
async def get_token_from_auth_response(
|
|
83
|
+
self,
|
|
84
|
+
authorization_response: str,
|
|
85
|
+
redirect_url: str,
|
|
86
|
+
state: str,
|
|
87
|
+
) -> OAuth2Token:
|
|
88
|
+
"""Fetch the authentication token from the server."""
|
|
89
|
+
if self._oauth_params is None:
|
|
90
|
+
raise MCPAuthException(
|
|
91
|
+
"Can't fetch a token without oauth_params, make sure they have been set."
|
|
92
|
+
)
|
|
93
|
+
oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
|
|
94
|
+
oauth_token = await oauth_client.fetch_token(
|
|
95
|
+
url=self._oauth_params.scheme.token_url,
|
|
96
|
+
authorization_response=authorization_response,
|
|
97
|
+
redirect_uri=redirect_url,
|
|
98
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
99
|
+
state=state,
|
|
100
|
+
)
|
|
101
|
+
return oauth_token
|
|
102
|
+
|
|
103
|
+
async def refresh_auth_token(self):
|
|
104
|
+
"""Refresh an expired token."""
|
|
105
|
+
if self._oauth_params is None or self._oauth_params.scheme.refresh_url is None:
|
|
106
|
+
raise MCPAuthException(
|
|
107
|
+
"Can't refresh a token without a refresh url make sure the oauth params have been set."
|
|
108
|
+
)
|
|
109
|
+
if self._auth_token is None:
|
|
110
|
+
raise MCPAuthException(
|
|
111
|
+
"Can't refresh a token without a refresh token, use the `set_auth_token` to add a OAuth2Token."
|
|
112
|
+
)
|
|
113
|
+
oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
|
|
114
|
+
oauth_token = await oauth_client.refresh_token(
|
|
115
|
+
url=self._oauth_params.scheme.refresh_url,
|
|
116
|
+
refresh_token=self._auth_token["refresh_token"],
|
|
117
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
118
|
+
)
|
|
119
|
+
self.set_auth_token(oauth_token)
|
|
120
|
+
|
|
121
|
+
def set_auth_token(self, token: OAuth2Token) -> None:
|
|
122
|
+
"""Register the authentication token with this client."""
|
|
123
|
+
self._auth_token = token
|
|
124
|
+
|
|
125
|
+
def _format_headers(self) -> dict[str, str]:
|
|
126
|
+
headers: dict[str, str] = {}
|
|
127
|
+
if self._sse_params.headers:
|
|
128
|
+
headers |= self._sse_params.headers
|
|
129
|
+
if self._auth_token:
|
|
130
|
+
headers["Authorization"] = f"Bearer {self._auth_token['access_token']}"
|
|
131
|
+
return headers
|
|
132
|
+
|
|
133
|
+
async def requires_auth(self) -> bool:
|
|
134
|
+
"""Check if the client requires authentication to communicate with the server."""
|
|
135
|
+
response = httpx.get(
|
|
136
|
+
self._sse_params.url,
|
|
137
|
+
headers=self._format_headers(),
|
|
138
|
+
timeout=self._sse_params.timeout,
|
|
139
|
+
)
|
|
140
|
+
return response.status_code == http.HTTPStatus.UNAUTHORIZED
|
|
141
|
+
|
|
142
|
+
async def _get_transport(
|
|
143
|
+
self, exit_stack: AsyncExitStack
|
|
144
|
+
) -> tuple[
|
|
145
|
+
MemoryObjectReceiveStream[typing.Union[SessionMessage, Exception]],
|
|
146
|
+
MemoryObjectSendStream[SessionMessage],
|
|
147
|
+
]:
|
|
148
|
+
try:
|
|
149
|
+
return await exit_stack.enter_async_context(
|
|
150
|
+
sse_client(
|
|
151
|
+
url=self._sse_params.url,
|
|
152
|
+
headers=self._format_headers(),
|
|
153
|
+
timeout=self._sse_params.timeout,
|
|
154
|
+
sse_read_timeout=self._sse_params.sse_read_timeout,
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
except Exception as e:
|
|
158
|
+
if isinstance(e, httpx.HTTPStatusError):
|
|
159
|
+
if e.response.status_code == http.HTTPStatus.UNAUTHORIZED:
|
|
160
|
+
if self._oauth_params is None:
|
|
161
|
+
raise MCPAuthException(
|
|
162
|
+
"Authentication required but no auth params provided."
|
|
163
|
+
) from e
|
|
164
|
+
raise MCPAuthException("Authentication required.") from e
|
|
165
|
+
raise
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import AsyncExitStack
|
|
4
|
+
|
|
5
|
+
from mistralai.extra.mcp.base import (
|
|
6
|
+
MCPClientBase,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
from mcp import stdio_client, StdioServerParameters
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MCPClientSTDIO(MCPClientBase):
|
|
15
|
+
"""MCP client that uses stdio for communication."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, stdio_params: StdioServerParameters, name: Optional[str] = None):
|
|
18
|
+
super().__init__(name=name)
|
|
19
|
+
self._stdio_params = stdio_params
|
|
20
|
+
|
|
21
|
+
async def _get_transport(self, exit_stack: AsyncExitStack):
|
|
22
|
+
return await exit_stack.enter_async_context(stdio_client(self._stdio_params))
|
|
File without changes
|