open-edison 0.1.19__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/server.py CHANGED
@@ -9,36 +9,37 @@ import asyncio
9
9
  import json
10
10
  import traceback
11
11
  from collections.abc import Awaitable, Callable, Coroutine
12
+ from contextlib import suppress
12
13
  from pathlib import Path
13
- from typing import Any, cast
14
+ from typing import Any, Literal, cast
14
15
 
15
16
  import uvicorn
16
17
  from fastapi import Depends, FastAPI, HTTPException, status
17
18
  from fastapi.middleware.cors import CORSMiddleware
18
- from fastapi.responses import FileResponse, JSONResponse, Response
19
+ from fastapi.responses import (
20
+ FileResponse,
21
+ JSONResponse,
22
+ RedirectResponse,
23
+ Response,
24
+ StreamingResponse,
25
+ )
19
26
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
20
27
  from fastapi.staticfiles import StaticFiles
21
28
  from fastmcp import FastMCP
22
29
  from loguru import logger as log
23
30
  from pydantic import BaseModel, Field
24
31
 
25
- from src.config import MCPServerConfig, config
32
+ from src import events
33
+ from src.config import Config, MCPServerConfig, load_json_file
26
34
  from src.config import get_config_dir as _get_cfg_dir # type: ignore[attr-defined]
27
35
  from src.middleware.session_tracking import (
28
36
  MCPSessionModel,
29
37
  create_db_session,
30
38
  )
39
+ from src.oauth_manager import OAuthStatus, get_oauth_manager
31
40
  from src.single_user_mcp import SingleUserMCP
32
41
  from src.telemetry import initialize_telemetry, set_servers_installed
33
42
 
34
-
35
- def _get_current_config():
36
- """Get current config, allowing for test mocking."""
37
- from src.config import config as current_config
38
-
39
- return current_config
40
-
41
-
42
43
  # Module-level dependency singletons
43
44
  _security = HTTPBearer()
44
45
  _auth_dependency = Depends(_security)
@@ -102,6 +103,15 @@ class OpenEdisonProxy:
102
103
  StaticFiles(directory=str(assets_dir), html=False),
103
104
  name="dashboard-assets",
104
105
  )
106
+ # Serve service worker at root path for registration at /sw.js
107
+ sw_path = static_dir / "sw.js"
108
+ if sw_path.exists():
109
+
110
+ async def _sw() -> FileResponse: # type: ignore[override]
111
+ # Service workers must be served from the origin root scope
112
+ return FileResponse(str(sw_path), media_type="application/javascript")
113
+
114
+ app.add_api_route("/sw.js", _sw, methods=["GET"]) # type: ignore[arg-type]
105
115
  favicon_path = static_dir / "favicon.ico"
106
116
  if favicon_path.exists():
107
117
 
@@ -116,49 +126,29 @@ class OpenEdisonProxy:
116
126
  log.warning(f"Failed to mount dashboard static assets: {mount_err}")
117
127
 
118
128
  # Special-case: serve SQLite db and config JSONs for dashboard (prod replacement for Vite @fs)
119
- def _resolve_db_path() -> Path | None:
120
- try:
121
- # Try configured database path first
122
- db_cfg = getattr(config.logging, "database_path", None)
123
- if isinstance(db_cfg, str) and db_cfg:
124
- db_path = Path(db_cfg)
125
- if db_path.is_absolute() and db_path.exists():
126
- return db_path
127
- # Check relative to config dir
128
- try:
129
- cfg_dir = _get_cfg_dir()
130
- except Exception:
131
- cfg_dir = Path.cwd()
132
- rel1 = cfg_dir / db_path
133
- if rel1.exists():
134
- return rel1
135
- # Also check relative to cwd as a fallback
136
- rel2 = Path.cwd() / db_path
137
- if rel2.exists():
138
- return rel2
139
- except Exception:
140
- pass
141
-
142
- # Fallback common locations
129
+ def _resolve_db_path() -> Path:
130
+ # Try configured database path first
131
+ db_cfg = Config().logging.database_path
132
+ db_path = Path(db_cfg)
133
+ if db_path.is_absolute() and db_path.exists():
134
+ return db_path
135
+ # Check relative to config dir
143
136
  try:
144
137
  cfg_dir = _get_cfg_dir()
145
138
  except Exception:
146
139
  cfg_dir = Path.cwd()
147
- candidates = [
148
- cfg_dir / "sessions.db",
149
- cfg_dir / "sessions.db",
150
- Path.cwd() / "edison.db",
151
- Path.cwd() / "sessions.db",
152
- ]
153
- for c in candidates:
154
- if c.exists():
155
- return c
156
- return None
140
+ rel1 = cfg_dir / db_path
141
+ if rel1.exists():
142
+ return rel1
143
+ # Also check relative to cwd as a fallback
144
+ rel2 = Path.cwd() / db_path
145
+ if rel2.exists():
146
+ return rel2
147
+
148
+ raise FileNotFoundError(f"Database file not found at {db_path}")
157
149
 
158
150
  async def _serve_db() -> FileResponse: # type: ignore[override]
159
151
  db_file = _resolve_db_path()
160
- if db_file is None:
161
- raise HTTPException(status_code=404, detail="Database file not found")
162
152
  return FileResponse(str(db_file), media_type="application/octet-stream")
163
153
 
164
154
  # Provide multiple paths the SPA might attempt (both edison.db legacy and sessions.db canonical)
@@ -200,12 +190,10 @@ class OpenEdisonProxy:
200
190
  # 2) Repository/package defaults next to src/
201
191
  repo_candidate = Path(__file__).parent.parent / filename
202
192
  if repo_candidate.exists():
203
- # Bootstrap a copy into config dir when possible
204
- try:
193
+ # Bootstrap a copy into config dir when possible (best effort)
194
+ with suppress(Exception):
205
195
  target.parent.mkdir(parents=True, exist_ok=True)
206
196
  target.write_text(repo_candidate.read_text(encoding="utf-8"), encoding="utf-8")
207
- except Exception:
208
- pass
209
197
  return target if target.exists() else repo_candidate
210
198
 
211
199
  # 3) Fall back to config dir path (will be created on save)
@@ -262,12 +250,57 @@ class OpenEdisonProxy:
262
250
  _ = json.loads(content or "{}")
263
251
  target.write_text(content or "{}", encoding="utf-8")
264
252
  log.debug(f"Saved JSON config to {target}")
253
+
254
+ # Clear cache for the config file, if it was config.json
255
+ if name == "config.json":
256
+ load_json_file.cache_clear()
257
+
265
258
  return {"status": "ok"}
266
259
  except Exception as e: # noqa: BLE001
267
260
  raise HTTPException(status_code=400, detail=f"Save failed: {e}") from e
268
261
 
269
262
  app.add_api_route("/__save_json__", _save_json, methods=["POST"]) # type: ignore[arg-type]
270
263
 
264
+ # SSE events endpoint
265
+ async def _events() -> StreamingResponse: # type: ignore[override]
266
+ queue = await events.subscribe()
267
+ return StreamingResponse(
268
+ events.sse_stream(queue),
269
+ media_type="text/event-stream",
270
+ )
271
+
272
+ app.add_api_route("/events", _events, methods=["GET"]) # type: ignore[arg-type]
273
+
274
+ # Approval endpoint to allow an item for the rest of the session
275
+ class _ApprovalBody(BaseModel):
276
+ session_id: str
277
+ kind: Literal["tool", "resource", "prompt"]
278
+ name: str
279
+
280
+ async def _approve(body: _ApprovalBody) -> dict[str, Any]: # type: ignore[override]
281
+ try:
282
+ # Mark approval once; no persistent overrides
283
+ await events.approve_once(body.session_id, body.kind, body.name)
284
+
285
+ # Notify listeners (best effort, log failure)
286
+ events.fire_and_forget(
287
+ {
288
+ "type": "mcp_approved_once",
289
+ "session_id": body.session_id,
290
+ "kind": body.kind,
291
+ "name": body.name,
292
+ }
293
+ )
294
+
295
+ return {"status": "ok"}
296
+ except HTTPException:
297
+ raise
298
+ except Exception as e: # noqa: BLE001
299
+ log.error(f"Approval failed: {e}")
300
+ raise HTTPException(status_code=500, detail="Failed to approve item") from e
301
+
302
+ app.add_api_route("/api/approve", _approve, methods=["POST"]) # type: ignore[arg-type]
303
+
271
304
  # Catch-all for @fs patterns; serve known db and json filenames
272
305
  async def _serve_fs_path(rest: str): # type: ignore[override]
273
306
  target = rest.strip("/")
@@ -282,6 +315,12 @@ class OpenEdisonProxy:
282
315
  app.add_api_route("/@fs/{rest:path}", _serve_fs_path, methods=["GET"]) # type: ignore[arg-type]
283
316
  app.add_api_route("/%40fs/{rest:path}", _serve_fs_path, methods=["GET"]) # type: ignore[arg-type]
284
317
 
318
+ # Redirect root to dashboard
319
+ async def _root_redirect() -> RedirectResponse: # type: ignore[override]
320
+ return RedirectResponse(url="/dashboard")
321
+
322
+ app.add_api_route("/", _root_redirect, methods=["GET"]) # type: ignore[arg-type]
323
+
285
324
  return app
286
325
 
287
326
  def _build_backend_config_top(
@@ -315,7 +354,7 @@ class OpenEdisonProxy:
315
354
  await self.single_user_mcp.initialize()
316
355
 
317
356
  # Emit snapshot of enabled servers
318
- enabled_count = len([s for s in config.mcp_servers if s.enabled])
357
+ enabled_count = len([s for s in Config().mcp_servers if s.enabled])
319
358
  set_servers_installed(enabled_count)
320
359
 
321
360
  # Add CORS middleware to FastAPI
@@ -335,7 +374,7 @@ class OpenEdisonProxy:
335
374
  app=self.fastapi_app,
336
375
  host=self.host,
337
376
  port=self.port + 1,
338
- log_level=config.logging.level.lower(),
377
+ log_level=Config().logging.level.lower(),
339
378
  )
340
379
  fastapi_server = uvicorn.Server(fastapi_config)
341
380
  servers_to_run.append(fastapi_server.serve())
@@ -346,7 +385,7 @@ class OpenEdisonProxy:
346
385
  app=mcp_app,
347
386
  host=self.host,
348
387
  port=self.port,
349
- log_level=config.logging.level.lower(),
388
+ log_level=Config().logging.level.lower(),
350
389
  )
351
390
  fastmcp_server = uvicorn.Server(fastmcp_config)
352
391
  servers_to_run.append(fastmcp_server.serve())
@@ -364,6 +403,12 @@ class OpenEdisonProxy:
364
403
  self.mcp_status,
365
404
  methods=["GET"],
366
405
  )
406
+ # Endpoint to notify server that permissions JSONs changed; invalidate caches
407
+ app.add_api_route(
408
+ "/api/permissions-changed",
409
+ self.permissions_changed,
410
+ methods=["POST"],
411
+ )
367
412
  app.add_api_route(
368
413
  "/mcp/validate",
369
414
  self.validate_mcp_server,
@@ -382,17 +427,55 @@ class OpenEdisonProxy:
382
427
  methods=["POST"],
383
428
  dependencies=[Depends(self.verify_api_key)],
384
429
  )
430
+ app.add_api_route(
431
+ "/mcp/mount/{server_name}",
432
+ self.mount_mcp_server,
433
+ methods=["POST"],
434
+ dependencies=[Depends(self.verify_api_key)],
435
+ )
436
+ app.add_api_route(
437
+ "/mcp/mount/{server_name}",
438
+ self.unmount_mcp_server,
439
+ methods=["DELETE"],
440
+ dependencies=[Depends(self.verify_api_key)],
441
+ )
385
442
  # Public sessions endpoint (no auth) for simple local dashboard
386
443
  app.add_api_route(
387
444
  "/sessions",
388
445
  self.get_sessions,
389
446
  methods=["GET"],
390
447
  )
391
- # Cache invalidation endpoint (no auth required - allowed to fail)
448
+
449
+ # OAuth endpoints
450
+ app.add_api_route(
451
+ "/mcp/oauth/status",
452
+ self.get_oauth_status_all,
453
+ methods=["GET"],
454
+ dependencies=[Depends(self.verify_api_key)],
455
+ )
456
+ app.add_api_route(
457
+ "/mcp/oauth/status/{server_name}",
458
+ self.get_oauth_status,
459
+ methods=["GET"],
460
+ dependencies=[Depends(self.verify_api_key)],
461
+ )
462
+ app.add_api_route(
463
+ "/mcp/oauth/test-connection/{server_name}",
464
+ self.oauth_test_connection,
465
+ methods=["POST"],
466
+ dependencies=[Depends(self.verify_api_key)],
467
+ )
468
+ app.add_api_route(
469
+ "/mcp/oauth/tokens/{server_name}",
470
+ self.oauth_clear_tokens,
471
+ methods=["DELETE"],
472
+ dependencies=[Depends(self.verify_api_key)],
473
+ )
392
474
  app.add_api_route(
393
- "/api/clear-caches",
394
- self.clear_caches,
475
+ "/mcp/oauth/refresh/{server_name}",
476
+ self.oauth_refresh_status,
395
477
  methods=["POST"],
478
+ dependencies=[Depends(self.verify_api_key)],
396
479
  )
397
480
 
398
481
  async def verify_api_key(
@@ -403,11 +486,27 @@ class OpenEdisonProxy:
403
486
 
404
487
  Returns the API key string if valid, otherwise raises HTTPException.
405
488
  """
406
- current_config = _get_current_config()
407
- if credentials.credentials != current_config.server.api_key:
489
+ if credentials.credentials != Config().server.api_key:
408
490
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
409
491
  return credentials.credentials
410
492
 
493
+ async def permissions_changed(self) -> dict[str, Any]:
494
+ """Invalidate SingleUserMCP manager caches after permissions JSON changed.
495
+
496
+ This attempts to clear any known cache methods on the internal managers and then
497
+ warms the lists to ensure subsequent list calls reflect current state.
498
+ """
499
+ try:
500
+ mcp = self.single_user_mcp
501
+ # Warm managers so any internal caches are refreshed
502
+ await mcp._tool_manager.list_tools() # type: ignore[attr-defined]
503
+ await mcp._resource_manager.list_resources() # type: ignore[attr-defined]
504
+ await mcp._prompt_manager.list_prompts() # type: ignore[attr-defined]
505
+ return {"status": "ok"}
506
+ except Exception as e: # noqa: BLE001
507
+ log.error(f"Failed to process permissions-changed: {e}")
508
+ raise HTTPException(status_code=500, detail="Failed to invalidate caches") from e
509
+
411
510
  async def mcp_status(self) -> dict[str, list[dict[str, Any]]]:
412
511
  """Get status of configured MCP servers (auth required)."""
413
512
  return {
@@ -416,34 +515,13 @@ class OpenEdisonProxy:
416
515
  "name": server.name,
417
516
  "enabled": server.enabled,
418
517
  }
419
- for server in config.mcp_servers
518
+ for server in Config().mcp_servers
420
519
  ]
421
520
  }
422
521
 
423
- def _handle_server_operation_error(
424
- self, operation: str, server_name: str, error: Exception
425
- ) -> HTTPException:
426
- """Handle common server operation errors."""
427
- log.error(f"Failed to {operation} server {server_name}: {error}")
428
- return HTTPException(
429
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
430
- detail=f"Failed to {operation} server: {str(error)}",
431
- )
432
-
433
- def _find_server_config(self, server_name: str) -> MCPServerConfig:
434
- """Find server configuration by name."""
435
- current_config = _get_current_config()
436
- for config_server in current_config.mcp_servers:
437
- if config_server.name == server_name:
438
- return config_server
439
- raise HTTPException(
440
- status_code=404,
441
- detail=f"Server configuration not found: {server_name}",
442
- )
443
-
444
522
  async def health_check(self) -> dict[str, Any]:
445
523
  """Health check endpoint"""
446
- return {"status": "healthy", "version": "0.1.0", "mcp_servers": len(config.mcp_servers)}
524
+ return {"status": "healthy", "version": "0.1.0", "mcp_servers": len(Config().mcp_servers)}
447
525
 
448
526
  async def get_mounted_servers(self) -> dict[str, Any]:
449
527
  """Get list of currently mounted MCP servers."""
@@ -458,48 +536,64 @@ class OpenEdisonProxy:
458
536
  ) from e
459
537
 
460
538
  async def reinitialize_mcp_servers(self) -> dict[str, Any]:
461
- """Reinitialize all MCP servers by creating a fresh instance and reloading config."""
462
- old_mcp = None
539
+ """Reinitialize all MCP servers by creating a fresh instance and reloading config.
540
+
541
+ Returns a JSON payload summarizing the final mounted servers so callers can display status.
542
+ """
463
543
  try:
464
544
  log.info("🔄 Reinitializing MCP servers via API endpoint")
465
545
 
466
- # Reload configuration from disk
467
- log.info("Reloading configuration from disk")
468
- from src.config import Config
469
-
470
- fresh_config = Config.load()
471
- log.info("✅ Configuration reloaded from disk")
472
-
473
546
  # Create a completely new SingleUserMCP instance to ensure clean state
474
- old_mcp = self.single_user_mcp
475
- self.single_user_mcp = SingleUserMCP()
547
+ # old_mcp = self.single_user_mcp
548
+ # self.single_user_mcp = SingleUserMCP()
549
+ # del old_mcp
476
550
 
477
551
  # Initialize the new instance with fresh config
478
- await self.single_user_mcp.initialize(fresh_config)
552
+ await self.single_user_mcp.initialize()
479
553
 
480
- # Get final status
481
- final_mounted = await self.single_user_mcp.get_mounted_servers()
554
+ # Summarize final mounted servers
555
+ try:
556
+ mounted = await self.single_user_mcp.get_mounted_servers()
557
+ except Exception:
558
+ log.error("Failed to get mounted servers")
559
+ mounted = []
482
560
 
483
- result = {
484
- "status": "success",
485
- "message": "MCP servers reinitialized successfully",
486
- "final_mounted_servers": [server["name"] for server in final_mounted],
487
- "total_final_mounted": len(final_mounted),
561
+ names = [m.get("name", "") for m in mounted]
562
+ return {
563
+ "status": "ok",
564
+ "total_final_mounted": len(mounted),
565
+ "mounted_servers": names,
488
566
  }
489
567
 
490
- log.info("✅ MCP servers reinitialized successfully via API")
491
- return result
492
-
493
568
  except Exception as e:
494
569
  log.error(f"❌ Failed to reinitialize MCP servers: {e}")
495
- # Restore the old instance on failure
496
- if old_mcp is not None:
497
- self.single_user_mcp = old_mcp
498
570
  raise HTTPException(
499
571
  status_code=500,
500
572
  detail=f"Failed to reinitialize MCP servers: {str(e)}",
501
573
  ) from e
502
574
 
575
+ async def mount_mcp_server(self, server_name: str) -> dict[str, Any]:
576
+ """Mount a single MCP server by name (auth required)."""
577
+ try:
578
+ ok = await self.single_user_mcp.mount_server(server_name)
579
+ return {"mounted": bool(ok), "server": server_name}
580
+ except Exception as e:
581
+ log.error(f"❌ Failed to mount server {server_name}: {e}")
582
+ raise HTTPException(
583
+ status_code=500, detail=f"Failed to mount server {server_name}: {str(e)}"
584
+ ) from e
585
+
586
+ async def unmount_mcp_server(self, server_name: str) -> dict[str, Any]:
587
+ """Unmount a previously mounted MCP server by name (auth required)."""
588
+ try:
589
+ ok = await self.single_user_mcp.unmount(server_name)
590
+ return {"unmounted": bool(ok), "server": server_name}
591
+ except Exception as e:
592
+ log.error(f"❌ Failed to unmount server {server_name}: {e}")
593
+ raise HTTPException(
594
+ status_code=500, detail=f"Failed to unmount server {server_name}: {str(e)}"
595
+ ) from e
596
+
503
597
  async def get_sessions(self) -> dict[str, Any]:
504
598
  """Return recent MCP session summaries from local SQLite.
505
599
 
@@ -549,21 +643,6 @@ class OpenEdisonProxy:
549
643
  log.error(f"Failed to fetch sessions: {e}")
550
644
  raise HTTPException(status_code=500, detail="Failed to fetch sessions") from e
551
645
 
552
- async def clear_caches(self) -> dict[str, str]:
553
- """Clear all permission caches to force reload from configuration files."""
554
- try:
555
- from src.middleware.data_access_tracker import clear_all_permissions_caches
556
-
557
- log.info("🔄 Clearing all permission caches via API endpoint")
558
- clear_all_permissions_caches()
559
- log.info("✅ All permission caches cleared successfully")
560
-
561
- return {"status": "success", "message": "All permission caches cleared"}
562
- except Exception as e:
563
- log.error(f"❌ Failed to clear permission caches: {e}")
564
- # Don't raise HTTPException - allow to fail gracefully as requested
565
- return {"status": "error", "message": f"Failed to clear caches: {str(e)}"}
566
-
567
646
  # ---- MCP validation ----
568
647
  class _ValidateRequest(BaseModel):
569
648
  name: str | None = Field(None, description="Optional server name label")
@@ -659,18 +738,6 @@ class OpenEdisonProxy:
659
738
  except Exception as cleanup_err: # noqa: BLE001
660
739
  log.debug(f"Validator cleanup skipped/failed: {cleanup_err}")
661
740
 
662
- def _build_backend_config(
663
- self, server_name: str, body: "OpenEdisonProxy._ValidateRequest"
664
- ) -> dict[str, Any]:
665
- backend_entry: dict[str, Any] = {
666
- "command": body.command,
667
- "args": body.args,
668
- "env": body.env or {},
669
- }
670
- if body.roots:
671
- backend_entry["roots"] = body.roots
672
- return {"mcpServers": {server_name: backend_entry}}
673
-
674
741
  async def _list_all_capabilities(
675
742
  self, server: FastMCP[Any], body: "OpenEdisonProxy._ValidateRequest"
676
743
  ) -> tuple[list[Any], list[Any], list[Any]]:
@@ -717,3 +784,293 @@ class OpenEdisonProxy:
717
784
  "name": prefix + "_" + str(name) if name is not None else "",
718
785
  "description": description,
719
786
  }
787
+
788
+ # ---- OAuth endpoints ----
789
+
790
+ async def get_oauth_status_all(self) -> dict[str, Any]:
791
+ """Get OAuth status for all configured MCP servers."""
792
+ try:
793
+ oauth_manager = get_oauth_manager()
794
+
795
+ servers_info = {}
796
+ for server_config in Config().mcp_servers:
797
+ server_name = server_config.name
798
+ info = oauth_manager.get_server_info(server_name)
799
+
800
+ if info:
801
+ # Use cached OAuth info
802
+ servers_info[server_name] = {
803
+ "server_name": info.server_name,
804
+ "status": info.status.value,
805
+ "error_message": info.error_message,
806
+ "token_expires_at": info.token_expires_at,
807
+ "has_refresh_token": info.has_refresh_token,
808
+ "scopes": info.scopes,
809
+ }
810
+ else:
811
+ # OAuth status not checked yet - check proactively for remote servers
812
+ if server_config.is_remote_server():
813
+ remote_url = server_config.get_remote_url()
814
+ log.info(f"🔍 Proactively checking OAuth for remote server {server_name}")
815
+
816
+ # Check OAuth requirements for this remote server
817
+ oauth_info = await oauth_manager.check_oauth_requirement(
818
+ server_name, remote_url
819
+ )
820
+
821
+ servers_info[server_name] = {
822
+ "server_name": oauth_info.server_name,
823
+ "status": oauth_info.status.value,
824
+ "error_message": oauth_info.error_message,
825
+ "token_expires_at": oauth_info.token_expires_at,
826
+ "has_refresh_token": oauth_info.has_refresh_token,
827
+ "scopes": oauth_info.scopes,
828
+ }
829
+ else:
830
+ # Local server - no OAuth needed
831
+ servers_info[server_name] = {
832
+ "server_name": server_name,
833
+ "status": OAuthStatus.NOT_REQUIRED.value,
834
+ "error_message": None,
835
+ "token_expires_at": None,
836
+ "has_refresh_token": False,
837
+ "scopes": None,
838
+ }
839
+
840
+ return {"oauth_status": servers_info}
841
+
842
+ except Exception as e:
843
+ log.error(f"Failed to get OAuth status for all servers: {e}")
844
+ raise HTTPException(
845
+ status_code=500,
846
+ detail=f"Failed to get OAuth status: {str(e)}",
847
+ ) from e
848
+
849
+ def _find_server_config(self, server_name: str) -> MCPServerConfig:
850
+ """Find server configuration by name."""
851
+ for config_server in Config().mcp_servers:
852
+ if config_server.name == server_name:
853
+ return config_server
854
+ raise HTTPException(
855
+ status_code=404,
856
+ detail=f"Server configuration not found: {server_name}",
857
+ )
858
+
859
+ async def get_oauth_status(self, server_name: str) -> dict[str, Any]:
860
+ """Get OAuth status for a specific MCP server."""
861
+ try:
862
+ server_config = self._find_server_config(server_name)
863
+ oauth_manager = get_oauth_manager()
864
+
865
+ # Get the remote URL if this is a remote server
866
+ remote_url = server_config.get_remote_url()
867
+
868
+ # Check or refresh OAuth status
869
+ oauth_info = await oauth_manager.check_oauth_requirement(server_name, remote_url)
870
+
871
+ return {
872
+ "server_name": oauth_info.server_name,
873
+ "mcp_url": oauth_info.mcp_url,
874
+ "status": oauth_info.status.value,
875
+ "error_message": oauth_info.error_message,
876
+ "token_expires_at": oauth_info.token_expires_at,
877
+ "has_refresh_token": oauth_info.has_refresh_token,
878
+ "scopes": oauth_info.scopes,
879
+ "client_name": oauth_info.client_name,
880
+ }
881
+
882
+ except HTTPException:
883
+ raise
884
+ except Exception as e:
885
+ log.error(f"Failed to get OAuth status for {server_name}: {e}")
886
+ raise HTTPException(
887
+ status_code=500,
888
+ detail=f"Failed to get OAuth status: {str(e)}",
889
+ ) from e
890
+
891
+ class _OAuthAuthorizeRequest(BaseModel):
892
+ scopes: list[str] | None = Field(None, description="OAuth scopes to request")
893
+ client_name: str | None = Field(None, description="Client name for OAuth registration")
894
+
895
+ async def oauth_test_connection(
896
+ self, server_name: str, body: _OAuthAuthorizeRequest | None = None
897
+ ) -> dict[str, Any]:
898
+ """
899
+ Test connection to a remote MCP server, triggering OAuth flow if needed.
900
+
901
+ This endpoint creates a temporary FastMCP client with OAuth authentication
902
+ and attempts to make a connection. This automatically triggers FastMCP's
903
+ OAuth flow, which will open a browser for user authorization.
904
+ """
905
+ try:
906
+ server_config = self._find_server_config(server_name)
907
+ oauth_manager = get_oauth_manager()
908
+
909
+ # Check if this is a remote server
910
+ if not server_config.is_remote_server():
911
+ raise HTTPException(
912
+ status_code=400,
913
+ detail=f"Server {server_name} is a local server and does not support OAuth",
914
+ )
915
+
916
+ # Get the remote URL
917
+ remote_url = server_config.get_remote_url()
918
+ if not remote_url:
919
+ raise HTTPException(
920
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
921
+ )
922
+
923
+ # Get OAuth configuration
924
+ scopes = None
925
+ client_name = None
926
+
927
+ if body:
928
+ scopes = body.scopes
929
+ client_name = body.client_name
930
+
931
+ # Use server config OAuth settings if not provided in request
932
+ if not scopes and server_config.oauth_scopes:
933
+ scopes = server_config.oauth_scopes
934
+ if not client_name and server_config.oauth_client_name:
935
+ client_name = server_config.oauth_client_name
936
+
937
+ log.info(f"🔗 Testing connection to {server_name} at {remote_url}")
938
+
939
+ # Import FastMCP client for testing
940
+ from fastmcp import Client as FastMCPClient
941
+ from fastmcp.client.auth import OAuth
942
+
943
+ # Create OAuth auth object
944
+ oauth = OAuth(
945
+ mcp_url=remote_url,
946
+ scopes=scopes,
947
+ client_name=client_name or "OpenEdison MCP Gateway",
948
+ token_storage_cache_dir=oauth_manager.cache_dir,
949
+ )
950
+
951
+ # Create a temporary client and test the connection
952
+ # This will automatically trigger OAuth flow if tokens don't exist
953
+ try:
954
+ async with FastMCPClient(remote_url, auth=oauth) as client:
955
+ # Try to ping the server - this triggers OAuth if needed
956
+ log.info(
957
+ f"🔐 Attempting to connect to {server_name} (may open browser for OAuth)..."
958
+ )
959
+ await client.ping()
960
+ log.info(f"✅ Successfully connected to {server_name}")
961
+
962
+ # Update OAuth status in manager
963
+ await oauth_manager.check_oauth_requirement(server_name, remote_url)
964
+
965
+ return {
966
+ "status": "connection_successful",
967
+ "message": f"Successfully connected to {server_name}. OAuth tokens are now cached.",
968
+ "server_name": server_name,
969
+ }
970
+
971
+ except Exception as e:
972
+ log.error(f"❌ Failed to connect to {server_name}: {e}")
973
+
974
+ # Check if this was an OAuth-related error
975
+ error_message = str(e)
976
+ if "oauth" in error_message.lower() or "authorization" in error_message.lower():
977
+ return {
978
+ "status": "oauth_required",
979
+ "message": f"OAuth authorization completed for {server_name}. Please try connecting again.",
980
+ "server_name": server_name,
981
+ }
982
+ raise HTTPException(
983
+ status_code=500, detail=f"Connection test failed: {error_message}"
984
+ ) from None
985
+
986
+ except HTTPException:
987
+ raise
988
+ except Exception as e:
989
+ log.error(f"Failed to test connection for {server_name}: {e}")
990
+ raise HTTPException(
991
+ status_code=500,
992
+ detail=f"Failed to test connection: {str(e)}",
993
+ ) from e
994
+
995
+ async def oauth_clear_tokens(self, server_name: str) -> dict[str, Any]:
996
+ """Clear stored OAuth tokens for a server."""
997
+ try:
998
+ server_config = self._find_server_config(server_name)
999
+ oauth_manager = get_oauth_manager()
1000
+
1001
+ # Check if this is a remote server
1002
+ if not server_config.is_remote_server():
1003
+ raise HTTPException(
1004
+ status_code=400,
1005
+ detail=f"Server {server_name} is a local server and does not support OAuth",
1006
+ )
1007
+
1008
+ # Get the remote URL
1009
+ remote_url = server_config.get_remote_url()
1010
+ if not remote_url:
1011
+ raise HTTPException(
1012
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
1013
+ )
1014
+
1015
+ success = oauth_manager.clear_tokens(server_name, remote_url)
1016
+
1017
+ if success:
1018
+ return {
1019
+ "status": "success",
1020
+ "message": f"OAuth tokens cleared for {server_name}",
1021
+ "server_name": server_name,
1022
+ }
1023
+ raise HTTPException(
1024
+ status_code=500, detail=f"Failed to clear OAuth tokens for {server_name}"
1025
+ )
1026
+
1027
+ except HTTPException:
1028
+ raise
1029
+ except Exception as e:
1030
+ log.error(f"Failed to clear OAuth tokens for {server_name}: {e}")
1031
+ raise HTTPException(
1032
+ status_code=500,
1033
+ detail=f"Failed to clear OAuth tokens: {str(e)}",
1034
+ ) from e
1035
+
1036
+ async def oauth_refresh_status(self, server_name: str) -> dict[str, Any]:
1037
+ """Refresh OAuth status for a server."""
1038
+ try:
1039
+ server_config = self._find_server_config(server_name)
1040
+ oauth_manager = get_oauth_manager()
1041
+
1042
+ # Check if this is a remote server
1043
+ if not server_config.is_remote_server():
1044
+ raise HTTPException(
1045
+ status_code=400,
1046
+ detail=f"Server {server_name} is a local server and does not support OAuth",
1047
+ )
1048
+
1049
+ # Get the remote URL (now guaranteed to be non-None for remote servers)
1050
+ remote_url = server_config.get_remote_url()
1051
+ if not remote_url:
1052
+ raise HTTPException(
1053
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
1054
+ )
1055
+
1056
+ # Refresh OAuth status
1057
+ oauth_info = await oauth_manager.refresh_server_status(server_name, remote_url)
1058
+
1059
+ return {
1060
+ "status": "refreshed",
1061
+ "server_name": oauth_info.server_name,
1062
+ "oauth_status": oauth_info.status.value,
1063
+ "error_message": oauth_info.error_message,
1064
+ "token_expires_at": oauth_info.token_expires_at,
1065
+ "has_refresh_token": oauth_info.has_refresh_token,
1066
+ "scopes": oauth_info.scopes,
1067
+ }
1068
+
1069
+ except HTTPException:
1070
+ raise
1071
+ except Exception as e:
1072
+ log.error(f"Failed to refresh OAuth status for {server_name}: {e}")
1073
+ raise HTTPException(
1074
+ status_code=500,
1075
+ detail=f"Failed to refresh OAuth status: {str(e)}",
1076
+ ) from e