mistralai 1.7.1__py3-none-any.whl → 1.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. mistralai/_version.py +2 -2
  2. mistralai/beta.py +22 -0
  3. mistralai/conversations.py +2660 -0
  4. mistralai/embeddings.py +12 -0
  5. mistralai/extra/__init__.py +10 -2
  6. mistralai/extra/exceptions.py +14 -0
  7. mistralai/extra/mcp/__init__.py +0 -0
  8. mistralai/extra/mcp/auth.py +166 -0
  9. mistralai/extra/mcp/base.py +155 -0
  10. mistralai/extra/mcp/sse.py +165 -0
  11. mistralai/extra/mcp/stdio.py +22 -0
  12. mistralai/extra/run/__init__.py +0 -0
  13. mistralai/extra/run/context.py +295 -0
  14. mistralai/extra/run/result.py +212 -0
  15. mistralai/extra/run/tools.py +225 -0
  16. mistralai/extra/run/utils.py +36 -0
  17. mistralai/extra/tests/test_struct_chat.py +1 -1
  18. mistralai/mistral_agents.py +1160 -0
  19. mistralai/models/__init__.py +472 -1
  20. mistralai/models/agent.py +129 -0
  21. mistralai/models/agentconversation.py +71 -0
  22. mistralai/models/agentcreationrequest.py +109 -0
  23. mistralai/models/agenthandoffdoneevent.py +33 -0
  24. mistralai/models/agenthandoffentry.py +75 -0
  25. mistralai/models/agenthandoffstartedevent.py +33 -0
  26. mistralai/models/agents_api_v1_agents_getop.py +16 -0
  27. mistralai/models/agents_api_v1_agents_listop.py +24 -0
  28. mistralai/models/agents_api_v1_agents_update_versionop.py +21 -0
  29. mistralai/models/agents_api_v1_agents_updateop.py +23 -0
  30. mistralai/models/agents_api_v1_conversations_append_streamop.py +28 -0
  31. mistralai/models/agents_api_v1_conversations_appendop.py +28 -0
  32. mistralai/models/agents_api_v1_conversations_getop.py +33 -0
  33. mistralai/models/agents_api_v1_conversations_historyop.py +16 -0
  34. mistralai/models/agents_api_v1_conversations_listop.py +37 -0
  35. mistralai/models/agents_api_v1_conversations_messagesop.py +16 -0
  36. mistralai/models/agents_api_v1_conversations_restart_streamop.py +26 -0
  37. mistralai/models/agents_api_v1_conversations_restartop.py +26 -0
  38. mistralai/models/agentupdaterequest.py +111 -0
  39. mistralai/models/builtinconnectors.py +13 -0
  40. mistralai/models/chatcompletionresponse.py +6 -6
  41. mistralai/models/codeinterpretertool.py +17 -0
  42. mistralai/models/completionargs.py +100 -0
  43. mistralai/models/completionargsstop.py +13 -0
  44. mistralai/models/completionjobout.py +3 -3
  45. mistralai/models/conversationappendrequest.py +35 -0
  46. mistralai/models/conversationappendstreamrequest.py +37 -0
  47. mistralai/models/conversationevents.py +72 -0
  48. mistralai/models/conversationhistory.py +58 -0
  49. mistralai/models/conversationinputs.py +14 -0
  50. mistralai/models/conversationmessages.py +28 -0
  51. mistralai/models/conversationrequest.py +133 -0
  52. mistralai/models/conversationresponse.py +51 -0
  53. mistralai/models/conversationrestartrequest.py +42 -0
  54. mistralai/models/conversationrestartstreamrequest.py +44 -0
  55. mistralai/models/conversationstreamrequest.py +135 -0
  56. mistralai/models/conversationusageinfo.py +63 -0
  57. mistralai/models/documentlibrarytool.py +22 -0
  58. mistralai/models/embeddingdtype.py +7 -0
  59. mistralai/models/embeddingrequest.py +43 -3
  60. mistralai/models/fimcompletionresponse.py +6 -6
  61. mistralai/models/functioncallentry.py +76 -0
  62. mistralai/models/functioncallentryarguments.py +15 -0
  63. mistralai/models/functioncallevent.py +36 -0
  64. mistralai/models/functionresultentry.py +69 -0
  65. mistralai/models/functiontool.py +21 -0
  66. mistralai/models/imagegenerationtool.py +17 -0
  67. mistralai/models/inputentries.py +18 -0
  68. mistralai/models/messageentries.py +18 -0
  69. mistralai/models/messageinputcontentchunks.py +26 -0
  70. mistralai/models/messageinputentry.py +89 -0
  71. mistralai/models/messageoutputcontentchunks.py +30 -0
  72. mistralai/models/messageoutputentry.py +100 -0
  73. mistralai/models/messageoutputevent.py +93 -0
  74. mistralai/models/modelconversation.py +127 -0
  75. mistralai/models/outputcontentchunks.py +30 -0
  76. mistralai/models/responsedoneevent.py +25 -0
  77. mistralai/models/responseerrorevent.py +27 -0
  78. mistralai/models/responsestartedevent.py +24 -0
  79. mistralai/models/ssetypes.py +18 -0
  80. mistralai/models/toolexecutiondoneevent.py +34 -0
  81. mistralai/models/toolexecutionentry.py +70 -0
  82. mistralai/models/toolexecutionstartedevent.py +31 -0
  83. mistralai/models/toolfilechunk.py +61 -0
  84. mistralai/models/toolreferencechunk.py +61 -0
  85. mistralai/models/websearchpremiumtool.py +17 -0
  86. mistralai/models/websearchtool.py +17 -0
  87. mistralai/sdk.py +3 -0
  88. {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/METADATA +42 -7
  89. {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/RECORD +91 -14
  90. {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/LICENSE +0 -0
  91. {mistralai-1.7.1.dist-info → mistralai-1.8.1.dist-info}/WHEEL +0 -0
mistralai/embeddings.py CHANGED
@@ -18,6 +18,8 @@ class Embeddings(BaseSDK):
18
18
  inputs: Union[
19
19
  models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
20
20
  ],
21
+ output_dimension: OptionalNullable[int] = UNSET,
22
+ output_dtype: Optional[models.EmbeddingDtype] = None,
21
23
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
22
24
  server_url: Optional[str] = None,
23
25
  timeout_ms: Optional[int] = None,
@@ -29,6 +31,8 @@ class Embeddings(BaseSDK):
29
31
 
30
32
  :param model: ID of the model to use.
31
33
  :param inputs: Text to embed.
34
+ :param output_dimension: The dimension of the output embeddings.
35
+ :param output_dtype:
32
36
  :param retries: Override the default retry configuration for this method
33
37
  :param server_url: Override the default server URL for this method
34
38
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -47,6 +51,8 @@ class Embeddings(BaseSDK):
47
51
  request = models.EmbeddingRequest(
48
52
  model=model,
49
53
  inputs=inputs,
54
+ output_dimension=output_dimension,
55
+ output_dtype=output_dtype,
50
56
  )
51
57
 
52
58
  req = self._build_request(
@@ -125,6 +131,8 @@ class Embeddings(BaseSDK):
125
131
  inputs: Union[
126
132
  models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
127
133
  ],
134
+ output_dimension: OptionalNullable[int] = UNSET,
135
+ output_dtype: Optional[models.EmbeddingDtype] = None,
128
136
  retries: OptionalNullable[utils.RetryConfig] = UNSET,
129
137
  server_url: Optional[str] = None,
130
138
  timeout_ms: Optional[int] = None,
@@ -136,6 +144,8 @@ class Embeddings(BaseSDK):
136
144
 
137
145
  :param model: ID of the model to use.
138
146
  :param inputs: Text to embed.
147
+ :param output_dimension: The dimension of the output embeddings.
148
+ :param output_dtype:
139
149
  :param retries: Override the default retry configuration for this method
140
150
  :param server_url: Override the default server URL for this method
141
151
  :param timeout_ms: Override the default request timeout configuration for this method in milliseconds
@@ -154,6 +164,8 @@ class Embeddings(BaseSDK):
154
164
  request = models.EmbeddingRequest(
155
165
  model=model,
156
166
  inputs=inputs,
167
+ output_dimension=output_dimension,
168
+ output_dtype=output_dtype,
157
169
  )
158
170
 
159
171
  req = self._build_request_async(
@@ -1,5 +1,13 @@
1
- from .struct_chat import ParsedChatCompletionResponse, convert_to_parsed_chat_completion_response
1
+ from .struct_chat import (
2
+ ParsedChatCompletionResponse,
3
+ convert_to_parsed_chat_completion_response,
4
+ )
2
5
  from .utils import response_format_from_pydantic_model
3
6
  from .utils.response_format import CustomPydanticModel
4
7
 
5
- __all__ = ["convert_to_parsed_chat_completion_response", "response_format_from_pydantic_model", "CustomPydanticModel", "ParsedChatCompletionResponse"]
8
+ __all__ = [
9
+ "convert_to_parsed_chat_completion_response",
10
+ "response_format_from_pydantic_model",
11
+ "CustomPydanticModel",
12
+ "ParsedChatCompletionResponse",
13
+ ]
@@ -0,0 +1,14 @@
1
+ class MistralClientException(Exception):
2
+ """Base exception for all the client errors."""
3
+
4
+
5
+ class RunException(MistralClientException):
6
+ """Exception raised for errors during a conversation run."""
7
+
8
+
9
+ class MCPException(MistralClientException):
10
+ """Exception raised for errors related to MCP operations."""
11
+
12
+
13
+ class MCPAuthException(MCPException):
14
+ """Exception raised for authentication errors with an MCP server."""
File without changes
@@ -0,0 +1,166 @@
1
+ from typing import Optional
2
+
3
+ from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
4
+ from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncOAuth2ClientBase
5
+ import httpx
6
+ import logging
7
+
8
+ from mistralai.types import BaseModel
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class Oauth2AuthorizationScheme(BaseModel):
14
+ """Information about the oauth flow to perform with the authorization server."""
15
+
16
+ authorization_url: str
17
+ token_url: str
18
+ scope: list[str]
19
+ description: Optional[str] = None
20
+ refresh_url: Optional[str] = None
21
+
22
+
23
+ class OAuthParams(BaseModel):
24
+ """Required params for authorization."""
25
+
26
+ scheme: Oauth2AuthorizationScheme
27
+ client_id: str
28
+ client_secret: str
29
+
30
+
31
+ class AsyncOAuth2Client(AsyncOAuth2ClientBase):
32
+ """Subclass of the Async httpx oauth client which provides a constructor from OAuthParams."""
33
+
34
+ @classmethod
35
+ def from_oauth_params(cls, oauth_params: OAuthParams) -> "AsyncOAuth2Client":
36
+ return cls(
37
+ client_id=oauth_params.client_id,
38
+ client_secret=oauth_params.client_secret,
39
+ scope=oauth_params.scheme.scope,
40
+ )
41
+
42
+
43
+ async def get_well_known_authorization_server_metadata(
44
+ server_url: str,
45
+ ) -> Optional[AuthorizationServerMetadata]:
46
+ """Fetch the metadata from the well-known location.
47
+
48
+ This should be available on MCP servers as described by the specification:
49
+ https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#2-3-server-metadata-discovery.
50
+ """
51
+ well_known_url = f"{server_url}/.well-known/oauth-authorization-server"
52
+ response = await httpx.AsyncClient().get(well_known_url)
53
+ if 200 <= response.status_code < 300:
54
+ try:
55
+ server_metadata = AuthorizationServerMetadata(**response.json())
56
+ server_metadata.validate()
57
+ return server_metadata
58
+ except ValueError:
59
+ logger.exception("Failed to parse oauth well-known metadata")
60
+ return None
61
+ else:
62
+ logger.error(f"Failed to get oauth well-known metadata from {server_url}")
63
+ return None
64
+
65
+
66
+ async def get_oauth_server_metadata(server_url: str) -> AuthorizationServerMetadata:
67
+ """Fetch the metadata from the authorization server to perform the oauth flow."""
68
+ # 1) attempt to get the metadata from the resource server at /.well-known/oauth-protected-resource
69
+ # TODO: new self-discovery protocol, not released yet
70
+
71
+ # 2) attempt to get the metadata from the authorization server at /.well-known/oauth-authorization-server
72
+ metadata = await get_well_known_authorization_server_metadata(server_url=server_url)
73
+ if metadata is not None:
74
+ return metadata
75
+
76
+ # 3) fallback on default endpoints
77
+ # https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization#2-3-3-fallbacks-for-servers-without-metadata-discovery
78
+ return AuthorizationServerMetadata(
79
+ issuer=server_url,
80
+ authorization_endpoint=f"{server_url}/authorize",
81
+ token_endpoint=f"{server_url}/token",
82
+ register_endpoint=f"{server_url}/register",
83
+ response_types_supported=["code"],
84
+ response_modes_supported=["query"],
85
+ grant_types_supported=["authorization_code", "refresh_token"],
86
+ token_endpoint_auth_methods_supported=["client_secret_basic"],
87
+ code_challenge_methods_supported=["S256", "plain"],
88
+ )
89
+
90
+
91
+ async def dynamic_client_registration(
92
+ register_endpoint: str,
93
+ redirect_url: str,
94
+ async_client: httpx.AsyncClient,
95
+ ) -> tuple[str, str]:
96
+ """Try to register the client dynamically with an MCP server.
97
+
98
+ Returns a client_id and client_secret.
99
+ """
100
+ # Construct the registration request payload
101
+ registration_payload = {
102
+ "client_name": "MistralSDKClient",
103
+ "grant_types": ["authorization_code", "refresh_token"],
104
+ "token_endpoint_auth_method": "client_secret_basic",
105
+ "response_types": ["code"],
106
+ "redirect_uris": [redirect_url],
107
+ }
108
+
109
+ # Make the registration request
110
+ response = await async_client.post(register_endpoint, json=registration_payload)
111
+ try:
112
+ response.raise_for_status()
113
+ registration_info = response.json()
114
+ client_id = registration_info["client_id"]
115
+ client_secret = registration_info["client_secret"]
116
+ except Exception as e:
117
+ raise ValueError(
118
+ f"Client registration failed: status={response.status_code}, error={response.text}"
119
+ ) from e
120
+ return client_id, client_secret
121
+
122
+
123
+ async def build_oauth_params(
124
+ server_url: str,
125
+ redirect_url: str,
126
+ client_id: Optional[str] = None,
127
+ client_secret: Optional[str] = None,
128
+ scope: Optional[list[str]] = None,
129
+ async_client: Optional[httpx.AsyncClient] = None,
130
+ ) -> OAuthParams:
131
+ """Get issuer metadata and build the oauth required params."""
132
+ metadata = await get_oauth_server_metadata(server_url=server_url)
133
+ oauth_scheme = Oauth2AuthorizationScheme(
134
+ authorization_url=metadata.authorization_endpoint,
135
+ token_url=metadata.token_endpoint,
136
+ scope=scope or [],
137
+ refresh_url=metadata.token_endpoint
138
+ if "refresh_token" in metadata.grant_types_supported
139
+ else None,
140
+ )
141
+ if client_id and client_secret:
142
+ return OAuthParams(
143
+ client_id=client_id,
144
+ client_secret=client_secret,
145
+ scheme=oauth_scheme,
146
+ )
147
+
148
+ # Try to dynamically register the client
149
+ if async_client:
150
+ reg_client_id, reg_client_secret = await dynamic_client_registration(
151
+ register_endpoint=metadata.registration_endpoint,
152
+ redirect_url=redirect_url,
153
+ async_client=async_client,
154
+ )
155
+ else:
156
+ async with httpx.AsyncClient() as async_client:
157
+ reg_client_id, reg_client_secret = await dynamic_client_registration(
158
+ register_endpoint=metadata.registration_endpoint,
159
+ redirect_url=redirect_url,
160
+ async_client=async_client,
161
+ )
162
+ return OAuthParams(
163
+ client_id=reg_client_id,
164
+ client_secret=reg_client_secret,
165
+ scheme=oauth_scheme,
166
+ )
@@ -0,0 +1,155 @@
1
+ from typing import Optional, Union
2
+ import logging
3
+ import typing
4
+ from contextlib import AsyncExitStack
5
+ from typing import Protocol, Any
6
+
7
+ from mcp import ClientSession
8
+ from mcp.types import ListPromptsResult, EmbeddedResource, ImageContent, TextContent
9
+
10
+ from mistralai.extra.exceptions import MCPException
11
+ from mistralai.models import (
12
+ FunctionTool,
13
+ Function,
14
+ SystemMessageTypedDict,
15
+ AssistantMessageTypedDict,
16
+ TextChunkTypedDict,
17
+ )
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MCPSystemPrompt(typing.TypedDict):
23
+ description: Optional[str]
24
+ messages: list[Union[SystemMessageTypedDict, AssistantMessageTypedDict]]
25
+
26
+
27
+ class MCPClientProtocol(Protocol):
28
+ """MCP client that converts MCP artifacts to Mistral format."""
29
+
30
+ _name: str
31
+
32
+ async def initialize(self, exit_stack: Optional[AsyncExitStack]) -> None:
33
+ ...
34
+
35
+ async def aclose(self) -> None:
36
+ ...
37
+
38
+ async def get_tools(self) -> list[FunctionTool]:
39
+ ...
40
+
41
+ async def execute_tool(
42
+ self, name: str, arguments: dict
43
+ ) -> list[TextChunkTypedDict]:
44
+ ...
45
+
46
+ async def get_system_prompt(
47
+ self, name: str, arguments: dict[str, Any]
48
+ ) -> MCPSystemPrompt:
49
+ ...
50
+
51
+ async def list_system_prompts(self) -> ListPromptsResult:
52
+ ...
53
+
54
+
55
+ class MCPClientBase(MCPClientProtocol):
56
+ """Base class to implement functionalities from an initialized MCP session."""
57
+
58
+ _session: ClientSession
59
+
60
+ def __init__(self, name: Optional[str] = None):
61
+ self._name = name or self.__class__.__name__
62
+ self._exit_stack: Optional[AsyncExitStack] = None
63
+ self._is_initialized = False
64
+
65
+ def _convert_content(
66
+ self, mcp_content: Union[TextContent, ImageContent, EmbeddedResource]
67
+ ) -> TextChunkTypedDict:
68
+ if not mcp_content.type == "text":
69
+ raise MCPException("Only supporting text tool responses for now.")
70
+ return {"type": "text", "text": mcp_content.text}
71
+
72
+ def _convert_content_list(
73
+ self, mcp_contents: list[Union[TextContent, ImageContent, EmbeddedResource]]
74
+ ) -> list[TextChunkTypedDict]:
75
+ content_chunks = []
76
+ for mcp_content in mcp_contents:
77
+ content_chunks.append(self._convert_content(mcp_content))
78
+ return content_chunks
79
+
80
+ async def get_tools(self) -> list[FunctionTool]:
81
+ mcp_tools = await self._session.list_tools()
82
+ tools = []
83
+ for mcp_tool in mcp_tools.tools:
84
+ tools.append(
85
+ FunctionTool(
86
+ type="function",
87
+ function=Function(
88
+ name=mcp_tool.name,
89
+ description=mcp_tool.description,
90
+ parameters=mcp_tool.inputSchema,
91
+ strict=True,
92
+ ),
93
+ )
94
+ )
95
+ return tools
96
+
97
+ async def execute_tool(
98
+ self, name: str, arguments: dict[str, Any]
99
+ ) -> list[TextChunkTypedDict]:
100
+ contents = await self._session.call_tool(name=name, arguments=arguments)
101
+ return self._convert_content_list(contents.content)
102
+
103
+ async def get_system_prompt(
104
+ self, name: str, arguments: dict[str, Any]
105
+ ) -> MCPSystemPrompt:
106
+ prompt_result = await self._session.get_prompt(name=name, arguments=arguments)
107
+ return {
108
+ "description": prompt_result.description,
109
+ "messages": [
110
+ typing.cast(
111
+ Union[SystemMessageTypedDict, AssistantMessageTypedDict],
112
+ {
113
+ "role": message.role,
114
+ "content": self._convert_content(mcp_content=message.content),
115
+ },
116
+ )
117
+ for message in prompt_result.messages
118
+ ],
119
+ }
120
+
121
+ async def list_system_prompts(self) -> ListPromptsResult:
122
+ return await self._session.list_prompts()
123
+
124
+ async def initialize(self, exit_stack: Optional[AsyncExitStack] = None) -> None:
125
+ """Initialize the MCP session."""
126
+ # client is already initialized so return
127
+ if self._is_initialized:
128
+ return
129
+ if exit_stack is None:
130
+ self._exit_stack = AsyncExitStack()
131
+ exit_stack = self._exit_stack
132
+ stdio_transport = await self._get_transport(exit_stack=exit_stack)
133
+ mcp_session = await exit_stack.enter_async_context(
134
+ ClientSession(
135
+ read_stream=stdio_transport[0],
136
+ write_stream=stdio_transport[1],
137
+ )
138
+ )
139
+ await mcp_session.initialize()
140
+ self._session = mcp_session
141
+ self._is_initialized = True
142
+
143
+ async def aclose(self):
144
+ """Close the MCP session."""
145
+ if self._exit_stack:
146
+ await self._exit_stack.aclose()
147
+
148
+ def __repr__(self):
149
+ return f"<{self.__class__.__name__} name={self._name!r} id=0x{id(self):x}>"
150
+
151
+ def __str__(self):
152
+ return f"{self.__class__.__name__}(name={self._name})"
153
+
154
+ async def _get_transport(self, exit_stack: AsyncExitStack):
155
+ raise NotImplementedError
@@ -0,0 +1,165 @@
1
+ import http
2
+ import logging
3
+ import typing
4
+ from typing import Any, Optional
5
+ from contextlib import AsyncExitStack
6
+ from functools import cached_property
7
+
8
+ import httpx
9
+
10
+ from mistralai.extra.exceptions import MCPAuthException
11
+ from mistralai.extra.mcp.base import (
12
+ MCPClientBase,
13
+ )
14
+ from mistralai.extra.mcp.auth import OAuthParams, AsyncOAuth2Client
15
+ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
16
+
17
+ from mcp.client.sse import sse_client
18
+ from mcp.shared.message import SessionMessage
19
+ from authlib.oauth2.rfc6749 import OAuth2Token
20
+
21
+ from mistralai.types import BaseModel
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SSEServerParams(BaseModel):
27
+ """Parameters required for a MCPClient with SSE transport"""
28
+
29
+ url: str
30
+ headers: Optional[dict[str, Any]] = None
31
+ timeout: float = 5
32
+ sse_read_timeout: float = 60 * 5
33
+
34
+
35
+ class MCPClientSSE(MCPClientBase):
36
+ """MCP client that uses sse for communication.
37
+
38
+ The client provides authentication for OAuth2 protocol following the current MCP authorization spec:
39
+ https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization.
40
+
41
+ This is possibly going to change in the future since the protocol has ongoing discussions.
42
+ """
43
+
44
+ _oauth_params: Optional[OAuthParams]
45
+ _sse_params: SSEServerParams
46
+
47
+ def __init__(
48
+ self,
49
+ sse_params: SSEServerParams,
50
+ name: Optional[str] = None,
51
+ oauth_params: Optional[OAuthParams] = None,
52
+ auth_token: Optional[OAuth2Token] = None,
53
+ ):
54
+ super().__init__(name=name)
55
+ self._sse_params = sse_params
56
+ self._oauth_params: Optional[OAuthParams] = oauth_params
57
+ self._auth_token: Optional[OAuth2Token] = auth_token
58
+
59
+ @cached_property
60
+ def base_url(self) -> str:
61
+ return self._sse_params.url.rstrip("/sse")
62
+
63
+ def set_oauth_params(self, oauth_params: OAuthParams):
64
+ """Update the oauth params and client accordingly."""
65
+ if self._oauth_params is not None:
66
+ logger.warning(f"Overriding current oauth params for {self._name}")
67
+ self._oauth_params = oauth_params
68
+
69
+ async def get_auth_url_and_state(self, redirect_url: str) -> tuple[str, str]:
70
+ """Create the authorization url for client to start oauth flow."""
71
+ if self._oauth_params is None:
72
+ raise MCPAuthException(
73
+ "Can't generate an authorization url without oauth_params being set, "
74
+ "make sure the oauth params have been set."
75
+ )
76
+ oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
77
+ auth_url, state = oauth_client.create_authorization_url(
78
+ self._oauth_params.scheme.authorization_url, redirect_uri=redirect_url
79
+ )
80
+ return auth_url, state
81
+
82
+ async def get_token_from_auth_response(
83
+ self,
84
+ authorization_response: str,
85
+ redirect_url: str,
86
+ state: str,
87
+ ) -> OAuth2Token:
88
+ """Fetch the authentication token from the server."""
89
+ if self._oauth_params is None:
90
+ raise MCPAuthException(
91
+ "Can't fetch a token without oauth_params, make sure they have been set."
92
+ )
93
+ oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
94
+ oauth_token = await oauth_client.fetch_token(
95
+ url=self._oauth_params.scheme.token_url,
96
+ authorization_response=authorization_response,
97
+ redirect_uri=redirect_url,
98
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
99
+ state=state,
100
+ )
101
+ return oauth_token
102
+
103
+ async def refresh_auth_token(self):
104
+ """Refresh an expired token."""
105
+ if self._oauth_params is None or self._oauth_params.scheme.refresh_url is None:
106
+ raise MCPAuthException(
107
+ "Can't refresh a token without a refresh url make sure the oauth params have been set."
108
+ )
109
+ if self._auth_token is None:
110
+ raise MCPAuthException(
111
+ "Can't refresh a token without a refresh token, use the `set_auth_token` to add a OAuth2Token."
112
+ )
113
+ oauth_client = AsyncOAuth2Client.from_oauth_params(self._oauth_params)
114
+ oauth_token = await oauth_client.refresh_token(
115
+ url=self._oauth_params.scheme.refresh_url,
116
+ refresh_token=self._auth_token["refresh_token"],
117
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
118
+ )
119
+ self.set_auth_token(oauth_token)
120
+
121
+ def set_auth_token(self, token: OAuth2Token) -> None:
122
+ """Register the authentication token with this client."""
123
+ self._auth_token = token
124
+
125
+ def _format_headers(self) -> dict[str, str]:
126
+ headers: dict[str, str] = {}
127
+ if self._sse_params.headers:
128
+ headers |= self._sse_params.headers
129
+ if self._auth_token:
130
+ headers["Authorization"] = f"Bearer {self._auth_token['access_token']}"
131
+ return headers
132
+
133
+ async def requires_auth(self) -> bool:
134
+ """Check if the client requires authentication to communicate with the server."""
135
+ response = httpx.get(
136
+ self._sse_params.url,
137
+ headers=self._format_headers(),
138
+ timeout=self._sse_params.timeout,
139
+ )
140
+ return response.status_code == http.HTTPStatus.UNAUTHORIZED
141
+
142
+ async def _get_transport(
143
+ self, exit_stack: AsyncExitStack
144
+ ) -> tuple[
145
+ MemoryObjectReceiveStream[typing.Union[SessionMessage, Exception]],
146
+ MemoryObjectSendStream[SessionMessage],
147
+ ]:
148
+ try:
149
+ return await exit_stack.enter_async_context(
150
+ sse_client(
151
+ url=self._sse_params.url,
152
+ headers=self._format_headers(),
153
+ timeout=self._sse_params.timeout,
154
+ sse_read_timeout=self._sse_params.sse_read_timeout,
155
+ )
156
+ )
157
+ except Exception as e:
158
+ if isinstance(e, httpx.HTTPStatusError):
159
+ if e.response.status_code == http.HTTPStatus.UNAUTHORIZED:
160
+ if self._oauth_params is None:
161
+ raise MCPAuthException(
162
+ "Authentication required but no auth params provided."
163
+ ) from e
164
+ raise MCPAuthException("Authentication required.") from e
165
+ raise
@@ -0,0 +1,22 @@
1
+ from typing import Optional
2
+ import logging
3
+ from contextlib import AsyncExitStack
4
+
5
+ from mistralai.extra.mcp.base import (
6
+ MCPClientBase,
7
+ )
8
+
9
+ from mcp import stdio_client, StdioServerParameters
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class MCPClientSTDIO(MCPClientBase):
15
+ """MCP client that uses stdio for communication."""
16
+
17
+ def __init__(self, stdio_params: StdioServerParameters, name: Optional[str] = None):
18
+ super().__init__(name=name)
19
+ self._stdio_params = stdio_params
20
+
21
+ async def _get_transport(self, exit_stack: AsyncExitStack):
22
+ return await exit_stack.enter_async_context(stdio_client(self._stdio_params))
File without changes