open-edison 0.1.17__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
src/server.py CHANGED
@@ -9,36 +9,37 @@ import asyncio
9
9
  import json
10
10
  import traceback
11
11
  from collections.abc import Awaitable, Callable, Coroutine
12
+ from contextlib import suppress
12
13
  from pathlib import Path
13
- from typing import Any, cast
14
+ from typing import Any, Literal, cast
14
15
 
15
16
  import uvicorn
16
17
  from fastapi import Depends, FastAPI, HTTPException, status
17
18
  from fastapi.middleware.cors import CORSMiddleware
18
- from fastapi.responses import FileResponse, JSONResponse, Response
19
+ from fastapi.responses import (
20
+ FileResponse,
21
+ JSONResponse,
22
+ RedirectResponse,
23
+ Response,
24
+ StreamingResponse,
25
+ )
19
26
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
20
27
  from fastapi.staticfiles import StaticFiles
21
28
  from fastmcp import FastMCP
22
29
  from loguru import logger as log
23
30
  from pydantic import BaseModel, Field
24
31
 
25
- from src.config import MCPServerConfig, config
32
+ from src import events
33
+ from src.config import Config, MCPServerConfig
26
34
  from src.config import get_config_dir as _get_cfg_dir # type: ignore[attr-defined]
27
35
  from src.middleware.session_tracking import (
28
36
  MCPSessionModel,
29
37
  create_db_session,
30
38
  )
39
+ from src.oauth_manager import OAuthStatus, get_oauth_manager
31
40
  from src.single_user_mcp import SingleUserMCP
32
41
  from src.telemetry import initialize_telemetry, set_servers_installed
33
42
 
34
-
35
- def _get_current_config():
36
- """Get current config, allowing for test mocking."""
37
- from src.config import config as current_config
38
-
39
- return current_config
40
-
41
-
42
43
  # Module-level dependency singletons
43
44
  _security = HTTPBearer()
44
45
  _auth_dependency = Depends(_security)
@@ -102,6 +103,15 @@ class OpenEdisonProxy:
102
103
  StaticFiles(directory=str(assets_dir), html=False),
103
104
  name="dashboard-assets",
104
105
  )
106
+ # Serve service worker at root path for registration at /sw.js
107
+ sw_path = static_dir / "sw.js"
108
+ if sw_path.exists():
109
+
110
+ async def _sw() -> FileResponse: # type: ignore[override]
111
+ # Service workers must be served from the origin root scope
112
+ return FileResponse(str(sw_path), media_type="application/javascript")
113
+
114
+ app.add_api_route("/sw.js", _sw, methods=["GET"]) # type: ignore[arg-type]
105
115
  favicon_path = static_dir / "favicon.ico"
106
116
  if favicon_path.exists():
107
117
 
@@ -116,49 +126,29 @@ class OpenEdisonProxy:
116
126
  log.warning(f"Failed to mount dashboard static assets: {mount_err}")
117
127
 
118
128
  # Special-case: serve SQLite db and config JSONs for dashboard (prod replacement for Vite @fs)
119
- def _resolve_db_path() -> Path | None:
120
- try:
121
- # Try configured database path first
122
- db_cfg = getattr(config.logging, "database_path", None)
123
- if isinstance(db_cfg, str) and db_cfg:
124
- db_path = Path(db_cfg)
125
- if db_path.is_absolute() and db_path.exists():
126
- return db_path
127
- # Check relative to config dir
128
- try:
129
- cfg_dir = _get_cfg_dir()
130
- except Exception:
131
- cfg_dir = Path.cwd()
132
- rel1 = cfg_dir / db_path
133
- if rel1.exists():
134
- return rel1
135
- # Also check relative to cwd as a fallback
136
- rel2 = Path.cwd() / db_path
137
- if rel2.exists():
138
- return rel2
139
- except Exception:
140
- pass
141
-
142
- # Fallback common locations
129
+ def _resolve_db_path() -> Path:
130
+ # Try configured database path first
131
+ db_cfg = Config().logging.database_path
132
+ db_path = Path(db_cfg)
133
+ if db_path.is_absolute() and db_path.exists():
134
+ return db_path
135
+ # Check relative to config dir
143
136
  try:
144
137
  cfg_dir = _get_cfg_dir()
145
138
  except Exception:
146
139
  cfg_dir = Path.cwd()
147
- candidates = [
148
- cfg_dir / "sessions.db",
149
- cfg_dir / "sessions.db",
150
- Path.cwd() / "edison.db",
151
- Path.cwd() / "sessions.db",
152
- ]
153
- for c in candidates:
154
- if c.exists():
155
- return c
156
- return None
140
+ rel1 = cfg_dir / db_path
141
+ if rel1.exists():
142
+ return rel1
143
+ # Also check relative to cwd as a fallback
144
+ rel2 = Path.cwd() / db_path
145
+ if rel2.exists():
146
+ return rel2
147
+
148
+ raise FileNotFoundError(f"Database file not found at {db_path}")
157
149
 
158
150
  async def _serve_db() -> FileResponse: # type: ignore[override]
159
151
  db_file = _resolve_db_path()
160
- if db_file is None:
161
- raise HTTPException(status_code=404, detail="Database file not found")
162
152
  return FileResponse(str(db_file), media_type="application/octet-stream")
163
153
 
164
154
  # Provide multiple paths the SPA might attempt (both edison.db legacy and sessions.db canonical)
@@ -200,12 +190,10 @@ class OpenEdisonProxy:
200
190
  # 2) Repository/package defaults next to src/
201
191
  repo_candidate = Path(__file__).parent.parent / filename
202
192
  if repo_candidate.exists():
203
- # Bootstrap a copy into config dir when possible
204
- try:
193
+ # Bootstrap a copy into config dir when possible (best effort)
194
+ with suppress(Exception):
205
195
  target.parent.mkdir(parents=True, exist_ok=True)
206
196
  target.write_text(repo_candidate.read_text(encoding="utf-8"), encoding="utf-8")
207
- except Exception:
208
- pass
209
197
  return target if target.exists() else repo_candidate
210
198
 
211
199
  # 3) Fall back to config dir path (will be created on save)
@@ -268,6 +256,46 @@ class OpenEdisonProxy:
268
256
 
269
257
  app.add_api_route("/__save_json__", _save_json, methods=["POST"]) # type: ignore[arg-type]
270
258
 
259
+ # SSE events endpoint
260
+ async def _events() -> StreamingResponse: # type: ignore[override]
261
+ queue = await events.subscribe()
262
+ return StreamingResponse(
263
+ events.sse_stream(queue),
264
+ media_type="text/event-stream",
265
+ )
266
+
267
+ app.add_api_route("/events", _events, methods=["GET"]) # type: ignore[arg-type]
268
+
269
+ # Approval endpoint to allow an item for the rest of the session
270
+ class _ApprovalBody(BaseModel):
271
+ session_id: str
272
+ kind: Literal["tool", "resource", "prompt"]
273
+ name: str
274
+
275
+ async def _approve(body: _ApprovalBody) -> dict[str, Any]: # type: ignore[override]
276
+ try:
277
+ # Mark approval once; no persistent overrides
278
+ await events.approve_once(body.session_id, body.kind, body.name)
279
+
280
+ # Notify listeners (best effort, log failure)
281
+ events.fire_and_forget(
282
+ {
283
+ "type": "mcp_approved_once",
284
+ "session_id": body.session_id,
285
+ "kind": body.kind,
286
+ "name": body.name,
287
+ }
288
+ )
289
+
290
+ return {"status": "ok"}
291
+ except HTTPException:
292
+ raise
293
+ except Exception as e: # noqa: BLE001
294
+ log.error(f"Approval failed: {e}")
295
+ raise HTTPException(status_code=500, detail="Failed to approve item") from e
296
+
297
+ app.add_api_route("/api/approve", _approve, methods=["POST"]) # type: ignore[arg-type]
298
+
271
299
  # Catch-all for @fs patterns; serve known db and json filenames
272
300
  async def _serve_fs_path(rest: str): # type: ignore[override]
273
301
  target = rest.strip("/")
@@ -282,6 +310,12 @@ class OpenEdisonProxy:
282
310
  app.add_api_route("/@fs/{rest:path}", _serve_fs_path, methods=["GET"]) # type: ignore[arg-type]
283
311
  app.add_api_route("/%40fs/{rest:path}", _serve_fs_path, methods=["GET"]) # type: ignore[arg-type]
284
312
 
313
+ # Redirect root to dashboard
314
+ async def _root_redirect() -> RedirectResponse: # type: ignore[override]
315
+ return RedirectResponse(url="/dashboard")
316
+
317
+ app.add_api_route("/", _root_redirect, methods=["GET"]) # type: ignore[arg-type]
318
+
285
319
  return app
286
320
 
287
321
  def _build_backend_config_top(
@@ -315,7 +349,7 @@ class OpenEdisonProxy:
315
349
  await self.single_user_mcp.initialize()
316
350
 
317
351
  # Emit snapshot of enabled servers
318
- enabled_count = len([s for s in config.mcp_servers if s.enabled])
352
+ enabled_count = len([s for s in Config().mcp_servers if s.enabled])
319
353
  set_servers_installed(enabled_count)
320
354
 
321
355
  # Add CORS middleware to FastAPI
@@ -335,7 +369,7 @@ class OpenEdisonProxy:
335
369
  app=self.fastapi_app,
336
370
  host=self.host,
337
371
  port=self.port + 1,
338
- log_level=config.logging.level.lower(),
372
+ log_level=Config().logging.level.lower(),
339
373
  )
340
374
  fastapi_server = uvicorn.Server(fastapi_config)
341
375
  servers_to_run.append(fastapi_server.serve())
@@ -346,7 +380,7 @@ class OpenEdisonProxy:
346
380
  app=mcp_app,
347
381
  host=self.host,
348
382
  port=self.port,
349
- log_level=config.logging.level.lower(),
383
+ log_level=Config().logging.level.lower(),
350
384
  )
351
385
  fastmcp_server = uvicorn.Server(fastmcp_config)
352
386
  servers_to_run.append(fastmcp_server.serve())
@@ -363,7 +397,12 @@ class OpenEdisonProxy:
363
397
  "/mcp/status",
364
398
  self.mcp_status,
365
399
  methods=["GET"],
366
- dependencies=[Depends(self.verify_api_key)],
400
+ )
401
+ # Endpoint to notify server that permissions JSONs changed; invalidate caches
402
+ app.add_api_route(
403
+ "/api/permissions-changed",
404
+ self.permissions_changed,
405
+ methods=["POST"],
367
406
  )
368
407
  app.add_api_route(
369
408
  "/mcp/validate",
@@ -377,6 +416,24 @@ class OpenEdisonProxy:
377
416
  methods=["GET"],
378
417
  dependencies=[Depends(self.verify_api_key)],
379
418
  )
419
+ app.add_api_route(
420
+ "/mcp/reinitialize",
421
+ self.reinitialize_mcp_servers,
422
+ methods=["POST"],
423
+ dependencies=[Depends(self.verify_api_key)],
424
+ )
425
+ app.add_api_route(
426
+ "/mcp/mount/{server_name}",
427
+ self.mount_mcp_server,
428
+ methods=["POST"],
429
+ dependencies=[Depends(self.verify_api_key)],
430
+ )
431
+ app.add_api_route(
432
+ "/mcp/mount/{server_name}",
433
+ self.unmount_mcp_server,
434
+ methods=["DELETE"],
435
+ dependencies=[Depends(self.verify_api_key)],
436
+ )
380
437
  # Public sessions endpoint (no auth) for simple local dashboard
381
438
  app.add_api_route(
382
439
  "/sessions",
@@ -384,6 +441,38 @@ class OpenEdisonProxy:
384
441
  methods=["GET"],
385
442
  )
386
443
 
444
+ # OAuth endpoints
445
+ app.add_api_route(
446
+ "/mcp/oauth/status",
447
+ self.get_oauth_status_all,
448
+ methods=["GET"],
449
+ dependencies=[Depends(self.verify_api_key)],
450
+ )
451
+ app.add_api_route(
452
+ "/mcp/oauth/status/{server_name}",
453
+ self.get_oauth_status,
454
+ methods=["GET"],
455
+ dependencies=[Depends(self.verify_api_key)],
456
+ )
457
+ app.add_api_route(
458
+ "/mcp/oauth/test-connection/{server_name}",
459
+ self.oauth_test_connection,
460
+ methods=["POST"],
461
+ dependencies=[Depends(self.verify_api_key)],
462
+ )
463
+ app.add_api_route(
464
+ "/mcp/oauth/tokens/{server_name}",
465
+ self.oauth_clear_tokens,
466
+ methods=["DELETE"],
467
+ dependencies=[Depends(self.verify_api_key)],
468
+ )
469
+ app.add_api_route(
470
+ "/mcp/oauth/refresh/{server_name}",
471
+ self.oauth_refresh_status,
472
+ methods=["POST"],
473
+ dependencies=[Depends(self.verify_api_key)],
474
+ )
475
+
387
476
  async def verify_api_key(
388
477
  self, credentials: HTTPAuthorizationCredentials = _auth_dependency
389
478
  ) -> str:
@@ -392,11 +481,27 @@ class OpenEdisonProxy:
392
481
 
393
482
  Returns the API key string if valid, otherwise raises HTTPException.
394
483
  """
395
- current_config = _get_current_config()
396
- if credentials.credentials != current_config.server.api_key:
484
+ if credentials.credentials != Config().server.api_key:
397
485
  raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")
398
486
  return credentials.credentials
399
487
 
488
+ async def permissions_changed(self) -> dict[str, Any]:
489
+ """Invalidate SingleUserMCP manager caches after permissions JSON changed.
490
+
491
+ This attempts to clear any known cache methods on the internal managers and then
492
+ warms the lists to ensure subsequent list calls reflect current state.
493
+ """
494
+ try:
495
+ mcp = self.single_user_mcp
496
+ # Warm managers so any internal caches are refreshed
497
+ await mcp._tool_manager.list_tools() # type: ignore[attr-defined]
498
+ await mcp._resource_manager.list_resources() # type: ignore[attr-defined]
499
+ await mcp._prompt_manager.list_prompts() # type: ignore[attr-defined]
500
+ return {"status": "ok"}
501
+ except Exception as e: # noqa: BLE001
502
+ log.error(f"Failed to process permissions-changed: {e}")
503
+ raise HTTPException(status_code=500, detail="Failed to invalidate caches") from e
504
+
400
505
  async def mcp_status(self) -> dict[str, list[dict[str, Any]]]:
401
506
  """Get status of configured MCP servers (auth required)."""
402
507
  return {
@@ -405,34 +510,13 @@ class OpenEdisonProxy:
405
510
  "name": server.name,
406
511
  "enabled": server.enabled,
407
512
  }
408
- for server in config.mcp_servers
513
+ for server in Config().mcp_servers
409
514
  ]
410
515
  }
411
516
 
412
- def _handle_server_operation_error(
413
- self, operation: str, server_name: str, error: Exception
414
- ) -> HTTPException:
415
- """Handle common server operation errors."""
416
- log.error(f"Failed to {operation} server {server_name}: {error}")
417
- return HTTPException(
418
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
419
- detail=f"Failed to {operation} server: {str(error)}",
420
- )
421
-
422
- def _find_server_config(self, server_name: str) -> MCPServerConfig:
423
- """Find server configuration by name."""
424
- current_config = _get_current_config()
425
- for config_server in current_config.mcp_servers:
426
- if config_server.name == server_name:
427
- return config_server
428
- raise HTTPException(
429
- status_code=404,
430
- detail=f"Server configuration not found: {server_name}",
431
- )
432
-
433
517
  async def health_check(self) -> dict[str, Any]:
434
518
  """Health check endpoint"""
435
- return {"status": "healthy", "version": "0.1.0", "mcp_servers": len(config.mcp_servers)}
519
+ return {"status": "healthy", "version": "0.1.0", "mcp_servers": len(Config().mcp_servers)}
436
520
 
437
521
  async def get_mounted_servers(self) -> dict[str, Any]:
438
522
  """Get list of currently mounted MCP servers."""
@@ -446,6 +530,65 @@ class OpenEdisonProxy:
446
530
  detail=f"Failed to get mounted servers: {str(e)}",
447
531
  ) from e
448
532
 
533
+ async def reinitialize_mcp_servers(self) -> dict[str, Any]:
534
+ """Reinitialize all MCP servers by creating a fresh instance and reloading config.
535
+
536
+ Returns a JSON payload summarizing the final mounted servers so callers can display status.
537
+ """
538
+ try:
539
+ log.info("🔄 Reinitializing MCP servers via API endpoint")
540
+
541
+ # Create a completely new SingleUserMCP instance to ensure clean state
542
+ old_mcp = self.single_user_mcp
543
+ self.single_user_mcp = SingleUserMCP()
544
+ del old_mcp
545
+
546
+ # Initialize the new instance with fresh config
547
+ await self.single_user_mcp.initialize()
548
+
549
+ # Summarize final mounted servers
550
+ try:
551
+ mounted = await self.single_user_mcp.get_mounted_servers()
552
+ except Exception:
553
+ log.error("Failed to get mounted servers")
554
+ mounted = []
555
+
556
+ names = [m.get("name", "") for m in mounted]
557
+ return {
558
+ "status": "ok",
559
+ "total_final_mounted": len(mounted),
560
+ "mounted_servers": names,
561
+ }
562
+
563
+ except Exception as e:
564
+ log.error(f"❌ Failed to reinitialize MCP servers: {e}")
565
+ raise HTTPException(
566
+ status_code=500,
567
+ detail=f"Failed to reinitialize MCP servers: {str(e)}",
568
+ ) from e
569
+
570
+ async def mount_mcp_server(self, server_name: str) -> dict[str, Any]:
571
+ """Mount a single MCP server by name (auth required)."""
572
+ try:
573
+ ok = await self.single_user_mcp.mount_server(server_name)
574
+ return {"mounted": bool(ok), "server": server_name}
575
+ except Exception as e:
576
+ log.error(f"❌ Failed to mount server {server_name}: {e}")
577
+ raise HTTPException(
578
+ status_code=500, detail=f"Failed to mount server {server_name}: {str(e)}"
579
+ ) from e
580
+
581
+ async def unmount_mcp_server(self, server_name: str) -> dict[str, Any]:
582
+ """Unmount a previously mounted MCP server by name (auth required)."""
583
+ try:
584
+ ok = await self.single_user_mcp.unmount(server_name)
585
+ return {"unmounted": bool(ok), "server": server_name}
586
+ except Exception as e:
587
+ log.error(f"❌ Failed to unmount server {server_name}: {e}")
588
+ raise HTTPException(
589
+ status_code=500, detail=f"Failed to unmount server {server_name}: {str(e)}"
590
+ ) from e
591
+
449
592
  async def get_sessions(self) -> dict[str, Any]:
450
593
  """Return recent MCP session summaries from local SQLite.
451
594
 
@@ -551,9 +694,9 @@ class OpenEdisonProxy:
551
694
  "args": body.args,
552
695
  "has_roots": bool(body.roots),
553
696
  },
554
- "tools": [self._safe_tool(t) for t in tools],
697
+ "tools": [self._safe_tool(t, prefix=server_name) for t in tools],
555
698
  "resources": [self._safe_resource(r) for r in resources],
556
- "prompts": [self._safe_prompt(p) for p in prompts],
699
+ "prompts": [self._safe_prompt(p, prefix=server_name) for p in prompts],
557
700
  }
558
701
  except TimeoutError as te: # noqa: PERF203
559
702
  log.error(f"MCP validation timed out: {te}\n{traceback.format_exc()}")
@@ -590,18 +733,6 @@ class OpenEdisonProxy:
590
733
  except Exception as cleanup_err: # noqa: BLE001
591
734
  log.debug(f"Validator cleanup skipped/failed: {cleanup_err}")
592
735
 
593
- def _build_backend_config(
594
- self, server_name: str, body: "OpenEdisonProxy._ValidateRequest"
595
- ) -> dict[str, Any]:
596
- backend_entry: dict[str, Any] = {
597
- "command": body.command,
598
- "args": body.args,
599
- "env": body.env or {},
600
- }
601
- if body.roots:
602
- backend_entry["roots"] = body.roots
603
- return {"mcpServers": {server_name: backend_entry}}
604
-
605
736
  async def _list_all_capabilities(
606
737
  self, server: FastMCP[Any], body: "OpenEdisonProxy._ValidateRequest"
607
738
  ) -> tuple[list[Any], list[Any], list[Any]]:
@@ -624,10 +755,13 @@ class OpenEdisonProxy:
624
755
  timeout = body.timeout_s if isinstance(body.timeout_s, (int | float)) else 20.0
625
756
  return await asyncio.wait_for(list_all(), timeout=timeout)
626
757
 
627
- def _safe_tool(self, t: Any) -> dict[str, Any]:
758
+ def _safe_tool(self, t: Any, prefix: str) -> dict[str, Any]:
628
759
  name = getattr(t, "name", None)
629
760
  description = getattr(t, "description", None)
630
- return {"name": str(name) if name is not None else "", "description": description}
761
+ return {
762
+ "name": prefix + "_" + str(name) if name is not None else "",
763
+ "description": description,
764
+ }
631
765
 
632
766
  def _safe_resource(self, r: Any) -> dict[str, Any]:
633
767
  uri = getattr(r, "uri", None)
@@ -638,7 +772,300 @@ class OpenEdisonProxy:
638
772
  description = getattr(r, "description", None)
639
773
  return {"uri": uri_str, "description": description}
640
774
 
641
- def _safe_prompt(self, p: Any) -> dict[str, Any]:
775
+ def _safe_prompt(self, p: Any, prefix: str) -> dict[str, Any]:
642
776
  name = getattr(p, "name", None)
643
777
  description = getattr(p, "description", None)
644
- return {"name": str(name) if name is not None else "", "description": description}
778
+ return {
779
+ "name": prefix + "_" + str(name) if name is not None else "",
780
+ "description": description,
781
+ }
782
+
783
+ # ---- OAuth endpoints ----
784
+
785
+ async def get_oauth_status_all(self) -> dict[str, Any]:
786
+ """Get OAuth status for all configured MCP servers."""
787
+ try:
788
+ oauth_manager = get_oauth_manager()
789
+
790
+ servers_info = {}
791
+ for server_config in Config().mcp_servers:
792
+ server_name = server_config.name
793
+ info = oauth_manager.get_server_info(server_name)
794
+
795
+ if info:
796
+ # Use cached OAuth info
797
+ servers_info[server_name] = {
798
+ "server_name": info.server_name,
799
+ "status": info.status.value,
800
+ "error_message": info.error_message,
801
+ "token_expires_at": info.token_expires_at,
802
+ "has_refresh_token": info.has_refresh_token,
803
+ "scopes": info.scopes,
804
+ }
805
+ else:
806
+ # OAuth status not checked yet - check proactively for remote servers
807
+ if server_config.is_remote_server():
808
+ remote_url = server_config.get_remote_url()
809
+ log.info(f"🔍 Proactively checking OAuth for remote server {server_name}")
810
+
811
+ # Check OAuth requirements for this remote server
812
+ oauth_info = await oauth_manager.check_oauth_requirement(
813
+ server_name, remote_url
814
+ )
815
+
816
+ servers_info[server_name] = {
817
+ "server_name": oauth_info.server_name,
818
+ "status": oauth_info.status.value,
819
+ "error_message": oauth_info.error_message,
820
+ "token_expires_at": oauth_info.token_expires_at,
821
+ "has_refresh_token": oauth_info.has_refresh_token,
822
+ "scopes": oauth_info.scopes,
823
+ }
824
+ else:
825
+ # Local server - no OAuth needed
826
+ servers_info[server_name] = {
827
+ "server_name": server_name,
828
+ "status": OAuthStatus.NOT_REQUIRED.value,
829
+ "error_message": None,
830
+ "token_expires_at": None,
831
+ "has_refresh_token": False,
832
+ "scopes": None,
833
+ }
834
+
835
+ return {"oauth_status": servers_info}
836
+
837
+ except Exception as e:
838
+ log.error(f"Failed to get OAuth status for all servers: {e}")
839
+ raise HTTPException(
840
+ status_code=500,
841
+ detail=f"Failed to get OAuth status: {str(e)}",
842
+ ) from e
843
+
844
+ def _find_server_config(self, server_name: str) -> MCPServerConfig:
845
+ """Find server configuration by name."""
846
+ for config_server in Config().mcp_servers:
847
+ if config_server.name == server_name:
848
+ return config_server
849
+ raise HTTPException(
850
+ status_code=404,
851
+ detail=f"Server configuration not found: {server_name}",
852
+ )
853
+
854
+ async def get_oauth_status(self, server_name: str) -> dict[str, Any]:
855
+ """Get OAuth status for a specific MCP server."""
856
+ try:
857
+ server_config = self._find_server_config(server_name)
858
+ oauth_manager = get_oauth_manager()
859
+
860
+ # Get the remote URL if this is a remote server
861
+ remote_url = server_config.get_remote_url()
862
+
863
+ # Check or refresh OAuth status
864
+ oauth_info = await oauth_manager.check_oauth_requirement(server_name, remote_url)
865
+
866
+ return {
867
+ "server_name": oauth_info.server_name,
868
+ "mcp_url": oauth_info.mcp_url,
869
+ "status": oauth_info.status.value,
870
+ "error_message": oauth_info.error_message,
871
+ "token_expires_at": oauth_info.token_expires_at,
872
+ "has_refresh_token": oauth_info.has_refresh_token,
873
+ "scopes": oauth_info.scopes,
874
+ "client_name": oauth_info.client_name,
875
+ }
876
+
877
+ except HTTPException:
878
+ raise
879
+ except Exception as e:
880
+ log.error(f"Failed to get OAuth status for {server_name}: {e}")
881
+ raise HTTPException(
882
+ status_code=500,
883
+ detail=f"Failed to get OAuth status: {str(e)}",
884
+ ) from e
885
+
886
+ class _OAuthAuthorizeRequest(BaseModel):
887
+ scopes: list[str] | None = Field(None, description="OAuth scopes to request")
888
+ client_name: str | None = Field(None, description="Client name for OAuth registration")
889
+
890
+ async def oauth_test_connection(
891
+ self, server_name: str, body: _OAuthAuthorizeRequest | None = None
892
+ ) -> dict[str, Any]:
893
+ """
894
+ Test connection to a remote MCP server, triggering OAuth flow if needed.
895
+
896
+ This endpoint creates a temporary FastMCP client with OAuth authentication
897
+ and attempts to make a connection. This automatically triggers FastMCP's
898
+ OAuth flow, which will open a browser for user authorization.
899
+ """
900
+ try:
901
+ server_config = self._find_server_config(server_name)
902
+ oauth_manager = get_oauth_manager()
903
+
904
+ # Check if this is a remote server
905
+ if not server_config.is_remote_server():
906
+ raise HTTPException(
907
+ status_code=400,
908
+ detail=f"Server {server_name} is a local server and does not support OAuth",
909
+ )
910
+
911
+ # Get the remote URL
912
+ remote_url = server_config.get_remote_url()
913
+ if not remote_url:
914
+ raise HTTPException(
915
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
916
+ )
917
+
918
+ # Get OAuth configuration
919
+ scopes = None
920
+ client_name = None
921
+
922
+ if body:
923
+ scopes = body.scopes
924
+ client_name = body.client_name
925
+
926
+ # Use server config OAuth settings if not provided in request
927
+ if not scopes and server_config.oauth_scopes:
928
+ scopes = server_config.oauth_scopes
929
+ if not client_name and server_config.oauth_client_name:
930
+ client_name = server_config.oauth_client_name
931
+
932
+ log.info(f"🔗 Testing connection to {server_name} at {remote_url}")
933
+
934
+ # Import FastMCP client for testing
935
+ from fastmcp import Client as FastMCPClient
936
+ from fastmcp.client.auth import OAuth
937
+
938
+ # Create OAuth auth object
939
+ oauth = OAuth(
940
+ mcp_url=remote_url,
941
+ scopes=scopes,
942
+ client_name=client_name or "OpenEdison MCP Gateway",
943
+ token_storage_cache_dir=oauth_manager.cache_dir,
944
+ )
945
+
946
+ # Create a temporary client and test the connection
947
+ # This will automatically trigger OAuth flow if tokens don't exist
948
+ try:
949
+ async with FastMCPClient(remote_url, auth=oauth) as client:
950
+ # Try to ping the server - this triggers OAuth if needed
951
+ log.info(
952
+ f"🔐 Attempting to connect to {server_name} (may open browser for OAuth)..."
953
+ )
954
+ await client.ping()
955
+ log.info(f"✅ Successfully connected to {server_name}")
956
+
957
+ # Update OAuth status in manager
958
+ await oauth_manager.check_oauth_requirement(server_name, remote_url)
959
+
960
+ return {
961
+ "status": "connection_successful",
962
+ "message": f"Successfully connected to {server_name}. OAuth tokens are now cached.",
963
+ "server_name": server_name,
964
+ }
965
+
966
+ except Exception as e:
967
+ log.error(f"❌ Failed to connect to {server_name}: {e}")
968
+
969
+ # Check if this was an OAuth-related error
970
+ error_message = str(e)
971
+ if "oauth" in error_message.lower() or "authorization" in error_message.lower():
972
+ return {
973
+ "status": "oauth_required",
974
+ "message": f"OAuth authorization completed for {server_name}. Please try connecting again.",
975
+ "server_name": server_name,
976
+ }
977
+ raise HTTPException(
978
+ status_code=500, detail=f"Connection test failed: {error_message}"
979
+ ) from None
980
+
981
+ except HTTPException:
982
+ raise
983
+ except Exception as e:
984
+ log.error(f"Failed to test connection for {server_name}: {e}")
985
+ raise HTTPException(
986
+ status_code=500,
987
+ detail=f"Failed to test connection: {str(e)}",
988
+ ) from e
989
+
990
+ async def oauth_clear_tokens(self, server_name: str) -> dict[str, Any]:
991
+ """Clear stored OAuth tokens for a server."""
992
+ try:
993
+ server_config = self._find_server_config(server_name)
994
+ oauth_manager = get_oauth_manager()
995
+
996
+ # Check if this is a remote server
997
+ if not server_config.is_remote_server():
998
+ raise HTTPException(
999
+ status_code=400,
1000
+ detail=f"Server {server_name} is a local server and does not support OAuth",
1001
+ )
1002
+
1003
+ # Get the remote URL
1004
+ remote_url = server_config.get_remote_url()
1005
+ if not remote_url:
1006
+ raise HTTPException(
1007
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
1008
+ )
1009
+
1010
+ success = oauth_manager.clear_tokens(server_name, remote_url)
1011
+
1012
+ if success:
1013
+ return {
1014
+ "status": "success",
1015
+ "message": f"OAuth tokens cleared for {server_name}",
1016
+ "server_name": server_name,
1017
+ }
1018
+ raise HTTPException(
1019
+ status_code=500, detail=f"Failed to clear OAuth tokens for {server_name}"
1020
+ )
1021
+
1022
+ except HTTPException:
1023
+ raise
1024
+ except Exception as e:
1025
+ log.error(f"Failed to clear OAuth tokens for {server_name}: {e}")
1026
+ raise HTTPException(
1027
+ status_code=500,
1028
+ detail=f"Failed to clear OAuth tokens: {str(e)}",
1029
+ ) from e
1030
+
1031
+ async def oauth_refresh_status(self, server_name: str) -> dict[str, Any]:
1032
+ """Refresh OAuth status for a server."""
1033
+ try:
1034
+ server_config = self._find_server_config(server_name)
1035
+ oauth_manager = get_oauth_manager()
1036
+
1037
+ # Check if this is a remote server
1038
+ if not server_config.is_remote_server():
1039
+ raise HTTPException(
1040
+ status_code=400,
1041
+ detail=f"Server {server_name} is a local server and does not support OAuth",
1042
+ )
1043
+
1044
+ # Get the remote URL (now guaranteed to be non-None for remote servers)
1045
+ remote_url = server_config.get_remote_url()
1046
+ if not remote_url:
1047
+ raise HTTPException(
1048
+ status_code=400, detail=f"Server {server_name} does not have a valid remote URL"
1049
+ )
1050
+
1051
+ # Refresh OAuth status
1052
+ oauth_info = await oauth_manager.refresh_server_status(server_name, remote_url)
1053
+
1054
+ return {
1055
+ "status": "refreshed",
1056
+ "server_name": oauth_info.server_name,
1057
+ "oauth_status": oauth_info.status.value,
1058
+ "error_message": oauth_info.error_message,
1059
+ "token_expires_at": oauth_info.token_expires_at,
1060
+ "has_refresh_token": oauth_info.has_refresh_token,
1061
+ "scopes": oauth_info.scopes,
1062
+ }
1063
+
1064
+ except HTTPException:
1065
+ raise
1066
+ except Exception as e:
1067
+ log.error(f"Failed to refresh OAuth status for {server_name}: {e}")
1068
+ raise HTTPException(
1069
+ status_code=500,
1070
+ detail=f"Failed to refresh OAuth status: {str(e)}",
1071
+ ) from e