proxilion 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proxilion/__init__.py +136 -0
- proxilion/audit/__init__.py +133 -0
- proxilion/audit/base_exporters.py +527 -0
- proxilion/audit/compliance/__init__.py +130 -0
- proxilion/audit/compliance/base.py +457 -0
- proxilion/audit/compliance/eu_ai_act.py +603 -0
- proxilion/audit/compliance/iso27001.py +544 -0
- proxilion/audit/compliance/soc2.py +491 -0
- proxilion/audit/events.py +493 -0
- proxilion/audit/explainability.py +1173 -0
- proxilion/audit/exporters/__init__.py +58 -0
- proxilion/audit/exporters/aws_s3.py +636 -0
- proxilion/audit/exporters/azure_storage.py +608 -0
- proxilion/audit/exporters/cloud_base.py +468 -0
- proxilion/audit/exporters/gcp_storage.py +570 -0
- proxilion/audit/exporters/multi_exporter.py +498 -0
- proxilion/audit/hash_chain.py +652 -0
- proxilion/audit/logger.py +543 -0
- proxilion/caching/__init__.py +49 -0
- proxilion/caching/tool_cache.py +633 -0
- proxilion/context/__init__.py +73 -0
- proxilion/context/context_window.py +556 -0
- proxilion/context/message_history.py +505 -0
- proxilion/context/session.py +735 -0
- proxilion/contrib/__init__.py +51 -0
- proxilion/contrib/anthropic.py +609 -0
- proxilion/contrib/google.py +1012 -0
- proxilion/contrib/langchain.py +641 -0
- proxilion/contrib/mcp.py +893 -0
- proxilion/contrib/openai.py +646 -0
- proxilion/core.py +3058 -0
- proxilion/decorators.py +966 -0
- proxilion/engines/__init__.py +287 -0
- proxilion/engines/base.py +266 -0
- proxilion/engines/casbin_engine.py +412 -0
- proxilion/engines/opa_engine.py +493 -0
- proxilion/engines/simple.py +437 -0
- proxilion/exceptions.py +887 -0
- proxilion/guards/__init__.py +54 -0
- proxilion/guards/input_guard.py +522 -0
- proxilion/guards/output_guard.py +634 -0
- proxilion/observability/__init__.py +198 -0
- proxilion/observability/cost_tracker.py +866 -0
- proxilion/observability/hooks.py +683 -0
- proxilion/observability/metrics.py +798 -0
- proxilion/observability/session_cost_tracker.py +1063 -0
- proxilion/policies/__init__.py +67 -0
- proxilion/policies/base.py +304 -0
- proxilion/policies/builtin.py +486 -0
- proxilion/policies/registry.py +376 -0
- proxilion/providers/__init__.py +201 -0
- proxilion/providers/adapter.py +468 -0
- proxilion/providers/anthropic_adapter.py +330 -0
- proxilion/providers/gemini_adapter.py +391 -0
- proxilion/providers/openai_adapter.py +294 -0
- proxilion/py.typed +0 -0
- proxilion/resilience/__init__.py +81 -0
- proxilion/resilience/degradation.py +615 -0
- proxilion/resilience/fallback.py +555 -0
- proxilion/resilience/retry.py +554 -0
- proxilion/scheduling/__init__.py +57 -0
- proxilion/scheduling/priority_queue.py +419 -0
- proxilion/scheduling/scheduler.py +459 -0
- proxilion/security/__init__.py +244 -0
- proxilion/security/agent_trust.py +968 -0
- proxilion/security/behavioral_drift.py +794 -0
- proxilion/security/cascade_protection.py +869 -0
- proxilion/security/circuit_breaker.py +428 -0
- proxilion/security/cost_limiter.py +690 -0
- proxilion/security/idor_protection.py +460 -0
- proxilion/security/intent_capsule.py +849 -0
- proxilion/security/intent_validator.py +495 -0
- proxilion/security/memory_integrity.py +767 -0
- proxilion/security/rate_limiter.py +509 -0
- proxilion/security/scope_enforcer.py +680 -0
- proxilion/security/sequence_validator.py +636 -0
- proxilion/security/trust_boundaries.py +784 -0
- proxilion/streaming/__init__.py +70 -0
- proxilion/streaming/detector.py +761 -0
- proxilion/streaming/transformer.py +674 -0
- proxilion/timeouts/__init__.py +55 -0
- proxilion/timeouts/decorators.py +477 -0
- proxilion/timeouts/manager.py +545 -0
- proxilion/tools/__init__.py +69 -0
- proxilion/tools/decorators.py +493 -0
- proxilion/tools/registry.py +732 -0
- proxilion/types.py +339 -0
- proxilion/validation/__init__.py +93 -0
- proxilion/validation/pydantic_schema.py +351 -0
- proxilion/validation/schema.py +651 -0
- proxilion-0.0.1.dist-info/METADATA +872 -0
- proxilion-0.0.1.dist-info/RECORD +94 -0
- proxilion-0.0.1.dist-info/WHEEL +4 -0
- proxilion-0.0.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Intent validation for Proxilion.
|
|
3
|
+
|
|
4
|
+
This module provides deterministic intent validation to detect
|
|
5
|
+
anomalous patterns in tool calls without relying on LLM analysis.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from collections import defaultdict
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from enum import Enum
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ValidationResult(Enum):
|
|
23
|
+
"""Result of intent validation."""
|
|
24
|
+
VALID = "valid"
|
|
25
|
+
SUSPICIOUS = "suspicious"
|
|
26
|
+
BLOCKED = "blocked"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class ValidationOutcome:
|
|
31
|
+
"""Outcome of intent validation."""
|
|
32
|
+
result: ValidationResult
|
|
33
|
+
reason: str | None = None
|
|
34
|
+
risk_score: float = 0.0
|
|
35
|
+
details: dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def is_valid(self) -> bool:
|
|
39
|
+
"""Check if the validation passed."""
|
|
40
|
+
return self.result == ValidationResult.VALID
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def should_block(self) -> bool:
|
|
44
|
+
"""Check if the request should be blocked."""
|
|
45
|
+
return self.result == ValidationResult.BLOCKED
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class WorkflowState:
|
|
50
|
+
"""State of a user's workflow."""
|
|
51
|
+
current_state: str = "initial"
|
|
52
|
+
allowed_transitions: set[str] = field(default_factory=set)
|
|
53
|
+
history: list[str] = field(default_factory=list)
|
|
54
|
+
context: dict[str, Any] = field(default_factory=dict)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class AnomalyThresholds:
|
|
59
|
+
"""Thresholds for anomaly detection."""
|
|
60
|
+
max_calls_per_minute: int = 60
|
|
61
|
+
max_unique_resources_per_minute: int = 20
|
|
62
|
+
max_consecutive_failures: int = 5
|
|
63
|
+
max_data_volume_mb: float = 10.0
|
|
64
|
+
suspicious_hour_start: int = 2 # 2 AM
|
|
65
|
+
suspicious_hour_end: int = 5 # 5 AM
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class IntentValidator:
|
|
69
|
+
"""
|
|
70
|
+
Validates the intent of tool calls using deterministic rules.
|
|
71
|
+
|
|
72
|
+
This validator checks for anomalous patterns that might indicate
|
|
73
|
+
malicious use, without relying on LLM-based analysis.
|
|
74
|
+
|
|
75
|
+
Features:
|
|
76
|
+
- Workflow state validation
|
|
77
|
+
- Parameter consistency checking
|
|
78
|
+
- Anomaly detection (mass operations, unusual patterns)
|
|
79
|
+
- Time-based suspicion (unusual hours)
|
|
80
|
+
- Resource access pattern analysis
|
|
81
|
+
|
|
82
|
+
Example:
|
|
83
|
+
>>> validator = IntentValidator()
|
|
84
|
+
>>>
|
|
85
|
+
>>> # Register workflow
|
|
86
|
+
>>> validator.register_workflow("document_workflow", {
|
|
87
|
+
... "initial": ["search"],
|
|
88
|
+
... "search": ["view", "search"],
|
|
89
|
+
... "view": ["edit", "download", "search"],
|
|
90
|
+
... "edit": ["save", "view"],
|
|
91
|
+
... })
|
|
92
|
+
>>>
|
|
93
|
+
>>> # Validate a tool call
|
|
94
|
+
>>> outcome = validator.validate(
|
|
95
|
+
... user_id="user_123",
|
|
96
|
+
... tool_name="view_document",
|
|
97
|
+
... arguments={"doc_id": "doc_456"},
|
|
98
|
+
... workflow_name="document_workflow",
|
|
99
|
+
... )
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
thresholds: AnomalyThresholds | None = None,
|
|
105
|
+
) -> None:
|
|
106
|
+
"""
|
|
107
|
+
Initialize the intent validator.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
thresholds: Custom anomaly detection thresholds.
|
|
111
|
+
"""
|
|
112
|
+
self.thresholds = thresholds or AnomalyThresholds()
|
|
113
|
+
|
|
114
|
+
# Workflows: workflow_name -> {state -> allowed_next_states}
|
|
115
|
+
self._workflows: dict[str, dict[str, set[str]]] = {}
|
|
116
|
+
|
|
117
|
+
# User workflow states: user_id -> workflow_name -> WorkflowState
|
|
118
|
+
self._user_states: dict[str, dict[str, WorkflowState]] = defaultdict(dict)
|
|
119
|
+
|
|
120
|
+
# Call history for anomaly detection: user_id -> list of (timestamp, tool_name, arguments)
|
|
121
|
+
self._call_history: dict[str, list[tuple[float, str, dict[str, Any]]]] = defaultdict(list)
|
|
122
|
+
|
|
123
|
+
# Failure tracking
|
|
124
|
+
self._failure_counts: dict[str, int] = defaultdict(int)
|
|
125
|
+
|
|
126
|
+
# Custom validators
|
|
127
|
+
self._validators: list[Callable[[str, str, dict[str, Any]], ValidationOutcome | None]] = []
|
|
128
|
+
|
|
129
|
+
self._lock = threading.RLock()
|
|
130
|
+
|
|
131
|
+
def register_workflow(
|
|
132
|
+
self,
|
|
133
|
+
workflow_name: str,
|
|
134
|
+
transitions: dict[str, list[str]],
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Register a workflow state machine.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
workflow_name: Name of the workflow.
|
|
141
|
+
transitions: Dictionary mapping states to allowed next states.
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
>>> validator.register_workflow("order_flow", {
|
|
145
|
+
... "initial": ["browse"],
|
|
146
|
+
... "browse": ["add_to_cart", "browse"],
|
|
147
|
+
... "add_to_cart": ["checkout", "browse", "remove_from_cart"],
|
|
148
|
+
... "checkout": ["pay", "cancel"],
|
|
149
|
+
... "pay": ["complete"],
|
|
150
|
+
... })
|
|
151
|
+
"""
|
|
152
|
+
with self._lock:
|
|
153
|
+
self._workflows[workflow_name] = {
|
|
154
|
+
state: set(next_states)
|
|
155
|
+
for state, next_states in transitions.items()
|
|
156
|
+
}
|
|
157
|
+
logger.debug(f"Registered workflow: {workflow_name}")
|
|
158
|
+
|
|
159
|
+
def register_validator(
|
|
160
|
+
self,
|
|
161
|
+
validator: Callable[[str, str, dict[str, Any]], ValidationOutcome | None],
|
|
162
|
+
) -> None:
|
|
163
|
+
"""
|
|
164
|
+
Register a custom validation function.
|
|
165
|
+
|
|
166
|
+
The function should return a ValidationOutcome if it has a decision,
|
|
167
|
+
or None to defer to other validators.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
validator: Function(user_id, tool_name, arguments) -> ValidationOutcome | None
|
|
171
|
+
"""
|
|
172
|
+
with self._lock:
|
|
173
|
+
self._validators.append(validator)
|
|
174
|
+
|
|
175
|
+
def validate(
|
|
176
|
+
self,
|
|
177
|
+
user_id: str,
|
|
178
|
+
tool_name: str,
|
|
179
|
+
arguments: dict[str, Any],
|
|
180
|
+
workflow_name: str | None = None,
|
|
181
|
+
tool_to_state: Callable[[str], str] | None = None,
|
|
182
|
+
) -> ValidationOutcome:
|
|
183
|
+
"""
|
|
184
|
+
Validate a tool call intent.
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
user_id: The user's ID.
|
|
188
|
+
tool_name: The tool being called.
|
|
189
|
+
arguments: The tool arguments.
|
|
190
|
+
workflow_name: Optional workflow to validate against.
|
|
191
|
+
tool_to_state: Optional function to map tool name to workflow state.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
ValidationOutcome with the validation result.
|
|
195
|
+
"""
|
|
196
|
+
with self._lock:
|
|
197
|
+
# Run custom validators first
|
|
198
|
+
for validator in self._validators:
|
|
199
|
+
try:
|
|
200
|
+
outcome = validator(user_id, tool_name, arguments)
|
|
201
|
+
if outcome is not None:
|
|
202
|
+
return outcome
|
|
203
|
+
except Exception as e:
|
|
204
|
+
logger.error(f"Custom validator failed: {e}")
|
|
205
|
+
|
|
206
|
+
# Record call for history
|
|
207
|
+
self._record_call(user_id, tool_name, arguments)
|
|
208
|
+
|
|
209
|
+
# Run built-in checks
|
|
210
|
+
outcomes: list[ValidationOutcome] = []
|
|
211
|
+
|
|
212
|
+
# Workflow validation
|
|
213
|
+
if workflow_name:
|
|
214
|
+
state_name = tool_to_state(tool_name) if tool_to_state else tool_name
|
|
215
|
+
workflow_outcome = self._validate_workflow(
|
|
216
|
+
user_id, workflow_name, state_name
|
|
217
|
+
)
|
|
218
|
+
outcomes.append(workflow_outcome)
|
|
219
|
+
|
|
220
|
+
# Anomaly detection
|
|
221
|
+
anomaly_outcome = self._detect_anomalies(user_id, tool_name, arguments)
|
|
222
|
+
outcomes.append(anomaly_outcome)
|
|
223
|
+
|
|
224
|
+
# Parameter consistency
|
|
225
|
+
consistency_outcome = self._check_parameter_consistency(
|
|
226
|
+
tool_name, arguments
|
|
227
|
+
)
|
|
228
|
+
outcomes.append(consistency_outcome)
|
|
229
|
+
|
|
230
|
+
# Combine outcomes
|
|
231
|
+
return self._combine_outcomes(outcomes)
|
|
232
|
+
|
|
233
|
+
def _record_call(
|
|
234
|
+
self,
|
|
235
|
+
user_id: str,
|
|
236
|
+
tool_name: str,
|
|
237
|
+
arguments: dict[str, Any],
|
|
238
|
+
) -> None:
|
|
239
|
+
"""Record a tool call for history analysis."""
|
|
240
|
+
now = time.time()
|
|
241
|
+
self._call_history[user_id].append((now, tool_name, arguments))
|
|
242
|
+
|
|
243
|
+
# Cleanup old entries (keep last hour)
|
|
244
|
+
cutoff = now - 3600
|
|
245
|
+
self._call_history[user_id] = [
|
|
246
|
+
entry for entry in self._call_history[user_id]
|
|
247
|
+
if entry[0] > cutoff
|
|
248
|
+
]
|
|
249
|
+
|
|
250
|
+
def _validate_workflow(
|
|
251
|
+
self,
|
|
252
|
+
user_id: str,
|
|
253
|
+
workflow_name: str,
|
|
254
|
+
state_name: str,
|
|
255
|
+
) -> ValidationOutcome:
|
|
256
|
+
"""Validate against workflow state machine."""
|
|
257
|
+
if workflow_name not in self._workflows:
|
|
258
|
+
return ValidationOutcome(
|
|
259
|
+
result=ValidationResult.VALID,
|
|
260
|
+
reason="Unknown workflow, skipping validation",
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
workflow = self._workflows[workflow_name]
|
|
264
|
+
|
|
265
|
+
# Get or create user's workflow state
|
|
266
|
+
if workflow_name not in self._user_states[user_id]:
|
|
267
|
+
self._user_states[user_id][workflow_name] = WorkflowState(
|
|
268
|
+
current_state="initial",
|
|
269
|
+
allowed_transitions=workflow.get("initial", set()),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
user_state = self._user_states[user_id][workflow_name]
|
|
273
|
+
|
|
274
|
+
# Check if transition is allowed
|
|
275
|
+
not_allowed = state_name not in user_state.allowed_transitions
|
|
276
|
+
not_initial = user_state.current_state != "initial"
|
|
277
|
+
if not_allowed and not_initial:
|
|
278
|
+
return ValidationOutcome(
|
|
279
|
+
result=ValidationResult.SUSPICIOUS,
|
|
280
|
+
reason=(
|
|
281
|
+
f"Unexpected workflow transition: "
|
|
282
|
+
f"{user_state.current_state} -> {state_name}"
|
|
283
|
+
),
|
|
284
|
+
risk_score=0.5,
|
|
285
|
+
details={
|
|
286
|
+
"current_state": user_state.current_state,
|
|
287
|
+
"attempted_state": state_name,
|
|
288
|
+
"allowed": list(user_state.allowed_transitions),
|
|
289
|
+
},
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Update state
|
|
293
|
+
user_state.current_state = state_name
|
|
294
|
+
user_state.allowed_transitions = workflow.get(state_name, set())
|
|
295
|
+
user_state.history.append(state_name)
|
|
296
|
+
|
|
297
|
+
return ValidationOutcome(
|
|
298
|
+
result=ValidationResult.VALID,
|
|
299
|
+
reason="Workflow transition valid",
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
def _detect_anomalies(
|
|
303
|
+
self,
|
|
304
|
+
user_id: str,
|
|
305
|
+
tool_name: str,
|
|
306
|
+
arguments: dict[str, Any],
|
|
307
|
+
) -> ValidationOutcome:
|
|
308
|
+
"""Detect anomalous patterns in tool usage."""
|
|
309
|
+
now = time.time()
|
|
310
|
+
minute_ago = now - 60
|
|
311
|
+
|
|
312
|
+
# Get recent calls
|
|
313
|
+
recent_calls = [
|
|
314
|
+
entry for entry in self._call_history[user_id]
|
|
315
|
+
if entry[0] > minute_ago
|
|
316
|
+
]
|
|
317
|
+
|
|
318
|
+
# Check call rate
|
|
319
|
+
if len(recent_calls) > self.thresholds.max_calls_per_minute:
|
|
320
|
+
return ValidationOutcome(
|
|
321
|
+
result=ValidationResult.SUSPICIOUS,
|
|
322
|
+
reason=f"High call rate: {len(recent_calls)} calls in last minute",
|
|
323
|
+
risk_score=0.7,
|
|
324
|
+
details={"calls_per_minute": len(recent_calls)},
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# Check unique resources accessed
|
|
328
|
+
unique_resources: set[str] = set()
|
|
329
|
+
for _, _, args in recent_calls:
|
|
330
|
+
for key, value in args.items():
|
|
331
|
+
if key.endswith("_id") and isinstance(value, str):
|
|
332
|
+
unique_resources.add(f"{key}:{value}")
|
|
333
|
+
|
|
334
|
+
if len(unique_resources) > self.thresholds.max_unique_resources_per_minute:
|
|
335
|
+
return ValidationOutcome(
|
|
336
|
+
result=ValidationResult.SUSPICIOUS,
|
|
337
|
+
reason=f"Mass resource access: {len(unique_resources)} unique resources",
|
|
338
|
+
risk_score=0.8,
|
|
339
|
+
details={"unique_resources": len(unique_resources)},
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Check for unusual hours (local time check would require timezone)
|
|
343
|
+
hour = time.localtime(now).tm_hour
|
|
344
|
+
if self.thresholds.suspicious_hour_start <= hour < self.thresholds.suspicious_hour_end:
|
|
345
|
+
return ValidationOutcome(
|
|
346
|
+
result=ValidationResult.SUSPICIOUS,
|
|
347
|
+
reason=f"Access during unusual hours ({hour}:00)",
|
|
348
|
+
risk_score=0.3,
|
|
349
|
+
details={"hour": hour},
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Check for mass operations (same tool, different IDs)
|
|
353
|
+
same_tool_calls = [c for c in recent_calls if c[1] == tool_name]
|
|
354
|
+
if len(same_tool_calls) > 10:
|
|
355
|
+
different_ids = len({
|
|
356
|
+
str(c[2].get("id") or c[2].get("document_id") or c[2].get("user_id"))
|
|
357
|
+
for c in same_tool_calls
|
|
358
|
+
})
|
|
359
|
+
if different_ids > 5:
|
|
360
|
+
return ValidationOutcome(
|
|
361
|
+
result=ValidationResult.SUSPICIOUS,
|
|
362
|
+
reason=(
|
|
363
|
+
f"Mass operation detected: {tool_name} "
|
|
364
|
+
f"on {different_ids} different resources"
|
|
365
|
+
),
|
|
366
|
+
risk_score=0.6,
|
|
367
|
+
details={
|
|
368
|
+
"tool": tool_name,
|
|
369
|
+
"call_count": len(same_tool_calls),
|
|
370
|
+
"unique_targets": different_ids,
|
|
371
|
+
},
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
return ValidationOutcome(result=ValidationResult.VALID)
|
|
375
|
+
|
|
376
|
+
def _check_parameter_consistency(
|
|
377
|
+
self,
|
|
378
|
+
tool_name: str,
|
|
379
|
+
arguments: dict[str, Any],
|
|
380
|
+
) -> ValidationOutcome:
|
|
381
|
+
"""Check for inconsistent or suspicious parameters."""
|
|
382
|
+
# Check for suspiciously long strings (potential injection)
|
|
383
|
+
for key, value in arguments.items():
|
|
384
|
+
if isinstance(value, str) and len(value) > 10000:
|
|
385
|
+
return ValidationOutcome(
|
|
386
|
+
result=ValidationResult.SUSPICIOUS,
|
|
387
|
+
reason=f"Unusually long parameter: {key} ({len(value)} chars)",
|
|
388
|
+
risk_score=0.4,
|
|
389
|
+
details={"parameter": key, "length": len(value)},
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Check for null bytes (potential injection)
|
|
393
|
+
for key, value in arguments.items():
|
|
394
|
+
if isinstance(value, str) and "\x00" in value:
|
|
395
|
+
return ValidationOutcome(
|
|
396
|
+
result=ValidationResult.BLOCKED,
|
|
397
|
+
reason=f"Null byte in parameter: {key}",
|
|
398
|
+
risk_score=1.0,
|
|
399
|
+
details={"parameter": key},
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Check for excessive nesting in dicts/lists
|
|
403
|
+
def check_depth(obj: Any, depth: int = 0) -> int:
|
|
404
|
+
if depth > 10:
|
|
405
|
+
return depth
|
|
406
|
+
if isinstance(obj, dict):
|
|
407
|
+
return max(
|
|
408
|
+
(check_depth(v, depth + 1) for v in obj.values()),
|
|
409
|
+
default=depth,
|
|
410
|
+
)
|
|
411
|
+
elif isinstance(obj, list):
|
|
412
|
+
return max(
|
|
413
|
+
(check_depth(v, depth + 1) for v in obj),
|
|
414
|
+
default=depth,
|
|
415
|
+
)
|
|
416
|
+
return depth
|
|
417
|
+
|
|
418
|
+
max_depth = check_depth(arguments)
|
|
419
|
+
if max_depth > 10:
|
|
420
|
+
return ValidationOutcome(
|
|
421
|
+
result=ValidationResult.SUSPICIOUS,
|
|
422
|
+
reason=f"Deeply nested parameters (depth: {max_depth})",
|
|
423
|
+
risk_score=0.5,
|
|
424
|
+
details={"max_depth": max_depth},
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
return ValidationOutcome(result=ValidationResult.VALID)
|
|
428
|
+
|
|
429
|
+
def _combine_outcomes(
|
|
430
|
+
self,
|
|
431
|
+
outcomes: list[ValidationOutcome],
|
|
432
|
+
) -> ValidationOutcome:
|
|
433
|
+
"""Combine multiple validation outcomes."""
|
|
434
|
+
# Any BLOCKED result blocks
|
|
435
|
+
blocked = [o for o in outcomes if o.result == ValidationResult.BLOCKED]
|
|
436
|
+
if blocked:
|
|
437
|
+
return blocked[0]
|
|
438
|
+
|
|
439
|
+
# Aggregate suspicious results
|
|
440
|
+
suspicious = [o for o in outcomes if o.result == ValidationResult.SUSPICIOUS]
|
|
441
|
+
if suspicious:
|
|
442
|
+
total_risk = sum(o.risk_score for o in suspicious) / len(suspicious)
|
|
443
|
+
|
|
444
|
+
# If combined risk is high enough, block
|
|
445
|
+
if total_risk > 0.8:
|
|
446
|
+
return ValidationOutcome(
|
|
447
|
+
result=ValidationResult.BLOCKED,
|
|
448
|
+
reason="Multiple suspicious indicators",
|
|
449
|
+
risk_score=total_risk,
|
|
450
|
+
details={"indicators": [o.reason for o in suspicious]},
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
# Return the highest risk suspicious result
|
|
454
|
+
return max(suspicious, key=lambda o: o.risk_score)
|
|
455
|
+
|
|
456
|
+
return ValidationOutcome(result=ValidationResult.VALID)
|
|
457
|
+
|
|
458
|
+
def record_failure(self, user_id: str) -> None:
|
|
459
|
+
"""Record a tool call failure for the user."""
|
|
460
|
+
with self._lock:
|
|
461
|
+
self._failure_counts[user_id] += 1
|
|
462
|
+
|
|
463
|
+
def record_success(self, user_id: str) -> None:
|
|
464
|
+
"""Record a tool call success (resets failure count)."""
|
|
465
|
+
with self._lock:
|
|
466
|
+
self._failure_counts[user_id] = 0
|
|
467
|
+
|
|
468
|
+
def get_failure_count(self, user_id: str) -> int:
|
|
469
|
+
"""Get consecutive failure count for a user."""
|
|
470
|
+
with self._lock:
|
|
471
|
+
return self._failure_counts.get(user_id, 0)
|
|
472
|
+
|
|
473
|
+
def reset_user_state(
|
|
474
|
+
self,
|
|
475
|
+
user_id: str,
|
|
476
|
+
workflow_name: str | None = None,
|
|
477
|
+
) -> None:
|
|
478
|
+
"""Reset a user's workflow state."""
|
|
479
|
+
with self._lock:
|
|
480
|
+
if workflow_name:
|
|
481
|
+
if workflow_name in self._user_states.get(user_id, {}):
|
|
482
|
+
del self._user_states[user_id][workflow_name]
|
|
483
|
+
else:
|
|
484
|
+
self._user_states.pop(user_id, None)
|
|
485
|
+
self._call_history.pop(user_id, None)
|
|
486
|
+
self._failure_counts.pop(user_id, None)
|
|
487
|
+
|
|
488
|
+
def get_user_state(
|
|
489
|
+
self,
|
|
490
|
+
user_id: str,
|
|
491
|
+
workflow_name: str,
|
|
492
|
+
) -> WorkflowState | None:
|
|
493
|
+
"""Get a user's current workflow state."""
|
|
494
|
+
with self._lock:
|
|
495
|
+
return self._user_states.get(user_id, {}).get(workflow_name)
|