proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,636 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tool sequence validation for Proxilion.
|
|
3
|
+
|
|
4
|
+
Prevents dangerous tool call sequences by defining allowed/disallowed
|
|
5
|
+
patterns. For example, prevent "delete" from being called without a
|
|
6
|
+
prior "confirm" step.
|
|
7
|
+
|
|
8
|
+
Addresses:
|
|
9
|
+
- OWASP ASI01 (Agent Goal Hijack)
|
|
10
|
+
- OWASP ASI02 (Tool Misuse)
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
>>> from proxilion.security.sequence_validator import (
|
|
14
|
+
... SequenceValidator, SequenceRule, SequenceAction
|
|
15
|
+
... )
|
|
16
|
+
>>>
|
|
17
|
+
>>> validator = SequenceValidator()
|
|
18
|
+
>>>
|
|
19
|
+
>>> # Add rule requiring confirmation before deletion
|
|
20
|
+
>>> validator.add_rule(SequenceRule(
|
|
21
|
+
... name="require_confirm",
|
|
22
|
+
... action=SequenceAction.REQUIRE_BEFORE,
|
|
23
|
+
... target_pattern="delete_*",
|
|
24
|
+
... required_pattern="confirm_*",
|
|
25
|
+
... ))
|
|
26
|
+
>>>
|
|
27
|
+
>>> # Validate a tool call
|
|
28
|
+
>>> allowed, violation = validator.validate_call("delete_file", "user_123")
|
|
29
|
+
>>> if not allowed:
|
|
30
|
+
... print(f"Blocked: {violation.message}")
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from __future__ import annotations
|
|
34
|
+
|
|
35
|
+
import fnmatch
|
|
36
|
+
import logging
|
|
37
|
+
import threading
|
|
38
|
+
from collections import deque
|
|
39
|
+
from dataclasses import dataclass, field
|
|
40
|
+
from datetime import datetime, timezone
|
|
41
|
+
from enum import Enum
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SequenceAction(Enum):
|
|
47
|
+
"""Type of sequence rule."""
|
|
48
|
+
|
|
49
|
+
REQUIRE_BEFORE = "require_before"
|
|
50
|
+
"""Tool X requires tool Y to have been called first."""
|
|
51
|
+
|
|
52
|
+
FORBID_AFTER = "forbid_after"
|
|
53
|
+
"""Tool X cannot be called after tool Y within a time window."""
|
|
54
|
+
|
|
55
|
+
REQUIRE_SEQUENCE = "require_sequence"
|
|
56
|
+
"""Tools must be called in exact order."""
|
|
57
|
+
|
|
58
|
+
MAX_CONSECUTIVE = "max_consecutive"
|
|
59
|
+
"""Maximum times a tool can be called consecutively."""
|
|
60
|
+
|
|
61
|
+
COOLDOWN = "cooldown"
|
|
62
|
+
"""Minimum time between calls to the same tool."""
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
@dataclass
|
|
66
|
+
class SequenceRule:
|
|
67
|
+
"""
|
|
68
|
+
Rule for validating tool call sequences.
|
|
69
|
+
|
|
70
|
+
Attributes:
|
|
71
|
+
name: Unique identifier for the rule.
|
|
72
|
+
action: Type of sequence validation.
|
|
73
|
+
target_pattern: Tool name pattern this rule applies to (supports wildcards).
|
|
74
|
+
required_pattern: For REQUIRE_BEFORE, the pattern that must precede.
|
|
75
|
+
forbidden_pattern: For FORBID_AFTER, the pattern that triggers block.
|
|
76
|
+
sequence_patterns: For REQUIRE_SEQUENCE, ordered list of patterns.
|
|
77
|
+
max_count: For MAX_CONSECUTIVE, maximum consecutive calls.
|
|
78
|
+
cooldown_seconds: For COOLDOWN, minimum seconds between calls.
|
|
79
|
+
window_seconds: Time window for FORBID_AFTER and lookback.
|
|
80
|
+
description: Human-readable description of the rule.
|
|
81
|
+
enabled: Whether the rule is active.
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
name: str
|
|
85
|
+
action: SequenceAction
|
|
86
|
+
target_pattern: str = "*"
|
|
87
|
+
required_pattern: str | None = None
|
|
88
|
+
forbidden_pattern: str | None = None
|
|
89
|
+
sequence_patterns: list[str] = field(default_factory=list)
|
|
90
|
+
max_count: int = 5
|
|
91
|
+
cooldown_seconds: float = 60.0
|
|
92
|
+
window_seconds: float = 300.0
|
|
93
|
+
description: str = ""
|
|
94
|
+
enabled: bool = True
|
|
95
|
+
|
|
96
|
+
def matches_target(self, tool_name: str) -> bool:
|
|
97
|
+
"""Check if tool name matches the target pattern."""
|
|
98
|
+
return fnmatch.fnmatch(tool_name.lower(), self.target_pattern.lower())
|
|
99
|
+
|
|
100
|
+
def matches_pattern(self, tool_name: str, pattern: str) -> bool:
|
|
101
|
+
"""Check if tool name matches a pattern."""
|
|
102
|
+
return fnmatch.fnmatch(tool_name.lower(), pattern.lower())
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@dataclass
|
|
106
|
+
class SequenceViolation:
|
|
107
|
+
"""
|
|
108
|
+
Details about a sequence rule violation.
|
|
109
|
+
|
|
110
|
+
Attributes:
|
|
111
|
+
rule_name: Name of the violated rule.
|
|
112
|
+
violation_type: Type of violation (from SequenceAction).
|
|
113
|
+
tool_name: Tool that triggered the violation.
|
|
114
|
+
tool_sequence: Recent tool call sequence.
|
|
115
|
+
message: Human-readable violation message.
|
|
116
|
+
required_prior: For REQUIRE_BEFORE, what tool was required.
|
|
117
|
+
forbidden_prior: For FORBID_AFTER, what tool was forbidden before.
|
|
118
|
+
consecutive_count: For MAX_CONSECUTIVE, how many calls were made.
|
|
119
|
+
last_call_seconds_ago: For COOLDOWN, seconds since last call.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
rule_name: str
|
|
123
|
+
violation_type: SequenceAction
|
|
124
|
+
tool_name: str
|
|
125
|
+
tool_sequence: list[str] = field(default_factory=list)
|
|
126
|
+
message: str = ""
|
|
127
|
+
required_prior: str | None = None
|
|
128
|
+
forbidden_prior: str | None = None
|
|
129
|
+
consecutive_count: int = 0
|
|
130
|
+
last_call_seconds_ago: float = 0.0
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass
|
|
134
|
+
class ToolCallRecord:
|
|
135
|
+
"""Record of a tool call for sequence tracking."""
|
|
136
|
+
|
|
137
|
+
tool_name: str
|
|
138
|
+
timestamp: datetime
|
|
139
|
+
user_id: str
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# Default security rules
|
|
143
|
+
DEFAULT_SEQUENCE_RULES: list[SequenceRule] = [
|
|
144
|
+
SequenceRule(
|
|
145
|
+
name="require_confirm_before_delete",
|
|
146
|
+
action=SequenceAction.REQUIRE_BEFORE,
|
|
147
|
+
target_pattern="delete_*",
|
|
148
|
+
required_pattern="confirm_*",
|
|
149
|
+
description="Deletion requires confirmation first",
|
|
150
|
+
),
|
|
151
|
+
SequenceRule(
|
|
152
|
+
name="max_consecutive_calls",
|
|
153
|
+
action=SequenceAction.MAX_CONSECUTIVE,
|
|
154
|
+
target_pattern="*",
|
|
155
|
+
max_count=10,
|
|
156
|
+
description="Prevent runaway tool loops",
|
|
157
|
+
),
|
|
158
|
+
SequenceRule(
|
|
159
|
+
name="forbid_download_execute",
|
|
160
|
+
action=SequenceAction.FORBID_AFTER,
|
|
161
|
+
target_pattern="execute_*",
|
|
162
|
+
forbidden_pattern="download_*",
|
|
163
|
+
window_seconds=300.0,
|
|
164
|
+
description="Prevent download-and-execute attacks",
|
|
165
|
+
),
|
|
166
|
+
SequenceRule(
|
|
167
|
+
name="forbid_download_run",
|
|
168
|
+
action=SequenceAction.FORBID_AFTER,
|
|
169
|
+
target_pattern="run_*",
|
|
170
|
+
forbidden_pattern="download_*",
|
|
171
|
+
window_seconds=300.0,
|
|
172
|
+
description="Prevent download-and-run attacks",
|
|
173
|
+
),
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class SequenceValidator:
|
|
178
|
+
"""
|
|
179
|
+
Validates tool call sequences against defined rules.
|
|
180
|
+
|
|
181
|
+
Tracks per-user tool call history and validates each call against
|
|
182
|
+
rules for dangerous patterns like delete without confirm, download
|
|
183
|
+
followed by execute, or rapid consecutive calls.
|
|
184
|
+
|
|
185
|
+
Example:
|
|
186
|
+
>>> validator = SequenceValidator()
|
|
187
|
+
>>>
|
|
188
|
+
>>> # Try to delete without confirming
|
|
189
|
+
>>> allowed, violation = validator.validate_call("delete_file", "user_1")
|
|
190
|
+
>>> print(allowed) # False - needs confirm first
|
|
191
|
+
>>>
|
|
192
|
+
>>> # Confirm first, then delete
|
|
193
|
+
>>> validator.record_call("confirm_delete", "user_1")
|
|
194
|
+
>>> allowed, violation = validator.validate_call("delete_file", "user_1")
|
|
195
|
+
>>> print(allowed) # True
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
rules: list[SequenceRule] | None = None,
|
|
201
|
+
history_size: int = 100,
|
|
202
|
+
include_defaults: bool = True,
|
|
203
|
+
) -> None:
|
|
204
|
+
"""
|
|
205
|
+
Initialize the sequence validator.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
rules: Custom rules to use.
|
|
209
|
+
history_size: Maximum history entries per user.
|
|
210
|
+
include_defaults: Whether to include default security rules.
|
|
211
|
+
"""
|
|
212
|
+
self._rules: list[SequenceRule] = []
|
|
213
|
+
self._history_size = history_size
|
|
214
|
+
self._user_history: dict[str, deque[ToolCallRecord]] = {}
|
|
215
|
+
self._lock = threading.RLock()
|
|
216
|
+
|
|
217
|
+
# Index rules by target pattern for efficient lookup
|
|
218
|
+
self._rule_index: dict[str, list[SequenceRule]] = {}
|
|
219
|
+
|
|
220
|
+
# Add default rules if requested
|
|
221
|
+
if include_defaults:
|
|
222
|
+
for rule in DEFAULT_SEQUENCE_RULES:
|
|
223
|
+
self.add_rule(rule)
|
|
224
|
+
|
|
225
|
+
# Add custom rules
|
|
226
|
+
if rules:
|
|
227
|
+
for rule in rules:
|
|
228
|
+
self.add_rule(rule)
|
|
229
|
+
|
|
230
|
+
def add_rule(self, rule: SequenceRule) -> None:
|
|
231
|
+
"""
|
|
232
|
+
Add a sequence rule.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
rule: The rule to add.
|
|
236
|
+
"""
|
|
237
|
+
with self._lock:
|
|
238
|
+
self._rules.append(rule)
|
|
239
|
+
# Index by target pattern
|
|
240
|
+
if rule.target_pattern not in self._rule_index:
|
|
241
|
+
self._rule_index[rule.target_pattern] = []
|
|
242
|
+
self._rule_index[rule.target_pattern].append(rule)
|
|
243
|
+
|
|
244
|
+
def remove_rule(self, name: str) -> bool:
|
|
245
|
+
"""
|
|
246
|
+
Remove a rule by name.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
name: The rule name to remove.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
True if rule was removed, False if not found.
|
|
253
|
+
"""
|
|
254
|
+
with self._lock:
|
|
255
|
+
for i, rule in enumerate(self._rules):
|
|
256
|
+
if rule.name == name:
|
|
257
|
+
self._rules.pop(i)
|
|
258
|
+
# Remove from index
|
|
259
|
+
if rule.target_pattern in self._rule_index:
|
|
260
|
+
self._rule_index[rule.target_pattern] = [
|
|
261
|
+
r for r in self._rule_index[rule.target_pattern]
|
|
262
|
+
if r.name != name
|
|
263
|
+
]
|
|
264
|
+
return True
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
def get_rules(self) -> list[SequenceRule]:
|
|
268
|
+
"""Get all registered rules."""
|
|
269
|
+
with self._lock:
|
|
270
|
+
return list(self._rules)
|
|
271
|
+
|
|
272
|
+
def get_rule(self, name: str) -> SequenceRule | None:
|
|
273
|
+
"""Get a rule by name."""
|
|
274
|
+
with self._lock:
|
|
275
|
+
for rule in self._rules:
|
|
276
|
+
if rule.name == name:
|
|
277
|
+
return rule
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
def enable_rule(self, name: str) -> bool:
|
|
281
|
+
"""Enable a rule by name."""
|
|
282
|
+
rule = self.get_rule(name)
|
|
283
|
+
if rule:
|
|
284
|
+
rule.enabled = True
|
|
285
|
+
return True
|
|
286
|
+
return False
|
|
287
|
+
|
|
288
|
+
def disable_rule(self, name: str) -> bool:
|
|
289
|
+
"""Disable a rule by name."""
|
|
290
|
+
rule = self.get_rule(name)
|
|
291
|
+
if rule:
|
|
292
|
+
rule.enabled = False
|
|
293
|
+
return True
|
|
294
|
+
return False
|
|
295
|
+
|
|
296
|
+
def record_call(
|
|
297
|
+
self,
|
|
298
|
+
tool_name: str,
|
|
299
|
+
user_id: str,
|
|
300
|
+
timestamp: datetime | None = None,
|
|
301
|
+
) -> None:
|
|
302
|
+
"""
|
|
303
|
+
Record a tool call for sequence tracking.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
tool_name: Name of the tool called.
|
|
307
|
+
user_id: ID of the user making the call.
|
|
308
|
+
timestamp: Optional timestamp (defaults to now).
|
|
309
|
+
"""
|
|
310
|
+
if timestamp is None:
|
|
311
|
+
timestamp = datetime.now(timezone.utc)
|
|
312
|
+
|
|
313
|
+
record = ToolCallRecord(
|
|
314
|
+
tool_name=tool_name,
|
|
315
|
+
timestamp=timestamp,
|
|
316
|
+
user_id=user_id,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
with self._lock:
|
|
320
|
+
if user_id not in self._user_history:
|
|
321
|
+
self._user_history[user_id] = deque(maxlen=self._history_size)
|
|
322
|
+
self._user_history[user_id].append(record)
|
|
323
|
+
|
|
324
|
+
def validate_call(
|
|
325
|
+
self,
|
|
326
|
+
tool_name: str,
|
|
327
|
+
user_id: str,
|
|
328
|
+
) -> tuple[bool, SequenceViolation | None]:
|
|
329
|
+
"""
|
|
330
|
+
Validate a tool call against sequence rules.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
tool_name: Name of the tool to validate.
|
|
334
|
+
user_id: ID of the user making the call.
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
Tuple of (allowed, violation). If allowed is False,
|
|
338
|
+
violation contains details about what rule was violated.
|
|
339
|
+
"""
|
|
340
|
+
with self._lock:
|
|
341
|
+
history = self._get_user_history(user_id)
|
|
342
|
+
|
|
343
|
+
# Check all rules that might apply
|
|
344
|
+
for rule in self._rules:
|
|
345
|
+
if not rule.enabled:
|
|
346
|
+
continue
|
|
347
|
+
|
|
348
|
+
if not rule.matches_target(tool_name):
|
|
349
|
+
continue
|
|
350
|
+
|
|
351
|
+
violation = self._check_rule(rule, tool_name, history)
|
|
352
|
+
if violation:
|
|
353
|
+
logger.warning(
|
|
354
|
+
f"Sequence violation for user {user_id}: "
|
|
355
|
+
f"{violation.rule_name} - {violation.message}"
|
|
356
|
+
)
|
|
357
|
+
return False, violation
|
|
358
|
+
|
|
359
|
+
return True, None
|
|
360
|
+
|
|
361
|
+
def _get_user_history(self, user_id: str) -> list[ToolCallRecord]:
|
|
362
|
+
"""Get history for a user."""
|
|
363
|
+
if user_id not in self._user_history:
|
|
364
|
+
return []
|
|
365
|
+
return list(self._user_history[user_id])
|
|
366
|
+
|
|
367
|
+
def _check_rule(
|
|
368
|
+
self,
|
|
369
|
+
rule: SequenceRule,
|
|
370
|
+
tool_name: str,
|
|
371
|
+
history: list[ToolCallRecord],
|
|
372
|
+
) -> SequenceViolation | None:
|
|
373
|
+
"""Check a single rule against the tool call."""
|
|
374
|
+
if rule.action == SequenceAction.REQUIRE_BEFORE:
|
|
375
|
+
return self._check_require_before(rule, tool_name, history)
|
|
376
|
+
elif rule.action == SequenceAction.FORBID_AFTER:
|
|
377
|
+
return self._check_forbid_after(rule, tool_name, history)
|
|
378
|
+
elif rule.action == SequenceAction.REQUIRE_SEQUENCE:
|
|
379
|
+
return self._check_require_sequence(rule, tool_name, history)
|
|
380
|
+
elif rule.action == SequenceAction.MAX_CONSECUTIVE:
|
|
381
|
+
return self._check_max_consecutive(rule, tool_name, history)
|
|
382
|
+
elif rule.action == SequenceAction.COOLDOWN:
|
|
383
|
+
return self._check_cooldown(rule, tool_name, history)
|
|
384
|
+
return None
|
|
385
|
+
|
|
386
|
+
def _check_require_before(
|
|
387
|
+
self,
|
|
388
|
+
rule: SequenceRule,
|
|
389
|
+
tool_name: str,
|
|
390
|
+
history: list[ToolCallRecord],
|
|
391
|
+
) -> SequenceViolation | None:
|
|
392
|
+
"""Check REQUIRE_BEFORE rule."""
|
|
393
|
+
if not rule.required_pattern:
|
|
394
|
+
return None
|
|
395
|
+
|
|
396
|
+
# Look for required pattern in history
|
|
397
|
+
for record in reversed(history):
|
|
398
|
+
if rule.matches_pattern(record.tool_name, rule.required_pattern):
|
|
399
|
+
return None # Found required predecessor
|
|
400
|
+
|
|
401
|
+
return SequenceViolation(
|
|
402
|
+
rule_name=rule.name,
|
|
403
|
+
violation_type=SequenceAction.REQUIRE_BEFORE,
|
|
404
|
+
tool_name=tool_name,
|
|
405
|
+
tool_sequence=[r.tool_name for r in history[-5:]],
|
|
406
|
+
message=f"Tool '{tool_name}' requires '{rule.required_pattern}' to be called first",
|
|
407
|
+
required_prior=rule.required_pattern,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
def _check_forbid_after(
|
|
411
|
+
self,
|
|
412
|
+
rule: SequenceRule,
|
|
413
|
+
tool_name: str,
|
|
414
|
+
history: list[ToolCallRecord],
|
|
415
|
+
) -> SequenceViolation | None:
|
|
416
|
+
"""Check FORBID_AFTER rule."""
|
|
417
|
+
if not rule.forbidden_pattern:
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
now = datetime.now(timezone.utc)
|
|
421
|
+
window_seconds = rule.window_seconds
|
|
422
|
+
|
|
423
|
+
# Look for forbidden pattern within time window
|
|
424
|
+
for record in reversed(history):
|
|
425
|
+
age = (now - record.timestamp).total_seconds()
|
|
426
|
+
if age > window_seconds:
|
|
427
|
+
break # Beyond time window
|
|
428
|
+
|
|
429
|
+
if rule.matches_pattern(record.tool_name, rule.forbidden_pattern):
|
|
430
|
+
return SequenceViolation(
|
|
431
|
+
rule_name=rule.name,
|
|
432
|
+
violation_type=SequenceAction.FORBID_AFTER,
|
|
433
|
+
tool_name=tool_name,
|
|
434
|
+
tool_sequence=[r.tool_name for r in history[-5:]],
|
|
435
|
+
message=(
|
|
436
|
+
f"Tool '{tool_name}' cannot be called within "
|
|
437
|
+
f"{window_seconds}s after '{rule.forbidden_pattern}' "
|
|
438
|
+
f"('{record.tool_name}' was called {age:.1f}s ago)"
|
|
439
|
+
),
|
|
440
|
+
forbidden_prior=record.tool_name,
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
return None
|
|
444
|
+
|
|
445
|
+
def _check_require_sequence(
|
|
446
|
+
self,
|
|
447
|
+
rule: SequenceRule,
|
|
448
|
+
tool_name: str,
|
|
449
|
+
history: list[ToolCallRecord],
|
|
450
|
+
) -> SequenceViolation | None:
|
|
451
|
+
"""Check REQUIRE_SEQUENCE rule."""
|
|
452
|
+
if not rule.sequence_patterns:
|
|
453
|
+
return None
|
|
454
|
+
|
|
455
|
+
sequence = rule.sequence_patterns
|
|
456
|
+
|
|
457
|
+
# Find which step we're on
|
|
458
|
+
step_index = -1
|
|
459
|
+
for i, pattern in enumerate(sequence):
|
|
460
|
+
if rule.matches_pattern(tool_name, pattern):
|
|
461
|
+
step_index = i
|
|
462
|
+
break
|
|
463
|
+
|
|
464
|
+
if step_index == -1:
|
|
465
|
+
return None # Tool not in sequence
|
|
466
|
+
|
|
467
|
+
if step_index == 0:
|
|
468
|
+
return None # First step is always allowed
|
|
469
|
+
|
|
470
|
+
# Check that previous steps were completed in order
|
|
471
|
+
expected_prior = sequence[step_index - 1]
|
|
472
|
+
|
|
473
|
+
# Look for the expected prior step
|
|
474
|
+
found_prior = False
|
|
475
|
+
for record in reversed(history):
|
|
476
|
+
if rule.matches_pattern(record.tool_name, expected_prior):
|
|
477
|
+
found_prior = True
|
|
478
|
+
break
|
|
479
|
+
# If we find any other step from the sequence that's not the expected one
|
|
480
|
+
for i, pattern in enumerate(sequence):
|
|
481
|
+
if i != step_index - 1 and rule.matches_pattern(record.tool_name, pattern):
|
|
482
|
+
# Found a different step - sequence may be broken
|
|
483
|
+
pass
|
|
484
|
+
|
|
485
|
+
if not found_prior:
|
|
486
|
+
return SequenceViolation(
|
|
487
|
+
rule_name=rule.name,
|
|
488
|
+
violation_type=SequenceAction.REQUIRE_SEQUENCE,
|
|
489
|
+
tool_name=tool_name,
|
|
490
|
+
tool_sequence=[r.tool_name for r in history[-5:]],
|
|
491
|
+
message=(
|
|
492
|
+
f"Tool '{tool_name}' requires '{expected_prior}' to be called first "
|
|
493
|
+
f"(sequence: {' -> '.join(sequence)})"
|
|
494
|
+
),
|
|
495
|
+
required_prior=expected_prior,
|
|
496
|
+
)
|
|
497
|
+
|
|
498
|
+
return None
|
|
499
|
+
|
|
500
|
+
def _check_max_consecutive(
|
|
501
|
+
self,
|
|
502
|
+
rule: SequenceRule,
|
|
503
|
+
tool_name: str,
|
|
504
|
+
history: list[ToolCallRecord],
|
|
505
|
+
) -> SequenceViolation | None:
|
|
506
|
+
"""Check MAX_CONSECUTIVE rule."""
|
|
507
|
+
consecutive_count = 0
|
|
508
|
+
|
|
509
|
+
for record in reversed(history):
|
|
510
|
+
if record.tool_name == tool_name:
|
|
511
|
+
consecutive_count += 1
|
|
512
|
+
else:
|
|
513
|
+
break
|
|
514
|
+
|
|
515
|
+
if consecutive_count >= rule.max_count:
|
|
516
|
+
return SequenceViolation(
|
|
517
|
+
rule_name=rule.name,
|
|
518
|
+
violation_type=SequenceAction.MAX_CONSECUTIVE,
|
|
519
|
+
tool_name=tool_name,
|
|
520
|
+
tool_sequence=[r.tool_name for r in history[-5:]],
|
|
521
|
+
message=(
|
|
522
|
+
f"Tool '{tool_name}' has been called {consecutive_count} times "
|
|
523
|
+
f"consecutively (max allowed: {rule.max_count})"
|
|
524
|
+
),
|
|
525
|
+
consecutive_count=consecutive_count,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
return None
|
|
529
|
+
|
|
530
|
+
def _check_cooldown(
|
|
531
|
+
self,
|
|
532
|
+
rule: SequenceRule,
|
|
533
|
+
tool_name: str,
|
|
534
|
+
history: list[ToolCallRecord],
|
|
535
|
+
) -> SequenceViolation | None:
|
|
536
|
+
"""Check COOLDOWN rule."""
|
|
537
|
+
now = datetime.now(timezone.utc)
|
|
538
|
+
|
|
539
|
+
# Find last call to this tool
|
|
540
|
+
for record in reversed(history):
|
|
541
|
+
if record.tool_name == tool_name:
|
|
542
|
+
age = (now - record.timestamp).total_seconds()
|
|
543
|
+
if age < rule.cooldown_seconds:
|
|
544
|
+
return SequenceViolation(
|
|
545
|
+
rule_name=rule.name,
|
|
546
|
+
violation_type=SequenceAction.COOLDOWN,
|
|
547
|
+
tool_name=tool_name,
|
|
548
|
+
tool_sequence=[r.tool_name for r in history[-5:]],
|
|
549
|
+
message=(
|
|
550
|
+
f"Tool '{tool_name}' requires {rule.cooldown_seconds}s cooldown "
|
|
551
|
+
f"(last called {age:.1f}s ago)"
|
|
552
|
+
),
|
|
553
|
+
last_call_seconds_ago=age,
|
|
554
|
+
)
|
|
555
|
+
break
|
|
556
|
+
|
|
557
|
+
return None
|
|
558
|
+
|
|
559
|
+
def get_history(
|
|
560
|
+
self,
|
|
561
|
+
user_id: str,
|
|
562
|
+
limit: int | None = None,
|
|
563
|
+
) -> list[tuple[str, datetime]]:
|
|
564
|
+
"""
|
|
565
|
+
Get tool call history for a user.
|
|
566
|
+
|
|
567
|
+
Args:
|
|
568
|
+
user_id: The user ID.
|
|
569
|
+
limit: Maximum entries to return (None for all).
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
List of (tool_name, timestamp) tuples, most recent first.
|
|
573
|
+
"""
|
|
574
|
+
with self._lock:
|
|
575
|
+
history = self._get_user_history(user_id)
|
|
576
|
+
result = [(r.tool_name, r.timestamp) for r in reversed(history)]
|
|
577
|
+
if limit:
|
|
578
|
+
result = result[:limit]
|
|
579
|
+
return result
|
|
580
|
+
|
|
581
|
+
def clear_history(self, user_id: str | None = None) -> None:
|
|
582
|
+
"""
|
|
583
|
+
Clear tool call history.
|
|
584
|
+
|
|
585
|
+
Args:
|
|
586
|
+
user_id: User ID to clear (None to clear all).
|
|
587
|
+
"""
|
|
588
|
+
with self._lock:
|
|
589
|
+
if user_id is None:
|
|
590
|
+
self._user_history.clear()
|
|
591
|
+
elif user_id in self._user_history:
|
|
592
|
+
del self._user_history[user_id]
|
|
593
|
+
|
|
594
|
+
def configure(
|
|
595
|
+
self,
|
|
596
|
+
history_size: int | None = None,
|
|
597
|
+
) -> None:
|
|
598
|
+
"""
|
|
599
|
+
Update validator configuration.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
history_size: New maximum history size per user.
|
|
603
|
+
"""
|
|
604
|
+
with self._lock:
|
|
605
|
+
if history_size is not None:
|
|
606
|
+
self._history_size = history_size
|
|
607
|
+
# Resize existing histories
|
|
608
|
+
for user_id in self._user_history:
|
|
609
|
+
old_history = list(self._user_history[user_id])
|
|
610
|
+
self._user_history[user_id] = deque(
|
|
611
|
+
old_history[-history_size:],
|
|
612
|
+
maxlen=history_size,
|
|
613
|
+
)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def create_sequence_validator(
|
|
617
|
+
include_defaults: bool = True,
|
|
618
|
+
custom_rules: list[SequenceRule] | None = None,
|
|
619
|
+
history_size: int = 100,
|
|
620
|
+
) -> SequenceValidator:
|
|
621
|
+
"""
|
|
622
|
+
Factory function to create a SequenceValidator.
|
|
623
|
+
|
|
624
|
+
Args:
|
|
625
|
+
include_defaults: Whether to include default security rules.
|
|
626
|
+
custom_rules: Additional custom rules.
|
|
627
|
+
history_size: Maximum history entries per user.
|
|
628
|
+
|
|
629
|
+
Returns:
|
|
630
|
+
Configured SequenceValidator instance.
|
|
631
|
+
"""
|
|
632
|
+
return SequenceValidator(
|
|
633
|
+
rules=custom_rules,
|
|
634
|
+
history_size=history_size,
|
|
635
|
+
include_defaults=include_defaults,
|
|
636
|
+
)
|