proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,767 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Memory and Context Integrity for Proxilion.
|
|
3
|
+
|
|
4
|
+
Addresses OWASP ASI06: Memory & Context Poisoning.
|
|
5
|
+
|
|
6
|
+
This module provides cryptographic integrity verification for:
|
|
7
|
+
- Conversation history / message windows
|
|
8
|
+
- Vector store embeddings and retrieved documents
|
|
9
|
+
- Long-term memory (knowledge graphs, user preferences)
|
|
10
|
+
- RAG context injection detection
|
|
11
|
+
|
|
12
|
+
Memory poisoning attacks inject malicious content into an agent's
|
|
13
|
+
persistent memory, causing incorrect behavior days or weeks later.
|
|
14
|
+
This module detects such tampering.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
>>> from proxilion.security.memory_integrity import (
|
|
18
|
+
... MemoryIntegrityGuard,
|
|
19
|
+
... ContextWindow,
|
|
20
|
+
... SignedMessage,
|
|
21
|
+
... )
|
|
22
|
+
>>>
|
|
23
|
+
>>> guard = MemoryIntegrityGuard(secret_key="your-secret-key")
|
|
24
|
+
>>>
|
|
25
|
+
>>> # Sign messages as they're added
|
|
26
|
+
>>> msg = guard.sign_message(role="user", content="Hello")
|
|
27
|
+
>>> context.append(msg)
|
|
28
|
+
>>>
|
|
29
|
+
>>> # Verify context before sending to LLM
|
|
30
|
+
>>> if guard.verify_context(context):
|
|
31
|
+
... response = llm.generate(context)
|
|
32
|
+
... else:
|
|
33
|
+
... raise ContextTamperingError("Context has been modified")
|
|
34
|
+
>>>
|
|
35
|
+
>>> # Detect RAG poisoning
|
|
36
|
+
>>> docs = retriever.get_relevant_docs(query)
|
|
37
|
+
>>> safe_docs = guard.scan_rag_documents(docs)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from __future__ import annotations
|
|
41
|
+
|
|
42
|
+
import hashlib
|
|
43
|
+
import hmac
|
|
44
|
+
import json
|
|
45
|
+
import logging
|
|
46
|
+
import re
|
|
47
|
+
import threading
|
|
48
|
+
import time
|
|
49
|
+
from dataclasses import dataclass, field
|
|
50
|
+
from datetime import datetime, timezone
|
|
51
|
+
from enum import Enum
|
|
52
|
+
from typing import Any, Protocol, runtime_checkable
|
|
53
|
+
|
|
54
|
+
logger = logging.getLogger(__name__)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class IntegrityViolationType(Enum):
|
|
58
|
+
"""Types of integrity violations."""
|
|
59
|
+
|
|
60
|
+
SIGNATURE_MISMATCH = "signature_mismatch"
|
|
61
|
+
"""Message signature doesn't match content."""
|
|
62
|
+
|
|
63
|
+
SEQUENCE_GAP = "sequence_gap"
|
|
64
|
+
"""Missing messages in sequence."""
|
|
65
|
+
|
|
66
|
+
SEQUENCE_REORDER = "sequence_reorder"
|
|
67
|
+
"""Messages out of order."""
|
|
68
|
+
|
|
69
|
+
TIMESTAMP_ANOMALY = "timestamp_anomaly"
|
|
70
|
+
"""Timestamp inconsistency detected."""
|
|
71
|
+
|
|
72
|
+
ROLE_INJECTION = "role_injection"
|
|
73
|
+
"""Attempted role spoofing in content."""
|
|
74
|
+
|
|
75
|
+
RAG_POISONING = "rag_poisoning"
|
|
76
|
+
"""Malicious content in retrieved documents."""
|
|
77
|
+
|
|
78
|
+
CONTEXT_OVERFLOW = "context_overflow"
|
|
79
|
+
"""Context exceeds expected bounds."""
|
|
80
|
+
|
|
81
|
+
HASH_CHAIN_BREAK = "hash_chain_break"
|
|
82
|
+
"""Hash chain integrity violated."""
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class IntegrityViolation:
|
|
87
|
+
"""Details of an integrity violation."""
|
|
88
|
+
|
|
89
|
+
violation_type: IntegrityViolationType
|
|
90
|
+
message: str
|
|
91
|
+
severity: float # 0.0 to 1.0
|
|
92
|
+
index: int | None = None # Position in context where violation occurred
|
|
93
|
+
expected: str | None = None
|
|
94
|
+
actual: str | None = None
|
|
95
|
+
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
96
|
+
|
|
97
|
+
def to_dict(self) -> dict[str, Any]:
|
|
98
|
+
"""Convert to dictionary."""
|
|
99
|
+
return {
|
|
100
|
+
"type": self.violation_type.value,
|
|
101
|
+
"message": self.message,
|
|
102
|
+
"severity": self.severity,
|
|
103
|
+
"index": self.index,
|
|
104
|
+
"expected": self.expected,
|
|
105
|
+
"actual": self.actual,
|
|
106
|
+
"timestamp": self.timestamp.isoformat(),
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class SignedMessage:
|
|
112
|
+
"""
|
|
113
|
+
A cryptographically signed message for context integrity.
|
|
114
|
+
|
|
115
|
+
Each message includes:
|
|
116
|
+
- Content and role
|
|
117
|
+
- Sequence number for ordering
|
|
118
|
+
- Timestamp for temporal verification
|
|
119
|
+
- HMAC signature for tamper detection
|
|
120
|
+
- Previous hash for chain integrity
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
role: str
|
|
124
|
+
content: str
|
|
125
|
+
sequence: int
|
|
126
|
+
timestamp: float
|
|
127
|
+
signature: str
|
|
128
|
+
previous_hash: str
|
|
129
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
130
|
+
|
|
131
|
+
def to_dict(self) -> dict[str, Any]:
|
|
132
|
+
"""Convert to dictionary for serialization."""
|
|
133
|
+
return {
|
|
134
|
+
"role": self.role,
|
|
135
|
+
"content": self.content,
|
|
136
|
+
"sequence": self.sequence,
|
|
137
|
+
"timestamp": self.timestamp,
|
|
138
|
+
"signature": self.signature,
|
|
139
|
+
"previous_hash": self.previous_hash,
|
|
140
|
+
"metadata": self.metadata,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def from_dict(cls, data: dict[str, Any]) -> SignedMessage:
|
|
145
|
+
"""Create from dictionary."""
|
|
146
|
+
return cls(
|
|
147
|
+
role=data["role"],
|
|
148
|
+
content=data["content"],
|
|
149
|
+
sequence=data["sequence"],
|
|
150
|
+
timestamp=data["timestamp"],
|
|
151
|
+
signature=data["signature"],
|
|
152
|
+
previous_hash=data["previous_hash"],
|
|
153
|
+
metadata=data.get("metadata", {}),
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def content_hash(self) -> str:
|
|
157
|
+
"""Get hash of message content for chaining."""
|
|
158
|
+
content = f"{self.role}:{self.content}:{self.sequence}:{self.timestamp}"
|
|
159
|
+
return hashlib.sha256(content.encode()).hexdigest()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@dataclass
|
|
163
|
+
class VerificationResult:
|
|
164
|
+
"""Result of context verification."""
|
|
165
|
+
|
|
166
|
+
valid: bool
|
|
167
|
+
violations: list[IntegrityViolation] = field(default_factory=list)
|
|
168
|
+
verified_count: int = 0
|
|
169
|
+
total_count: int = 0
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def violation_count(self) -> int:
|
|
173
|
+
"""Number of violations found."""
|
|
174
|
+
return len(self.violations)
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def max_severity(self) -> float:
|
|
178
|
+
"""Maximum severity among violations."""
|
|
179
|
+
if not self.violations:
|
|
180
|
+
return 0.0
|
|
181
|
+
return max(v.severity for v in self.violations)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
@dataclass
|
|
185
|
+
class RAGDocument:
|
|
186
|
+
"""A document from RAG retrieval."""
|
|
187
|
+
|
|
188
|
+
content: str
|
|
189
|
+
source: str | None = None
|
|
190
|
+
score: float = 0.0
|
|
191
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dataclass
|
|
195
|
+
class RAGScanResult:
|
|
196
|
+
"""Result of scanning RAG documents for poisoning."""
|
|
197
|
+
|
|
198
|
+
safe: bool
|
|
199
|
+
documents: list[RAGDocument]
|
|
200
|
+
poisoned_indices: list[int] = field(default_factory=list)
|
|
201
|
+
violations: list[IntegrityViolation] = field(default_factory=list)
|
|
202
|
+
|
|
203
|
+
@property
|
|
204
|
+
def safe_documents(self) -> list[RAGDocument]:
|
|
205
|
+
"""Get only the safe documents."""
|
|
206
|
+
return [
|
|
207
|
+
doc for i, doc in enumerate(self.documents)
|
|
208
|
+
if i not in self.poisoned_indices
|
|
209
|
+
]
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
# RAG poisoning detection patterns
|
|
213
|
+
RAG_POISON_PATTERNS: list[tuple[str, str, float]] = [
|
|
214
|
+
# (pattern, description, severity)
|
|
215
|
+
(r"(?i)ignore\s+(all\s+)?(previous|prior|above)\s+(instructions?|prompts?|rules?)",
|
|
216
|
+
"Instruction override attempt", 0.95),
|
|
217
|
+
(r"(?i)you\s+are\s+now\s+(\w+\s+)?(mode|persona|character)",
|
|
218
|
+
"Role/persona injection", 0.9),
|
|
219
|
+
(r"(?i)system\s*:\s*you\s+(are|must|should|will)",
|
|
220
|
+
"Fake system message", 0.95),
|
|
221
|
+
(r"(?i)\[/?INST\]|\[/?SYS\]|<\|im_start\|>|<\|im_end\|>",
|
|
222
|
+
"Model delimiter injection", 0.9),
|
|
223
|
+
(r"(?i)admin\s+(mode|access|override)\s*(enabled|activated|on)",
|
|
224
|
+
"Privilege escalation attempt", 0.85),
|
|
225
|
+
(r"(?i)(reveal|show|display|print)\s+(your\s+)?(system\s+)?(prompt|instructions)",
|
|
226
|
+
"System prompt extraction", 0.8),
|
|
227
|
+
(r"(?i)forget\s+(everything|all|what)\s+(you\s+)?(know|learned|were\s+told)",
|
|
228
|
+
"Memory wipe attempt", 0.85),
|
|
229
|
+
(r"(?i)from\s+now\s+on\s*,?\s*(you\s+)?(will|must|should|are)",
|
|
230
|
+
"Behavioral override", 0.8),
|
|
231
|
+
(r"(?i)disregard\s+(all\s+)?(safety|security|ethical)\s+(guidelines?|rules?|constraints?)",
|
|
232
|
+
"Safety bypass attempt", 0.95),
|
|
233
|
+
(r"(?i)execute\s+(this\s+)?(code|script|command)\s*:",
|
|
234
|
+
"Code execution injection", 0.9),
|
|
235
|
+
]
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class MemoryIntegrityGuard:
|
|
239
|
+
"""
|
|
240
|
+
Cryptographic integrity guard for agent memory and context.
|
|
241
|
+
|
|
242
|
+
Provides:
|
|
243
|
+
- HMAC signing of messages
|
|
244
|
+
- Hash chain verification
|
|
245
|
+
- Sequence validation
|
|
246
|
+
- Timestamp anomaly detection
|
|
247
|
+
- RAG poisoning detection
|
|
248
|
+
|
|
249
|
+
Example:
|
|
250
|
+
>>> guard = MemoryIntegrityGuard(secret_key="your-key")
|
|
251
|
+
>>>
|
|
252
|
+
>>> # Build a signed context
|
|
253
|
+
>>> context = []
|
|
254
|
+
>>> context.append(guard.sign_message("system", "You are helpful."))
|
|
255
|
+
>>> context.append(guard.sign_message("user", "Hello!"))
|
|
256
|
+
>>>
|
|
257
|
+
>>> # Verify before using
|
|
258
|
+
>>> result = guard.verify_context(context)
|
|
259
|
+
>>> if not result.valid:
|
|
260
|
+
... for v in result.violations:
|
|
261
|
+
... print(f"Violation: {v.message}")
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
GENESIS_HASH = "0" * 64
|
|
265
|
+
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
secret_key: str | bytes,
|
|
269
|
+
max_timestamp_drift: float = 60.0,
|
|
270
|
+
max_context_size: int = 1000,
|
|
271
|
+
enable_rag_scan: bool = True,
|
|
272
|
+
custom_rag_patterns: list[tuple[str, str, float]] | None = None,
|
|
273
|
+
) -> None:
|
|
274
|
+
"""
|
|
275
|
+
Initialize the integrity guard.
|
|
276
|
+
|
|
277
|
+
Args:
|
|
278
|
+
secret_key: Secret key for HMAC signatures.
|
|
279
|
+
max_timestamp_drift: Max allowed time difference in seconds.
|
|
280
|
+
max_context_size: Maximum allowed context messages.
|
|
281
|
+
enable_rag_scan: Enable RAG poisoning detection.
|
|
282
|
+
custom_rag_patterns: Additional RAG poisoning patterns.
|
|
283
|
+
"""
|
|
284
|
+
if isinstance(secret_key, str):
|
|
285
|
+
secret_key = secret_key.encode()
|
|
286
|
+
|
|
287
|
+
self._secret_key = secret_key
|
|
288
|
+
self._max_timestamp_drift = max_timestamp_drift
|
|
289
|
+
self._max_context_size = max_context_size
|
|
290
|
+
self._enable_rag_scan = enable_rag_scan
|
|
291
|
+
|
|
292
|
+
# Compile RAG patterns
|
|
293
|
+
self._rag_patterns: list[tuple[re.Pattern[str], str, float]] = []
|
|
294
|
+
for pattern, desc, severity in RAG_POISON_PATTERNS:
|
|
295
|
+
self._rag_patterns.append((re.compile(pattern), desc, severity))
|
|
296
|
+
|
|
297
|
+
if custom_rag_patterns:
|
|
298
|
+
for pattern, desc, severity in custom_rag_patterns:
|
|
299
|
+
self._rag_patterns.append((re.compile(pattern), desc, severity))
|
|
300
|
+
|
|
301
|
+
self._sequence_counter = 0
|
|
302
|
+
self._last_hash = self.GENESIS_HASH
|
|
303
|
+
self._lock = threading.RLock()
|
|
304
|
+
|
|
305
|
+
logger.debug("MemoryIntegrityGuard initialized")
|
|
306
|
+
|
|
307
|
+
def _compute_signature(
|
|
308
|
+
self,
|
|
309
|
+
role: str,
|
|
310
|
+
content: str,
|
|
311
|
+
sequence: int,
|
|
312
|
+
timestamp: float,
|
|
313
|
+
previous_hash: str,
|
|
314
|
+
) -> str:
|
|
315
|
+
"""Compute HMAC signature for a message."""
|
|
316
|
+
message = f"{role}|{content}|{sequence}|{timestamp}|{previous_hash}"
|
|
317
|
+
return hmac.new(
|
|
318
|
+
self._secret_key,
|
|
319
|
+
message.encode(),
|
|
320
|
+
hashlib.sha256,
|
|
321
|
+
).hexdigest()
|
|
322
|
+
|
|
323
|
+
def sign_message(
|
|
324
|
+
self,
|
|
325
|
+
role: str,
|
|
326
|
+
content: str,
|
|
327
|
+
metadata: dict[str, Any] | None = None,
|
|
328
|
+
) -> SignedMessage:
|
|
329
|
+
"""
|
|
330
|
+
Create a signed message.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
role: Message role (system, user, assistant, tool).
|
|
334
|
+
content: Message content.
|
|
335
|
+
metadata: Optional metadata.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
SignedMessage with cryptographic signature.
|
|
339
|
+
"""
|
|
340
|
+
with self._lock:
|
|
341
|
+
timestamp = time.time()
|
|
342
|
+
sequence = self._sequence_counter
|
|
343
|
+
self._sequence_counter += 1
|
|
344
|
+
|
|
345
|
+
signature = self._compute_signature(
|
|
346
|
+
role, content, sequence, timestamp, self._last_hash
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
msg = SignedMessage(
|
|
350
|
+
role=role,
|
|
351
|
+
content=content,
|
|
352
|
+
sequence=sequence,
|
|
353
|
+
timestamp=timestamp,
|
|
354
|
+
signature=signature,
|
|
355
|
+
previous_hash=self._last_hash,
|
|
356
|
+
metadata=metadata or {},
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# Update chain
|
|
360
|
+
self._last_hash = msg.content_hash()
|
|
361
|
+
|
|
362
|
+
return msg
|
|
363
|
+
|
|
364
|
+
def verify_message(
|
|
365
|
+
self,
|
|
366
|
+
message: SignedMessage,
|
|
367
|
+
expected_previous_hash: str | None = None,
|
|
368
|
+
) -> tuple[bool, IntegrityViolation | None]:
|
|
369
|
+
"""
|
|
370
|
+
Verify a single message's signature.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
message: The message to verify.
|
|
374
|
+
expected_previous_hash: Expected previous hash for chain verification.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Tuple of (valid, violation or None).
|
|
378
|
+
"""
|
|
379
|
+
# Verify signature
|
|
380
|
+
expected_sig = self._compute_signature(
|
|
381
|
+
message.role,
|
|
382
|
+
message.content,
|
|
383
|
+
message.sequence,
|
|
384
|
+
message.timestamp,
|
|
385
|
+
message.previous_hash,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if not hmac.compare_digest(expected_sig, message.signature):
|
|
389
|
+
return False, IntegrityViolation(
|
|
390
|
+
violation_type=IntegrityViolationType.SIGNATURE_MISMATCH,
|
|
391
|
+
message=f"Signature mismatch for message at sequence {message.sequence}",
|
|
392
|
+
severity=1.0,
|
|
393
|
+
index=message.sequence,
|
|
394
|
+
expected=expected_sig[:16] + "...",
|
|
395
|
+
actual=message.signature[:16] + "...",
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# Verify chain if expected hash provided
|
|
399
|
+
if expected_previous_hash is not None:
|
|
400
|
+
if message.previous_hash != expected_previous_hash:
|
|
401
|
+
return False, IntegrityViolation(
|
|
402
|
+
violation_type=IntegrityViolationType.HASH_CHAIN_BREAK,
|
|
403
|
+
message=f"Hash chain break at sequence {message.sequence}",
|
|
404
|
+
severity=1.0,
|
|
405
|
+
index=message.sequence,
|
|
406
|
+
expected=expected_previous_hash[:16] + "...",
|
|
407
|
+
actual=message.previous_hash[:16] + "...",
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
return True, None
|
|
411
|
+
|
|
412
|
+
def verify_context(
|
|
413
|
+
self,
|
|
414
|
+
context: list[SignedMessage],
|
|
415
|
+
strict_sequence: bool = True,
|
|
416
|
+
check_timestamps: bool = True,
|
|
417
|
+
) -> VerificationResult:
|
|
418
|
+
"""
|
|
419
|
+
Verify an entire context window.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
context: List of signed messages.
|
|
423
|
+
strict_sequence: Require sequential sequence numbers.
|
|
424
|
+
check_timestamps: Verify timestamp ordering.
|
|
425
|
+
|
|
426
|
+
Returns:
|
|
427
|
+
VerificationResult with any violations found.
|
|
428
|
+
"""
|
|
429
|
+
violations: list[IntegrityViolation] = []
|
|
430
|
+
verified_count = 0
|
|
431
|
+
|
|
432
|
+
# Check context size
|
|
433
|
+
if len(context) > self._max_context_size:
|
|
434
|
+
violations.append(IntegrityViolation(
|
|
435
|
+
violation_type=IntegrityViolationType.CONTEXT_OVERFLOW,
|
|
436
|
+
message=f"Context size {len(context)} exceeds max {self._max_context_size}",
|
|
437
|
+
severity=0.7,
|
|
438
|
+
))
|
|
439
|
+
|
|
440
|
+
if not context:
|
|
441
|
+
return VerificationResult(
|
|
442
|
+
valid=True,
|
|
443
|
+
violations=violations,
|
|
444
|
+
verified_count=0,
|
|
445
|
+
total_count=0,
|
|
446
|
+
)
|
|
447
|
+
|
|
448
|
+
# Verify each message
|
|
449
|
+
expected_hash = self.GENESIS_HASH
|
|
450
|
+
prev_sequence = -1
|
|
451
|
+
prev_timestamp = 0.0
|
|
452
|
+
|
|
453
|
+
for i, msg in enumerate(context):
|
|
454
|
+
# Verify signature and chain
|
|
455
|
+
valid, violation = self.verify_message(msg, expected_hash)
|
|
456
|
+
if not valid and violation:
|
|
457
|
+
violations.append(violation)
|
|
458
|
+
else:
|
|
459
|
+
verified_count += 1
|
|
460
|
+
|
|
461
|
+
# Check sequence ordering
|
|
462
|
+
if strict_sequence:
|
|
463
|
+
if msg.sequence != prev_sequence + 1:
|
|
464
|
+
if msg.sequence <= prev_sequence:
|
|
465
|
+
violations.append(IntegrityViolation(
|
|
466
|
+
violation_type=IntegrityViolationType.SEQUENCE_REORDER,
|
|
467
|
+
message=f"Message {i} has sequence {msg.sequence}, expected {prev_sequence + 1}",
|
|
468
|
+
severity=0.9,
|
|
469
|
+
index=i,
|
|
470
|
+
expected=str(prev_sequence + 1),
|
|
471
|
+
actual=str(msg.sequence),
|
|
472
|
+
))
|
|
473
|
+
else:
|
|
474
|
+
violations.append(IntegrityViolation(
|
|
475
|
+
violation_type=IntegrityViolationType.SEQUENCE_GAP,
|
|
476
|
+
message=f"Gap in sequence: {prev_sequence} -> {msg.sequence}",
|
|
477
|
+
severity=0.8,
|
|
478
|
+
index=i,
|
|
479
|
+
))
|
|
480
|
+
|
|
481
|
+
# Check timestamp ordering
|
|
482
|
+
if check_timestamps and prev_timestamp > 0:
|
|
483
|
+
if msg.timestamp < prev_timestamp:
|
|
484
|
+
violations.append(IntegrityViolation(
|
|
485
|
+
violation_type=IntegrityViolationType.TIMESTAMP_ANOMALY,
|
|
486
|
+
message=f"Timestamp goes backwards at message {i}",
|
|
487
|
+
severity=0.7,
|
|
488
|
+
index=i,
|
|
489
|
+
))
|
|
490
|
+
elif msg.timestamp - prev_timestamp > self._max_timestamp_drift:
|
|
491
|
+
# Large gap might indicate injection
|
|
492
|
+
violations.append(IntegrityViolation(
|
|
493
|
+
violation_type=IntegrityViolationType.TIMESTAMP_ANOMALY,
|
|
494
|
+
message=f"Large timestamp gap at message {i}: {msg.timestamp - prev_timestamp:.1f}s",
|
|
495
|
+
severity=0.5,
|
|
496
|
+
index=i,
|
|
497
|
+
))
|
|
498
|
+
|
|
499
|
+
# Check for role injection in content
|
|
500
|
+
role_injection = self._detect_role_injection(msg.content)
|
|
501
|
+
if role_injection:
|
|
502
|
+
violations.append(IntegrityViolation(
|
|
503
|
+
violation_type=IntegrityViolationType.ROLE_INJECTION,
|
|
504
|
+
message=f"Role injection detected in message {i}: {role_injection}",
|
|
505
|
+
severity=0.85,
|
|
506
|
+
index=i,
|
|
507
|
+
))
|
|
508
|
+
|
|
509
|
+
# Update state for next iteration
|
|
510
|
+
expected_hash = msg.content_hash()
|
|
511
|
+
prev_sequence = msg.sequence
|
|
512
|
+
prev_timestamp = msg.timestamp
|
|
513
|
+
|
|
514
|
+
return VerificationResult(
|
|
515
|
+
valid=len(violations) == 0,
|
|
516
|
+
violations=violations,
|
|
517
|
+
verified_count=verified_count,
|
|
518
|
+
total_count=len(context),
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
def _detect_role_injection(self, content: str) -> str | None:
|
|
522
|
+
"""Detect role injection attempts in message content."""
|
|
523
|
+
# Look for fake role prefixes
|
|
524
|
+
patterns = [
|
|
525
|
+
(r"(?i)^(system|assistant|user|tool)\s*:\s*", "Role prefix injection"),
|
|
526
|
+
(r"(?i)\n(system|assistant|user|tool)\s*:\s*", "Inline role injection"),
|
|
527
|
+
(r"(?i)<\|(system|assistant|user|tool)\|>", "Delimiter role injection"),
|
|
528
|
+
(r"(?i)\[INST\]|\[/INST\]", "Llama instruction delimiters"),
|
|
529
|
+
(r"(?i)<\|im_start\|>(system|user|assistant)", "ChatML injection"),
|
|
530
|
+
]
|
|
531
|
+
|
|
532
|
+
for pattern, description in patterns:
|
|
533
|
+
if re.search(pattern, content):
|
|
534
|
+
return description
|
|
535
|
+
|
|
536
|
+
return None
|
|
537
|
+
|
|
538
|
+
def scan_rag_documents(
|
|
539
|
+
self,
|
|
540
|
+
documents: list[RAGDocument] | list[str] | list[dict[str, Any]],
|
|
541
|
+
) -> RAGScanResult:
|
|
542
|
+
"""
|
|
543
|
+
Scan RAG documents for poisoning attempts.
|
|
544
|
+
|
|
545
|
+
Args:
|
|
546
|
+
documents: Documents to scan. Can be RAGDocument objects,
|
|
547
|
+
plain strings, or dicts with 'content' key.
|
|
548
|
+
|
|
549
|
+
Returns:
|
|
550
|
+
RAGScanResult with safe documents and violations.
|
|
551
|
+
"""
|
|
552
|
+
# Normalize to RAGDocument
|
|
553
|
+
normalized: list[RAGDocument] = []
|
|
554
|
+
for doc in documents:
|
|
555
|
+
if isinstance(doc, RAGDocument):
|
|
556
|
+
normalized.append(doc)
|
|
557
|
+
elif isinstance(doc, str):
|
|
558
|
+
normalized.append(RAGDocument(content=doc))
|
|
559
|
+
elif isinstance(doc, dict):
|
|
560
|
+
normalized.append(RAGDocument(
|
|
561
|
+
content=doc.get("content", doc.get("text", str(doc))),
|
|
562
|
+
source=doc.get("source"),
|
|
563
|
+
score=doc.get("score", 0.0),
|
|
564
|
+
metadata=doc.get("metadata", {}),
|
|
565
|
+
))
|
|
566
|
+
else:
|
|
567
|
+
normalized.append(RAGDocument(content=str(doc)))
|
|
568
|
+
|
|
569
|
+
violations: list[IntegrityViolation] = []
|
|
570
|
+
poisoned_indices: list[int] = []
|
|
571
|
+
|
|
572
|
+
for i, doc in enumerate(normalized):
|
|
573
|
+
doc_violations = self._scan_document_content(doc.content, i)
|
|
574
|
+
if doc_violations:
|
|
575
|
+
violations.extend(doc_violations)
|
|
576
|
+
poisoned_indices.append(i)
|
|
577
|
+
|
|
578
|
+
return RAGScanResult(
|
|
579
|
+
safe=len(poisoned_indices) == 0,
|
|
580
|
+
documents=normalized,
|
|
581
|
+
poisoned_indices=poisoned_indices,
|
|
582
|
+
violations=violations,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
def _scan_document_content(
|
|
586
|
+
self,
|
|
587
|
+
content: str,
|
|
588
|
+
index: int,
|
|
589
|
+
) -> list[IntegrityViolation]:
|
|
590
|
+
"""Scan a single document's content for poisoning."""
|
|
591
|
+
violations: list[IntegrityViolation] = []
|
|
592
|
+
|
|
593
|
+
for pattern, description, severity in self._rag_patterns:
|
|
594
|
+
if pattern.search(content):
|
|
595
|
+
violations.append(IntegrityViolation(
|
|
596
|
+
violation_type=IntegrityViolationType.RAG_POISONING,
|
|
597
|
+
message=f"RAG poisoning detected in document {index}: {description}",
|
|
598
|
+
severity=severity,
|
|
599
|
+
index=index,
|
|
600
|
+
))
|
|
601
|
+
|
|
602
|
+
return violations
|
|
603
|
+
|
|
604
|
+
def reset(self) -> None:
|
|
605
|
+
"""Reset the guard state (sequence counter and hash chain)."""
|
|
606
|
+
with self._lock:
|
|
607
|
+
self._sequence_counter = 0
|
|
608
|
+
self._last_hash = self.GENESIS_HASH
|
|
609
|
+
|
|
610
|
+
def get_state(self) -> dict[str, Any]:
|
|
611
|
+
"""Get current guard state for serialization."""
|
|
612
|
+
with self._lock:
|
|
613
|
+
return {
|
|
614
|
+
"sequence_counter": self._sequence_counter,
|
|
615
|
+
"last_hash": self._last_hash,
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
def restore_state(self, state: dict[str, Any]) -> None:
|
|
619
|
+
"""Restore guard state from serialization."""
|
|
620
|
+
with self._lock:
|
|
621
|
+
self._sequence_counter = state.get("sequence_counter", 0)
|
|
622
|
+
self._last_hash = state.get("last_hash", self.GENESIS_HASH)
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
class ContextWindowGuard:
|
|
626
|
+
"""
|
|
627
|
+
High-level guard for managing signed context windows.
|
|
628
|
+
|
|
629
|
+
Provides a simple API for building and verifying context
|
|
630
|
+
that will be sent to an LLM.
|
|
631
|
+
|
|
632
|
+
Example:
|
|
633
|
+
>>> guard = ContextWindowGuard(secret_key="key")
|
|
634
|
+
>>>
|
|
635
|
+
>>> # Build context
|
|
636
|
+
>>> guard.add_system("You are a helpful assistant.")
|
|
637
|
+
>>> guard.add_user("Hello!")
|
|
638
|
+
>>> guard.add_assistant("Hi there! How can I help?")
|
|
639
|
+
>>>
|
|
640
|
+
>>> # Get verified context for LLM
|
|
641
|
+
>>> messages = guard.get_verified_messages()
|
|
642
|
+
>>> response = llm.generate(messages)
|
|
643
|
+
>>>
|
|
644
|
+
>>> # Add response to context
|
|
645
|
+
>>> guard.add_assistant(response)
|
|
646
|
+
"""
|
|
647
|
+
|
|
648
|
+
def __init__(
|
|
649
|
+
self,
|
|
650
|
+
secret_key: str | bytes,
|
|
651
|
+
**kwargs: Any,
|
|
652
|
+
) -> None:
|
|
653
|
+
"""
|
|
654
|
+
Initialize the context window guard.
|
|
655
|
+
|
|
656
|
+
Args:
|
|
657
|
+
secret_key: Secret key for signing.
|
|
658
|
+
**kwargs: Additional args passed to MemoryIntegrityGuard.
|
|
659
|
+
"""
|
|
660
|
+
self._guard = MemoryIntegrityGuard(secret_key, **kwargs)
|
|
661
|
+
self._messages: list[SignedMessage] = []
|
|
662
|
+
self._lock = threading.RLock()
|
|
663
|
+
|
|
664
|
+
def add_system(self, content: str, **metadata: Any) -> SignedMessage:
|
|
665
|
+
"""Add a system message."""
|
|
666
|
+
return self._add_message("system", content, metadata)
|
|
667
|
+
|
|
668
|
+
def add_user(self, content: str, **metadata: Any) -> SignedMessage:
|
|
669
|
+
"""Add a user message."""
|
|
670
|
+
return self._add_message("user", content, metadata)
|
|
671
|
+
|
|
672
|
+
def add_assistant(self, content: str, **metadata: Any) -> SignedMessage:
|
|
673
|
+
"""Add an assistant message."""
|
|
674
|
+
return self._add_message("assistant", content, metadata)
|
|
675
|
+
|
|
676
|
+
def add_tool(self, content: str, tool_name: str = "", **metadata: Any) -> SignedMessage:
|
|
677
|
+
"""Add a tool result message."""
|
|
678
|
+
metadata["tool_name"] = tool_name
|
|
679
|
+
return self._add_message("tool", content, metadata)
|
|
680
|
+
|
|
681
|
+
def _add_message(
|
|
682
|
+
self,
|
|
683
|
+
role: str,
|
|
684
|
+
content: str,
|
|
685
|
+
metadata: dict[str, Any],
|
|
686
|
+
) -> SignedMessage:
|
|
687
|
+
"""Add a message with the given role."""
|
|
688
|
+
with self._lock:
|
|
689
|
+
msg = self._guard.sign_message(role, content, metadata)
|
|
690
|
+
self._messages.append(msg)
|
|
691
|
+
return msg
|
|
692
|
+
|
|
693
|
+
def verify(self) -> VerificationResult:
|
|
694
|
+
"""Verify the current context."""
|
|
695
|
+
with self._lock:
|
|
696
|
+
return self._guard.verify_context(self._messages)
|
|
697
|
+
|
|
698
|
+
def get_messages(self) -> list[SignedMessage]:
|
|
699
|
+
"""Get all messages (unverified)."""
|
|
700
|
+
with self._lock:
|
|
701
|
+
return list(self._messages)
|
|
702
|
+
|
|
703
|
+
def get_verified_messages(self) -> list[dict[str, str]]:
|
|
704
|
+
"""
|
|
705
|
+
Get messages if context is valid, else raise.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
List of message dicts suitable for LLM API.
|
|
709
|
+
|
|
710
|
+
Raises:
|
|
711
|
+
ContextIntegrityError: If verification fails.
|
|
712
|
+
"""
|
|
713
|
+
result = self.verify()
|
|
714
|
+
if not result.valid:
|
|
715
|
+
from proxilion.exceptions import ContextIntegrityError
|
|
716
|
+
raise ContextIntegrityError(
|
|
717
|
+
f"Context integrity violated: {result.violations[0].message}",
|
|
718
|
+
violations=result.violations,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
with self._lock:
|
|
722
|
+
return [
|
|
723
|
+
{"role": msg.role, "content": msg.content}
|
|
724
|
+
for msg in self._messages
|
|
725
|
+
]
|
|
726
|
+
|
|
727
|
+
def get_messages_for_api(self) -> list[dict[str, str]]:
|
|
728
|
+
"""Get messages in API format (role, content only)."""
|
|
729
|
+
with self._lock:
|
|
730
|
+
return [
|
|
731
|
+
{"role": msg.role, "content": msg.content}
|
|
732
|
+
for msg in self._messages
|
|
733
|
+
]
|
|
734
|
+
|
|
735
|
+
def clear(self) -> None:
|
|
736
|
+
"""Clear all messages and reset state."""
|
|
737
|
+
with self._lock:
|
|
738
|
+
self._messages.clear()
|
|
739
|
+
self._guard.reset()
|
|
740
|
+
|
|
741
|
+
def pop(self) -> SignedMessage | None:
|
|
742
|
+
"""Remove and return the last message."""
|
|
743
|
+
with self._lock:
|
|
744
|
+
if self._messages:
|
|
745
|
+
return self._messages.pop()
|
|
746
|
+
return None
|
|
747
|
+
|
|
748
|
+
def __len__(self) -> int:
|
|
749
|
+
"""Number of messages in context."""
|
|
750
|
+
return len(self._messages)
|
|
751
|
+
|
|
752
|
+
|
|
753
|
+
# Convenience exports
|
|
754
|
+
__all__ = [
|
|
755
|
+
# Core classes
|
|
756
|
+
"MemoryIntegrityGuard",
|
|
757
|
+
"ContextWindowGuard",
|
|
758
|
+
# Data classes
|
|
759
|
+
"SignedMessage",
|
|
760
|
+
"VerificationResult",
|
|
761
|
+
"IntegrityViolation",
|
|
762
|
+
"IntegrityViolationType",
|
|
763
|
+
"RAGDocument",
|
|
764
|
+
"RAGScanResult",
|
|
765
|
+
# Patterns
|
|
766
|
+
"RAG_POISON_PATTERNS",
|
|
767
|
+
]
|