tweek 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tweek/__init__.py +2 -2
- tweek/audit.py +2 -2
- tweek/cli.py +78 -6605
- tweek/cli_config.py +643 -0
- tweek/cli_configure.py +413 -0
- tweek/cli_core.py +718 -0
- tweek/cli_dry_run.py +390 -0
- tweek/cli_helpers.py +316 -0
- tweek/cli_install.py +1666 -0
- tweek/cli_logs.py +301 -0
- tweek/cli_mcp.py +148 -0
- tweek/cli_memory.py +343 -0
- tweek/cli_plugins.py +748 -0
- tweek/cli_protect.py +564 -0
- tweek/cli_proxy.py +405 -0
- tweek/cli_security.py +236 -0
- tweek/cli_skills.py +289 -0
- tweek/cli_uninstall.py +551 -0
- tweek/cli_vault.py +313 -0
- tweek/config/allowed_dirs.yaml +16 -17
- tweek/config/families.yaml +4 -1
- tweek/config/manager.py +17 -0
- tweek/config/patterns.yaml +29 -5
- tweek/config/templates/config.yaml.template +212 -0
- tweek/config/templates/env.template +45 -0
- tweek/config/templates/overrides.yaml.template +121 -0
- tweek/config/templates/tweek.yaml.template +20 -0
- tweek/config/templates.py +136 -0
- tweek/config/tiers.yaml +5 -4
- tweek/diagnostics.py +112 -32
- tweek/hooks/overrides.py +4 -0
- tweek/hooks/post_tool_use.py +46 -1
- tweek/hooks/pre_tool_use.py +149 -49
- tweek/integrations/openclaw.py +84 -0
- tweek/licensing.py +1 -1
- tweek/mcp/__init__.py +7 -9
- tweek/mcp/clients/chatgpt.py +2 -2
- tweek/mcp/clients/claude_desktop.py +2 -2
- tweek/mcp/clients/gemini.py +2 -2
- tweek/mcp/proxy.py +165 -1
- tweek/memory/provenance.py +438 -0
- tweek/memory/queries.py +2 -0
- tweek/memory/safety.py +23 -4
- tweek/memory/schemas.py +1 -0
- tweek/memory/store.py +101 -71
- tweek/plugins/screening/heuristic_scorer.py +1 -1
- tweek/security/integrity.py +77 -0
- tweek/security/llm_reviewer.py +162 -68
- tweek/security/local_reviewer.py +44 -2
- tweek/security/model_registry.py +73 -7
- tweek/skill_template/overrides-reference.md +1 -1
- tweek/skills/context.py +221 -0
- tweek/skills/scanner.py +2 -2
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/METADATA +8 -7
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/RECORD +60 -38
- tweek/mcp/server.py +0 -320
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/WHEEL +0 -0
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/entry_points.txt +0 -0
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/licenses/LICENSE +0 -0
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/licenses/NOTICE +0 -0
- {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/top_level.txt +0 -0
tweek/mcp/proxy.py
CHANGED
|
@@ -28,6 +28,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|
|
28
28
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
|
+
# Version for MCP server identification
|
|
32
|
+
MCP_SERVER_VERSION = "0.2.0"
|
|
33
|
+
|
|
31
34
|
try:
|
|
32
35
|
from mcp.client.session import ClientSession
|
|
33
36
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
@@ -264,8 +267,56 @@ class TweekMCPProxy:
|
|
|
264
267
|
|
|
265
268
|
@self.server.list_tools()
|
|
266
269
|
async def list_tools() -> list[Tool]:
|
|
267
|
-
"""Return merged tools from all connected upstreams."""
|
|
270
|
+
"""Return built-in tools plus merged tools from all connected upstreams."""
|
|
268
271
|
merged = []
|
|
272
|
+
|
|
273
|
+
# Built-in tools (vault + status)
|
|
274
|
+
tool_configs = self.config.get("mcp", {}).get("proxy", {}).get("tools", {})
|
|
275
|
+
|
|
276
|
+
if tool_configs.get("vault", True):
|
|
277
|
+
merged.append(Tool(
|
|
278
|
+
name="tweek_vault",
|
|
279
|
+
description=(
|
|
280
|
+
"Retrieve a credential from Tweek's secure vault. "
|
|
281
|
+
"Credentials are stored in the system keychain, not in .env files. "
|
|
282
|
+
"Use this instead of reading .env files or hardcoding secrets."
|
|
283
|
+
),
|
|
284
|
+
inputSchema={
|
|
285
|
+
"type": "object",
|
|
286
|
+
"properties": {
|
|
287
|
+
"skill": {
|
|
288
|
+
"type": "string",
|
|
289
|
+
"description": "Skill namespace for the credential",
|
|
290
|
+
},
|
|
291
|
+
"key": {
|
|
292
|
+
"type": "string",
|
|
293
|
+
"description": "Credential key name",
|
|
294
|
+
},
|
|
295
|
+
},
|
|
296
|
+
"required": ["skill", "key"],
|
|
297
|
+
},
|
|
298
|
+
))
|
|
299
|
+
|
|
300
|
+
if tool_configs.get("status", True):
|
|
301
|
+
merged.append(Tool(
|
|
302
|
+
name="tweek_status",
|
|
303
|
+
description=(
|
|
304
|
+
"Show Tweek security status including active plugins, "
|
|
305
|
+
"recent activity, threat summary, and proxy statistics."
|
|
306
|
+
),
|
|
307
|
+
inputSchema={
|
|
308
|
+
"type": "object",
|
|
309
|
+
"properties": {
|
|
310
|
+
"detail": {
|
|
311
|
+
"type": "string",
|
|
312
|
+
"enum": ["summary", "plugins", "activity", "threats"],
|
|
313
|
+
"description": "Level of detail (default: summary)",
|
|
314
|
+
},
|
|
315
|
+
},
|
|
316
|
+
},
|
|
317
|
+
))
|
|
318
|
+
|
|
319
|
+
# Upstream tools
|
|
269
320
|
for upstream_name, upstream in self.upstreams.items():
|
|
270
321
|
if not upstream.connected:
|
|
271
322
|
continue
|
|
@@ -279,6 +330,24 @@ class TweekMCPProxy:
|
|
|
279
330
|
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|
280
331
|
"""Handle tool calls with security screening and approval."""
|
|
281
332
|
self._request_count += 1
|
|
333
|
+
|
|
334
|
+
# Handle built-in tools (vault, status)
|
|
335
|
+
builtin_handlers = {
|
|
336
|
+
"tweek_vault": self._handle_vault,
|
|
337
|
+
"tweek_status": self._handle_status,
|
|
338
|
+
}
|
|
339
|
+
handler = builtin_handlers.get(name)
|
|
340
|
+
if handler is not None:
|
|
341
|
+
try:
|
|
342
|
+
result = await handler(arguments)
|
|
343
|
+
return [TextContent(type="text", text=result)]
|
|
344
|
+
except Exception as e:
|
|
345
|
+
logger.error(f"Tool {name} failed: {e}")
|
|
346
|
+
return [TextContent(
|
|
347
|
+
type="text",
|
|
348
|
+
text=json.dumps({"error": str(e), "tool": name}),
|
|
349
|
+
)]
|
|
350
|
+
|
|
282
351
|
return await self._handle_call_tool(name, arguments)
|
|
283
352
|
|
|
284
353
|
async def _handle_call_tool(
|
|
@@ -578,6 +647,101 @@ class TweekMCPProxy:
|
|
|
578
647
|
except Exception as e:
|
|
579
648
|
logger.debug(f"Failed to log security event: {e}")
|
|
580
649
|
|
|
650
|
+
async def _handle_vault(self, arguments: Dict[str, Any]) -> str:
|
|
651
|
+
"""Handle tweek_vault tool call."""
|
|
652
|
+
skill = arguments.get("skill", "")
|
|
653
|
+
key = arguments.get("key", "")
|
|
654
|
+
|
|
655
|
+
# Screen vault access
|
|
656
|
+
context = self._build_context(
|
|
657
|
+
tool_name="Vault",
|
|
658
|
+
content=f"vault:{skill}/{key}",
|
|
659
|
+
upstream_name="_builtin",
|
|
660
|
+
tool_input=arguments,
|
|
661
|
+
)
|
|
662
|
+
screening = self._run_screening(context)
|
|
663
|
+
|
|
664
|
+
if screening.get("blocked"):
|
|
665
|
+
return json.dumps({
|
|
666
|
+
"blocked": True,
|
|
667
|
+
"reason": screening.get("reason"),
|
|
668
|
+
})
|
|
669
|
+
|
|
670
|
+
try:
|
|
671
|
+
from tweek.vault.cross_platform import CrossPlatformVault
|
|
672
|
+
|
|
673
|
+
vault = CrossPlatformVault()
|
|
674
|
+
value = vault.get(skill, key)
|
|
675
|
+
|
|
676
|
+
if value is None:
|
|
677
|
+
return json.dumps({
|
|
678
|
+
"error": f"Credential not found: {skill}/{key}",
|
|
679
|
+
"available": False,
|
|
680
|
+
})
|
|
681
|
+
|
|
682
|
+
return json.dumps({
|
|
683
|
+
"value": value,
|
|
684
|
+
"skill": skill,
|
|
685
|
+
"key": key,
|
|
686
|
+
})
|
|
687
|
+
|
|
688
|
+
except Exception as e:
|
|
689
|
+
# Don't leak internal details across trust boundary
|
|
690
|
+
logger.error(f"Vault operation failed: {e}")
|
|
691
|
+
return json.dumps({"error": "Vault operation failed"})
|
|
692
|
+
|
|
693
|
+
async def _handle_status(self, arguments: Dict[str, Any]) -> str:
|
|
694
|
+
"""Handle tweek_status tool call."""
|
|
695
|
+
detail = arguments.get("detail", "summary")
|
|
696
|
+
|
|
697
|
+
try:
|
|
698
|
+
status = {
|
|
699
|
+
"version": MCP_SERVER_VERSION,
|
|
700
|
+
"source": "mcp_proxy",
|
|
701
|
+
"mode": "proxy",
|
|
702
|
+
"proxy_requests": self._request_count,
|
|
703
|
+
"proxy_blocked": self._blocked_count,
|
|
704
|
+
"proxy_approvals": self._approval_count,
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
if detail in ("summary", "plugins"):
|
|
708
|
+
try:
|
|
709
|
+
from tweek.plugins import get_registry
|
|
710
|
+
registry = get_registry()
|
|
711
|
+
stats = registry.get_stats()
|
|
712
|
+
status["plugins"] = stats
|
|
713
|
+
except ImportError:
|
|
714
|
+
status["plugins"] = {"error": "Plugin system not available"}
|
|
715
|
+
|
|
716
|
+
if detail in ("summary", "activity"):
|
|
717
|
+
try:
|
|
718
|
+
from tweek.logging.security_log import get_logger as get_sec_logger
|
|
719
|
+
sec_logger = get_sec_logger()
|
|
720
|
+
recent = sec_logger.get_recent(limit=10)
|
|
721
|
+
status["recent_activity"] = [
|
|
722
|
+
{
|
|
723
|
+
"timestamp": str(e.timestamp),
|
|
724
|
+
"event_type": e.event_type.value,
|
|
725
|
+
"tool": e.tool_name,
|
|
726
|
+
"decision": e.decision,
|
|
727
|
+
}
|
|
728
|
+
for e in recent
|
|
729
|
+
] if recent else []
|
|
730
|
+
except (ImportError, Exception):
|
|
731
|
+
status["recent_activity"] = []
|
|
732
|
+
|
|
733
|
+
# Include approval queue stats if available
|
|
734
|
+
try:
|
|
735
|
+
queue = self._get_approval_queue()
|
|
736
|
+
status["approval_queue"] = queue.get_stats()
|
|
737
|
+
except Exception:
|
|
738
|
+
pass
|
|
739
|
+
|
|
740
|
+
return json.dumps(status, indent=2)
|
|
741
|
+
|
|
742
|
+
except Exception as e:
|
|
743
|
+
return json.dumps({"error": str(e)})
|
|
744
|
+
|
|
581
745
|
async def start(self) -> None:
|
|
582
746
|
"""
|
|
583
747
|
Start the proxy: connect to upstreams and serve on stdio.
|
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tweek Session Provenance & Taint Tracking
|
|
3
|
+
|
|
4
|
+
Tracks the "trust lineage" of commands within a session.
|
|
5
|
+
When external content is ingested (Read, WebFetch, WebSearch)
|
|
6
|
+
and that content contains suspicious patterns, the session
|
|
7
|
+
becomes "tainted" — subsequent commands receive heightened scrutiny.
|
|
8
|
+
|
|
9
|
+
When a session is "clean" (no external content or all content
|
|
10
|
+
from trusted sources), enforcement thresholds are relaxed to
|
|
11
|
+
reduce false positives.
|
|
12
|
+
|
|
13
|
+
Taint levels: clean → low → medium → high → critical
|
|
14
|
+
Taint decays by one level every DECAY_INTERVAL tool calls
|
|
15
|
+
without new external content or pattern matches.
|
|
16
|
+
|
|
17
|
+
Storage: SQLite persistent in memory.db (same DB as pattern decisions).
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import sqlite3
|
|
23
|
+
from datetime import datetime
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Optional, Dict, Any
|
|
26
|
+
|
|
27
|
+
# =========================================================================
|
|
28
|
+
# Constants
|
|
29
|
+
# =========================================================================
|
|
30
|
+
|
|
31
|
+
# Taint levels in order of severity
|
|
32
|
+
TAINT_LEVELS = ("clean", "low", "medium", "high", "critical")
|
|
33
|
+
TAINT_RANK = {level: i for i, level in enumerate(TAINT_LEVELS)}
|
|
34
|
+
|
|
35
|
+
# How many tool calls between taint decay steps
|
|
36
|
+
DECAY_INTERVAL = 5
|
|
37
|
+
|
|
38
|
+
# Tools classified as external content sources
|
|
39
|
+
EXTERNAL_SOURCE_TOOLS = frozenset({"Read", "WebFetch", "WebSearch", "Grep"})
|
|
40
|
+
|
|
41
|
+
# Tools classified as action tools (user-context)
|
|
42
|
+
ACTION_TOOLS = frozenset({"Bash", "Write", "Edit", "NotebookEdit"})
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# =========================================================================
|
|
46
|
+
# Schema
|
|
47
|
+
# =========================================================================
|
|
48
|
+
|
|
49
|
+
PROVENANCE_SCHEMA = """
|
|
50
|
+
CREATE TABLE IF NOT EXISTS session_taint (
|
|
51
|
+
session_id TEXT PRIMARY KEY,
|
|
52
|
+
taint_level TEXT NOT NULL DEFAULT 'clean',
|
|
53
|
+
last_taint_source TEXT,
|
|
54
|
+
last_taint_reason TEXT,
|
|
55
|
+
turns_since_taint INTEGER NOT NULL DEFAULT 0,
|
|
56
|
+
total_tool_calls INTEGER NOT NULL DEFAULT 0,
|
|
57
|
+
total_external_ingests INTEGER NOT NULL DEFAULT 0,
|
|
58
|
+
total_taint_escalations INTEGER NOT NULL DEFAULT 0,
|
|
59
|
+
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
|
60
|
+
updated_at TEXT NOT NULL DEFAULT (datetime('now'))
|
|
61
|
+
);
|
|
62
|
+
|
|
63
|
+
CREATE INDEX IF NOT EXISTS idx_st_taint_level
|
|
64
|
+
ON session_taint(taint_level);
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# =========================================================================
|
|
69
|
+
# Taint Level Operations
|
|
70
|
+
# =========================================================================
|
|
71
|
+
|
|
72
|
+
def escalate_taint(current: str, to_level: str) -> str:
|
|
73
|
+
"""Escalate taint level (never downgrade via this function).
|
|
74
|
+
|
|
75
|
+
Returns the higher of current and to_level.
|
|
76
|
+
"""
|
|
77
|
+
current_rank = TAINT_RANK.get(current, 0)
|
|
78
|
+
target_rank = TAINT_RANK.get(to_level, 0)
|
|
79
|
+
if target_rank > current_rank:
|
|
80
|
+
return to_level
|
|
81
|
+
return current
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def decay_taint(current: str) -> str:
|
|
85
|
+
"""Decay taint level by one step toward 'clean'.
|
|
86
|
+
|
|
87
|
+
Returns the next lower taint level.
|
|
88
|
+
"""
|
|
89
|
+
rank = TAINT_RANK.get(current, 0)
|
|
90
|
+
if rank <= 0:
|
|
91
|
+
return "clean"
|
|
92
|
+
return TAINT_LEVELS[rank - 1]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def severity_to_taint(pattern_severity: str) -> str:
|
|
96
|
+
"""Map a pattern severity to a taint level.
|
|
97
|
+
|
|
98
|
+
critical pattern → critical taint
|
|
99
|
+
high pattern → high taint
|
|
100
|
+
medium pattern → medium taint
|
|
101
|
+
low pattern → low taint
|
|
102
|
+
"""
|
|
103
|
+
mapping = {
|
|
104
|
+
"critical": "critical",
|
|
105
|
+
"high": "high",
|
|
106
|
+
"medium": "medium",
|
|
107
|
+
"low": "low",
|
|
108
|
+
}
|
|
109
|
+
return mapping.get(pattern_severity, "medium")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# =========================================================================
|
|
113
|
+
# Session Taint Store
|
|
114
|
+
# =========================================================================
|
|
115
|
+
|
|
116
|
+
class SessionTaintStore:
|
|
117
|
+
"""SQLite-backed session taint tracking.
|
|
118
|
+
|
|
119
|
+
Uses the same database as MemoryStore (memory.db) but manages
|
|
120
|
+
its own table (session_taint).
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def __init__(self, db_path: Optional[Path] = None):
|
|
124
|
+
from tweek.memory.store import GLOBAL_MEMORY_PATH
|
|
125
|
+
self.db_path = db_path or GLOBAL_MEMORY_PATH
|
|
126
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
127
|
+
self._conn: Optional[sqlite3.Connection] = None
|
|
128
|
+
self._ensure_schema()
|
|
129
|
+
|
|
130
|
+
def _get_connection(self) -> sqlite3.Connection:
|
|
131
|
+
if self._conn is None:
|
|
132
|
+
self._conn = sqlite3.connect(
|
|
133
|
+
str(self.db_path),
|
|
134
|
+
timeout=5.0,
|
|
135
|
+
isolation_level=None,
|
|
136
|
+
)
|
|
137
|
+
self._conn.row_factory = sqlite3.Row
|
|
138
|
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
139
|
+
return self._conn
|
|
140
|
+
|
|
141
|
+
def close(self):
|
|
142
|
+
if self._conn is not None:
|
|
143
|
+
self._conn.close()
|
|
144
|
+
self._conn = None
|
|
145
|
+
|
|
146
|
+
def _ensure_schema(self):
|
|
147
|
+
conn = self._get_connection()
|
|
148
|
+
conn.executescript(PROVENANCE_SCHEMA)
|
|
149
|
+
|
|
150
|
+
def get_session_taint(self, session_id: str) -> Dict[str, Any]:
|
|
151
|
+
"""Get current taint state for a session.
|
|
152
|
+
|
|
153
|
+
Returns a dict with taint_level, turns_since_taint, etc.
|
|
154
|
+
Returns a "clean" default if session not found.
|
|
155
|
+
"""
|
|
156
|
+
conn = self._get_connection()
|
|
157
|
+
row = conn.execute(
|
|
158
|
+
"SELECT * FROM session_taint WHERE session_id = ?",
|
|
159
|
+
(session_id,),
|
|
160
|
+
).fetchone()
|
|
161
|
+
|
|
162
|
+
if row is None:
|
|
163
|
+
return {
|
|
164
|
+
"taint_level": "clean",
|
|
165
|
+
"turns_since_taint": 0,
|
|
166
|
+
"total_tool_calls": 0,
|
|
167
|
+
"total_external_ingests": 0,
|
|
168
|
+
"total_taint_escalations": 0,
|
|
169
|
+
"last_taint_source": None,
|
|
170
|
+
"last_taint_reason": None,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
return dict(row)
|
|
174
|
+
|
|
175
|
+
def record_tool_call(self, session_id: str, tool_name: str) -> Dict[str, Any]:
|
|
176
|
+
"""Record a tool call and apply taint decay if applicable.
|
|
177
|
+
|
|
178
|
+
Called from pre_tool_use to track tool calls and decay taint.
|
|
179
|
+
Returns the updated taint state.
|
|
180
|
+
"""
|
|
181
|
+
conn = self._get_connection()
|
|
182
|
+
state = self.get_session_taint(session_id)
|
|
183
|
+
|
|
184
|
+
new_total = state["total_tool_calls"] + 1
|
|
185
|
+
new_turns_since_taint = state["turns_since_taint"] + 1
|
|
186
|
+
|
|
187
|
+
# Apply decay if enough clean turns have passed
|
|
188
|
+
current_taint = state["taint_level"]
|
|
189
|
+
if (current_taint != "clean"
|
|
190
|
+
and new_turns_since_taint >= DECAY_INTERVAL):
|
|
191
|
+
current_taint = decay_taint(current_taint)
|
|
192
|
+
new_turns_since_taint = 0 # Reset counter after decay
|
|
193
|
+
|
|
194
|
+
# Upsert
|
|
195
|
+
conn.execute("""
|
|
196
|
+
INSERT INTO session_taint
|
|
197
|
+
(session_id, taint_level, turns_since_taint,
|
|
198
|
+
total_tool_calls, total_external_ingests,
|
|
199
|
+
total_taint_escalations,
|
|
200
|
+
last_taint_source, last_taint_reason,
|
|
201
|
+
updated_at)
|
|
202
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
|
203
|
+
ON CONFLICT(session_id) DO UPDATE SET
|
|
204
|
+
taint_level = excluded.taint_level,
|
|
205
|
+
turns_since_taint = excluded.turns_since_taint,
|
|
206
|
+
total_tool_calls = excluded.total_tool_calls,
|
|
207
|
+
updated_at = datetime('now')
|
|
208
|
+
""", (
|
|
209
|
+
session_id, current_taint, new_turns_since_taint,
|
|
210
|
+
new_total, state["total_external_ingests"],
|
|
211
|
+
state["total_taint_escalations"],
|
|
212
|
+
state["last_taint_source"], state["last_taint_reason"],
|
|
213
|
+
))
|
|
214
|
+
|
|
215
|
+
state["taint_level"] = current_taint
|
|
216
|
+
state["turns_since_taint"] = new_turns_since_taint
|
|
217
|
+
state["total_tool_calls"] = new_total
|
|
218
|
+
return state
|
|
219
|
+
|
|
220
|
+
def record_taint(
|
|
221
|
+
self,
|
|
222
|
+
session_id: str,
|
|
223
|
+
taint_level: str,
|
|
224
|
+
source: str,
|
|
225
|
+
reason: str,
|
|
226
|
+
) -> Dict[str, Any]:
|
|
227
|
+
"""Escalate session taint after finding suspicious content.
|
|
228
|
+
|
|
229
|
+
Called from post_tool_use when patterns are found in ingested content.
|
|
230
|
+
Taint only escalates, never downgrades via this method.
|
|
231
|
+
Returns the updated taint state.
|
|
232
|
+
"""
|
|
233
|
+
conn = self._get_connection()
|
|
234
|
+
state = self.get_session_taint(session_id)
|
|
235
|
+
|
|
236
|
+
new_taint = escalate_taint(state["taint_level"], taint_level)
|
|
237
|
+
new_escalations = state["total_taint_escalations"]
|
|
238
|
+
if new_taint != state["taint_level"]:
|
|
239
|
+
new_escalations += 1
|
|
240
|
+
|
|
241
|
+
new_ingests = state["total_external_ingests"] + 1
|
|
242
|
+
|
|
243
|
+
conn.execute("""
|
|
244
|
+
INSERT INTO session_taint
|
|
245
|
+
(session_id, taint_level, turns_since_taint,
|
|
246
|
+
total_tool_calls, total_external_ingests,
|
|
247
|
+
total_taint_escalations,
|
|
248
|
+
last_taint_source, last_taint_reason,
|
|
249
|
+
updated_at)
|
|
250
|
+
VALUES (?, ?, 0, ?, ?, ?, ?, ?, datetime('now'))
|
|
251
|
+
ON CONFLICT(session_id) DO UPDATE SET
|
|
252
|
+
taint_level = excluded.taint_level,
|
|
253
|
+
turns_since_taint = 0,
|
|
254
|
+
total_tool_calls = excluded.total_tool_calls,
|
|
255
|
+
total_external_ingests = excluded.total_external_ingests,
|
|
256
|
+
total_taint_escalations = excluded.total_taint_escalations,
|
|
257
|
+
last_taint_source = excluded.last_taint_source,
|
|
258
|
+
last_taint_reason = excluded.last_taint_reason,
|
|
259
|
+
updated_at = datetime('now')
|
|
260
|
+
""", (
|
|
261
|
+
session_id, new_taint,
|
|
262
|
+
state["total_tool_calls"], new_ingests,
|
|
263
|
+
new_escalations, source, reason,
|
|
264
|
+
))
|
|
265
|
+
|
|
266
|
+
state["taint_level"] = new_taint
|
|
267
|
+
state["turns_since_taint"] = 0
|
|
268
|
+
state["total_external_ingests"] = new_ingests
|
|
269
|
+
state["total_taint_escalations"] = new_escalations
|
|
270
|
+
state["last_taint_source"] = source
|
|
271
|
+
state["last_taint_reason"] = reason
|
|
272
|
+
return state
|
|
273
|
+
|
|
274
|
+
def record_external_ingest(self, session_id: str, source: str):
|
|
275
|
+
"""Record an external content ingest without escalating taint.
|
|
276
|
+
|
|
277
|
+
Called when Read/WebFetch returns content that passes screening.
|
|
278
|
+
Tracks the ingest count but doesn't change taint level.
|
|
279
|
+
"""
|
|
280
|
+
conn = self._get_connection()
|
|
281
|
+
state = self.get_session_taint(session_id)
|
|
282
|
+
|
|
283
|
+
conn.execute("""
|
|
284
|
+
INSERT INTO session_taint
|
|
285
|
+
(session_id, taint_level, turns_since_taint,
|
|
286
|
+
total_tool_calls, total_external_ingests,
|
|
287
|
+
total_taint_escalations,
|
|
288
|
+
last_taint_source, last_taint_reason,
|
|
289
|
+
updated_at)
|
|
290
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
|
|
291
|
+
ON CONFLICT(session_id) DO UPDATE SET
|
|
292
|
+
total_external_ingests = excluded.total_external_ingests,
|
|
293
|
+
updated_at = datetime('now')
|
|
294
|
+
""", (
|
|
295
|
+
session_id, state["taint_level"],
|
|
296
|
+
state["turns_since_taint"],
|
|
297
|
+
state["total_tool_calls"],
|
|
298
|
+
state["total_external_ingests"] + 1,
|
|
299
|
+
state["total_taint_escalations"],
|
|
300
|
+
source, state["last_taint_reason"],
|
|
301
|
+
))
|
|
302
|
+
|
|
303
|
+
def clear_session(self, session_id: str):
|
|
304
|
+
"""Remove taint tracking for a session."""
|
|
305
|
+
conn = self._get_connection()
|
|
306
|
+
conn.execute(
|
|
307
|
+
"DELETE FROM session_taint WHERE session_id = ?",
|
|
308
|
+
(session_id,),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
312
|
+
"""Get aggregate stats across all tracked sessions."""
|
|
313
|
+
conn = self._get_connection()
|
|
314
|
+
row = conn.execute("""
|
|
315
|
+
SELECT
|
|
316
|
+
COUNT(*) as total_sessions,
|
|
317
|
+
SUM(CASE WHEN taint_level = 'clean' THEN 1 ELSE 0 END) as clean_sessions,
|
|
318
|
+
SUM(CASE WHEN taint_level != 'clean' THEN 1 ELSE 0 END) as tainted_sessions,
|
|
319
|
+
SUM(total_taint_escalations) as total_escalations,
|
|
320
|
+
SUM(total_external_ingests) as total_ingests
|
|
321
|
+
FROM session_taint
|
|
322
|
+
""").fetchone()
|
|
323
|
+
|
|
324
|
+
return dict(row) if row else {
|
|
325
|
+
"total_sessions": 0,
|
|
326
|
+
"clean_sessions": 0,
|
|
327
|
+
"tainted_sessions": 0,
|
|
328
|
+
"total_escalations": 0,
|
|
329
|
+
"total_ingests": 0,
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# =========================================================================
|
|
334
|
+
# Enforcement Adjustment
|
|
335
|
+
# =========================================================================
|
|
336
|
+
|
|
337
|
+
# The "balanced" preset enforcement matrix for CLEAN sessions
|
|
338
|
+
# Compared to "cautious" default, this logs instead of asking for
|
|
339
|
+
# heuristic/contextual patterns in high severity
|
|
340
|
+
BALANCED_CLEAN_OVERRIDES = {
|
|
341
|
+
"critical": {"contextual": "log"}, # Don't prompt on broad contextual
|
|
342
|
+
"high": {"heuristic": "log", "contextual": "log"}, # Don't prompt on heuristic
|
|
343
|
+
"medium": {"deterministic": "log", "heuristic": "log", "contextual": "log"},
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def adjust_enforcement_for_taint(
|
|
348
|
+
base_decision: str,
|
|
349
|
+
severity: str,
|
|
350
|
+
confidence: str,
|
|
351
|
+
taint_level: str,
|
|
352
|
+
) -> str:
|
|
353
|
+
"""Adjust an enforcement decision based on session taint level.
|
|
354
|
+
|
|
355
|
+
In CLEAN sessions: relax heuristic/contextual patterns to "log"
|
|
356
|
+
In TAINTED sessions: keep base enforcement (or escalate)
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
base_decision: The decision from EnforcementPolicy.resolve()
|
|
360
|
+
severity: Pattern severity (critical/high/medium/low)
|
|
361
|
+
confidence: Pattern confidence (deterministic/heuristic/contextual)
|
|
362
|
+
taint_level: Current session taint level
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Adjusted decision: "deny", "ask", or "log"
|
|
366
|
+
"""
|
|
367
|
+
# Never relax a "deny" decision — deny is hardcoded by policy.
|
|
368
|
+
# This also covers critical+deterministic (always deny from policy).
|
|
369
|
+
# Note: break-glass may downgrade deny→ask before we see it, and
|
|
370
|
+
# we should respect that intentional override.
|
|
371
|
+
if base_decision == "deny":
|
|
372
|
+
return "deny"
|
|
373
|
+
|
|
374
|
+
# In clean sessions, apply the balanced overrides
|
|
375
|
+
if taint_level == "clean":
|
|
376
|
+
override = BALANCED_CLEAN_OVERRIDES.get(severity, {}).get(confidence)
|
|
377
|
+
if override is not None:
|
|
378
|
+
return override
|
|
379
|
+
|
|
380
|
+
# In tainted sessions (medium+), consider escalation
|
|
381
|
+
if TAINT_RANK.get(taint_level, 0) >= TAINT_RANK["medium"]:
|
|
382
|
+
# Escalate "log" to "ask" for heuristic patterns
|
|
383
|
+
if base_decision == "log" and confidence in ("heuristic", "deterministic"):
|
|
384
|
+
if severity in ("critical", "high"):
|
|
385
|
+
return "ask"
|
|
386
|
+
|
|
387
|
+
# In critically tainted sessions, escalate everything
|
|
388
|
+
if taint_level == "critical":
|
|
389
|
+
if base_decision == "log" and severity in ("critical", "high", "medium"):
|
|
390
|
+
return "ask"
|
|
391
|
+
|
|
392
|
+
return base_decision
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def should_skip_llm_for_clean_session(
|
|
396
|
+
taint_level: str,
|
|
397
|
+
tool_tier: str,
|
|
398
|
+
) -> bool:
|
|
399
|
+
"""Determine if LLM review can be skipped for clean sessions.
|
|
400
|
+
|
|
401
|
+
In clean sessions, default-tier tools don't need LLM review
|
|
402
|
+
since there's no external content that could contain injection.
|
|
403
|
+
|
|
404
|
+
Args:
|
|
405
|
+
taint_level: Current session taint level
|
|
406
|
+
tool_tier: The tier of the tool (safe/default/risky/dangerous)
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
True if LLM review can be skipped
|
|
410
|
+
"""
|
|
411
|
+
if taint_level != "clean":
|
|
412
|
+
return False
|
|
413
|
+
# Only skip for default tier (Read, Edit, NotebookEdit)
|
|
414
|
+
# Risky and dangerous always get LLM review
|
|
415
|
+
return tool_tier == "default"
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# =========================================================================
|
|
419
|
+
# Singleton
|
|
420
|
+
# =========================================================================
|
|
421
|
+
|
|
422
|
+
_taint_store: Optional[SessionTaintStore] = None
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def get_taint_store(db_path: Optional[Path] = None) -> SessionTaintStore:
|
|
426
|
+
"""Get the singleton SessionTaintStore instance."""
|
|
427
|
+
global _taint_store
|
|
428
|
+
if _taint_store is None:
|
|
429
|
+
_taint_store = SessionTaintStore(db_path)
|
|
430
|
+
return _taint_store
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def reset_taint_store():
|
|
434
|
+
"""Reset the singleton (for testing)."""
|
|
435
|
+
global _taint_store
|
|
436
|
+
if _taint_store is not None:
|
|
437
|
+
_taint_store.close()
|
|
438
|
+
_taint_store = None
|
tweek/memory/queries.py
CHANGED
tweek/memory/safety.py
CHANGED
|
@@ -31,8 +31,17 @@ MAX_RELAXATION = {
|
|
|
31
31
|
"allow": "allow", # Already at minimum
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
-
#
|
|
35
|
-
|
|
34
|
+
# Context-scoped decision thresholds: narrower context = fewer decisions needed.
|
|
35
|
+
# The system tries scopes narrowest-first and returns the first match.
|
|
36
|
+
# Global (pattern-only) is intentionally absent — too broad to be safe.
|
|
37
|
+
SCOPED_THRESHOLDS = {
|
|
38
|
+
"exact": 1, # pattern + tool + path + project
|
|
39
|
+
"tool_project": 3, # pattern + tool + project
|
|
40
|
+
"path": 5, # pattern + path_prefix
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
# Minimum weighted decisions (backward compat — smallest scope threshold)
|
|
44
|
+
MIN_DECISION_THRESHOLD = SCOPED_THRESHOLDS["exact"]
|
|
36
45
|
|
|
37
46
|
# Minimum approval ratio to suggest relaxation
|
|
38
47
|
MIN_APPROVAL_RATIO = 0.90 # 90% approval rate
|
|
@@ -115,18 +124,28 @@ def compute_suggested_decision(
|
|
|
115
124
|
total_weighted_decisions: float,
|
|
116
125
|
original_severity: str,
|
|
117
126
|
original_confidence: str,
|
|
127
|
+
min_threshold: Optional[float] = None,
|
|
118
128
|
) -> Optional[str]:
|
|
119
129
|
"""Compute what decision memory would suggest, if any.
|
|
120
130
|
|
|
121
131
|
Returns None if memory has no suggestion (insufficient data or
|
|
122
132
|
pattern is immune).
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
min_threshold: Override the minimum weighted-decision threshold.
|
|
136
|
+
Used by scoped queries where narrower context requires fewer
|
|
137
|
+
decisions. Defaults to SCOPED_THRESHOLDS["path"] (broadest
|
|
138
|
+
allowed scope).
|
|
123
139
|
"""
|
|
140
|
+
if min_threshold is None:
|
|
141
|
+
min_threshold = SCOPED_THRESHOLDS["path"]
|
|
142
|
+
|
|
124
143
|
# Immune patterns get no suggestions
|
|
125
144
|
if is_immune_pattern(original_severity, original_confidence):
|
|
126
145
|
return None
|
|
127
146
|
|
|
128
|
-
# Insufficient data
|
|
129
|
-
if total_weighted_decisions <
|
|
147
|
+
# Insufficient data for this scope
|
|
148
|
+
if total_weighted_decisions < min_threshold:
|
|
130
149
|
return None
|
|
131
150
|
|
|
132
151
|
# deny is never relaxed
|