cade-cli 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. cade_cli-0.3.3.dist-info/METADATA +151 -0
  2. cade_cli-0.3.3.dist-info/RECORD +44 -0
  3. cade_cli-0.3.3.dist-info/WHEEL +4 -0
  4. cade_cli-0.3.3.dist-info/entry_points.txt +2 -0
  5. cadecoder/__init__.py +1 -0
  6. cadecoder/ai/__init__.py +6 -0
  7. cadecoder/ai/prompts.py +572 -0
  8. cadecoder/cli/__init__.py +0 -0
  9. cadecoder/cli/app.py +147 -0
  10. cadecoder/cli/auth.py +483 -0
  11. cadecoder/cli/commands/__init__.py +5 -0
  12. cadecoder/cli/commands/auth.py +143 -0
  13. cadecoder/cli/commands/chat.py +264 -0
  14. cadecoder/cli/commands/mcp.py +477 -0
  15. cadecoder/cli/commands/tools.py +226 -0
  16. cadecoder/core/__init__.py +12 -0
  17. cadecoder/core/config.py +380 -0
  18. cadecoder/core/constants.py +281 -0
  19. cadecoder/core/errors.py +145 -0
  20. cadecoder/core/logging.py +148 -0
  21. cadecoder/core/types.py +235 -0
  22. cadecoder/core/utils.py +279 -0
  23. cadecoder/execution/__init__.py +46 -0
  24. cadecoder/execution/context_window.py +521 -0
  25. cadecoder/execution/orchestrator.py +562 -0
  26. cadecoder/execution/parallel.py +287 -0
  27. cadecoder/providers/__init__.py +60 -0
  28. cadecoder/providers/base.py +294 -0
  29. cadecoder/providers/openai.py +251 -0
  30. cadecoder/storage/__init__.py +0 -0
  31. cadecoder/storage/threads.py +489 -0
  32. cadecoder/templates/login_failed.html +21 -0
  33. cadecoder/templates/login_success.html +21 -0
  34. cadecoder/templates/styles.css +87 -0
  35. cadecoder/tools/__init__.py +19 -0
  36. cadecoder/tools/builtin.py +644 -0
  37. cadecoder/tools/filesystem.py +315 -0
  38. cadecoder/tools/git.py +221 -0
  39. cadecoder/tools/manager.py +1635 -0
  40. cadecoder/ui/__init__.py +7 -0
  41. cadecoder/ui/display.py +338 -0
  42. cadecoder/ui/input.py +145 -0
  43. cadecoder/ui/session.py +455 -0
  44. cadecoder/ui/state.py +20 -0
@@ -0,0 +1,1635 @@
1
+ """Unified tool management system supporting local, remote, and MCP tools."""
2
+
3
+ import asyncio
4
+ import inspect
5
+ import json
6
+ import uuid
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from datetime import datetime, timedelta
10
+ from enum import Enum
11
+ from pathlib import Path
12
+ from typing import Annotated, Any, get_args, get_origin
13
+
14
+ import httpx
15
+ from arcadepy import AsyncArcade
16
+ from rich.console import Console
17
+ from rich.panel import Panel
18
+ from rich.text import Text
19
+
20
+ from cadecoder.core.config import get_config
21
+ from cadecoder.core.logging import log
22
+
23
+ # Type aliases
24
+ JsonToolSchema = dict[str, Any]
25
+
26
+ # Console for output
27
+ console = Console(stderr=True)
28
+
29
+
30
+ # =============================================================================
31
+ # MCP Server Configuration
32
+ # =============================================================================
33
+
34
+
35
+ class MCPAuthType(str, Enum):
36
+ """Authentication types for MCP servers."""
37
+
38
+ NONE = "none"
39
+ BEARER = "bearer" # Static bearer token
40
+ API_KEY = "api_key" # Static API key
41
+ OAUTH = "oauth" # Full OAuth 2.1 flow (MCP spec compliant)
42
+
43
+
44
+ @dataclass
45
+ class MCPOAuthTokens:
46
+ """OAuth tokens for an MCP server."""
47
+
48
+ access_token: str
49
+ token_type: str = "Bearer"
50
+ refresh_token: str | None = None
51
+ expires_at: datetime | None = None
52
+ scope: str | None = None
53
+
54
+ def is_expired(self) -> bool:
55
+ """Check if access token is expired."""
56
+ if self.expires_at is None:
57
+ return False
58
+ return datetime.now() >= self.expires_at
59
+
60
+ def to_dict(self) -> dict[str, Any]:
61
+ """Convert to dictionary."""
62
+ return {
63
+ "access_token": self.access_token,
64
+ "token_type": self.token_type,
65
+ "refresh_token": self.refresh_token,
66
+ "expires_at": self.expires_at.isoformat() if self.expires_at else None,
67
+ "scope": self.scope,
68
+ }
69
+
70
+ @classmethod
71
+ def from_dict(cls, data: dict[str, Any]) -> "MCPOAuthTokens":
72
+ """Create from dictionary."""
73
+ expires_at = None
74
+ if data.get("expires_at"):
75
+ expires_at = datetime.fromisoformat(data["expires_at"])
76
+ return cls(
77
+ access_token=data["access_token"],
78
+ token_type=data.get("token_type", "Bearer"),
79
+ refresh_token=data.get("refresh_token"),
80
+ expires_at=expires_at,
81
+ scope=data.get("scope"),
82
+ )
83
+
84
+
85
+ @dataclass
86
+ class MCPServerConfig:
87
+ """Configuration for an MCP server connection."""
88
+
89
+ name: str
90
+ url: str
91
+ auth_type: MCPAuthType = MCPAuthType.NONE
92
+ auth_value: str | None = None # For bearer/api_key
93
+ enabled: bool = True
94
+ last_connected: datetime | None = None
95
+ tool_count: int = 0
96
+ # OAuth-specific fields
97
+ oauth_tokens: MCPOAuthTokens | None = None
98
+ oauth_authorization_server: str | None = None
99
+ oauth_client_id: str | None = None
100
+ oauth_scopes: list[str] | None = None
101
+
102
+ def to_dict(self) -> dict[str, Any]:
103
+ """Convert to dictionary for serialization."""
104
+ return {
105
+ "name": self.name,
106
+ "url": self.url,
107
+ "auth_type": self.auth_type.value,
108
+ "auth_value": self.auth_value,
109
+ "enabled": self.enabled,
110
+ "last_connected": (self.last_connected.isoformat() if self.last_connected else None),
111
+ "tool_count": self.tool_count,
112
+ "oauth_tokens": self.oauth_tokens.to_dict() if self.oauth_tokens else None,
113
+ "oauth_authorization_server": self.oauth_authorization_server,
114
+ "oauth_client_id": self.oauth_client_id,
115
+ "oauth_scopes": self.oauth_scopes,
116
+ }
117
+
118
+ @classmethod
119
+ def from_dict(cls, data: dict[str, Any]) -> "MCPServerConfig":
120
+ """Create from dictionary."""
121
+ last_connected = None
122
+ if data.get("last_connected"):
123
+ last_connected = datetime.fromisoformat(data["last_connected"])
124
+
125
+ oauth_tokens = None
126
+ if data.get("oauth_tokens"):
127
+ oauth_tokens = MCPOAuthTokens.from_dict(data["oauth_tokens"])
128
+
129
+ return cls(
130
+ name=data["name"],
131
+ url=data["url"],
132
+ auth_type=MCPAuthType(data.get("auth_type", "none")),
133
+ auth_value=data.get("auth_value"),
134
+ enabled=data.get("enabled", True),
135
+ last_connected=last_connected,
136
+ tool_count=data.get("tool_count", 0),
137
+ oauth_tokens=oauth_tokens,
138
+ oauth_authorization_server=data.get("oauth_authorization_server"),
139
+ oauth_client_id=data.get("oauth_client_id"),
140
+ oauth_scopes=data.get("oauth_scopes"),
141
+ )
142
+
143
+
144
+ class MCPServerStore:
145
+ """Persistent storage for MCP server configurations."""
146
+
147
+ def __init__(self, config_dir: Path | None = None) -> None:
148
+ if config_dir is None:
149
+ config_dir = Path(get_config().app_dir)
150
+ self.config_dir = config_dir
151
+ self.config_file = config_dir / "mcp_servers.json"
152
+ self._servers: dict[str, MCPServerConfig] = {}
153
+ self._load()
154
+
155
+ def _load(self) -> None:
156
+ """Load servers from disk."""
157
+ if not self.config_file.exists():
158
+ return
159
+
160
+ try:
161
+ with open(self.config_file, encoding="utf-8") as f:
162
+ data = json.load(f)
163
+
164
+ for server_data in data.get("servers", []):
165
+ server = MCPServerConfig.from_dict(server_data)
166
+ self._servers[server.name] = server
167
+
168
+ log.debug(f"Loaded {len(self._servers)} MCP server configs")
169
+ except Exception as e:
170
+ log.warning(f"Failed to load MCP server configs: {e}")
171
+
172
+ def _save(self) -> None:
173
+ """Save servers to disk."""
174
+ self.config_dir.mkdir(parents=True, exist_ok=True)
175
+
176
+ try:
177
+ data = {
178
+ "servers": [s.to_dict() for s in self._servers.values()],
179
+ "updated_at": datetime.now().isoformat(),
180
+ }
181
+ with open(self.config_file, "w", encoding="utf-8") as f:
182
+ json.dump(data, f, indent=2)
183
+ except Exception as e:
184
+ log.warning(f"Failed to save MCP server configs: {e}")
185
+
186
+ def add(self, server: MCPServerConfig) -> None:
187
+ """Add or update a server configuration."""
188
+ self._servers[server.name] = server
189
+ self._save()
190
+
191
+ def remove(self, name: str) -> bool:
192
+ """Remove a server configuration."""
193
+ if name in self._servers:
194
+ del self._servers[name]
195
+ self._save()
196
+ return True
197
+ return False
198
+
199
+ def get(self, name: str) -> MCPServerConfig | None:
200
+ """Get a server configuration by name."""
201
+ return self._servers.get(name)
202
+
203
+ def list_all(self) -> list[MCPServerConfig]:
204
+ """List all server configurations."""
205
+ return list(self._servers.values())
206
+
207
+ def list_enabled(self) -> list[MCPServerConfig]:
208
+ """List enabled server configurations."""
209
+ return [s for s in self._servers.values() if s.enabled]
210
+
211
+ def update_status(self, name: str, connected: bool, tool_count: int = 0) -> None:
212
+ """Update server connection status."""
213
+ if name in self._servers:
214
+ if connected:
215
+ self._servers[name].last_connected = datetime.now()
216
+ self._servers[name].tool_count = tool_count
217
+ self._save()
218
+
219
+
220
+ class ToolAuthorizationRequired(Exception):
221
+ """Exception raised when a tool requires user authorization."""
222
+
223
+ def __init__(self, tool_name: str, authorization_url: str | None = None):
224
+ self.tool_name = tool_name
225
+ self.authorization_url = authorization_url
226
+
227
+ if authorization_url:
228
+ message = (
229
+ f"Authorization required for '{tool_name}'.\n\n"
230
+ f"Please authorize by clicking this link:\n{authorization_url}\n\n"
231
+ f"After authorizing, let me know and I'll retry the operation."
232
+ )
233
+ else:
234
+ message = (
235
+ f"Authorization required for '{tool_name}'.\n\n"
236
+ f"Please authorize this tool in your Arcade account."
237
+ )
238
+
239
+ super().__init__(message)
240
+
241
+
242
+ class ToolManager(ABC):
243
+ """Base interface for tool management."""
244
+
245
+ @abstractmethod
246
+ async def get_tools(self) -> list[dict[str, Any]]:
247
+ """Get all available tools as jsonschemas."""
248
+ pass
249
+
250
+ @abstractmethod
251
+ async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
252
+ """Execute a tool by name with inputs."""
253
+ pass
254
+
255
+
256
+ # =============================================================================
257
+ # MCP OAuth 2.1 Support (RFC 9728, RFC 8414)
258
+ # =============================================================================
259
+
260
+
261
+ class MCPOAuthHandler:
262
+ """Handles OAuth 2.1 authorization for MCP servers per MCP spec."""
263
+
264
+ def __init__(self, server_config: MCPServerConfig) -> None:
265
+ self.config = server_config
266
+ self._client: httpx.AsyncClient | None = None
267
+
268
+ async def _ensure_client(self) -> httpx.AsyncClient:
269
+ """Ensure HTTP client exists."""
270
+ if self._client is None:
271
+ self._client = httpx.AsyncClient(timeout=httpx.Timeout(30.0))
272
+ return self._client
273
+
274
+ def parse_www_authenticate(self, header: str) -> dict[str, str]:
275
+ """Parse WWW-Authenticate header to extract OAuth parameters.
276
+
277
+ Per MCP spec, extracts: resource_metadata, scope, error, error_description
278
+ """
279
+ params: dict[str, str] = {}
280
+ # Remove "Bearer " prefix if present
281
+ if header.lower().startswith("bearer "):
282
+ header = header[7:]
283
+
284
+ # Parse key="value" pairs
285
+ import re
286
+
287
+ pattern = r'(\w+)="([^"]*)"'
288
+ for match in re.finditer(pattern, header):
289
+ params[match.group(1)] = match.group(2)
290
+
291
+ return params
292
+
293
+ async def discover_from_401(self, response: httpx.Response) -> tuple[str | None, str | None]:
294
+ """Discover authorization server from 401 response.
295
+
296
+ Returns: (resource_metadata_url, required_scope)
297
+ """
298
+ www_auth = response.headers.get("WWW-Authenticate", "")
299
+ params = self.parse_www_authenticate(www_auth)
300
+
301
+ resource_metadata_url = params.get("resource_metadata")
302
+ scope = params.get("scope")
303
+
304
+ return resource_metadata_url, scope
305
+
306
+ async def fetch_protected_resource_metadata(
307
+ self, metadata_url: str | None = None
308
+ ) -> dict[str, Any] | None:
309
+ """Fetch Protected Resource Metadata per RFC 9728.
310
+
311
+ Tries:
312
+ 1. Provided metadata_url (from WWW-Authenticate)
313
+ 2. Well-known URI with path: /.well-known/oauth-protected-resource/<path>
314
+ 3. Well-known URI at root: /.well-known/oauth-protected-resource
315
+ """
316
+ client = await self._ensure_client()
317
+ from urllib.parse import urlparse
318
+
319
+ parsed = urlparse(self.config.url)
320
+ base = f"{parsed.scheme}://{parsed.netloc}"
321
+
322
+ urls_to_try = []
323
+ if metadata_url:
324
+ urls_to_try.append(metadata_url)
325
+
326
+ # Well-known with path
327
+ if parsed.path and parsed.path != "/":
328
+ path = parsed.path.rstrip("/")
329
+ urls_to_try.append(f"{base}/.well-known/oauth-protected-resource{path}")
330
+
331
+ # Well-known at root
332
+ urls_to_try.append(f"{base}/.well-known/oauth-protected-resource")
333
+
334
+ for url in urls_to_try:
335
+ try:
336
+ resp = await client.get(url)
337
+ if resp.status_code == 200:
338
+ return resp.json()
339
+ except Exception:
340
+ continue
341
+
342
+ return None
343
+
344
+ async def fetch_authorization_server_metadata(self, issuer: str) -> dict[str, Any] | None:
345
+ """Fetch Authorization Server Metadata per RFC 8414.
346
+
347
+ Tries both OAuth 2.0 and OpenID Connect discovery endpoints.
348
+ """
349
+ client = await self._ensure_client()
350
+ from urllib.parse import urlparse
351
+
352
+ parsed = urlparse(issuer)
353
+ base = f"{parsed.scheme}://{parsed.netloc}"
354
+ path = parsed.path.rstrip("/") if parsed.path else ""
355
+
356
+ # Priority order per MCP spec
357
+ urls_to_try = []
358
+ if path:
359
+ # With path component
360
+ urls_to_try.extend(
361
+ [
362
+ f"{base}/.well-known/oauth-authorization-server{path}",
363
+ f"{base}/.well-known/openid-configuration{path}",
364
+ f"{base}{path}/.well-known/openid-configuration",
365
+ ]
366
+ )
367
+ else:
368
+ # Without path
369
+ urls_to_try.extend(
370
+ [
371
+ f"{base}/.well-known/oauth-authorization-server",
372
+ f"{base}/.well-known/openid-configuration",
373
+ ]
374
+ )
375
+
376
+ for url in urls_to_try:
377
+ try:
378
+ resp = await client.get(url)
379
+ if resp.status_code == 200:
380
+ metadata = resp.json()
381
+ # Verify PKCE support (required by MCP spec)
382
+ if "code_challenge_methods_supported" not in metadata:
383
+ log.warning(f"Auth server {url} doesn't advertise PKCE support")
384
+ return metadata
385
+ except Exception:
386
+ continue
387
+
388
+ return None
389
+
390
+ def generate_pkce(self) -> tuple[str, str]:
391
+ """Generate PKCE code_verifier and code_challenge (S256)."""
392
+ import base64
393
+ import hashlib
394
+ import secrets
395
+
396
+ # Generate code_verifier (43-128 chars)
397
+ code_verifier = secrets.token_urlsafe(32)
398
+
399
+ # Generate code_challenge using S256
400
+ digest = hashlib.sha256(code_verifier.encode()).digest()
401
+ code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
402
+
403
+ return code_verifier, code_challenge
404
+
405
+ async def start_oauth_flow(
406
+ self,
407
+ auth_server_metadata: dict[str, Any],
408
+ scope: str | None = None,
409
+ client_id: str | None = None,
410
+ ) -> tuple[str, str, str]:
411
+ """Start OAuth authorization flow.
412
+
413
+ Returns: (authorization_url, code_verifier, state)
414
+ """
415
+ import secrets
416
+ from urllib.parse import urlencode
417
+
418
+ auth_endpoint = auth_server_metadata.get("authorization_endpoint")
419
+ if not auth_endpoint:
420
+ raise Exception("No authorization_endpoint in auth server metadata")
421
+
422
+ code_verifier, code_challenge = self.generate_pkce()
423
+ state = secrets.token_urlsafe(16)
424
+
425
+ # Use client_id or generate one
426
+ if not client_id:
427
+ client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
428
+
429
+ # Determine scope
430
+ if not scope:
431
+ scope = " ".join(self.config.oauth_scopes or [])
432
+ if not scope:
433
+ # Use scopes_supported from metadata
434
+ supported = auth_server_metadata.get("scopes_supported", [])
435
+ scope = " ".join(supported) if supported else ""
436
+
437
+ params = {
438
+ "response_type": "code",
439
+ "client_id": client_id,
440
+ "redirect_uri": "http://127.0.0.1:9876/callback",
441
+ "code_challenge": code_challenge,
442
+ "code_challenge_method": "S256",
443
+ "state": state,
444
+ "resource": self.config.url, # RFC 8707 resource indicator
445
+ }
446
+ if scope:
447
+ params["scope"] = scope
448
+
449
+ auth_url = f"{auth_endpoint}?{urlencode(params)}"
450
+ return auth_url, code_verifier, state
451
+
452
+ async def exchange_code_for_tokens(
453
+ self,
454
+ auth_server_metadata: dict[str, Any],
455
+ code: str,
456
+ code_verifier: str,
457
+ client_id: str | None = None,
458
+ ) -> MCPOAuthTokens:
459
+ """Exchange authorization code for tokens."""
460
+ client = await self._ensure_client()
461
+
462
+ token_endpoint = auth_server_metadata.get("token_endpoint")
463
+ if not token_endpoint:
464
+ raise Exception("No token_endpoint in auth server metadata")
465
+
466
+ if not client_id:
467
+ client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
468
+
469
+ data = {
470
+ "grant_type": "authorization_code",
471
+ "code": code,
472
+ "redirect_uri": "http://127.0.0.1:9876/callback",
473
+ "client_id": client_id,
474
+ "code_verifier": code_verifier,
475
+ "resource": self.config.url,
476
+ }
477
+
478
+ resp = await client.post(
479
+ token_endpoint,
480
+ data=data,
481
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
482
+ )
483
+
484
+ if resp.status_code != 200:
485
+ raise Exception(f"Token exchange failed: {resp.text}")
486
+
487
+ token_data = resp.json()
488
+ expires_at = None
489
+ if "expires_in" in token_data:
490
+ expires_at = datetime.now() + timedelta(seconds=token_data["expires_in"])
491
+
492
+ return MCPOAuthTokens(
493
+ access_token=token_data["access_token"],
494
+ token_type=token_data.get("token_type", "Bearer"),
495
+ refresh_token=token_data.get("refresh_token"),
496
+ expires_at=expires_at,
497
+ scope=token_data.get("scope"),
498
+ )
499
+
500
+ async def refresh_tokens(
501
+ self,
502
+ auth_server_metadata: dict[str, Any],
503
+ refresh_token: str,
504
+ client_id: str | None = None,
505
+ ) -> MCPOAuthTokens:
506
+ """Refresh access token using refresh token."""
507
+ client = await self._ensure_client()
508
+
509
+ token_endpoint = auth_server_metadata.get("token_endpoint")
510
+ if not token_endpoint:
511
+ raise Exception("No token_endpoint in auth server metadata")
512
+
513
+ if not client_id:
514
+ client_id = self.config.oauth_client_id or f"cade-mcp-{self.config.name}"
515
+
516
+ data = {
517
+ "grant_type": "refresh_token",
518
+ "refresh_token": refresh_token,
519
+ "client_id": client_id,
520
+ "resource": self.config.url,
521
+ }
522
+
523
+ resp = await client.post(
524
+ token_endpoint,
525
+ data=data,
526
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
527
+ )
528
+
529
+ if resp.status_code != 200:
530
+ raise Exception(f"Token refresh failed: {resp.text}")
531
+
532
+ token_data = resp.json()
533
+ expires_at = None
534
+ if "expires_in" in token_data:
535
+ expires_at = datetime.now() + timedelta(seconds=token_data["expires_in"])
536
+
537
+ return MCPOAuthTokens(
538
+ access_token=token_data["access_token"],
539
+ token_type=token_data.get("token_type", "Bearer"),
540
+ refresh_token=token_data.get("refresh_token", refresh_token),
541
+ expires_at=expires_at,
542
+ scope=token_data.get("scope"),
543
+ )
544
+
545
+ async def close(self) -> None:
546
+ """Close HTTP client."""
547
+ if self._client:
548
+ await self._client.aclose()
549
+ self._client = None
550
+
551
+
552
+ # =============================================================================
553
+ # MCP Tool Manager (HTTP Streamable JSON-RPC)
554
+ # =============================================================================
555
+
556
+
557
+ class MCPToolManager(ToolManager):
558
+ """Manages tools from MCP servers over HTTP (JSON-RPC streamable transport)."""
559
+
560
+ def __init__(
561
+ self,
562
+ server_config: MCPServerConfig,
563
+ server_store: "MCPServerStore | None" = None,
564
+ ) -> None:
565
+ self.config = server_config
566
+ self._server_store = server_store
567
+ self._tools_cache: list[dict[str, Any]] | None = None
568
+ self._initialized = False
569
+ self._client: httpx.AsyncClient | None = None
570
+ self._session_id: str | None = None
571
+ self._oauth_handler = MCPOAuthHandler(server_config)
572
+ self._auth_server_metadata: dict[str, Any] | None = None
573
+
574
+ def _get_headers(self) -> dict[str, str]:
575
+ """Get HTTP headers including authentication and session ID."""
576
+ headers = {
577
+ "Content-Type": "application/json",
578
+ "Accept": "application/json, text/event-stream",
579
+ }
580
+
581
+ # Include MCP session ID if we have one (required for HTTP transport)
582
+ if self._session_id:
583
+ headers["Mcp-Session-Id"] = self._session_id
584
+
585
+ # Handle different auth types
586
+ if self.config.auth_type == MCPAuthType.BEARER and self.config.auth_value:
587
+ headers["Authorization"] = f"Bearer {self.config.auth_value}"
588
+ elif self.config.auth_type == MCPAuthType.API_KEY and self.config.auth_value:
589
+ headers["X-API-Key"] = self.config.auth_value
590
+ elif self.config.auth_type == MCPAuthType.OAUTH and self.config.oauth_tokens:
591
+ # Use OAuth token
592
+ token = self.config.oauth_tokens
593
+ headers["Authorization"] = f"{token.token_type} {token.access_token}"
594
+
595
+ return headers
596
+
597
+ async def _ensure_client(self) -> httpx.AsyncClient:
598
+ """Ensure HTTP client is created with current auth headers."""
599
+ # Only recreate if no client exists or if we need to refresh tokens
600
+ if self._client is None:
601
+ self._client = httpx.AsyncClient(
602
+ timeout=httpx.Timeout(60.0, connect=10.0),
603
+ )
604
+ return self._client
605
+
606
+ async def _maybe_refresh_oauth_token(self) -> bool:
607
+ """Refresh OAuth token if expired. Returns True if refreshed."""
608
+ if self.config.auth_type != MCPAuthType.OAUTH:
609
+ return False
610
+
611
+ tokens = self.config.oauth_tokens
612
+ if not tokens or not tokens.is_expired():
613
+ return False
614
+
615
+ if not tokens.refresh_token:
616
+ return False
617
+
618
+ if not self._auth_server_metadata:
619
+ # Try to fetch it
620
+ rs_meta = await self._oauth_handler.fetch_protected_resource_metadata()
621
+ if rs_meta:
622
+ auth_servers = rs_meta.get("authorization_servers", [])
623
+ if auth_servers:
624
+ self._auth_server_metadata = (
625
+ await self._oauth_handler.fetch_authorization_server_metadata(
626
+ auth_servers[0]
627
+ )
628
+ )
629
+
630
+ if not self._auth_server_metadata:
631
+ return False
632
+
633
+ try:
634
+ new_tokens = await self._oauth_handler.refresh_tokens(
635
+ self._auth_server_metadata,
636
+ tokens.refresh_token,
637
+ self.config.oauth_client_id,
638
+ )
639
+ self.config.oauth_tokens = new_tokens
640
+
641
+ # Persist updated tokens
642
+ if self._server_store:
643
+ self._server_store.add(self.config)
644
+
645
+ log.info(f"Refreshed OAuth token for MCP server '{self.config.name}'")
646
+ return True
647
+ except Exception as e:
648
+ log.warning(f"Failed to refresh OAuth token: {e}")
649
+ return False
650
+
651
+ async def _handle_401_response(self, response: httpx.Response) -> str | None:
652
+ """Handle 401 response per MCP OAuth spec.
653
+
654
+ Returns authorization URL if OAuth flow should be initiated.
655
+ """
656
+ # Parse WWW-Authenticate header
657
+ metadata_url, scope = await self._oauth_handler.discover_from_401(response)
658
+
659
+ # Fetch Protected Resource Metadata
660
+ rs_metadata = await self._oauth_handler.fetch_protected_resource_metadata(metadata_url)
661
+ if not rs_metadata:
662
+ return None
663
+
664
+ # Get authorization server
665
+ auth_servers = rs_metadata.get("authorization_servers", [])
666
+ if not auth_servers:
667
+ return None
668
+
669
+ # Fetch Authorization Server Metadata
670
+ as_metadata = await self._oauth_handler.fetch_authorization_server_metadata(auth_servers[0])
671
+ if not as_metadata:
672
+ return None
673
+
674
+ self._auth_server_metadata = as_metadata
675
+ self.config.oauth_authorization_server = auth_servers[0]
676
+
677
+ # Generate authorization URL
678
+ auth_url, code_verifier, state = await self._oauth_handler.start_oauth_flow(
679
+ as_metadata, scope, self.config.oauth_client_id
680
+ )
681
+
682
+ # Store PKCE verifier for later token exchange (temporary)
683
+ self._pending_code_verifier = code_verifier
684
+ self._pending_state = state
685
+
686
+ return auth_url
687
+
688
+ async def _send_notification(self, method: str) -> None:
689
+ """Send a JSON-RPC notification (no response expected).
690
+
691
+ Per MCP spec, notifications don't have an 'id' and don't expect a response.
692
+ Used for 'notifications/initialized' and similar methods.
693
+ """
694
+ client = await self._ensure_client()
695
+
696
+ # Notifications don't have an 'id' field
697
+ payload: dict[str, Any] = {
698
+ "jsonrpc": "2.0",
699
+ "method": method,
700
+ }
701
+
702
+ headers = self._get_headers()
703
+ try:
704
+ response = await client.post(self.config.url, json=payload, headers=headers)
705
+ # We don't expect a meaningful response for notifications
706
+ # Some servers return 202 Accepted, others return 200
707
+ if response.status_code not in (200, 202, 204):
708
+ log.warning(
709
+ f"MCP notification '{method}' got unexpected status: {response.status_code}"
710
+ )
711
+ except Exception as e:
712
+ log.warning(f"MCP notification '{method}' failed: {e}")
713
+
714
+ async def _send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
715
+ """Send a JSON-RPC request to the MCP server.
716
+
717
+ Handles MCP HTTP transport session management by capturing and including
718
+ the Mcp-Session-Id header across requests.
719
+ """
720
+ # Refresh token if needed
721
+ await self._maybe_refresh_oauth_token()
722
+
723
+ client = await self._ensure_client()
724
+
725
+ request_id = str(uuid.uuid4())
726
+ payload: dict[str, Any] = {
727
+ "jsonrpc": "2.0",
728
+ "id": request_id,
729
+ "method": method,
730
+ }
731
+ if params:
732
+ payload["params"] = params
733
+
734
+ try:
735
+ # Include current headers (with session ID if available)
736
+ headers = self._get_headers()
737
+ response = await client.post(self.config.url, json=payload, headers=headers)
738
+
739
+ # Capture session ID from response for subsequent requests
740
+ # Per MCP HTTP transport spec, server sends Mcp-Session-Id header
741
+ if "mcp-session-id" in response.headers:
742
+ self._session_id = response.headers["mcp-session-id"]
743
+ log.debug(
744
+ f"Captured MCP session ID for '{self.config.name}': {self._session_id[:8]}..."
745
+ )
746
+
747
+ # Handle 401 - OAuth discovery
748
+ if response.status_code == 401:
749
+ auth_url = await self._handle_401_response(response)
750
+ raise ToolAuthorizationRequired(
751
+ f"MCP:{self.config.name}",
752
+ authorization_url=auth_url,
753
+ )
754
+
755
+ # Handle 403 - Insufficient scope
756
+ if response.status_code == 403:
757
+ www_auth = response.headers.get("WWW-Authenticate", "")
758
+ auth_params = self._oauth_handler.parse_www_authenticate(www_auth)
759
+ if auth_params.get("error") == "insufficient_scope":
760
+ raise ToolAuthorizationRequired(
761
+ f"MCP:{self.config.name}",
762
+ authorization_url=None,
763
+ )
764
+
765
+ response.raise_for_status()
766
+
767
+ # Handle SSE or direct JSON response
768
+ content_type = response.headers.get("content-type", "")
769
+
770
+ if "text/event-stream" in content_type:
771
+ # Parse SSE response
772
+ return await self._parse_sse_response(response, request_id)
773
+ else:
774
+ # Direct JSON response
775
+ result = response.json()
776
+ if "error" in result:
777
+ raise Exception(f"MCP error: {result['error'].get('message', 'Unknown error')}")
778
+ return result.get("result")
779
+
780
+ except httpx.HTTPStatusError as e:
781
+ raise Exception(f"MCP HTTP error: {e}")
782
+ except ToolAuthorizationRequired:
783
+ raise
784
+ except Exception as e:
785
+ log.error(f"MCP request failed for {self.config.name}: {e}")
786
+ raise
787
+
788
+ async def _parse_sse_response(self, response: httpx.Response, request_id: str) -> Any:
789
+ """Parse Server-Sent Events response."""
790
+ result = None
791
+ async for line in response.aiter_lines():
792
+ if line.startswith("data: "):
793
+ data = line[6:]
794
+ if data.strip():
795
+ try:
796
+ event = json.loads(data)
797
+ if event.get("id") == request_id:
798
+ if "error" in event:
799
+ raise Exception(f"MCP error: {event['error'].get('message')}")
800
+ result = event.get("result")
801
+ except json.JSONDecodeError:
802
+ continue
803
+ return result
804
+
805
+ async def initialize(self) -> bool:
806
+ """Initialize connection to MCP server.
807
+
808
+ Follows MCP protocol:
809
+ 1. Send 'initialize' request
810
+ 2. Receive server capabilities
811
+ 3. Send 'notifications/initialized' notification
812
+ """
813
+ if self._initialized:
814
+ return True
815
+
816
+ try:
817
+ result = await self._send_request(
818
+ "initialize",
819
+ {
820
+ "protocolVersion": "2024-11-05",
821
+ "capabilities": {"tools": {}},
822
+ "clientInfo": {"name": "cade", "version": "1.0.0"},
823
+ },
824
+ )
825
+
826
+ if result:
827
+ self._initialized = True
828
+ # Send initialized notification (no response expected)
829
+ await self._send_notification("notifications/initialized")
830
+ log.info(f"MCP server '{self.config.name}' initialized")
831
+ return True
832
+
833
+ return False
834
+
835
+ except Exception as e:
836
+ log.error(f"Failed to initialize MCP server '{self.config.name}': {e}")
837
+ return False
838
+
839
+ async def get_tools(self) -> list[dict[str, Any]]:
840
+ """Get available tools from the MCP server."""
841
+ if self._tools_cache is not None:
842
+ return self._tools_cache
843
+
844
+ if not self._initialized:
845
+ success = await self.initialize()
846
+ if not success:
847
+ return []
848
+
849
+ try:
850
+ result = await self._send_request("tools/list")
851
+
852
+ if not result or "tools" not in result:
853
+ return []
854
+
855
+ # Convert MCP tool format to OpenAI function format
856
+ self._tools_cache = []
857
+ for mcp_tool in result["tools"]:
858
+ openai_tool = self._convert_mcp_to_openai(mcp_tool)
859
+ self._tools_cache.append(openai_tool)
860
+
861
+ log.info(f"Loaded {len(self._tools_cache)} tools from MCP server '{self.config.name}'")
862
+ return self._tools_cache
863
+
864
+ except Exception as e:
865
+ log.error(f"Failed to list tools from MCP '{self.config.name}': {e}")
866
+ return []
867
+
868
+ def _convert_mcp_to_openai(self, mcp_tool: dict[str, Any]) -> dict[str, Any]:
869
+ """Convert MCP tool schema to OpenAI function schema.
870
+
871
+ Uses the original tool name from the MCP server without modification.
872
+ """
873
+ return {
874
+ "type": "function",
875
+ "function": {
876
+ "name": mcp_tool["name"],
877
+ "description": mcp_tool.get("description", ""),
878
+ "parameters": mcp_tool.get("inputSchema", {"type": "object", "properties": {}}),
879
+ },
880
+ }
881
+
882
+ async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
883
+ """Execute a tool on the MCP server."""
884
+ if not self._initialized:
885
+ success = await self.initialize()
886
+ if not success:
887
+ raise Exception(f"MCP server '{self.config.name}' not initialized")
888
+
889
+ try:
890
+ result = await self._send_request(
891
+ "tools/call",
892
+ {"name": name, "arguments": inputs},
893
+ )
894
+
895
+ if result and "content" in result:
896
+ # Extract text content from MCP response
897
+ content_items = result["content"]
898
+ text_parts = []
899
+ for item in content_items:
900
+ if item.get("type") == "text":
901
+ text_parts.append(item.get("text", ""))
902
+ return "\n".join(text_parts) if text_parts else result
903
+
904
+ return result
905
+
906
+ except Exception as e:
907
+ log.error(f"MCP tool execution failed for '{name}': {e}")
908
+ raise Exception(f"Failed to execute MCP tool '{name}': {e}")
909
+
910
+ async def check_status(self) -> tuple[bool, str]:
911
+ """Check if MCP server is reachable and properly initialize the session.
912
+
913
+ This performs a full initialization so the session is ready for tool listing.
914
+ Calling get_tools() after a successful check_status() will work correctly.
915
+ """
916
+ try:
917
+ # Reset state for fresh check
918
+ self._initialized = False
919
+ self._session_id = None
920
+
921
+ # Use the proper initialize flow
922
+ success = await self.initialize()
923
+ if success:
924
+ return True, "Connected"
925
+ else:
926
+ return False, "Initialization failed"
927
+
928
+ except httpx.ConnectError:
929
+ return False, "Connection failed"
930
+ except ToolAuthorizationRequired:
931
+ return False, "Authentication required"
932
+ except Exception as e:
933
+ return False, str(e)[:50]
934
+
935
+ async def complete_oauth_flow(self, authorization_code: str) -> bool:
936
+ """Complete OAuth flow by exchanging authorization code for tokens.
937
+
938
+ Call this after user has authorized and callback received the code.
939
+ """
940
+ if not self._auth_server_metadata:
941
+ log.error("No auth server metadata - cannot complete OAuth flow")
942
+ return False
943
+
944
+ if not hasattr(self, "_pending_code_verifier"):
945
+ log.error("No pending PKCE verifier - OAuth flow not started")
946
+ return False
947
+
948
+ try:
949
+ tokens = await self._oauth_handler.exchange_code_for_tokens(
950
+ self._auth_server_metadata,
951
+ authorization_code,
952
+ self._pending_code_verifier,
953
+ self.config.oauth_client_id,
954
+ )
955
+
956
+ # Update config with tokens
957
+ self.config.auth_type = MCPAuthType.OAUTH
958
+ self.config.oauth_tokens = tokens
959
+
960
+ # Persist
961
+ if self._server_store:
962
+ self._server_store.add(self.config)
963
+
964
+ # Clean up pending state
965
+ del self._pending_code_verifier
966
+ if hasattr(self, "_pending_state"):
967
+ del self._pending_state
968
+
969
+ # Reset client to use new tokens
970
+ if self._client:
971
+ await self._client.aclose()
972
+ self._client = None
973
+
974
+ log.info(f"OAuth flow completed for MCP server '{self.config.name}'")
975
+ return True
976
+
977
+ except Exception as e:
978
+ log.error(f"Failed to complete OAuth flow: {e}")
979
+ return False
980
+
981
+ async def close(self) -> None:
982
+ """Close the HTTP client and reset session state."""
983
+ if self._client:
984
+ await self._client.aclose()
985
+ self._client = None
986
+ self._session_id = None
987
+ self._initialized = False
988
+ await self._oauth_handler.close()
989
+
990
+
991
+ class CacheEntry:
992
+ """Cache entry for tool schemas with expiration."""
993
+
994
+ def __init__(self, tool_name: str, tool_schema: dict[str, Any], last_updated: datetime) -> None:
995
+ self.tool_name = tool_name
996
+ self.tool_schema = tool_schema
997
+ self.last_updated = last_updated
998
+
999
+ def is_expired(self, max_age: timedelta) -> bool:
1000
+ """Check if cache entry has expired."""
1001
+ return datetime.now() - self.last_updated > max_age
1002
+
1003
+ def update(self, tool_schema: dict[str, Any]) -> None:
1004
+ self.tool_schema = tool_schema
1005
+ self.last_updated = datetime.now()
1006
+
1007
+
1008
+ class ToolCache:
1009
+ """Cache for tool schemas with TTL support and persistent storage."""
1010
+
1011
+ def __init__(
1012
+ self,
1013
+ max_age: timedelta = timedelta(hours=24),
1014
+ cache_dir: Path | None = None,
1015
+ ) -> None:
1016
+ self._tools_cache: dict[str, CacheEntry] = {}
1017
+ self.max_age = max_age
1018
+
1019
+ if cache_dir is None:
1020
+ app_dir = Path(get_config().app_dir)
1021
+ cache_dir = app_dir / "tool_cache"
1022
+ self.cache_dir = cache_dir
1023
+ self.cache_file = self.cache_dir / "remote_tools.json"
1024
+
1025
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
1026
+ self._load_from_disk()
1027
+
1028
+ def _load_from_disk(self) -> None:
1029
+ """Load cached tools from disk."""
1030
+ if not self.cache_file.exists():
1031
+ return
1032
+
1033
+ try:
1034
+ with open(self.cache_file, encoding="utf-8") as f:
1035
+ data = json.load(f)
1036
+
1037
+ for tool_name, entry_data in data.get("tools", {}).items():
1038
+ last_updated = datetime.fromisoformat(entry_data["last_updated"])
1039
+ tool_schema = entry_data["tool_schema"]
1040
+ entry = CacheEntry(tool_name, tool_schema, last_updated)
1041
+ if not entry.is_expired(self.max_age):
1042
+ self._tools_cache[tool_name] = entry
1043
+
1044
+ log.debug(f"Loaded {len(self._tools_cache)} tools from persistent cache")
1045
+ except Exception as e:
1046
+ log.warning(f"Failed to load tool cache from disk: {e}")
1047
+
1048
+ def _save_to_disk(self) -> None:
1049
+ """Save cached tools to disk."""
1050
+ try:
1051
+ data = {
1052
+ "tools": {
1053
+ tool_name: {
1054
+ "tool_schema": entry.tool_schema,
1055
+ "last_updated": entry.last_updated.isoformat(),
1056
+ }
1057
+ for tool_name, entry in self._tools_cache.items()
1058
+ },
1059
+ "updated_at": datetime.now().isoformat(),
1060
+ }
1061
+
1062
+ with open(self.cache_file, "w", encoding="utf-8") as f:
1063
+ json.dump(data, f, indent=2)
1064
+ except Exception as e:
1065
+ log.warning(f"Failed to save tool cache to disk: {e}")
1066
+
1067
+ def get_tools(
1068
+ self, tools: list[str] | None = None, toolkits: list[str] | None = None
1069
+ ) -> list[dict[str, Any]]:
1070
+ """Get cached tools filtered by names or toolkits."""
1071
+ if tools:
1072
+ return [
1073
+ entry.tool_schema
1074
+ for entry in self._tools_cache.values()
1075
+ if entry.tool_name in tools and not entry.is_expired(self.max_age)
1076
+ ]
1077
+ if toolkits:
1078
+ # Filter by toolkit prefix
1079
+ toolkit_lower = {t.lower() for t in toolkits}
1080
+ return [
1081
+ entry.tool_schema
1082
+ for entry in self._tools_cache.values()
1083
+ if not entry.is_expired(self.max_age)
1084
+ and any(
1085
+ entry.tool_name.lower().startswith(f"{tk}_")
1086
+ or entry.tool_name.lower().startswith(f"{tk}.")
1087
+ for tk in toolkit_lower
1088
+ )
1089
+ ]
1090
+ return [
1091
+ entry.tool_schema
1092
+ for entry in self._tools_cache.values()
1093
+ if not entry.is_expired(self.max_age)
1094
+ ]
1095
+
1096
+ def update_tools(self, tool_name: str, tool_schema: dict[str, Any], save: bool = False):
1097
+ """Update or add a tool to the cache."""
1098
+ self._tools_cache[tool_name] = CacheEntry(tool_name, tool_schema, datetime.now())
1099
+ if save:
1100
+ self._save_to_disk()
1101
+
1102
+ def save(self):
1103
+ """Save all cached tools to disk."""
1104
+ self._save_to_disk()
1105
+
1106
+ def has_valid_cache(self) -> bool:
1107
+ """Check if cache has any valid (non-expired) entries."""
1108
+ return any(not entry.is_expired(self.max_age) for entry in self._tools_cache.values())
1109
+
1110
+
1111
+ class LocalToolManager(ToolManager):
1112
+ """Manages local tool execution."""
1113
+
1114
+ def __init__(self) -> None:
1115
+ self._tools_cache: list[JsonToolSchema] | None = None
1116
+ self._tool_funcs: dict[str, Any] = {}
1117
+ self._interactive_tools: set[str] = set()
1118
+
1119
+ async def get_tools(self) -> list[dict[str, Any]]:
1120
+ """Get all available local tools as jsonschemas."""
1121
+ if self._tools_cache is None:
1122
+ from cadecoder.tools.builtin import get_all_tools
1123
+
1124
+ tools = get_all_tools()
1125
+ self._tools_cache = []
1126
+
1127
+ for tool in tools:
1128
+ tool_name = getattr(tool, "__tool_name__", getattr(tool, "__name__", "unknown"))
1129
+
1130
+ # Cache the function for execution
1131
+ self._tool_funcs[tool_name] = tool
1132
+
1133
+ schema = {
1134
+ "type": "function",
1135
+ "function": {
1136
+ "name": tool_name,
1137
+ "description": getattr(tool, "__tool_description__", ""),
1138
+ },
1139
+ }
1140
+
1141
+ if getattr(tool, "__interactive__", False):
1142
+ self._interactive_tools.add(tool_name)
1143
+
1144
+ # Build parameters schema from function signature
1145
+ sig = inspect.signature(tool)
1146
+ parameters = {"type": "object", "properties": {}, "required": []}
1147
+
1148
+ for param_name, param in sig.parameters.items():
1149
+ if param_name == "context":
1150
+ continue
1151
+
1152
+ param_type = "string"
1153
+ param_desc = f"Parameter {param_name}"
1154
+
1155
+ if param.annotation != inspect.Parameter.empty:
1156
+ if get_origin(param.annotation) is Annotated:
1157
+ args = get_args(param.annotation)
1158
+ if len(args) >= 2:
1159
+ actual_type = args[0]
1160
+ param_desc = args[1]
1161
+
1162
+ if actual_type is str:
1163
+ param_type = "string"
1164
+ elif actual_type is int:
1165
+ param_type = "integer"
1166
+ elif actual_type is float:
1167
+ param_type = "number"
1168
+ elif actual_type is bool:
1169
+ param_type = "boolean"
1170
+ elif actual_type is list or get_origin(actual_type) is list:
1171
+ param_type = "array"
1172
+ elif actual_type is dict:
1173
+ param_type = "object"
1174
+
1175
+ parameters["properties"][param_name] = {
1176
+ "type": param_type,
1177
+ "description": param_desc,
1178
+ }
1179
+
1180
+ if param.default == inspect.Parameter.empty:
1181
+ parameters["required"].append(param_name)
1182
+
1183
+ if parameters["properties"]:
1184
+ schema["function"]["parameters"] = parameters
1185
+
1186
+ self._tools_cache.append(schema)
1187
+
1188
+ return self._tools_cache
1189
+
1190
+ async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
1191
+ """Execute a local tool by name with inputs."""
1192
+ # Ensure tools are loaded
1193
+ await self.get_tools()
1194
+
1195
+ tool_func = self._tool_funcs.get(name)
1196
+ if not tool_func:
1197
+ raise Exception(f"Tool '{name}' not found")
1198
+
1199
+ # Create minimal context
1200
+ from arcade_tdk import ToolContext
1201
+
1202
+ context = ToolContext(user_id="local_user") # type: ignore[call-arg]
1203
+
1204
+ # Execute (handle async and sync)
1205
+ if inspect.iscoroutinefunction(tool_func):
1206
+ result = await tool_func(context, **inputs)
1207
+ else:
1208
+ result = tool_func(context, **inputs)
1209
+
1210
+ return result
1211
+
1212
+ def is_interactive_tool(self, name: str) -> bool:
1213
+ """Check if the named tool requires exclusive terminal access."""
1214
+ return name in self._interactive_tools
1215
+
1216
+
1217
+ def _extract_toolkit_names_from_tools(tool_names: list[str]) -> set[str]:
1218
+ """Extract toolkit names from tool names."""
1219
+ toolkits = set()
1220
+ for tool_name in tool_names:
1221
+ if "_" in tool_name:
1222
+ toolkit = tool_name.split("_")[0].lower()
1223
+ toolkits.add(toolkit)
1224
+ elif "." in tool_name:
1225
+ toolkit = tool_name.split(".")[0].lower()
1226
+ toolkits.add(toolkit)
1227
+ return toolkits
1228
+
1229
+
1230
+ class RemoteToolManager(ToolManager):
1231
+ """Manages remote Arcade tool execution using arcadepy."""
1232
+
1233
+ def __init__(
1234
+ self,
1235
+ api_key: str | None = None,
1236
+ base_url: str | None = None,
1237
+ user_email: str | None = None,
1238
+ ):
1239
+ log.info(f"Initializing RemoteToolManager with user_email: {user_email}")
1240
+ cfg = get_config()
1241
+ self.arcade_client = AsyncArcade(
1242
+ api_key=api_key or cfg.api_key,
1243
+ base_url=base_url or cfg.base_url,
1244
+ )
1245
+ self._default_user_id = user_email or cfg.user_email
1246
+ self._tools_cache = ToolCache()
1247
+ self._all_tools_fetched = False
1248
+ self._fetching_lock = False
1249
+
1250
+ async def get_tools(
1251
+ self, tools: list[str] | None = None, toolkits: list[str] | None = None
1252
+ ) -> list[dict[str, Any]]:
1253
+ """Get available tools from Arcade Cloud."""
1254
+ try:
1255
+ toolkit_names_from_tools = set()
1256
+ if tools:
1257
+ toolkit_names_from_tools = _extract_toolkit_names_from_tools(tools)
1258
+
1259
+ all_toolkit_names = set()
1260
+ if toolkits:
1261
+ all_toolkit_names.update(t.lower() for t in toolkits)
1262
+ all_toolkit_names.update(toolkit_names_from_tools)
1263
+
1264
+ (list(all_toolkit_names) if all_toolkit_names else (toolkits or []))
1265
+
1266
+ if self._all_tools_fetched or self._tools_cache.has_valid_cache():
1267
+ cached_tools = self._filter_tools_from_cache(tools, toolkits)
1268
+ if cached_tools:
1269
+ log.debug(f"Using {len(cached_tools)} cached tools")
1270
+ return cached_tools
1271
+
1272
+ if self._fetching_lock:
1273
+ await asyncio.sleep(0.1)
1274
+ return await self.get_tools(tools, toolkits)
1275
+
1276
+ self._fetching_lock = True
1277
+
1278
+ try:
1279
+ with console.status("[cyan]Updating tools...", spinner="dots"):
1280
+ log.info("Fetching remote tools from Arcade API")
1281
+
1282
+ all_items: list[dict[str, Any]] = []
1283
+ offset = 0
1284
+ limit = 100
1285
+ total_fetched = 0
1286
+
1287
+ while True:
1288
+ tools_response = await self.arcade_client.tools.formatted.list(
1289
+ limit=limit,
1290
+ offset=offset,
1291
+ format="openai",
1292
+ user_id=self._default_user_id,
1293
+ )
1294
+
1295
+ items = getattr(tools_response, "items", [])
1296
+ if not items:
1297
+ break
1298
+
1299
+ all_items.extend(items)
1300
+ total_fetched += len(items)
1301
+
1302
+ total_count = getattr(tools_response, "total_count", 0)
1303
+ if total_fetched >= total_count or len(items) < limit:
1304
+ break
1305
+
1306
+ offset += limit
1307
+
1308
+ log.info(f"Fetched {total_fetched} total tools from Arcade API")
1309
+
1310
+ for tool in all_items:
1311
+ tool_name = self._extract_tool_name(tool)
1312
+ if tool_name:
1313
+ self._tools_cache.update_tools(tool_name, tool, save=False)
1314
+
1315
+ self._tools_cache.save()
1316
+ self._all_tools_fetched = True
1317
+
1318
+ finally:
1319
+ self._fetching_lock = False
1320
+
1321
+ return self._filter_tools_from_cache(tools, toolkits)
1322
+
1323
+ except Exception as e:
1324
+ log.error(f"Failed to fetch remote tools: {e}", exc_info=True)
1325
+ return []
1326
+
1327
+ def _extract_tool_name(self, tool: dict[str, Any]) -> str | None:
1328
+ """Extract tool name from tool schema."""
1329
+ if isinstance(tool, dict):
1330
+ return tool.get("function", {}).get("name", "")
1331
+ return None
1332
+
1333
+ def _filter_tools_from_cache(
1334
+ self, tools: list[str] | None = None, toolkits: list[str] | None = None
1335
+ ) -> list[dict[str, Any]]:
1336
+ """Filter tools from cache based on tool names and toolkits."""
1337
+ all_cached = self._tools_cache.get_tools()
1338
+
1339
+ if not all_cached:
1340
+ return []
1341
+
1342
+ if not tools and not toolkits:
1343
+ return all_cached
1344
+
1345
+ filtered_tools: list[dict[str, Any]] = []
1346
+ toolkit_lower = [t.lower() for t in (toolkits or [])]
1347
+ tool_names_set = set(tools or [])
1348
+
1349
+ for tool in all_cached:
1350
+ tool_name = self._extract_tool_name(tool)
1351
+ if not tool_name:
1352
+ continue
1353
+
1354
+ tool_name_lower = tool_name.lower()
1355
+
1356
+ if tool_names_set and tool_name in tool_names_set:
1357
+ filtered_tools.append(tool)
1358
+ continue
1359
+
1360
+ if toolkit_lower:
1361
+ if any(
1362
+ tool_name_lower.startswith(f"{tk}_") or tool_name_lower.startswith(f"{tk}.")
1363
+ for tk in toolkit_lower
1364
+ ):
1365
+ filtered_tools.append(tool)
1366
+
1367
+ return filtered_tools
1368
+
1369
+ async def execute(
1370
+ self,
1371
+ name: str,
1372
+ inputs: dict[str, Any],
1373
+ user_id: str | None = None,
1374
+ timeout: float = 120.0,
1375
+ ) -> Any:
1376
+ """Execute a remote tool via Arcade Cloud."""
1377
+ try:
1378
+ result = await asyncio.wait_for(
1379
+ self.arcade_client.tools.execute(
1380
+ tool_name=name,
1381
+ input=inputs,
1382
+ user_id=user_id or self._default_user_id,
1383
+ ),
1384
+ timeout=timeout,
1385
+ )
1386
+
1387
+ if result.output and result.output.authorization:
1388
+ auth_response = result.output.authorization
1389
+ if auth_response.status in ("not_started", "pending"):
1390
+ await self._handle_authorization_and_retry(name, inputs, user_id, auth_response)
1391
+ result = await asyncio.wait_for(
1392
+ self.arcade_client.tools.execute(
1393
+ tool_name=name,
1394
+ input=inputs,
1395
+ user_id=user_id or self._default_user_id,
1396
+ ),
1397
+ timeout=timeout,
1398
+ )
1399
+
1400
+ if result.output and result.output.error:
1401
+ error = result.output.error
1402
+ raise Exception(f"Tool execution failed: {error.message}")
1403
+
1404
+ return result.output.value if result.output else result
1405
+
1406
+ except Exception as e:
1407
+ error_str = str(e)
1408
+
1409
+ if "403" in error_str and "authorization required" in error_str.lower():
1410
+ try:
1411
+ auth_response = await self.arcade_client.tools.authorize(
1412
+ tool_name=name,
1413
+ user_id=user_id or self._default_user_id,
1414
+ )
1415
+
1416
+ await self._handle_authorization_and_retry(name, inputs, user_id, auth_response)
1417
+
1418
+ result = await self.arcade_client.tools.execute(
1419
+ tool_name=name,
1420
+ input=inputs,
1421
+ user_id=user_id or self._default_user_id,
1422
+ )
1423
+
1424
+ if result.output and result.output.error:
1425
+ error = result.output.error
1426
+ raise Exception(f"Tool execution failed: {error.message}")
1427
+
1428
+ return result.output.value if result.output else result
1429
+
1430
+ except Exception as retry_error:
1431
+ log.error(f"Failed to authorize and retry tool '{name}': {retry_error}")
1432
+ raise
1433
+
1434
+ log.error(f"Remote tool execution error for '{name}': {e}")
1435
+ raise Exception(f"Failed to execute remote tool '{name}': {str(e)}")
1436
+
1437
+ async def _handle_authorization_and_retry(
1438
+ self,
1439
+ tool_name: str,
1440
+ inputs: dict[str, Any],
1441
+ user_id: str | None,
1442
+ auth_response: Any,
1443
+ ) -> None:
1444
+ """Handle authorization flow."""
1445
+ if auth_response.url:
1446
+ content = Text()
1447
+ content.append(f"Authorization required for '{tool_name}'.\n\n", style="bold yellow")
1448
+ content.append("Click this link to authorize:\n", style="white")
1449
+ content.append(auth_response.url, style="bold cyan underline")
1450
+ content.append("\n\nWaiting for authorization to complete...", style="white dim")
1451
+
1452
+ panel = Panel(
1453
+ content,
1454
+ title="[bold yellow]⚠ Authorization Required[/bold yellow]",
1455
+ title_align="left",
1456
+ border_style="yellow",
1457
+ padding=(0, 1),
1458
+ width=110,
1459
+ )
1460
+ console.print(panel)
1461
+
1462
+ log.info(f"Waiting for authorization to complete for tool '{tool_name}'...")
1463
+ completed_auth = await self.arcade_client.auth.wait_for_completion(auth_response)
1464
+
1465
+ if completed_auth.status == "completed":
1466
+ console.print(f"[green]✓[/green] Authorization completed for '{tool_name}'")
1467
+ log.info(f"Authorization completed for tool '{tool_name}'")
1468
+ else:
1469
+ raise Exception(
1470
+ f"Authorization failed for '{tool_name}': status={completed_auth.status}"
1471
+ )
1472
+ else:
1473
+ raise Exception(f"Authorization required for '{tool_name}' but no URL provided")
1474
+
1475
+
1476
+ class CompositeToolManager(ToolManager):
1477
+ """Manages local, remote (Arcade), and MCP tools simultaneously."""
1478
+
1479
+ def __init__(
1480
+ self,
1481
+ local_manager: LocalToolManager | None = None,
1482
+ remote_manager: RemoteToolManager | None = None,
1483
+ enable_mcp: bool = True,
1484
+ ):
1485
+ self.local_manager = local_manager or LocalToolManager()
1486
+ self.remote_manager = remote_manager or RemoteToolManager(
1487
+ user_email=get_config().user_email,
1488
+ )
1489
+ self._tool_source_map: dict[str, str] = {}
1490
+ self._mcp_tool_to_manager: dict[str, MCPToolManager] = {}
1491
+ self._tools_cache = ToolCache()
1492
+ self._mcp_managers: list[MCPToolManager] = []
1493
+ self._enable_mcp = enable_mcp
1494
+
1495
+ tool_cfg = get_config().tool_settings
1496
+ self._default_remote_tool: list[str] = tool_cfg.included_tools.copy()
1497
+ self._default_remote_toolkit: list[str] = tool_cfg.included_toolkits.copy()
1498
+
1499
+ # Load MCP servers if enabled
1500
+ if enable_mcp:
1501
+ self._mcp_store = MCPServerStore()
1502
+ for server_config in self._mcp_store.list_enabled():
1503
+ mcp_manager = MCPToolManager(server_config)
1504
+ self._mcp_managers.append(mcp_manager)
1505
+
1506
+ async def get_tools(self) -> list[dict[str, Any]]:
1507
+ """Get all available tools from local, remote, and MCP sources."""
1508
+ all_tools = []
1509
+
1510
+ # Local tools
1511
+ try:
1512
+ local_tools = await self.local_manager.get_tools()
1513
+ for tool in local_tools:
1514
+ tool_name = tool.get("function", {}).get("name", "unknown")
1515
+ self._tool_source_map[tool_name] = "local"
1516
+ all_tools.append(tool)
1517
+ log.info(f"Loaded {len(local_tools)} local tools")
1518
+ except Exception as e:
1519
+ log.error(f"Failed to load local tools: {e}")
1520
+
1521
+ # Remote (Arcade) tools
1522
+ if self.remote_manager:
1523
+ try:
1524
+ remote_tools = await self.remote_manager.get_tools(
1525
+ tools=self._default_remote_tool,
1526
+ toolkits=self._default_remote_toolkit,
1527
+ )
1528
+ for tool in remote_tools:
1529
+ self._tools_cache.update_tools(
1530
+ tool.get("function", {}).get("name", "unknown"), tool
1531
+ )
1532
+
1533
+ for tool in remote_tools:
1534
+ tool_name = tool.get("function", {}).get("name", "unknown")
1535
+ desc = tool["function"].get("description", "")
1536
+ tool["function"]["description"] = f"[Arcade Cloud] {desc}"
1537
+ self._tool_source_map[tool_name] = "remote"
1538
+ all_tools.append(tool)
1539
+ log.info(f"Loaded {len(remote_tools)} remote tools from Arcade Cloud")
1540
+ except Exception as e:
1541
+ log.error(f"Failed to load remote tools: {e}")
1542
+
1543
+ # MCP tools
1544
+ mcp_tool_count = 0
1545
+ for mcp_manager in self._mcp_managers:
1546
+ try:
1547
+ mcp_tools = await mcp_manager.get_tools()
1548
+ for tool in mcp_tools:
1549
+ tool_name = tool.get("function", {}).get("name", "unknown")
1550
+ self._tool_source_map[tool_name] = "mcp"
1551
+ self._mcp_tool_to_manager[tool_name] = mcp_manager
1552
+ all_tools.append(tool)
1553
+ mcp_tool_count += len(mcp_tools)
1554
+
1555
+ # Update server status
1556
+ if self._enable_mcp and hasattr(self, "_mcp_store"):
1557
+ self._mcp_store.update_status(mcp_manager.config.name, True, len(mcp_tools))
1558
+ except Exception as e:
1559
+ log.error(f"Failed to load tools from MCP '{mcp_manager.config.name}': {e}")
1560
+
1561
+ if mcp_tool_count > 0:
1562
+ log.info(f"Loaded {mcp_tool_count} tools from MCP servers")
1563
+
1564
+ # Summary
1565
+ sources = list(self._tool_source_map.values())
1566
+ local_count = sources.count("local")
1567
+ remote_count = sources.count("remote")
1568
+ mcp_count = sources.count("mcp")
1569
+ log.info(
1570
+ f"Total tools: {len(all_tools)} "
1571
+ f"(Local: {local_count}, Arcade: {remote_count}, MCP: {mcp_count})"
1572
+ )
1573
+ return all_tools
1574
+
1575
+ async def execute(self, name: str, inputs: dict[str, Any]) -> Any:
1576
+ """Execute a tool by routing to the appropriate manager."""
1577
+ source = self._tool_source_map.get(name)
1578
+
1579
+ if source == "local":
1580
+ log.debug(f"Executing local tool: {name}")
1581
+ return await self.local_manager.execute(name, inputs)
1582
+ elif source == "remote" and self.remote_manager:
1583
+ log.debug(f"Executing remote tool: {name}")
1584
+ return await self.remote_manager.execute(name, inputs, user_id=get_config().user_email)
1585
+ elif source == "mcp":
1586
+ mcp_manager = self._mcp_tool_to_manager.get(name)
1587
+ if mcp_manager:
1588
+ log.debug(f"Executing MCP tool: {name}")
1589
+ return await mcp_manager.execute(name, inputs)
1590
+ raise Exception(f"MCP manager not found for tool '{name}'")
1591
+ else:
1592
+ raise Exception(f"Tool '{name}' not found in local, remote, or MCP tools")
1593
+
1594
+ def is_interactive_tool(self, name: str) -> bool:
1595
+ """Determine whether a tool is interactive based on its source."""
1596
+ source = self._tool_source_map.get(name)
1597
+ if source == "local" and hasattr(self.local_manager, "is_interactive_tool"):
1598
+ return bool(self.local_manager.is_interactive_tool(name))
1599
+ return False
1600
+
1601
+ def get_tool_source(self, name: str) -> str | None:
1602
+ """Get the source type for a tool."""
1603
+ return self._tool_source_map.get(name)
1604
+
1605
+ def get_all_tool_info(self) -> list[dict[str, Any]]:
1606
+ """Get info about all tools including their source."""
1607
+ tool_info = []
1608
+ for name, source in self._tool_source_map.items():
1609
+ info = {"name": name, "source": source}
1610
+ if source == "mcp" and name in self._mcp_tool_to_manager:
1611
+ info["server"] = self._mcp_tool_to_manager[name].config.name
1612
+ tool_info.append(info)
1613
+ return tool_info
1614
+
1615
+ async def close(self) -> None:
1616
+ """Close all MCP connections."""
1617
+ for mcp_manager in self._mcp_managers:
1618
+ await mcp_manager.close()
1619
+
1620
+
1621
+ __all__ = [
1622
+ "ToolManager",
1623
+ "LocalToolManager",
1624
+ "RemoteToolManager",
1625
+ "MCPToolManager",
1626
+ "MCPOAuthHandler",
1627
+ "CompositeToolManager",
1628
+ "ToolCache",
1629
+ "CacheEntry",
1630
+ "ToolAuthorizationRequired",
1631
+ "MCPServerConfig",
1632
+ "MCPServerStore",
1633
+ "MCPAuthType",
1634
+ "MCPOAuthTokens",
1635
+ ]