tweek 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. tweek/__init__.py +16 -0
  2. tweek/cli.py +3390 -0
  3. tweek/cli_helpers.py +193 -0
  4. tweek/config/__init__.py +13 -0
  5. tweek/config/allowed_dirs.yaml +23 -0
  6. tweek/config/manager.py +1064 -0
  7. tweek/config/patterns.yaml +751 -0
  8. tweek/config/tiers.yaml +129 -0
  9. tweek/diagnostics.py +589 -0
  10. tweek/hooks/__init__.py +1 -0
  11. tweek/hooks/pre_tool_use.py +861 -0
  12. tweek/integrations/__init__.py +3 -0
  13. tweek/integrations/moltbot.py +243 -0
  14. tweek/licensing.py +398 -0
  15. tweek/logging/__init__.py +9 -0
  16. tweek/logging/bundle.py +350 -0
  17. tweek/logging/json_logger.py +150 -0
  18. tweek/logging/security_log.py +745 -0
  19. tweek/mcp/__init__.py +24 -0
  20. tweek/mcp/approval.py +456 -0
  21. tweek/mcp/approval_cli.py +356 -0
  22. tweek/mcp/clients/__init__.py +37 -0
  23. tweek/mcp/clients/chatgpt.py +112 -0
  24. tweek/mcp/clients/claude_desktop.py +203 -0
  25. tweek/mcp/clients/gemini.py +178 -0
  26. tweek/mcp/proxy.py +667 -0
  27. tweek/mcp/screening.py +175 -0
  28. tweek/mcp/server.py +317 -0
  29. tweek/platform/__init__.py +131 -0
  30. tweek/plugins/__init__.py +835 -0
  31. tweek/plugins/base.py +1080 -0
  32. tweek/plugins/compliance/__init__.py +30 -0
  33. tweek/plugins/compliance/gdpr.py +333 -0
  34. tweek/plugins/compliance/gov.py +324 -0
  35. tweek/plugins/compliance/hipaa.py +285 -0
  36. tweek/plugins/compliance/legal.py +322 -0
  37. tweek/plugins/compliance/pci.py +361 -0
  38. tweek/plugins/compliance/soc2.py +275 -0
  39. tweek/plugins/detectors/__init__.py +30 -0
  40. tweek/plugins/detectors/continue_dev.py +206 -0
  41. tweek/plugins/detectors/copilot.py +254 -0
  42. tweek/plugins/detectors/cursor.py +192 -0
  43. tweek/plugins/detectors/moltbot.py +205 -0
  44. tweek/plugins/detectors/windsurf.py +214 -0
  45. tweek/plugins/git_discovery.py +395 -0
  46. tweek/plugins/git_installer.py +491 -0
  47. tweek/plugins/git_lockfile.py +338 -0
  48. tweek/plugins/git_registry.py +503 -0
  49. tweek/plugins/git_security.py +482 -0
  50. tweek/plugins/providers/__init__.py +30 -0
  51. tweek/plugins/providers/anthropic.py +181 -0
  52. tweek/plugins/providers/azure_openai.py +289 -0
  53. tweek/plugins/providers/bedrock.py +248 -0
  54. tweek/plugins/providers/google.py +197 -0
  55. tweek/plugins/providers/openai.py +230 -0
  56. tweek/plugins/scope.py +130 -0
  57. tweek/plugins/screening/__init__.py +26 -0
  58. tweek/plugins/screening/llm_reviewer.py +149 -0
  59. tweek/plugins/screening/pattern_matcher.py +273 -0
  60. tweek/plugins/screening/rate_limiter.py +174 -0
  61. tweek/plugins/screening/session_analyzer.py +159 -0
  62. tweek/proxy/__init__.py +302 -0
  63. tweek/proxy/addon.py +223 -0
  64. tweek/proxy/interceptor.py +313 -0
  65. tweek/proxy/server.py +315 -0
  66. tweek/sandbox/__init__.py +71 -0
  67. tweek/sandbox/executor.py +382 -0
  68. tweek/sandbox/linux.py +278 -0
  69. tweek/sandbox/profile_generator.py +323 -0
  70. tweek/screening/__init__.py +13 -0
  71. tweek/screening/context.py +81 -0
  72. tweek/security/__init__.py +22 -0
  73. tweek/security/llm_reviewer.py +348 -0
  74. tweek/security/rate_limiter.py +682 -0
  75. tweek/security/secret_scanner.py +506 -0
  76. tweek/security/session_analyzer.py +600 -0
  77. tweek/vault/__init__.py +40 -0
  78. tweek/vault/cross_platform.py +251 -0
  79. tweek/vault/keychain.py +288 -0
  80. tweek-0.1.0.dist-info/METADATA +335 -0
  81. tweek-0.1.0.dist-info/RECORD +85 -0
  82. tweek-0.1.0.dist-info/WHEEL +5 -0
  83. tweek-0.1.0.dist-info/entry_points.txt +25 -0
  84. tweek-0.1.0.dist-info/licenses/LICENSE +190 -0
  85. tweek-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,682 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tweek Rate Limiter
4
+
5
+ Protects against resource theft attacks (MCP sampling abuse, quota drain)
6
+ by detecting:
7
+ - Burst patterns (many commands in short time)
8
+ - Repeated identical commands
9
+ - Unusual invocation volume
10
+ - Suspicious velocity changes
11
+
12
+ Based on Unit42 research on MCP sampling attack vectors.
13
+ """
14
+
15
+ import hashlib
16
+ import json
17
+ import sqlite3
18
+ from dataclasses import dataclass, field
19
+ from datetime import datetime, timedelta
20
+ from enum import Enum
21
+ from pathlib import Path
22
+ from typing import Optional, List, Dict, Any, Tuple
23
+
24
+ from tweek.logging.security_log import SecurityLogger, get_logger
25
+
26
+
27
+ class RateLimitViolation(Enum):
28
+ """Types of rate limit violations."""
29
+ BURST = "burst" # Too many commands in short window
30
+ REPEATED_COMMAND = "repeated" # Same command executed too many times
31
+ HIGH_VOLUME = "high_volume" # Total volume exceeds threshold
32
+ DANGEROUS_SPIKE = "dangerous_spike" # Spike in dangerous tier commands
33
+ VELOCITY_ANOMALY = "velocity" # Unusual acceleration in activity
34
+ CIRCUIT_OPEN = "circuit_open" # Circuit breaker is open
35
+
36
+
37
+ class CircuitState(Enum):
38
+ """Circuit breaker states."""
39
+ CLOSED = "closed" # Normal operation, requests allowed
40
+ OPEN = "open" # Failures exceeded threshold, requests blocked
41
+ HALF_OPEN = "half_open" # Testing if service recovered
42
+
43
+
44
+ @dataclass
45
+ class RateLimitConfig:
46
+ """Configuration for rate limiting thresholds."""
47
+ # Time windows (in seconds)
48
+ burst_window: int = 5
49
+ short_window: int = 60
50
+ long_window: int = 300
51
+
52
+ # Thresholds
53
+ burst_threshold: int = 15 # Max commands in burst window
54
+ max_per_minute: int = 60 # Max commands per minute
55
+ max_dangerous_per_minute: int = 10 # Max dangerous tier per minute
56
+ max_same_command: int = 5 # Max identical commands per minute
57
+ velocity_multiplier: float = 3.0 # Alert if velocity > N * baseline
58
+
59
+ # Baseline learning
60
+ baseline_window_hours: int = 24 # Hours of data for baseline
61
+ min_baseline_samples: int = 100 # Minimum samples for baseline
62
+
63
+
64
+ @dataclass
65
+ class CircuitBreakerConfig:
66
+ """Configuration for circuit breaker pattern."""
67
+ # Failure thresholds
68
+ failure_threshold: int = 5 # Failures before opening circuit
69
+ success_threshold: int = 3 # Successes in half-open before closing
70
+
71
+ # Timing
72
+ open_timeout: int = 60 # Seconds to stay open before half-open
73
+ half_open_max_requests: int = 3 # Max requests to test in half-open
74
+
75
+ # What counts as failure
76
+ count_rate_limit_as_failure: bool = True
77
+ count_timeout_as_failure: bool = True
78
+
79
+
80
+ @dataclass
81
+ class CircuitBreakerState:
82
+ """Current state of a circuit breaker."""
83
+ state: CircuitState = CircuitState.CLOSED
84
+ failure_count: int = 0
85
+ success_count: int = 0
86
+ last_failure_time: Optional[datetime] = None
87
+ last_state_change: Optional[datetime] = None
88
+ half_open_requests: int = 0
89
+
90
+
91
+ @dataclass
92
+ class RateLimitResult:
93
+ """Result of rate limit check."""
94
+ allowed: bool
95
+ violations: List[RateLimitViolation] = field(default_factory=list)
96
+ details: Dict[str, Any] = field(default_factory=dict)
97
+ message: Optional[str] = None
98
+ circuit_state: CircuitState = CircuitState.CLOSED
99
+ retry_after: Optional[int] = None # Seconds to wait before retry
100
+
101
+ @property
102
+ def is_burst(self) -> bool:
103
+ return RateLimitViolation.BURST in self.violations
104
+
105
+ @property
106
+ def is_repeated(self) -> bool:
107
+ return RateLimitViolation.REPEATED_COMMAND in self.violations
108
+
109
+ @property
110
+ def is_circuit_open(self) -> bool:
111
+ return self.circuit_state == CircuitState.OPEN
112
+
113
+
114
+ class CircuitBreaker:
115
+ """
116
+ Circuit breaker pattern implementation for fault tolerance.
117
+
118
+ States:
119
+ - CLOSED: Normal operation, requests allowed, failures tracked
120
+ - OPEN: Too many failures, requests blocked, waiting for timeout
121
+ - HALF_OPEN: Testing recovery, limited requests allowed
122
+
123
+ Based on moltbot's circuit breaker implementation for resilience.
124
+ """
125
+
126
+ def __init__(self, config: Optional[CircuitBreakerConfig] = None):
127
+ """
128
+ Initialize the circuit breaker.
129
+
130
+ Args:
131
+ config: Circuit breaker configuration
132
+ """
133
+ self.config = config or CircuitBreakerConfig()
134
+ self._states: Dict[str, CircuitBreakerState] = {}
135
+
136
+ def _get_state(self, key: str) -> CircuitBreakerState:
137
+ """Get or create state for a circuit key."""
138
+ if key not in self._states:
139
+ self._states[key] = CircuitBreakerState(
140
+ last_state_change=datetime.now()
141
+ )
142
+ return self._states[key]
143
+
144
+ def _transition_to(self, state: CircuitBreakerState, new_state: CircuitState) -> None:
145
+ """Transition circuit to a new state."""
146
+ state.state = new_state
147
+ state.last_state_change = datetime.now()
148
+ if new_state == CircuitState.HALF_OPEN:
149
+ state.half_open_requests = 0
150
+ state.success_count = 0
151
+
152
+ def can_execute(self, key: str = "default") -> Tuple[bool, CircuitState, Optional[int]]:
153
+ """
154
+ Check if a request can be executed.
155
+
156
+ Args:
157
+ key: Circuit breaker key (e.g., "session:123" or "tool:Bash")
158
+
159
+ Returns:
160
+ (allowed, circuit_state, retry_after_seconds)
161
+ """
162
+ state = self._get_state(key)
163
+ now = datetime.now()
164
+
165
+ if state.state == CircuitState.CLOSED:
166
+ return True, CircuitState.CLOSED, None
167
+
168
+ elif state.state == CircuitState.OPEN:
169
+ # Check if timeout has elapsed
170
+ if state.last_state_change:
171
+ elapsed = (now - state.last_state_change).total_seconds()
172
+ if elapsed >= self.config.open_timeout:
173
+ # Transition to half-open
174
+ self._transition_to(state, CircuitState.HALF_OPEN)
175
+ return True, CircuitState.HALF_OPEN, None
176
+ else:
177
+ retry_after = int(self.config.open_timeout - elapsed)
178
+ return False, CircuitState.OPEN, retry_after
179
+
180
+ return False, CircuitState.OPEN, self.config.open_timeout
181
+
182
+ elif state.state == CircuitState.HALF_OPEN:
183
+ # Allow limited requests in half-open state
184
+ if state.half_open_requests < self.config.half_open_max_requests:
185
+ state.half_open_requests += 1
186
+ return True, CircuitState.HALF_OPEN, None
187
+ else:
188
+ return False, CircuitState.HALF_OPEN, 5 # Short wait
189
+
190
+ return True, CircuitState.CLOSED, None
191
+
192
+ def record_success(self, key: str = "default") -> CircuitState:
193
+ """
194
+ Record a successful request.
195
+
196
+ Args:
197
+ key: Circuit breaker key
198
+
199
+ Returns:
200
+ New circuit state
201
+ """
202
+ state = self._get_state(key)
203
+
204
+ if state.state == CircuitState.HALF_OPEN:
205
+ state.success_count += 1
206
+ if state.success_count >= self.config.success_threshold:
207
+ # Recovery confirmed, close the circuit
208
+ self._transition_to(state, CircuitState.CLOSED)
209
+ state.failure_count = 0
210
+
211
+ elif state.state == CircuitState.CLOSED:
212
+ # Reset failure count on success
213
+ state.failure_count = 0
214
+
215
+ return state.state
216
+
217
+ def record_failure(self, key: str = "default") -> CircuitState:
218
+ """
219
+ Record a failed request.
220
+
221
+ Args:
222
+ key: Circuit breaker key
223
+
224
+ Returns:
225
+ New circuit state
226
+ """
227
+ state = self._get_state(key)
228
+ state.failure_count += 1
229
+ state.last_failure_time = datetime.now()
230
+
231
+ if state.state == CircuitState.HALF_OPEN:
232
+ # Any failure in half-open reopens the circuit
233
+ self._transition_to(state, CircuitState.OPEN)
234
+
235
+ elif state.state == CircuitState.CLOSED:
236
+ if state.failure_count >= self.config.failure_threshold:
237
+ # Too many failures, open the circuit
238
+ self._transition_to(state, CircuitState.OPEN)
239
+
240
+ return state.state
241
+
242
+ def get_state(self, key: str = "default") -> CircuitBreakerState:
243
+ """
244
+ Get the current state of a circuit.
245
+
246
+ Args:
247
+ key: Circuit breaker key
248
+
249
+ Returns:
250
+ Current circuit breaker state
251
+ """
252
+ return self._get_state(key)
253
+
254
+ def reset(self, key: str = "default") -> None:
255
+ """
256
+ Reset a circuit breaker to closed state.
257
+
258
+ Args:
259
+ key: Circuit breaker key
260
+ """
261
+ if key in self._states:
262
+ del self._states[key]
263
+
264
+ def get_all_states(self) -> Dict[str, CircuitBreakerState]:
265
+ """Get all circuit breaker states."""
266
+ return self._states.copy()
267
+
268
+ def get_metrics(self) -> Dict[str, Any]:
269
+ """
270
+ Get circuit breaker metrics.
271
+
272
+ Returns:
273
+ Dictionary with metrics for all circuits
274
+ """
275
+ metrics = {
276
+ "total_circuits": len(self._states),
277
+ "open_circuits": 0,
278
+ "half_open_circuits": 0,
279
+ "closed_circuits": 0,
280
+ "circuits": {}
281
+ }
282
+
283
+ for key, state in self._states.items():
284
+ if state.state == CircuitState.OPEN:
285
+ metrics["open_circuits"] += 1
286
+ elif state.state == CircuitState.HALF_OPEN:
287
+ metrics["half_open_circuits"] += 1
288
+ else:
289
+ metrics["closed_circuits"] += 1
290
+
291
+ metrics["circuits"][key] = {
292
+ "state": state.state.value,
293
+ "failure_count": state.failure_count,
294
+ "success_count": state.success_count,
295
+ }
296
+
297
+ return metrics
298
+
299
+
300
+ class RateLimiter:
301
+ """
302
+ Rate limiter for detecting resource theft and abuse patterns.
303
+
304
+ Uses the security.db to track invocation patterns and detect anomalies.
305
+ Includes circuit breaker for fault tolerance.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ config: Optional[RateLimitConfig] = None,
311
+ circuit_config: Optional[CircuitBreakerConfig] = None,
312
+ logger: Optional[SecurityLogger] = None
313
+ ):
314
+ """Initialize the rate limiter.
315
+
316
+ Args:
317
+ config: Rate limiting configuration
318
+ circuit_config: Circuit breaker configuration
319
+ logger: Security logger for database access
320
+ """
321
+ self.config = config or RateLimitConfig()
322
+ self.logger = logger or get_logger()
323
+ self.circuit_breaker = CircuitBreaker(circuit_config)
324
+ self._ensure_indexes()
325
+
326
+ def _ensure_indexes(self):
327
+ """Ensure necessary database indexes exist for efficient queries."""
328
+ try:
329
+ with self.logger._get_connection() as conn:
330
+ conn.executescript("""
331
+ -- Index for session + timestamp queries (rate limiting)
332
+ CREATE INDEX IF NOT EXISTS idx_events_session_time
333
+ ON security_events(session_id, timestamp);
334
+
335
+ -- Index for command hash queries (repeated command detection)
336
+ CREATE INDEX IF NOT EXISTS idx_events_command_hash
337
+ ON security_events(tool_name, command);
338
+ """)
339
+ except Exception:
340
+ # Indexes may already exist or db not initialized
341
+ pass
342
+
343
+ def _hash_command(self, command: str) -> str:
344
+ """Create a hash of a command for comparison."""
345
+ return hashlib.md5(command.encode()).hexdigest()[:16]
346
+
347
+ def _get_recent_count(
348
+ self,
349
+ conn: sqlite3.Connection,
350
+ session_id: str,
351
+ window_seconds: int,
352
+ tool_name: Optional[str] = None,
353
+ tier: Optional[str] = None
354
+ ) -> int:
355
+ """Get count of recent events in a time window."""
356
+ query = """
357
+ SELECT COUNT(*) as count FROM security_events
358
+ WHERE session_id = ?
359
+ AND timestamp > datetime('now', ?)
360
+ AND event_type = 'tool_invoked'
361
+ """
362
+ params = [session_id, f'-{window_seconds} seconds']
363
+
364
+ if tool_name:
365
+ query += " AND tool_name = ?"
366
+ params.append(tool_name)
367
+
368
+ if tier:
369
+ query += " AND tier = ?"
370
+ params.append(tier)
371
+
372
+ return conn.execute(query, params).fetchone()[0]
373
+
374
+ def _get_command_count(
375
+ self,
376
+ conn: sqlite3.Connection,
377
+ session_id: str,
378
+ command: str,
379
+ window_seconds: int
380
+ ) -> int:
381
+ """Get count of identical commands in a time window."""
382
+ query = """
383
+ SELECT COUNT(*) as count FROM security_events
384
+ WHERE session_id = ?
385
+ AND timestamp > datetime('now', ?)
386
+ AND command = ?
387
+ AND event_type = 'tool_invoked'
388
+ """
389
+ return conn.execute(
390
+ query,
391
+ [session_id, f'-{window_seconds} seconds', command]
392
+ ).fetchone()[0]
393
+
394
+ def _get_baseline_velocity(
395
+ self,
396
+ conn: sqlite3.Connection,
397
+ session_id: str
398
+ ) -> Optional[float]:
399
+ """Get baseline commands per minute for comparison."""
400
+ query = """
401
+ SELECT COUNT(*) as count,
402
+ MIN(timestamp) as first_ts,
403
+ MAX(timestamp) as last_ts
404
+ FROM security_events
405
+ WHERE session_id = ?
406
+ AND timestamp > datetime('now', ?)
407
+ AND event_type = 'tool_invoked'
408
+ """
409
+ result = conn.execute(
410
+ query,
411
+ [session_id, f'-{self.config.baseline_window_hours} hours']
412
+ ).fetchone()
413
+
414
+ count = result[0]
415
+ if count < self.config.min_baseline_samples:
416
+ return None
417
+
418
+ # Calculate average commands per minute
419
+ try:
420
+ first_ts = datetime.fromisoformat(result[1])
421
+ last_ts = datetime.fromisoformat(result[2])
422
+ duration_minutes = (last_ts - first_ts).total_seconds() / 60
423
+ if duration_minutes > 0:
424
+ return count / duration_minutes
425
+ except (ValueError, TypeError):
426
+ pass
427
+
428
+ return None
429
+
430
+ def _get_current_velocity(
431
+ self,
432
+ conn: sqlite3.Connection,
433
+ session_id: str,
434
+ window_seconds: int = 60
435
+ ) -> float:
436
+ """Get current commands per minute."""
437
+ count = self._get_recent_count(conn, session_id, window_seconds)
438
+ return count * (60 / window_seconds)
439
+
440
+ def check(
441
+ self,
442
+ tool_name: str,
443
+ command: Optional[str],
444
+ session_id: Optional[str],
445
+ tier: Optional[str] = None
446
+ ) -> RateLimitResult:
447
+ """
448
+ Check if an invocation should be rate limited.
449
+
450
+ Rate limiting is free and open source.
451
+
452
+ Args:
453
+ tool_name: Name of the tool being invoked
454
+ command: The command being executed (for Bash)
455
+ session_id: Current session identifier
456
+ tier: Security tier of the operation
457
+
458
+ Returns:
459
+ RateLimitResult with allowed status and any violations
460
+ """
461
+ if not session_id:
462
+ # No session tracking - allow but log
463
+ return RateLimitResult(allowed=True, message="No session ID for rate limiting")
464
+
465
+ # Check circuit breaker first
466
+ circuit_key = f"session:{session_id}"
467
+ can_exec, circuit_state, retry_after = self.circuit_breaker.can_execute(circuit_key)
468
+
469
+ if not can_exec:
470
+ return RateLimitResult(
471
+ allowed=False,
472
+ violations=[RateLimitViolation.CIRCUIT_OPEN],
473
+ message=f"Circuit breaker is {circuit_state.value}. Too many rate limit violations.",
474
+ circuit_state=circuit_state,
475
+ retry_after=retry_after,
476
+ details={"circuit_key": circuit_key}
477
+ )
478
+
479
+ violations = []
480
+ details = {}
481
+
482
+ try:
483
+ with self.logger._get_connection() as conn:
484
+ # Check 1: Burst detection (many commands in very short window)
485
+ burst_count = self._get_recent_count(
486
+ conn, session_id, self.config.burst_window
487
+ )
488
+ details["burst_count"] = burst_count
489
+ if burst_count >= self.config.burst_threshold:
490
+ violations.append(RateLimitViolation.BURST)
491
+ details["burst_threshold"] = self.config.burst_threshold
492
+
493
+ # Check 2: Per-minute volume
494
+ minute_count = self._get_recent_count(
495
+ conn, session_id, self.config.short_window
496
+ )
497
+ details["minute_count"] = minute_count
498
+ if minute_count >= self.config.max_per_minute:
499
+ violations.append(RateLimitViolation.HIGH_VOLUME)
500
+ details["max_per_minute"] = self.config.max_per_minute
501
+
502
+ # Check 3: Dangerous tier spike
503
+ if tier == "dangerous":
504
+ dangerous_count = self._get_recent_count(
505
+ conn, session_id, self.config.short_window, tier="dangerous"
506
+ )
507
+ details["dangerous_count"] = dangerous_count
508
+ if dangerous_count >= self.config.max_dangerous_per_minute:
509
+ violations.append(RateLimitViolation.DANGEROUS_SPIKE)
510
+ details["max_dangerous"] = self.config.max_dangerous_per_minute
511
+
512
+ # Check 4: Repeated command detection
513
+ if command:
514
+ cmd_count = self._get_command_count(
515
+ conn, session_id, command, self.config.short_window
516
+ )
517
+ details["same_command_count"] = cmd_count
518
+ if cmd_count >= self.config.max_same_command:
519
+ violations.append(RateLimitViolation.REPEATED_COMMAND)
520
+ details["max_same_command"] = self.config.max_same_command
521
+
522
+ # Check 5: Velocity anomaly
523
+ baseline = self._get_baseline_velocity(conn, session_id)
524
+ current = self._get_current_velocity(conn, session_id)
525
+ details["current_velocity"] = round(current, 2)
526
+
527
+ if baseline:
528
+ details["baseline_velocity"] = round(baseline, 2)
529
+ if current > baseline * self.config.velocity_multiplier:
530
+ violations.append(RateLimitViolation.VELOCITY_ANOMALY)
531
+ details["velocity_ratio"] = round(current / baseline, 2)
532
+
533
+ except Exception as e:
534
+ # Database error - fail open but log
535
+ return RateLimitResult(
536
+ allowed=True,
537
+ message=f"Rate limit check failed: {e}",
538
+ details={"error": str(e)}
539
+ )
540
+
541
+ # Determine if we should block
542
+ allowed = len(violations) == 0
543
+
544
+ # Update circuit breaker based on result
545
+ if allowed:
546
+ new_state = self.circuit_breaker.record_success(circuit_key)
547
+ else:
548
+ new_state = self.circuit_breaker.record_failure(circuit_key)
549
+
550
+ # Build message
551
+ message = None
552
+ if not allowed:
553
+ violation_names = [v.value for v in violations]
554
+ message = f"Rate limit violations: {', '.join(violation_names)}"
555
+
556
+ return RateLimitResult(
557
+ allowed=allowed,
558
+ violations=violations,
559
+ details=details,
560
+ message=message,
561
+ circuit_state=new_state
562
+ )
563
+
564
+ def get_session_stats(self, session_id: str) -> Dict[str, Any]:
565
+ """
566
+ Get statistics for a session.
567
+
568
+ Args:
569
+ session_id: Session to get stats for
570
+
571
+ Returns:
572
+ Dictionary with session statistics
573
+ """
574
+ try:
575
+ with self.logger._get_connection() as conn:
576
+ # Total invocations
577
+ total = self._get_recent_count(
578
+ conn, session_id, self.config.long_window * 12 # 1 hour
579
+ )
580
+
581
+ # By tier
582
+ tiers = {}
583
+ for tier in ["safe", "default", "risky", "dangerous"]:
584
+ tiers[tier] = self._get_recent_count(
585
+ conn, session_id, self.config.long_window * 12, tier=tier
586
+ )
587
+
588
+ # Velocity
589
+ current = self._get_current_velocity(conn, session_id)
590
+ baseline = self._get_baseline_velocity(conn, session_id)
591
+
592
+ return {
593
+ "session_id": session_id,
594
+ "total_invocations_1h": total,
595
+ "by_tier": tiers,
596
+ "current_velocity_per_min": round(current, 2),
597
+ "baseline_velocity_per_min": round(baseline, 2) if baseline else None,
598
+ "config": {
599
+ "burst_threshold": self.config.burst_threshold,
600
+ "max_per_minute": self.config.max_per_minute,
601
+ "max_dangerous_per_minute": self.config.max_dangerous_per_minute,
602
+ }
603
+ }
604
+ except Exception as e:
605
+ return {"error": str(e)}
606
+
607
+ def format_violation_message(self, result: RateLimitResult) -> str:
608
+ """Format a user-friendly violation message."""
609
+ if result.allowed:
610
+ return ""
611
+
612
+ lines = [
613
+ "Rate Limit Alert",
614
+ "=" * 40,
615
+ ]
616
+
617
+ if result.is_burst:
618
+ lines.append(
619
+ f" Burst detected: {result.details.get('burst_count', '?')} "
620
+ f"commands in {self.config.burst_window}s "
621
+ f"(limit: {self.config.burst_threshold})"
622
+ )
623
+
624
+ if result.is_repeated:
625
+ lines.append(
626
+ f" Repeated command: {result.details.get('same_command_count', '?')} "
627
+ f"times in 1 minute (limit: {self.config.max_same_command})"
628
+ )
629
+
630
+ if RateLimitViolation.HIGH_VOLUME in result.violations:
631
+ lines.append(
632
+ f" High volume: {result.details.get('minute_count', '?')} "
633
+ f"commands/min (limit: {self.config.max_per_minute})"
634
+ )
635
+
636
+ if RateLimitViolation.DANGEROUS_SPIKE in result.violations:
637
+ lines.append(
638
+ f" Dangerous tier spike: {result.details.get('dangerous_count', '?')} "
639
+ f"dangerous commands (limit: {self.config.max_dangerous_per_minute})"
640
+ )
641
+
642
+ if RateLimitViolation.VELOCITY_ANOMALY in result.violations:
643
+ lines.append(
644
+ f" Velocity anomaly: {result.details.get('velocity_ratio', '?')}x "
645
+ f"above baseline"
646
+ )
647
+
648
+ if result.is_circuit_open:
649
+ lines.append(
650
+ f" Circuit breaker OPEN: Too many rate limit violations"
651
+ )
652
+ if result.retry_after:
653
+ lines.append(f" Retry after: {result.retry_after} seconds")
654
+
655
+ lines.append("=" * 40)
656
+ lines.append("This may indicate automated abuse or attack.")
657
+
658
+ return "\n".join(lines)
659
+
660
+ def get_circuit_metrics(self) -> Dict[str, Any]:
661
+ """Get circuit breaker metrics for all sessions."""
662
+ return self.circuit_breaker.get_metrics()
663
+
664
+ def reset_circuit(self, session_id: str) -> None:
665
+ """Reset circuit breaker for a session."""
666
+ circuit_key = f"session:{session_id}"
667
+ self.circuit_breaker.reset(circuit_key)
668
+
669
+
670
+ # Singleton instance
671
+ _rate_limiter: Optional[RateLimiter] = None
672
+
673
+
674
+ def get_rate_limiter(
675
+ config: Optional[RateLimitConfig] = None,
676
+ circuit_config: Optional[CircuitBreakerConfig] = None
677
+ ) -> RateLimiter:
678
+ """Get the singleton rate limiter instance."""
679
+ global _rate_limiter
680
+ if _rate_limiter is None:
681
+ _rate_limiter = RateLimiter(config=config, circuit_config=circuit_config)
682
+ return _rate_limiter