cade-cli 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cade_cli-0.3.3.dist-info/METADATA +151 -0
- cade_cli-0.3.3.dist-info/RECORD +44 -0
- cade_cli-0.3.3.dist-info/WHEEL +4 -0
- cade_cli-0.3.3.dist-info/entry_points.txt +2 -0
- cadecoder/__init__.py +1 -0
- cadecoder/ai/__init__.py +6 -0
- cadecoder/ai/prompts.py +572 -0
- cadecoder/cli/__init__.py +0 -0
- cadecoder/cli/app.py +147 -0
- cadecoder/cli/auth.py +483 -0
- cadecoder/cli/commands/__init__.py +5 -0
- cadecoder/cli/commands/auth.py +143 -0
- cadecoder/cli/commands/chat.py +264 -0
- cadecoder/cli/commands/mcp.py +477 -0
- cadecoder/cli/commands/tools.py +226 -0
- cadecoder/core/__init__.py +12 -0
- cadecoder/core/config.py +380 -0
- cadecoder/core/constants.py +281 -0
- cadecoder/core/errors.py +145 -0
- cadecoder/core/logging.py +148 -0
- cadecoder/core/types.py +235 -0
- cadecoder/core/utils.py +279 -0
- cadecoder/execution/__init__.py +46 -0
- cadecoder/execution/context_window.py +521 -0
- cadecoder/execution/orchestrator.py +562 -0
- cadecoder/execution/parallel.py +287 -0
- cadecoder/providers/__init__.py +60 -0
- cadecoder/providers/base.py +294 -0
- cadecoder/providers/openai.py +251 -0
- cadecoder/storage/__init__.py +0 -0
- cadecoder/storage/threads.py +489 -0
- cadecoder/templates/login_failed.html +21 -0
- cadecoder/templates/login_success.html +21 -0
- cadecoder/templates/styles.css +87 -0
- cadecoder/tools/__init__.py +19 -0
- cadecoder/tools/builtin.py +644 -0
- cadecoder/tools/filesystem.py +315 -0
- cadecoder/tools/git.py +221 -0
- cadecoder/tools/manager.py +1635 -0
- cadecoder/ui/__init__.py +7 -0
- cadecoder/ui/display.py +338 -0
- cadecoder/ui/input.py +145 -0
- cadecoder/ui/session.py +455 -0
- cadecoder/ui/state.py +20 -0
|
@@ -0,0 +1,1635 @@
|
|
|
1
|
+
"""Unified tool management system supporting local, remote, and MCP tools."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import uuid
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from datetime import datetime, timedelta
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Annotated, Any, get_args, get_origin
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
from arcadepy import AsyncArcade
|
|
16
|
+
from rich.console import Console
|
|
17
|
+
from rich.panel import Panel
|
|
18
|
+
from rich.text import Text
|
|
19
|
+
|
|
20
|
+
from cadecoder.core.config import get_config
|
|
21
|
+
from cadecoder.core.logging import log
|
|
22
|
+
|
|
23
|
+
# Type aliases
|
|
24
|
+
JsonToolSchema = dict[str, Any]
|
|
25
|
+
|
|
26
|
+
# Console for output
|
|
27
|
+
console = Console(stderr=True)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# =============================================================================
|
|
31
|
+
# MCP Server Configuration
|
|
32
|
+
# =============================================================================
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MCPAuthType(str, Enum):
|
|
36
|
+
"""Authentication types for MCP servers."""
|
|
37
|
+
|
|
38
|
+
NONE = "none"
|
|
39
|
+
BEARER = "bearer" # Static bearer token
|
|
40
|
+
API_KEY = "api_key" # Static API key
|
|
41
|
+
OAUTH = "oauth" # Full OAuth 2.1 flow (MCP spec compliant)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class MCPOAuthTokens:
|
|
46
|
+
"""OAuth tokens for an MCP server."""
|
|
47
|
+
|
|
48
|
+
access_token: str
|
|
49
|
+
token_type: str = "Bearer"
|
|
50
|
+
refresh_token: str | None = None
|
|
51
|
+
expires_at: datetime | None = None
|
|
52
|
+
scope: str | None = None
|
|
53
|
+
|
|
54
|
+
def is_expired(self) -> bool:
|
|
55
|
+
"""Check if access token is expired."""
|
|
56
|
+
if self.expires_at is None:
|
|
57
|
+
return False
|
|
58
|
+
return datetime.now() >= self.expires_at
|
|
59
|
+
|
|
60
|
+
def to_dict(self) -> dict[str, Any]:
|
|
61
|
+
"""Convert to dictionary."""
|
|
62
|
+
return {
|
|
63
|
+
"access_token": self.access_token,
|
|
64
|
+
"token_type": self.token_type,
|
|
65
|
+
"refresh_token": self.refresh_token,
|
|
66
|
+
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
|
67
|
+
"scope": self.scope,
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
@classmethod
|
|
71
|
+
def from_dict(cls, data: dict[str, Any]) -> "MCPOAuthTokens":
|
|
72
|
+
"""Create from dictionary."""
|
|
73
|
+
expires_at = None
|
|
74
|
+
if data.get("expires_at"):
|
|
75
|
+
expires_at = datetime.fromisoformat(data["expires_at"])
|
|
76
|
+
return cls(
|
|
77
|
+
access_token=data["access_token"],
|
|
78
|
+
token_type=data.get("token_type", "Bearer"),
|
|
79
|
+
refresh_token=data.get("refresh_token"),
|
|
80
|
+
expires_at=expires_at,
|
|
81
|
+
scope=data.get("scope"),
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class MCPServerConfig:
|
|
87
|
+
"""Configuration for an MCP server connection."""
|
|
88
|
+
|
|
89
|
+
name: str
|
|
90
|
+
url: str
|
|
91
|
+
auth_type: MCPAuthType = MCPAuthType.NONE
|
|
92
|
+
auth_value: str | None = None # For bearer/api_key
|
|
93
|
+
enabled: bool = True
|
|
94
|
+
last_connected: datetime | None = None
|
|
95
|
+
tool_count: int = 0
|
|
96
|
+
# OAuth-specific fields
|
|
97
|
+
oauth_tokens: MCPOAuthTokens | None = None
|
|
98
|
+
oauth_authorization_server: str | None = None
|
|
99
|
+
oauth_client_id: str | None = None
|
|
100
|
+
oauth_scopes: list[str] | None = None
|
|
101
|
+
|
|
102
|
+
def to_dict(self) -> dict[str, Any]:
|
|
103
|
+
"""Convert to dictionary for serialization."""
|
|
104
|
+
return {
|
|
105
|
+
"name": self.name,
|
|
106
|
+
"url": self.url,
|
|
107
|
+
"auth_type": self.auth_type.value,
|
|
108
|
+
"auth_value": self.auth_value,
|
|
109
|
+
"enabled": self.enabled,
|
|
110
|
+
"last_connected": (self.last_connected.isoformat() if self.last_connected else None),
|
|
111
|
+
"tool_count": self.tool_count,
|
|
112
|
+
"oauth_tokens": self.oauth_tokens.to_dict() if self.oauth_tokens else None,
|
|
113
|
+
"oauth_authorization_server": self.oauth_authorization_server,
|
|
114
|
+
"oauth_client_id": self.oauth_client_id,
|
|
115
|
+
"oauth_scopes": self.oauth_scopes,
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def from_dict(cls, data: dict[str, Any]) -> "MCPServerConfig":
|
|
120
|
+
"""Create from dictionary."""
|
|
121
|
+
last_connected = None
|
|
122
|
+
if data.get("last_connected"):
|
|
123
|
+
last_connected = datetime.fromisoformat(data["last_connected"])
|
|
124
|
+
|
|
125
|
+
oauth_tokens = None
|
|
126
|
+
if data.get("oauth_tokens"):
|
|
127
|
+
oauth_tokens = MCPOAuthTokens.from_dict(data["oauth_tokens"])
|
|
128
|
+
|
|
129
|
+
return cls(
|
|
130
|
+
name=data["name"],
|
|
131
|
+
url=data["url"],
|
|
132
|
+
auth_type=MCPAuthType(data.get("auth_type", "none")),
|
|
133
|
+
auth_value=data.get("auth_value"),
|
|
134
|
+
enabled=data.get("enabled", True),
|
|
135
|
+
last_connected=last_connected,
|
|
136
|
+
tool_count=data.get("tool_count", 0),
|
|
137
|
+
oauth_tokens=oauth_tokens,
|
|
138
|
+
oauth_authorization_server=data.get("oauth_authorization_server"),
|
|
139
|
+
oauth_client_id=data.get("oauth_client_id"),
|
|
140
|
+
oauth_scopes=data.get("oauth_scopes"),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class MCPServerStore:
|
|
145
|
+
"""Persistent storage for MCP server configurations."""
|
|
146
|
+
|
|
147
|
+
def __init__(self, config_dir: Path | None = None) -> None:
|
|
148
|
+
if config_dir is None:
|
|
149
|
+
config_dir = Path(get_config().app_dir)
|
|
150
|
+
self.config_dir = config_dir
|
|
151
|
+
self.config_file = config_dir / "mcp_servers.json"
|
|
152
|
+
self._servers: dict[str, MCPServerConfig] = {}
|
|
153
|
+
self._load()
|
|
154
|
+
|
|
155
|
+
def _load(self) -> None:
|
|
156
|
+
"""Load servers from disk."""
|
|
157
|
+
if not self.config_file.exists():
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
with open(self.config_file, encoding="utf-8") as f:
|
|
162
|
+
data = json.load(f)
|
|
163
|
+
|
|
164
|
+
for server_data in data.get("servers", []):
|
|
165
|
+
server = MCPServerConfig.from_dict(server_data)
|
|
166
|
+
self._servers[server.name] = server
|
|
167
|
+
|
|
168
|
+
log.debug(f"Loaded {len(self._servers)} MCP server configs")
|
|
169
|
+
except Exception as e:
|
|
170
|
+
log.warning(f"Failed to load MCP server configs: {e}")
|
|
171
|
+
|
|
172
|
+
def _save(self) -> None:
|
|
173
|
+
"""Save servers to disk."""
|
|
174
|
+
self.config_dir.mkdir(parents=True, exist_ok=True)
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
data = {
|
|
178
|
+
"servers": [s.to_dict() for s in self._servers.values()],
|
|
179
|
+
"updated_at": datetime.now().isoformat(),
|
|
180
|
+
}
|
|
181
|
+
with open(self.config_file, "w", encoding="utf-8") as f:
|
|
182
|
+
json.dump(data, f, indent=2)
|
|
183
|
+
except Exception as e:
|
|
184
|
+
log.warning(f"Failed to save MCP server configs: {e}")
|
|
185
|
+
|
|
186
|
+
def add(self, server: MCPServerConfig) -> None:
|
|
187
|
+
"""Add or update a server configuration."""
|
|
188
|
+
self._servers[server.name] = server
|
|
189
|
+
self._save()
|
|
190
|
+
|
|
191
|
+
def remove(self, name: str) -> bool:
|
|
192
|
+
"""Remove a server configuration."""
|
|
193
|
+
if name in self._servers:
|
|
194
|
+
del self._servers[name]
|
|
195
|
+
self._save()
|
|
196
|
+
return True
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
def get(self, name: str) -> MCPServerConfig | None:
|
|
200
|
+
"""Get a server configuration by name."""
|
|
201
|
+
return self._servers.get(name)
|
|
202
|
+
|
|
203
|
+
def list_all(self) -> list[MCPServerConfig]:
|
|
204
|
+
"""List all server configurations."""
|
|
205
|
+
return list(self._servers.values())
|
|
206
|
+
|
|
207
|
+
def list_enabled(self) -> list[MCPServerConfig]:
|
|
208
|
+
"""List enabled server configurations."""
|
|
209
|
+
return [s for s in self._servers.values() if s.enabled]
|
|
210
|
+
|
|
211
|
+
def update_status(self, name: str, connected: bool, tool_count: int = 0) -> None:
|
|
212
|
+
"""Update server connection status."""
|
|
213
|
+
if name in self._servers:
|
|
214
|
+
if connected:
|
|
215
|
+
self._servers[name].last_connected = datetime.now()
|
|
216
|
+
self._servers[name].tool_count = tool_count
|
|
217
|
+
self._save()
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class ToolAuthorizationRequired(Exception):
|
|
221
|
+
"""Exception raised when a tool requires user authorization."""
|
|
222
|
+
|
|
223
|
+
def __init__(self, tool_name: str, authorization_url: str | None = None):
|
|
224
|
+
self.tool_name = tool_name
|
|
225
|
+
self.authorization_url = authorization_url
|
|
226
|
+
|
|
227
|
+
if authorization_url:
|
|
228
|
+
message = (
|
|
229
|
+
f"Authorization required for '{tool_name}'.\n\n"
|
|
230
|
+
f"Please authorize by clicking this link:\n{authorization_url}\n\n"
|
|
231
|
+
f"After authorizing, let me know and I'll retry the operation."
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
message = (
|
|
235
|
+
f"Authorization required for '{tool_name}'.\n\n"
|
|
236
|
+
f"Please authorize this tool in your Arcade account."
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
super().__init__(message)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
class ToolManager(ABC):
|
|
243
|
+
"""Base interface for tool management."""
|
|
244
|
+
|
|
245
|
+
@abstractmethod
|
|
246
|
+
async def get_tools(self) -> list[dict[str, Any]]:
|
|
247
|
+
"""Get all available tools as jsonschemas."""
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
@abstractmethod
|
|
251
|
+
async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
|
|
252
|
+
"""Execute a tool by name with inputs."""
|
|
253
|
+
pass
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
# =============================================================================
|
|
257
|
+
# MCP OAuth 2.1 Support (RFC 9728, RFC 8414)
|
|
258
|
+
# =============================================================================
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class MCPOAuthHandler:
|
|
262
|
+
"""Handles OAuth 2.1 authorization for MCP servers per MCP spec."""
|
|
263
|
+
|
|
264
|
+
def __init__(self, server_config: MCPServerConfig) -> None:
|
|
265
|
+
self.config = server_config
|
|
266
|
+
self._client: httpx.AsyncClient | None = None
|
|
267
|
+
|
|
268
|
+
async def _ensure_client(self) -> httpx.AsyncClient:
|
|
269
|
+
"""Ensure HTTP client exists."""
|
|
270
|
+
if self._client is None:
|
|
271
|
+
self._client = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
|
|
272
|
+
return self._client
|
|
273
|
+
|
|
274
|
+
def parse_www_authenticate(self, header: str) -> dict[str, str]:
|
|
275
|
+
"""Parse WWW-Authenticate header to extract OAuth parameters.
|
|
276
|
+
|
|
277
|
+
Per MCP spec, extracts: resource_metadata, scope, error, error_description
|
|
278
|
+
"""
|
|
279
|
+
params: dict[str, str] = {}
|
|
280
|
+
# Remove "Bearer " prefix if present
|
|
281
|
+
if header.lower().startswith("bearer "):
|
|
282
|
+
header = header[7:]
|
|
283
|
+
|
|
284
|
+
# Parse key="value" pairs
|
|
285
|
+
import re
|
|
286
|
+
|
|
287
|
+
pattern = r'(\w+)="([^"]*)"'
|
|
288
|
+
for match in re.finditer(pattern, header):
|
|
289
|
+
params[match.group(1)] = match.group(2)
|
|
290
|
+
|
|
291
|
+
return params
|
|
292
|
+
|
|
293
|
+
async def discover_from_401(self, response: httpx.Response) -> tuple[str | None, str | None]:
|
|
294
|
+
"""Discover authorization server from 401 response.
|
|
295
|
+
|
|
296
|
+
Returns: (resource_metadata_url, required_scope)
|
|
297
|
+
"""
|
|
298
|
+
www_auth = response.headers.get("WWW-Authenticate", "")
|
|
299
|
+
params = self.parse_www_authenticate(www_auth)
|
|
300
|
+
|
|
301
|
+
resource_metadata_url = params.get("resource_metadata")
|
|
302
|
+
scope = params.get("scope")
|
|
303
|
+
|
|
304
|
+
return resource_metadata_url, scope
|
|
305
|
+
|
|
306
|
+
async def fetch_protected_resource_metadata(
|
|
307
|
+
self, metadata_url: str | None = None
|
|
308
|
+
) -> dict[str, Any] | None:
|
|
309
|
+
"""Fetch Protected Resource Metadata per RFC 9728.
|
|
310
|
+
|
|
311
|
+
Tries:
|
|
312
|
+
1. Provided metadata_url (from WWW-Authenticate)
|
|
313
|
+
2. Well-known URI with path: /.well-known/oauth-protected-resource/<path>
|
|
314
|
+
3. Well-known URI at root: /.well-known/oauth-protected-resource
|
|
315
|
+
"""
|
|
316
|
+
client = await self._ensure_client()
|
|
317
|
+
from urllib.parse import urlparse
|
|
318
|
+
|
|
319
|
+
parsed = urlparse(self.config.url)
|
|
320
|
+
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
321
|
+
|
|
322
|
+
urls_to_try = []
|
|
323
|
+
if metadata_url:
|
|
324
|
+
urls_to_try.append(metadata_url)
|
|
325
|
+
|
|
326
|
+
# Well-known with path
|
|
327
|
+
if parsed.path and parsed.path != "/":
|
|
328
|
+
path = parsed.path.rstrip("/")
|
|
329
|
+
urls_to_try.append(f"{base}/.well-known/oauth-protected-resource{path}")
|
|
330
|
+
|
|
331
|
+
# Well-known at root
|
|
332
|
+
urls_to_try.append(f"{base}/.well-known/oauth-protected-resource")
|
|
333
|
+
|
|
334
|
+
for url in urls_to_try:
|
|
335
|
+
try:
|
|
336
|
+
resp = await client.get(url)
|
|
337
|
+
if resp.status_code == 200:
|
|
338
|
+
return resp.json()
|
|
339
|
+
except Exception:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
return None
|
|
343
|
+
|
|
344
|
+
async def fetch_authorization_server_metadata(self, issuer: str) -> dict[str, Any] | None:
|
|
345
|
+
"""Fetch Authorization Server Metadata per RFC 8414.
|
|
346
|
+
|
|
347
|
+
Tries both OAuth 2.0 and OpenID Connect discovery endpoints.
|
|
348
|
+
"""
|
|
349
|
+
client = await self._ensure_client()
|
|
350
|
+
from urllib.parse import urlparse
|
|
351
|
+
|
|
352
|
+
parsed = urlparse(issuer)
|
|
353
|
+
base = f"{parsed.scheme}://{parsed.netloc}"
|
|
354
|
+
path = parsed.path.rstrip("/") if parsed.path else ""
|
|
355
|
+
|
|
356
|
+
# Priority order per MCP spec
|
|
357
|
+
urls_to_try = []
|
|
358
|
+
if path:
|
|
359
|
+
# With path component
|
|
360
|
+
urls_to_try.extend(
|
|
361
|
+
[
|
|
362
|
+
f"{base}/.well-known/oauth-authorization-server{path}",
|
|
363
|
+
f"{base}/.well-known/openid-configuration{path}",
|
|
364
|
+
f"{base}{path}/.well-known/openid-configuration",
|
|
365
|
+
]
|
|
366
|
+
)
|
|
367
|
+
else:
|
|
368
|
+
# Without path
|
|
369
|
+
urls_to_try.extend(
|
|
370
|
+
[
|
|
371
|
+
f"{base}/.well-known/oauth-authorization-server",
|
|
372
|
+
f"{base}/.well-known/openid-configuration",
|
|
373
|
+
]
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
for url in urls_to_try:
|
|
377
|
+
try:
|
|
378
|
+
resp = await client.get(url)
|
|
379
|
+
if resp.status_code == 200:
|
|
380
|
+
metadata = resp.json()
|
|
381
|
+
# Verify PKCE support (required by MCP spec)
|
|
382
|
+
if "code_challenge_methods_supported" not in metadata:
|
|
383
|
+
log.warning(f"Auth server {url} doesn't advertise PKCE support")
|
|
384
|
+
return metadata
|
|
385
|
+
except Exception:
|
|
386
|
+
continue
|
|
387
|
+
|
|
388
|
+
return None
|
|
389
|
+
|
|
390
|
+
def generate_pkce(self) -> tuple[str, str]:
|
|
391
|
+
"""Generate PKCE code_verifier and code_challenge (S256)."""
|
|
392
|
+
import base64
|
|
393
|
+
import hashlib
|
|
394
|
+
import secrets
|
|
395
|
+
|
|
396
|
+
# Generate code_verifier (43-128 chars)
|
|
397
|
+
code_verifier = secrets.token_urlsafe(32)
|
|
398
|
+
|
|
399
|
+
# Generate code_challenge using S256
|
|
400
|
+
digest = hashlib.sha256(code_verifier.encode()).digest()
|
|
401
|
+
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
|
402
|
+
|
|
403
|
+
return code_verifier, code_challenge
|
|
404
|
+
|
|
405
|
+
async def start_oauth_flow(
|
|
406
|
+
self,
|
|
407
|
+
auth_server_metadata: dict[str, Any],
|
|
408
|
+
scope: str | None = None,
|
|
409
|
+
client_id: str | None = None,
|
|
410
|
+
) -> tuple[str, str, str]:
|
|
411
|
+
"""Start OAuth authorization flow.
|
|
412
|
+
|
|
413
|
+
Returns: (authorization_url, code_verifier, state)
|
|
414
|
+
"""
|
|
415
|
+
import secrets
|
|
416
|
+
from urllib.parse import urlencode
|
|
417
|
+
|
|
418
|
+
auth_endpoint = auth_server_metadata.get("authorization_endpoint")
|
|
419
|
+
if not auth_endpoint:
|
|
420
|
+
raise Exception("No authorization_endpoint in auth server metadata")
|
|
421
|
+
|
|
422
|
+
code_verifier, code_challenge = self.generate_pkce()
|
|
423
|
+
state = secrets.token_urlsafe(16)
|
|
424
|
+
|
|
425
|
+
# Use client_id or generate one
|
|
426
|
+
if not client_id:
|
|
427
|
+
client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
|
|
428
|
+
|
|
429
|
+
# Determine scope
|
|
430
|
+
if not scope:
|
|
431
|
+
scope = " ".join(self.config.oauth_scopes or [])
|
|
432
|
+
if not scope:
|
|
433
|
+
# Use scopes_supported from metadata
|
|
434
|
+
supported = auth_server_metadata.get("scopes_supported", [])
|
|
435
|
+
scope = " ".join(supported) if supported else ""
|
|
436
|
+
|
|
437
|
+
params = {
|
|
438
|
+
"response_type": "code",
|
|
439
|
+
"client_id": client_id,
|
|
440
|
+
"redirect_uri": "http://127.0.0.1:9876/callback",
|
|
441
|
+
"code_challenge": code_challenge,
|
|
442
|
+
"code_challenge_method": "S256",
|
|
443
|
+
"state": state,
|
|
444
|
+
"resource": self.config.url, # RFC 8707 resource indicator
|
|
445
|
+
}
|
|
446
|
+
if scope:
|
|
447
|
+
params["scope"] = scope
|
|
448
|
+
|
|
449
|
+
auth_url = f"{auth_endpoint}?{urlencode(params)}"
|
|
450
|
+
return auth_url, code_verifier, state
|
|
451
|
+
|
|
452
|
+
async def exchange_code_for_tokens(
|
|
453
|
+
self,
|
|
454
|
+
auth_server_metadata: dict[str, Any],
|
|
455
|
+
code: str,
|
|
456
|
+
code_verifier: str,
|
|
457
|
+
client_id: str | None = None,
|
|
458
|
+
) -> MCPOAuthTokens:
|
|
459
|
+
"""Exchange authorization code for tokens."""
|
|
460
|
+
client = await self._ensure_client()
|
|
461
|
+
|
|
462
|
+
token_endpoint = auth_server_metadata.get("token_endpoint")
|
|
463
|
+
if not token_endpoint:
|
|
464
|
+
raise Exception("No token_endpoint in auth server metadata")
|
|
465
|
+
|
|
466
|
+
if not client_id:
|
|
467
|
+
client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
|
|
468
|
+
|
|
469
|
+
data = {
|
|
470
|
+
"grant_type": "authorization_code",
|
|
471
|
+
"code": code,
|
|
472
|
+
"redirect_uri": "http://127.0.0.1:9876/callback",
|
|
473
|
+
"client_id": client_id,
|
|
474
|
+
"code_verifier": code_verifier,
|
|
475
|
+
"resource": self.config.url,
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
resp = await client.post(
|
|
479
|
+
token_endpoint,
|
|
480
|
+
data=data,
|
|
481
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
if resp.status_code != 200:
|
|
485
|
+
raise Exception(f"Token exchange failed: {resp.text}")
|
|
486
|
+
|
|
487
|
+
token_data = resp.json()
|
|
488
|
+
expires_at = None
|
|
489
|
+
if "expires_in" in token_data:
|
|
490
|
+
expires_at = datetime.now() + timedelta(seconds=token_data["expires_in"])
|
|
491
|
+
|
|
492
|
+
return MCPOAuthTokens(
|
|
493
|
+
access_token=token_data["access_token"],
|
|
494
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
495
|
+
refresh_token=token_data.get("refresh_token"),
|
|
496
|
+
expires_at=expires_at,
|
|
497
|
+
scope=token_data.get("scope"),
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
async def refresh_tokens(
|
|
501
|
+
self,
|
|
502
|
+
auth_server_metadata: dict[str, Any],
|
|
503
|
+
refresh_token: str,
|
|
504
|
+
client_id: str | None = None,
|
|
505
|
+
) -> MCPOAuthTokens:
|
|
506
|
+
"""Refresh access token using refresh token."""
|
|
507
|
+
client = await self._ensure_client()
|
|
508
|
+
|
|
509
|
+
token_endpoint = auth_server_metadata.get("token_endpoint")
|
|
510
|
+
if not token_endpoint:
|
|
511
|
+
raise Exception("No token_endpoint in auth server metadata")
|
|
512
|
+
|
|
513
|
+
if not client_id:
|
|
514
|
+
client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
|
|
515
|
+
|
|
516
|
+
data = {
|
|
517
|
+
"grant_type": "refresh_token",
|
|
518
|
+
"refresh_token": refresh_token,
|
|
519
|
+
"client_id": client_id,
|
|
520
|
+
"resource": self.config.url,
|
|
521
|
+
}
|
|
522
|
+
|
|
523
|
+
resp = await client.post(
|
|
524
|
+
token_endpoint,
|
|
525
|
+
data=data,
|
|
526
|
+
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
if resp.status_code != 200:
|
|
530
|
+
raise Exception(f"Token refresh failed: {resp.text}")
|
|
531
|
+
|
|
532
|
+
token_data = resp.json()
|
|
533
|
+
expires_at = None
|
|
534
|
+
if "expires_in" in token_data:
|
|
535
|
+
expires_at = datetime.now() + timedelta(seconds=token_data["expires_in"])
|
|
536
|
+
|
|
537
|
+
return MCPOAuthTokens(
|
|
538
|
+
access_token=token_data["access_token"],
|
|
539
|
+
token_type=token_data.get("token_type", "Bearer"),
|
|
540
|
+
refresh_token=token_data.get("refresh_token", refresh_token),
|
|
541
|
+
expires_at=expires_at,
|
|
542
|
+
scope=token_data.get("scope"),
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
async def close(self) -> None:
|
|
546
|
+
"""Close HTTP client."""
|
|
547
|
+
if self._client:
|
|
548
|
+
await self._client.aclose()
|
|
549
|
+
self._client = None
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
# =============================================================================
|
|
553
|
+
# MCP Tool Manager (HTTP Streamable JSON-RPC)
|
|
554
|
+
# =============================================================================
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class MCPToolManager(ToolManager):
|
|
558
|
+
"""Manages tools from MCP servers over HTTP (JSON-RPC streamable transport)."""
|
|
559
|
+
|
|
560
|
+
def __init__(
|
|
561
|
+
self,
|
|
562
|
+
server_config: MCPServerConfig,
|
|
563
|
+
server_store: "MCPServerStore | None" = None,
|
|
564
|
+
) -> None:
|
|
565
|
+
self.config = server_config
|
|
566
|
+
self._server_store = server_store
|
|
567
|
+
self._tools_cache: list[dict[str, Any]] | None = None
|
|
568
|
+
self._initialized = False
|
|
569
|
+
self._client: httpx.AsyncClient | None = None
|
|
570
|
+
self._session_id: str | None = None
|
|
571
|
+
self._oauth_handler = MCPOAuthHandler(server_config)
|
|
572
|
+
self._auth_server_metadata: dict[str, Any] | None = None
|
|
573
|
+
|
|
574
|
+
def _get_headers(self) -> dict[str, str]:
|
|
575
|
+
"""Get HTTP headers including authentication and session ID."""
|
|
576
|
+
headers = {
|
|
577
|
+
"Content-Type": "application/json",
|
|
578
|
+
"Accept": "application/json, text/event-stream",
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
# Include MCP session ID if we have one (required for HTTP transport)
|
|
582
|
+
if self._session_id:
|
|
583
|
+
headers["Mcp-Session-Id"] = self._session_id
|
|
584
|
+
|
|
585
|
+
# Handle different auth types
|
|
586
|
+
if self.config.auth_type == MCPAuthType.BEARER and self.config.auth_value:
|
|
587
|
+
headers["Authorization"] = f"Bearer {self.config.auth_value}"
|
|
588
|
+
elif self.config.auth_type == MCPAuthType.API_KEY and self.config.auth_value:
|
|
589
|
+
headers["X-API-Key"] = self.config.auth_value
|
|
590
|
+
elif self.config.auth_type == MCPAuthType.OAUTH and self.config.oauth_tokens:
|
|
591
|
+
# Use OAuth token
|
|
592
|
+
token = self.config.oauth_tokens
|
|
593
|
+
headers["Authorization"] = f"{token.token_type} {token.access_token}"
|
|
594
|
+
|
|
595
|
+
return headers
|
|
596
|
+
|
|
597
|
+
async def _ensure_client(self) -> httpx.AsyncClient:
|
|
598
|
+
"""Ensure HTTP client is created with current auth headers."""
|
|
599
|
+
# Only recreate if no client exists or if we need to refresh tokens
|
|
600
|
+
if self._client is None:
|
|
601
|
+
self._client = httpx.AsyncClient(
|
|
602
|
+
timeout=httpx.Timeout(60.0, connect=10.0),
|
|
603
|
+
)
|
|
604
|
+
return self._client
|
|
605
|
+
|
|
606
|
+
async def _maybe_refresh_oauth_token(self) -> bool:
|
|
607
|
+
"""Refresh OAuth token if expired. Returns True if refreshed."""
|
|
608
|
+
if self.config.auth_type != MCPAuthType.OAUTH:
|
|
609
|
+
return False
|
|
610
|
+
|
|
611
|
+
tokens = self.config.oauth_tokens
|
|
612
|
+
if not tokens or not tokens.is_expired():
|
|
613
|
+
return False
|
|
614
|
+
|
|
615
|
+
if not tokens.refresh_token:
|
|
616
|
+
return False
|
|
617
|
+
|
|
618
|
+
if not self._auth_server_metadata:
|
|
619
|
+
# Try to fetch it
|
|
620
|
+
rs_meta = await self._oauth_handler.fetch_protected_resource_metadata()
|
|
621
|
+
if rs_meta:
|
|
622
|
+
auth_servers = rs_meta.get("authorization_servers", [])
|
|
623
|
+
if auth_servers:
|
|
624
|
+
self._auth_server_metadata = (
|
|
625
|
+
await self._oauth_handler.fetch_authorization_server_metadata(
|
|
626
|
+
auth_servers[0]
|
|
627
|
+
)
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
if not self._auth_server_metadata:
|
|
631
|
+
return False
|
|
632
|
+
|
|
633
|
+
try:
|
|
634
|
+
new_tokens = await self._oauth_handler.refresh_tokens(
|
|
635
|
+
self._auth_server_metadata,
|
|
636
|
+
tokens.refresh_token,
|
|
637
|
+
self.config.oauth_client_id,
|
|
638
|
+
)
|
|
639
|
+
self.config.oauth_tokens = new_tokens
|
|
640
|
+
|
|
641
|
+
# Persist updated tokens
|
|
642
|
+
if self._server_store:
|
|
643
|
+
self._server_store.add(self.config)
|
|
644
|
+
|
|
645
|
+
log.info(f"Refreshed OAuth token for MCP server '{self.config.name}'")
|
|
646
|
+
return True
|
|
647
|
+
except Exception as e:
|
|
648
|
+
log.warning(f"Failed to refresh OAuth token: {e}")
|
|
649
|
+
return False
|
|
650
|
+
|
|
651
|
+
async def _handle_401_response(self, response: httpx.Response) -> str | None:
|
|
652
|
+
"""Handle 401 response per MCP OAuth spec.
|
|
653
|
+
|
|
654
|
+
Returns authorization URL if OAuth flow should be initiated.
|
|
655
|
+
"""
|
|
656
|
+
# Parse WWW-Authenticate header
|
|
657
|
+
metadata_url, scope = await self._oauth_handler.discover_from_401(response)
|
|
658
|
+
|
|
659
|
+
# Fetch Protected Resource Metadata
|
|
660
|
+
rs_metadata = await self._oauth_handler.fetch_protected_resource_metadata(metadata_url)
|
|
661
|
+
if not rs_metadata:
|
|
662
|
+
return None
|
|
663
|
+
|
|
664
|
+
# Get authorization server
|
|
665
|
+
auth_servers = rs_metadata.get("authorization_servers", [])
|
|
666
|
+
if not auth_servers:
|
|
667
|
+
return None
|
|
668
|
+
|
|
669
|
+
# Fetch Authorization Server Metadata
|
|
670
|
+
as_metadata = await self._oauth_handler.fetch_authorization_server_metadata(auth_servers[0])
|
|
671
|
+
if not as_metadata:
|
|
672
|
+
return None
|
|
673
|
+
|
|
674
|
+
self._auth_server_metadata = as_metadata
|
|
675
|
+
self.config.oauth_authorization_server = auth_servers[0]
|
|
676
|
+
|
|
677
|
+
# Generate authorization URL
|
|
678
|
+
auth_url, code_verifier, state = await self._oauth_handler.start_oauth_flow(
|
|
679
|
+
as_metadata, scope, self.config.oauth_client_id
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# Store PKCE verifier for later token exchange (temporary)
|
|
683
|
+
self._pending_code_verifier = code_verifier
|
|
684
|
+
self._pending_state = state
|
|
685
|
+
|
|
686
|
+
return auth_url
|
|
687
|
+
|
|
688
|
+
async def _send_notification(self, method: str) -> None:
|
|
689
|
+
"""Send a JSON-RPC notification (no response expected).
|
|
690
|
+
|
|
691
|
+
Per MCP spec, notifications don't have an 'id' and don't expect a response.
|
|
692
|
+
Used for 'notifications/initialized' and similar methods.
|
|
693
|
+
"""
|
|
694
|
+
client = await self._ensure_client()
|
|
695
|
+
|
|
696
|
+
# Notifications don't have an 'id' field
|
|
697
|
+
payload: dict[str, Any] = {
|
|
698
|
+
"jsonrpc": "2.0",
|
|
699
|
+
"method": method,
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
headers = self._get_headers()
|
|
703
|
+
try:
|
|
704
|
+
response = await client.post(self.config.url, json=payload, headers=headers)
|
|
705
|
+
# We don't expect a meaningful response for notifications
|
|
706
|
+
# Some servers return 202 Accepted, others return 200
|
|
707
|
+
if response.status_code not in (200, 202, 204):
|
|
708
|
+
log.warning(
|
|
709
|
+
f"MCP notification '{method}' got unexpected status: {response.status_code}"
|
|
710
|
+
)
|
|
711
|
+
except Exception as e:
|
|
712
|
+
log.warning(f"MCP notification '{method}' failed: {e}")
|
|
713
|
+
|
|
714
|
+
async def _send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
|
715
|
+
"""Send a JSON-RPC request to the MCP server.
|
|
716
|
+
|
|
717
|
+
Handles MCP HTTP transport session management by capturing and including
|
|
718
|
+
the Mcp-Session-Id header across requests.
|
|
719
|
+
"""
|
|
720
|
+
# Refresh token if needed
|
|
721
|
+
await self._maybe_refresh_oauth_token()
|
|
722
|
+
|
|
723
|
+
client = await self._ensure_client()
|
|
724
|
+
|
|
725
|
+
request_id = str(uuid.uuid4())
|
|
726
|
+
payload: dict[str, Any] = {
|
|
727
|
+
"jsonrpc": "2.0",
|
|
728
|
+
"id": request_id,
|
|
729
|
+
"method": method,
|
|
730
|
+
}
|
|
731
|
+
if params:
|
|
732
|
+
payload["params"] = params
|
|
733
|
+
|
|
734
|
+
try:
|
|
735
|
+
# Include current headers (with session ID if available)
|
|
736
|
+
headers = self._get_headers()
|
|
737
|
+
response = await client.post(self.config.url, json=payload, headers=headers)
|
|
738
|
+
|
|
739
|
+
# Capture session ID from response for subsequent requests
|
|
740
|
+
# Per MCP HTTP transport spec, server sends Mcp-Session-Id header
|
|
741
|
+
if "mcp-session-id" in response.headers:
|
|
742
|
+
self._session_id = response.headers["mcp-session-id"]
|
|
743
|
+
log.debug(
|
|
744
|
+
f"Captured MCP session ID for '{self.config.name}': {self._session_id[:8]}..."
|
|
745
|
+
)
|
|
746
|
+
|
|
747
|
+
# Handle 401 - OAuth discovery
|
|
748
|
+
if response.status_code == 401:
|
|
749
|
+
auth_url = await self._handle_401_response(response)
|
|
750
|
+
raise ToolAuthorizationRequired(
|
|
751
|
+
f"MCP:{self.config.name}",
|
|
752
|
+
authorization_url=auth_url,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Handle 403 - Insufficient scope
|
|
756
|
+
if response.status_code == 403:
|
|
757
|
+
www_auth = response.headers.get("WWW-Authenticate", "")
|
|
758
|
+
auth_params = self._oauth_handler.parse_www_authenticate(www_auth)
|
|
759
|
+
if auth_params.get("error") == "insufficient_scope":
|
|
760
|
+
raise ToolAuthorizationRequired(
|
|
761
|
+
f"MCP:{self.config.name}",
|
|
762
|
+
authorization_url=None,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
response.raise_for_status()
|
|
766
|
+
|
|
767
|
+
# Handle SSE or direct JSON response
|
|
768
|
+
content_type = response.headers.get("content-type", "")
|
|
769
|
+
|
|
770
|
+
if "text/event-stream" in content_type:
|
|
771
|
+
# Parse SSE response
|
|
772
|
+
return await self._parse_sse_response(response, request_id)
|
|
773
|
+
else:
|
|
774
|
+
# Direct JSON response
|
|
775
|
+
result = response.json()
|
|
776
|
+
if "error" in result:
|
|
777
|
+
raise Exception(f"MCP error: {result['error'].get('message', 'Unknown error')}")
|
|
778
|
+
return result.get("result")
|
|
779
|
+
|
|
780
|
+
except httpx.HTTPStatusError as e:
|
|
781
|
+
raise Exception(f"MCP HTTP error: {e}")
|
|
782
|
+
except ToolAuthorizationRequired:
|
|
783
|
+
raise
|
|
784
|
+
except Exception as e:
|
|
785
|
+
log.error(f"MCP request failed for {self.config.name}: {e}")
|
|
786
|
+
raise
|
|
787
|
+
|
|
788
|
+
async def _parse_sse_response(self, response: httpx.Response, request_id: str) -> Any:
|
|
789
|
+
"""Parse Server-Sent Events response."""
|
|
790
|
+
result = None
|
|
791
|
+
async for line in response.aiter_lines():
|
|
792
|
+
if line.startswith("data: "):
|
|
793
|
+
data = line[6:]
|
|
794
|
+
if data.strip():
|
|
795
|
+
try:
|
|
796
|
+
event = json.loads(data)
|
|
797
|
+
if event.get("id") == request_id:
|
|
798
|
+
if "error" in event:
|
|
799
|
+
raise Exception(f"MCP error: {event['error'].get('message')}")
|
|
800
|
+
result = event.get("result")
|
|
801
|
+
except json.JSONDecodeError:
|
|
802
|
+
continue
|
|
803
|
+
return result
|
|
804
|
+
|
|
805
|
+
async def initialize(self) -> bool:
|
|
806
|
+
"""Initialize connection to MCP server.
|
|
807
|
+
|
|
808
|
+
Follows MCP protocol:
|
|
809
|
+
1. Send 'initialize' request
|
|
810
|
+
2. Receive server capabilities
|
|
811
|
+
3. Send 'notifications/initialized' notification
|
|
812
|
+
"""
|
|
813
|
+
if self._initialized:
|
|
814
|
+
return True
|
|
815
|
+
|
|
816
|
+
try:
|
|
817
|
+
result = await self._send_request(
|
|
818
|
+
"initialize",
|
|
819
|
+
{
|
|
820
|
+
"protocolVersion": "2024-11-05",
|
|
821
|
+
"capabilities": {"tools": {}},
|
|
822
|
+
"clientInfo": {"name": "cade", "version": "1.0.0"},
|
|
823
|
+
},
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
if result:
|
|
827
|
+
self._initialized = True
|
|
828
|
+
# Send initialized notification (no response expected)
|
|
829
|
+
await self._send_notification("notifications/initialized")
|
|
830
|
+
log.info(f"MCP server '{self.config.name}' initialized")
|
|
831
|
+
return True
|
|
832
|
+
|
|
833
|
+
return False
|
|
834
|
+
|
|
835
|
+
except Exception as e:
|
|
836
|
+
log.error(f"Failed to initialize MCP server '{self.config.name}': {e}")
|
|
837
|
+
return False
|
|
838
|
+
|
|
839
|
+
async def get_tools(self) -> list[dict[str, Any]]:
|
|
840
|
+
"""Get available tools from the MCP server."""
|
|
841
|
+
if self._tools_cache is not None:
|
|
842
|
+
return self._tools_cache
|
|
843
|
+
|
|
844
|
+
if not self._initialized:
|
|
845
|
+
success = await self.initialize()
|
|
846
|
+
if not success:
|
|
847
|
+
return []
|
|
848
|
+
|
|
849
|
+
try:
|
|
850
|
+
result = await self._send_request("tools/list")
|
|
851
|
+
|
|
852
|
+
if not result or "tools" not in result:
|
|
853
|
+
return []
|
|
854
|
+
|
|
855
|
+
# Convert MCP tool format to OpenAI function format
|
|
856
|
+
self._tools_cache = []
|
|
857
|
+
for mcp_tool in result["tools"]:
|
|
858
|
+
openai_tool = self._convert_mcp_to_openai(mcp_tool)
|
|
859
|
+
self._tools_cache.append(openai_tool)
|
|
860
|
+
|
|
861
|
+
log.info(f"Loaded {len(self._tools_cache)} tools from MCP server '{self.config.name}'")
|
|
862
|
+
return self._tools_cache
|
|
863
|
+
|
|
864
|
+
except Exception as e:
|
|
865
|
+
log.error(f"Failed to list tools from MCP '{self.config.name}': {e}")
|
|
866
|
+
return []
|
|
867
|
+
|
|
868
|
+
def _convert_mcp_to_openai(self, mcp_tool: dict[str, Any]) -> dict[str, Any]:
|
|
869
|
+
"""Convert MCP tool schema to OpenAI function schema.
|
|
870
|
+
|
|
871
|
+
Uses the original tool name from the MCP server without modification.
|
|
872
|
+
"""
|
|
873
|
+
return {
|
|
874
|
+
"type": "function",
|
|
875
|
+
"function": {
|
|
876
|
+
"name": mcp_tool["name"],
|
|
877
|
+
"description": mcp_tool.get("description", ""),
|
|
878
|
+
"parameters": mcp_tool.get("inputSchema", {"type": "object", "properties": {}}),
|
|
879
|
+
},
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
|
|
883
|
+
"""Execute a tool on the MCP server."""
|
|
884
|
+
if not self._initialized:
|
|
885
|
+
success = await self.initialize()
|
|
886
|
+
if not success:
|
|
887
|
+
raise Exception(f"MCP server '{self.config.name}' not initialized")
|
|
888
|
+
|
|
889
|
+
try:
|
|
890
|
+
result = await self._send_request(
|
|
891
|
+
"tools/call",
|
|
892
|
+
{"name": name, "arguments": inputs},
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
if result and "content" in result:
|
|
896
|
+
# Extract text content from MCP response
|
|
897
|
+
content_items = result["content"]
|
|
898
|
+
text_parts = []
|
|
899
|
+
for item in content_items:
|
|
900
|
+
if item.get("type") == "text":
|
|
901
|
+
text_parts.append(item.get("text", ""))
|
|
902
|
+
return "\n".join(text_parts) if text_parts else result
|
|
903
|
+
|
|
904
|
+
return result
|
|
905
|
+
|
|
906
|
+
except Exception as e:
|
|
907
|
+
log.error(f"MCP tool execution failed for '{name}': {e}")
|
|
908
|
+
raise Exception(f"Failed to execute MCP tool '{name}': {e}")
|
|
909
|
+
|
|
910
|
+
async def check_status(self) -> tuple[bool, str]:
|
|
911
|
+
"""Check if MCP server is reachable and properly initialize the session.
|
|
912
|
+
|
|
913
|
+
This performs a full initialization so the session is ready for tool listing.
|
|
914
|
+
Calling get_tools() after a successful check_status() will work correctly.
|
|
915
|
+
"""
|
|
916
|
+
try:
|
|
917
|
+
# Reset state for fresh check
|
|
918
|
+
self._initialized = False
|
|
919
|
+
self._session_id = None
|
|
920
|
+
|
|
921
|
+
# Use the proper initialize flow
|
|
922
|
+
success = await self.initialize()
|
|
923
|
+
if success:
|
|
924
|
+
return True, "Connected"
|
|
925
|
+
else:
|
|
926
|
+
return False, "Initialization failed"
|
|
927
|
+
|
|
928
|
+
except httpx.ConnectError:
|
|
929
|
+
return False, "Connection failed"
|
|
930
|
+
except ToolAuthorizationRequired:
|
|
931
|
+
return False, "Authentication required"
|
|
932
|
+
except Exception as e:
|
|
933
|
+
return False, str(e)[:50]
|
|
934
|
+
|
|
935
|
+
async def complete_oauth_flow(self, authorization_code: str) -> bool:
|
|
936
|
+
"""Complete OAuth flow by exchanging authorization code for tokens.
|
|
937
|
+
|
|
938
|
+
Call this after user has authorized and callback received the code.
|
|
939
|
+
"""
|
|
940
|
+
if not self._auth_server_metadata:
|
|
941
|
+
log.error("No auth server metadata - cannot complete OAuth flow")
|
|
942
|
+
return False
|
|
943
|
+
|
|
944
|
+
if not hasattr(self, "_pending_code_verifier"):
|
|
945
|
+
log.error("No pending PKCE verifier - OAuth flow not started")
|
|
946
|
+
return False
|
|
947
|
+
|
|
948
|
+
try:
|
|
949
|
+
tokens = await self._oauth_handler.exchange_code_for_tokens(
|
|
950
|
+
self._auth_server_metadata,
|
|
951
|
+
authorization_code,
|
|
952
|
+
self._pending_code_verifier,
|
|
953
|
+
self.config.oauth_client_id,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
# Update config with tokens
|
|
957
|
+
self.config.auth_type = MCPAuthType.OAUTH
|
|
958
|
+
self.config.oauth_tokens = tokens
|
|
959
|
+
|
|
960
|
+
# Persist
|
|
961
|
+
if self._server_store:
|
|
962
|
+
self._server_store.add(self.config)
|
|
963
|
+
|
|
964
|
+
# Clean up pending state
|
|
965
|
+
del self._pending_code_verifier
|
|
966
|
+
if hasattr(self, "_pending_state"):
|
|
967
|
+
del self._pending_state
|
|
968
|
+
|
|
969
|
+
# Reset client to use new tokens
|
|
970
|
+
if self._client:
|
|
971
|
+
await self._client.aclose()
|
|
972
|
+
self._client = None
|
|
973
|
+
|
|
974
|
+
log.info(f"OAuth flow completed for MCP server '{self.config.name}'")
|
|
975
|
+
return True
|
|
976
|
+
|
|
977
|
+
except Exception as e:
|
|
978
|
+
log.error(f"Failed to complete OAuth flow: {e}")
|
|
979
|
+
return False
|
|
980
|
+
|
|
981
|
+
async def close(self) -> None:
|
|
982
|
+
"""Close the HTTP client and reset session state."""
|
|
983
|
+
if self._client:
|
|
984
|
+
await self._client.aclose()
|
|
985
|
+
self._client = None
|
|
986
|
+
self._session_id = None
|
|
987
|
+
self._initialized = False
|
|
988
|
+
await self._oauth_handler.close()
|
|
989
|
+
|
|
990
|
+
|
|
991
|
+
class CacheEntry:
|
|
992
|
+
"""Cache entry for tool schemas with expiration."""
|
|
993
|
+
|
|
994
|
+
def __init__(self, tool_name: str, tool_schema: dict[str, Any], last_updated: datetime) -> None:
|
|
995
|
+
self.tool_name = tool_name
|
|
996
|
+
self.tool_schema = tool_schema
|
|
997
|
+
self.last_updated = last_updated
|
|
998
|
+
|
|
999
|
+
def is_expired(self, max_age: timedelta) -> bool:
|
|
1000
|
+
"""Check if cache entry has expired."""
|
|
1001
|
+
return datetime.now() - self.last_updated > max_age
|
|
1002
|
+
|
|
1003
|
+
def update(self, tool_schema: dict[str, Any]) -> None:
|
|
1004
|
+
self.tool_schema = tool_schema
|
|
1005
|
+
self.last_updated = datetime.now()
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
class ToolCache:
|
|
1009
|
+
"""Cache for tool schemas with TTL support and persistent storage."""
|
|
1010
|
+
|
|
1011
|
+
def __init__(
|
|
1012
|
+
self,
|
|
1013
|
+
max_age: timedelta = timedelta(hours=24),
|
|
1014
|
+
cache_dir: Path | None = None,
|
|
1015
|
+
) -> None:
|
|
1016
|
+
self._tools_cache: dict[str, CacheEntry] = {}
|
|
1017
|
+
self.max_age = max_age
|
|
1018
|
+
|
|
1019
|
+
if cache_dir is None:
|
|
1020
|
+
app_dir = Path(get_config().app_dir)
|
|
1021
|
+
cache_dir = app_dir / "tool_cache"
|
|
1022
|
+
self.cache_dir = cache_dir
|
|
1023
|
+
self.cache_file = self.cache_dir / "remote_tools.json"
|
|
1024
|
+
|
|
1025
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
1026
|
+
self._load_from_disk()
|
|
1027
|
+
|
|
1028
|
+
def _load_from_disk(self) -> None:
|
|
1029
|
+
"""Load cached tools from disk."""
|
|
1030
|
+
if not self.cache_file.exists():
|
|
1031
|
+
return
|
|
1032
|
+
|
|
1033
|
+
try:
|
|
1034
|
+
with open(self.cache_file, encoding="utf-8") as f:
|
|
1035
|
+
data = json.load(f)
|
|
1036
|
+
|
|
1037
|
+
for tool_name, entry_data in data.get("tools", {}).items():
|
|
1038
|
+
last_updated = datetime.fromisoformat(entry_data["last_updated"])
|
|
1039
|
+
tool_schema = entry_data["tool_schema"]
|
|
1040
|
+
entry = CacheEntry(tool_name, tool_schema, last_updated)
|
|
1041
|
+
if not entry.is_expired(self.max_age):
|
|
1042
|
+
self._tools_cache[tool_name] = entry
|
|
1043
|
+
|
|
1044
|
+
log.debug(f"Loaded {len(self._tools_cache)} tools from persistent cache")
|
|
1045
|
+
except Exception as e:
|
|
1046
|
+
log.warning(f"Failed to load tool cache from disk: {e}")
|
|
1047
|
+
|
|
1048
|
+
def _save_to_disk(self) -> None:
|
|
1049
|
+
"""Save cached tools to disk."""
|
|
1050
|
+
try:
|
|
1051
|
+
data = {
|
|
1052
|
+
"tools": {
|
|
1053
|
+
tool_name: {
|
|
1054
|
+
"tool_schema": entry.tool_schema,
|
|
1055
|
+
"last_updated": entry.last_updated.isoformat(),
|
|
1056
|
+
}
|
|
1057
|
+
for tool_name, entry in self._tools_cache.items()
|
|
1058
|
+
},
|
|
1059
|
+
"updated_at": datetime.now().isoformat(),
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
with open(self.cache_file, "w", encoding="utf-8") as f:
|
|
1063
|
+
json.dump(data, f, indent=2)
|
|
1064
|
+
except Exception as e:
|
|
1065
|
+
log.warning(f"Failed to save tool cache to disk: {e}")
|
|
1066
|
+
|
|
1067
|
+
def get_tools(
|
|
1068
|
+
self, tools: list[str] | None = None, toolkits: list[str] | None = None
|
|
1069
|
+
) -> list[dict[str, Any]]:
|
|
1070
|
+
"""Get cached tools filtered by names or toolkits."""
|
|
1071
|
+
if tools:
|
|
1072
|
+
return [
|
|
1073
|
+
entry.tool_schema
|
|
1074
|
+
for entry in self._tools_cache.values()
|
|
1075
|
+
if entry.tool_name in tools and not entry.is_expired(self.max_age)
|
|
1076
|
+
]
|
|
1077
|
+
if toolkits:
|
|
1078
|
+
# Filter by toolkit prefix
|
|
1079
|
+
toolkit_lower = {t.lower() for t in toolkits}
|
|
1080
|
+
return [
|
|
1081
|
+
entry.tool_schema
|
|
1082
|
+
for entry in self._tools_cache.values()
|
|
1083
|
+
if not entry.is_expired(self.max_age)
|
|
1084
|
+
and any(
|
|
1085
|
+
entry.tool_name.lower().startswith(f"{tk}_")
|
|
1086
|
+
or entry.tool_name.lower().startswith(f"{tk}.")
|
|
1087
|
+
for tk in toolkit_lower
|
|
1088
|
+
)
|
|
1089
|
+
]
|
|
1090
|
+
return [
|
|
1091
|
+
entry.tool_schema
|
|
1092
|
+
for entry in self._tools_cache.values()
|
|
1093
|
+
if not entry.is_expired(self.max_age)
|
|
1094
|
+
]
|
|
1095
|
+
|
|
1096
|
+
def update_tools(self, tool_name: str, tool_schema: dict[str, Any], save: bool = False):
|
|
1097
|
+
"""Update or add a tool to the cache."""
|
|
1098
|
+
self._tools_cache[tool_name] = CacheEntry(tool_name, tool_schema, datetime.now())
|
|
1099
|
+
if save:
|
|
1100
|
+
self._save_to_disk()
|
|
1101
|
+
|
|
1102
|
+
def save(self):
|
|
1103
|
+
"""Save all cached tools to disk."""
|
|
1104
|
+
self._save_to_disk()
|
|
1105
|
+
|
|
1106
|
+
def has_valid_cache(self) -> bool:
|
|
1107
|
+
"""Check if cache has any valid (non-expired) entries."""
|
|
1108
|
+
return any(not entry.is_expired(self.max_age) for entry in self._tools_cache.values())
|
|
1109
|
+
|
|
1110
|
+
|
|
1111
|
+
class LocalToolManager(ToolManager):
|
|
1112
|
+
"""Manages local tool execution."""
|
|
1113
|
+
|
|
1114
|
+
def __init__(self) -> None:
|
|
1115
|
+
self._tools_cache: list[JsonToolSchema] | None = None
|
|
1116
|
+
self._tool_funcs: dict[str, Any] = {}
|
|
1117
|
+
self._interactive_tools: set[str] = set()
|
|
1118
|
+
|
|
1119
|
+
async def get_tools(self) -> list[dict[str, Any]]:
|
|
1120
|
+
"""Get all available local tools as jsonschemas."""
|
|
1121
|
+
if self._tools_cache is None:
|
|
1122
|
+
from cadecoder.tools.builtin import get_all_tools
|
|
1123
|
+
|
|
1124
|
+
tools = get_all_tools()
|
|
1125
|
+
self._tools_cache = []
|
|
1126
|
+
|
|
1127
|
+
for tool in tools:
|
|
1128
|
+
tool_name = getattr(tool, "__tool_name__", getattr(tool, "__name__", "unknown"))
|
|
1129
|
+
|
|
1130
|
+
# Cache the function for execution
|
|
1131
|
+
self._tool_funcs[tool_name] = tool
|
|
1132
|
+
|
|
1133
|
+
schema = {
|
|
1134
|
+
"type": "function",
|
|
1135
|
+
"function": {
|
|
1136
|
+
"name": tool_name,
|
|
1137
|
+
"description": getattr(tool, "__tool_description__", ""),
|
|
1138
|
+
},
|
|
1139
|
+
}
|
|
1140
|
+
|
|
1141
|
+
if getattr(tool, "__interactive__", False):
|
|
1142
|
+
self._interactive_tools.add(tool_name)
|
|
1143
|
+
|
|
1144
|
+
# Build parameters schema from function signature
|
|
1145
|
+
sig = inspect.signature(tool)
|
|
1146
|
+
parameters = {"type": "object", "properties": {}, "required": []}
|
|
1147
|
+
|
|
1148
|
+
for param_name, param in sig.parameters.items():
|
|
1149
|
+
if param_name == "context":
|
|
1150
|
+
continue
|
|
1151
|
+
|
|
1152
|
+
param_type = "string"
|
|
1153
|
+
param_desc = f"Parameter {param_name}"
|
|
1154
|
+
|
|
1155
|
+
if param.annotation != inspect.Parameter.empty:
|
|
1156
|
+
if get_origin(param.annotation) is Annotated:
|
|
1157
|
+
args = get_args(param.annotation)
|
|
1158
|
+
if len(args) >= 2:
|
|
1159
|
+
actual_type = args[0]
|
|
1160
|
+
param_desc = args[1]
|
|
1161
|
+
|
|
1162
|
+
if actual_type is str:
|
|
1163
|
+
param_type = "string"
|
|
1164
|
+
elif actual_type is int:
|
|
1165
|
+
param_type = "integer"
|
|
1166
|
+
elif actual_type is float:
|
|
1167
|
+
param_type = "number"
|
|
1168
|
+
elif actual_type is bool:
|
|
1169
|
+
param_type = "boolean"
|
|
1170
|
+
elif actual_type is list or get_origin(actual_type) is list:
|
|
1171
|
+
param_type = "array"
|
|
1172
|
+
elif actual_type is dict:
|
|
1173
|
+
param_type = "object"
|
|
1174
|
+
|
|
1175
|
+
parameters["properties"][param_name] = {
|
|
1176
|
+
"type": param_type,
|
|
1177
|
+
"description": param_desc,
|
|
1178
|
+
}
|
|
1179
|
+
|
|
1180
|
+
if param.default == inspect.Parameter.empty:
|
|
1181
|
+
parameters["required"].append(param_name)
|
|
1182
|
+
|
|
1183
|
+
if parameters["properties"]:
|
|
1184
|
+
schema["function"]["parameters"] = parameters
|
|
1185
|
+
|
|
1186
|
+
self._tools_cache.append(schema)
|
|
1187
|
+
|
|
1188
|
+
return self._tools_cache
|
|
1189
|
+
|
|
1190
|
+
async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
|
|
1191
|
+
"""Execute a local tool by name with inputs."""
|
|
1192
|
+
# Ensure tools are loaded
|
|
1193
|
+
await self.get_tools()
|
|
1194
|
+
|
|
1195
|
+
tool_func = self._tool_funcs.get(name)
|
|
1196
|
+
if not tool_func:
|
|
1197
|
+
raise Exception(f"Tool '{name}' not found")
|
|
1198
|
+
|
|
1199
|
+
# Create minimal context
|
|
1200
|
+
from arcade_tdk import ToolContext
|
|
1201
|
+
|
|
1202
|
+
context = ToolContext(user_id="local_user") # type: ignore[call-arg]
|
|
1203
|
+
|
|
1204
|
+
# Execute (handle async and sync)
|
|
1205
|
+
if inspect.iscoroutinefunction(tool_func):
|
|
1206
|
+
result = await tool_func(context, **inputs)
|
|
1207
|
+
else:
|
|
1208
|
+
result = tool_func(context, **inputs)
|
|
1209
|
+
|
|
1210
|
+
return result
|
|
1211
|
+
|
|
1212
|
+
def is_interactive_tool(self, name: str) -> bool:
|
|
1213
|
+
"""Check if the named tool requires exclusive terminal access."""
|
|
1214
|
+
return name in self._interactive_tools
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
def _extract_toolkit_names_from_tools(tool_names: list[str]) -> set[str]:
|
|
1218
|
+
"""Extract toolkit names from tool names."""
|
|
1219
|
+
toolkits = set()
|
|
1220
|
+
for tool_name in tool_names:
|
|
1221
|
+
if "_" in tool_name:
|
|
1222
|
+
toolkit = tool_name.split("_")[0].lower()
|
|
1223
|
+
toolkits.add(toolkit)
|
|
1224
|
+
elif "." in tool_name:
|
|
1225
|
+
toolkit = tool_name.split(".")[0].lower()
|
|
1226
|
+
toolkits.add(toolkit)
|
|
1227
|
+
return toolkits
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
class RemoteToolManager(ToolManager):
|
|
1231
|
+
"""Manages remote Arcade tool execution using arcadepy."""
|
|
1232
|
+
|
|
1233
|
+
def __init__(
|
|
1234
|
+
self,
|
|
1235
|
+
api_key: str | None = None,
|
|
1236
|
+
base_url: str | None = None,
|
|
1237
|
+
user_email: str | None = None,
|
|
1238
|
+
):
|
|
1239
|
+
log.info(f"Initializing RemoteToolManager with user_email: {user_email}")
|
|
1240
|
+
cfg = get_config()
|
|
1241
|
+
self.arcade_client = AsyncArcade(
|
|
1242
|
+
api_key=api_key or cfg.api_key,
|
|
1243
|
+
base_url=base_url or cfg.base_url,
|
|
1244
|
+
)
|
|
1245
|
+
self._default_user_id = user_email or cfg.user_email
|
|
1246
|
+
self._tools_cache = ToolCache()
|
|
1247
|
+
self._all_tools_fetched = False
|
|
1248
|
+
self._fetching_lock = False
|
|
1249
|
+
|
|
1250
|
+
async def get_tools(
|
|
1251
|
+
self, tools: list[str] | None = None, toolkits: list[str] | None = None
|
|
1252
|
+
) -> list[dict[str, Any]]:
|
|
1253
|
+
"""Get available tools from Arcade Cloud."""
|
|
1254
|
+
try:
|
|
1255
|
+
toolkit_names_from_tools = set()
|
|
1256
|
+
if tools:
|
|
1257
|
+
toolkit_names_from_tools = _extract_toolkit_names_from_tools(tools)
|
|
1258
|
+
|
|
1259
|
+
all_toolkit_names = set()
|
|
1260
|
+
if toolkits:
|
|
1261
|
+
all_toolkit_names.update(t.lower() for t in toolkits)
|
|
1262
|
+
all_toolkit_names.update(toolkit_names_from_tools)
|
|
1263
|
+
|
|
1264
|
+
(list(all_toolkit_names) if all_toolkit_names else (toolkits or []))
|
|
1265
|
+
|
|
1266
|
+
if self._all_tools_fetched or self._tools_cache.has_valid_cache():
|
|
1267
|
+
cached_tools = self._filter_tools_from_cache(tools, toolkits)
|
|
1268
|
+
if cached_tools:
|
|
1269
|
+
log.debug(f"Using {len(cached_tools)} cached tools")
|
|
1270
|
+
return cached_tools
|
|
1271
|
+
|
|
1272
|
+
if self._fetching_lock:
|
|
1273
|
+
await asyncio.sleep(0.1)
|
|
1274
|
+
return await self.get_tools(tools, toolkits)
|
|
1275
|
+
|
|
1276
|
+
self._fetching_lock = True
|
|
1277
|
+
|
|
1278
|
+
try:
|
|
1279
|
+
with console.status("[cyan]Updating tools...", spinner="dots"):
|
|
1280
|
+
log.info("Fetching remote tools from Arcade API")
|
|
1281
|
+
|
|
1282
|
+
all_items: list[dict[str, Any]] = []
|
|
1283
|
+
offset = 0
|
|
1284
|
+
limit = 100
|
|
1285
|
+
total_fetched = 0
|
|
1286
|
+
|
|
1287
|
+
while True:
|
|
1288
|
+
tools_response = await self.arcade_client.tools.formatted.list(
|
|
1289
|
+
limit=limit,
|
|
1290
|
+
offset=offset,
|
|
1291
|
+
format="openai",
|
|
1292
|
+
user_id=self._default_user_id,
|
|
1293
|
+
)
|
|
1294
|
+
|
|
1295
|
+
items = getattr(tools_response, "items", [])
|
|
1296
|
+
if not items:
|
|
1297
|
+
break
|
|
1298
|
+
|
|
1299
|
+
all_items.extend(items)
|
|
1300
|
+
total_fetched += len(items)
|
|
1301
|
+
|
|
1302
|
+
total_count = getattr(tools_response, "total_count", 0)
|
|
1303
|
+
if total_fetched >= total_count or len(items) < limit:
|
|
1304
|
+
break
|
|
1305
|
+
|
|
1306
|
+
offset += limit
|
|
1307
|
+
|
|
1308
|
+
log.info(f"Fetched {total_fetched} total tools from Arcade API")
|
|
1309
|
+
|
|
1310
|
+
for tool in all_items:
|
|
1311
|
+
tool_name = self._extract_tool_name(tool)
|
|
1312
|
+
if tool_name:
|
|
1313
|
+
self._tools_cache.update_tools(tool_name, tool, save=False)
|
|
1314
|
+
|
|
1315
|
+
self._tools_cache.save()
|
|
1316
|
+
self._all_tools_fetched = True
|
|
1317
|
+
|
|
1318
|
+
finally:
|
|
1319
|
+
self._fetching_lock = False
|
|
1320
|
+
|
|
1321
|
+
return self._filter_tools_from_cache(tools, toolkits)
|
|
1322
|
+
|
|
1323
|
+
except Exception as e:
|
|
1324
|
+
log.error(f"Failed to fetch remote tools: {e}", exc_info=True)
|
|
1325
|
+
return []
|
|
1326
|
+
|
|
1327
|
+
def _extract_tool_name(self, tool: dict[str, Any]) -> str | None:
|
|
1328
|
+
"""Extract tool name from tool schema."""
|
|
1329
|
+
if isinstance(tool, dict):
|
|
1330
|
+
return tool.get("function", {}).get("name", "")
|
|
1331
|
+
return None
|
|
1332
|
+
|
|
1333
|
+
def _filter_tools_from_cache(
|
|
1334
|
+
self, tools: list[str] | None = None, toolkits: list[str] | None = None
|
|
1335
|
+
) -> list[dict[str, Any]]:
|
|
1336
|
+
"""Filter tools from cache based on tool names and toolkits."""
|
|
1337
|
+
all_cached = self._tools_cache.get_tools()
|
|
1338
|
+
|
|
1339
|
+
if not all_cached:
|
|
1340
|
+
return []
|
|
1341
|
+
|
|
1342
|
+
if not tools and not toolkits:
|
|
1343
|
+
return all_cached
|
|
1344
|
+
|
|
1345
|
+
filtered_tools: list[dict[str, Any]] = []
|
|
1346
|
+
toolkit_lower = [t.lower() for t in (toolkits or [])]
|
|
1347
|
+
tool_names_set = set(tools or [])
|
|
1348
|
+
|
|
1349
|
+
for tool in all_cached:
|
|
1350
|
+
tool_name = self._extract_tool_name(tool)
|
|
1351
|
+
if not tool_name:
|
|
1352
|
+
continue
|
|
1353
|
+
|
|
1354
|
+
tool_name_lower = tool_name.lower()
|
|
1355
|
+
|
|
1356
|
+
if tool_names_set and tool_name in tool_names_set:
|
|
1357
|
+
filtered_tools.append(tool)
|
|
1358
|
+
continue
|
|
1359
|
+
|
|
1360
|
+
if toolkit_lower:
|
|
1361
|
+
if any(
|
|
1362
|
+
tool_name_lower.startswith(f"{tk}_") or tool_name_lower.startswith(f"{tk}.")
|
|
1363
|
+
for tk in toolkit_lower
|
|
1364
|
+
):
|
|
1365
|
+
filtered_tools.append(tool)
|
|
1366
|
+
|
|
1367
|
+
return filtered_tools
|
|
1368
|
+
|
|
1369
|
+
async def execute(
|
|
1370
|
+
self,
|
|
1371
|
+
name: str,
|
|
1372
|
+
inputs: dict[str, Any],
|
|
1373
|
+
user_id: str | None = None,
|
|
1374
|
+
timeout: float = 120.0,
|
|
1375
|
+
) -> Any:
|
|
1376
|
+
"""Execute a remote tool via Arcade Cloud."""
|
|
1377
|
+
try:
|
|
1378
|
+
result = await asyncio.wait_for(
|
|
1379
|
+
self.arcade_client.tools.execute(
|
|
1380
|
+
tool_name=name,
|
|
1381
|
+
input=inputs,
|
|
1382
|
+
user_id=user_id or self._default_user_id,
|
|
1383
|
+
),
|
|
1384
|
+
timeout=timeout,
|
|
1385
|
+
)
|
|
1386
|
+
|
|
1387
|
+
if result.output and result.output.authorization:
|
|
1388
|
+
auth_response = result.output.authorization
|
|
1389
|
+
if auth_response.status in ("not_started", "pending"):
|
|
1390
|
+
await self._handle_authorization_and_retry(name, inputs, user_id, auth_response)
|
|
1391
|
+
result = await asyncio.wait_for(
|
|
1392
|
+
self.arcade_client.tools.execute(
|
|
1393
|
+
tool_name=name,
|
|
1394
|
+
input=inputs,
|
|
1395
|
+
user_id=user_id or self._default_user_id,
|
|
1396
|
+
),
|
|
1397
|
+
timeout=timeout,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
if result.output and result.output.error:
|
|
1401
|
+
error = result.output.error
|
|
1402
|
+
raise Exception(f"Tool execution failed: {error.message}")
|
|
1403
|
+
|
|
1404
|
+
return result.output.value if result.output else result
|
|
1405
|
+
|
|
1406
|
+
except Exception as e:
|
|
1407
|
+
error_str = str(e)
|
|
1408
|
+
|
|
1409
|
+
if "403" in error_str and "authorization required" in error_str.lower():
|
|
1410
|
+
try:
|
|
1411
|
+
auth_response = await self.arcade_client.tools.authorize(
|
|
1412
|
+
tool_name=name,
|
|
1413
|
+
user_id=user_id or self._default_user_id,
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
await self._handle_authorization_and_retry(name, inputs, user_id, auth_response)
|
|
1417
|
+
|
|
1418
|
+
result = await self.arcade_client.tools.execute(
|
|
1419
|
+
tool_name=name,
|
|
1420
|
+
input=inputs,
|
|
1421
|
+
user_id=user_id or self._default_user_id,
|
|
1422
|
+
)
|
|
1423
|
+
|
|
1424
|
+
if result.output and result.output.error:
|
|
1425
|
+
error = result.output.error
|
|
1426
|
+
raise Exception(f"Tool execution failed: {error.message}")
|
|
1427
|
+
|
|
1428
|
+
return result.output.value if result.output else result
|
|
1429
|
+
|
|
1430
|
+
except Exception as retry_error:
|
|
1431
|
+
log.error(f"Failed to authorize and retry tool '{name}': {retry_error}")
|
|
1432
|
+
raise
|
|
1433
|
+
|
|
1434
|
+
log.error(f"Remote tool execution error for '{name}': {e}")
|
|
1435
|
+
raise Exception(f"Failed to execute remote tool '{name}': {str(e)}")
|
|
1436
|
+
|
|
1437
|
+
async def _handle_authorization_and_retry(
|
|
1438
|
+
self,
|
|
1439
|
+
tool_name: str,
|
|
1440
|
+
inputs: dict[str, Any],
|
|
1441
|
+
user_id: str | None,
|
|
1442
|
+
auth_response: Any,
|
|
1443
|
+
) -> None:
|
|
1444
|
+
"""Handle authorization flow."""
|
|
1445
|
+
if auth_response.url:
|
|
1446
|
+
content = Text()
|
|
1447
|
+
content.append(f"Authorization required for '{tool_name}'.\n\n", style="bold yellow")
|
|
1448
|
+
content.append("Click this link to authorize:\n", style="white")
|
|
1449
|
+
content.append(auth_response.url, style="bold cyan underline")
|
|
1450
|
+
content.append("\n\nWaiting for authorization to complete...", style="white dim")
|
|
1451
|
+
|
|
1452
|
+
panel = Panel(
|
|
1453
|
+
content,
|
|
1454
|
+
title="[bold yellow]⚠ Authorization Required[/bold yellow]",
|
|
1455
|
+
title_align="left",
|
|
1456
|
+
border_style="yellow",
|
|
1457
|
+
padding=(0, 1),
|
|
1458
|
+
width=110,
|
|
1459
|
+
)
|
|
1460
|
+
console.print(panel)
|
|
1461
|
+
|
|
1462
|
+
log.info(f"Waiting for authorization to complete for tool '{tool_name}'...")
|
|
1463
|
+
completed_auth = await self.arcade_client.auth.wait_for_completion(auth_response)
|
|
1464
|
+
|
|
1465
|
+
if completed_auth.status == "completed":
|
|
1466
|
+
console.print(f"[green]✓[/green] Authorization completed for '{tool_name}'")
|
|
1467
|
+
log.info(f"Authorization completed for tool '{tool_name}'")
|
|
1468
|
+
else:
|
|
1469
|
+
raise Exception(
|
|
1470
|
+
f"Authorization failed for '{tool_name}': status={completed_auth.status}"
|
|
1471
|
+
)
|
|
1472
|
+
else:
|
|
1473
|
+
raise Exception(f"Authorization required for '{tool_name}' but no URL provided")
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
class CompositeToolManager(ToolManager):
|
|
1477
|
+
"""Manages local, remote (Arcade), and MCP tools simultaneously."""
|
|
1478
|
+
|
|
1479
|
+
def __init__(
|
|
1480
|
+
self,
|
|
1481
|
+
local_manager: LocalToolManager | None = None,
|
|
1482
|
+
remote_manager: RemoteToolManager | None = None,
|
|
1483
|
+
enable_mcp: bool = True,
|
|
1484
|
+
):
|
|
1485
|
+
self.local_manager = local_manager or LocalToolManager()
|
|
1486
|
+
self.remote_manager = remote_manager or RemoteToolManager(
|
|
1487
|
+
user_email=get_config().user_email,
|
|
1488
|
+
)
|
|
1489
|
+
self._tool_source_map: dict[str, str] = {}
|
|
1490
|
+
self._mcp_tool_to_manager: dict[str, MCPToolManager] = {}
|
|
1491
|
+
self._tools_cache = ToolCache()
|
|
1492
|
+
self._mcp_managers: list[MCPToolManager] = []
|
|
1493
|
+
self._enable_mcp = enable_mcp
|
|
1494
|
+
|
|
1495
|
+
tool_cfg = get_config().tool_settings
|
|
1496
|
+
self._default_remote_tool: list[str] = tool_cfg.included_tools.copy()
|
|
1497
|
+
self._default_remote_toolkit: list[str] = tool_cfg.included_toolkits.copy()
|
|
1498
|
+
|
|
1499
|
+
# Load MCP servers if enabled
|
|
1500
|
+
if enable_mcp:
|
|
1501
|
+
self._mcp_store = MCPServerStore()
|
|
1502
|
+
for server_config in self._mcp_store.list_enabled():
|
|
1503
|
+
mcp_manager = MCPToolManager(server_config)
|
|
1504
|
+
self._mcp_managers.append(mcp_manager)
|
|
1505
|
+
|
|
1506
|
+
async def get_tools(self) -> list[dict[str, Any]]:
|
|
1507
|
+
"""Get all available tools from local, remote, and MCP sources."""
|
|
1508
|
+
all_tools = []
|
|
1509
|
+
|
|
1510
|
+
# Local tools
|
|
1511
|
+
try:
|
|
1512
|
+
local_tools = await self.local_manager.get_tools()
|
|
1513
|
+
for tool in local_tools:
|
|
1514
|
+
tool_name = tool.get("function", {}).get("name", "unknown")
|
|
1515
|
+
self._tool_source_map[tool_name] = "local"
|
|
1516
|
+
all_tools.append(tool)
|
|
1517
|
+
log.info(f"Loaded {len(local_tools)} local tools")
|
|
1518
|
+
except Exception as e:
|
|
1519
|
+
log.error(f"Failed to load local tools: {e}")
|
|
1520
|
+
|
|
1521
|
+
# Remote (Arcade) tools
|
|
1522
|
+
if self.remote_manager:
|
|
1523
|
+
try:
|
|
1524
|
+
remote_tools = await self.remote_manager.get_tools(
|
|
1525
|
+
tools=self._default_remote_tool,
|
|
1526
|
+
toolkits=self._default_remote_toolkit,
|
|
1527
|
+
)
|
|
1528
|
+
for tool in remote_tools:
|
|
1529
|
+
self._tools_cache.update_tools(
|
|
1530
|
+
tool.get("function", {}).get("name", "unknown"), tool
|
|
1531
|
+
)
|
|
1532
|
+
|
|
1533
|
+
for tool in remote_tools:
|
|
1534
|
+
tool_name = tool.get("function", {}).get("name", "unknown")
|
|
1535
|
+
desc = tool["function"].get("description", "")
|
|
1536
|
+
tool["function"]["description"] = f"[Arcade Cloud] {desc}"
|
|
1537
|
+
self._tool_source_map[tool_name] = "remote"
|
|
1538
|
+
all_tools.append(tool)
|
|
1539
|
+
log.info(f"Loaded {len(remote_tools)} remote tools from Arcade Cloud")
|
|
1540
|
+
except Exception as e:
|
|
1541
|
+
log.error(f"Failed to load remote tools: {e}")
|
|
1542
|
+
|
|
1543
|
+
# MCP tools
|
|
1544
|
+
mcp_tool_count = 0
|
|
1545
|
+
for mcp_manager in self._mcp_managers:
|
|
1546
|
+
try:
|
|
1547
|
+
mcp_tools = await mcp_manager.get_tools()
|
|
1548
|
+
for tool in mcp_tools:
|
|
1549
|
+
tool_name = tool.get("function", {}).get("name", "unknown")
|
|
1550
|
+
self._tool_source_map[tool_name] = "mcp"
|
|
1551
|
+
self._mcp_tool_to_manager[tool_name] = mcp_manager
|
|
1552
|
+
all_tools.append(tool)
|
|
1553
|
+
mcp_tool_count += len(mcp_tools)
|
|
1554
|
+
|
|
1555
|
+
# Update server status
|
|
1556
|
+
if self._enable_mcp and hasattr(self, "_mcp_store"):
|
|
1557
|
+
self._mcp_store.update_status(mcp_manager.config.name, True, len(mcp_tools))
|
|
1558
|
+
except Exception as e:
|
|
1559
|
+
log.error(f"Failed to load tools from MCP '{mcp_manager.config.name}': {e}")
|
|
1560
|
+
|
|
1561
|
+
if mcp_tool_count > 0:
|
|
1562
|
+
log.info(f"Loaded {mcp_tool_count} tools from MCP servers")
|
|
1563
|
+
|
|
1564
|
+
# Summary
|
|
1565
|
+
sources = list(self._tool_source_map.values())
|
|
1566
|
+
local_count = sources.count("local")
|
|
1567
|
+
remote_count = sources.count("remote")
|
|
1568
|
+
mcp_count = sources.count("mcp")
|
|
1569
|
+
log.info(
|
|
1570
|
+
f"Total tools: {len(all_tools)} "
|
|
1571
|
+
f"(Local: {local_count}, Arcade: {remote_count}, MCP: {mcp_count})"
|
|
1572
|
+
)
|
|
1573
|
+
return all_tools
|
|
1574
|
+
|
|
1575
|
+
async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
|
|
1576
|
+
"""Execute a tool by routing to the appropriate manager."""
|
|
1577
|
+
source = self._tool_source_map.get(name)
|
|
1578
|
+
|
|
1579
|
+
if source == "local":
|
|
1580
|
+
log.debug(f"Executing local tool: {name}")
|
|
1581
|
+
return await self.local_manager.execute(name, inputs)
|
|
1582
|
+
elif source == "remote" and self.remote_manager:
|
|
1583
|
+
log.debug(f"Executing remote tool: {name}")
|
|
1584
|
+
return await self.remote_manager.execute(name, inputs, user_id=get_config().user_email)
|
|
1585
|
+
elif source == "mcp":
|
|
1586
|
+
mcp_manager = self._mcp_tool_to_manager.get(name)
|
|
1587
|
+
if mcp_manager:
|
|
1588
|
+
log.debug(f"Executing MCP tool: {name}")
|
|
1589
|
+
return await mcp_manager.execute(name, inputs)
|
|
1590
|
+
raise Exception(f"MCP manager not found for tool '{name}'")
|
|
1591
|
+
else:
|
|
1592
|
+
raise Exception(f"Tool '{name}' not found in local, remote, or MCP tools")
|
|
1593
|
+
|
|
1594
|
+
def is_interactive_tool(self, name: str) -> bool:
|
|
1595
|
+
"""Determine whether a tool is interactive based on its source."""
|
|
1596
|
+
source = self._tool_source_map.get(name)
|
|
1597
|
+
if source == "local" and hasattr(self.local_manager, "is_interactive_tool"):
|
|
1598
|
+
return bool(self.local_manager.is_interactive_tool(name))
|
|
1599
|
+
return False
|
|
1600
|
+
|
|
1601
|
+
def get_tool_source(self, name: str) -> str | None:
|
|
1602
|
+
"""Get the source type for a tool."""
|
|
1603
|
+
return self._tool_source_map.get(name)
|
|
1604
|
+
|
|
1605
|
+
def get_all_tool_info(self) -> list[dict[str, Any]]:
|
|
1606
|
+
"""Get info about all tools including their source."""
|
|
1607
|
+
tool_info = []
|
|
1608
|
+
for name, source in self._tool_source_map.items():
|
|
1609
|
+
info = {"name": name, "source": source}
|
|
1610
|
+
if source == "mcp" and name in self._mcp_tool_to_manager:
|
|
1611
|
+
info["server"] = self._mcp_tool_to_manager[name].config.name
|
|
1612
|
+
tool_info.append(info)
|
|
1613
|
+
return tool_info
|
|
1614
|
+
|
|
1615
|
+
async def close(self) -> None:
|
|
1616
|
+
"""Close all MCP connections."""
|
|
1617
|
+
for mcp_manager in self._mcp_managers:
|
|
1618
|
+
await mcp_manager.close()
|
|
1619
|
+
|
|
1620
|
+
|
|
1621
|
+
__all__ = [
|
|
1622
|
+
"ToolManager",
|
|
1623
|
+
"LocalToolManager",
|
|
1624
|
+
"RemoteToolManager",
|
|
1625
|
+
"MCPToolManager",
|
|
1626
|
+
"MCPOAuthHandler",
|
|
1627
|
+
"CompositeToolManager",
|
|
1628
|
+
"ToolCache",
|
|
1629
|
+
"CacheEntry",
|
|
1630
|
+
"ToolAuthorizationRequired",
|
|
1631
|
+
"MCPServerConfig",
|
|
1632
|
+
"MCPServerStore",
|
|
1633
|
+
"MCPAuthType",
|
|
1634
|
+
"MCPOAuthTokens",
|
|
1635
|
+
]
|