universal-mcp 0.1.22rc1__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,6 +30,8 @@ sys.path.append(str(UNIVERSAL_MCP_HOME))
30
30
  # Name are in the format of "app-name", eg, google-calendar
31
31
  # Class name is NameApp, eg, GoogleCalendarApp
32
32
 
33
+ app_cache: dict[str, type[BaseApplication]] = {}
34
+
33
35
 
34
36
  def _install_or_upgrade_package(package_name: str, repository_path: str):
35
37
  """
@@ -71,6 +73,8 @@ def app_from_slug(slug: str):
71
73
  Dynamically resolve and return the application class for the given slug.
72
74
  Attempts installation from GitHub if the package is not found locally.
73
75
  """
76
+ if slug in app_cache:
77
+ return app_cache[slug]
74
78
  class_name = get_default_class_name(slug)
75
79
  module_path = get_default_module_path(slug)
76
80
  package_name = get_default_package_name(slug)
@@ -81,6 +85,7 @@ def app_from_slug(slug: str):
81
85
  module = importlib.import_module(module_path)
82
86
  class_ = getattr(module, class_name)
83
87
  logger.debug(f"Loaded class '{class_}' from module '{module_path}'")
88
+ app_cache[slug] = class_
84
89
  return class_
85
90
  except ModuleNotFoundError as e:
86
91
  raise ModuleNotFoundError(f"Package '{module_path}' not found locally. Please install it first.") from e
@@ -149,6 +149,30 @@ class APIApplication(BaseApplication):
149
149
  )
150
150
  return self._client
151
151
 
152
+ def _handle_response(self, response: httpx.Response) -> dict[str, Any]:
153
+ """
154
+ Handle API responses by checking for errors and parsing the response appropriately.
155
+
156
+ This method:
157
+ 1. Checks for API errors and provides detailed error context including status code and response body
158
+ 2. For successful responses, automatically parses JSON or returns success message
159
+
160
+ Args:
161
+ response: The HTTP response to process
162
+
163
+ Returns:
164
+ dict[str, Any] | str: Parsed JSON data if response contains JSON,
165
+ otherwise a success message with status code
166
+
167
+ Raises:
168
+ httpx.HTTPStatusError: If the response indicates an error status, with full error details
169
+ """
170
+ response.raise_for_status()
171
+ try:
172
+ return response.json()
173
+ except Exception:
174
+ return {"status": "success", "status_code": response.status_code, "text": response.text}
175
+
152
176
  def _get(self, url: str, params: dict[str, Any] | None = None) -> httpx.Response:
153
177
  """
154
178
  Make a GET request to the specified URL.
@@ -158,14 +182,13 @@ class APIApplication(BaseApplication):
158
182
  params: Optional query parameters
159
183
 
160
184
  Returns:
161
- httpx.Response: The response from the server
185
+ httpx.Response: The raw HTTP response object
162
186
 
163
187
  Raises:
164
- httpx.HTTPError: If the request fails
188
+ httpx.HTTPStatusError: If the request fails (when raise_for_status() is called)
165
189
  """
166
190
  logger.debug(f"Making GET request to {url} with params: {params}")
167
191
  response = self.client.get(url, params=params)
168
- response.raise_for_status()
169
192
  logger.debug(f"GET request successful with status code: {response.status_code}")
170
193
  return response
171
194
 
@@ -193,10 +216,10 @@ class APIApplication(BaseApplication):
193
216
  Example: {'file_field_name': ('filename.txt', open('file.txt', 'rb'), 'text/plain')}
194
217
 
195
218
  Returns:
196
- httpx.Response: The response from the server
219
+ httpx.Response: The raw HTTP response object
197
220
 
198
221
  Raises:
199
- httpx.HTTPError: If the request fails
222
+ httpx.HTTPStatusError: If the request fails (when raise_for_status() is called)
200
223
  """
201
224
  logger.debug(
202
225
  f"Making POST request to {url} with params: {params}, data type: {type(data)}, content_type={content_type}, files: {'yes' if files else 'no'}"
@@ -235,7 +258,6 @@ class APIApplication(BaseApplication):
235
258
  content=data, # Expect data to be bytes or str
236
259
  params=params,
237
260
  )
238
- response.raise_for_status()
239
261
  logger.debug(f"POST request successful with status code: {response.status_code}")
240
262
  return response
241
263
 
@@ -263,10 +285,10 @@ class APIApplication(BaseApplication):
263
285
  Example: {'file_field_name': ('filename.txt', open('file.txt', 'rb'), 'text/plain')}
264
286
 
265
287
  Returns:
266
- httpx.Response: The response from the server
288
+ httpx.Response: The raw HTTP response object
267
289
 
268
290
  Raises:
269
- httpx.HTTPError: If the request fails
291
+ httpx.HTTPStatusError: If the request fails (when raise_for_status() is called)
270
292
  """
271
293
  logger.debug(
272
294
  f"Making PUT request to {url} with params: {params}, data type: {type(data)}, content_type={content_type}, files: {'yes' if files else 'no'}"
@@ -306,7 +328,6 @@ class APIApplication(BaseApplication):
306
328
  content=data, # Expect data to be bytes or str
307
329
  params=params,
308
330
  )
309
- response.raise_for_status()
310
331
  logger.debug(f"PUT request successful with status code: {response.status_code}")
311
332
  return response
312
333
 
@@ -319,14 +340,13 @@ class APIApplication(BaseApplication):
319
340
  params: Optional query parameters
320
341
 
321
342
  Returns:
322
- httpx.Response: The response from the server
343
+ httpx.Response: The raw HTTP response object
323
344
 
324
345
  Raises:
325
- httpx.HTTPError: If the request fails
346
+ httpx.HTTPStatusError: If the request fails (when raise_for_status() is called)
326
347
  """
327
348
  logger.debug(f"Making DELETE request to {url} with params: {params}")
328
349
  response = self.client.delete(url, params=params, timeout=self.default_timeout)
329
- response.raise_for_status()
330
350
  logger.debug(f"DELETE request successful with status code: {response.status_code}")
331
351
  return response
332
352
 
@@ -340,10 +360,10 @@ class APIApplication(BaseApplication):
340
360
  params: Optional query parameters
341
361
 
342
362
  Returns:
343
- httpx.Response: The response from the server
363
+ httpx.Response: The raw HTTP response object
344
364
 
345
365
  Raises:
346
- httpx.HTTPError: If the request fails
366
+ httpx.HTTPStatusError: If the request fails (when raise_for_status() is called)
347
367
  """
348
368
  logger.debug(f"Making PATCH request to {url} with params: {params} and data: {data}")
349
369
  response = self.client.patch(
@@ -351,7 +371,6 @@ class APIApplication(BaseApplication):
351
371
  json=data,
352
372
  params=params,
353
373
  )
354
- response.raise_for_status()
355
374
  logger.debug(f"PATCH request successful with status code: {response.status_code}")
356
375
  return response
357
376
 
universal_mcp/cli.py CHANGED
@@ -270,6 +270,7 @@ def preprocess(
270
270
  def split_api(
271
271
  input_app_file: Path = typer.Argument(..., help="Path to the generated app.py file to split"),
272
272
  output_dir: Path = typer.Option(..., "--output-dir", "-o", help="Directory to save the split files"),
273
+ package_name: str = typer.Option(None, "--package-name", "-p", help="Package name for absolute imports (e.g., 'hubspot')"),
273
274
  ):
274
275
  """Splits a single generated API client file into multiple files based on path groups."""
275
276
  from universal_mcp.utils.openapi.api_splitter import split_generated_app_file
@@ -286,7 +287,7 @@ def split_api(
286
287
  raise typer.Exit(1)
287
288
 
288
289
  try:
289
- split_generated_app_file(input_app_file, output_dir)
290
+ split_generated_app_file(input_app_file, output_dir, package_name)
290
291
  console.print(f"[green]Successfully split {input_app_file} into {output_dir}[/green]")
291
292
  except Exception as e:
292
293
  console.print(f"[red]Error splitting API client: {e}[/red]")
@@ -0,0 +1,30 @@
1
+ import asyncio
2
+ import os
3
+ import sys
4
+
5
+ from loguru import logger
6
+ from pydantic import ValidationError
7
+
8
+ from universal_mcp.client.agent import ChatSession
9
+ from universal_mcp.client.client import MultiClientServer
10
+ from universal_mcp.config import ClientConfig
11
+
12
+
13
+ async def main() -> None:
14
+ """Initialize and run the chat session."""
15
+ # Load settings and config using Pydantic BaseSettings
16
+
17
+ config_path = os.getenv("MCP_CONFIG_PATH", "servers.json")
18
+ try:
19
+ app_config = ClientConfig.load_json_config(config_path)
20
+ except (FileNotFoundError, ValidationError) as e:
21
+ logger.error(f"Error loading config: {e}")
22
+ sys.exit(1)
23
+
24
+ async with MultiClientServer(app_config.mcpServers) as mcp_server:
25
+ chat_session = ChatSession(mcp_server, app_config.llm)
26
+ await chat_session.interactive_loop()
27
+
28
+
29
+ if __name__ == "__main__":
30
+ asyncio.run(main())
@@ -0,0 +1,96 @@
1
+ import json
2
+
3
+ from loguru import logger
4
+ from mcp.server import Server as MCPServer
5
+ from openai import AsyncOpenAI
6
+
7
+ from universal_mcp.config import LLMConfig
8
+
9
+
10
+ class ChatSession:
11
+ """Orchestrates the interaction between user, LLM, and tools."""
12
+
13
+ def __init__(self, mcp_server: MCPServer, llm: LLMConfig | None) -> None:
14
+ self.mcp_server: MCPServer = mcp_server
15
+ self.llm: AsyncOpenAI | None = AsyncOpenAI(api_key=llm.api_key, base_url=llm.base_url) if llm else None
16
+ self.model = llm.model if llm else None
17
+
18
+ async def run(self, messages, tools) -> None:
19
+ """Run the chat session."""
20
+ llm_response = await self.llm.chat.completions.create(
21
+ model=self.model,
22
+ messages=messages,
23
+ tools=tools,
24
+ tool_choice="auto",
25
+ )
26
+
27
+ tool_calls = llm_response.choices[0].message.tool_calls
28
+ if tool_calls:
29
+ for tool_call in tool_calls:
30
+ result = await self.mcp_server.call_tool(
31
+ tool_name=tool_call.function.name,
32
+ arguments=json.loads(tool_call.function.arguments) if tool_call.function.arguments else {},
33
+ )
34
+ result_content = [rc.text for rc in result.content] if result.content else "No result"
35
+ messages.append(
36
+ {
37
+ "tool_call_id": tool_call.id,
38
+ "role": "tool",
39
+ "name": tool_call.function.name,
40
+ "content": result_content,
41
+ }
42
+ )
43
+ else:
44
+ messages.append(llm_response.choices[0].message)
45
+ return messages
46
+
47
+ async def interactive_loop(self) -> None:
48
+ """Main chat session handler."""
49
+ all_openai_tools = await self.mcp_server.list_tools(format="openai")
50
+ system_message = "You are a helpful assistant"
51
+ messages = [{"role": "system", "content": system_message}]
52
+
53
+ print("\n🎯 Interactive MCP Client")
54
+ print("Commands:")
55
+ print(" list - List available tools")
56
+ print(" call <tool_name> [args] - Call a tool")
57
+ print(" quit - Exit the client")
58
+ print()
59
+ while True:
60
+ try:
61
+ user_input = input("You: ").strip()
62
+ if user_input.lower() in {"quit", "exit"}:
63
+ logger.info("\nExiting...")
64
+ break
65
+ elif user_input.lower() == "list":
66
+ tools = await self.mcp_server.list_tools()
67
+ print("\nAvailable tools:")
68
+ for tool in tools:
69
+ print(f" {tool.name}")
70
+ continue
71
+ elif user_input.startswith("call "):
72
+ parts = user_input.split(maxsplit=2)
73
+ tool_name = parts[1] if len(parts) > 1 else ""
74
+
75
+ if not tool_name:
76
+ print("❌ Please specify a tool name")
77
+ continue
78
+
79
+ # Parse arguments (simple JSON-like format)
80
+ arguments = {}
81
+ if len(parts) > 2:
82
+ try:
83
+ arguments = json.loads(parts[2])
84
+ except json.JSONDecodeError:
85
+ print("❌ Invalid arguments format (expected JSON)")
86
+ continue
87
+ await self.mcp_server.call_tool(tool_name, arguments)
88
+
89
+ messages.append({"role": "user", "content": user_input})
90
+
91
+ messages = await self.run(messages, all_openai_tools)
92
+ print("\nAssistant: ", messages[-1]["content"])
93
+
94
+ except KeyboardInterrupt:
95
+ print("\nExiting...")
96
+ break
@@ -0,0 +1,198 @@
1
+ import os
2
+ import webbrowser
3
+ from contextlib import AsyncExitStack
4
+ from typing import Any, Literal
5
+
6
+ from loguru import logger
7
+ from mcp import ClientSession, StdioServerParameters
8
+ from mcp.client.auth import OAuthClientProvider
9
+ from mcp.client.sse import sse_client
10
+ from mcp.client.stdio import stdio_client
11
+ from mcp.client.streamable_http import streamablehttp_client
12
+ from mcp.server import Server
13
+ from mcp.shared.auth import OAuthClientMetadata
14
+ from mcp.types import (
15
+ CallToolResult as MCPCallToolResult,
16
+ )
17
+ from mcp.types import (
18
+ Tool as MCPTool,
19
+ )
20
+ from openai.types.chat import ChatCompletionToolParam
21
+
22
+ from universal_mcp.client.oauth import CallbackServer
23
+ from universal_mcp.client.token_store import TokenStore
24
+ from universal_mcp.config import ClientTransportConfig
25
+ from universal_mcp.stores.store import KeyringStore
26
+ from universal_mcp.tools.adapters import transform_mcp_tool_to_openai_tool
27
+
28
+
29
+ class MCPClient:
30
+ """Manages MCP server connections and tool execution."""
31
+
32
+ def __init__(self, name: str, config: ClientTransportConfig) -> None:
33
+ self.name: str = name
34
+ self.config: ClientTransportConfig = config
35
+ self.session: ClientSession | None = None
36
+ self.server_url: str = config.url
37
+
38
+ # Set up callback server
39
+ self.callback_server = CallbackServer(port=3000)
40
+ self.callback_server.start()
41
+
42
+ # Create OAuth authentication handler using the new interface
43
+ if self.server_url and not self.config.headers:
44
+ self.store = KeyringStore(self.name)
45
+ self.auth = OAuthClientProvider(
46
+ server_url="/".join(self.server_url.split("/")[:-1]),
47
+ client_metadata=OAuthClientMetadata.model_validate(self.client_metadata_dict),
48
+ storage=TokenStore(self.store),
49
+ redirect_handler=self._default_redirect_handler,
50
+ callback_handler=self._callback_handler,
51
+ )
52
+ else:
53
+ self.auth = None
54
+
55
+ async def _callback_handler(self) -> tuple[str, str | None]:
56
+ """Wait for OAuth callback and return auth code and state."""
57
+ print("⏳ Waiting for authorization callback...")
58
+ try:
59
+ auth_code = self.callback_server.wait_for_callback(timeout=300)
60
+ return auth_code, self.callback_server.get_state()
61
+ finally:
62
+ self.callback_server.stop()
63
+
64
+ @property
65
+ def client_metadata_dict(self) -> dict[str, Any]:
66
+ return {
67
+ "client_name": "Simple Auth Client",
68
+ "redirect_uris": ["http://localhost:3000/callback"],
69
+ "grant_types": ["authorization_code", "refresh_token"],
70
+ "response_types": ["code"],
71
+ "token_endpoint_auth_method": "client_secret_post",
72
+ }
73
+
74
+ async def _default_redirect_handler(self, authorization_url: str) -> None:
75
+ """Default redirect handler that opens the URL in a browser."""
76
+ print(f"Opening browser for authorization: {authorization_url}")
77
+ webbrowser.open(authorization_url)
78
+
79
+ async def initialize(self, exit_stack: AsyncExitStack):
80
+ """Initialize the server connection."""
81
+ transport = self.config.transport
82
+ try:
83
+ if transport == "stdio":
84
+ command = self.config["command"]
85
+ if command is None:
86
+ raise ValueError("The command must be a valid string and cannot be None.")
87
+
88
+ server_params = StdioServerParameters(
89
+ command=command,
90
+ args=self.config["args"],
91
+ env={**os.environ, **self.config["env"]} if self.config.get("env") else None,
92
+ )
93
+ stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
94
+ read, write = stdio_transport
95
+ session = await exit_stack.enter_async_context(ClientSession(read, write))
96
+ await session.initialize()
97
+ self.session = session
98
+ elif transport == "streamable_http":
99
+ url = self.config.get("url")
100
+ headers = self.config.get("headers", {})
101
+ if not url:
102
+ raise ValueError("'url' must be provided for streamable_http transport.")
103
+ streamable_http_transport = await exit_stack.enter_async_context(
104
+ streamablehttp_client(url=url, headers=headers, auth=self.auth)
105
+ )
106
+ read, write, _ = streamable_http_transport
107
+ session = await exit_stack.enter_async_context(ClientSession(read, write))
108
+ await session.initialize()
109
+ self.session = session
110
+ elif transport == "sse":
111
+ url = self.config.url
112
+ headers = self.config.headers
113
+ if not url:
114
+ raise ValueError("'url' must be provided for sse transport.")
115
+ sse_transport = await exit_stack.enter_async_context(
116
+ sse_client(url=url, headers=headers, auth=self.auth)
117
+ )
118
+ read, write = sse_transport
119
+ session = await exit_stack.enter_async_context(ClientSession(read, write))
120
+ await session.initialize()
121
+ self.session = session
122
+ else:
123
+ raise ValueError(f"Unknown transport: {transport}")
124
+ except Exception as e:
125
+ logger.error(f"Error initializing server {self.name}: {e}")
126
+ raise
127
+
128
+ async def list_tools(self) -> list[MCPTool]:
129
+ """List available tools from the server."""
130
+ if self.session:
131
+ tools = await self.session.list_tools()
132
+ return list(tools.tools)
133
+ return []
134
+
135
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> MCPCallToolResult:
136
+ """Call a tool on the server."""
137
+ if self.session:
138
+ return await self.session.call_tool(tool_name, arguments)
139
+ return MCPCallToolResult(
140
+ content=[],
141
+ isError=True,
142
+ )
143
+
144
+
145
+ class MultiClientServer(Server):
146
+ """
147
+ Manages multiple MCP servers and maintains a mapping from tool name to the server that provides it.
148
+ """
149
+
150
+ def __init__(self, clients: dict[str, ClientTransportConfig]):
151
+ self.clients: list[MCPClient] = [MCPClient(name, config) for name, config in clients.items()]
152
+ self.tool_to_client: dict[str, MCPClient] = {}
153
+ self._mcp_tools: list[MCPTool] = []
154
+ self._exit_stack: AsyncExitStack = AsyncExitStack()
155
+
156
+ async def __aenter__(self):
157
+ """Initialize the server connection."""
158
+ for client in self.clients:
159
+ await client.initialize(self._exit_stack)
160
+ await self._populate_tool_mapping()
161
+ return self
162
+
163
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
164
+ """Clean up the server connection."""
165
+ self.clients.clear()
166
+ self.tool_to_client.clear()
167
+ self._mcp_tools.clear()
168
+ await self._exit_stack.aclose()
169
+
170
+ async def _populate_tool_mapping(self):
171
+ """Populate the mapping from tool name to server."""
172
+ self.tool_to_client.clear()
173
+ self._mcp_tools.clear()
174
+ for client in self.clients:
175
+ try:
176
+ tools = await client.list_tools()
177
+ for tool in tools:
178
+ self._mcp_tools.append(tool)
179
+ tool_name = tool.name
180
+ logger.info(f"Found tool: {tool_name} from client: {client.name}")
181
+ if tool_name:
182
+ self.tool_to_client[tool_name] = client
183
+ except Exception as e:
184
+ logger.warning(f"Failed to list tools for client {client.name}: {e}")
185
+
186
+ async def list_tools(self, format: Literal["mcp", "openai"] = "mcp") -> list[MCPTool | ChatCompletionToolParam]:
187
+ """List available tools from all servers."""
188
+ if format == "mcp":
189
+ return self._mcp_tools
190
+ elif format == "openai":
191
+ return [transform_mcp_tool_to_openai_tool(tool) for tool in self._mcp_tools]
192
+ else:
193
+ raise ValueError(f"Invalid format: {format}")
194
+
195
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> MCPCallToolResult:
196
+ """Call a tool on the server."""
197
+ client = self.tool_to_client[tool_name]
198
+ return await client.call_tool(tool_name, arguments)
@@ -0,0 +1,114 @@
1
+ import threading
2
+ import time
3
+ from http.server import BaseHTTPRequestHandler, HTTPServer
4
+ from urllib.parse import parse_qs, urlparse
5
+
6
+ from universal_mcp.utils.singleton import Singleton
7
+
8
+
9
+ class CallbackHandler(BaseHTTPRequestHandler):
10
+ """Simple HTTP handler to capture OAuth callback."""
11
+
12
+ def __init__(self, request, client_address, server, callback_data):
13
+ """Initialize with callback data storage."""
14
+ self.callback_data = callback_data
15
+ super().__init__(request, client_address, server)
16
+
17
+ def do_GET(self):
18
+ """Handle GET request from OAuth redirect."""
19
+ parsed = urlparse(self.path)
20
+ query_params = parse_qs(parsed.query)
21
+
22
+ if "code" in query_params:
23
+ self.callback_data["authorization_code"] = query_params["code"][0]
24
+ self.callback_data["state"] = query_params.get("state", [None])[0]
25
+ self.send_response(200)
26
+ self.send_header("Content-type", "text/html")
27
+ self.end_headers()
28
+ self.wfile.write(b"""
29
+ <html>
30
+ <body>
31
+ <h1>Authorization Successful!</h1>
32
+ <p>You can close this window and return to the terminal.</p>
33
+ <script>setTimeout(() => window.close(), 2000);</script>
34
+ </body>
35
+ </html>
36
+ """)
37
+ elif "error" in query_params:
38
+ self.callback_data["error"] = query_params["error"][0]
39
+ self.send_response(400)
40
+ self.send_header("Content-type", "text/html")
41
+ self.end_headers()
42
+ self.wfile.write(
43
+ f"""
44
+ <html>
45
+ <body>
46
+ <h1>Authorization Failed</h1>
47
+ <p>Error: {query_params['error'][0]}</p>
48
+ <p>You can close this window and return to the terminal.</p>
49
+ </body>
50
+ </html>
51
+ """.encode()
52
+ )
53
+ else:
54
+ self.send_response(404)
55
+ self.end_headers()
56
+
57
+ def log_message(self, format, *args):
58
+ """Suppress default logging."""
59
+ pass
60
+
61
+
62
+ class CallbackServer(metaclass=Singleton):
63
+ """Simple server to handle OAuth callbacks."""
64
+
65
+ def __init__(self, port=3000):
66
+ self.port = port
67
+ self.server = None
68
+ self.thread = None
69
+ self.callback_data = {"authorization_code": None, "state": None, "error": None}
70
+ self._running = False
71
+
72
+ def _create_handler_with_data(self):
73
+ """Create a handler class with access to callback data."""
74
+ callback_data = self.callback_data
75
+
76
+ class DataCallbackHandler(CallbackHandler):
77
+ def __init__(self, request, client_address, server):
78
+ super().__init__(request, client_address, server, callback_data)
79
+
80
+ return DataCallbackHandler
81
+
82
+ def start(self):
83
+ """Start the callback server in a background thread."""
84
+ if self._running:
85
+ return
86
+ handler_class = self._create_handler_with_data()
87
+ self.server = HTTPServer(("localhost", self.port), handler_class)
88
+ self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
89
+ self.thread.start()
90
+ print(f"🖥️ Started callback server on http://localhost:{self.port}")
91
+ self._running = True
92
+
93
+ def stop(self):
94
+ """Stop the callback server."""
95
+ if self.server:
96
+ self.server.shutdown()
97
+ self.server.server_close()
98
+ if self.thread:
99
+ self.thread.join(timeout=1)
100
+
101
+ def wait_for_callback(self, timeout=300):
102
+ """Wait for OAuth callback with timeout."""
103
+ start_time = time.time()
104
+ while time.time() - start_time < timeout:
105
+ if self.callback_data["authorization_code"]:
106
+ return self.callback_data["authorization_code"]
107
+ elif self.callback_data["error"]:
108
+ raise Exception(f"OAuth error: {self.callback_data['error']}")
109
+ time.sleep(0.1)
110
+ raise Exception("Timeout waiting for OAuth callback")
111
+
112
+ def get_state(self):
113
+ """Get the received state parameter."""
114
+ return self.callback_data["state"]
@@ -0,0 +1,32 @@
1
+ from mcp.client.auth import TokenStorage as MCPTokenStorage
2
+ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
3
+
4
+ from universal_mcp.exceptions import KeyNotFoundError
5
+ from universal_mcp.stores.store import KeyringStore
6
+
7
+
8
+ class TokenStore(MCPTokenStorage):
9
+ """Simple in-memory token storage implementation."""
10
+
11
+ def __init__(self, store: KeyringStore):
12
+ self.store = store
13
+ self._tokens: OAuthToken | None = None
14
+ self._client_info: OAuthClientInformationFull | None = None
15
+
16
+ async def get_tokens(self) -> OAuthToken | None:
17
+ try:
18
+ return OAuthToken.model_validate_json(self.store.get("tokens"))
19
+ except KeyNotFoundError:
20
+ return None
21
+
22
+ async def set_tokens(self, tokens: OAuthToken) -> None:
23
+ self.store.set("tokens", tokens.model_dump_json())
24
+
25
+ async def get_client_info(self) -> OAuthClientInformationFull | None:
26
+ try:
27
+ return OAuthClientInformationFull.model_validate_json(self.store.get("client_info"))
28
+ except KeyNotFoundError:
29
+ return None
30
+
31
+ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
32
+ self.store.set("client_info", client_info.model_dump_json())