jaf-py 2.4.5__py3-none-any.whl → 2.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/core/engine.py +169 -65
- jaf/core/guardrails.py +666 -0
- jaf/core/types.py +83 -1
- jaf/providers/__init__.py +2 -1
- jaf/providers/model.py +363 -8
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/METADATA +2 -1
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/RECORD +11 -10
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/WHEEL +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.7.dist-info}/top_level.txt +0 -0
jaf/core/guardrails.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Advanced guardrails implementation for JAF framework.
|
|
3
|
+
|
|
4
|
+
This module provides LLM-based guardrails with caching, circuit breaking,
|
|
5
|
+
and execution strategies for input validation and output filtering.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import re
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
from .types import (
|
|
16
|
+
Agent,
|
|
17
|
+
RunConfig,
|
|
18
|
+
RunState,
|
|
19
|
+
ValidationResult,
|
|
20
|
+
ValidValidationResult,
|
|
21
|
+
InvalidValidationResult,
|
|
22
|
+
Guardrail,
|
|
23
|
+
AdvancedGuardrailsConfig,
|
|
24
|
+
validate_guardrails_config,
|
|
25
|
+
json_parse_llm_output,
|
|
26
|
+
get_text_content,
|
|
27
|
+
Message,
|
|
28
|
+
ContentRole,
|
|
29
|
+
create_run_id,
|
|
30
|
+
create_trace_id,
|
|
31
|
+
GuardrailEvent,
|
|
32
|
+
GuardrailEventData,
|
|
33
|
+
GuardrailViolationEvent,
|
|
34
|
+
GuardrailViolationEventData
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Constants for content length limits
|
|
38
|
+
SHORT_TIMEOUT_MAX_CONTENT = 10000
|
|
39
|
+
LONG_TIMEOUT_MAX_CONTENT = 50000
|
|
40
|
+
CIRCUIT_BREAKER_CLEANUP_MAX_AGE = 10 * 60 * 1000 # 10 minutes
|
|
41
|
+
|
|
42
|
+
# Constants for timeout values
|
|
43
|
+
DEFAULT_FAST_MODEL_TIMEOUT_MS = 10000
|
|
44
|
+
DEFAULT_TIMEOUT_MS = 5000
|
|
45
|
+
GUARDRAIL_TIMEOUT_MS = 10000
|
|
46
|
+
OUTPUT_GUARDRAIL_TIMEOUT_MS = 15000
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GuardrailCircuitBreaker:
|
|
50
|
+
"""Circuit breaker for guardrail execution to handle repeated failures."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, max_failures: int = 5, reset_time_ms: int = 60000):
|
|
53
|
+
self.failures = 0
|
|
54
|
+
self.last_failure_time = 0
|
|
55
|
+
self.max_failures = max_failures
|
|
56
|
+
self.reset_time_ms = reset_time_ms
|
|
57
|
+
|
|
58
|
+
def is_open(self) -> bool:
|
|
59
|
+
"""Check if circuit breaker is open (blocking requests)."""
|
|
60
|
+
if self.failures < self.max_failures:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
time_since_last_failure = (time.time() * 1000) - self.last_failure_time
|
|
64
|
+
if time_since_last_failure > self.reset_time_ms:
|
|
65
|
+
self.failures = 0
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
def record_failure(self) -> None:
|
|
71
|
+
"""Record a failure."""
|
|
72
|
+
self.failures += 1
|
|
73
|
+
self.last_failure_time = time.time() * 1000
|
|
74
|
+
|
|
75
|
+
def record_success(self) -> None:
|
|
76
|
+
"""Record a success, resetting the failure count."""
|
|
77
|
+
self.failures = 0
|
|
78
|
+
|
|
79
|
+
def should_be_cleaned_up(self, max_age: int) -> bool:
|
|
80
|
+
"""Check if this circuit breaker should be cleaned up."""
|
|
81
|
+
now = time.time() * 1000
|
|
82
|
+
return (self.last_failure_time > 0 and
|
|
83
|
+
(now - self.last_failure_time) > max_age and
|
|
84
|
+
not self.is_open())
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class CacheEntry:
|
|
89
|
+
"""Cache entry for guardrail results."""
|
|
90
|
+
result: ValidationResult
|
|
91
|
+
timestamp: float
|
|
92
|
+
hit_count: int = 1
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class GuardrailCache:
|
|
96
|
+
"""LRU cache for guardrail results."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, max_size: int = 1000, ttl_ms: int = 300000):
|
|
99
|
+
self.cache: Dict[str, CacheEntry] = {}
|
|
100
|
+
self.max_size = max_size
|
|
101
|
+
self.ttl_ms = ttl_ms
|
|
102
|
+
|
|
103
|
+
def _create_key(self, stage: str, rule_prompt: str, content: str, model_name: str) -> str:
|
|
104
|
+
"""Create a cache key."""
|
|
105
|
+
content_hash = self._hash_string(content[:1000])
|
|
106
|
+
rule_hash = self._hash_string(rule_prompt)
|
|
107
|
+
return f"guardrail_{stage}_{model_name}_{rule_hash}_{content_hash}_{len(content)}"
|
|
108
|
+
|
|
109
|
+
def _hash_string(self, s: str) -> str:
|
|
110
|
+
"""Simple hash function for strings."""
|
|
111
|
+
hash_val = 0
|
|
112
|
+
for char in s:
|
|
113
|
+
hash_val = ((hash_val << 5) - hash_val) + ord(char)
|
|
114
|
+
hash_val = hash_val & 0xFFFFFFFF # Keep it 32-bit
|
|
115
|
+
return str(abs(hash_val))
|
|
116
|
+
|
|
117
|
+
def _is_expired(self, entry: CacheEntry) -> bool:
|
|
118
|
+
"""Check if cache entry is expired."""
|
|
119
|
+
return (time.time() * 1000) - entry.timestamp > self.ttl_ms
|
|
120
|
+
|
|
121
|
+
def _evict_lru(self) -> None:
|
|
122
|
+
"""Evict least recently used entry."""
|
|
123
|
+
if len(self.cache) < self.max_size:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
lru_key: Optional[str] = None
|
|
127
|
+
lru_score = float('inf')
|
|
128
|
+
now = time.time() * 1000
|
|
129
|
+
|
|
130
|
+
for key, entry in self.cache.items():
|
|
131
|
+
age_hours = (now - entry.timestamp) / (1000 * 60 * 60)
|
|
132
|
+
score = entry.hit_count / (1 + age_hours)
|
|
133
|
+
if score < lru_score:
|
|
134
|
+
lru_score = score
|
|
135
|
+
lru_key = key
|
|
136
|
+
|
|
137
|
+
if lru_key:
|
|
138
|
+
del self.cache[lru_key]
|
|
139
|
+
|
|
140
|
+
def get(self, stage: str, rule_prompt: str, content: str, model_name: str) -> Optional[ValidationResult]:
|
|
141
|
+
"""Get cached result."""
|
|
142
|
+
key = self._create_key(stage, rule_prompt, content, model_name)
|
|
143
|
+
entry = self.cache.get(key)
|
|
144
|
+
|
|
145
|
+
if not entry or self._is_expired(entry):
|
|
146
|
+
if entry:
|
|
147
|
+
del self.cache[key]
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
entry.hit_count += 1
|
|
151
|
+
entry.timestamp = time.time() * 1000
|
|
152
|
+
|
|
153
|
+
return entry.result
|
|
154
|
+
|
|
155
|
+
def set(self, stage: str, rule_prompt: str, content: str, model_name: str, result: ValidationResult) -> None:
|
|
156
|
+
"""Cache a result."""
|
|
157
|
+
key = self._create_key(stage, rule_prompt, content, model_name)
|
|
158
|
+
|
|
159
|
+
self._evict_lru()
|
|
160
|
+
|
|
161
|
+
self.cache[key] = CacheEntry(
|
|
162
|
+
result=result,
|
|
163
|
+
timestamp=time.time() * 1000,
|
|
164
|
+
hit_count=1
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def clear(self) -> None:
|
|
168
|
+
"""Clear all cached entries."""
|
|
169
|
+
self.cache.clear()
|
|
170
|
+
|
|
171
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
172
|
+
"""Get cache statistics."""
|
|
173
|
+
return {
|
|
174
|
+
'size': len(self.cache),
|
|
175
|
+
'max_size': self.max_size
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# Global instances
|
|
180
|
+
_guardrail_cache = GuardrailCache()
|
|
181
|
+
_circuit_breakers: Dict[str, GuardrailCircuitBreaker] = {}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _get_circuit_breaker(stage: str, model_name: str) -> GuardrailCircuitBreaker:
|
|
185
|
+
"""Get or create a circuit breaker for a stage/model combination."""
|
|
186
|
+
key = f"{stage}-{model_name}"
|
|
187
|
+
if key not in _circuit_breakers:
|
|
188
|
+
_circuit_breakers[key] = GuardrailCircuitBreaker()
|
|
189
|
+
return _circuit_breakers[key]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def _with_timeout(awaitable, timeout_ms: int, error_message: str):
|
|
193
|
+
"""Run an awaitable with a timeout."""
|
|
194
|
+
try:
|
|
195
|
+
return await asyncio.wait_for(awaitable, timeout=timeout_ms / 1000)
|
|
196
|
+
except asyncio.TimeoutError:
|
|
197
|
+
raise TimeoutError(f"Timeout: {error_message}")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
async def _create_llm_guardrail(
|
|
201
|
+
config: RunConfig,
|
|
202
|
+
stage: str,
|
|
203
|
+
rule_prompt: str,
|
|
204
|
+
fast_model: Optional[str] = None,
|
|
205
|
+
fail_safe: str = 'allow',
|
|
206
|
+
timeout_ms: int = 30000
|
|
207
|
+
) -> Guardrail:
|
|
208
|
+
"""Create an LLM-based guardrail function."""
|
|
209
|
+
|
|
210
|
+
async def guardrail_func(content: Any) -> ValidationResult:
|
|
211
|
+
content_str = str(content) if not isinstance(content, str) else content
|
|
212
|
+
|
|
213
|
+
model_to_use = fast_model or config.default_fast_model
|
|
214
|
+
if not model_to_use:
|
|
215
|
+
print(f"[JAF:GUARDRAILS] No fast model available for LLM guardrail evaluation, using failSafe: {fail_safe}")
|
|
216
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
217
|
+
else InvalidValidationResult(error_message='No model available for guardrail evaluation'))
|
|
218
|
+
|
|
219
|
+
# Check cache first
|
|
220
|
+
cached_result = _guardrail_cache.get(stage, rule_prompt, content_str, model_to_use)
|
|
221
|
+
if cached_result:
|
|
222
|
+
print(f"[JAF:GUARDRAILS] Cache hit for {stage} guardrail")
|
|
223
|
+
return cached_result
|
|
224
|
+
|
|
225
|
+
# Check circuit breaker
|
|
226
|
+
circuit_breaker = _get_circuit_breaker(stage, model_to_use)
|
|
227
|
+
if circuit_breaker.is_open():
|
|
228
|
+
print(f"[JAF:GUARDRAILS] Circuit breaker open for {stage} guardrail on model {model_to_use}, using failSafe: {fail_safe}")
|
|
229
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
230
|
+
else InvalidValidationResult(error_message='Circuit breaker open - too many recent failures'))
|
|
231
|
+
|
|
232
|
+
# Validate content
|
|
233
|
+
if not content_str:
|
|
234
|
+
print(f"[JAF:GUARDRAILS] Invalid content provided to {stage} guardrail")
|
|
235
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
236
|
+
else InvalidValidationResult(error_message='Invalid content provided to guardrail'))
|
|
237
|
+
|
|
238
|
+
# Check content length
|
|
239
|
+
max_content_length = SHORT_TIMEOUT_MAX_CONTENT if timeout_ms < 10000 else LONG_TIMEOUT_MAX_CONTENT
|
|
240
|
+
if len(content_str) > max_content_length:
|
|
241
|
+
print(f"[JAF:GUARDRAILS] Content too large for {stage} guardrail ({len(content_str)} chars, max: {max_content_length})")
|
|
242
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
243
|
+
else InvalidValidationResult(error_message=f'Content too large for guardrail evaluation ({len(content_str)} > {max_content_length} chars)'))
|
|
244
|
+
|
|
245
|
+
# Prepare content for evaluation
|
|
246
|
+
sanitized_content = (content_str
|
|
247
|
+
.replace('"""', '[TRIPLE_QUOTE]')
|
|
248
|
+
.replace('\n', ' ')[:2000])
|
|
249
|
+
|
|
250
|
+
eval_prompt = f"""You are a guardrail validator for {stage}.
|
|
251
|
+
Rules:
|
|
252
|
+
{rule_prompt}
|
|
253
|
+
|
|
254
|
+
Decide if the {stage} {'user message' if stage == 'input' else 'assistant output'} complies with the rules.
|
|
255
|
+
Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not include extra text.
|
|
256
|
+
{stage.capitalize()} {'user message' if stage == 'input' else 'assistant output'}:
|
|
257
|
+
\"\"\"
|
|
258
|
+
{sanitized_content}
|
|
259
|
+
\"\"\""""
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Create temporary state for guardrail evaluation
|
|
263
|
+
temp_state = RunState(
|
|
264
|
+
run_id=create_run_id('guardrail-eval'),
|
|
265
|
+
trace_id=create_trace_id('guardrail-eval'),
|
|
266
|
+
messages=[Message(role=ContentRole.USER, content=eval_prompt)],
|
|
267
|
+
current_agent_name='guardrail-evaluator',
|
|
268
|
+
context={},
|
|
269
|
+
turn_count=0
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Create evaluation agent
|
|
273
|
+
def eval_instructions(state: RunState) -> str:
|
|
274
|
+
return 'You are a guardrail validator. Return only valid JSON.'
|
|
275
|
+
|
|
276
|
+
eval_agent = Agent(
|
|
277
|
+
name='guardrail-evaluator',
|
|
278
|
+
instructions=eval_instructions,
|
|
279
|
+
model_config={'name': model_to_use} if hasattr(config, 'ModelConfig') else None
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Create guardrail config (no guardrails to avoid recursion)
|
|
283
|
+
guardrail_config = RunConfig(
|
|
284
|
+
agent_registry=config.agent_registry,
|
|
285
|
+
model_provider=config.model_provider,
|
|
286
|
+
max_turns=1,
|
|
287
|
+
default_fast_model=config.default_fast_model,
|
|
288
|
+
model_override=model_to_use,
|
|
289
|
+
initial_input_guardrails=None,
|
|
290
|
+
final_output_guardrails=None,
|
|
291
|
+
on_event=None
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Execute with timeout
|
|
295
|
+
completion_promise = config.model_provider.get_completion(temp_state, eval_agent, guardrail_config)
|
|
296
|
+
response = await _with_timeout(
|
|
297
|
+
completion_promise,
|
|
298
|
+
timeout_ms,
|
|
299
|
+
f"{stage} guardrail evaluation timed out after {timeout_ms}ms"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Handle different response formats
|
|
303
|
+
response_content = None
|
|
304
|
+
if hasattr(response, 'message') and response.message:
|
|
305
|
+
if hasattr(response.message, 'content'):
|
|
306
|
+
response_content = response.message.content
|
|
307
|
+
elif isinstance(response, dict):
|
|
308
|
+
if 'message' in response and response['message']:
|
|
309
|
+
if isinstance(response['message'], dict) and 'content' in response['message']:
|
|
310
|
+
response_content = response['message']['content']
|
|
311
|
+
elif hasattr(response['message'], 'content'):
|
|
312
|
+
response_content = response['message'].content
|
|
313
|
+
|
|
314
|
+
if not response_content:
|
|
315
|
+
circuit_breaker.record_success()
|
|
316
|
+
result = ValidValidationResult()
|
|
317
|
+
_guardrail_cache.set(stage, rule_prompt, content_str, model_to_use, result)
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
# Parse response
|
|
321
|
+
parsed = json_parse_llm_output(response_content)
|
|
322
|
+
allowed = bool(parsed.get('allowed', True) if parsed else True)
|
|
323
|
+
reason = str(parsed.get('reason', 'Guardrail violation') if parsed else 'Guardrail violation')
|
|
324
|
+
|
|
325
|
+
circuit_breaker.record_success()
|
|
326
|
+
|
|
327
|
+
result = (ValidValidationResult() if allowed
|
|
328
|
+
else InvalidValidationResult(error_message=reason))
|
|
329
|
+
|
|
330
|
+
_guardrail_cache.set(stage, rule_prompt, content_str, model_to_use, result)
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
except Exception as e:
|
|
334
|
+
circuit_breaker.record_failure()
|
|
335
|
+
|
|
336
|
+
error_message = str(e)
|
|
337
|
+
is_timeout = 'Timeout' in error_message
|
|
338
|
+
|
|
339
|
+
log_message = f"[JAF:GUARDRAILS] {stage} guardrail evaluation failed"
|
|
340
|
+
if is_timeout:
|
|
341
|
+
print(f"{log_message} due to timeout ({timeout_ms}ms), using failSafe: {fail_safe}")
|
|
342
|
+
else:
|
|
343
|
+
print(f"{log_message}, using failSafe: {fail_safe} - {error_message}")
|
|
344
|
+
|
|
345
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
346
|
+
else InvalidValidationResult(error_message=f'Guardrail evaluation failed: {error_message}'))
|
|
347
|
+
|
|
348
|
+
return guardrail_func
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def build_effective_guardrails(
|
|
352
|
+
current_agent: Agent,
|
|
353
|
+
config: RunConfig
|
|
354
|
+
) -> Tuple[List[Guardrail], List[Guardrail]]:
|
|
355
|
+
"""Build effective input and output guardrails for an agent."""
|
|
356
|
+
effective_input_guardrails: List[Guardrail] = []
|
|
357
|
+
effective_output_guardrails: List[Guardrail] = []
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
raw_guardrails_cfg = (current_agent.advanced_config.guardrails
|
|
361
|
+
if current_agent.advanced_config
|
|
362
|
+
else None)
|
|
363
|
+
guardrails_cfg = validate_guardrails_config(raw_guardrails_cfg)
|
|
364
|
+
|
|
365
|
+
fast_model = guardrails_cfg.fast_model or config.default_fast_model
|
|
366
|
+
if not fast_model and (guardrails_cfg.input_prompt or guardrails_cfg.output_prompt):
|
|
367
|
+
print('[JAF:GUARDRAILS] No fast model available for LLM guardrails - skipping LLM-based validation')
|
|
368
|
+
|
|
369
|
+
print('[JAF:GUARDRAILS] Configuration:', {
|
|
370
|
+
'hasInputPrompt': bool(guardrails_cfg.input_prompt),
|
|
371
|
+
'hasOutputPrompt': bool(guardrails_cfg.output_prompt),
|
|
372
|
+
'requireCitations': guardrails_cfg.require_citations,
|
|
373
|
+
'executionMode': guardrails_cfg.execution_mode,
|
|
374
|
+
'failSafe': guardrails_cfg.fail_safe,
|
|
375
|
+
'timeoutMs': guardrails_cfg.timeout_ms,
|
|
376
|
+
'fastModel': fast_model or 'none'
|
|
377
|
+
})
|
|
378
|
+
|
|
379
|
+
# Start with global guardrails
|
|
380
|
+
effective_input_guardrails = list(config.initial_input_guardrails or [])
|
|
381
|
+
effective_output_guardrails = list(config.final_output_guardrails or [])
|
|
382
|
+
|
|
383
|
+
# Add input prompt guardrail
|
|
384
|
+
if guardrails_cfg.input_prompt and guardrails_cfg.input_prompt.strip():
|
|
385
|
+
input_guardrail = await _create_llm_guardrail(
|
|
386
|
+
config, 'input', guardrails_cfg.input_prompt,
|
|
387
|
+
fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
|
|
388
|
+
)
|
|
389
|
+
effective_input_guardrails.append(input_guardrail)
|
|
390
|
+
|
|
391
|
+
# Add citation requirement guardrail
|
|
392
|
+
if guardrails_cfg.require_citations:
|
|
393
|
+
def citation_guardrail(output: Any) -> ValidationResult:
|
|
394
|
+
def find_text(val: Any) -> str:
|
|
395
|
+
if isinstance(val, str):
|
|
396
|
+
return val
|
|
397
|
+
elif isinstance(val, list):
|
|
398
|
+
return ' '.join(find_text(item) for item in val)
|
|
399
|
+
elif isinstance(val, dict):
|
|
400
|
+
return ' '.join(find_text(v) for v in val.values())
|
|
401
|
+
else:
|
|
402
|
+
return str(val)
|
|
403
|
+
|
|
404
|
+
text = find_text(output)
|
|
405
|
+
has_citation = bool(re.search(r'\[(\d+)\]', text))
|
|
406
|
+
return (ValidValidationResult() if has_citation
|
|
407
|
+
else InvalidValidationResult(error_message="Missing required [n] citation in output"))
|
|
408
|
+
|
|
409
|
+
effective_output_guardrails.append(citation_guardrail)
|
|
410
|
+
|
|
411
|
+
# Add output prompt guardrail
|
|
412
|
+
if guardrails_cfg.output_prompt and guardrails_cfg.output_prompt.strip():
|
|
413
|
+
output_guardrail = await _create_llm_guardrail(
|
|
414
|
+
config, 'output', guardrails_cfg.output_prompt,
|
|
415
|
+
fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
|
|
416
|
+
)
|
|
417
|
+
effective_output_guardrails.append(output_guardrail)
|
|
418
|
+
|
|
419
|
+
except Exception as e:
|
|
420
|
+
print(f'[JAF:GUARDRAILS] Failed to configure advanced guardrails: {e}')
|
|
421
|
+
# Fall back to global guardrails only
|
|
422
|
+
effective_input_guardrails = list(config.initial_input_guardrails or [])
|
|
423
|
+
effective_output_guardrails = list(config.final_output_guardrails or [])
|
|
424
|
+
|
|
425
|
+
return effective_input_guardrails, effective_output_guardrails
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
async def execute_input_guardrails_sequential(
|
|
429
|
+
input_guardrails: List[Guardrail],
|
|
430
|
+
first_user_message: Message,
|
|
431
|
+
config: RunConfig
|
|
432
|
+
) -> ValidationResult:
|
|
433
|
+
"""Execute input guardrails sequentially."""
|
|
434
|
+
if not input_guardrails:
|
|
435
|
+
return ValidValidationResult()
|
|
436
|
+
|
|
437
|
+
print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails (sequential)")
|
|
438
|
+
|
|
439
|
+
content = get_text_content(first_user_message.content)
|
|
440
|
+
|
|
441
|
+
for i, guardrail in enumerate(input_guardrails):
|
|
442
|
+
guardrail_name = f"input-guardrail-{i + 1}"
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
|
|
446
|
+
|
|
447
|
+
timeout_ms = GUARDRAIL_TIMEOUT_MS
|
|
448
|
+
result = await _with_timeout(
|
|
449
|
+
guardrail(content) if asyncio.iscoroutinefunction(guardrail) else guardrail(content),
|
|
450
|
+
timeout_ms,
|
|
451
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
|
|
455
|
+
|
|
456
|
+
if not result.is_valid:
|
|
457
|
+
error_message = getattr(result, 'error_message', 'Guardrail violation')
|
|
458
|
+
print(f"🚨 {guardrail_name} violation: {error_message}")
|
|
459
|
+
if config.on_event:
|
|
460
|
+
config.on_event(GuardrailViolationEvent(
|
|
461
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
462
|
+
))
|
|
463
|
+
return result
|
|
464
|
+
|
|
465
|
+
except Exception as error:
|
|
466
|
+
error_message = str(error)
|
|
467
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
468
|
+
|
|
469
|
+
is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
|
|
470
|
+
|
|
471
|
+
if is_system_error:
|
|
472
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} system error, continuing: {error_message}")
|
|
473
|
+
continue
|
|
474
|
+
else:
|
|
475
|
+
if config.on_event:
|
|
476
|
+
config.on_event(GuardrailViolationEvent(
|
|
477
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
478
|
+
))
|
|
479
|
+
return InvalidValidationResult(error_message=error_message)
|
|
480
|
+
|
|
481
|
+
print("✅ All input guardrails passed (sequential).")
|
|
482
|
+
return ValidValidationResult()
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
async def execute_input_guardrails_parallel(
|
|
486
|
+
input_guardrails: List[Guardrail],
|
|
487
|
+
first_user_message: Message,
|
|
488
|
+
config: RunConfig
|
|
489
|
+
) -> ValidationResult:
|
|
490
|
+
"""Execute input guardrails in parallel."""
|
|
491
|
+
if not input_guardrails:
|
|
492
|
+
return ValidValidationResult()
|
|
493
|
+
|
|
494
|
+
print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails")
|
|
495
|
+
|
|
496
|
+
content = get_text_content(first_user_message.content)
|
|
497
|
+
|
|
498
|
+
async def run_guardrail(guardrail: Guardrail, index: int):
|
|
499
|
+
guardrail_name = f"input-guardrail-{index + 1}"
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
|
|
503
|
+
|
|
504
|
+
timeout_ms = DEFAULT_FAST_MODEL_TIMEOUT_MS if config.default_fast_model else DEFAULT_TIMEOUT_MS
|
|
505
|
+
|
|
506
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
507
|
+
result = await _with_timeout(guardrail(content), timeout_ms,
|
|
508
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms")
|
|
509
|
+
else:
|
|
510
|
+
result = guardrail(content)
|
|
511
|
+
|
|
512
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
|
|
513
|
+
return {'result': result, 'guardrail_index': index}
|
|
514
|
+
|
|
515
|
+
except Exception as error:
|
|
516
|
+
error_message = str(error)
|
|
517
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
518
|
+
|
|
519
|
+
return {
|
|
520
|
+
'result': ValidValidationResult(),
|
|
521
|
+
'guardrail_index': index,
|
|
522
|
+
'warning': f"Guardrail {index + 1} failed but was skipped: {error_message}"
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
# Run all guardrails in parallel
|
|
527
|
+
tasks = [run_guardrail(guardrail, i) for i, guardrail in enumerate(input_guardrails)]
|
|
528
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
529
|
+
|
|
530
|
+
print("[JAF:GUARDRAILS] Input guardrails completed. Checking results...")
|
|
531
|
+
|
|
532
|
+
warnings = []
|
|
533
|
+
|
|
534
|
+
for i, result in enumerate(results):
|
|
535
|
+
if isinstance(result, Exception):
|
|
536
|
+
error_message = str(result)
|
|
537
|
+
print(f"[JAF:GUARDRAILS] Input guardrail {i + 1} promise rejected: {error_message}")
|
|
538
|
+
warnings.append(f"Guardrail {i + 1} failed: {error_message}")
|
|
539
|
+
continue
|
|
540
|
+
|
|
541
|
+
if 'warning' in result:
|
|
542
|
+
warnings.append(result['warning'])
|
|
543
|
+
|
|
544
|
+
validation_result = result['result']
|
|
545
|
+
if not validation_result.is_valid:
|
|
546
|
+
error_message = getattr(validation_result, 'error_message', 'Guardrail violation')
|
|
547
|
+
print(f"🚨 Input guardrail {result['guardrail_index'] + 1} violation: {error_message}")
|
|
548
|
+
if config.on_event:
|
|
549
|
+
config.on_event(GuardrailViolationEvent(
|
|
550
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
551
|
+
))
|
|
552
|
+
return validation_result
|
|
553
|
+
|
|
554
|
+
if warnings:
|
|
555
|
+
print(f"[JAF:GUARDRAILS] {len(warnings)} guardrail warnings: {warnings}")
|
|
556
|
+
|
|
557
|
+
print("✅ All input guardrails passed.")
|
|
558
|
+
return ValidValidationResult()
|
|
559
|
+
|
|
560
|
+
except Exception as error:
|
|
561
|
+
print(f"[JAF:GUARDRAILS] Catastrophic failure in input guardrail execution: {error}")
|
|
562
|
+
return ValidValidationResult() # Fail gracefully
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
async def execute_output_guardrails(
|
|
566
|
+
output_guardrails: List[Guardrail],
|
|
567
|
+
output: Any,
|
|
568
|
+
config: RunConfig
|
|
569
|
+
) -> ValidationResult:
|
|
570
|
+
"""Execute output guardrails sequentially."""
|
|
571
|
+
if not output_guardrails:
|
|
572
|
+
return ValidValidationResult()
|
|
573
|
+
|
|
574
|
+
print(f"[JAF:GUARDRAILS] Checking {len(output_guardrails)} output guardrails")
|
|
575
|
+
|
|
576
|
+
for i, guardrail in enumerate(output_guardrails):
|
|
577
|
+
guardrail_name = f"output-guardrail-{i + 1}"
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
timeout_ms = OUTPUT_GUARDRAIL_TIMEOUT_MS
|
|
581
|
+
|
|
582
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
583
|
+
result = await _with_timeout(guardrail(output), timeout_ms,
|
|
584
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms")
|
|
585
|
+
else:
|
|
586
|
+
result = guardrail(output)
|
|
587
|
+
|
|
588
|
+
if not result.is_valid:
|
|
589
|
+
error_message = getattr(result, 'error_message', 'Guardrail violation')
|
|
590
|
+
print(f"🚨 {guardrail_name} violation: {error_message}")
|
|
591
|
+
if config.on_event:
|
|
592
|
+
config.on_event(GuardrailViolationEvent(
|
|
593
|
+
data=GuardrailViolationEventData(stage='output', reason=error_message)
|
|
594
|
+
))
|
|
595
|
+
return result
|
|
596
|
+
|
|
597
|
+
print(f"✅ {guardrail_name} passed")
|
|
598
|
+
|
|
599
|
+
except Exception as error:
|
|
600
|
+
error_message = str(error)
|
|
601
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
602
|
+
|
|
603
|
+
is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
|
|
604
|
+
|
|
605
|
+
if is_system_error:
|
|
606
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} system error, allowing output: {error_message}")
|
|
607
|
+
continue
|
|
608
|
+
else:
|
|
609
|
+
if config.on_event:
|
|
610
|
+
config.on_event(GuardrailViolationEvent(
|
|
611
|
+
data=GuardrailViolationEventData(stage='output', reason=error_message)
|
|
612
|
+
))
|
|
613
|
+
return InvalidValidationResult(error_message=error_message)
|
|
614
|
+
|
|
615
|
+
print("✅ All output guardrails passed")
|
|
616
|
+
return ValidValidationResult()
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def cleanup_circuit_breakers() -> None:
|
|
620
|
+
"""Clean up old circuit breakers."""
|
|
621
|
+
to_remove = []
|
|
622
|
+
for key, breaker in _circuit_breakers.items():
|
|
623
|
+
if breaker.should_be_cleaned_up(CIRCUIT_BREAKER_CLEANUP_MAX_AGE):
|
|
624
|
+
to_remove.append(key)
|
|
625
|
+
|
|
626
|
+
for key in to_remove:
|
|
627
|
+
del _circuit_breakers[key]
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
class GuardrailCacheManager:
|
|
631
|
+
"""Manager for guardrail cache operations."""
|
|
632
|
+
|
|
633
|
+
@staticmethod
|
|
634
|
+
def get_stats() -> Dict[str, Any]:
|
|
635
|
+
"""Get cache statistics."""
|
|
636
|
+
return _guardrail_cache.get_stats()
|
|
637
|
+
|
|
638
|
+
@staticmethod
|
|
639
|
+
def clear() -> None:
|
|
640
|
+
"""Clear cache."""
|
|
641
|
+
_guardrail_cache.clear()
|
|
642
|
+
|
|
643
|
+
@staticmethod
|
|
644
|
+
def get_metrics() -> Dict[str, Any]:
|
|
645
|
+
"""Get cache metrics."""
|
|
646
|
+
stats = _guardrail_cache.get_stats()
|
|
647
|
+
return {
|
|
648
|
+
**stats,
|
|
649
|
+
'utilization_percent': (stats['size'] / stats['max_size']) * 100,
|
|
650
|
+
'circuit_breakers_count': len(_circuit_breakers)
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def log_stats() -> None:
|
|
655
|
+
"""Log cache statistics."""
|
|
656
|
+
metrics = GuardrailCacheManager.get_metrics()
|
|
657
|
+
print('[JAF:GUARDRAILS] Cache stats:', metrics)
|
|
658
|
+
|
|
659
|
+
@staticmethod
|
|
660
|
+
def cleanup() -> None:
|
|
661
|
+
"""Cleanup old entries."""
|
|
662
|
+
cleanup_circuit_breakers()
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
# Export the cache manager
|
|
666
|
+
guardrail_cache_manager = GuardrailCacheManager()
|