proxilion 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. proxilion/__init__.py +136 -0
  2. proxilion/audit/__init__.py +133 -0
  3. proxilion/audit/base_exporters.py +527 -0
  4. proxilion/audit/compliance/__init__.py +130 -0
  5. proxilion/audit/compliance/base.py +457 -0
  6. proxilion/audit/compliance/eu_ai_act.py +603 -0
  7. proxilion/audit/compliance/iso27001.py +544 -0
  8. proxilion/audit/compliance/soc2.py +491 -0
  9. proxilion/audit/events.py +493 -0
  10. proxilion/audit/explainability.py +1173 -0
  11. proxilion/audit/exporters/__init__.py +58 -0
  12. proxilion/audit/exporters/aws_s3.py +636 -0
  13. proxilion/audit/exporters/azure_storage.py +608 -0
  14. proxilion/audit/exporters/cloud_base.py +468 -0
  15. proxilion/audit/exporters/gcp_storage.py +570 -0
  16. proxilion/audit/exporters/multi_exporter.py +498 -0
  17. proxilion/audit/hash_chain.py +652 -0
  18. proxilion/audit/logger.py +543 -0
  19. proxilion/caching/__init__.py +49 -0
  20. proxilion/caching/tool_cache.py +633 -0
  21. proxilion/context/__init__.py +73 -0
  22. proxilion/context/context_window.py +556 -0
  23. proxilion/context/message_history.py +505 -0
  24. proxilion/context/session.py +735 -0
  25. proxilion/contrib/__init__.py +51 -0
  26. proxilion/contrib/anthropic.py +609 -0
  27. proxilion/contrib/google.py +1012 -0
  28. proxilion/contrib/langchain.py +641 -0
  29. proxilion/contrib/mcp.py +893 -0
  30. proxilion/contrib/openai.py +646 -0
  31. proxilion/core.py +3058 -0
  32. proxilion/decorators.py +966 -0
  33. proxilion/engines/__init__.py +287 -0
  34. proxilion/engines/base.py +266 -0
  35. proxilion/engines/casbin_engine.py +412 -0
  36. proxilion/engines/opa_engine.py +493 -0
  37. proxilion/engines/simple.py +437 -0
  38. proxilion/exceptions.py +887 -0
  39. proxilion/guards/__init__.py +54 -0
  40. proxilion/guards/input_guard.py +522 -0
  41. proxilion/guards/output_guard.py +634 -0
  42. proxilion/observability/__init__.py +198 -0
  43. proxilion/observability/cost_tracker.py +866 -0
  44. proxilion/observability/hooks.py +683 -0
  45. proxilion/observability/metrics.py +798 -0
  46. proxilion/observability/session_cost_tracker.py +1063 -0
  47. proxilion/policies/__init__.py +67 -0
  48. proxilion/policies/base.py +304 -0
  49. proxilion/policies/builtin.py +486 -0
  50. proxilion/policies/registry.py +376 -0
  51. proxilion/providers/__init__.py +201 -0
  52. proxilion/providers/adapter.py +468 -0
  53. proxilion/providers/anthropic_adapter.py +330 -0
  54. proxilion/providers/gemini_adapter.py +391 -0
  55. proxilion/providers/openai_adapter.py +294 -0
  56. proxilion/py.typed +0 -0
  57. proxilion/resilience/__init__.py +81 -0
  58. proxilion/resilience/degradation.py +615 -0
  59. proxilion/resilience/fallback.py +555 -0
  60. proxilion/resilience/retry.py +554 -0
  61. proxilion/scheduling/__init__.py +57 -0
  62. proxilion/scheduling/priority_queue.py +419 -0
  63. proxilion/scheduling/scheduler.py +459 -0
  64. proxilion/security/__init__.py +244 -0
  65. proxilion/security/agent_trust.py +968 -0
  66. proxilion/security/behavioral_drift.py +794 -0
  67. proxilion/security/cascade_protection.py +869 -0
  68. proxilion/security/circuit_breaker.py +428 -0
  69. proxilion/security/cost_limiter.py +690 -0
  70. proxilion/security/idor_protection.py +460 -0
  71. proxilion/security/intent_capsule.py +849 -0
  72. proxilion/security/intent_validator.py +495 -0
  73. proxilion/security/memory_integrity.py +767 -0
  74. proxilion/security/rate_limiter.py +509 -0
  75. proxilion/security/scope_enforcer.py +680 -0
  76. proxilion/security/sequence_validator.py +636 -0
  77. proxilion/security/trust_boundaries.py +784 -0
  78. proxilion/streaming/__init__.py +70 -0
  79. proxilion/streaming/detector.py +761 -0
  80. proxilion/streaming/transformer.py +674 -0
  81. proxilion/timeouts/__init__.py +55 -0
  82. proxilion/timeouts/decorators.py +477 -0
  83. proxilion/timeouts/manager.py +545 -0
  84. proxilion/tools/__init__.py +69 -0
  85. proxilion/tools/decorators.py +493 -0
  86. proxilion/tools/registry.py +732 -0
  87. proxilion/types.py +339 -0
  88. proxilion/validation/__init__.py +93 -0
  89. proxilion/validation/pydantic_schema.py +351 -0
  90. proxilion/validation/schema.py +651 -0
  91. proxilion-0.0.1.dist-info/METADATA +872 -0
  92. proxilion-0.0.1.dist-info/RECORD +94 -0
  93. proxilion-0.0.1.dist-info/WHEEL +4 -0
  94. proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,794 @@
1
+ """
2
+ Behavioral Drift Detection for Proxilion.
3
+
4
+ Addresses OWASP ASI10: Rogue Agents.
5
+
6
+ This module detects when an agent's behavior deviates significantly
7
+ from its established baseline, potentially indicating:
8
+ - Compromise or injection attack
9
+ - Malfunction or loop behavior
10
+ - Goal hijacking
11
+ - Rogue agent behavior
12
+
13
+ Example:
14
+ >>> from proxilion.security.behavioral_drift import (
15
+ ... BehavioralMonitor,
16
+ ... DriftDetector,
17
+ ... KillSwitch,
18
+ ... )
19
+ >>>
20
+ >>> # Create monitor
21
+ >>> monitor = BehavioralMonitor(agent_id="my_agent")
22
+ >>>
23
+ >>> # Record normal behavior during baseline period
24
+ >>> for i in range(100):
25
+ ... monitor.record_event("tool_call", {"tool": "search"})
26
+ ... monitor.record_event("response", {"length": 150})
27
+ >>>
28
+ >>> # Lock baseline
29
+ >>> monitor.lock_baseline()
30
+ >>>
31
+ >>> # Detect drift during operation
32
+ >>> drift = monitor.check_drift()
33
+ >>> if drift.is_drifting:
34
+ ... print(f"Behavioral drift detected: {drift.reason}")
35
+ ... if drift.severity > 0.8:
36
+ ... kill_switch.activate("Severe behavioral drift")
37
+ """
38
+
39
+ from __future__ import annotations
40
+
41
+ import logging
42
+ import math
43
+ import statistics
44
+ import threading
45
+ import time
46
+ from collections import deque
47
+ from dataclasses import dataclass, field
48
+ from datetime import datetime, timezone
49
+ from enum import Enum
50
+ from typing import Any, Callable
51
+
52
+ from proxilion.exceptions import BehavioralDriftError, EmergencyHaltError
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class DriftMetric(Enum):
58
+ """Types of behavioral metrics tracked."""
59
+
60
+ TOOL_CALL_RATE = "tool_call_rate"
61
+ """Calls per minute."""
62
+
63
+ RESPONSE_LENGTH = "response_length"
64
+ """Average response length."""
65
+
66
+ ERROR_RATE = "error_rate"
67
+ """Errors per minute."""
68
+
69
+ UNIQUE_TOOLS = "unique_tools"
70
+ """Number of unique tools used."""
71
+
72
+ LATENCY = "latency"
73
+ """Average response latency."""
74
+
75
+ TOKEN_USAGE = "token_usage"
76
+ """Tokens consumed per request."""
77
+
78
+ TOOL_REPETITION = "tool_repetition"
79
+ """Same tool called consecutively."""
80
+
81
+ SCOPE_VIOLATIONS = "scope_violations"
82
+ """Attempts to exceed scope."""
83
+
84
+ CONTEXT_SIZE = "context_size"
85
+ """Size of conversation context."""
86
+
87
+ CUSTOM = "custom"
88
+ """User-defined metric."""
89
+
90
+
91
+ @dataclass
92
+ class MetricValue:
93
+ """A single metric measurement."""
94
+
95
+ metric: DriftMetric
96
+ value: float
97
+ timestamp: float
98
+ metadata: dict[str, Any] = field(default_factory=dict)
99
+
100
+
101
+ @dataclass
102
+ class BaselineStats:
103
+ """Statistical baseline for a metric."""
104
+
105
+ metric: DriftMetric
106
+ mean: float
107
+ std_dev: float
108
+ min_value: float
109
+ max_value: float
110
+ sample_count: int
111
+ percentile_95: float
112
+ percentile_99: float
113
+
114
+ def z_score(self, value: float) -> float:
115
+ """Calculate z-score for a value."""
116
+ if self.std_dev == 0:
117
+ return 0.0 if value == self.mean else float("inf")
118
+ return (value - self.mean) / self.std_dev
119
+
120
+ def is_anomaly(self, value: float, threshold: float = 3.0) -> bool:
121
+ """Check if value is anomalous (beyond threshold std devs)."""
122
+ return abs(self.z_score(value)) > threshold
123
+
124
+
125
+ @dataclass
126
+ class DriftResult:
127
+ """Result of drift detection."""
128
+
129
+ is_drifting: bool
130
+ severity: float # 0.0 to 1.0
131
+ drifting_metrics: list[tuple[DriftMetric, float, float]] # (metric, value, z_score)
132
+ reason: str
133
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
134
+
135
+ def to_dict(self) -> dict[str, Any]:
136
+ """Serialize to dict."""
137
+ return {
138
+ "is_drifting": self.is_drifting,
139
+ "severity": self.severity,
140
+ "drifting_metrics": [
141
+ {"metric": m.value, "value": v, "z_score": z}
142
+ for m, v, z in self.drifting_metrics
143
+ ],
144
+ "reason": self.reason,
145
+ "timestamp": self.timestamp.isoformat(),
146
+ }
147
+
148
+
149
+ class BehavioralMonitor:
150
+ """
151
+ Monitors agent behavior and detects drift from baseline.
152
+
153
+ Tracks multiple behavioral metrics and uses statistical analysis
154
+ to detect when current behavior deviates from established patterns.
155
+
156
+ Example:
157
+ >>> monitor = BehavioralMonitor(agent_id="my_agent")
158
+ >>>
159
+ >>> # Record events during operation
160
+ >>> monitor.record_tool_call("search", {"query": "test"})
161
+ >>> monitor.record_response({"content": "result", "tokens": 50})
162
+ >>>
163
+ >>> # Check for drift
164
+ >>> result = monitor.check_drift()
165
+ >>> if result.is_drifting:
166
+ ... handle_drift(result)
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ agent_id: str,
172
+ baseline_window: int = 100,
173
+ detection_window: int = 10,
174
+ drift_threshold: float = 3.0,
175
+ min_baseline_samples: int = 20,
176
+ ) -> None:
177
+ """
178
+ Initialize the monitor.
179
+
180
+ Args:
181
+ agent_id: Unique identifier for the agent.
182
+ baseline_window: Number of samples for baseline calculation.
183
+ detection_window: Recent samples for drift detection.
184
+ drift_threshold: Z-score threshold for drift detection.
185
+ min_baseline_samples: Minimum samples before baseline is valid.
186
+ """
187
+ self.agent_id = agent_id
188
+ self._baseline_window = baseline_window
189
+ self._detection_window = detection_window
190
+ self._drift_threshold = drift_threshold
191
+ self._min_baseline_samples = min_baseline_samples
192
+
193
+ # Metric storage
194
+ self._metrics: dict[DriftMetric, deque[MetricValue]] = {}
195
+ for metric in DriftMetric:
196
+ self._metrics[metric] = deque(maxlen=baseline_window)
197
+
198
+ # Baseline (locked after initial period)
199
+ self._baseline: dict[DriftMetric, BaselineStats] = {}
200
+ self._baseline_locked = False
201
+
202
+ # Rate tracking
203
+ self._event_times: deque[float] = deque(maxlen=1000)
204
+ self._tool_history: deque[str] = deque(maxlen=100)
205
+ self._error_count = 0
206
+
207
+ # Callbacks
208
+ self._drift_callbacks: list[Callable[[DriftResult], None]] = []
209
+
210
+ self._lock = threading.RLock()
211
+
212
+ logger.debug(f"BehavioralMonitor initialized for agent: {agent_id}")
213
+
214
+ def record_event(
215
+ self,
216
+ event_type: str,
217
+ data: dict[str, Any],
218
+ ) -> None:
219
+ """
220
+ Record a generic event.
221
+
222
+ Args:
223
+ event_type: Type of event (tool_call, response, error, etc.).
224
+ data: Event data.
225
+ """
226
+ now = time.time()
227
+ self._event_times.append(now)
228
+
229
+ with self._lock:
230
+ if event_type == "tool_call":
231
+ self._record_tool_call(data, now)
232
+ elif event_type == "response":
233
+ self._record_response(data, now)
234
+ elif event_type == "error":
235
+ self._record_error(data, now)
236
+ elif event_type == "latency":
237
+ self._record_metric(DriftMetric.LATENCY, data.get("value", 0), now)
238
+ elif event_type == "tokens":
239
+ self._record_metric(DriftMetric.TOKEN_USAGE, data.get("value", 0), now)
240
+ elif event_type == "context_size":
241
+ self._record_metric(DriftMetric.CONTEXT_SIZE, data.get("value", 0), now)
242
+ elif event_type == "scope_violation":
243
+ self._record_metric(DriftMetric.SCOPE_VIOLATIONS, 1.0, now)
244
+
245
+ def record_tool_call(
246
+ self,
247
+ tool_name: str,
248
+ arguments: dict[str, Any] | None = None,
249
+ latency_ms: float | None = None,
250
+ ) -> None:
251
+ """Record a tool call event."""
252
+ now = time.time()
253
+ self._event_times.append(now)
254
+
255
+ with self._lock:
256
+ self._tool_history.append(tool_name)
257
+
258
+ # Calculate call rate (calls per minute)
259
+ recent_calls = sum(1 for t in self._event_times if now - t < 60)
260
+ self._record_metric(DriftMetric.TOOL_CALL_RATE, recent_calls, now)
261
+
262
+ # Track unique tools
263
+ unique_tools = len(set(self._tool_history))
264
+ self._record_metric(DriftMetric.UNIQUE_TOOLS, unique_tools, now)
265
+
266
+ # Track repetition
267
+ if len(self._tool_history) >= 2:
268
+ repetition = sum(
269
+ 1 for i in range(1, len(self._tool_history))
270
+ if self._tool_history[i] == self._tool_history[i - 1]
271
+ )
272
+ self._record_metric(DriftMetric.TOOL_REPETITION, repetition, now)
273
+
274
+ # Record latency if provided
275
+ if latency_ms is not None:
276
+ self._record_metric(DriftMetric.LATENCY, latency_ms, now)
277
+
278
+ def record_response(
279
+ self,
280
+ response: dict[str, Any],
281
+ ) -> None:
282
+ """Record a response event."""
283
+ now = time.time()
284
+
285
+ with self._lock:
286
+ # Response length
287
+ content = response.get("content", "")
288
+ if isinstance(content, str):
289
+ self._record_metric(DriftMetric.RESPONSE_LENGTH, len(content), now)
290
+
291
+ # Token usage
292
+ tokens = response.get("tokens") or response.get("token_count")
293
+ if tokens:
294
+ self._record_metric(DriftMetric.TOKEN_USAGE, tokens, now)
295
+
296
+ def record_error(self, error_info: dict[str, Any]) -> None:
297
+ """Record an error event."""
298
+ now = time.time()
299
+ self._error_count += 1
300
+
301
+ with self._lock:
302
+ # Error rate (errors per minute)
303
+ recent_errors = sum(
304
+ 1 for mv in self._metrics[DriftMetric.ERROR_RATE]
305
+ if now - mv.timestamp < 60
306
+ )
307
+ self._record_metric(DriftMetric.ERROR_RATE, recent_errors + 1, now)
308
+
309
+ def _record_tool_call(self, data: dict[str, Any], timestamp: float) -> None:
310
+ """Internal tool call recording."""
311
+ tool_name = data.get("tool") or data.get("tool_name", "unknown")
312
+ self._tool_history.append(tool_name)
313
+
314
+ # Calculate metrics
315
+ recent_calls = sum(1 for t in self._event_times if timestamp - t < 60)
316
+ self._record_metric(DriftMetric.TOOL_CALL_RATE, recent_calls, timestamp)
317
+
318
+ unique_tools = len(set(self._tool_history))
319
+ self._record_metric(DriftMetric.UNIQUE_TOOLS, unique_tools, timestamp)
320
+
321
+ def _record_response(self, data: dict[str, Any], timestamp: float) -> None:
322
+ """Internal response recording."""
323
+ content = data.get("content", "")
324
+ if isinstance(content, str):
325
+ self._record_metric(DriftMetric.RESPONSE_LENGTH, len(content), timestamp)
326
+
327
+ tokens = data.get("tokens", 0)
328
+ if tokens:
329
+ self._record_metric(DriftMetric.TOKEN_USAGE, tokens, timestamp)
330
+
331
+ def _record_error(self, data: dict[str, Any], timestamp: float) -> None:
332
+ """Internal error recording."""
333
+ self._error_count += 1
334
+ recent_errors = sum(
335
+ 1 for mv in self._metrics[DriftMetric.ERROR_RATE]
336
+ if timestamp - mv.timestamp < 60
337
+ )
338
+ self._record_metric(DriftMetric.ERROR_RATE, recent_errors + 1, timestamp)
339
+
340
+ def _record_metric(
341
+ self,
342
+ metric: DriftMetric,
343
+ value: float,
344
+ timestamp: float,
345
+ ) -> None:
346
+ """Record a metric value."""
347
+ self._metrics[metric].append(MetricValue(
348
+ metric=metric,
349
+ value=value,
350
+ timestamp=timestamp,
351
+ ))
352
+
353
+ def lock_baseline(self) -> dict[DriftMetric, BaselineStats]:
354
+ """
355
+ Lock the current baseline.
356
+
357
+ Calculates statistical baselines from current data and
358
+ locks them for future drift detection.
359
+
360
+ Returns:
361
+ Dictionary of baseline stats per metric.
362
+
363
+ Raises:
364
+ ValueError: If not enough samples for baseline.
365
+ """
366
+ with self._lock:
367
+ self._baseline = {}
368
+
369
+ for metric, values in self._metrics.items():
370
+ if len(values) < self._min_baseline_samples:
371
+ continue
372
+
373
+ samples = [v.value for v in values]
374
+
375
+ # Calculate statistics
376
+ mean = statistics.mean(samples)
377
+ std_dev = statistics.stdev(samples) if len(samples) > 1 else 0.0
378
+ sorted_samples = sorted(samples)
379
+ p95_idx = int(len(sorted_samples) * 0.95)
380
+ p99_idx = int(len(sorted_samples) * 0.99)
381
+
382
+ self._baseline[metric] = BaselineStats(
383
+ metric=metric,
384
+ mean=mean,
385
+ std_dev=std_dev,
386
+ min_value=min(samples),
387
+ max_value=max(samples),
388
+ sample_count=len(samples),
389
+ percentile_95=sorted_samples[p95_idx] if p95_idx < len(sorted_samples) else max(samples),
390
+ percentile_99=sorted_samples[p99_idx] if p99_idx < len(sorted_samples) else max(samples),
391
+ )
392
+
393
+ self._baseline_locked = True
394
+ logger.info(f"Baseline locked with {len(self._baseline)} metrics")
395
+
396
+ return self._baseline
397
+
398
+ def check_drift(self) -> DriftResult:
399
+ """
400
+ Check for behavioral drift from baseline.
401
+
402
+ Returns:
403
+ DriftResult indicating if drift was detected.
404
+ """
405
+ with self._lock:
406
+ if not self._baseline_locked:
407
+ # Auto-lock baseline if we have enough samples
408
+ has_enough = any(
409
+ len(values) >= self._min_baseline_samples
410
+ for values in self._metrics.values()
411
+ )
412
+ if has_enough:
413
+ self.lock_baseline()
414
+ else:
415
+ return DriftResult(
416
+ is_drifting=False,
417
+ severity=0.0,
418
+ drifting_metrics=[],
419
+ reason="Baseline not yet established",
420
+ )
421
+
422
+ drifting_metrics: list[tuple[DriftMetric, float, float]] = []
423
+ max_severity = 0.0
424
+
425
+ for metric, baseline in self._baseline.items():
426
+ # Get recent values
427
+ recent = list(self._metrics[metric])[-self._detection_window:]
428
+ if not recent:
429
+ continue
430
+
431
+ # Calculate current value (average of recent)
432
+ current_value = statistics.mean([v.value for v in recent])
433
+
434
+ # Calculate z-score
435
+ z_score = baseline.z_score(current_value)
436
+
437
+ # Check for drift
438
+ if abs(z_score) > self._drift_threshold:
439
+ drifting_metrics.append((metric, current_value, z_score))
440
+
441
+ # Calculate severity (normalized z-score)
442
+ severity = min(1.0, abs(z_score) / (self._drift_threshold * 2))
443
+ max_severity = max(max_severity, severity)
444
+
445
+ if drifting_metrics:
446
+ reasons = [
447
+ f"{m.value}: {v:.2f} (z={z:.1f})"
448
+ for m, v, z in drifting_metrics
449
+ ]
450
+ result = DriftResult(
451
+ is_drifting=True,
452
+ severity=max_severity,
453
+ drifting_metrics=drifting_metrics,
454
+ reason=f"Drift detected in: {', '.join(reasons)}",
455
+ )
456
+
457
+ # Notify callbacks
458
+ for callback in self._drift_callbacks:
459
+ try:
460
+ callback(result)
461
+ except Exception as e:
462
+ logger.error(f"Drift callback error: {e}")
463
+
464
+ return result
465
+
466
+ return DriftResult(
467
+ is_drifting=False,
468
+ severity=0.0,
469
+ drifting_metrics=[],
470
+ reason="Behavior within normal parameters",
471
+ )
472
+
473
+ def on_drift(self, callback: Callable[[DriftResult], None]) -> None:
474
+ """Register a callback for drift detection."""
475
+ self._drift_callbacks.append(callback)
476
+
477
+ def get_current_metrics(self) -> dict[str, float]:
478
+ """Get current metric values."""
479
+ with self._lock:
480
+ result = {}
481
+ for metric, values in self._metrics.items():
482
+ if values:
483
+ recent = list(values)[-self._detection_window:]
484
+ result[metric.value] = statistics.mean([v.value for v in recent])
485
+ return result
486
+
487
+ def get_baseline(self) -> dict[DriftMetric, BaselineStats]:
488
+ """Get the current baseline."""
489
+ with self._lock:
490
+ return self._baseline.copy()
491
+
492
+ def reset(self) -> None:
493
+ """Reset the monitor."""
494
+ with self._lock:
495
+ for values in self._metrics.values():
496
+ values.clear()
497
+ self._baseline = {}
498
+ self._baseline_locked = False
499
+ self._event_times.clear()
500
+ self._tool_history.clear()
501
+ self._error_count = 0
502
+
503
+
504
+ class KillSwitch:
505
+ """
506
+ Emergency halt mechanism for rogue agent behavior.
507
+
508
+ Provides immediate shutdown capability when severe behavioral
509
+ drift or other anomalies are detected.
510
+
511
+ Example:
512
+ >>> kill_switch = KillSwitch()
513
+ >>>
514
+ >>> # Register halt handlers
515
+ >>> kill_switch.on_halt(lambda reason: cleanup_resources())
516
+ >>> kill_switch.on_halt(lambda reason: notify_operators(reason))
517
+ >>>
518
+ >>> # Activate when needed
519
+ >>> if drift.severity > 0.9:
520
+ ... kill_switch.activate("Severe behavioral drift detected")
521
+ """
522
+
523
+ def __init__(
524
+ self,
525
+ auto_reset_seconds: float | None = None,
526
+ ) -> None:
527
+ """
528
+ Initialize the kill switch.
529
+
530
+ Args:
531
+ auto_reset_seconds: If set, auto-reset after this many seconds.
532
+ """
533
+ self._active = False
534
+ self._activation_time: datetime | None = None
535
+ self._activation_reason: str = ""
536
+ self._auto_reset_seconds = auto_reset_seconds
537
+
538
+ self._halt_callbacks: list[Callable[[str], None]] = []
539
+ self._reset_callbacks: list[Callable[[], None]] = []
540
+
541
+ self._lock = threading.RLock()
542
+
543
+ logger.debug("KillSwitch initialized")
544
+
545
+ @property
546
+ def is_active(self) -> bool:
547
+ """Check if kill switch is active."""
548
+ with self._lock:
549
+ if self._active and self._auto_reset_seconds:
550
+ # Check for auto-reset
551
+ if self._activation_time:
552
+ elapsed = (datetime.now(timezone.utc) - self._activation_time).total_seconds()
553
+ if elapsed > self._auto_reset_seconds:
554
+ self._active = False
555
+ self._activation_reason = ""
556
+ logger.info("Kill switch auto-reset")
557
+ return self._active
558
+
559
+ @property
560
+ def reason(self) -> str:
561
+ """Get activation reason."""
562
+ return self._activation_reason
563
+
564
+ def activate(
565
+ self,
566
+ reason: str,
567
+ triggered_by: str = "system",
568
+ raise_exception: bool = True,
569
+ ) -> None:
570
+ """
571
+ Activate the kill switch.
572
+
573
+ Args:
574
+ reason: Why the kill switch was activated.
575
+ triggered_by: What triggered the activation.
576
+ raise_exception: If True, raise EmergencyHaltError.
577
+
578
+ Raises:
579
+ EmergencyHaltError: If raise_exception is True.
580
+ """
581
+ with self._lock:
582
+ self._active = True
583
+ self._activation_time = datetime.now(timezone.utc)
584
+ self._activation_reason = reason
585
+
586
+ logger.critical(
587
+ f"KILL SWITCH ACTIVATED: {reason} (triggered by: {triggered_by})"
588
+ )
589
+
590
+ # Notify handlers
591
+ for callback in self._halt_callbacks:
592
+ try:
593
+ callback(reason)
594
+ except Exception as e:
595
+ logger.error(f"Halt callback error: {e}")
596
+
597
+ if raise_exception:
598
+ raise EmergencyHaltError(reason=reason, triggered_by=triggered_by)
599
+
600
+ def reset(self) -> bool:
601
+ """
602
+ Reset the kill switch.
603
+
604
+ Returns:
605
+ True if was active and is now reset.
606
+ """
607
+ with self._lock:
608
+ was_active = self._active
609
+ self._active = False
610
+ self._activation_reason = ""
611
+ self._activation_time = None
612
+
613
+ if was_active:
614
+ logger.warning("Kill switch reset")
615
+ for callback in self._reset_callbacks:
616
+ try:
617
+ callback()
618
+ except Exception as e:
619
+ logger.error(f"Reset callback error: {e}")
620
+
621
+ return was_active
622
+
623
+ def check(self) -> None:
624
+ """
625
+ Check if kill switch is active and raise if so.
626
+
627
+ Raises:
628
+ EmergencyHaltError: If kill switch is active.
629
+ """
630
+ if self.is_active:
631
+ raise EmergencyHaltError(
632
+ reason=self._activation_reason,
633
+ triggered_by="kill_switch_check",
634
+ )
635
+
636
+ def on_halt(self, callback: Callable[[str], None]) -> None:
637
+ """Register a callback for when kill switch activates."""
638
+ self._halt_callbacks.append(callback)
639
+
640
+ def on_reset(self, callback: Callable[[], None]) -> None:
641
+ """Register a callback for when kill switch resets."""
642
+ self._reset_callbacks.append(callback)
643
+
644
+ def get_status(self) -> dict[str, Any]:
645
+ """Get kill switch status."""
646
+ with self._lock:
647
+ return {
648
+ "active": self._active,
649
+ "reason": self._activation_reason,
650
+ "activation_time": self._activation_time.isoformat() if self._activation_time else None,
651
+ }
652
+
653
+
654
+ class DriftDetector:
655
+ """
656
+ High-level drift detector with integrated kill switch.
657
+
658
+ Combines behavioral monitoring with automatic response
659
+ to detected anomalies.
660
+
661
+ Example:
662
+ >>> detector = DriftDetector(
663
+ ... agent_id="my_agent",
664
+ ... auto_halt_threshold=0.9,
665
+ ... )
666
+ >>>
667
+ >>> # Record events
668
+ >>> detector.record_tool_call("search", {"query": "test"})
669
+ >>>
670
+ >>> # This will auto-halt if drift exceeds threshold
671
+ >>> detector.check()
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ agent_id: str,
677
+ auto_halt_threshold: float = 0.9,
678
+ warning_threshold: float = 0.5,
679
+ monitor_kwargs: dict[str, Any] | None = None,
680
+ ) -> None:
681
+ """
682
+ Initialize the detector.
683
+
684
+ Args:
685
+ agent_id: Unique identifier for the agent.
686
+ auto_halt_threshold: Severity threshold for automatic halt.
687
+ warning_threshold: Severity threshold for warnings.
688
+ monitor_kwargs: Additional kwargs for BehavioralMonitor.
689
+ """
690
+ self.agent_id = agent_id
691
+ self._auto_halt_threshold = auto_halt_threshold
692
+ self._warning_threshold = warning_threshold
693
+
694
+ self._monitor = BehavioralMonitor(agent_id, **(monitor_kwargs or {}))
695
+ self._kill_switch = KillSwitch()
696
+
697
+ # Wire up automatic drift handling
698
+ self._monitor.on_drift(self._handle_drift)
699
+
700
+ @property
701
+ def monitor(self) -> BehavioralMonitor:
702
+ """Get the behavioral monitor."""
703
+ return self._monitor
704
+
705
+ @property
706
+ def kill_switch(self) -> KillSwitch:
707
+ """Get the kill switch."""
708
+ return self._kill_switch
709
+
710
+ def record_tool_call(
711
+ self,
712
+ tool_name: str,
713
+ arguments: dict[str, Any] | None = None,
714
+ latency_ms: float | None = None,
715
+ ) -> None:
716
+ """Record a tool call and check for drift."""
717
+ self._kill_switch.check() # Fail fast if halted
718
+ self._monitor.record_tool_call(tool_name, arguments, latency_ms)
719
+
720
+ def record_response(self, response: dict[str, Any]) -> None:
721
+ """Record a response and check for drift."""
722
+ self._kill_switch.check()
723
+ self._monitor.record_response(response)
724
+
725
+ def record_error(self, error_info: dict[str, Any]) -> None:
726
+ """Record an error and check for drift."""
727
+ self._monitor.record_error(error_info)
728
+
729
+ def record_event(self, event_type: str, data: dict[str, Any]) -> None:
730
+ """Record a generic event."""
731
+ self._kill_switch.check()
732
+ self._monitor.record_event(event_type, data)
733
+
734
+ def check(self) -> DriftResult:
735
+ """
736
+ Check for drift and respond accordingly.
737
+
738
+ Returns:
739
+ DriftResult from the check.
740
+
741
+ Raises:
742
+ EmergencyHaltError: If drift exceeds auto_halt_threshold.
743
+ """
744
+ self._kill_switch.check()
745
+ result = self._monitor.check_drift()
746
+
747
+ if result.is_drifting:
748
+ self._handle_drift(result)
749
+
750
+ return result
751
+
752
+ def _handle_drift(self, result: DriftResult) -> None:
753
+ """Handle detected drift."""
754
+ if result.severity >= self._auto_halt_threshold:
755
+ self._kill_switch.activate(
756
+ reason=f"Severe behavioral drift: {result.reason}",
757
+ triggered_by="drift_detector",
758
+ )
759
+ elif result.severity >= self._warning_threshold:
760
+ logger.warning(
761
+ f"Behavioral drift warning for {self.agent_id}: {result.reason}"
762
+ )
763
+
764
+ def lock_baseline(self) -> dict[DriftMetric, BaselineStats]:
765
+ """Lock the baseline."""
766
+ return self._monitor.lock_baseline()
767
+
768
+ def reset(self) -> None:
769
+ """Reset the detector."""
770
+ self._monitor.reset()
771
+ self._kill_switch.reset()
772
+
773
+ def get_status(self) -> dict[str, Any]:
774
+ """Get detector status."""
775
+ return {
776
+ "agent_id": self.agent_id,
777
+ "kill_switch": self._kill_switch.get_status(),
778
+ "current_metrics": self._monitor.get_current_metrics(),
779
+ "baseline_locked": self._monitor._baseline_locked,
780
+ }
781
+
782
+
783
+ # Convenience exports
784
+ __all__ = [
785
+ # Core classes
786
+ "BehavioralMonitor",
787
+ "DriftDetector",
788
+ "KillSwitch",
789
+ # Data classes
790
+ "DriftResult",
791
+ "BaselineStats",
792
+ "MetricValue",
793
+ "DriftMetric",
794
+ ]