tweek 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tweek/__init__.py +16 -0
- tweek/cli.py +3390 -0
- tweek/cli_helpers.py +193 -0
- tweek/config/__init__.py +13 -0
- tweek/config/allowed_dirs.yaml +23 -0
- tweek/config/manager.py +1064 -0
- tweek/config/patterns.yaml +751 -0
- tweek/config/tiers.yaml +129 -0
- tweek/diagnostics.py +589 -0
- tweek/hooks/__init__.py +1 -0
- tweek/hooks/pre_tool_use.py +861 -0
- tweek/integrations/__init__.py +3 -0
- tweek/integrations/moltbot.py +243 -0
- tweek/licensing.py +398 -0
- tweek/logging/__init__.py +9 -0
- tweek/logging/bundle.py +350 -0
- tweek/logging/json_logger.py +150 -0
- tweek/logging/security_log.py +745 -0
- tweek/mcp/__init__.py +24 -0
- tweek/mcp/approval.py +456 -0
- tweek/mcp/approval_cli.py +356 -0
- tweek/mcp/clients/__init__.py +37 -0
- tweek/mcp/clients/chatgpt.py +112 -0
- tweek/mcp/clients/claude_desktop.py +203 -0
- tweek/mcp/clients/gemini.py +178 -0
- tweek/mcp/proxy.py +667 -0
- tweek/mcp/screening.py +175 -0
- tweek/mcp/server.py +317 -0
- tweek/platform/__init__.py +131 -0
- tweek/plugins/__init__.py +835 -0
- tweek/plugins/base.py +1080 -0
- tweek/plugins/compliance/__init__.py +30 -0
- tweek/plugins/compliance/gdpr.py +333 -0
- tweek/plugins/compliance/gov.py +324 -0
- tweek/plugins/compliance/hipaa.py +285 -0
- tweek/plugins/compliance/legal.py +322 -0
- tweek/plugins/compliance/pci.py +361 -0
- tweek/plugins/compliance/soc2.py +275 -0
- tweek/plugins/detectors/__init__.py +30 -0
- tweek/plugins/detectors/continue_dev.py +206 -0
- tweek/plugins/detectors/copilot.py +254 -0
- tweek/plugins/detectors/cursor.py +192 -0
- tweek/plugins/detectors/moltbot.py +205 -0
- tweek/plugins/detectors/windsurf.py +214 -0
- tweek/plugins/git_discovery.py +395 -0
- tweek/plugins/git_installer.py +491 -0
- tweek/plugins/git_lockfile.py +338 -0
- tweek/plugins/git_registry.py +503 -0
- tweek/plugins/git_security.py +482 -0
- tweek/plugins/providers/__init__.py +30 -0
- tweek/plugins/providers/anthropic.py +181 -0
- tweek/plugins/providers/azure_openai.py +289 -0
- tweek/plugins/providers/bedrock.py +248 -0
- tweek/plugins/providers/google.py +197 -0
- tweek/plugins/providers/openai.py +230 -0
- tweek/plugins/scope.py +130 -0
- tweek/plugins/screening/__init__.py +26 -0
- tweek/plugins/screening/llm_reviewer.py +149 -0
- tweek/plugins/screening/pattern_matcher.py +273 -0
- tweek/plugins/screening/rate_limiter.py +174 -0
- tweek/plugins/screening/session_analyzer.py +159 -0
- tweek/proxy/__init__.py +302 -0
- tweek/proxy/addon.py +223 -0
- tweek/proxy/interceptor.py +313 -0
- tweek/proxy/server.py +315 -0
- tweek/sandbox/__init__.py +71 -0
- tweek/sandbox/executor.py +382 -0
- tweek/sandbox/linux.py +278 -0
- tweek/sandbox/profile_generator.py +323 -0
- tweek/screening/__init__.py +13 -0
- tweek/screening/context.py +81 -0
- tweek/security/__init__.py +22 -0
- tweek/security/llm_reviewer.py +348 -0
- tweek/security/rate_limiter.py +682 -0
- tweek/security/secret_scanner.py +506 -0
- tweek/security/session_analyzer.py +600 -0
- tweek/vault/__init__.py +40 -0
- tweek/vault/cross_platform.py +251 -0
- tweek/vault/keychain.py +288 -0
- tweek-0.1.0.dist-info/METADATA +335 -0
- tweek-0.1.0.dist-info/RECORD +85 -0
- tweek-0.1.0.dist-info/WHEEL +5 -0
- tweek-0.1.0.dist-info/entry_points.txt +25 -0
- tweek-0.1.0.dist-info/licenses/LICENSE +190 -0
- tweek-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,682 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Tweek Rate Limiter
|
|
4
|
+
|
|
5
|
+
Protects against resource theft attacks (MCP sampling abuse, quota drain)
|
|
6
|
+
by detecting:
|
|
7
|
+
- Burst patterns (many commands in short time)
|
|
8
|
+
- Repeated identical commands
|
|
9
|
+
- Unusual invocation volume
|
|
10
|
+
- Suspicious velocity changes
|
|
11
|
+
|
|
12
|
+
Based on Unit42 research on MCP sampling attack vectors.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import hashlib
|
|
16
|
+
import json
|
|
17
|
+
import sqlite3
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from datetime import datetime, timedelta
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Optional, List, Dict, Any, Tuple
|
|
23
|
+
|
|
24
|
+
from tweek.logging.security_log import SecurityLogger, get_logger
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class RateLimitViolation(Enum):
|
|
28
|
+
"""Types of rate limit violations."""
|
|
29
|
+
BURST = "burst" # Too many commands in short window
|
|
30
|
+
REPEATED_COMMAND = "repeated" # Same command executed too many times
|
|
31
|
+
HIGH_VOLUME = "high_volume" # Total volume exceeds threshold
|
|
32
|
+
DANGEROUS_SPIKE = "dangerous_spike" # Spike in dangerous tier commands
|
|
33
|
+
VELOCITY_ANOMALY = "velocity" # Unusual acceleration in activity
|
|
34
|
+
CIRCUIT_OPEN = "circuit_open" # Circuit breaker is open
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class CircuitState(Enum):
|
|
38
|
+
"""Circuit breaker states."""
|
|
39
|
+
CLOSED = "closed" # Normal operation, requests allowed
|
|
40
|
+
OPEN = "open" # Failures exceeded threshold, requests blocked
|
|
41
|
+
HALF_OPEN = "half_open" # Testing if service recovered
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class RateLimitConfig:
|
|
46
|
+
"""Configuration for rate limiting thresholds."""
|
|
47
|
+
# Time windows (in seconds)
|
|
48
|
+
burst_window: int = 5
|
|
49
|
+
short_window: int = 60
|
|
50
|
+
long_window: int = 300
|
|
51
|
+
|
|
52
|
+
# Thresholds
|
|
53
|
+
burst_threshold: int = 15 # Max commands in burst window
|
|
54
|
+
max_per_minute: int = 60 # Max commands per minute
|
|
55
|
+
max_dangerous_per_minute: int = 10 # Max dangerous tier per minute
|
|
56
|
+
max_same_command: int = 5 # Max identical commands per minute
|
|
57
|
+
velocity_multiplier: float = 3.0 # Alert if velocity > N * baseline
|
|
58
|
+
|
|
59
|
+
# Baseline learning
|
|
60
|
+
baseline_window_hours: int = 24 # Hours of data for baseline
|
|
61
|
+
min_baseline_samples: int = 100 # Minimum samples for baseline
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class CircuitBreakerConfig:
|
|
66
|
+
"""Configuration for circuit breaker pattern."""
|
|
67
|
+
# Failure thresholds
|
|
68
|
+
failure_threshold: int = 5 # Failures before opening circuit
|
|
69
|
+
success_threshold: int = 3 # Successes in half-open before closing
|
|
70
|
+
|
|
71
|
+
# Timing
|
|
72
|
+
open_timeout: int = 60 # Seconds to stay open before half-open
|
|
73
|
+
half_open_max_requests: int = 3 # Max requests to test in half-open
|
|
74
|
+
|
|
75
|
+
# What counts as failure
|
|
76
|
+
count_rate_limit_as_failure: bool = True
|
|
77
|
+
count_timeout_as_failure: bool = True
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class CircuitBreakerState:
|
|
82
|
+
"""Current state of a circuit breaker."""
|
|
83
|
+
state: CircuitState = CircuitState.CLOSED
|
|
84
|
+
failure_count: int = 0
|
|
85
|
+
success_count: int = 0
|
|
86
|
+
last_failure_time: Optional[datetime] = None
|
|
87
|
+
last_state_change: Optional[datetime] = None
|
|
88
|
+
half_open_requests: int = 0
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class RateLimitResult:
|
|
93
|
+
"""Result of rate limit check."""
|
|
94
|
+
allowed: bool
|
|
95
|
+
violations: List[RateLimitViolation] = field(default_factory=list)
|
|
96
|
+
details: Dict[str, Any] = field(default_factory=dict)
|
|
97
|
+
message: Optional[str] = None
|
|
98
|
+
circuit_state: CircuitState = CircuitState.CLOSED
|
|
99
|
+
retry_after: Optional[int] = None # Seconds to wait before retry
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def is_burst(self) -> bool:
|
|
103
|
+
return RateLimitViolation.BURST in self.violations
|
|
104
|
+
|
|
105
|
+
@property
|
|
106
|
+
def is_repeated(self) -> bool:
|
|
107
|
+
return RateLimitViolation.REPEATED_COMMAND in self.violations
|
|
108
|
+
|
|
109
|
+
@property
|
|
110
|
+
def is_circuit_open(self) -> bool:
|
|
111
|
+
return self.circuit_state == CircuitState.OPEN
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class CircuitBreaker:
|
|
115
|
+
"""
|
|
116
|
+
Circuit breaker pattern implementation for fault tolerance.
|
|
117
|
+
|
|
118
|
+
States:
|
|
119
|
+
- CLOSED: Normal operation, requests allowed, failures tracked
|
|
120
|
+
- OPEN: Too many failures, requests blocked, waiting for timeout
|
|
121
|
+
- HALF_OPEN: Testing recovery, limited requests allowed
|
|
122
|
+
|
|
123
|
+
Based on moltbot's circuit breaker implementation for resilience.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(self, config: Optional[CircuitBreakerConfig] = None):
|
|
127
|
+
"""
|
|
128
|
+
Initialize the circuit breaker.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
config: Circuit breaker configuration
|
|
132
|
+
"""
|
|
133
|
+
self.config = config or CircuitBreakerConfig()
|
|
134
|
+
self._states: Dict[str, CircuitBreakerState] = {}
|
|
135
|
+
|
|
136
|
+
def _get_state(self, key: str) -> CircuitBreakerState:
|
|
137
|
+
"""Get or create state for a circuit key."""
|
|
138
|
+
if key not in self._states:
|
|
139
|
+
self._states[key] = CircuitBreakerState(
|
|
140
|
+
last_state_change=datetime.now()
|
|
141
|
+
)
|
|
142
|
+
return self._states[key]
|
|
143
|
+
|
|
144
|
+
def _transition_to(self, state: CircuitBreakerState, new_state: CircuitState) -> None:
|
|
145
|
+
"""Transition circuit to a new state."""
|
|
146
|
+
state.state = new_state
|
|
147
|
+
state.last_state_change = datetime.now()
|
|
148
|
+
if new_state == CircuitState.HALF_OPEN:
|
|
149
|
+
state.half_open_requests = 0
|
|
150
|
+
state.success_count = 0
|
|
151
|
+
|
|
152
|
+
def can_execute(self, key: str = "default") -> Tuple[bool, CircuitState, Optional[int]]:
|
|
153
|
+
"""
|
|
154
|
+
Check if a request can be executed.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
key: Circuit breaker key (e.g., "session:123" or "tool:Bash")
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
(allowed, circuit_state, retry_after_seconds)
|
|
161
|
+
"""
|
|
162
|
+
state = self._get_state(key)
|
|
163
|
+
now = datetime.now()
|
|
164
|
+
|
|
165
|
+
if state.state == CircuitState.CLOSED:
|
|
166
|
+
return True, CircuitState.CLOSED, None
|
|
167
|
+
|
|
168
|
+
elif state.state == CircuitState.OPEN:
|
|
169
|
+
# Check if timeout has elapsed
|
|
170
|
+
if state.last_state_change:
|
|
171
|
+
elapsed = (now - state.last_state_change).total_seconds()
|
|
172
|
+
if elapsed >= self.config.open_timeout:
|
|
173
|
+
# Transition to half-open
|
|
174
|
+
self._transition_to(state, CircuitState.HALF_OPEN)
|
|
175
|
+
return True, CircuitState.HALF_OPEN, None
|
|
176
|
+
else:
|
|
177
|
+
retry_after = int(self.config.open_timeout - elapsed)
|
|
178
|
+
return False, CircuitState.OPEN, retry_after
|
|
179
|
+
|
|
180
|
+
return False, CircuitState.OPEN, self.config.open_timeout
|
|
181
|
+
|
|
182
|
+
elif state.state == CircuitState.HALF_OPEN:
|
|
183
|
+
# Allow limited requests in half-open state
|
|
184
|
+
if state.half_open_requests < self.config.half_open_max_requests:
|
|
185
|
+
state.half_open_requests += 1
|
|
186
|
+
return True, CircuitState.HALF_OPEN, None
|
|
187
|
+
else:
|
|
188
|
+
return False, CircuitState.HALF_OPEN, 5 # Short wait
|
|
189
|
+
|
|
190
|
+
return True, CircuitState.CLOSED, None
|
|
191
|
+
|
|
192
|
+
def record_success(self, key: str = "default") -> CircuitState:
|
|
193
|
+
"""
|
|
194
|
+
Record a successful request.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
key: Circuit breaker key
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
New circuit state
|
|
201
|
+
"""
|
|
202
|
+
state = self._get_state(key)
|
|
203
|
+
|
|
204
|
+
if state.state == CircuitState.HALF_OPEN:
|
|
205
|
+
state.success_count += 1
|
|
206
|
+
if state.success_count >= self.config.success_threshold:
|
|
207
|
+
# Recovery confirmed, close the circuit
|
|
208
|
+
self._transition_to(state, CircuitState.CLOSED)
|
|
209
|
+
state.failure_count = 0
|
|
210
|
+
|
|
211
|
+
elif state.state == CircuitState.CLOSED:
|
|
212
|
+
# Reset failure count on success
|
|
213
|
+
state.failure_count = 0
|
|
214
|
+
|
|
215
|
+
return state.state
|
|
216
|
+
|
|
217
|
+
def record_failure(self, key: str = "default") -> CircuitState:
|
|
218
|
+
"""
|
|
219
|
+
Record a failed request.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
key: Circuit breaker key
|
|
223
|
+
|
|
224
|
+
Returns:
|
|
225
|
+
New circuit state
|
|
226
|
+
"""
|
|
227
|
+
state = self._get_state(key)
|
|
228
|
+
state.failure_count += 1
|
|
229
|
+
state.last_failure_time = datetime.now()
|
|
230
|
+
|
|
231
|
+
if state.state == CircuitState.HALF_OPEN:
|
|
232
|
+
# Any failure in half-open reopens the circuit
|
|
233
|
+
self._transition_to(state, CircuitState.OPEN)
|
|
234
|
+
|
|
235
|
+
elif state.state == CircuitState.CLOSED:
|
|
236
|
+
if state.failure_count >= self.config.failure_threshold:
|
|
237
|
+
# Too many failures, open the circuit
|
|
238
|
+
self._transition_to(state, CircuitState.OPEN)
|
|
239
|
+
|
|
240
|
+
return state.state
|
|
241
|
+
|
|
242
|
+
def get_state(self, key: str = "default") -> CircuitBreakerState:
|
|
243
|
+
"""
|
|
244
|
+
Get the current state of a circuit.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
key: Circuit breaker key
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
Current circuit breaker state
|
|
251
|
+
"""
|
|
252
|
+
return self._get_state(key)
|
|
253
|
+
|
|
254
|
+
def reset(self, key: str = "default") -> None:
|
|
255
|
+
"""
|
|
256
|
+
Reset a circuit breaker to closed state.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
key: Circuit breaker key
|
|
260
|
+
"""
|
|
261
|
+
if key in self._states:
|
|
262
|
+
del self._states[key]
|
|
263
|
+
|
|
264
|
+
def get_all_states(self) -> Dict[str, CircuitBreakerState]:
|
|
265
|
+
"""Get all circuit breaker states."""
|
|
266
|
+
return self._states.copy()
|
|
267
|
+
|
|
268
|
+
def get_metrics(self) -> Dict[str, Any]:
|
|
269
|
+
"""
|
|
270
|
+
Get circuit breaker metrics.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Dictionary with metrics for all circuits
|
|
274
|
+
"""
|
|
275
|
+
metrics = {
|
|
276
|
+
"total_circuits": len(self._states),
|
|
277
|
+
"open_circuits": 0,
|
|
278
|
+
"half_open_circuits": 0,
|
|
279
|
+
"closed_circuits": 0,
|
|
280
|
+
"circuits": {}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
for key, state in self._states.items():
|
|
284
|
+
if state.state == CircuitState.OPEN:
|
|
285
|
+
metrics["open_circuits"] += 1
|
|
286
|
+
elif state.state == CircuitState.HALF_OPEN:
|
|
287
|
+
metrics["half_open_circuits"] += 1
|
|
288
|
+
else:
|
|
289
|
+
metrics["closed_circuits"] += 1
|
|
290
|
+
|
|
291
|
+
metrics["circuits"][key] = {
|
|
292
|
+
"state": state.state.value,
|
|
293
|
+
"failure_count": state.failure_count,
|
|
294
|
+
"success_count": state.success_count,
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
return metrics
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class RateLimiter:
|
|
301
|
+
"""
|
|
302
|
+
Rate limiter for detecting resource theft and abuse patterns.
|
|
303
|
+
|
|
304
|
+
Uses the security.db to track invocation patterns and detect anomalies.
|
|
305
|
+
Includes circuit breaker for fault tolerance.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self,
|
|
310
|
+
config: Optional[RateLimitConfig] = None,
|
|
311
|
+
circuit_config: Optional[CircuitBreakerConfig] = None,
|
|
312
|
+
logger: Optional[SecurityLogger] = None
|
|
313
|
+
):
|
|
314
|
+
"""Initialize the rate limiter.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
config: Rate limiting configuration
|
|
318
|
+
circuit_config: Circuit breaker configuration
|
|
319
|
+
logger: Security logger for database access
|
|
320
|
+
"""
|
|
321
|
+
self.config = config or RateLimitConfig()
|
|
322
|
+
self.logger = logger or get_logger()
|
|
323
|
+
self.circuit_breaker = CircuitBreaker(circuit_config)
|
|
324
|
+
self._ensure_indexes()
|
|
325
|
+
|
|
326
|
+
def _ensure_indexes(self):
|
|
327
|
+
"""Ensure necessary database indexes exist for efficient queries."""
|
|
328
|
+
try:
|
|
329
|
+
with self.logger._get_connection() as conn:
|
|
330
|
+
conn.executescript("""
|
|
331
|
+
-- Index for session + timestamp queries (rate limiting)
|
|
332
|
+
CREATE INDEX IF NOT EXISTS idx_events_session_time
|
|
333
|
+
ON security_events(session_id, timestamp);
|
|
334
|
+
|
|
335
|
+
-- Index for command hash queries (repeated command detection)
|
|
336
|
+
CREATE INDEX IF NOT EXISTS idx_events_command_hash
|
|
337
|
+
ON security_events(tool_name, command);
|
|
338
|
+
""")
|
|
339
|
+
except Exception:
|
|
340
|
+
# Indexes may already exist or db not initialized
|
|
341
|
+
pass
|
|
342
|
+
|
|
343
|
+
def _hash_command(self, command: str) -> str:
|
|
344
|
+
"""Create a hash of a command for comparison."""
|
|
345
|
+
return hashlib.md5(command.encode()).hexdigest()[:16]
|
|
346
|
+
|
|
347
|
+
def _get_recent_count(
|
|
348
|
+
self,
|
|
349
|
+
conn: sqlite3.Connection,
|
|
350
|
+
session_id: str,
|
|
351
|
+
window_seconds: int,
|
|
352
|
+
tool_name: Optional[str] = None,
|
|
353
|
+
tier: Optional[str] = None
|
|
354
|
+
) -> int:
|
|
355
|
+
"""Get count of recent events in a time window."""
|
|
356
|
+
query = """
|
|
357
|
+
SELECT COUNT(*) as count FROM security_events
|
|
358
|
+
WHERE session_id = ?
|
|
359
|
+
AND timestamp > datetime('now', ?)
|
|
360
|
+
AND event_type = 'tool_invoked'
|
|
361
|
+
"""
|
|
362
|
+
params = [session_id, f'-{window_seconds} seconds']
|
|
363
|
+
|
|
364
|
+
if tool_name:
|
|
365
|
+
query += " AND tool_name = ?"
|
|
366
|
+
params.append(tool_name)
|
|
367
|
+
|
|
368
|
+
if tier:
|
|
369
|
+
query += " AND tier = ?"
|
|
370
|
+
params.append(tier)
|
|
371
|
+
|
|
372
|
+
return conn.execute(query, params).fetchone()[0]
|
|
373
|
+
|
|
374
|
+
def _get_command_count(
|
|
375
|
+
self,
|
|
376
|
+
conn: sqlite3.Connection,
|
|
377
|
+
session_id: str,
|
|
378
|
+
command: str,
|
|
379
|
+
window_seconds: int
|
|
380
|
+
) -> int:
|
|
381
|
+
"""Get count of identical commands in a time window."""
|
|
382
|
+
query = """
|
|
383
|
+
SELECT COUNT(*) as count FROM security_events
|
|
384
|
+
WHERE session_id = ?
|
|
385
|
+
AND timestamp > datetime('now', ?)
|
|
386
|
+
AND command = ?
|
|
387
|
+
AND event_type = 'tool_invoked'
|
|
388
|
+
"""
|
|
389
|
+
return conn.execute(
|
|
390
|
+
query,
|
|
391
|
+
[session_id, f'-{window_seconds} seconds', command]
|
|
392
|
+
).fetchone()[0]
|
|
393
|
+
|
|
394
|
+
def _get_baseline_velocity(
|
|
395
|
+
self,
|
|
396
|
+
conn: sqlite3.Connection,
|
|
397
|
+
session_id: str
|
|
398
|
+
) -> Optional[float]:
|
|
399
|
+
"""Get baseline commands per minute for comparison."""
|
|
400
|
+
query = """
|
|
401
|
+
SELECT COUNT(*) as count,
|
|
402
|
+
MIN(timestamp) as first_ts,
|
|
403
|
+
MAX(timestamp) as last_ts
|
|
404
|
+
FROM security_events
|
|
405
|
+
WHERE session_id = ?
|
|
406
|
+
AND timestamp > datetime('now', ?)
|
|
407
|
+
AND event_type = 'tool_invoked'
|
|
408
|
+
"""
|
|
409
|
+
result = conn.execute(
|
|
410
|
+
query,
|
|
411
|
+
[session_id, f'-{self.config.baseline_window_hours} hours']
|
|
412
|
+
).fetchone()
|
|
413
|
+
|
|
414
|
+
count = result[0]
|
|
415
|
+
if count < self.config.min_baseline_samples:
|
|
416
|
+
return None
|
|
417
|
+
|
|
418
|
+
# Calculate average commands per minute
|
|
419
|
+
try:
|
|
420
|
+
first_ts = datetime.fromisoformat(result[1])
|
|
421
|
+
last_ts = datetime.fromisoformat(result[2])
|
|
422
|
+
duration_minutes = (last_ts - first_ts).total_seconds() / 60
|
|
423
|
+
if duration_minutes > 0:
|
|
424
|
+
return count / duration_minutes
|
|
425
|
+
except (ValueError, TypeError):
|
|
426
|
+
pass
|
|
427
|
+
|
|
428
|
+
return None
|
|
429
|
+
|
|
430
|
+
def _get_current_velocity(
|
|
431
|
+
self,
|
|
432
|
+
conn: sqlite3.Connection,
|
|
433
|
+
session_id: str,
|
|
434
|
+
window_seconds: int = 60
|
|
435
|
+
) -> float:
|
|
436
|
+
"""Get current commands per minute."""
|
|
437
|
+
count = self._get_recent_count(conn, session_id, window_seconds)
|
|
438
|
+
return count * (60 / window_seconds)
|
|
439
|
+
|
|
440
|
+
def check(
|
|
441
|
+
self,
|
|
442
|
+
tool_name: str,
|
|
443
|
+
command: Optional[str],
|
|
444
|
+
session_id: Optional[str],
|
|
445
|
+
tier: Optional[str] = None
|
|
446
|
+
) -> RateLimitResult:
|
|
447
|
+
"""
|
|
448
|
+
Check if an invocation should be rate limited.
|
|
449
|
+
|
|
450
|
+
Rate limiting is free and open source.
|
|
451
|
+
|
|
452
|
+
Args:
|
|
453
|
+
tool_name: Name of the tool being invoked
|
|
454
|
+
command: The command being executed (for Bash)
|
|
455
|
+
session_id: Current session identifier
|
|
456
|
+
tier: Security tier of the operation
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
RateLimitResult with allowed status and any violations
|
|
460
|
+
"""
|
|
461
|
+
if not session_id:
|
|
462
|
+
# No session tracking - allow but log
|
|
463
|
+
return RateLimitResult(allowed=True, message="No session ID for rate limiting")
|
|
464
|
+
|
|
465
|
+
# Check circuit breaker first
|
|
466
|
+
circuit_key = f"session:{session_id}"
|
|
467
|
+
can_exec, circuit_state, retry_after = self.circuit_breaker.can_execute(circuit_key)
|
|
468
|
+
|
|
469
|
+
if not can_exec:
|
|
470
|
+
return RateLimitResult(
|
|
471
|
+
allowed=False,
|
|
472
|
+
violations=[RateLimitViolation.CIRCUIT_OPEN],
|
|
473
|
+
message=f"Circuit breaker is {circuit_state.value}. Too many rate limit violations.",
|
|
474
|
+
circuit_state=circuit_state,
|
|
475
|
+
retry_after=retry_after,
|
|
476
|
+
details={"circuit_key": circuit_key}
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
violations = []
|
|
480
|
+
details = {}
|
|
481
|
+
|
|
482
|
+
try:
|
|
483
|
+
with self.logger._get_connection() as conn:
|
|
484
|
+
# Check 1: Burst detection (many commands in very short window)
|
|
485
|
+
burst_count = self._get_recent_count(
|
|
486
|
+
conn, session_id, self.config.burst_window
|
|
487
|
+
)
|
|
488
|
+
details["burst_count"] = burst_count
|
|
489
|
+
if burst_count >= self.config.burst_threshold:
|
|
490
|
+
violations.append(RateLimitViolation.BURST)
|
|
491
|
+
details["burst_threshold"] = self.config.burst_threshold
|
|
492
|
+
|
|
493
|
+
# Check 2: Per-minute volume
|
|
494
|
+
minute_count = self._get_recent_count(
|
|
495
|
+
conn, session_id, self.config.short_window
|
|
496
|
+
)
|
|
497
|
+
details["minute_count"] = minute_count
|
|
498
|
+
if minute_count >= self.config.max_per_minute:
|
|
499
|
+
violations.append(RateLimitViolation.HIGH_VOLUME)
|
|
500
|
+
details["max_per_minute"] = self.config.max_per_minute
|
|
501
|
+
|
|
502
|
+
# Check 3: Dangerous tier spike
|
|
503
|
+
if tier == "dangerous":
|
|
504
|
+
dangerous_count = self._get_recent_count(
|
|
505
|
+
conn, session_id, self.config.short_window, tier="dangerous"
|
|
506
|
+
)
|
|
507
|
+
details["dangerous_count"] = dangerous_count
|
|
508
|
+
if dangerous_count >= self.config.max_dangerous_per_minute:
|
|
509
|
+
violations.append(RateLimitViolation.DANGEROUS_SPIKE)
|
|
510
|
+
details["max_dangerous"] = self.config.max_dangerous_per_minute
|
|
511
|
+
|
|
512
|
+
# Check 4: Repeated command detection
|
|
513
|
+
if command:
|
|
514
|
+
cmd_count = self._get_command_count(
|
|
515
|
+
conn, session_id, command, self.config.short_window
|
|
516
|
+
)
|
|
517
|
+
details["same_command_count"] = cmd_count
|
|
518
|
+
if cmd_count >= self.config.max_same_command:
|
|
519
|
+
violations.append(RateLimitViolation.REPEATED_COMMAND)
|
|
520
|
+
details["max_same_command"] = self.config.max_same_command
|
|
521
|
+
|
|
522
|
+
# Check 5: Velocity anomaly
|
|
523
|
+
baseline = self._get_baseline_velocity(conn, session_id)
|
|
524
|
+
current = self._get_current_velocity(conn, session_id)
|
|
525
|
+
details["current_velocity"] = round(current, 2)
|
|
526
|
+
|
|
527
|
+
if baseline:
|
|
528
|
+
details["baseline_velocity"] = round(baseline, 2)
|
|
529
|
+
if current > baseline * self.config.velocity_multiplier:
|
|
530
|
+
violations.append(RateLimitViolation.VELOCITY_ANOMALY)
|
|
531
|
+
details["velocity_ratio"] = round(current / baseline, 2)
|
|
532
|
+
|
|
533
|
+
except Exception as e:
|
|
534
|
+
# Database error - fail open but log
|
|
535
|
+
return RateLimitResult(
|
|
536
|
+
allowed=True,
|
|
537
|
+
message=f"Rate limit check failed: {e}",
|
|
538
|
+
details={"error": str(e)}
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Determine if we should block
|
|
542
|
+
allowed = len(violations) == 0
|
|
543
|
+
|
|
544
|
+
# Update circuit breaker based on result
|
|
545
|
+
if allowed:
|
|
546
|
+
new_state = self.circuit_breaker.record_success(circuit_key)
|
|
547
|
+
else:
|
|
548
|
+
new_state = self.circuit_breaker.record_failure(circuit_key)
|
|
549
|
+
|
|
550
|
+
# Build message
|
|
551
|
+
message = None
|
|
552
|
+
if not allowed:
|
|
553
|
+
violation_names = [v.value for v in violations]
|
|
554
|
+
message = f"Rate limit violations: {', '.join(violation_names)}"
|
|
555
|
+
|
|
556
|
+
return RateLimitResult(
|
|
557
|
+
allowed=allowed,
|
|
558
|
+
violations=violations,
|
|
559
|
+
details=details,
|
|
560
|
+
message=message,
|
|
561
|
+
circuit_state=new_state
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
def get_session_stats(self, session_id: str) -> Dict[str, Any]:
|
|
565
|
+
"""
|
|
566
|
+
Get statistics for a session.
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
session_id: Session to get stats for
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
Dictionary with session statistics
|
|
573
|
+
"""
|
|
574
|
+
try:
|
|
575
|
+
with self.logger._get_connection() as conn:
|
|
576
|
+
# Total invocations
|
|
577
|
+
total = self._get_recent_count(
|
|
578
|
+
conn, session_id, self.config.long_window * 12 # 1 hour
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# By tier
|
|
582
|
+
tiers = {}
|
|
583
|
+
for tier in ["safe", "default", "risky", "dangerous"]:
|
|
584
|
+
tiers[tier] = self._get_recent_count(
|
|
585
|
+
conn, session_id, self.config.long_window * 12, tier=tier
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
# Velocity
|
|
589
|
+
current = self._get_current_velocity(conn, session_id)
|
|
590
|
+
baseline = self._get_baseline_velocity(conn, session_id)
|
|
591
|
+
|
|
592
|
+
return {
|
|
593
|
+
"session_id": session_id,
|
|
594
|
+
"total_invocations_1h": total,
|
|
595
|
+
"by_tier": tiers,
|
|
596
|
+
"current_velocity_per_min": round(current, 2),
|
|
597
|
+
"baseline_velocity_per_min": round(baseline, 2) if baseline else None,
|
|
598
|
+
"config": {
|
|
599
|
+
"burst_threshold": self.config.burst_threshold,
|
|
600
|
+
"max_per_minute": self.config.max_per_minute,
|
|
601
|
+
"max_dangerous_per_minute": self.config.max_dangerous_per_minute,
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
except Exception as e:
|
|
605
|
+
return {"error": str(e)}
|
|
606
|
+
|
|
607
|
+
def format_violation_message(self, result: RateLimitResult) -> str:
|
|
608
|
+
"""Format a user-friendly violation message."""
|
|
609
|
+
if result.allowed:
|
|
610
|
+
return ""
|
|
611
|
+
|
|
612
|
+
lines = [
|
|
613
|
+
"Rate Limit Alert",
|
|
614
|
+
"=" * 40,
|
|
615
|
+
]
|
|
616
|
+
|
|
617
|
+
if result.is_burst:
|
|
618
|
+
lines.append(
|
|
619
|
+
f" Burst detected: {result.details.get('burst_count', '?')} "
|
|
620
|
+
f"commands in {self.config.burst_window}s "
|
|
621
|
+
f"(limit: {self.config.burst_threshold})"
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
if result.is_repeated:
|
|
625
|
+
lines.append(
|
|
626
|
+
f" Repeated command: {result.details.get('same_command_count', '?')} "
|
|
627
|
+
f"times in 1 minute (limit: {self.config.max_same_command})"
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
if RateLimitViolation.HIGH_VOLUME in result.violations:
|
|
631
|
+
lines.append(
|
|
632
|
+
f" High volume: {result.details.get('minute_count', '?')} "
|
|
633
|
+
f"commands/min (limit: {self.config.max_per_minute})"
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
if RateLimitViolation.DANGEROUS_SPIKE in result.violations:
|
|
637
|
+
lines.append(
|
|
638
|
+
f" Dangerous tier spike: {result.details.get('dangerous_count', '?')} "
|
|
639
|
+
f"dangerous commands (limit: {self.config.max_dangerous_per_minute})"
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
if RateLimitViolation.VELOCITY_ANOMALY in result.violations:
|
|
643
|
+
lines.append(
|
|
644
|
+
f" Velocity anomaly: {result.details.get('velocity_ratio', '?')}x "
|
|
645
|
+
f"above baseline"
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
if result.is_circuit_open:
|
|
649
|
+
lines.append(
|
|
650
|
+
f" Circuit breaker OPEN: Too many rate limit violations"
|
|
651
|
+
)
|
|
652
|
+
if result.retry_after:
|
|
653
|
+
lines.append(f" Retry after: {result.retry_after} seconds")
|
|
654
|
+
|
|
655
|
+
lines.append("=" * 40)
|
|
656
|
+
lines.append("This may indicate automated abuse or attack.")
|
|
657
|
+
|
|
658
|
+
return "\n".join(lines)
|
|
659
|
+
|
|
660
|
+
def get_circuit_metrics(self) -> Dict[str, Any]:
|
|
661
|
+
"""Get circuit breaker metrics for all sessions."""
|
|
662
|
+
return self.circuit_breaker.get_metrics()
|
|
663
|
+
|
|
664
|
+
def reset_circuit(self, session_id: str) -> None:
|
|
665
|
+
"""Reset circuit breaker for a session."""
|
|
666
|
+
circuit_key = f"session:{session_id}"
|
|
667
|
+
self.circuit_breaker.reset(circuit_key)
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
# Singleton instance
|
|
671
|
+
_rate_limiter: Optional[RateLimiter] = None
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def get_rate_limiter(
|
|
675
|
+
config: Optional[RateLimitConfig] = None,
|
|
676
|
+
circuit_config: Optional[CircuitBreakerConfig] = None
|
|
677
|
+
) -> RateLimiter:
|
|
678
|
+
"""Get the singleton rate limiter instance."""
|
|
679
|
+
global _rate_limiter
|
|
680
|
+
if _rate_limiter is None:
|
|
681
|
+
_rate_limiter = RateLimiter(config=config, circuit_config=circuit_config)
|
|
682
|
+
return _rate_limiter
|