mcp-use 1.3.10__py3-none-any.whl → 1.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-use might be problematic. Click here for more details.

@@ -0,0 +1,214 @@
1
+ """OAuth callback server implementation."""
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+
6
+ import anyio
7
+ import uvicorn
8
+ from starlette.applications import Starlette
9
+ from starlette.requests import Request
10
+ from starlette.responses import HTMLResponse
11
+ from starlette.routing import Route
12
+
13
+ from ..logging import logger
14
+
15
+
16
+ @dataclass
17
+ class CallbackResponse:
18
+ """Response data from OAuth callback."""
19
+
20
+ code: str | None = None # Authorization code (success)
21
+ state: str | None = None # CSRF protection token
22
+ error: str | None = None # Errors code (if failed)
23
+ error_description: str | None = None
24
+ error_uri: str | None = None
25
+
26
+
27
+ class OAuthCallbackServer:
28
+ """Local server to handle OAuth callback."""
29
+
30
+ def __init__(self, port: int):
31
+ """Initialize the callback server.
32
+
33
+ Args:
34
+ port: Port to listen on.
35
+ """
36
+ self.port = port
37
+ self.redirect_uri: str | None = None
38
+ # Thread safe way to pass callback data to the main OAuth flow
39
+ self.response_queue: asyncio.Queue[CallbackResponse] = asyncio.Queue(maxsize=1)
40
+ self.server: uvicorn.Server | None = None
41
+ self._shutdown_event = anyio.Event()
42
+
43
+ async def start(self) -> str:
44
+ """Start the callback server and return the redirect URI."""
45
+ app = self._create_app()
46
+
47
+ # Create the server
48
+ config = uvicorn.Config(
49
+ app,
50
+ host="127.0.0.1",
51
+ port=self.port,
52
+ log_level="error", # Suppress uvicorn logs
53
+ )
54
+ self.server = uvicorn.Server(config)
55
+
56
+ # Start server in background
57
+ self._server_task = asyncio.create_task(self.server.serve())
58
+
59
+ # Wait a moment for server to start
60
+ await asyncio.sleep(0.1)
61
+
62
+ self.redirect_uri = f"http://localhost:{self.port}/callback"
63
+ return self.redirect_uri
64
+
65
+ async def wait_for_code(self, timeout: float = 300) -> CallbackResponse:
66
+ """Wait for the OAuth callback with a timeout (default 5 minutes)."""
67
+ try:
68
+ response = await asyncio.wait_for(self.response_queue.get(), timeout=timeout)
69
+ return response
70
+ except TimeoutError:
71
+ raise TimeoutError(f"OAuth callback not received within {timeout} seconds") from None
72
+ finally:
73
+ await self.shutdown()
74
+
75
+ async def shutdown(self):
76
+ """Shutdown the callback server."""
77
+ self._shutdown_event.set()
78
+ if self.server:
79
+ self.server.should_exit = True
80
+ if hasattr(self, "_server_task"):
81
+ try:
82
+ await asyncio.wait_for(self._server_task, timeout=5.0)
83
+ except TimeoutError:
84
+ self._server_task.cancel()
85
+
86
+ def _create_app(self) -> Starlette:
87
+ """Create the Starlette application."""
88
+
89
+ async def callback(request: Request) -> HTMLResponse:
90
+ """Handle the OAuth callback."""
91
+ params = request.query_params
92
+
93
+ # Extract OAuth parameters
94
+ response = CallbackResponse(
95
+ code=params.get("code"),
96
+ state=params.get("state"),
97
+ error=params.get("error"),
98
+ error_description=params.get("error_description"),
99
+ error_uri=params.get("error_uri"),
100
+ )
101
+
102
+ # Log the callback response
103
+ logger.debug(
104
+ f"OAuth callback received: error={response.error}, error_description={response.error_description}"
105
+ )
106
+ if response.code:
107
+ logger.debug("OAuth callback received authorization code")
108
+ else:
109
+ logger.error(f"OAuth callback error: {response.error} - {response.error_description}")
110
+
111
+ # Put response in queue
112
+ try:
113
+ self.response_queue.put_nowait(response)
114
+ except asyncio.QueueFull:
115
+ pass # Ignore if queue is already full
116
+
117
+ # Return success page
118
+ if response.code:
119
+ html = self._success_html()
120
+ else:
121
+ html = self._error_html(response.error, response.error_description)
122
+
123
+ return HTMLResponse(content=html)
124
+
125
+ routes = [Route("/callback", callback)]
126
+ return Starlette(routes=routes)
127
+
128
+ def _success_html(self) -> str:
129
+ """HTML response for successful authorization."""
130
+ return """
131
+ <!DOCTYPE html>
132
+ <html>
133
+ <head>
134
+ <title>Authorization Successful</title>
135
+ <style>
136
+ body {
137
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
138
+ display: flex;
139
+ justify-content: center;
140
+ align-items: center;
141
+ height: 100vh;
142
+ margin: 0;
143
+ background-color: #f5f5f5;
144
+ }
145
+ .container {
146
+ text-align: center;
147
+ padding: 2rem;
148
+ background: white;
149
+ border-radius: 8px;
150
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
151
+ }
152
+ h1 { color: #22c55e; margin-bottom: 0.5rem; }
153
+ p { color: #666; margin-top: 0.5rem; }
154
+ .icon { font-size: 48px; margin-bottom: 1rem; }
155
+ </style>
156
+ </head>
157
+ <body>
158
+ <div class="container">
159
+ <div class="icon">✅</div>
160
+ <h1>Authorization Successful!</h1>
161
+ <p>You can now close this window and return to your application.</p>
162
+ </div>
163
+ <script>
164
+ // Auto-close after 3 seconds
165
+ setTimeout(() => window.close(), 3000);
166
+ </script>
167
+ </body>
168
+ </html>
169
+ """
170
+
171
+ def _error_html(self, error: str | None, description: str | None) -> str:
172
+ """HTML response for authorization error."""
173
+ error_msg = error or "Unknown error"
174
+ desc_msg = description or "Authorization was not completed successfully."
175
+
176
+ return f"""
177
+ <!DOCTYPE html>
178
+ <html>
179
+ <head>
180
+ <title>Authorization Error</title>
181
+ <style>
182
+ body {{
183
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
184
+ display: flex;
185
+ justify-content: center;
186
+ align-items: center;
187
+ height: 100vh;
188
+ margin: 0;
189
+ background-color: #f5f5f5;
190
+ }}
191
+ .container {{
192
+ text-align: center;
193
+ padding: 2rem;
194
+ background: white;
195
+ border-radius: 8px;
196
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
197
+ max-width: 500px;
198
+ }}
199
+ h1 {{ color: #ef4444; margin-bottom: 0.5rem; }}
200
+ .error {{ color: #dc2626; font-weight: 600; margin: 1rem 0; }}
201
+ .description {{ color: #666; margin-top: 0.5rem; }}
202
+ .icon {{ font-size: 48px; margin-bottom: 1rem; }}
203
+ </style>
204
+ </head>
205
+ <body>
206
+ <div class="container">
207
+ <div class="icon">❌</div>
208
+ <h1>Authorization Error</h1>
209
+ <p class="error">{error_msg}</p>
210
+ <p class="description">{desc_msg}</p>
211
+ </div>
212
+ </body>
213
+ </html>
214
+ """
mcp_use/client.py CHANGED
@@ -15,6 +15,7 @@ from mcp_use.types.sandbox import SandboxOptions
15
15
 
16
16
  from .config import create_connector_from_config, load_config_file
17
17
  from .logging import logger
18
+ from .middleware import Middleware, default_logging_middleware
18
19
  from .session import MCPSession
19
20
 
20
21
 
@@ -35,6 +36,7 @@ class MCPClient:
35
36
  elicitation_callback: ElicitationFnT | None = None,
36
37
  message_handler: MessageHandlerFnT | None = None,
37
38
  logging_callback: LoggingFnT | None = None,
39
+ middleware: list[Middleware] | None = None,
38
40
  ) -> None:
39
41
  """Initialize a new MCP client.
40
42
 
@@ -55,6 +57,12 @@ class MCPClient:
55
57
  self.elicitation_callback = elicitation_callback
56
58
  self.message_handler = message_handler
57
59
  self.logging_callback = logging_callback
60
+ # Add default logging middleware if no middleware provided, or prepend it to existing middleware
61
+ default_middleware = [default_logging_middleware]
62
+ if middleware:
63
+ self.middleware = default_middleware + middleware
64
+ else:
65
+ self.middleware = default_middleware
58
66
  # Load configuration if provided
59
67
  if config is not None:
60
68
  if isinstance(config, str):
@@ -151,6 +159,21 @@ class MCPClient:
151
159
  if name in self.active_sessions:
152
160
  self.active_sessions.remove(name)
153
161
 
162
+ def add_middleware(self, middleware: Middleware) -> None:
163
+ """Add a middleware.
164
+
165
+ Args:
166
+ middleware: The middleware to add
167
+ """
168
+ if len(self.sessions) == 0 and middleware not in self.middleware:
169
+ self.middleware.append(middleware)
170
+ return
171
+
172
+ if middleware not in self.middleware:
173
+ self.middleware.append(middleware)
174
+ for session in self.sessions.values():
175
+ session.connector.middleware_manager.add_middleware(middleware)
176
+
154
177
  def get_server_names(self) -> list[str]:
155
178
  """Get the list of configured server names.
156
179
 
@@ -192,7 +215,7 @@ class MCPClient:
192
215
 
193
216
  server_config = servers[server_name]
194
217
 
195
- # Create connector with options
218
+ # Create connector with options and client-level auth
196
219
  connector = create_connector_from_config(
197
220
  server_config,
198
221
  sandbox=self.sandbox,
@@ -201,6 +224,7 @@ class MCPClient:
201
224
  elicitation_callback=self.elicitation_callback,
202
225
  message_handler=self.message_handler,
203
226
  logging_callback=self.logging_callback,
227
+ middleware=self.middleware,
204
228
  )
205
229
 
206
230
  # Create the session
mcp_use/config.py CHANGED
@@ -13,6 +13,7 @@ from mcp_use.types.sandbox import SandboxOptions
13
13
 
14
14
  from .connectors import BaseConnector, HttpConnector, SandboxConnector, StdioConnector, WebSocketConnector
15
15
  from .connectors.utils import is_stdio_server
16
+ from .middleware import Middleware
16
17
 
17
18
 
18
19
  def load_config_file(filepath: str) -> dict[str, Any]:
@@ -36,6 +37,7 @@ def create_connector_from_config(
36
37
  elicitation_callback: ElicitationFnT | None = None,
37
38
  message_handler: MessageHandlerFnT | None = None,
38
39
  logging_callback: LoggingFnT | None = None,
40
+ middleware: list[Middleware] | None = None,
39
41
  ) -> BaseConnector:
40
42
  """Create a connector based on server configuration.
41
43
  This function can be called with just the server_config parameter:
@@ -59,6 +61,7 @@ def create_connector_from_config(
59
61
  elicitation_callback=elicitation_callback,
60
62
  message_handler=message_handler,
61
63
  logging_callback=logging_callback,
64
+ middleware=middleware,
62
65
  )
63
66
 
64
67
  # Sandboxed connector
@@ -72,6 +75,7 @@ def create_connector_from_config(
72
75
  elicitation_callback=elicitation_callback,
73
76
  message_handler=message_handler,
74
77
  logging_callback=logging_callback,
78
+ middleware=middleware,
75
79
  )
76
80
 
77
81
  # HTTP connector
@@ -79,13 +83,14 @@ def create_connector_from_config(
79
83
  return HttpConnector(
80
84
  base_url=server_config["url"],
81
85
  headers=server_config.get("headers", None),
82
- auth_token=server_config.get("auth_token", None),
86
+ auth=server_config.get("auth", {}),
83
87
  timeout=server_config.get("timeout", 5),
84
88
  sse_read_timeout=server_config.get("sse_read_timeout", 60 * 5),
85
89
  sampling_callback=sampling_callback,
86
90
  elicitation_callback=elicitation_callback,
87
91
  message_handler=message_handler,
88
92
  logging_callback=logging_callback,
93
+ middleware=middleware,
89
94
  )
90
95
 
91
96
  # WebSocket connector
@@ -93,7 +98,7 @@ def create_connector_from_config(
93
98
  return WebSocketConnector(
94
99
  url=server_config["ws_url"],
95
100
  headers=server_config.get("headers", None),
96
- auth_token=server_config.get("auth_token", None),
101
+ auth=server_config.get("auth", {}),
97
102
  )
98
103
 
99
104
  raise ValueError("Cannot determine connector type from config")
@@ -31,6 +31,7 @@ from pydantic import AnyUrl
31
31
  import mcp_use
32
32
 
33
33
  from ..logging import logger
34
+ from ..middleware import Middleware, MiddlewareManager
34
35
  from ..task_managers import ConnectionManager
35
36
 
36
37
 
@@ -46,6 +47,7 @@ class BaseConnector(ABC):
46
47
  elicitation_callback: ElicitationFnT | None = None,
47
48
  message_handler: MessageHandlerFnT | None = None,
48
49
  logging_callback: LoggingFnT | None = None,
50
+ middleware: list[Middleware] | None = None,
49
51
  ):
50
52
  """Initialize base connector with common attributes."""
51
53
  self.client_session: ClientSession | None = None
@@ -62,6 +64,12 @@ class BaseConnector(ABC):
62
64
  self.logging_callback = logging_callback
63
65
  self.capabilities: ServerCapabilities | None = None
64
66
 
67
+ # Set up middleware manager
68
+ self.middleware_manager = MiddlewareManager()
69
+ if middleware:
70
+ for mw in middleware:
71
+ self.middleware_manager.add_middleware(mw)
72
+
65
73
  @property
66
74
  def client_info(self) -> Implementation:
67
75
  """Get the client info for the connector."""
@@ -111,29 +119,34 @@ class BaseConnector(ABC):
111
119
  """Clean up all resources associated with this connector."""
112
120
  errors = []
113
121
 
114
- # First close the client session
115
- if self.client_session:
122
+ # First stop the connection manager, this closes the ClientSession inside
123
+ # the same task where it was opened, avoiding cancel-scope mismatches.
124
+ if self._connection_manager:
116
125
  try:
117
- logger.debug("Closing client session")
118
- await self.client_session.__aexit__(None, None, None)
126
+ logger.debug("Stopping connection manager")
127
+ await self._connection_manager.stop()
119
128
  except Exception as e:
120
- error_msg = f"Error closing client session: {e}"
129
+ error_msg = f"Error stopping connection manager: {e}"
121
130
  logger.warning(error_msg)
122
131
  errors.append(error_msg)
123
132
  finally:
124
- self.client_session = None
133
+ self._connection_manager = None
125
134
 
126
- # Then stop the connection manager
127
- if self._connection_manager:
135
+ # Ensure the client_session reference is cleared (it should already be
136
+ # closed by the connection manager). Only attempt a direct __aexit__ if
137
+ # the connection manager did *not* exist, this covers edge-cases like
138
+ # failed connections where no manager was started.
139
+ if self.client_session:
128
140
  try:
129
- logger.debug("Stopping connection manager")
130
- await self._connection_manager.stop()
141
+ if not self._connection_manager:
142
+ logger.debug("Closing client session (no connection manager)")
143
+ await self.client_session.__aexit__(None, None, None)
131
144
  except Exception as e:
132
- error_msg = f"Error stopping connection manager: {e}"
145
+ error_msg = f"Error closing client session: {e}"
133
146
  logger.warning(error_msg)
134
147
  errors.append(error_msg)
135
148
  finally:
136
- self._connection_manager = None
149
+ self.client_session = None
137
150
 
138
151
  # Reset tools
139
152
  self._tools = None