tweek 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. tweek/__init__.py +2 -2
  2. tweek/audit.py +2 -2
  3. tweek/cli.py +78 -6605
  4. tweek/cli_config.py +643 -0
  5. tweek/cli_configure.py +413 -0
  6. tweek/cli_core.py +718 -0
  7. tweek/cli_dry_run.py +390 -0
  8. tweek/cli_helpers.py +316 -0
  9. tweek/cli_install.py +1666 -0
  10. tweek/cli_logs.py +301 -0
  11. tweek/cli_mcp.py +148 -0
  12. tweek/cli_memory.py +343 -0
  13. tweek/cli_plugins.py +748 -0
  14. tweek/cli_protect.py +564 -0
  15. tweek/cli_proxy.py +405 -0
  16. tweek/cli_security.py +236 -0
  17. tweek/cli_skills.py +289 -0
  18. tweek/cli_uninstall.py +551 -0
  19. tweek/cli_vault.py +313 -0
  20. tweek/config/allowed_dirs.yaml +16 -17
  21. tweek/config/families.yaml +4 -1
  22. tweek/config/manager.py +17 -0
  23. tweek/config/patterns.yaml +29 -5
  24. tweek/config/templates/config.yaml.template +212 -0
  25. tweek/config/templates/env.template +45 -0
  26. tweek/config/templates/overrides.yaml.template +121 -0
  27. tweek/config/templates/tweek.yaml.template +20 -0
  28. tweek/config/templates.py +136 -0
  29. tweek/config/tiers.yaml +5 -4
  30. tweek/diagnostics.py +112 -32
  31. tweek/hooks/overrides.py +4 -0
  32. tweek/hooks/post_tool_use.py +46 -1
  33. tweek/hooks/pre_tool_use.py +149 -49
  34. tweek/integrations/openclaw.py +84 -0
  35. tweek/licensing.py +1 -1
  36. tweek/mcp/__init__.py +7 -9
  37. tweek/mcp/clients/chatgpt.py +2 -2
  38. tweek/mcp/clients/claude_desktop.py +2 -2
  39. tweek/mcp/clients/gemini.py +2 -2
  40. tweek/mcp/proxy.py +165 -1
  41. tweek/memory/provenance.py +438 -0
  42. tweek/memory/queries.py +2 -0
  43. tweek/memory/safety.py +23 -4
  44. tweek/memory/schemas.py +1 -0
  45. tweek/memory/store.py +101 -71
  46. tweek/plugins/screening/heuristic_scorer.py +1 -1
  47. tweek/security/integrity.py +77 -0
  48. tweek/security/llm_reviewer.py +162 -68
  49. tweek/security/local_reviewer.py +44 -2
  50. tweek/security/model_registry.py +73 -7
  51. tweek/skill_template/overrides-reference.md +1 -1
  52. tweek/skills/context.py +221 -0
  53. tweek/skills/scanner.py +2 -2
  54. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/METADATA +8 -7
  55. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/RECORD +60 -38
  56. tweek/mcp/server.py +0 -320
  57. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/WHEEL +0 -0
  58. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/entry_points.txt +0 -0
  59. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/licenses/LICENSE +0 -0
  60. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/licenses/NOTICE +0 -0
  61. {tweek-0.3.1.dist-info → tweek-0.4.0.dist-info}/top_level.txt +0 -0
tweek/mcp/proxy.py CHANGED
@@ -28,6 +28,9 @@ from typing import Any, Dict, List, Optional, Tuple
28
28
 
29
29
  logger = logging.getLogger(__name__)
30
30
 
31
+ # Version for MCP server identification
32
+ MCP_SERVER_VERSION = "0.2.0"
33
+
31
34
  try:
32
35
  from mcp.client.session import ClientSession
33
36
  from mcp.client.stdio import StdioServerParameters, stdio_client
@@ -264,8 +267,56 @@ class TweekMCPProxy:
264
267
 
265
268
  @self.server.list_tools()
266
269
  async def list_tools() -> list[Tool]:
267
- """Return merged tools from all connected upstreams."""
270
+ """Return built-in tools plus merged tools from all connected upstreams."""
268
271
  merged = []
272
+
273
+ # Built-in tools (vault + status)
274
+ tool_configs = self.config.get("mcp", {}).get("proxy", {}).get("tools", {})
275
+
276
+ if tool_configs.get("vault", True):
277
+ merged.append(Tool(
278
+ name="tweek_vault",
279
+ description=(
280
+ "Retrieve a credential from Tweek's secure vault. "
281
+ "Credentials are stored in the system keychain, not in .env files. "
282
+ "Use this instead of reading .env files or hardcoding secrets."
283
+ ),
284
+ inputSchema={
285
+ "type": "object",
286
+ "properties": {
287
+ "skill": {
288
+ "type": "string",
289
+ "description": "Skill namespace for the credential",
290
+ },
291
+ "key": {
292
+ "type": "string",
293
+ "description": "Credential key name",
294
+ },
295
+ },
296
+ "required": ["skill", "key"],
297
+ },
298
+ ))
299
+
300
+ if tool_configs.get("status", True):
301
+ merged.append(Tool(
302
+ name="tweek_status",
303
+ description=(
304
+ "Show Tweek security status including active plugins, "
305
+ "recent activity, threat summary, and proxy statistics."
306
+ ),
307
+ inputSchema={
308
+ "type": "object",
309
+ "properties": {
310
+ "detail": {
311
+ "type": "string",
312
+ "enum": ["summary", "plugins", "activity", "threats"],
313
+ "description": "Level of detail (default: summary)",
314
+ },
315
+ },
316
+ },
317
+ ))
318
+
319
+ # Upstream tools
269
320
  for upstream_name, upstream in self.upstreams.items():
270
321
  if not upstream.connected:
271
322
  continue
@@ -279,6 +330,24 @@ class TweekMCPProxy:
279
330
  async def call_tool(name: str, arguments: dict) -> list[TextContent]:
280
331
  """Handle tool calls with security screening and approval."""
281
332
  self._request_count += 1
333
+
334
+ # Handle built-in tools (vault, status)
335
+ builtin_handlers = {
336
+ "tweek_vault": self._handle_vault,
337
+ "tweek_status": self._handle_status,
338
+ }
339
+ handler = builtin_handlers.get(name)
340
+ if handler is not None:
341
+ try:
342
+ result = await handler(arguments)
343
+ return [TextContent(type="text", text=result)]
344
+ except Exception as e:
345
+ logger.error(f"Tool {name} failed: {e}")
346
+ return [TextContent(
347
+ type="text",
348
+ text=json.dumps({"error": str(e), "tool": name}),
349
+ )]
350
+
282
351
  return await self._handle_call_tool(name, arguments)
283
352
 
284
353
  async def _handle_call_tool(
@@ -578,6 +647,101 @@ class TweekMCPProxy:
578
647
  except Exception as e:
579
648
  logger.debug(f"Failed to log security event: {e}")
580
649
 
650
+ async def _handle_vault(self, arguments: Dict[str, Any]) -> str:
651
+ """Handle tweek_vault tool call."""
652
+ skill = arguments.get("skill", "")
653
+ key = arguments.get("key", "")
654
+
655
+ # Screen vault access
656
+ context = self._build_context(
657
+ tool_name="Vault",
658
+ content=f"vault:{skill}/{key}",
659
+ upstream_name="_builtin",
660
+ tool_input=arguments,
661
+ )
662
+ screening = self._run_screening(context)
663
+
664
+ if screening.get("blocked"):
665
+ return json.dumps({
666
+ "blocked": True,
667
+ "reason": screening.get("reason"),
668
+ })
669
+
670
+ try:
671
+ from tweek.vault.cross_platform import CrossPlatformVault
672
+
673
+ vault = CrossPlatformVault()
674
+ value = vault.get(skill, key)
675
+
676
+ if value is None:
677
+ return json.dumps({
678
+ "error": f"Credential not found: {skill}/{key}",
679
+ "available": False,
680
+ })
681
+
682
+ return json.dumps({
683
+ "value": value,
684
+ "skill": skill,
685
+ "key": key,
686
+ })
687
+
688
+ except Exception as e:
689
+ # Don't leak internal details across trust boundary
690
+ logger.error(f"Vault operation failed: {e}")
691
+ return json.dumps({"error": "Vault operation failed"})
692
+
693
+ async def _handle_status(self, arguments: Dict[str, Any]) -> str:
694
+ """Handle tweek_status tool call."""
695
+ detail = arguments.get("detail", "summary")
696
+
697
+ try:
698
+ status = {
699
+ "version": MCP_SERVER_VERSION,
700
+ "source": "mcp_proxy",
701
+ "mode": "proxy",
702
+ "proxy_requests": self._request_count,
703
+ "proxy_blocked": self._blocked_count,
704
+ "proxy_approvals": self._approval_count,
705
+ }
706
+
707
+ if detail in ("summary", "plugins"):
708
+ try:
709
+ from tweek.plugins import get_registry
710
+ registry = get_registry()
711
+ stats = registry.get_stats()
712
+ status["plugins"] = stats
713
+ except ImportError:
714
+ status["plugins"] = {"error": "Plugin system not available"}
715
+
716
+ if detail in ("summary", "activity"):
717
+ try:
718
+ from tweek.logging.security_log import get_logger as get_sec_logger
719
+ sec_logger = get_sec_logger()
720
+ recent = sec_logger.get_recent(limit=10)
721
+ status["recent_activity"] = [
722
+ {
723
+ "timestamp": str(e.timestamp),
724
+ "event_type": e.event_type.value,
725
+ "tool": e.tool_name,
726
+ "decision": e.decision,
727
+ }
728
+ for e in recent
729
+ ] if recent else []
730
+ except (ImportError, Exception):
731
+ status["recent_activity"] = []
732
+
733
+ # Include approval queue stats if available
734
+ try:
735
+ queue = self._get_approval_queue()
736
+ status["approval_queue"] = queue.get_stats()
737
+ except Exception:
738
+ pass
739
+
740
+ return json.dumps(status, indent=2)
741
+
742
+ except Exception as e:
743
+ return json.dumps({"error": str(e)})
744
+
581
745
  async def start(self) -> None:
582
746
  """
583
747
  Start the proxy: connect to upstreams and serve on stdio.
@@ -0,0 +1,438 @@
1
+ """
2
+ Tweek Session Provenance & Taint Tracking
3
+
4
+ Tracks the "trust lineage" of commands within a session.
5
+ When external content is ingested (Read, WebFetch, WebSearch)
6
+ and that content contains suspicious patterns, the session
7
+ becomes "tainted" — subsequent commands receive heightened scrutiny.
8
+
9
+ When a session is "clean" (no external content or all content
10
+ from trusted sources), enforcement thresholds are relaxed to
11
+ reduce false positives.
12
+
13
+ Taint levels: clean → low → medium → high → critical
14
+ Taint decays by one level every DECAY_INTERVAL tool calls
15
+ without new external content or pattern matches.
16
+
17
+ Storage: SQLite persistent in memory.db (same DB as pattern decisions).
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import sqlite3
23
+ from datetime import datetime
24
+ from pathlib import Path
25
+ from typing import Optional, Dict, Any
26
+
27
+ # =========================================================================
28
+ # Constants
29
+ # =========================================================================
30
+
31
+ # Taint levels in order of severity
32
+ TAINT_LEVELS = ("clean", "low", "medium", "high", "critical")
33
+ TAINT_RANK = {level: i for i, level in enumerate(TAINT_LEVELS)}
34
+
35
+ # How many tool calls between taint decay steps
36
+ DECAY_INTERVAL = 5
37
+
38
+ # Tools classified as external content sources
39
+ EXTERNAL_SOURCE_TOOLS = frozenset({"Read", "WebFetch", "WebSearch", "Grep"})
40
+
41
+ # Tools classified as action tools (user-context)
42
+ ACTION_TOOLS = frozenset({"Bash", "Write", "Edit", "NotebookEdit"})
43
+
44
+
45
+ # =========================================================================
46
+ # Schema
47
+ # =========================================================================
48
+
49
+ PROVENANCE_SCHEMA = """
50
+ CREATE TABLE IF NOT EXISTS session_taint (
51
+ session_id TEXT PRIMARY KEY,
52
+ taint_level TEXT NOT NULL DEFAULT 'clean',
53
+ last_taint_source TEXT,
54
+ last_taint_reason TEXT,
55
+ turns_since_taint INTEGER NOT NULL DEFAULT 0,
56
+ total_tool_calls INTEGER NOT NULL DEFAULT 0,
57
+ total_external_ingests INTEGER NOT NULL DEFAULT 0,
58
+ total_taint_escalations INTEGER NOT NULL DEFAULT 0,
59
+ created_at TEXT NOT NULL DEFAULT (datetime('now')),
60
+ updated_at TEXT NOT NULL DEFAULT (datetime('now'))
61
+ );
62
+
63
+ CREATE INDEX IF NOT EXISTS idx_st_taint_level
64
+ ON session_taint(taint_level);
65
+ """
66
+
67
+
68
+ # =========================================================================
69
+ # Taint Level Operations
70
+ # =========================================================================
71
+
72
+ def escalate_taint(current: str, to_level: str) -> str:
73
+ """Escalate taint level (never downgrade via this function).
74
+
75
+ Returns the higher of current and to_level.
76
+ """
77
+ current_rank = TAINT_RANK.get(current, 0)
78
+ target_rank = TAINT_RANK.get(to_level, 0)
79
+ if target_rank > current_rank:
80
+ return to_level
81
+ return current
82
+
83
+
84
+ def decay_taint(current: str) -> str:
85
+ """Decay taint level by one step toward 'clean'.
86
+
87
+ Returns the next lower taint level.
88
+ """
89
+ rank = TAINT_RANK.get(current, 0)
90
+ if rank <= 0:
91
+ return "clean"
92
+ return TAINT_LEVELS[rank - 1]
93
+
94
+
95
+ def severity_to_taint(pattern_severity: str) -> str:
96
+ """Map a pattern severity to a taint level.
97
+
98
+ critical pattern → critical taint
99
+ high pattern → high taint
100
+ medium pattern → medium taint
101
+ low pattern → low taint
102
+ """
103
+ mapping = {
104
+ "critical": "critical",
105
+ "high": "high",
106
+ "medium": "medium",
107
+ "low": "low",
108
+ }
109
+ return mapping.get(pattern_severity, "medium")
110
+
111
+
112
+ # =========================================================================
113
+ # Session Taint Store
114
+ # =========================================================================
115
+
116
+ class SessionTaintStore:
117
+ """SQLite-backed session taint tracking.
118
+
119
+ Uses the same database as MemoryStore (memory.db) but manages
120
+ its own table (session_taint).
121
+ """
122
+
123
+ def __init__(self, db_path: Optional[Path] = None):
124
+ from tweek.memory.store import GLOBAL_MEMORY_PATH
125
+ self.db_path = db_path or GLOBAL_MEMORY_PATH
126
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
127
+ self._conn: Optional[sqlite3.Connection] = None
128
+ self._ensure_schema()
129
+
130
+ def _get_connection(self) -> sqlite3.Connection:
131
+ if self._conn is None:
132
+ self._conn = sqlite3.connect(
133
+ str(self.db_path),
134
+ timeout=5.0,
135
+ isolation_level=None,
136
+ )
137
+ self._conn.row_factory = sqlite3.Row
138
+ self._conn.execute("PRAGMA journal_mode=WAL")
139
+ return self._conn
140
+
141
+ def close(self):
142
+ if self._conn is not None:
143
+ self._conn.close()
144
+ self._conn = None
145
+
146
+ def _ensure_schema(self):
147
+ conn = self._get_connection()
148
+ conn.executescript(PROVENANCE_SCHEMA)
149
+
150
+ def get_session_taint(self, session_id: str) -> Dict[str, Any]:
151
+ """Get current taint state for a session.
152
+
153
+ Returns a dict with taint_level, turns_since_taint, etc.
154
+ Returns a "clean" default if session not found.
155
+ """
156
+ conn = self._get_connection()
157
+ row = conn.execute(
158
+ "SELECT * FROM session_taint WHERE session_id = ?",
159
+ (session_id,),
160
+ ).fetchone()
161
+
162
+ if row is None:
163
+ return {
164
+ "taint_level": "clean",
165
+ "turns_since_taint": 0,
166
+ "total_tool_calls": 0,
167
+ "total_external_ingests": 0,
168
+ "total_taint_escalations": 0,
169
+ "last_taint_source": None,
170
+ "last_taint_reason": None,
171
+ }
172
+
173
+ return dict(row)
174
+
175
+ def record_tool_call(self, session_id: str, tool_name: str) -> Dict[str, Any]:
176
+ """Record a tool call and apply taint decay if applicable.
177
+
178
+ Called from pre_tool_use to track tool calls and decay taint.
179
+ Returns the updated taint state.
180
+ """
181
+ conn = self._get_connection()
182
+ state = self.get_session_taint(session_id)
183
+
184
+ new_total = state["total_tool_calls"] + 1
185
+ new_turns_since_taint = state["turns_since_taint"] + 1
186
+
187
+ # Apply decay if enough clean turns have passed
188
+ current_taint = state["taint_level"]
189
+ if (current_taint != "clean"
190
+ and new_turns_since_taint >= DECAY_INTERVAL):
191
+ current_taint = decay_taint(current_taint)
192
+ new_turns_since_taint = 0 # Reset counter after decay
193
+
194
+ # Upsert
195
+ conn.execute("""
196
+ INSERT INTO session_taint
197
+ (session_id, taint_level, turns_since_taint,
198
+ total_tool_calls, total_external_ingests,
199
+ total_taint_escalations,
200
+ last_taint_source, last_taint_reason,
201
+ updated_at)
202
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
203
+ ON CONFLICT(session_id) DO UPDATE SET
204
+ taint_level = excluded.taint_level,
205
+ turns_since_taint = excluded.turns_since_taint,
206
+ total_tool_calls = excluded.total_tool_calls,
207
+ updated_at = datetime('now')
208
+ """, (
209
+ session_id, current_taint, new_turns_since_taint,
210
+ new_total, state["total_external_ingests"],
211
+ state["total_taint_escalations"],
212
+ state["last_taint_source"], state["last_taint_reason"],
213
+ ))
214
+
215
+ state["taint_level"] = current_taint
216
+ state["turns_since_taint"] = new_turns_since_taint
217
+ state["total_tool_calls"] = new_total
218
+ return state
219
+
220
+ def record_taint(
221
+ self,
222
+ session_id: str,
223
+ taint_level: str,
224
+ source: str,
225
+ reason: str,
226
+ ) -> Dict[str, Any]:
227
+ """Escalate session taint after finding suspicious content.
228
+
229
+ Called from post_tool_use when patterns are found in ingested content.
230
+ Taint only escalates, never downgrades via this method.
231
+ Returns the updated taint state.
232
+ """
233
+ conn = self._get_connection()
234
+ state = self.get_session_taint(session_id)
235
+
236
+ new_taint = escalate_taint(state["taint_level"], taint_level)
237
+ new_escalations = state["total_taint_escalations"]
238
+ if new_taint != state["taint_level"]:
239
+ new_escalations += 1
240
+
241
+ new_ingests = state["total_external_ingests"] + 1
242
+
243
+ conn.execute("""
244
+ INSERT INTO session_taint
245
+ (session_id, taint_level, turns_since_taint,
246
+ total_tool_calls, total_external_ingests,
247
+ total_taint_escalations,
248
+ last_taint_source, last_taint_reason,
249
+ updated_at)
250
+ VALUES (?, ?, 0, ?, ?, ?, ?, ?, datetime('now'))
251
+ ON CONFLICT(session_id) DO UPDATE SET
252
+ taint_level = excluded.taint_level,
253
+ turns_since_taint = 0,
254
+ total_tool_calls = excluded.total_tool_calls,
255
+ total_external_ingests = excluded.total_external_ingests,
256
+ total_taint_escalations = excluded.total_taint_escalations,
257
+ last_taint_source = excluded.last_taint_source,
258
+ last_taint_reason = excluded.last_taint_reason,
259
+ updated_at = datetime('now')
260
+ """, (
261
+ session_id, new_taint,
262
+ state["total_tool_calls"], new_ingests,
263
+ new_escalations, source, reason,
264
+ ))
265
+
266
+ state["taint_level"] = new_taint
267
+ state["turns_since_taint"] = 0
268
+ state["total_external_ingests"] = new_ingests
269
+ state["total_taint_escalations"] = new_escalations
270
+ state["last_taint_source"] = source
271
+ state["last_taint_reason"] = reason
272
+ return state
273
+
274
+ def record_external_ingest(self, session_id: str, source: str):
275
+ """Record an external content ingest without escalating taint.
276
+
277
+ Called when Read/WebFetch returns content that passes screening.
278
+ Tracks the ingest count but doesn't change taint level.
279
+ """
280
+ conn = self._get_connection()
281
+ state = self.get_session_taint(session_id)
282
+
283
+ conn.execute("""
284
+ INSERT INTO session_taint
285
+ (session_id, taint_level, turns_since_taint,
286
+ total_tool_calls, total_external_ingests,
287
+ total_taint_escalations,
288
+ last_taint_source, last_taint_reason,
289
+ updated_at)
290
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, datetime('now'))
291
+ ON CONFLICT(session_id) DO UPDATE SET
292
+ total_external_ingests = excluded.total_external_ingests,
293
+ updated_at = datetime('now')
294
+ """, (
295
+ session_id, state["taint_level"],
296
+ state["turns_since_taint"],
297
+ state["total_tool_calls"],
298
+ state["total_external_ingests"] + 1,
299
+ state["total_taint_escalations"],
300
+ source, state["last_taint_reason"],
301
+ ))
302
+
303
+ def clear_session(self, session_id: str):
304
+ """Remove taint tracking for a session."""
305
+ conn = self._get_connection()
306
+ conn.execute(
307
+ "DELETE FROM session_taint WHERE session_id = ?",
308
+ (session_id,),
309
+ )
310
+
311
+ def get_stats(self) -> Dict[str, Any]:
312
+ """Get aggregate stats across all tracked sessions."""
313
+ conn = self._get_connection()
314
+ row = conn.execute("""
315
+ SELECT
316
+ COUNT(*) as total_sessions,
317
+ SUM(CASE WHEN taint_level = 'clean' THEN 1 ELSE 0 END) as clean_sessions,
318
+ SUM(CASE WHEN taint_level != 'clean' THEN 1 ELSE 0 END) as tainted_sessions,
319
+ SUM(total_taint_escalations) as total_escalations,
320
+ SUM(total_external_ingests) as total_ingests
321
+ FROM session_taint
322
+ """).fetchone()
323
+
324
+ return dict(row) if row else {
325
+ "total_sessions": 0,
326
+ "clean_sessions": 0,
327
+ "tainted_sessions": 0,
328
+ "total_escalations": 0,
329
+ "total_ingests": 0,
330
+ }
331
+
332
+
333
+ # =========================================================================
334
+ # Enforcement Adjustment
335
+ # =========================================================================
336
+
337
+ # The "balanced" preset enforcement matrix for CLEAN sessions
338
+ # Compared to "cautious" default, this logs instead of asking for
339
+ # heuristic/contextual patterns in high severity
340
+ BALANCED_CLEAN_OVERRIDES = {
341
+ "critical": {"contextual": "log"}, # Don't prompt on broad contextual
342
+ "high": {"heuristic": "log", "contextual": "log"}, # Don't prompt on heuristic
343
+ "medium": {"deterministic": "log", "heuristic": "log", "contextual": "log"},
344
+ }
345
+
346
+
347
+ def adjust_enforcement_for_taint(
348
+ base_decision: str,
349
+ severity: str,
350
+ confidence: str,
351
+ taint_level: str,
352
+ ) -> str:
353
+ """Adjust an enforcement decision based on session taint level.
354
+
355
+ In CLEAN sessions: relax heuristic/contextual patterns to "log"
356
+ In TAINTED sessions: keep base enforcement (or escalate)
357
+
358
+ Args:
359
+ base_decision: The decision from EnforcementPolicy.resolve()
360
+ severity: Pattern severity (critical/high/medium/low)
361
+ confidence: Pattern confidence (deterministic/heuristic/contextual)
362
+ taint_level: Current session taint level
363
+
364
+ Returns:
365
+ Adjusted decision: "deny", "ask", or "log"
366
+ """
367
+ # Never relax a "deny" decision — deny is hardcoded by policy.
368
+ # This also covers critical+deterministic (always deny from policy).
369
+ # Note: break-glass may downgrade deny→ask before we see it, and
370
+ # we should respect that intentional override.
371
+ if base_decision == "deny":
372
+ return "deny"
373
+
374
+ # In clean sessions, apply the balanced overrides
375
+ if taint_level == "clean":
376
+ override = BALANCED_CLEAN_OVERRIDES.get(severity, {}).get(confidence)
377
+ if override is not None:
378
+ return override
379
+
380
+ # In tainted sessions (medium+), consider escalation
381
+ if TAINT_RANK.get(taint_level, 0) >= TAINT_RANK["medium"]:
382
+ # Escalate "log" to "ask" for heuristic patterns
383
+ if base_decision == "log" and confidence in ("heuristic", "deterministic"):
384
+ if severity in ("critical", "high"):
385
+ return "ask"
386
+
387
+ # In critically tainted sessions, escalate everything
388
+ if taint_level == "critical":
389
+ if base_decision == "log" and severity in ("critical", "high", "medium"):
390
+ return "ask"
391
+
392
+ return base_decision
393
+
394
+
395
+ def should_skip_llm_for_clean_session(
396
+ taint_level: str,
397
+ tool_tier: str,
398
+ ) -> bool:
399
+ """Determine if LLM review can be skipped for clean sessions.
400
+
401
+ In clean sessions, default-tier tools don't need LLM review
402
+ since there's no external content that could contain injection.
403
+
404
+ Args:
405
+ taint_level: Current session taint level
406
+ tool_tier: The tier of the tool (safe/default/risky/dangerous)
407
+
408
+ Returns:
409
+ True if LLM review can be skipped
410
+ """
411
+ if taint_level != "clean":
412
+ return False
413
+ # Only skip for default tier (Read, Edit, NotebookEdit)
414
+ # Risky and dangerous always get LLM review
415
+ return tool_tier == "default"
416
+
417
+
418
+ # =========================================================================
419
+ # Singleton
420
+ # =========================================================================
421
+
422
+ _taint_store: Optional[SessionTaintStore] = None
423
+
424
+
425
+ def get_taint_store(db_path: Optional[Path] = None) -> SessionTaintStore:
426
+ """Get the singleton SessionTaintStore instance."""
427
+ global _taint_store
428
+ if _taint_store is None:
429
+ _taint_store = SessionTaintStore(db_path)
430
+ return _taint_store
431
+
432
+
433
+ def reset_taint_store():
434
+ """Reset the singleton (for testing)."""
435
+ global _taint_store
436
+ if _taint_store is not None:
437
+ _taint_store.close()
438
+ _taint_store = None
tweek/memory/queries.py CHANGED
@@ -46,6 +46,8 @@ def memory_read_for_pattern(
46
46
  current_decision=current_decision,
47
47
  original_severity=pattern_severity,
48
48
  original_confidence=pattern_confidence,
49
+ tool_name=tool_name,
50
+ project_hash=project_hash,
49
51
  )
50
52
 
51
53
  if adjustment is None:
tweek/memory/safety.py CHANGED
@@ -31,8 +31,17 @@ MAX_RELAXATION = {
31
31
  "allow": "allow", # Already at minimum
32
32
  }
33
33
 
34
- # Minimum weighted decisions before memory can suggest adjustments
35
- MIN_DECISION_THRESHOLD = 10
34
+ # Context-scoped decision thresholds: narrower context = fewer decisions needed.
35
+ # The system tries scopes narrowest-first and returns the first match.
36
+ # Global (pattern-only) is intentionally absent — too broad to be safe.
37
+ SCOPED_THRESHOLDS = {
38
+ "exact": 1, # pattern + tool + path + project
39
+ "tool_project": 3, # pattern + tool + project
40
+ "path": 5, # pattern + path_prefix
41
+ }
42
+
43
+ # Minimum weighted decisions (backward compat — smallest scope threshold)
44
+ MIN_DECISION_THRESHOLD = SCOPED_THRESHOLDS["exact"]
36
45
 
37
46
  # Minimum approval ratio to suggest relaxation
38
47
  MIN_APPROVAL_RATIO = 0.90 # 90% approval rate
@@ -115,18 +124,28 @@ def compute_suggested_decision(
115
124
  total_weighted_decisions: float,
116
125
  original_severity: str,
117
126
  original_confidence: str,
127
+ min_threshold: Optional[float] = None,
118
128
  ) -> Optional[str]:
119
129
  """Compute what decision memory would suggest, if any.
120
130
 
121
131
  Returns None if memory has no suggestion (insufficient data or
122
132
  pattern is immune).
133
+
134
+ Args:
135
+ min_threshold: Override the minimum weighted-decision threshold.
136
+ Used by scoped queries where narrower context requires fewer
137
+ decisions. Defaults to SCOPED_THRESHOLDS["path"] (broadest
138
+ allowed scope).
123
139
  """
140
+ if min_threshold is None:
141
+ min_threshold = SCOPED_THRESHOLDS["path"]
142
+
124
143
  # Immune patterns get no suggestions
125
144
  if is_immune_pattern(original_severity, original_confidence):
126
145
  return None
127
146
 
128
- # Insufficient data
129
- if total_weighted_decisions < MIN_DECISION_THRESHOLD:
147
+ # Insufficient data for this scope
148
+ if total_weighted_decisions < min_threshold:
130
149
  return None
131
150
 
132
151
  # deny is never relaxed