tweek 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tweek/__init__.py +16 -0
- tweek/cli.py +3390 -0
- tweek/cli_helpers.py +193 -0
- tweek/config/__init__.py +13 -0
- tweek/config/allowed_dirs.yaml +23 -0
- tweek/config/manager.py +1064 -0
- tweek/config/patterns.yaml +751 -0
- tweek/config/tiers.yaml +129 -0
- tweek/diagnostics.py +589 -0
- tweek/hooks/__init__.py +1 -0
- tweek/hooks/pre_tool_use.py +861 -0
- tweek/integrations/__init__.py +3 -0
- tweek/integrations/moltbot.py +243 -0
- tweek/licensing.py +398 -0
- tweek/logging/__init__.py +9 -0
- tweek/logging/bundle.py +350 -0
- tweek/logging/json_logger.py +150 -0
- tweek/logging/security_log.py +745 -0
- tweek/mcp/__init__.py +24 -0
- tweek/mcp/approval.py +456 -0
- tweek/mcp/approval_cli.py +356 -0
- tweek/mcp/clients/__init__.py +37 -0
- tweek/mcp/clients/chatgpt.py +112 -0
- tweek/mcp/clients/claude_desktop.py +203 -0
- tweek/mcp/clients/gemini.py +178 -0
- tweek/mcp/proxy.py +667 -0
- tweek/mcp/screening.py +175 -0
- tweek/mcp/server.py +317 -0
- tweek/platform/__init__.py +131 -0
- tweek/plugins/__init__.py +835 -0
- tweek/plugins/base.py +1080 -0
- tweek/plugins/compliance/__init__.py +30 -0
- tweek/plugins/compliance/gdpr.py +333 -0
- tweek/plugins/compliance/gov.py +324 -0
- tweek/plugins/compliance/hipaa.py +285 -0
- tweek/plugins/compliance/legal.py +322 -0
- tweek/plugins/compliance/pci.py +361 -0
- tweek/plugins/compliance/soc2.py +275 -0
- tweek/plugins/detectors/__init__.py +30 -0
- tweek/plugins/detectors/continue_dev.py +206 -0
- tweek/plugins/detectors/copilot.py +254 -0
- tweek/plugins/detectors/cursor.py +192 -0
- tweek/plugins/detectors/moltbot.py +205 -0
- tweek/plugins/detectors/windsurf.py +214 -0
- tweek/plugins/git_discovery.py +395 -0
- tweek/plugins/git_installer.py +491 -0
- tweek/plugins/git_lockfile.py +338 -0
- tweek/plugins/git_registry.py +503 -0
- tweek/plugins/git_security.py +482 -0
- tweek/plugins/providers/__init__.py +30 -0
- tweek/plugins/providers/anthropic.py +181 -0
- tweek/plugins/providers/azure_openai.py +289 -0
- tweek/plugins/providers/bedrock.py +248 -0
- tweek/plugins/providers/google.py +197 -0
- tweek/plugins/providers/openai.py +230 -0
- tweek/plugins/scope.py +130 -0
- tweek/plugins/screening/__init__.py +26 -0
- tweek/plugins/screening/llm_reviewer.py +149 -0
- tweek/plugins/screening/pattern_matcher.py +273 -0
- tweek/plugins/screening/rate_limiter.py +174 -0
- tweek/plugins/screening/session_analyzer.py +159 -0
- tweek/proxy/__init__.py +302 -0
- tweek/proxy/addon.py +223 -0
- tweek/proxy/interceptor.py +313 -0
- tweek/proxy/server.py +315 -0
- tweek/sandbox/__init__.py +71 -0
- tweek/sandbox/executor.py +382 -0
- tweek/sandbox/linux.py +278 -0
- tweek/sandbox/profile_generator.py +323 -0
- tweek/screening/__init__.py +13 -0
- tweek/screening/context.py +81 -0
- tweek/security/__init__.py +22 -0
- tweek/security/llm_reviewer.py +348 -0
- tweek/security/rate_limiter.py +682 -0
- tweek/security/secret_scanner.py +506 -0
- tweek/security/session_analyzer.py +600 -0
- tweek/vault/__init__.py +40 -0
- tweek/vault/cross_platform.py +251 -0
- tweek/vault/keychain.py +288 -0
- tweek-0.1.0.dist-info/METADATA +335 -0
- tweek-0.1.0.dist-info/RECORD +85 -0
- tweek-0.1.0.dist-info/WHEEL +5 -0
- tweek-0.1.0.dist-info/entry_points.txt +25 -0
- tweek-0.1.0.dist-info/licenses/LICENSE +190 -0
- tweek-0.1.0.dist-info/top_level.txt +1 -0
tweek/mcp/proxy.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Tweek MCP Proxy Server
|
|
4
|
+
|
|
5
|
+
Transparent MCP proxy that sits between LLM clients and upstream MCP servers.
|
|
6
|
+
All tool calls are screened through Tweek's defense-in-depth pipeline.
|
|
7
|
+
Flagged calls are queued for human approval via a separate CLI daemon.
|
|
8
|
+
|
|
9
|
+
Architecture:
|
|
10
|
+
LLM Client <--stdio--> TweekMCPProxy <--stdio--> Upstream MCP Server(s)
|
|
11
|
+
|
|
12
|
+
Usage:
|
|
13
|
+
tweek mcp proxy # Start proxy on stdio transport
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import asyncio
|
|
17
|
+
import json
|
|
18
|
+
import logging
|
|
19
|
+
import os
|
|
20
|
+
import signal
|
|
21
|
+
import sys
|
|
22
|
+
import uuid
|
|
23
|
+
from contextlib import AsyncExitStack
|
|
24
|
+
from datetime import timedelta
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
from mcp.client.session import ClientSession
|
|
32
|
+
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
33
|
+
from mcp.server import Server
|
|
34
|
+
from mcp.server.stdio import stdio_server
|
|
35
|
+
from mcp.types import TextContent, Tool
|
|
36
|
+
MCP_AVAILABLE = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
MCP_AVAILABLE = False
|
|
39
|
+
|
|
40
|
+
from tweek.screening.context import ScreeningContext
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Separator for namespaced tool names: {upstream}__{tool}
|
|
44
|
+
NAMESPACE_SEPARATOR = "__"
|
|
45
|
+
|
|
46
|
+
# Default timeout for upstream tool calls (seconds)
|
|
47
|
+
UPSTREAM_CALL_TIMEOUT = 120
|
|
48
|
+
|
|
49
|
+
# Polling interval for approval decisions (seconds)
|
|
50
|
+
APPROVAL_POLL_INTERVAL = 1.0
|
|
51
|
+
|
|
52
|
+
# Background expiry loop interval (seconds)
|
|
53
|
+
EXPIRY_LOOP_INTERVAL = 30
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _check_mcp_available():
|
|
57
|
+
"""Raise RuntimeError if MCP SDK is not installed."""
|
|
58
|
+
if not MCP_AVAILABLE:
|
|
59
|
+
raise RuntimeError(
|
|
60
|
+
"MCP SDK not installed. Install with: pip install 'tweek[mcp]' "
|
|
61
|
+
"or pip install mcp"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class UpstreamConnection:
|
|
66
|
+
"""
|
|
67
|
+
Manages a single connection to an upstream MCP server.
|
|
68
|
+
|
|
69
|
+
Connects via stdio transport, discovers available tools,
|
|
70
|
+
and forwards tool calls.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, name: str, server_params: StdioServerParameters):
|
|
74
|
+
self.name = name
|
|
75
|
+
self.server_params = server_params
|
|
76
|
+
self.session: Optional[ClientSession] = None
|
|
77
|
+
self.tools: List[Tool] = []
|
|
78
|
+
self.connected: bool = False
|
|
79
|
+
|
|
80
|
+
async def connect(self, exit_stack: AsyncExitStack) -> None:
|
|
81
|
+
"""
|
|
82
|
+
Connect to the upstream server and discover its tools.
|
|
83
|
+
|
|
84
|
+
Uses the provided AsyncExitStack to keep the stdio transport alive
|
|
85
|
+
for the lifetime of the proxy.
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
read_stream, write_stream = await exit_stack.enter_async_context(
|
|
89
|
+
stdio_client(self.server_params)
|
|
90
|
+
)
|
|
91
|
+
self.session = ClientSession(read_stream, write_stream)
|
|
92
|
+
init_result = await self.session.initialize()
|
|
93
|
+
tools_result = await self.session.list_tools()
|
|
94
|
+
self.tools = tools_result.tools
|
|
95
|
+
self.connected = True
|
|
96
|
+
|
|
97
|
+
logger.info(
|
|
98
|
+
f"Connected to upstream '{self.name}': "
|
|
99
|
+
f"{len(self.tools)} tool(s) available"
|
|
100
|
+
)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
logger.error(f"Failed to connect to upstream '{self.name}': {e}")
|
|
103
|
+
self.connected = False
|
|
104
|
+
self.tools = []
|
|
105
|
+
|
|
106
|
+
async def call_tool(
|
|
107
|
+
self,
|
|
108
|
+
name: str,
|
|
109
|
+
arguments: Optional[Dict[str, Any]] = None,
|
|
110
|
+
timeout: float = UPSTREAM_CALL_TIMEOUT,
|
|
111
|
+
) -> Dict[str, Any]:
|
|
112
|
+
"""
|
|
113
|
+
Forward a tool call to the upstream server.
|
|
114
|
+
|
|
115
|
+
Returns dict with:
|
|
116
|
+
content: List of content items (text/image/etc.)
|
|
117
|
+
isError: Whether the call resulted in an error
|
|
118
|
+
"""
|
|
119
|
+
if not self.connected or self.session is None:
|
|
120
|
+
return {
|
|
121
|
+
"content": [{"type": "text", "text": json.dumps({
|
|
122
|
+
"error": f"Upstream server '{self.name}' is not connected",
|
|
123
|
+
})}],
|
|
124
|
+
"isError": True,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
result = await self.session.call_tool(
|
|
129
|
+
name=name,
|
|
130
|
+
arguments=arguments,
|
|
131
|
+
read_timeout_seconds=timedelta(seconds=timeout),
|
|
132
|
+
)
|
|
133
|
+
# Convert CallToolResult to serializable dict
|
|
134
|
+
content_list = []
|
|
135
|
+
for item in result.content:
|
|
136
|
+
if hasattr(item, "text"):
|
|
137
|
+
content_list.append({"type": "text", "text": item.text})
|
|
138
|
+
elif hasattr(item, "data"):
|
|
139
|
+
content_list.append({
|
|
140
|
+
"type": getattr(item, "type", "unknown"),
|
|
141
|
+
"data": item.data,
|
|
142
|
+
})
|
|
143
|
+
else:
|
|
144
|
+
content_list.append({"type": "text", "text": str(item)})
|
|
145
|
+
|
|
146
|
+
return {
|
|
147
|
+
"content": content_list,
|
|
148
|
+
"isError": getattr(result, "isError", False),
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.error(f"Tool call to '{self.name}/{name}' failed: {e}")
|
|
153
|
+
return {
|
|
154
|
+
"content": [{"type": "text", "text": json.dumps({
|
|
155
|
+
"error": f"Upstream call failed: {e}",
|
|
156
|
+
"server": self.name,
|
|
157
|
+
"tool": name,
|
|
158
|
+
})}],
|
|
159
|
+
"isError": True,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class TweekMCPProxy:
|
|
164
|
+
"""
|
|
165
|
+
MCP Proxy with security screening and human-in-the-loop approval.
|
|
166
|
+
|
|
167
|
+
Presents merged tools from upstream MCP servers to the downstream
|
|
168
|
+
LLM client. All tool calls pass through Tweek's screening pipeline.
|
|
169
|
+
Flagged calls are queued for human approval.
|
|
170
|
+
"""
|
|
171
|
+
|
|
172
|
+
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
173
|
+
_check_mcp_available()
|
|
174
|
+
self.config = config or {}
|
|
175
|
+
self.server = Server("tweek-proxy")
|
|
176
|
+
self._exit_stack = AsyncExitStack()
|
|
177
|
+
self.upstreams: Dict[str, UpstreamConnection] = {}
|
|
178
|
+
self._tool_registry: Dict[str, str] = {} # namespaced_name -> upstream_name
|
|
179
|
+
self._request_count = 0
|
|
180
|
+
self._blocked_count = 0
|
|
181
|
+
self._approval_count = 0
|
|
182
|
+
self._approval_queue = None
|
|
183
|
+
self._expiry_task = None
|
|
184
|
+
self._setup_handlers()
|
|
185
|
+
|
|
186
|
+
def _get_approval_queue(self):
|
|
187
|
+
"""Lazy-initialize the approval queue."""
|
|
188
|
+
if self._approval_queue is None:
|
|
189
|
+
from tweek.mcp.approval import ApprovalQueue
|
|
190
|
+
proxy_config = self.config.get("mcp", {}).get("proxy", {})
|
|
191
|
+
timeout = proxy_config.get("approval_timeout", 300)
|
|
192
|
+
self._approval_queue = ApprovalQueue(default_timeout=timeout)
|
|
193
|
+
return self._approval_queue
|
|
194
|
+
|
|
195
|
+
def _get_proxy_config(self) -> Dict[str, Any]:
|
|
196
|
+
"""Get proxy-specific config."""
|
|
197
|
+
return self.config.get("mcp", {}).get("proxy", {})
|
|
198
|
+
|
|
199
|
+
def _build_upstreams(self) -> Dict[str, UpstreamConnection]:
|
|
200
|
+
"""Build upstream connections from config."""
|
|
201
|
+
proxy_config = self._get_proxy_config()
|
|
202
|
+
upstreams_config = proxy_config.get("upstreams", {})
|
|
203
|
+
connections = {}
|
|
204
|
+
|
|
205
|
+
for name, server_config in upstreams_config.items():
|
|
206
|
+
command = server_config.get("command", "")
|
|
207
|
+
args = server_config.get("args", [])
|
|
208
|
+
env_config = server_config.get("env") or None
|
|
209
|
+
cwd = server_config.get("cwd")
|
|
210
|
+
|
|
211
|
+
# Expand environment variables in env dict
|
|
212
|
+
if env_config:
|
|
213
|
+
expanded_env = {}
|
|
214
|
+
for key, value in env_config.items():
|
|
215
|
+
if isinstance(value, str):
|
|
216
|
+
expanded_env[key] = os.path.expandvars(value)
|
|
217
|
+
else:
|
|
218
|
+
expanded_env[key] = value
|
|
219
|
+
env_config = expanded_env
|
|
220
|
+
|
|
221
|
+
params = StdioServerParameters(
|
|
222
|
+
command=command,
|
|
223
|
+
args=args,
|
|
224
|
+
env=env_config,
|
|
225
|
+
cwd=cwd,
|
|
226
|
+
)
|
|
227
|
+
connections[name] = UpstreamConnection(name=name, server_params=params)
|
|
228
|
+
|
|
229
|
+
return connections
|
|
230
|
+
|
|
231
|
+
def _namespace_tool(self, upstream_name: str, tool: Tool) -> Tool:
|
|
232
|
+
"""Create a namespaced copy of a tool for the merged tool list."""
|
|
233
|
+
namespaced_name = f"{upstream_name}{NAMESPACE_SEPARATOR}{tool.name}"
|
|
234
|
+
description = tool.description or ""
|
|
235
|
+
namespaced_desc = f"[{upstream_name}] {description}"
|
|
236
|
+
|
|
237
|
+
return Tool(
|
|
238
|
+
name=namespaced_name,
|
|
239
|
+
description=namespaced_desc,
|
|
240
|
+
inputSchema=tool.inputSchema,
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
def _resolve_tool(self, namespaced_name: str) -> Tuple[str, str]:
|
|
244
|
+
"""
|
|
245
|
+
Resolve a namespaced tool name to (upstream_name, original_tool_name).
|
|
246
|
+
|
|
247
|
+
Raises ValueError if the tool name is not in the registry.
|
|
248
|
+
"""
|
|
249
|
+
if NAMESPACE_SEPARATOR not in namespaced_name:
|
|
250
|
+
raise ValueError(f"Tool '{namespaced_name}' is not namespaced")
|
|
251
|
+
|
|
252
|
+
upstream_name = self._tool_registry.get(namespaced_name)
|
|
253
|
+
if upstream_name is None:
|
|
254
|
+
raise ValueError(f"Unknown tool: {namespaced_name}")
|
|
255
|
+
|
|
256
|
+
# Extract original name by removing the prefix
|
|
257
|
+
prefix = f"{upstream_name}{NAMESPACE_SEPARATOR}"
|
|
258
|
+
original_name = namespaced_name[len(prefix):]
|
|
259
|
+
return upstream_name, original_name
|
|
260
|
+
|
|
261
|
+
def _setup_handlers(self):
|
|
262
|
+
"""Register MCP protocol handlers for the proxy server."""
|
|
263
|
+
|
|
264
|
+
@self.server.list_tools()
|
|
265
|
+
async def list_tools() -> list[Tool]:
|
|
266
|
+
"""Return merged tools from all connected upstreams."""
|
|
267
|
+
merged = []
|
|
268
|
+
for upstream_name, upstream in self.upstreams.items():
|
|
269
|
+
if not upstream.connected:
|
|
270
|
+
continue
|
|
271
|
+
for tool in upstream.tools:
|
|
272
|
+
namespaced = self._namespace_tool(upstream_name, tool)
|
|
273
|
+
merged.append(namespaced)
|
|
274
|
+
self._tool_registry[namespaced.name] = upstream_name
|
|
275
|
+
return merged
|
|
276
|
+
|
|
277
|
+
@self.server.call_tool()
|
|
278
|
+
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|
279
|
+
"""Handle tool calls with security screening and approval."""
|
|
280
|
+
self._request_count += 1
|
|
281
|
+
return await self._handle_call_tool(name, arguments)
|
|
282
|
+
|
|
283
|
+
async def _handle_call_tool(
|
|
284
|
+
self, name: str, arguments: dict
|
|
285
|
+
) -> list[TextContent]:
|
|
286
|
+
"""Screen and forward a tool call."""
|
|
287
|
+
# Generate correlation ID for this screening pass
|
|
288
|
+
correlation_id = uuid.uuid4().hex[:12]
|
|
289
|
+
|
|
290
|
+
# Resolve upstream and original tool name
|
|
291
|
+
try:
|
|
292
|
+
upstream_name, original_name = self._resolve_tool(name)
|
|
293
|
+
except ValueError as e:
|
|
294
|
+
return [TextContent(
|
|
295
|
+
type="text",
|
|
296
|
+
text=json.dumps({
|
|
297
|
+
"error": str(e),
|
|
298
|
+
"available_tools": list(self._tool_registry.keys()),
|
|
299
|
+
}),
|
|
300
|
+
)]
|
|
301
|
+
|
|
302
|
+
upstream = self.upstreams.get(upstream_name)
|
|
303
|
+
if upstream is None or not upstream.connected:
|
|
304
|
+
return [TextContent(
|
|
305
|
+
type="text",
|
|
306
|
+
text=json.dumps({
|
|
307
|
+
"error": f"Upstream '{upstream_name}' is not connected",
|
|
308
|
+
}),
|
|
309
|
+
)]
|
|
310
|
+
|
|
311
|
+
# Build screening context
|
|
312
|
+
content = self._extract_content_for_screening(original_name, arguments)
|
|
313
|
+
context = self._build_context(
|
|
314
|
+
tool_name=original_name,
|
|
315
|
+
content=content,
|
|
316
|
+
upstream_name=upstream_name,
|
|
317
|
+
tool_input=arguments,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
# Run screening
|
|
321
|
+
result = self._run_screening(context)
|
|
322
|
+
|
|
323
|
+
if result.get("blocked"):
|
|
324
|
+
self._blocked_count += 1
|
|
325
|
+
self._log_event("blocked", original_name, upstream_name, content, result,
|
|
326
|
+
metadata={"correlation_id": correlation_id}, correlation_id=correlation_id)
|
|
327
|
+
return [TextContent(
|
|
328
|
+
type="text",
|
|
329
|
+
text=json.dumps({
|
|
330
|
+
"blocked": True,
|
|
331
|
+
"reason": result.get("reason", "Blocked by security screening"),
|
|
332
|
+
"server": upstream_name,
|
|
333
|
+
"tool": original_name,
|
|
334
|
+
}),
|
|
335
|
+
)]
|
|
336
|
+
|
|
337
|
+
if result.get("should_prompt"):
|
|
338
|
+
# Queue for human approval
|
|
339
|
+
return await self._handle_approval_flow(
|
|
340
|
+
upstream_name, original_name, arguments, content, result, correlation_id
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Allowed - forward to upstream
|
|
344
|
+
self._log_event("allowed", original_name, upstream_name, content, result,
|
|
345
|
+
correlation_id=correlation_id)
|
|
346
|
+
return await self._forward_and_return(upstream, original_name, arguments)
|
|
347
|
+
|
|
348
|
+
async def _handle_approval_flow(
|
|
349
|
+
self,
|
|
350
|
+
upstream_name: str,
|
|
351
|
+
tool_name: str,
|
|
352
|
+
arguments: Dict[str, Any],
|
|
353
|
+
content: str,
|
|
354
|
+
screening_result: Dict[str, Any],
|
|
355
|
+
correlation_id: Optional[str] = None,
|
|
356
|
+
) -> list[TextContent]:
|
|
357
|
+
"""Queue a tool call for human approval and wait for decision."""
|
|
358
|
+
self._approval_count += 1
|
|
359
|
+
queue = self._get_approval_queue()
|
|
360
|
+
proxy_config = self._get_proxy_config()
|
|
361
|
+
timeout = proxy_config.get("approval_timeout", 300)
|
|
362
|
+
|
|
363
|
+
# Log the prompt
|
|
364
|
+
self._log_event(
|
|
365
|
+
"user_prompted", tool_name, upstream_name, content, screening_result,
|
|
366
|
+
correlation_id=correlation_id,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
# Enqueue
|
|
370
|
+
request_id = queue.enqueue(
|
|
371
|
+
upstream_server=upstream_name,
|
|
372
|
+
tool_name=tool_name,
|
|
373
|
+
arguments=arguments,
|
|
374
|
+
screening_reason=screening_result.get("reason", "Needs confirmation"),
|
|
375
|
+
screening_findings=screening_result.get("findings", []),
|
|
376
|
+
risk_level=screening_result.get("tier", "unknown"),
|
|
377
|
+
timeout_seconds=timeout,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
logger.info(
|
|
381
|
+
f"Approval queued [{request_id[:8]}]: "
|
|
382
|
+
f"{upstream_name}/{tool_name} - {screening_result.get('reason', '')}"
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Wait for decision
|
|
386
|
+
decision = await self._wait_for_approval(request_id, timeout)
|
|
387
|
+
|
|
388
|
+
if decision == "approved":
|
|
389
|
+
self._log_event(
|
|
390
|
+
"user_approved", tool_name, upstream_name, content, screening_result,
|
|
391
|
+
metadata={"request_id": request_id},
|
|
392
|
+
correlation_id=correlation_id,
|
|
393
|
+
)
|
|
394
|
+
upstream = self.upstreams[upstream_name]
|
|
395
|
+
return await self._forward_and_return(upstream, tool_name, arguments)
|
|
396
|
+
else:
|
|
397
|
+
reason = "Approval timed out" if decision == "expired" else "Denied by reviewer"
|
|
398
|
+
self._log_event(
|
|
399
|
+
"user_denied", tool_name, upstream_name, content, screening_result,
|
|
400
|
+
metadata={"request_id": request_id, "reason": reason},
|
|
401
|
+
correlation_id=correlation_id,
|
|
402
|
+
)
|
|
403
|
+
return [TextContent(
|
|
404
|
+
type="text",
|
|
405
|
+
text=json.dumps({
|
|
406
|
+
"blocked": True,
|
|
407
|
+
"reason": reason,
|
|
408
|
+
"server": upstream_name,
|
|
409
|
+
"tool": tool_name,
|
|
410
|
+
"request_id": request_id[:8],
|
|
411
|
+
}),
|
|
412
|
+
)]
|
|
413
|
+
|
|
414
|
+
async def _wait_for_approval(
|
|
415
|
+
self, request_id: str, timeout: float
|
|
416
|
+
) -> str:
|
|
417
|
+
"""
|
|
418
|
+
Poll the approval queue until a decision is made or timeout.
|
|
419
|
+
|
|
420
|
+
Returns: "approved", "denied", or "expired"
|
|
421
|
+
"""
|
|
422
|
+
queue = self._get_approval_queue()
|
|
423
|
+
elapsed = 0.0
|
|
424
|
+
|
|
425
|
+
while elapsed < timeout:
|
|
426
|
+
await asyncio.sleep(APPROVAL_POLL_INTERVAL)
|
|
427
|
+
elapsed += APPROVAL_POLL_INTERVAL
|
|
428
|
+
|
|
429
|
+
status = queue.get_decision(request_id)
|
|
430
|
+
if status is None:
|
|
431
|
+
# Request disappeared (shouldn't happen)
|
|
432
|
+
return "denied"
|
|
433
|
+
|
|
434
|
+
from tweek.mcp.approval import ApprovalStatus
|
|
435
|
+
|
|
436
|
+
if status == ApprovalStatus.APPROVED:
|
|
437
|
+
return "approved"
|
|
438
|
+
elif status == ApprovalStatus.DENIED:
|
|
439
|
+
return "denied"
|
|
440
|
+
elif status == ApprovalStatus.EXPIRED:
|
|
441
|
+
return "expired"
|
|
442
|
+
# Still pending, continue polling
|
|
443
|
+
|
|
444
|
+
# Timeout reached - expire the request
|
|
445
|
+
queue.expire_stale()
|
|
446
|
+
return "expired"
|
|
447
|
+
|
|
448
|
+
async def _forward_and_return(
|
|
449
|
+
self,
|
|
450
|
+
upstream: UpstreamConnection,
|
|
451
|
+
tool_name: str,
|
|
452
|
+
arguments: Dict[str, Any],
|
|
453
|
+
) -> list[TextContent]:
|
|
454
|
+
"""Forward a tool call to upstream and return the result."""
|
|
455
|
+
result = await upstream.call_tool(tool_name, arguments)
|
|
456
|
+
content_list = result.get("content", [])
|
|
457
|
+
|
|
458
|
+
text_contents = []
|
|
459
|
+
for item in content_list:
|
|
460
|
+
if item.get("type") == "text":
|
|
461
|
+
text_contents.append(TextContent(type="text", text=item.get("text", "")))
|
|
462
|
+
else:
|
|
463
|
+
text_contents.append(
|
|
464
|
+
TextContent(type="text", text=json.dumps(item))
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
if not text_contents:
|
|
468
|
+
text_contents = [TextContent(
|
|
469
|
+
type="text",
|
|
470
|
+
text=json.dumps({"result": "empty response from upstream"}),
|
|
471
|
+
)]
|
|
472
|
+
|
|
473
|
+
return text_contents
|
|
474
|
+
|
|
475
|
+
def _build_context(
|
|
476
|
+
self,
|
|
477
|
+
tool_name: str,
|
|
478
|
+
content: str,
|
|
479
|
+
upstream_name: str,
|
|
480
|
+
tool_input: Optional[Dict[str, Any]] = None,
|
|
481
|
+
) -> ScreeningContext:
|
|
482
|
+
"""Build a ScreeningContext for proxy tool calls."""
|
|
483
|
+
proxy_config = self._get_proxy_config()
|
|
484
|
+
overrides = proxy_config.get("screening_overrides", {})
|
|
485
|
+
upstream_override = overrides.get(upstream_name, {})
|
|
486
|
+
default_tier = upstream_override.get("tier", "default")
|
|
487
|
+
|
|
488
|
+
return ScreeningContext(
|
|
489
|
+
tool_name=tool_name,
|
|
490
|
+
content=content,
|
|
491
|
+
tier=default_tier,
|
|
492
|
+
working_dir=os.getcwd(),
|
|
493
|
+
source="mcp_proxy",
|
|
494
|
+
client_name=self.config.get("client_name"),
|
|
495
|
+
mcp_server=upstream_name,
|
|
496
|
+
tool_input=tool_input,
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
def _run_screening(self, context: ScreeningContext) -> Dict[str, Any]:
|
|
500
|
+
"""Run the shared screening pipeline."""
|
|
501
|
+
from tweek.mcp.screening import run_mcp_screening
|
|
502
|
+
return run_mcp_screening(context)
|
|
503
|
+
|
|
504
|
+
def _extract_content_for_screening(
|
|
505
|
+
self, tool_name: str, arguments: Dict[str, Any]
|
|
506
|
+
) -> str:
|
|
507
|
+
"""Extract the primary content string from tool arguments for screening."""
|
|
508
|
+
# Try common parameter names that represent the primary action
|
|
509
|
+
for key in ("command", "query", "sql", "code", "content", "path", "url", "body"):
|
|
510
|
+
if key in arguments:
|
|
511
|
+
value = arguments[key]
|
|
512
|
+
if isinstance(value, str):
|
|
513
|
+
return value
|
|
514
|
+
|
|
515
|
+
# Fallback: serialize all arguments
|
|
516
|
+
return json.dumps(arguments, default=str)
|
|
517
|
+
|
|
518
|
+
def _log_event(
|
|
519
|
+
self,
|
|
520
|
+
event_type: str,
|
|
521
|
+
tool_name: str,
|
|
522
|
+
upstream_name: str,
|
|
523
|
+
content: str,
|
|
524
|
+
screening_result: Dict[str, Any],
|
|
525
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
526
|
+
correlation_id: Optional[str] = None,
|
|
527
|
+
):
|
|
528
|
+
"""Log a screening event to the security logger."""
|
|
529
|
+
try:
|
|
530
|
+
from tweek.logging.security_log import SecurityLogger, SecurityEvent, EventType, get_logger
|
|
531
|
+
|
|
532
|
+
event_map = {
|
|
533
|
+
"allowed": EventType.ALLOWED,
|
|
534
|
+
"blocked": EventType.BLOCKED,
|
|
535
|
+
"user_prompted": EventType.USER_PROMPTED,
|
|
536
|
+
"user_approved": EventType.USER_APPROVED,
|
|
537
|
+
"user_denied": EventType.USER_DENIED,
|
|
538
|
+
}
|
|
539
|
+
evt = event_map.get(event_type, EventType.TOOL_INVOKED)
|
|
540
|
+
|
|
541
|
+
sec_logger = get_logger()
|
|
542
|
+
event_metadata = {
|
|
543
|
+
"upstream_server": upstream_name,
|
|
544
|
+
"findings_count": len(screening_result.get("findings", [])),
|
|
545
|
+
}
|
|
546
|
+
if metadata:
|
|
547
|
+
event_metadata.update(metadata)
|
|
548
|
+
|
|
549
|
+
sec_logger.log(SecurityEvent(
|
|
550
|
+
event_type=evt,
|
|
551
|
+
tool_name=tool_name,
|
|
552
|
+
command=content,
|
|
553
|
+
tier=screening_result.get("tier"),
|
|
554
|
+
decision=event_type,
|
|
555
|
+
decision_reason=screening_result.get("reason"),
|
|
556
|
+
metadata=event_metadata,
|
|
557
|
+
correlation_id=correlation_id,
|
|
558
|
+
source="mcp_proxy",
|
|
559
|
+
))
|
|
560
|
+
except Exception as e:
|
|
561
|
+
logger.debug(f"Failed to log security event: {e}")
|
|
562
|
+
|
|
563
|
+
async def start(self) -> None:
|
|
564
|
+
"""
|
|
565
|
+
Start the proxy: connect to upstreams and serve on stdio.
|
|
566
|
+
"""
|
|
567
|
+
# Build upstream connections from config
|
|
568
|
+
self.upstreams = self._build_upstreams()
|
|
569
|
+
|
|
570
|
+
if not self.upstreams:
|
|
571
|
+
logger.warning(
|
|
572
|
+
"No upstream servers configured. "
|
|
573
|
+
"Add 'mcp.proxy.upstreams' to your config."
|
|
574
|
+
)
|
|
575
|
+
print(
|
|
576
|
+
"Warning: No upstream MCP servers configured.\n"
|
|
577
|
+
"Configure upstreams in ~/.tweek/config.yaml under mcp.proxy.upstreams",
|
|
578
|
+
file=sys.stderr,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
async with self._exit_stack:
|
|
582
|
+
# Connect to all upstreams
|
|
583
|
+
for name, upstream in self.upstreams.items():
|
|
584
|
+
await upstream.connect(self._exit_stack)
|
|
585
|
+
|
|
586
|
+
connected = sum(1 for u in self.upstreams.values() if u.connected)
|
|
587
|
+
total_tools = sum(
|
|
588
|
+
len(u.tools) for u in self.upstreams.values() if u.connected
|
|
589
|
+
)
|
|
590
|
+
logger.info(
|
|
591
|
+
f"Proxy ready: {connected}/{len(self.upstreams)} upstreams, "
|
|
592
|
+
f"{total_tools} tools available"
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
# Check if approval daemon is reachable
|
|
596
|
+
try:
|
|
597
|
+
queue = self._get_approval_queue()
|
|
598
|
+
pending = queue.count_pending()
|
|
599
|
+
if pending > 0:
|
|
600
|
+
print(
|
|
601
|
+
f"Note: {pending} pending approval request(s) in queue.",
|
|
602
|
+
file=sys.stderr,
|
|
603
|
+
)
|
|
604
|
+
except Exception:
|
|
605
|
+
pass
|
|
606
|
+
|
|
607
|
+
print(
|
|
608
|
+
f"Tweek MCP Proxy ready ({connected} upstream(s), {total_tools} tools)",
|
|
609
|
+
file=sys.stderr,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Start background expiry loop
|
|
613
|
+
self._expiry_task = asyncio.create_task(self._run_expiry_loop())
|
|
614
|
+
|
|
615
|
+
try:
|
|
616
|
+
# Serve on stdio
|
|
617
|
+
async with stdio_server() as (read_stream, write_stream):
|
|
618
|
+
await self.server.run(
|
|
619
|
+
read_stream,
|
|
620
|
+
write_stream,
|
|
621
|
+
self.server.create_initialization_options(),
|
|
622
|
+
)
|
|
623
|
+
finally:
|
|
624
|
+
# Clean up
|
|
625
|
+
if self._expiry_task:
|
|
626
|
+
self._expiry_task.cancel()
|
|
627
|
+
try:
|
|
628
|
+
await self._expiry_task
|
|
629
|
+
except asyncio.CancelledError:
|
|
630
|
+
pass
|
|
631
|
+
|
|
632
|
+
# Expire all pending requests on shutdown
|
|
633
|
+
try:
|
|
634
|
+
queue = self._get_approval_queue()
|
|
635
|
+
expired = queue.expire_stale()
|
|
636
|
+
if expired:
|
|
637
|
+
logger.info(f"Expired {expired} pending requests on shutdown")
|
|
638
|
+
except Exception:
|
|
639
|
+
pass
|
|
640
|
+
|
|
641
|
+
async def _run_expiry_loop(self):
|
|
642
|
+
"""Background task to expire stale approval requests."""
|
|
643
|
+
while True:
|
|
644
|
+
try:
|
|
645
|
+
await asyncio.sleep(EXPIRY_LOOP_INTERVAL)
|
|
646
|
+
queue = self._get_approval_queue()
|
|
647
|
+
queue.expire_stale()
|
|
648
|
+
except asyncio.CancelledError:
|
|
649
|
+
break
|
|
650
|
+
except Exception as e:
|
|
651
|
+
logger.debug(f"Expiry loop error: {e}")
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
async def run_proxy(config: Optional[Dict[str, Any]] = None):
|
|
655
|
+
"""
|
|
656
|
+
Run the Tweek MCP proxy on stdio transport.
|
|
657
|
+
|
|
658
|
+
This is the main entry point for 'tweek mcp proxy'.
|
|
659
|
+
"""
|
|
660
|
+
_check_mcp_available()
|
|
661
|
+
proxy = TweekMCPProxy(config=config)
|
|
662
|
+
await proxy.start()
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def create_proxy(config: Optional[Dict[str, Any]] = None) -> "TweekMCPProxy":
|
|
666
|
+
"""Create a TweekMCPProxy instance for programmatic use."""
|
|
667
|
+
return TweekMCPProxy(config=config)
|