mcp-use 1.3.10__py3-none-any.whl → 1.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mcp-use might be problematic. Click here for more details.

@@ -0,0 +1,214 @@
1
+ """OAuth callback server implementation."""
2
+
3
+ import asyncio
4
+ from dataclasses import dataclass
5
+
6
+ import anyio
7
+ import uvicorn
8
+ from starlette.applications import Starlette
9
+ from starlette.requests import Request
10
+ from starlette.responses import HTMLResponse
11
+ from starlette.routing import Route
12
+
13
+ from ..logging import logger
14
+
15
+
16
+ @dataclass
17
+ class CallbackResponse:
18
+ """Response data from OAuth callback."""
19
+
20
+ code: str | None = None # Authorization code (success)
21
+ state: str | None = None # CSRF protection token
22
+ error: str | None = None # Errors code (if failed)
23
+ error_description: str | None = None
24
+ error_uri: str | None = None
25
+
26
+
27
+ class OAuthCallbackServer:
28
+ """Local server to handle OAuth callback."""
29
+
30
+ def __init__(self, port: int):
31
+ """Initialize the callback server.
32
+
33
+ Args:
34
+ port: Port to listen on.
35
+ """
36
+ self.port = port
37
+ self.redirect_uri: str | None = None
38
+ # Thread safe way to pass callback data to the main OAuth flow
39
+ self.response_queue: asyncio.Queue[CallbackResponse] = asyncio.Queue(maxsize=1)
40
+ self.server: uvicorn.Server | None = None
41
+ self._shutdown_event = anyio.Event()
42
+
43
+ async def start(self) -> str:
44
+ """Start the callback server and return the redirect URI."""
45
+ app = self._create_app()
46
+
47
+ # Create the server
48
+ config = uvicorn.Config(
49
+ app,
50
+ host="127.0.0.1",
51
+ port=self.port,
52
+ log_level="error", # Suppress uvicorn logs
53
+ )
54
+ self.server = uvicorn.Server(config)
55
+
56
+ # Start server in background
57
+ self._server_task = asyncio.create_task(self.server.serve())
58
+
59
+ # Wait a moment for server to start
60
+ await asyncio.sleep(0.1)
61
+
62
+ self.redirect_uri = f"http://localhost:{self.port}/callback"
63
+ return self.redirect_uri
64
+
65
+ async def wait_for_code(self, timeout: float = 300) -> CallbackResponse:
66
+ """Wait for the OAuth callback with a timeout (default 5 minutes)."""
67
+ try:
68
+ response = await asyncio.wait_for(self.response_queue.get(), timeout=timeout)
69
+ return response
70
+ except TimeoutError:
71
+ raise TimeoutError(f"OAuth callback not received within {timeout} seconds") from None
72
+ finally:
73
+ await self.shutdown()
74
+
75
+ async def shutdown(self):
76
+ """Shutdown the callback server."""
77
+ self._shutdown_event.set()
78
+ if self.server:
79
+ self.server.should_exit = True
80
+ if hasattr(self, "_server_task"):
81
+ try:
82
+ await asyncio.wait_for(self._server_task, timeout=5.0)
83
+ except TimeoutError:
84
+ self._server_task.cancel()
85
+
86
+ def _create_app(self) -> Starlette:
87
+ """Create the Starlette application."""
88
+
89
+ async def callback(request: Request) -> HTMLResponse:
90
+ """Handle the OAuth callback."""
91
+ params = request.query_params
92
+
93
+ # Extract OAuth parameters
94
+ response = CallbackResponse(
95
+ code=params.get("code"),
96
+ state=params.get("state"),
97
+ error=params.get("error"),
98
+ error_description=params.get("error_description"),
99
+ error_uri=params.get("error_uri"),
100
+ )
101
+
102
+ # Log the callback response
103
+ logger.debug(
104
+ f"OAuth callback received: error={response.error}, error_description={response.error_description}"
105
+ )
106
+ if response.code:
107
+ logger.debug("OAuth callback received authorization code")
108
+ else:
109
+ logger.error(f"OAuth callback error: {response.error} - {response.error_description}")
110
+
111
+ # Put response in queue
112
+ try:
113
+ self.response_queue.put_nowait(response)
114
+ except asyncio.QueueFull:
115
+ pass # Ignore if queue is already full
116
+
117
+ # Return success page
118
+ if response.code:
119
+ html = self._success_html()
120
+ else:
121
+ html = self._error_html(response.error, response.error_description)
122
+
123
+ return HTMLResponse(content=html)
124
+
125
+ routes = [Route("/callback", callback)]
126
+ return Starlette(routes=routes)
127
+
128
+ def _success_html(self) -> str:
129
+ """HTML response for successful authorization."""
130
+ return """
131
+ <!DOCTYPE html>
132
+ <html>
133
+ <head>
134
+ <title>Authorization Successful</title>
135
+ <style>
136
+ body {
137
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
138
+ display: flex;
139
+ justify-content: center;
140
+ align-items: center;
141
+ height: 100vh;
142
+ margin: 0;
143
+ background-color: #f5f5f5;
144
+ }
145
+ .container {
146
+ text-align: center;
147
+ padding: 2rem;
148
+ background: white;
149
+ border-radius: 8px;
150
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
151
+ }
152
+ h1 { color: #22c55e; margin-bottom: 0.5rem; }
153
+ p { color: #666; margin-top: 0.5rem; }
154
+ .icon { font-size: 48px; margin-bottom: 1rem; }
155
+ </style>
156
+ </head>
157
+ <body>
158
+ <div class="container">
159
+ <div class="icon">✅</div>
160
+ <h1>Authorization Successful!</h1>
161
+ <p>You can now close this window and return to your application.</p>
162
+ </div>
163
+ <script>
164
+ // Auto-close after 3 seconds
165
+ setTimeout(() => window.close(), 3000);
166
+ </script>
167
+ </body>
168
+ </html>
169
+ """
170
+
171
+ def _error_html(self, error: str | None, description: str | None) -> str:
172
+ """HTML response for authorization error."""
173
+ error_msg = error or "Unknown error"
174
+ desc_msg = description or "Authorization was not completed successfully."
175
+
176
+ return f"""
177
+ <!DOCTYPE html>
178
+ <html>
179
+ <head>
180
+ <title>Authorization Error</title>
181
+ <style>
182
+ body {{
183
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
184
+ display: flex;
185
+ justify-content: center;
186
+ align-items: center;
187
+ height: 100vh;
188
+ margin: 0;
189
+ background-color: #f5f5f5;
190
+ }}
191
+ .container {{
192
+ text-align: center;
193
+ padding: 2rem;
194
+ background: white;
195
+ border-radius: 8px;
196
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
197
+ max-width: 500px;
198
+ }}
199
+ h1 {{ color: #ef4444; margin-bottom: 0.5rem; }}
200
+ .error {{ color: #dc2626; font-weight: 600; margin: 1rem 0; }}
201
+ .description {{ color: #666; margin-top: 0.5rem; }}
202
+ .icon {{ font-size: 48px; margin-bottom: 1rem; }}
203
+ </style>
204
+ </head>
205
+ <body>
206
+ <div class="container">
207
+ <div class="icon">❌</div>
208
+ <h1>Authorization Error</h1>
209
+ <p class="error">{error_msg}</p>
210
+ <p class="description">{desc_msg}</p>
211
+ </div>
212
+ </body>
213
+ </html>
214
+ """
mcp_use/client.py CHANGED
@@ -192,7 +192,7 @@ class MCPClient:
192
192
 
193
193
  server_config = servers[server_name]
194
194
 
195
- # Create connector with options
195
+ # Create connector with options and client-level auth
196
196
  connector = create_connector_from_config(
197
197
  server_config,
198
198
  sandbox=self.sandbox,
mcp_use/config.py CHANGED
@@ -79,7 +79,7 @@ def create_connector_from_config(
79
79
  return HttpConnector(
80
80
  base_url=server_config["url"],
81
81
  headers=server_config.get("headers", None),
82
- auth_token=server_config.get("auth_token", None),
82
+ auth=server_config.get("auth", {}),
83
83
  timeout=server_config.get("timeout", 5),
84
84
  sse_read_timeout=server_config.get("sse_read_timeout", 60 * 5),
85
85
  sampling_callback=sampling_callback,
@@ -93,7 +93,7 @@ def create_connector_from_config(
93
93
  return WebSocketConnector(
94
94
  url=server_config["ws_url"],
95
95
  headers=server_config.get("headers", None),
96
- auth_token=server_config.get("auth_token", None),
96
+ auth=server_config.get("auth", {}),
97
97
  )
98
98
 
99
99
  raise ValueError("Cannot determine connector type from config")
@@ -111,29 +111,34 @@ class BaseConnector(ABC):
111
111
  """Clean up all resources associated with this connector."""
112
112
  errors = []
113
113
 
114
- # First close the client session
115
- if self.client_session:
114
+ # First stop the connection manager, this closes the ClientSession inside
115
+ # the same task where it was opened, avoiding cancel-scope mismatches.
116
+ if self._connection_manager:
116
117
  try:
117
- logger.debug("Closing client session")
118
- await self.client_session.__aexit__(None, None, None)
118
+ logger.debug("Stopping connection manager")
119
+ await self._connection_manager.stop()
119
120
  except Exception as e:
120
- error_msg = f"Error closing client session: {e}"
121
+ error_msg = f"Error stopping connection manager: {e}"
121
122
  logger.warning(error_msg)
122
123
  errors.append(error_msg)
123
124
  finally:
124
- self.client_session = None
125
+ self._connection_manager = None
125
126
 
126
- # Then stop the connection manager
127
- if self._connection_manager:
127
+ # Ensure the client_session reference is cleared (it should already be
128
+ # closed by the connection manager). Only attempt a direct __aexit__ if
129
+ # the connection manager did *not* exist, this covers edge-cases like
130
+ # failed connections where no manager was started.
131
+ if self.client_session:
128
132
  try:
129
- logger.debug("Stopping connection manager")
130
- await self._connection_manager.stop()
133
+ if not self._connection_manager:
134
+ logger.debug("Closing client session (no connection manager)")
135
+ await self.client_session.__aexit__(None, None, None)
131
136
  except Exception as e:
132
- error_msg = f"Error stopping connection manager: {e}"
137
+ error_msg = f"Error closing client session: {e}"
133
138
  logger.warning(error_msg)
134
139
  errors.append(error_msg)
135
140
  finally:
136
- self._connection_manager = None
141
+ self.client_session = None
137
142
 
138
143
  # Reset tools
139
144
  self._tools = None
@@ -5,10 +5,17 @@ This module provides a connector for communicating with MCP implementations
5
5
  through HTTP APIs with SSE or Streamable HTTP for transport.
6
6
  """
7
7
 
8
+ from typing import Any
9
+
8
10
  import httpx
9
11
  from mcp import ClientSession
10
12
  from mcp.client.session import ElicitationFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
13
+ from mcp.shared.exceptions import McpError
14
+
15
+ from mcp_use.auth.oauth import OAuthClientProvider
11
16
 
17
+ from ..auth import BearerAuth, OAuth
18
+ from ..exceptions import OAuthAuthenticationError, OAuthDiscoveryError
12
19
  from ..logging import logger
13
20
  from ..task_managers import SseConnectionManager, StreamableHttpConnectionManager
14
21
  from .base import BaseConnector
@@ -24,10 +31,10 @@ class HttpConnector(BaseConnector):
24
31
  def __init__(
25
32
  self,
26
33
  base_url: str,
27
- auth_token: str | None = None,
28
34
  headers: dict[str, str] | None = None,
29
35
  timeout: float = 5,
30
36
  sse_read_timeout: float = 60 * 5,
37
+ auth: str | dict[str, Any] | httpx.Auth | None = None,
31
38
  sampling_callback: SamplingFnT | None = None,
32
39
  elicitation_callback: ElicitationFnT | None = None,
33
40
  message_handler: MessageHandlerFnT | None = None,
@@ -37,10 +44,13 @@ class HttpConnector(BaseConnector):
37
44
 
38
45
  Args:
39
46
  base_url: The base URL of the MCP HTTP API.
40
- auth_token: Optional authentication token.
41
47
  headers: Optional additional headers.
42
48
  timeout: Timeout for HTTP operations in seconds.
43
49
  sse_read_timeout: Timeout for SSE read operations in seconds.
50
+ auth: Authentication method - can be:
51
+ - A string token: Use Bearer token authentication
52
+ - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
53
+ - An httpx.Auth object: Use custom authentication
44
54
  sampling_callback: Optional sampling callback.
45
55
  elicitation_callback: Optional elicitation callback.
46
56
  """
@@ -51,12 +61,57 @@ class HttpConnector(BaseConnector):
51
61
  logging_callback=logging_callback,
52
62
  )
53
63
  self.base_url = base_url.rstrip("/")
54
- self.auth_token = auth_token
55
64
  self.headers = headers or {}
56
- if auth_token:
57
- self.headers["Authorization"] = f"Bearer {auth_token}"
58
65
  self.timeout = timeout
59
66
  self.sse_read_timeout = sse_read_timeout
67
+ self._auth: httpx.Auth | None = None
68
+ self._oauth: OAuth | None = None
69
+
70
+ # Handle authentication
71
+ if auth is not None:
72
+ self._set_auth(auth)
73
+
74
+ def _set_auth(self, auth: str | dict[str, Any] | httpx.Auth) -> None:
75
+ """Set authentication method.
76
+
77
+ Args:
78
+ auth: Authentication method - can be:
79
+ - A string token: Use Bearer token authentication
80
+ - A dict with OAuth config: {"client_id": "...", "client_secret": "...", "scope": "..."}
81
+ - An httpx.Auth object: Use custom authentication
82
+ """
83
+ if isinstance(auth, str):
84
+ # Treat as bearer token
85
+ self._auth = BearerAuth(token=auth)
86
+ self.headers["Authorization"] = f"Bearer {auth}"
87
+ elif isinstance(auth, dict):
88
+ # Check if this is an OAuth provider configuration
89
+ if "oauth_provider" in auth:
90
+ oauth_provider = auth["oauth_provider"]
91
+ if isinstance(oauth_provider, dict):
92
+ oauth_provider = OAuthClientProvider(**oauth_provider)
93
+ self._oauth = OAuth(
94
+ self.base_url,
95
+ scope=auth.get("scope"),
96
+ client_id=auth.get("client_id"),
97
+ client_secret=auth.get("client_secret"),
98
+ callback_port=auth.get("callback_port"),
99
+ oauth_provider=oauth_provider,
100
+ )
101
+ self._oauth_config = auth
102
+ else:
103
+ self._oauth = OAuth(
104
+ self.base_url,
105
+ scope=auth.get("scope"),
106
+ client_id=auth.get("client_id"),
107
+ client_secret=auth.get("client_secret"),
108
+ callback_port=auth.get("callback_port"),
109
+ )
110
+ self._oauth_config = auth
111
+ elif isinstance(auth, httpx.Auth):
112
+ self._auth = auth
113
+ else:
114
+ raise ValueError(f"Invalid auth type: {type(auth)}")
60
115
 
61
116
  async def connect(self) -> None:
62
117
  """Establish a connection to the MCP implementation."""
@@ -64,6 +119,29 @@ class HttpConnector(BaseConnector):
64
119
  logger.debug("Already connected to MCP implementation")
65
120
  return
66
121
 
122
+ # Handle OAuth if needed
123
+ if self._oauth:
124
+ try:
125
+ # Create a temporary client for OAuth metadata discovery
126
+ async with httpx.AsyncClient() as client:
127
+ bearer_auth = await self._oauth.initialize(client)
128
+ if not bearer_auth:
129
+ # Need to perform OAuth flow
130
+ logger.info("OAuth authentication required")
131
+ bearer_auth = await self._oauth.authenticate()
132
+
133
+ # Update auth and headers
134
+ self._auth = bearer_auth
135
+ self.headers["Authorization"] = f"Bearer {bearer_auth.token.get_secret_value()}"
136
+ except OAuthDiscoveryError:
137
+ # OAuth discovery failed - it means server doesn't support OAuth default urls
138
+ logger.debug("OAuth discovery failed, continuing without initialization.")
139
+ self._oauth = None
140
+ self._auth = None
141
+ except OAuthAuthenticationError as e:
142
+ logger.error(f"OAuth initialization failed: {e}")
143
+ raise
144
+
67
145
  # Try streamable HTTP first (new transport), fall back to SSE (old transport)
68
146
  # This implements backwards compatibility per MCP specification
69
147
  self.transport_type = None
@@ -73,7 +151,7 @@ class HttpConnector(BaseConnector):
73
151
  # First, try the new streamable HTTP transport
74
152
  logger.debug(f"Attempting streamable HTTP connection to: {self.base_url}")
75
153
  connection_manager = StreamableHttpConnectionManager(
76
- self.base_url, self.headers, self.timeout, self.sse_read_timeout
154
+ self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
77
155
  )
78
156
 
79
157
  # Test if this is a streamable HTTP server by attempting initialization
@@ -94,9 +172,9 @@ class HttpConnector(BaseConnector):
94
172
  try:
95
173
  # Try to initialize - this is where streamable HTTP vs SSE difference should show up
96
174
  result = await test_client.initialize()
175
+ logger.debug(f"Streamable HTTP initialization result: {result}")
97
176
 
98
177
  # If we get here, streamable HTTP works
99
-
100
178
  self.client_session = test_client
101
179
  self.transport_type = "streamable HTTP"
102
180
  self._initialized = True # Mark as initialized since we just called initialize()
@@ -125,14 +203,28 @@ class HttpConnector(BaseConnector):
125
203
  else:
126
204
  self._prompts = []
127
205
 
128
- except Exception as init_error:
206
+ # Only McpError is raised from client's initialization because
207
+ # exceptions are handled internally.
208
+ except McpError as mcp_error:
209
+ logger.error("MCP protocol error during initialization: %s", mcp_error.error)
129
210
  # Clean up the test client
211
+ try:
212
+ await test_client.__aexit__(None, None, None)
213
+ except Exception:
214
+ pass
215
+ raise mcp_error
216
+
217
+ except Exception as init_error:
218
+ # This catches non-McpError exceptions, like a direct httpx timeout
219
+ # but in the most cases this won't happen. It's for safety.
130
220
  try:
131
221
  await test_client.__aexit__(None, None, None)
132
222
  except Exception:
133
223
  pass
134
224
  raise init_error
135
225
 
226
+ # Exception from the inner try is propagated here and in
227
+ # the most cases is an McpError, so checking instances is useless
136
228
  except Exception as streamable_error:
137
229
  logger.debug(f"Streamable HTTP failed: {streamable_error}")
138
230
 
@@ -143,24 +235,17 @@ class HttpConnector(BaseConnector):
143
235
  except Exception:
144
236
  pass
145
237
 
146
- # Check if this is a 4xx error that indicates we should try SSE fallback
147
- should_fallback = False
148
- if isinstance(streamable_error, httpx.HTTPStatusError):
149
- if streamable_error.response.status_code in [404, 405]:
150
- should_fallback = True
151
- elif "405 Method Not Allowed" in str(streamable_error) or "404 Not Found" in str(streamable_error):
152
- should_fallback = True
153
- else:
154
- # For other errors, still try fallback but they might indicate
155
- # real connectivity issues
156
- should_fallback = True
238
+ # It doesn't make sense to check error types. Because client
239
+ # always return a McpError, if he can't reach the server
240
+ # because it's offline, or if it has an auth problem.
241
+ should_fallback = True
157
242
 
158
243
  if should_fallback:
159
244
  try:
160
245
  # Fall back to the old SSE transport
161
246
  logger.debug(f"Attempting SSE fallback connection to: {self.base_url}")
162
247
  connection_manager = SseConnectionManager(
163
- self.base_url, self.headers, self.timeout, self.sse_read_timeout
248
+ self.base_url, self.headers, self.timeout, self.sse_read_timeout, auth=self._auth
164
249
  )
165
250
 
166
251
  read_stream, write_stream = await connection_manager.start()
@@ -178,7 +263,18 @@ class HttpConnector(BaseConnector):
178
263
  await self.client_session.__aenter__()
179
264
  self.transport_type = "SSE"
180
265
 
181
- except Exception as sse_error:
266
+ except* Exception as sse_error:
267
+ # Get the exception from the ExceptionGroup, and here we will get the correct type.
268
+ sse_error = sse_error.exceptions[0]
269
+ if isinstance(sse_error, httpx.HTTPStatusError) and sse_error.response.status_code in [
270
+ 401,
271
+ 403,
272
+ 407,
273
+ ]:
274
+ raise OAuthAuthenticationError(
275
+ f"Server requires authentication (HTTP {sse_error.response.status_code}) "
276
+ "but auth failed. Please provide auth configuration manually."
277
+ ) from sse_error
182
278
  logger.error(
183
279
  f"Both transport methods failed. Streamable HTTP: {streamable_error}, SSE: {sse_error}"
184
280
  )
@@ -10,6 +10,7 @@ import json
10
10
  import uuid
11
11
  from typing import Any
12
12
 
13
+ import httpx
13
14
  from mcp.types import Tool
14
15
  from websockets import ClientConnection
15
16
 
@@ -28,21 +29,29 @@ class WebSocketConnector(BaseConnector):
28
29
  def __init__(
29
30
  self,
30
31
  url: str,
31
- auth_token: str | None = None,
32
32
  headers: dict[str, str] | None = None,
33
+ auth: str | dict[str, Any] | httpx.Auth | None = None,
33
34
  ):
34
35
  """Initialize a new WebSocket connector.
35
36
 
36
37
  Args:
37
38
  url: The WebSocket URL to connect to.
38
- auth_token: Optional authentication token.
39
39
  headers: Optional additional headers.
40
+ auth: Authentication method - can be:
41
+ - A string token: Use Bearer token authentication
42
+ - A dict: Not supported for WebSocket (will log warning)
43
+ - An httpx.Auth object: Not supported for WebSocket (will log warning)
40
44
  """
41
45
  self.url = url
42
- self.auth_token = auth_token
43
46
  self.headers = headers or {}
44
- if auth_token:
45
- self.headers["Authorization"] = f"Bearer {auth_token}"
47
+
48
+ # Handle authentication - WebSocket only supports bearer tokens
49
+ # An auth field it's not needed
50
+ if auth is not None:
51
+ if isinstance(auth, str):
52
+ self.headers["Authorization"] = f"Bearer {auth}"
53
+ else:
54
+ logger.warning("WebSocket connector only supports bearer token authentication")
46
55
 
47
56
  self.ws: ClientConnection | None = None
48
57
  self._connection_manager: ConnectionManager | None = None
mcp_use/exceptions.py ADDED
@@ -0,0 +1,31 @@
1
+ """MCP-use exceptions."""
2
+
3
+
4
+ class MCPError(Exception):
5
+ """Base exception for MCP-use."""
6
+
7
+ pass
8
+
9
+
10
+ class OAuthDiscoveryError(MCPError):
11
+ """OAuth discovery auth metadata error"""
12
+
13
+ pass
14
+
15
+
16
+ class OAuthAuthenticationError(MCPError):
17
+ """OAuth authentication-related errors"""
18
+
19
+ pass
20
+
21
+
22
+ class ConnectionError(MCPError):
23
+ """Connection-related errors."""
24
+
25
+ pass
26
+
27
+
28
+ class ConfigurationError(MCPError):
29
+ """Configuration-related errors."""
30
+
31
+ pass
@@ -22,13 +22,14 @@ class ConnectionManager(Generic[T], ABC):
22
22
  used with MCP connectors.
23
23
  """
24
24
 
25
- def __init__(self):
25
+ def __init__(self) -> None:
26
26
  """Initialize a new connection manager."""
27
27
  self._ready_event = asyncio.Event()
28
28
  self._done_event = asyncio.Event()
29
+ self._stop_event = asyncio.Event()
29
30
  self._exception: Exception | None = None
30
31
  self._connection: T | None = None
31
- self._task: asyncio.Task | None = None
32
+ self._task: asyncio.Task[None] | None = None
32
33
 
33
34
  @abstractmethod
34
35
  async def _establish_connection(self) -> T:
@@ -86,20 +87,15 @@ class ConnectionManager(Generic[T], ABC):
86
87
 
87
88
  async def stop(self) -> None:
88
89
  """Stop the connection manager and close the connection."""
90
+ # Signal stop to the connection task instead of cancelling it, avoids
91
+ # propagating CancelledError to unrelated tasks.
89
92
  if self._task and not self._task.done():
90
- # Cancel the task
91
- logger.debug(f"Cancelling {self.__class__.__name__} task")
92
- self._task.cancel()
93
-
94
- # Wait for it to complete
95
- try:
96
- await self._task
97
- except asyncio.CancelledError:
98
- logger.debug(f"{self.__class__.__name__} task cancelled successfully")
99
- except Exception as e:
100
- logger.warning(f"Error stopping {self.__class__.__name__} task: {e}")
101
-
102
- # Wait for the connection to be done
93
+ logger.debug(f"Signaling stop to {self.__class__.__name__} task")
94
+ self._stop_event.set()
95
+ # Wait for it to finish gracefully
96
+ await self._task
97
+
98
+ # Ensure cleanup completed
103
99
  await self._done_event.wait()
104
100
  logger.debug(f"{self.__class__.__name__} task completed")
105
101
 
@@ -125,14 +121,8 @@ class ConnectionManager(Generic[T], ABC):
125
121
  # Signal that the connection is ready
126
122
  self._ready_event.set()
127
123
 
128
- # Wait indefinitely until cancelled
129
- try:
130
- # This keeps the connection open until cancelled
131
- await asyncio.Event().wait()
132
- except asyncio.CancelledError:
133
- # Expected when stopping
134
- logger.debug(f"{self.__class__.__name__} task received cancellation")
135
- pass
124
+ # Wait until stop is requested
125
+ await self._stop_event.wait()
136
126
 
137
127
  except Exception as e:
138
128
  # Store the exception