proxilion 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. proxilion/__init__.py +136 -0
  2. proxilion/audit/__init__.py +133 -0
  3. proxilion/audit/base_exporters.py +527 -0
  4. proxilion/audit/compliance/__init__.py +130 -0
  5. proxilion/audit/compliance/base.py +457 -0
  6. proxilion/audit/compliance/eu_ai_act.py +603 -0
  7. proxilion/audit/compliance/iso27001.py +544 -0
  8. proxilion/audit/compliance/soc2.py +491 -0
  9. proxilion/audit/events.py +493 -0
  10. proxilion/audit/explainability.py +1173 -0
  11. proxilion/audit/exporters/__init__.py +58 -0
  12. proxilion/audit/exporters/aws_s3.py +636 -0
  13. proxilion/audit/exporters/azure_storage.py +608 -0
  14. proxilion/audit/exporters/cloud_base.py +468 -0
  15. proxilion/audit/exporters/gcp_storage.py +570 -0
  16. proxilion/audit/exporters/multi_exporter.py +498 -0
  17. proxilion/audit/hash_chain.py +652 -0
  18. proxilion/audit/logger.py +543 -0
  19. proxilion/caching/__init__.py +49 -0
  20. proxilion/caching/tool_cache.py +633 -0
  21. proxilion/context/__init__.py +73 -0
  22. proxilion/context/context_window.py +556 -0
  23. proxilion/context/message_history.py +505 -0
  24. proxilion/context/session.py +735 -0
  25. proxilion/contrib/__init__.py +51 -0
  26. proxilion/contrib/anthropic.py +609 -0
  27. proxilion/contrib/google.py +1012 -0
  28. proxilion/contrib/langchain.py +641 -0
  29. proxilion/contrib/mcp.py +893 -0
  30. proxilion/contrib/openai.py +646 -0
  31. proxilion/core.py +3058 -0
  32. proxilion/decorators.py +966 -0
  33. proxilion/engines/__init__.py +287 -0
  34. proxilion/engines/base.py +266 -0
  35. proxilion/engines/casbin_engine.py +412 -0
  36. proxilion/engines/opa_engine.py +493 -0
  37. proxilion/engines/simple.py +437 -0
  38. proxilion/exceptions.py +887 -0
  39. proxilion/guards/__init__.py +54 -0
  40. proxilion/guards/input_guard.py +522 -0
  41. proxilion/guards/output_guard.py +634 -0
  42. proxilion/observability/__init__.py +198 -0
  43. proxilion/observability/cost_tracker.py +866 -0
  44. proxilion/observability/hooks.py +683 -0
  45. proxilion/observability/metrics.py +798 -0
  46. proxilion/observability/session_cost_tracker.py +1063 -0
  47. proxilion/policies/__init__.py +67 -0
  48. proxilion/policies/base.py +304 -0
  49. proxilion/policies/builtin.py +486 -0
  50. proxilion/policies/registry.py +376 -0
  51. proxilion/providers/__init__.py +201 -0
  52. proxilion/providers/adapter.py +468 -0
  53. proxilion/providers/anthropic_adapter.py +330 -0
  54. proxilion/providers/gemini_adapter.py +391 -0
  55. proxilion/providers/openai_adapter.py +294 -0
  56. proxilion/py.typed +0 -0
  57. proxilion/resilience/__init__.py +81 -0
  58. proxilion/resilience/degradation.py +615 -0
  59. proxilion/resilience/fallback.py +555 -0
  60. proxilion/resilience/retry.py +554 -0
  61. proxilion/scheduling/__init__.py +57 -0
  62. proxilion/scheduling/priority_queue.py +419 -0
  63. proxilion/scheduling/scheduler.py +459 -0
  64. proxilion/security/__init__.py +244 -0
  65. proxilion/security/agent_trust.py +968 -0
  66. proxilion/security/behavioral_drift.py +794 -0
  67. proxilion/security/cascade_protection.py +869 -0
  68. proxilion/security/circuit_breaker.py +428 -0
  69. proxilion/security/cost_limiter.py +690 -0
  70. proxilion/security/idor_protection.py +460 -0
  71. proxilion/security/intent_capsule.py +849 -0
  72. proxilion/security/intent_validator.py +495 -0
  73. proxilion/security/memory_integrity.py +767 -0
  74. proxilion/security/rate_limiter.py +509 -0
  75. proxilion/security/scope_enforcer.py +680 -0
  76. proxilion/security/sequence_validator.py +636 -0
  77. proxilion/security/trust_boundaries.py +784 -0
  78. proxilion/streaming/__init__.py +70 -0
  79. proxilion/streaming/detector.py +761 -0
  80. proxilion/streaming/transformer.py +674 -0
  81. proxilion/timeouts/__init__.py +55 -0
  82. proxilion/timeouts/decorators.py +477 -0
  83. proxilion/timeouts/manager.py +545 -0
  84. proxilion/tools/__init__.py +69 -0
  85. proxilion/tools/decorators.py +493 -0
  86. proxilion/tools/registry.py +732 -0
  87. proxilion/types.py +339 -0
  88. proxilion/validation/__init__.py +93 -0
  89. proxilion/validation/pydantic_schema.py +351 -0
  90. proxilion/validation/schema.py +651 -0
  91. proxilion-0.0.1.dist-info/METADATA +872 -0
  92. proxilion-0.0.1.dist-info/RECORD +94 -0
  93. proxilion-0.0.1.dist-info/WHEEL +4 -0
  94. proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,767 @@
1
+ """
2
+ Memory and Context Integrity for Proxilion.
3
+
4
+ Addresses OWASP ASI06: Memory & Context Poisoning.
5
+
6
+ This module provides cryptographic integrity verification for:
7
+ - Conversation history / message windows
8
+ - Vector store embeddings and retrieved documents
9
+ - Long-term memory (knowledge graphs, user preferences)
10
+ - RAG context injection detection
11
+
12
+ Memory poisoning attacks inject malicious content into an agent's
13
+ persistent memory, causing incorrect behavior days or weeks later.
14
+ This module detects such tampering.
15
+
16
+ Example:
17
+ >>> from proxilion.security.memory_integrity import (
18
+ ... MemoryIntegrityGuard,
19
+ ... ContextWindow,
20
+ ... SignedMessage,
21
+ ... )
22
+ >>>
23
+ >>> guard = MemoryIntegrityGuard(secret_key="your-secret-key")
24
+ >>>
25
+ >>> # Sign messages as they're added
26
+ >>> msg = guard.sign_message(role="user", content="Hello")
27
+ >>> context.append(msg)
28
+ >>>
29
+ >>> # Verify context before sending to LLM
30
+ >>> if guard.verify_context(context):
31
+ ... response = llm.generate(context)
32
+ ... else:
33
+ ... raise ContextTamperingError("Context has been modified")
34
+ >>>
35
+ >>> # Detect RAG poisoning
36
+ >>> docs = retriever.get_relevant_docs(query)
37
+ >>> safe_docs = guard.scan_rag_documents(docs)
38
+ """
39
+
40
+ from __future__ import annotations
41
+
42
+ import hashlib
43
+ import hmac
44
+ import json
45
+ import logging
46
+ import re
47
+ import threading
48
+ import time
49
+ from dataclasses import dataclass, field
50
+ from datetime import datetime, timezone
51
+ from enum import Enum
52
+ from typing import Any, Protocol, runtime_checkable
53
+
54
+ logger = logging.getLogger(__name__)
55
+
56
+
57
+ class IntegrityViolationType(Enum):
58
+ """Types of integrity violations."""
59
+
60
+ SIGNATURE_MISMATCH = "signature_mismatch"
61
+ """Message signature doesn't match content."""
62
+
63
+ SEQUENCE_GAP = "sequence_gap"
64
+ """Missing messages in sequence."""
65
+
66
+ SEQUENCE_REORDER = "sequence_reorder"
67
+ """Messages out of order."""
68
+
69
+ TIMESTAMP_ANOMALY = "timestamp_anomaly"
70
+ """Timestamp inconsistency detected."""
71
+
72
+ ROLE_INJECTION = "role_injection"
73
+ """Attempted role spoofing in content."""
74
+
75
+ RAG_POISONING = "rag_poisoning"
76
+ """Malicious content in retrieved documents."""
77
+
78
+ CONTEXT_OVERFLOW = "context_overflow"
79
+ """Context exceeds expected bounds."""
80
+
81
+ HASH_CHAIN_BREAK = "hash_chain_break"
82
+ """Hash chain integrity violated."""
83
+
84
+
85
+ @dataclass
86
+ class IntegrityViolation:
87
+ """Details of an integrity violation."""
88
+
89
+ violation_type: IntegrityViolationType
90
+ message: str
91
+ severity: float # 0.0 to 1.0
92
+ index: int | None = None # Position in context where violation occurred
93
+ expected: str | None = None
94
+ actual: str | None = None
95
+ timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
96
+
97
+ def to_dict(self) -> dict[str, Any]:
98
+ """Convert to dictionary."""
99
+ return {
100
+ "type": self.violation_type.value,
101
+ "message": self.message,
102
+ "severity": self.severity,
103
+ "index": self.index,
104
+ "expected": self.expected,
105
+ "actual": self.actual,
106
+ "timestamp": self.timestamp.isoformat(),
107
+ }
108
+
109
+
110
+ @dataclass
111
+ class SignedMessage:
112
+ """
113
+ A cryptographically signed message for context integrity.
114
+
115
+ Each message includes:
116
+ - Content and role
117
+ - Sequence number for ordering
118
+ - Timestamp for temporal verification
119
+ - HMAC signature for tamper detection
120
+ - Previous hash for chain integrity
121
+ """
122
+
123
+ role: str
124
+ content: str
125
+ sequence: int
126
+ timestamp: float
127
+ signature: str
128
+ previous_hash: str
129
+ metadata: dict[str, Any] = field(default_factory=dict)
130
+
131
+ def to_dict(self) -> dict[str, Any]:
132
+ """Convert to dictionary for serialization."""
133
+ return {
134
+ "role": self.role,
135
+ "content": self.content,
136
+ "sequence": self.sequence,
137
+ "timestamp": self.timestamp,
138
+ "signature": self.signature,
139
+ "previous_hash": self.previous_hash,
140
+ "metadata": self.metadata,
141
+ }
142
+
143
+ @classmethod
144
+ def from_dict(cls, data: dict[str, Any]) -> SignedMessage:
145
+ """Create from dictionary."""
146
+ return cls(
147
+ role=data["role"],
148
+ content=data["content"],
149
+ sequence=data["sequence"],
150
+ timestamp=data["timestamp"],
151
+ signature=data["signature"],
152
+ previous_hash=data["previous_hash"],
153
+ metadata=data.get("metadata", {}),
154
+ )
155
+
156
+ def content_hash(self) -> str:
157
+ """Get hash of message content for chaining."""
158
+ content = f"{self.role}:{self.content}:{self.sequence}:{self.timestamp}"
159
+ return hashlib.sha256(content.encode()).hexdigest()
160
+
161
+
162
+ @dataclass
163
+ class VerificationResult:
164
+ """Result of context verification."""
165
+
166
+ valid: bool
167
+ violations: list[IntegrityViolation] = field(default_factory=list)
168
+ verified_count: int = 0
169
+ total_count: int = 0
170
+
171
+ @property
172
+ def violation_count(self) -> int:
173
+ """Number of violations found."""
174
+ return len(self.violations)
175
+
176
+ @property
177
+ def max_severity(self) -> float:
178
+ """Maximum severity among violations."""
179
+ if not self.violations:
180
+ return 0.0
181
+ return max(v.severity for v in self.violations)
182
+
183
+
184
+ @dataclass
185
+ class RAGDocument:
186
+ """A document from RAG retrieval."""
187
+
188
+ content: str
189
+ source: str | None = None
190
+ score: float = 0.0
191
+ metadata: dict[str, Any] = field(default_factory=dict)
192
+
193
+
194
+ @dataclass
195
+ class RAGScanResult:
196
+ """Result of scanning RAG documents for poisoning."""
197
+
198
+ safe: bool
199
+ documents: list[RAGDocument]
200
+ poisoned_indices: list[int] = field(default_factory=list)
201
+ violations: list[IntegrityViolation] = field(default_factory=list)
202
+
203
+ @property
204
+ def safe_documents(self) -> list[RAGDocument]:
205
+ """Get only the safe documents."""
206
+ return [
207
+ doc for i, doc in enumerate(self.documents)
208
+ if i not in self.poisoned_indices
209
+ ]
210
+
211
+
212
+ # RAG poisoning detection patterns
213
+ RAG_POISON_PATTERNS: list[tuple[str, str, float]] = [
214
+ # (pattern, description, severity)
215
+ (r"(?i)ignore\s+(all\s+)?(previous|prior|above)\s+(instructions?|prompts?|rules?)",
216
+ "Instruction override attempt", 0.95),
217
+ (r"(?i)you\s+are\s+now\s+(\w+\s+)?(mode|persona|character)",
218
+ "Role/persona injection", 0.9),
219
+ (r"(?i)system\s*:\s*you\s+(are|must|should|will)",
220
+ "Fake system message", 0.95),
221
+ (r"(?i)\[/?INST\]|\[/?SYS\]|<\|im_start\|>|<\|im_end\|>",
222
+ "Model delimiter injection", 0.9),
223
+ (r"(?i)admin\s+(mode|access|override)\s*(enabled|activated|on)",
224
+ "Privilege escalation attempt", 0.85),
225
+ (r"(?i)(reveal|show|display|print)\s+(your\s+)?(system\s+)?(prompt|instructions)",
226
+ "System prompt extraction", 0.8),
227
+ (r"(?i)forget\s+(everything|all|what)\s+(you\s+)?(know|learned|were\s+told)",
228
+ "Memory wipe attempt", 0.85),
229
+ (r"(?i)from\s+now\s+on\s*,?\s*(you\s+)?(will|must|should|are)",
230
+ "Behavioral override", 0.8),
231
+ (r"(?i)disregard\s+(all\s+)?(safety|security|ethical)\s+(guidelines?|rules?|constraints?)",
232
+ "Safety bypass attempt", 0.95),
233
+ (r"(?i)execute\s+(this\s+)?(code|script|command)\s*:",
234
+ "Code execution injection", 0.9),
235
+ ]
236
+
237
+
238
+ class MemoryIntegrityGuard:
239
+ """
240
+ Cryptographic integrity guard for agent memory and context.
241
+
242
+ Provides:
243
+ - HMAC signing of messages
244
+ - Hash chain verification
245
+ - Sequence validation
246
+ - Timestamp anomaly detection
247
+ - RAG poisoning detection
248
+
249
+ Example:
250
+ >>> guard = MemoryIntegrityGuard(secret_key="your-key")
251
+ >>>
252
+ >>> # Build a signed context
253
+ >>> context = []
254
+ >>> context.append(guard.sign_message("system", "You are helpful."))
255
+ >>> context.append(guard.sign_message("user", "Hello!"))
256
+ >>>
257
+ >>> # Verify before using
258
+ >>> result = guard.verify_context(context)
259
+ >>> if not result.valid:
260
+ ... for v in result.violations:
261
+ ... print(f"Violation: {v.message}")
262
+ """
263
+
264
+ GENESIS_HASH = "0" * 64
265
+
266
+ def __init__(
267
+ self,
268
+ secret_key: str | bytes,
269
+ max_timestamp_drift: float = 60.0,
270
+ max_context_size: int = 1000,
271
+ enable_rag_scan: bool = True,
272
+ custom_rag_patterns: list[tuple[str, str, float]] | None = None,
273
+ ) -> None:
274
+ """
275
+ Initialize the integrity guard.
276
+
277
+ Args:
278
+ secret_key: Secret key for HMAC signatures.
279
+ max_timestamp_drift: Max allowed time difference in seconds.
280
+ max_context_size: Maximum allowed context messages.
281
+ enable_rag_scan: Enable RAG poisoning detection.
282
+ custom_rag_patterns: Additional RAG poisoning patterns.
283
+ """
284
+ if isinstance(secret_key, str):
285
+ secret_key = secret_key.encode()
286
+
287
+ self._secret_key = secret_key
288
+ self._max_timestamp_drift = max_timestamp_drift
289
+ self._max_context_size = max_context_size
290
+ self._enable_rag_scan = enable_rag_scan
291
+
292
+ # Compile RAG patterns
293
+ self._rag_patterns: list[tuple[re.Pattern[str], str, float]] = []
294
+ for pattern, desc, severity in RAG_POISON_PATTERNS:
295
+ self._rag_patterns.append((re.compile(pattern), desc, severity))
296
+
297
+ if custom_rag_patterns:
298
+ for pattern, desc, severity in custom_rag_patterns:
299
+ self._rag_patterns.append((re.compile(pattern), desc, severity))
300
+
301
+ self._sequence_counter = 0
302
+ self._last_hash = self.GENESIS_HASH
303
+ self._lock = threading.RLock()
304
+
305
+ logger.debug("MemoryIntegrityGuard initialized")
306
+
307
+ def _compute_signature(
308
+ self,
309
+ role: str,
310
+ content: str,
311
+ sequence: int,
312
+ timestamp: float,
313
+ previous_hash: str,
314
+ ) -> str:
315
+ """Compute HMAC signature for a message."""
316
+ message = f"{role}|{content}|{sequence}|{timestamp}|{previous_hash}"
317
+ return hmac.new(
318
+ self._secret_key,
319
+ message.encode(),
320
+ hashlib.sha256,
321
+ ).hexdigest()
322
+
323
+ def sign_message(
324
+ self,
325
+ role: str,
326
+ content: str,
327
+ metadata: dict[str, Any] | None = None,
328
+ ) -> SignedMessage:
329
+ """
330
+ Create a signed message.
331
+
332
+ Args:
333
+ role: Message role (system, user, assistant, tool).
334
+ content: Message content.
335
+ metadata: Optional metadata.
336
+
337
+ Returns:
338
+ SignedMessage with cryptographic signature.
339
+ """
340
+ with self._lock:
341
+ timestamp = time.time()
342
+ sequence = self._sequence_counter
343
+ self._sequence_counter += 1
344
+
345
+ signature = self._compute_signature(
346
+ role, content, sequence, timestamp, self._last_hash
347
+ )
348
+
349
+ msg = SignedMessage(
350
+ role=role,
351
+ content=content,
352
+ sequence=sequence,
353
+ timestamp=timestamp,
354
+ signature=signature,
355
+ previous_hash=self._last_hash,
356
+ metadata=metadata or {},
357
+ )
358
+
359
+ # Update chain
360
+ self._last_hash = msg.content_hash()
361
+
362
+ return msg
363
+
364
+ def verify_message(
365
+ self,
366
+ message: SignedMessage,
367
+ expected_previous_hash: str | None = None,
368
+ ) -> tuple[bool, IntegrityViolation | None]:
369
+ """
370
+ Verify a single message's signature.
371
+
372
+ Args:
373
+ message: The message to verify.
374
+ expected_previous_hash: Expected previous hash for chain verification.
375
+
376
+ Returns:
377
+ Tuple of (valid, violation or None).
378
+ """
379
+ # Verify signature
380
+ expected_sig = self._compute_signature(
381
+ message.role,
382
+ message.content,
383
+ message.sequence,
384
+ message.timestamp,
385
+ message.previous_hash,
386
+ )
387
+
388
+ if not hmac.compare_digest(expected_sig, message.signature):
389
+ return False, IntegrityViolation(
390
+ violation_type=IntegrityViolationType.SIGNATURE_MISMATCH,
391
+ message=f"Signature mismatch for message at sequence {message.sequence}",
392
+ severity=1.0,
393
+ index=message.sequence,
394
+ expected=expected_sig[:16] + "...",
395
+ actual=message.signature[:16] + "...",
396
+ )
397
+
398
+ # Verify chain if expected hash provided
399
+ if expected_previous_hash is not None:
400
+ if message.previous_hash != expected_previous_hash:
401
+ return False, IntegrityViolation(
402
+ violation_type=IntegrityViolationType.HASH_CHAIN_BREAK,
403
+ message=f"Hash chain break at sequence {message.sequence}",
404
+ severity=1.0,
405
+ index=message.sequence,
406
+ expected=expected_previous_hash[:16] + "...",
407
+ actual=message.previous_hash[:16] + "...",
408
+ )
409
+
410
+ return True, None
411
+
412
+ def verify_context(
413
+ self,
414
+ context: list[SignedMessage],
415
+ strict_sequence: bool = True,
416
+ check_timestamps: bool = True,
417
+ ) -> VerificationResult:
418
+ """
419
+ Verify an entire context window.
420
+
421
+ Args:
422
+ context: List of signed messages.
423
+ strict_sequence: Require sequential sequence numbers.
424
+ check_timestamps: Verify timestamp ordering.
425
+
426
+ Returns:
427
+ VerificationResult with any violations found.
428
+ """
429
+ violations: list[IntegrityViolation] = []
430
+ verified_count = 0
431
+
432
+ # Check context size
433
+ if len(context) > self._max_context_size:
434
+ violations.append(IntegrityViolation(
435
+ violation_type=IntegrityViolationType.CONTEXT_OVERFLOW,
436
+ message=f"Context size {len(context)} exceeds max {self._max_context_size}",
437
+ severity=0.7,
438
+ ))
439
+
440
+ if not context:
441
+ return VerificationResult(
442
+ valid=True,
443
+ violations=violations,
444
+ verified_count=0,
445
+ total_count=0,
446
+ )
447
+
448
+ # Verify each message
449
+ expected_hash = self.GENESIS_HASH
450
+ prev_sequence = -1
451
+ prev_timestamp = 0.0
452
+
453
+ for i, msg in enumerate(context):
454
+ # Verify signature and chain
455
+ valid, violation = self.verify_message(msg, expected_hash)
456
+ if not valid and violation:
457
+ violations.append(violation)
458
+ else:
459
+ verified_count += 1
460
+
461
+ # Check sequence ordering
462
+ if strict_sequence:
463
+ if msg.sequence != prev_sequence + 1:
464
+ if msg.sequence <= prev_sequence:
465
+ violations.append(IntegrityViolation(
466
+ violation_type=IntegrityViolationType.SEQUENCE_REORDER,
467
+ message=f"Message {i} has sequence {msg.sequence}, expected {prev_sequence + 1}",
468
+ severity=0.9,
469
+ index=i,
470
+ expected=str(prev_sequence + 1),
471
+ actual=str(msg.sequence),
472
+ ))
473
+ else:
474
+ violations.append(IntegrityViolation(
475
+ violation_type=IntegrityViolationType.SEQUENCE_GAP,
476
+ message=f"Gap in sequence: {prev_sequence} -> {msg.sequence}",
477
+ severity=0.8,
478
+ index=i,
479
+ ))
480
+
481
+ # Check timestamp ordering
482
+ if check_timestamps and prev_timestamp > 0:
483
+ if msg.timestamp < prev_timestamp:
484
+ violations.append(IntegrityViolation(
485
+ violation_type=IntegrityViolationType.TIMESTAMP_ANOMALY,
486
+ message=f"Timestamp goes backwards at message {i}",
487
+ severity=0.7,
488
+ index=i,
489
+ ))
490
+ elif msg.timestamp - prev_timestamp > self._max_timestamp_drift:
491
+ # Large gap might indicate injection
492
+ violations.append(IntegrityViolation(
493
+ violation_type=IntegrityViolationType.TIMESTAMP_ANOMALY,
494
+ message=f"Large timestamp gap at message {i}: {msg.timestamp - prev_timestamp:.1f}s",
495
+ severity=0.5,
496
+ index=i,
497
+ ))
498
+
499
+ # Check for role injection in content
500
+ role_injection = self._detect_role_injection(msg.content)
501
+ if role_injection:
502
+ violations.append(IntegrityViolation(
503
+ violation_type=IntegrityViolationType.ROLE_INJECTION,
504
+ message=f"Role injection detected in message {i}: {role_injection}",
505
+ severity=0.85,
506
+ index=i,
507
+ ))
508
+
509
+ # Update state for next iteration
510
+ expected_hash = msg.content_hash()
511
+ prev_sequence = msg.sequence
512
+ prev_timestamp = msg.timestamp
513
+
514
+ return VerificationResult(
515
+ valid=len(violations) == 0,
516
+ violations=violations,
517
+ verified_count=verified_count,
518
+ total_count=len(context),
519
+ )
520
+
521
+ def _detect_role_injection(self, content: str) -> str | None:
522
+ """Detect role injection attempts in message content."""
523
+ # Look for fake role prefixes
524
+ patterns = [
525
+ (r"(?i)^(system|assistant|user|tool)\s*:\s*", "Role prefix injection"),
526
+ (r"(?i)\n(system|assistant|user|tool)\s*:\s*", "Inline role injection"),
527
+ (r"(?i)<\|(system|assistant|user|tool)\|>", "Delimiter role injection"),
528
+ (r"(?i)\[INST\]|\[/INST\]", "Llama instruction delimiters"),
529
+ (r"(?i)<\|im_start\|>(system|user|assistant)", "ChatML injection"),
530
+ ]
531
+
532
+ for pattern, description in patterns:
533
+ if re.search(pattern, content):
534
+ return description
535
+
536
+ return None
537
+
538
+ def scan_rag_documents(
539
+ self,
540
+ documents: list[RAGDocument] | list[str] | list[dict[str, Any]],
541
+ ) -> RAGScanResult:
542
+ """
543
+ Scan RAG documents for poisoning attempts.
544
+
545
+ Args:
546
+ documents: Documents to scan. Can be RAGDocument objects,
547
+ plain strings, or dicts with 'content' key.
548
+
549
+ Returns:
550
+ RAGScanResult with safe documents and violations.
551
+ """
552
+ # Normalize to RAGDocument
553
+ normalized: list[RAGDocument] = []
554
+ for doc in documents:
555
+ if isinstance(doc, RAGDocument):
556
+ normalized.append(doc)
557
+ elif isinstance(doc, str):
558
+ normalized.append(RAGDocument(content=doc))
559
+ elif isinstance(doc, dict):
560
+ normalized.append(RAGDocument(
561
+ content=doc.get("content", doc.get("text", str(doc))),
562
+ source=doc.get("source"),
563
+ score=doc.get("score", 0.0),
564
+ metadata=doc.get("metadata", {}),
565
+ ))
566
+ else:
567
+ normalized.append(RAGDocument(content=str(doc)))
568
+
569
+ violations: list[IntegrityViolation] = []
570
+ poisoned_indices: list[int] = []
571
+
572
+ for i, doc in enumerate(normalized):
573
+ doc_violations = self._scan_document_content(doc.content, i)
574
+ if doc_violations:
575
+ violations.extend(doc_violations)
576
+ poisoned_indices.append(i)
577
+
578
+ return RAGScanResult(
579
+ safe=len(poisoned_indices) == 0,
580
+ documents=normalized,
581
+ poisoned_indices=poisoned_indices,
582
+ violations=violations,
583
+ )
584
+
585
+ def _scan_document_content(
586
+ self,
587
+ content: str,
588
+ index: int,
589
+ ) -> list[IntegrityViolation]:
590
+ """Scan a single document's content for poisoning."""
591
+ violations: list[IntegrityViolation] = []
592
+
593
+ for pattern, description, severity in self._rag_patterns:
594
+ if pattern.search(content):
595
+ violations.append(IntegrityViolation(
596
+ violation_type=IntegrityViolationType.RAG_POISONING,
597
+ message=f"RAG poisoning detected in document {index}: {description}",
598
+ severity=severity,
599
+ index=index,
600
+ ))
601
+
602
+ return violations
603
+
604
+ def reset(self) -> None:
605
+ """Reset the guard state (sequence counter and hash chain)."""
606
+ with self._lock:
607
+ self._sequence_counter = 0
608
+ self._last_hash = self.GENESIS_HASH
609
+
610
+ def get_state(self) -> dict[str, Any]:
611
+ """Get current guard state for serialization."""
612
+ with self._lock:
613
+ return {
614
+ "sequence_counter": self._sequence_counter,
615
+ "last_hash": self._last_hash,
616
+ }
617
+
618
+ def restore_state(self, state: dict[str, Any]) -> None:
619
+ """Restore guard state from serialization."""
620
+ with self._lock:
621
+ self._sequence_counter = state.get("sequence_counter", 0)
622
+ self._last_hash = state.get("last_hash", self.GENESIS_HASH)
623
+
624
+
625
+ class ContextWindowGuard:
626
+ """
627
+ High-level guard for managing signed context windows.
628
+
629
+ Provides a simple API for building and verifying context
630
+ that will be sent to an LLM.
631
+
632
+ Example:
633
+ >>> guard = ContextWindowGuard(secret_key="key")
634
+ >>>
635
+ >>> # Build context
636
+ >>> guard.add_system("You are a helpful assistant.")
637
+ >>> guard.add_user("Hello!")
638
+ >>> guard.add_assistant("Hi there! How can I help?")
639
+ >>>
640
+ >>> # Get verified context for LLM
641
+ >>> messages = guard.get_verified_messages()
642
+ >>> response = llm.generate(messages)
643
+ >>>
644
+ >>> # Add response to context
645
+ >>> guard.add_assistant(response)
646
+ """
647
+
648
+ def __init__(
649
+ self,
650
+ secret_key: str | bytes,
651
+ **kwargs: Any,
652
+ ) -> None:
653
+ """
654
+ Initialize the context window guard.
655
+
656
+ Args:
657
+ secret_key: Secret key for signing.
658
+ **kwargs: Additional args passed to MemoryIntegrityGuard.
659
+ """
660
+ self._guard = MemoryIntegrityGuard(secret_key, **kwargs)
661
+ self._messages: list[SignedMessage] = []
662
+ self._lock = threading.RLock()
663
+
664
+ def add_system(self, content: str, **metadata: Any) -> SignedMessage:
665
+ """Add a system message."""
666
+ return self._add_message("system", content, metadata)
667
+
668
+ def add_user(self, content: str, **metadata: Any) -> SignedMessage:
669
+ """Add a user message."""
670
+ return self._add_message("user", content, metadata)
671
+
672
+ def add_assistant(self, content: str, **metadata: Any) -> SignedMessage:
673
+ """Add an assistant message."""
674
+ return self._add_message("assistant", content, metadata)
675
+
676
+ def add_tool(self, content: str, tool_name: str = "", **metadata: Any) -> SignedMessage:
677
+ """Add a tool result message."""
678
+ metadata["tool_name"] = tool_name
679
+ return self._add_message("tool", content, metadata)
680
+
681
+ def _add_message(
682
+ self,
683
+ role: str,
684
+ content: str,
685
+ metadata: dict[str, Any],
686
+ ) -> SignedMessage:
687
+ """Add a message with the given role."""
688
+ with self._lock:
689
+ msg = self._guard.sign_message(role, content, metadata)
690
+ self._messages.append(msg)
691
+ return msg
692
+
693
+ def verify(self) -> VerificationResult:
694
+ """Verify the current context."""
695
+ with self._lock:
696
+ return self._guard.verify_context(self._messages)
697
+
698
+ def get_messages(self) -> list[SignedMessage]:
699
+ """Get all messages (unverified)."""
700
+ with self._lock:
701
+ return list(self._messages)
702
+
703
+ def get_verified_messages(self) -> list[dict[str, str]]:
704
+ """
705
+ Get messages if context is valid, else raise.
706
+
707
+ Returns:
708
+ List of message dicts suitable for LLM API.
709
+
710
+ Raises:
711
+ ContextIntegrityError: If verification fails.
712
+ """
713
+ result = self.verify()
714
+ if not result.valid:
715
+ from proxilion.exceptions import ContextIntegrityError
716
+ raise ContextIntegrityError(
717
+ f"Context integrity violated: {result.violations[0].message}",
718
+ violations=result.violations,
719
+ )
720
+
721
+ with self._lock:
722
+ return [
723
+ {"role": msg.role, "content": msg.content}
724
+ for msg in self._messages
725
+ ]
726
+
727
+ def get_messages_for_api(self) -> list[dict[str, str]]:
728
+ """Get messages in API format (role, content only)."""
729
+ with self._lock:
730
+ return [
731
+ {"role": msg.role, "content": msg.content}
732
+ for msg in self._messages
733
+ ]
734
+
735
+ def clear(self) -> None:
736
+ """Clear all messages and reset state."""
737
+ with self._lock:
738
+ self._messages.clear()
739
+ self._guard.reset()
740
+
741
+ def pop(self) -> SignedMessage | None:
742
+ """Remove and return the last message."""
743
+ with self._lock:
744
+ if self._messages:
745
+ return self._messages.pop()
746
+ return None
747
+
748
+ def __len__(self) -> int:
749
+ """Number of messages in context."""
750
+ return len(self._messages)
751
+
752
+
753
+ # Convenience exports
754
+ __all__ = [
755
+ # Core classes
756
+ "MemoryIntegrityGuard",
757
+ "ContextWindowGuard",
758
+ # Data classes
759
+ "SignedMessage",
760
+ "VerificationResult",
761
+ "IntegrityViolation",
762
+ "IntegrityViolationType",
763
+ "RAGDocument",
764
+ "RAGScanResult",
765
+ # Patterns
766
+ "RAG_POISON_PATTERNS",
767
+ ]