mcpower-proxy 0.0.58__py3-none-any.whl → 0.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. ide_tools/__init__.py +12 -0
  2. ide_tools/common/__init__.py +6 -0
  3. ide_tools/common/hooks/__init__.py +6 -0
  4. ide_tools/common/hooks/init.py +125 -0
  5. ide_tools/common/hooks/output.py +64 -0
  6. ide_tools/common/hooks/prompt_submit.py +186 -0
  7. ide_tools/common/hooks/read_file.py +170 -0
  8. ide_tools/common/hooks/shell_execution.py +196 -0
  9. ide_tools/common/hooks/types.py +35 -0
  10. ide_tools/common/hooks/utils.py +276 -0
  11. ide_tools/cursor/__init__.py +11 -0
  12. ide_tools/cursor/constants.py +58 -0
  13. ide_tools/cursor/format.py +35 -0
  14. ide_tools/cursor/router.py +100 -0
  15. ide_tools/router.py +48 -0
  16. main.py +11 -4
  17. {mcpower_proxy-0.0.58.dist-info → mcpower_proxy-0.0.73.dist-info}/METADATA +15 -3
  18. mcpower_proxy-0.0.73.dist-info/RECORD +59 -0
  19. {mcpower_proxy-0.0.58.dist-info → mcpower_proxy-0.0.73.dist-info}/top_level.txt +1 -0
  20. modules/apis/security_policy.py +11 -6
  21. modules/decision_handler.py +219 -0
  22. modules/logs/audit_trail.py +22 -17
  23. modules/logs/logger.py +14 -18
  24. modules/redaction/redactor.py +112 -107
  25. modules/ui/__init__.py +1 -1
  26. modules/ui/confirmation.py +0 -1
  27. modules/utils/cli.py +36 -6
  28. modules/utils/ids.py +55 -10
  29. modules/utils/json.py +3 -3
  30. wrapper/__version__.py +1 -1
  31. wrapper/middleware.py +121 -210
  32. wrapper/server.py +19 -11
  33. mcpower_proxy-0.0.58.dist-info/RECORD +0 -43
  34. {mcpower_proxy-0.0.58.dist-info → mcpower_proxy-0.0.73.dist-info}/WHEEL +0 -0
  35. {mcpower_proxy-0.0.58.dist-info → mcpower_proxy-0.0.73.dist-info}/entry_points.txt +0 -0
  36. {mcpower_proxy-0.0.58.dist-info → mcpower_proxy-0.0.73.dist-info}/licenses/LICENSE +0 -0
modules/utils/ids.py CHANGED
@@ -2,6 +2,7 @@
2
2
  Utilities for generating event IDs, session IDs, app UIDs, and timing helpers
3
3
  """
4
4
  import os
5
+ import sys
5
6
  import time
6
7
  import uuid
7
8
  from pathlib import Path
@@ -23,6 +24,16 @@ def generate_event_id() -> str:
23
24
  return f"{timestamp}-{unique_part}"
24
25
 
25
26
 
27
+ def generate_prompt_id() -> str:
28
+ """
29
+ Generate truly-random 8-character prompt ID for user request correlation
30
+
31
+ Returns:
32
+ 8-character random ID string
33
+ """
34
+ return str(uuid.uuid4())[:8]
35
+
36
+
26
37
  def get_session_id() -> str:
27
38
  """
28
39
  Get session ID for the current process. Returns the same value for all calls
@@ -67,7 +78,8 @@ def _atomic_write_uuid(file_path: Path, new_uuid: str) -> bool:
67
78
  True if write succeeded, False if file exists
68
79
  """
69
80
  try:
70
- fd = os.open(str(file_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
81
+ mode = 0o666 if sys.platform == 'win32' else 0o600
82
+ fd = os.open(str(file_path), os.O_CREAT | os.O_EXCL | os.O_WRONLY, mode)
71
83
  try:
72
84
  os.write(fd, new_uuid.encode('utf-8'))
73
85
  finally:
@@ -91,7 +103,7 @@ def _get_or_create_uuid(uid_path: Path, logger, id_type: str) -> str:
91
103
  UUID string
92
104
  """
93
105
  uid_path.parent.mkdir(parents=True, exist_ok=True)
94
-
106
+
95
107
  max_attempts = 3
96
108
  for attempt in range(max_attempts):
97
109
  if uid_path.exists():
@@ -107,20 +119,53 @@ def _get_or_create_uuid(uid_path: Path, logger, id_type: str) -> str:
107
119
  time.sleep(0.1 * (2 ** attempt))
108
120
  continue
109
121
  raise
110
-
122
+
111
123
  new_uid = str(uuid.uuid4())
112
-
124
+
113
125
  if _atomic_write_uuid(uid_path, new_uid):
114
- logger.info(f"Generated {id_type}: {new_uid}")
126
+ logger.info(f"Generated {id_type}: {new_uid} at {uid_path}")
115
127
  return new_uid
116
-
117
- logger.debug(f"{id_type.title()} file created by another process, reading (attempt {attempt + 1}/{max_attempts})")
128
+
129
+ logger.debug(
130
+ f"{id_type.title()} file created by another process, reading (attempt {attempt + 1}/{max_attempts})")
118
131
  if attempt < max_attempts - 1:
119
132
  time.sleep(0.05)
120
-
133
+
121
134
  raise RuntimeError(f"Failed to get or create {id_type} after {max_attempts} attempts")
122
135
 
123
136
 
137
+ def get_home_mcpower_dir() -> Path:
138
+ """
139
+ Get the global MCPower directory path in user's home directory
140
+
141
+ Returns:
142
+ Path to ~/.mcpower directory
143
+ """
144
+ return Path.home() / ".mcpower"
145
+
146
+
147
+ def get_project_mcpower_dir(project_path: Optional[str] = None) -> str:
148
+ """
149
+ Get the MCPower directory path, with fallback to global ~/.mcpower
150
+
151
+ Args:
152
+ project_path: Optional project/workspace path. If None or invalid, falls back to ~/.mcpower
153
+
154
+ Returns:
155
+ Path to use for MCPower data (either project/.mcpower or ~/.mcpower)
156
+ """
157
+ if project_path:
158
+ try:
159
+ path = Path(project_path)
160
+ if path.exists() and path.is_dir():
161
+ return str(path)
162
+ except Exception:
163
+ pass
164
+
165
+ # Fallback to global ~/.mcpower
166
+ return str(get_home_mcpower_dir())
167
+
168
+
124
169
  def get_or_create_user_id(logger) -> str:
125
170
  """
126
171
  Get or create machine-wide user ID from ~/.mcpower/uid
@@ -132,7 +177,7 @@ def get_or_create_user_id(logger) -> str:
132
177
  Returns:
133
178
  User ID string
134
179
  """
135
- uid_path = Path.home() / ".mcpower" / "uid"
180
+ uid_path = get_home_mcpower_dir() / "uid"
136
181
  return _get_or_create_uuid(uid_path, logger, "user ID")
137
182
 
138
183
 
@@ -156,5 +201,5 @@ def read_app_uid(logger, project_folder_path: str) -> str:
156
201
  else:
157
202
  # Project-specific case
158
203
  uid_path = project_path / ".mcpower" / "app_uid"
159
-
204
+
160
205
  return _get_or_create_uuid(uid_path, logger, "app UID")
modules/utils/json.py CHANGED
@@ -52,7 +52,7 @@ def safe_json_dumps(obj: Any, **kwargs) -> str:
52
52
  # If it's a Pydantic BaseModel, use its built-in JSON serialization
53
53
  if isinstance(obj, BaseModel):
54
54
  return obj.model_dump_json(**kwargs)
55
-
55
+
56
56
  # If it's a dict or list that might contain Pydantic objects, use custom serializer
57
57
  def default_serializer(o):
58
58
  if isinstance(o, BaseModel):
@@ -72,7 +72,7 @@ def safe_json_dumps(obj: Any, **kwargs) -> str:
72
72
  return o.__dict__
73
73
  # Fallback to string representation
74
74
  return str(o)
75
-
75
+
76
76
  return json.dumps(obj, default=default_serializer, **kwargs)
77
77
 
78
78
 
@@ -117,4 +117,4 @@ def parse_jsonc(text: str) -> Any:
117
117
  return json.loads(text)
118
118
  except json.JSONDecodeError:
119
119
  # Re-raise the original JSONC error if JSON also fails
120
- raise json.JSONDecodeError(f"JSONC parsing failed: {str(e)}", text, 0)
120
+ raise json.JSONDecodeError(f"JSONC parsing failed: {str(e)}", text, 0)
wrapper/__version__.py CHANGED
@@ -3,4 +3,4 @@
3
3
  Wrapper MCP Server Version
4
4
  """
5
5
 
6
- __version__ = "0.0.58"
6
+ __version__ = "0.0.73"
wrapper/middleware.py CHANGED
@@ -2,6 +2,8 @@
2
2
  FastMCP middleware for security policy enforcement
3
3
  Implements pre/post interception for all MCP operations
4
4
  """
5
+ import asyncio
6
+ import sys
5
7
  import time
6
8
  import urllib.parse
7
9
  from datetime import datetime, timezone
@@ -11,22 +13,23 @@ from typing import Any, Dict, List, Optional
11
13
  from fastmcp.exceptions import FastMCPError
12
14
  from fastmcp.server.middleware.middleware import Middleware, MiddlewareContext, CallNext
13
15
  from fastmcp.server.proxy import ProxyClient
16
+ from httpx import HTTPStatusError
17
+ from mcp import ErrorData
18
+
19
+ from mcpower_shared.mcp_types import (create_policy_request, create_policy_response, AgentContext, EnvironmentContext,
20
+ InitRequest,
21
+ ServerRef, ToolRef)
14
22
  from modules.apis.security_policy import SecurityPolicyClient
23
+ from modules.decision_handler import DecisionHandler, DecisionEnforcementError
15
24
  from modules.logs.audit_trail import AuditTrailLogger
16
25
  from modules.logs.logger import MCPLogger
17
26
  from modules.redaction import redact
18
- from modules.ui.classes import ConfirmationRequest, DialogOptions, UserDecision
19
- from modules.ui.confirmation import UserConfirmationDialog, UserConfirmationError
20
27
  from modules.utils.copy import safe_copy
21
- from modules.utils.ids import generate_event_id, get_session_id, read_app_uid
28
+ from modules.utils.ids import generate_event_id, get_session_id, read_app_uid, get_project_mcpower_dir
22
29
  from modules.utils.json import safe_json_dumps, to_dict
23
30
  from modules.utils.mcp_configs import extract_wrapped_server_info
24
31
  from wrapper.schema import merge_input_schema_with_existing
25
32
 
26
- from mcpower_shared.mcp_types import (create_policy_request, create_policy_response, AgentContext, EnvironmentContext,
27
- InitRequest,
28
- ServerRef, ToolRef, UserConfirmation)
29
-
30
33
 
31
34
  class MockContext:
32
35
  """Mock context for internal operations"""
@@ -52,9 +55,7 @@ class MockContext:
52
55
  class SecurityMiddleware(Middleware):
53
56
  """FastMCP middleware for security policy enforcement"""
54
57
 
55
- app_id: str = ""
56
58
  _TOOLS_INIT_DEBOUNCE_SECONDS = 60
57
- _last_tools_init_time: Optional[float] = None
58
59
 
59
60
  def __init__(self,
60
61
  wrapped_server_configs: dict,
@@ -69,6 +70,10 @@ class SecurityMiddleware(Middleware):
69
70
  self.logger = logger
70
71
  self.audit_logger = audit_logger
71
72
  self.app_id = ""
73
+ self._last_workspace_root = None
74
+ self._last_tools_init_time: Optional[float] = None
75
+ self._tools_list_in_progress: Optional[asyncio.Task] = None
76
+ self._tools_list_lock = asyncio.Lock()
72
77
 
73
78
  self.wrapped_server_name, self.wrapped_server_transport = (
74
79
  extract_wrapped_server_info(self.wrapper_server_name, self.logger, self.wrapped_server_configs)
@@ -85,18 +90,36 @@ class SecurityMiddleware(Middleware):
85
90
  async def on_message(self, context: MiddlewareContext, call_next: CallNext) -> Any:
86
91
  self.logger.info(f"on_message: {redact(safe_json_dumps(context))}")
87
92
 
88
- # Ensure app_id is set before making API calls
89
- if not self.app_id:
93
+ # Skip workspace check for `initialize` calls to avoid premature app_uid changes.
94
+ # The `initialize` request doesn't contain workspace data, so checking it would
95
+ # cause unnecessary audit log flushes before the actual workspace init arrives.
96
+ if context.method != "initialize":
97
+ # Check workspace roots and re-initialize app_uid if workspace changed
90
98
  workspace_roots = await self._extract_workspace_roots(context)
91
- if workspace_roots:
92
- self.app_id = read_app_uid(logger=self.logger, project_folder_path=workspace_roots[0])
93
- else:
94
- # Fallback: read app_uid from ~/.mcpower when no workspace roots
95
- self.app_id = read_app_uid(logger=self.logger, project_folder_path=str(Path.home() / ".mcpower"))
96
- self.audit_logger.set_app_uid(self.app_id)
99
+ current_workspace_root = get_project_mcpower_dir(workspace_roots[0] if workspace_roots else None)
100
+ if current_workspace_root != self._last_workspace_root:
101
+ self.logger.debug(
102
+ f"Workspace root changed from {self._last_workspace_root} to {current_workspace_root}")
103
+ self._last_workspace_root = current_workspace_root
104
+ self.app_id = read_app_uid(logger=self.logger, project_folder_path=current_workspace_root)
105
+ self.audit_logger.set_app_uid(self.app_id)
97
106
 
98
107
  operation_type = "message"
99
- call_next_callback = call_next
108
+
109
+ async def call_next_wrapper(ctx):
110
+ try:
111
+ return await call_next(ctx)
112
+ except HTTPStatusError as e:
113
+ if e.response.status_code in (401, 403):
114
+ raise FastMCPError(ErrorData(
115
+ code=-32000,
116
+ message="Authentication required",
117
+ data={
118
+ "type": "unauthorized",
119
+ "details": "Please provide valid authentication credentials"
120
+ }
121
+ ))
122
+ raise e
100
123
 
101
124
  match context.type:
102
125
  case "request":
@@ -113,13 +136,13 @@ class SecurityMiddleware(Middleware):
113
136
  operation_type = "prompt"
114
137
  case "tools/list":
115
138
  # Special handling for tools/list - call /init instead of normal inspection
116
- return await self._handle_tools_list(context, call_next)
117
- case "resources/list" | "resources/templates/list" | "prompts/list":
118
- return await call_next_callback(context)
139
+ return await self._handle_tools_list(context, call_next_wrapper)
140
+ case "initialize" | "resources/list" | "resources/templates/list" | "prompts/list":
141
+ return await call_next_wrapper(context)
119
142
 
120
143
  return await self._handle_operation(
121
144
  context=context,
122
- call_next=call_next_callback,
145
+ call_next=call_next_wrapper,
123
146
  error_class=FastMCPError,
124
147
  operation_type=operation_type
125
148
  )
@@ -179,15 +202,15 @@ class SecurityMiddleware(Middleware):
179
202
  return await ProxyClient.default_progress_handler(progress, total, message)
180
203
 
181
204
  async def secure_log_handler(self, log_message):
182
- # FIXME: log_message should be redacted before logging,
205
+ # FIXME: log_message should be redacted before logging,
183
206
  self.logger.info(f"secure_log_handler: {str(log_message)[:100]}...")
184
207
  # FIXME: log_message should be reviewed with policy before forwarding
185
-
208
+
186
209
  # Handle case where log_message.data is a string instead of dict
187
210
  # The default_log_handler expects data to be a dict with 'msg' and 'extra' keys
188
211
  if hasattr(log_message, 'data') and isinstance(log_message.data, str):
189
212
  log_message = safe_copy(log_message, {'data': {'msg': log_message.data, 'extra': None}})
190
-
213
+
191
214
  return await ProxyClient.default_log_handler(log_message)
192
215
 
193
216
  async def _handle_operation(self, context: MiddlewareContext, call_next, error_class, operation_type: str):
@@ -220,19 +243,28 @@ class SecurityMiddleware(Middleware):
220
243
  prompt_id=prompt_id
221
244
  )
222
245
  on_inspect_request_duration = time.time() - on_inspect_request_start_time
223
- self.logger.info(f"PROFILE: {operation_type} id: {event_id} inspect_request duration: {on_inspect_request_duration:.2f} seconds")
246
+ self.logger.debug(
247
+ f"PROFILE: {operation_type} id: {event_id} inspect_request duration: {on_inspect_request_duration:.2f} seconds")
224
248
 
225
- await self._enforce_decision(
226
- decision=request_decision,
227
- error_class=error_class,
228
- base_message=f"{operation_type.title()} request blocked by security policy",
229
- is_request=True,
230
- event_id=event_id,
231
- tool_name=tool_name,
232
- content_data=tool_args,
233
- operation_type=operation_type,
234
- prompt_id=prompt_id
235
- )
249
+ try:
250
+ await DecisionHandler(
251
+ logger=self.logger,
252
+ audit_logger=self.audit_logger,
253
+ session_id=self.session_id,
254
+ app_id=self.app_id
255
+ ).enforce_decision(
256
+ decision=request_decision,
257
+ is_request=True,
258
+ event_id=event_id,
259
+ tool_name=tool_name,
260
+ content_data=tool_args,
261
+ operation_type=operation_type,
262
+ prompt_id=prompt_id,
263
+ server_name=self.wrapped_server_name,
264
+ error_message_prefix=f"{operation_type.title()} request blocked by security policy"
265
+ )
266
+ except DecisionEnforcementError as e:
267
+ raise error_class(str(e))
236
268
 
237
269
  self.audit_logger.log_event(
238
270
  "agent_request_forwarded",
@@ -249,7 +281,8 @@ class SecurityMiddleware(Middleware):
249
281
  # Call wrapped MCP with cleaned context (e.g., no wrapper args)
250
282
  result = await call_next(cleaned_context)
251
283
  on_call_next_duration = time.time() - on_call_next_start_time
252
- self.logger.info(f"PROFILE: {operation_type} id: {event_id} call_next duration: {on_call_next_duration:.2f} seconds")
284
+ self.logger.debug(
285
+ f"PROFILE: {operation_type} id: {event_id} call_next duration: {on_call_next_duration:.2f} seconds")
253
286
 
254
287
  response_content = self._extract_response_content(result)
255
288
 
@@ -274,19 +307,28 @@ class SecurityMiddleware(Middleware):
274
307
  prompt_id=prompt_id
275
308
  )
276
309
  on_inspect_response_duration = time.time() - on_inspect_response_start_time
277
- self.logger.info(f"PROFILE: {operation_type} id: {event_id} inspect_response duration: {on_inspect_response_duration:.2f} seconds")
310
+ self.logger.debug(
311
+ f"PROFILE: {operation_type} id: {event_id} inspect_response duration: {on_inspect_response_duration:.2f} seconds")
278
312
 
279
- await self._enforce_decision(
280
- decision=response_decision,
281
- error_class=error_class,
282
- base_message=f"{operation_type.title()} response blocked by security policy",
283
- is_request=False,
284
- event_id=event_id,
285
- tool_name=tool_name,
286
- content_data=response_content,
287
- operation_type=operation_type,
288
- prompt_id=prompt_id
289
- )
313
+ try:
314
+ await DecisionHandler(
315
+ logger=self.logger,
316
+ audit_logger=self.audit_logger,
317
+ session_id=self.session_id,
318
+ app_id=self.app_id
319
+ ).enforce_decision(
320
+ decision=response_decision,
321
+ is_request=False,
322
+ event_id=event_id,
323
+ tool_name=tool_name,
324
+ content_data=response_content,
325
+ operation_type=operation_type,
326
+ prompt_id=prompt_id,
327
+ server_name=self.wrapped_server_name,
328
+ error_message_prefix=f"{operation_type.title()} response blocked by security policy"
329
+ )
330
+ except DecisionEnforcementError as e:
331
+ raise error_class(str(e))
290
332
 
291
333
  self.audit_logger.log_event(
292
334
  "mcp_response_forwarded",
@@ -299,15 +341,30 @@ class SecurityMiddleware(Middleware):
299
341
  prompt_id=prompt_id
300
342
  )
301
343
  on_handle_operation_duration = time.time() - on_handle_operation_start_time
302
- self.logger.info(f"PROFILE: {operation_type} id: {event_id} duration: {on_handle_operation_duration:.2f} seconds")
344
+ self.logger.debug(
345
+ f"PROFILE: {operation_type} id: {event_id} duration: {on_handle_operation_duration:.2f} seconds")
303
346
  return result
304
347
 
305
348
  async def _handle_tools_list(self, context: MiddlewareContext, call_next: CallNext) -> Any:
306
- """Handle tools/list by calling /init API and modifying schemas"""
349
+ """Handle tools/list by calling /init API and modifying schemas with deduplication"""
307
350
  event_id = generate_event_id()
308
351
  on_handle_tools_list_start_time = time.time()
309
- result = await call_next(context)
310
- self.logger.info(f"PROFILE: tools/list call_next duration: {time.time() - on_handle_tools_list_start_time:.2f} seconds id: {event_id}")
352
+
353
+ async with self._tools_list_lock:
354
+ if not self._tools_list_in_progress or self._tools_list_in_progress.done():
355
+ self._tools_list_in_progress = asyncio.create_task(call_next(context))
356
+ shared_task = self._tools_list_in_progress
357
+
358
+ try:
359
+ result = await shared_task
360
+ except Exception as e:
361
+ async with self._tools_list_lock:
362
+ if self._tools_list_in_progress is shared_task:
363
+ self._tools_list_in_progress = None
364
+ raise
365
+ self.logger.debug(
366
+ f"PROFILE: tools/list call_next duration: {time.time() - on_handle_tools_list_start_time:.2f} seconds id: {event_id}")
367
+
311
368
  tools_list = None
312
369
  if isinstance(result, list):
313
370
  tools_list = result
@@ -337,11 +394,13 @@ class SecurityMiddleware(Middleware):
337
394
  enhanced_result = result
338
395
 
339
396
  on_handle_tools_list_duration = time.time() - on_handle_tools_list_start_time
340
- self.logger.info(f"PROFILE: tools/list enhanced_result duration: {on_handle_tools_list_duration:.2f} seconds id: {event_id}")
397
+ self.logger.debug(
398
+ f"PROFILE: tools/list enhanced_result duration: {on_handle_tools_list_duration:.2f} seconds id: {event_id}")
341
399
  return enhanced_result
342
400
 
343
401
  on_handle_tools_list_duration = time.time() - on_handle_tools_list_start_time
344
- self.logger.info(f"PROFILE: tools/list result duration: {on_handle_tools_list_duration:.2f} seconds id: {event_id}")
402
+ self.logger.debug(
403
+ f"PROFILE: tools/list result duration: {on_handle_tools_list_duration:.2f} seconds id: {event_id}")
345
404
 
346
405
  return result
347
406
 
@@ -480,6 +539,12 @@ class SecurityMiddleware(Middleware):
480
539
  file_path_prefix = 'file://'
481
540
  if uri.startswith(file_path_prefix):
482
541
  path = urllib.parse.unquote(uri[len(file_path_prefix):])
542
+
543
+ # Windows fix: remove leading slash before drive letter
544
+ # file:///C:/path becomes /C:/path, should be C:/path
545
+ if sys.platform == 'win32' and len(path) >= 3 and path[0] == '/' and path[2] == ':':
546
+ path = path[1:]
547
+
483
548
  try:
484
549
  resolved_path = str(Path(path).resolve())
485
550
  workspace_roots.append(resolved_path)
@@ -585,28 +650,6 @@ class SecurityMiddleware(Middleware):
585
650
  )
586
651
  }
587
652
 
588
- async def _record_user_confirmation(self, event_id: str, is_request: bool, user_decision: UserDecision,
589
- prompt_id: str, call_type: str = None):
590
- """Record user confirmation decision with the security API"""
591
- try:
592
- direction = "request" if is_request else "response"
593
-
594
- user_confirmation = UserConfirmation(
595
- event_id=event_id,
596
- direction=direction,
597
- user_decision=user_decision,
598
- call_type=call_type
599
- )
600
-
601
- async with SecurityPolicyClient(session_id=self.session_id, logger=self.logger,
602
- audit_logger=self.audit_logger, app_id=self.app_id) as client:
603
- result = await client.record_user_confirmation(user_confirmation, prompt_id=prompt_id)
604
- self.logger.debug(f"User confirmation recorded: {result}")
605
- except Exception as e:
606
- # Don't fail the operation if API call fails - just log the error
607
- self.logger.error(f"Failed to record user confirmation: {e}")
608
-
609
-
610
653
  @staticmethod
611
654
  def _create_security_api_failure_decision(error: Exception) -> Dict[str, Any]:
612
655
  """Create a standard failure decision when security API is unavailable/failing/unreachable"""
@@ -616,135 +659,3 @@ class SecurityMiddleware(Middleware):
616
659
  "reasons": [f"Security API unavailable: {error}"],
617
660
  "matched_rules": ["security_api.error"]
618
661
  }
619
-
620
- async def _enforce_decision(self, decision: Dict[str, Any], error_class, base_message: str,
621
- is_request: bool, event_id: str, tool_name: str, content_data: Dict[str, Any],
622
- operation_type: str, prompt_id: str):
623
- """Enforce security decision with user confirmation support"""
624
- decision_type = decision.get("decision", "block")
625
-
626
- if decision_type == "allow":
627
- return
628
-
629
- elif decision_type == "block":
630
- policy_reasons = decision.get("reasons", ["Policy violation"])
631
- severity = decision.get("severity", "unknown")
632
- call_type = decision.get("call_type")
633
-
634
- try:
635
- # Show a blocking dialog and wait for user decision
636
- confirmation_request = ConfirmationRequest(
637
- is_request=is_request,
638
- tool_name=tool_name,
639
- policy_reasons=policy_reasons,
640
- content_data=content_data,
641
- severity=severity,
642
- event_id=event_id,
643
- operation_type=operation_type,
644
- server_name=self.wrapped_server_name,
645
- timeout_seconds=60
646
- )
647
-
648
- response = UserConfirmationDialog(
649
- self.logger, self.audit_logger
650
- ).request_blocking_confirmation(confirmation_request, prompt_id, call_type)
651
-
652
- # If we got here, user chose "Allow Anyway"
653
- self.logger.info(f"User chose to 'allow anyway' a blocked {confirmation_request.operation_type} "
654
- f"operation for tool '{tool_name}' (event: {event_id})")
655
-
656
- await self._record_user_confirmation(event_id, is_request, response.user_decision, prompt_id, call_type)
657
- return
658
-
659
- except UserConfirmationError as e:
660
- # User chose to block or dialog failed
661
- self.logger.warning(f"User blocking confirmation failed: {e}")
662
- await self._record_user_confirmation(event_id, is_request, UserDecision.BLOCK, prompt_id, call_type)
663
- reasons = "; ".join(policy_reasons)
664
- raise error_class("Security Violation. User blocked the operation")
665
-
666
- elif decision_type == "required_explicit_user_confirmation":
667
- policy_reasons = decision.get("reasons", ["Security policy requires confirmation"])
668
- severity = decision.get("severity", "unknown")
669
- call_type = decision.get("call_type")
670
-
671
- try:
672
- confirmation_request = ConfirmationRequest(
673
- is_request=is_request,
674
- tool_name=tool_name,
675
- policy_reasons=policy_reasons,
676
- content_data=content_data,
677
- severity=severity,
678
- event_id=event_id,
679
- operation_type=operation_type,
680
- server_name=self.wrapped_server_name,
681
- timeout_seconds=60
682
- )
683
-
684
- # only show YES_ALWAYS if call_type exists
685
- options = DialogOptions(
686
- show_always_allow=(call_type is not None),
687
- show_always_block=False
688
- )
689
-
690
- response = UserConfirmationDialog(
691
- self.logger, self.audit_logger
692
- ).request_confirmation(confirmation_request, prompt_id, call_type, options)
693
-
694
- # If we got here, user approved the operation
695
- self.logger.info(f"User {response.user_decision.value} {confirmation_request.operation_type} "
696
- f"operation for tool '{tool_name}' (event: {event_id})")
697
-
698
- await self._record_user_confirmation(event_id, is_request, response.user_decision, prompt_id, call_type)
699
- return
700
-
701
- except UserConfirmationError as e:
702
- # User denied confirmation or dialog failed
703
- self.logger.warning(f"User confirmation failed: {e}")
704
- await self._record_user_confirmation(event_id, is_request, UserDecision.BLOCK, prompt_id, call_type)
705
- raise error_class("Security Violation. User blocked the operation")
706
-
707
- elif decision_type == "need_more_info":
708
- stage_title = 'CLIENT REQUEST' if is_request else 'TOOL RESPONSE'
709
-
710
- # Create an actionable error message for the AI agent
711
- reasons = decision.get("reasons", [])
712
- need_fields = decision.get("need_fields", [])
713
-
714
- error_parts = [
715
- f"SECURITY POLICY NEEDS MORE INFORMATION FOR REVIEWING {stage_title}:",
716
- '\n'.join(reasons),
717
- '' # newline
718
- ]
719
-
720
- if need_fields:
721
- # Convert server field names to wrapper field names for the AI agent
722
- wrapper_field_mapping = {
723
- "context.agent.intent": "__wrapper_modelIntent",
724
- "context.agent.plan": "__wrapper_modelPlan",
725
- "context.agent.expectedOutputs": "__wrapper_modelExpectedOutputs",
726
- "context.agent.user_prompt": "__wrapper_userPrompt",
727
- "context.agent.user_prompt_id": "__wrapper_userPromptId",
728
- "context.agent.context_summary": "__wrapper_contextSummary",
729
- "context.workspace.current_files": "__wrapper_currentFiles",
730
- }
731
-
732
- missing_wrapper_fields = []
733
- for field in need_fields:
734
- wrapper_field = wrapper_field_mapping.get(field, field)
735
- missing_wrapper_fields.append(wrapper_field)
736
-
737
- if missing_wrapper_fields:
738
- error_parts.append("AFFECTED FIELDS:")
739
- error_parts.extend(missing_wrapper_fields)
740
- else:
741
- error_parts.append("MISSING INFORMATION:")
742
- error_parts.extend(need_fields)
743
-
744
-
745
- error_parts.append("\nMANDATORY ACTIONS:")
746
- error_parts.append("1. Add/Edit ALL affected fields according to the required information")
747
- error_parts.append("2. Retry the tool call")
748
-
749
- actionable_message = "\n".join(error_parts)
750
- raise error_class(actionable_message)
wrapper/server.py CHANGED
@@ -6,10 +6,11 @@ Implements transparent 1:1 MCP proxying with security middleware
6
6
  import logging
7
7
 
8
8
  from fastmcp.server.middleware.logging import StructuredLoggingMiddleware
9
- from fastmcp.server.proxy import ProxyClient, default_proxy_roots_handler, FastMCPProxy
9
+ from fastmcp.server.proxy import ProxyClient, default_proxy_roots_handler, FastMCPProxy, StatefulProxyClient
10
10
 
11
11
  from modules.logs.audit_trail import AuditTrailLogger
12
12
  from modules.logs.logger import MCPLogger
13
+ from modules.utils.json import safe_json_dumps
13
14
  from .__version__ import __version__
14
15
  from .middleware import SecurityMiddleware
15
16
 
@@ -42,7 +43,7 @@ def create_wrapper_server(wrapper_server_name: str,
42
43
  logger=logger,
43
44
  audit_logger=audit_logger
44
45
  )
45
-
46
+
46
47
  # Log MCPower startup to audit trail
47
48
  audit_logger.log_event("mcpower_start", {
48
49
  "wrapper_version": __version__,
@@ -51,16 +52,23 @@ def create_wrapper_server(wrapper_server_name: str,
51
52
  })
52
53
 
53
54
  # Create FastMCP server as proxy with our security-aware ProxyClient
55
+ # Use StatefulProxyClient for remote servers (mcp-remote or url-based transports)
56
+ config_str = safe_json_dumps(wrapped_server_configs)
57
+ is_remote = '"@mcpower/mcp-remote",' in config_str or '"url":' in config_str
58
+ backend_class = StatefulProxyClient if is_remote else ProxyClient
59
+ backend = backend_class(
60
+ wrapped_server_configs,
61
+ name=wrapper_server_name,
62
+ roots=default_proxy_roots_handler, # Use default for filesystem roots
63
+ sampling_handler=security_middleware.secure_sampling_handler,
64
+ elicitation_handler=security_middleware.secure_elicitation_handler,
65
+ log_handler=security_middleware.secure_log_handler,
66
+ progress_handler=security_middleware.secure_progress_handler,
67
+ )
68
+
54
69
  def client_factory():
55
- return ProxyClient(
56
- wrapped_server_configs,
57
- name=wrapper_server_name,
58
- roots=default_proxy_roots_handler, # Use default for filesystem roots
59
- sampling_handler=security_middleware.secure_sampling_handler,
60
- elicitation_handler=security_middleware.secure_elicitation_handler,
61
- log_handler=security_middleware.secure_log_handler,
62
- progress_handler=security_middleware.secure_progress_handler,
63
- )
70
+ # we must return the same instance, otherwise StatefulProxyClient doesn't play nice with mcp-remote
71
+ return backend
64
72
 
65
73
  server = FastMCPProxy(client_factory=client_factory, name=wrapper_server_name, version=__version__)
66
74