jaf-py 2.4.5__py3-none-any.whl → 2.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/core/engine.py +169 -65
- jaf/core/guardrails.py +666 -0
- jaf/core/types.py +83 -1
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/METADATA +1 -1
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/RECORD +9 -8
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/WHEEL +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.4.5.dist-info → jaf_py-2.4.6.dist-info}/top_level.txt +0 -0
jaf/core/engine.py
CHANGED
|
@@ -32,6 +32,8 @@ from .types import (
|
|
|
32
32
|
Interruption,
|
|
33
33
|
GuardrailEvent,
|
|
34
34
|
GuardrailEventData,
|
|
35
|
+
GuardrailViolationEvent,
|
|
36
|
+
GuardrailViolationEventData,
|
|
35
37
|
MemoryEvent,
|
|
36
38
|
MemoryEventData,
|
|
37
39
|
OutputParseEvent,
|
|
@@ -61,6 +63,15 @@ from .types import (
|
|
|
61
63
|
ToolCallFunction,
|
|
62
64
|
ToolCallStartEvent,
|
|
63
65
|
ToolCallStartEventData,
|
|
66
|
+
Guardrail,
|
|
67
|
+
ValidValidationResult,
|
|
68
|
+
InvalidValidationResult,
|
|
69
|
+
)
|
|
70
|
+
from .guardrails import (
|
|
71
|
+
build_effective_guardrails,
|
|
72
|
+
execute_input_guardrails_sequential,
|
|
73
|
+
execute_input_guardrails_parallel,
|
|
74
|
+
execute_output_guardrails,
|
|
64
75
|
)
|
|
65
76
|
|
|
66
77
|
|
|
@@ -399,36 +410,6 @@ async def _run_internal(
|
|
|
399
410
|
if resumed:
|
|
400
411
|
return resumed
|
|
401
412
|
|
|
402
|
-
# Check initial input guardrails on first turn
|
|
403
|
-
if state.turn_count == 0:
|
|
404
|
-
first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
|
|
405
|
-
if first_user_message and config.initial_input_guardrails:
|
|
406
|
-
for guardrail in config.initial_input_guardrails:
|
|
407
|
-
if config.on_event:
|
|
408
|
-
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
409
|
-
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
410
|
-
content=get_text_content(first_user_message.content)
|
|
411
|
-
)))
|
|
412
|
-
if asyncio.iscoroutinefunction(guardrail):
|
|
413
|
-
result = await guardrail(get_text_content(first_user_message.content))
|
|
414
|
-
else:
|
|
415
|
-
result = guardrail(get_text_content(first_user_message.content))
|
|
416
|
-
|
|
417
|
-
if not result.is_valid:
|
|
418
|
-
if config.on_event:
|
|
419
|
-
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
420
|
-
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
421
|
-
content=get_text_content(first_user_message.content),
|
|
422
|
-
is_valid=False,
|
|
423
|
-
error_message=result.error_message
|
|
424
|
-
)))
|
|
425
|
-
return RunResult(
|
|
426
|
-
final_state=state,
|
|
427
|
-
outcome=ErrorOutcome(error=InputGuardrailTripwire(
|
|
428
|
-
reason=result.error_message or "Input guardrail failed"
|
|
429
|
-
))
|
|
430
|
-
)
|
|
431
|
-
|
|
432
413
|
# Check max turns
|
|
433
414
|
max_turns = config.max_turns or 50
|
|
434
415
|
if state.turn_count >= max_turns:
|
|
@@ -445,6 +426,105 @@ async def _run_internal(
|
|
|
445
426
|
outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
|
|
446
427
|
)
|
|
447
428
|
|
|
429
|
+
# Determine if agent has advanced guardrails configuration
|
|
430
|
+
has_advanced_guardrails = bool(
|
|
431
|
+
current_agent.advanced_config and
|
|
432
|
+
current_agent.advanced_config.guardrails and
|
|
433
|
+
(current_agent.advanced_config.guardrails.input_prompt or
|
|
434
|
+
current_agent.advanced_config.guardrails.output_prompt or
|
|
435
|
+
current_agent.advanced_config.guardrails.require_citations)
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
print('[JAF:ENGINE] Debug guardrails setup:', {
|
|
439
|
+
'agent_name': current_agent.name,
|
|
440
|
+
'has_advanced_config': bool(current_agent.advanced_config),
|
|
441
|
+
'has_advanced_guardrails': has_advanced_guardrails,
|
|
442
|
+
'initial_input_guardrails': len(config.initial_input_guardrails or []),
|
|
443
|
+
'final_output_guardrails': len(config.final_output_guardrails or [])
|
|
444
|
+
})
|
|
445
|
+
|
|
446
|
+
# Build effective guardrails
|
|
447
|
+
effective_input_guardrails: List[Guardrail] = []
|
|
448
|
+
effective_output_guardrails: List[Guardrail] = []
|
|
449
|
+
|
|
450
|
+
if has_advanced_guardrails:
|
|
451
|
+
result = await build_effective_guardrails(current_agent, config)
|
|
452
|
+
effective_input_guardrails, effective_output_guardrails = result
|
|
453
|
+
else:
|
|
454
|
+
effective_input_guardrails = list(config.initial_input_guardrails or [])
|
|
455
|
+
effective_output_guardrails = list(config.final_output_guardrails or [])
|
|
456
|
+
|
|
457
|
+
# Execute input guardrails on first turn
|
|
458
|
+
input_guardrails_to_run = (effective_input_guardrails
|
|
459
|
+
if state.turn_count == 0 and effective_input_guardrails
|
|
460
|
+
else [])
|
|
461
|
+
|
|
462
|
+
print('[JAF:ENGINE] Input guardrails to run:', {
|
|
463
|
+
'turn_count': state.turn_count,
|
|
464
|
+
'effective_input_length': len(effective_input_guardrails),
|
|
465
|
+
'input_guardrails_to_run_length': len(input_guardrails_to_run),
|
|
466
|
+
'has_advanced_guardrails': has_advanced_guardrails
|
|
467
|
+
})
|
|
468
|
+
|
|
469
|
+
if input_guardrails_to_run and state.turn_count == 0:
|
|
470
|
+
first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
|
|
471
|
+
if first_user_message:
|
|
472
|
+
if has_advanced_guardrails:
|
|
473
|
+
execution_mode = (current_agent.advanced_config.guardrails.execution_mode
|
|
474
|
+
if current_agent.advanced_config and current_agent.advanced_config.guardrails
|
|
475
|
+
else 'parallel')
|
|
476
|
+
|
|
477
|
+
if execution_mode == 'sequential':
|
|
478
|
+
guardrail_result = await execute_input_guardrails_sequential(
|
|
479
|
+
input_guardrails_to_run, first_user_message, config
|
|
480
|
+
)
|
|
481
|
+
if not guardrail_result.is_valid:
|
|
482
|
+
return RunResult(
|
|
483
|
+
final_state=state,
|
|
484
|
+
outcome=ErrorOutcome(error=InputGuardrailTripwire(
|
|
485
|
+
reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
|
|
486
|
+
))
|
|
487
|
+
)
|
|
488
|
+
else:
|
|
489
|
+
# Parallel execution with LLM call overlap
|
|
490
|
+
guardrail_result = await execute_input_guardrails_parallel(
|
|
491
|
+
input_guardrails_to_run, first_user_message, config
|
|
492
|
+
)
|
|
493
|
+
if not guardrail_result.is_valid:
|
|
494
|
+
print(f"🚨 Input guardrail violation: {getattr(guardrail_result, 'error_message', 'Unknown violation')}")
|
|
495
|
+
return RunResult(
|
|
496
|
+
final_state=state,
|
|
497
|
+
outcome=ErrorOutcome(error=InputGuardrailTripwire(
|
|
498
|
+
reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
|
|
499
|
+
))
|
|
500
|
+
)
|
|
501
|
+
else:
|
|
502
|
+
# Legacy guardrails path
|
|
503
|
+
print('[JAF:ENGINE] Using LEGACY guardrails path with', len(input_guardrails_to_run), 'guardrails')
|
|
504
|
+
for guardrail in input_guardrails_to_run:
|
|
505
|
+
if config.on_event:
|
|
506
|
+
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
507
|
+
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
508
|
+
content=get_text_content(first_user_message.content)
|
|
509
|
+
)))
|
|
510
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
511
|
+
result = await guardrail(get_text_content(first_user_message.content))
|
|
512
|
+
else:
|
|
513
|
+
result = guardrail(get_text_content(first_user_message.content))
|
|
514
|
+
|
|
515
|
+
if not result.is_valid:
|
|
516
|
+
if config.on_event:
|
|
517
|
+
config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
|
|
518
|
+
stage='input',
|
|
519
|
+
reason=getattr(result, 'error_message', 'Input guardrail failed')
|
|
520
|
+
)))
|
|
521
|
+
return RunResult(
|
|
522
|
+
final_state=state,
|
|
523
|
+
outcome=ErrorOutcome(error=InputGuardrailTripwire(
|
|
524
|
+
reason=getattr(result, 'error_message', 'Input guardrail failed')
|
|
525
|
+
))
|
|
526
|
+
)
|
|
527
|
+
|
|
448
528
|
# Agent debugging logs removed for performance
|
|
449
529
|
|
|
450
530
|
# Get model name
|
|
@@ -752,13 +832,27 @@ async def _run_internal(
|
|
|
752
832
|
)))
|
|
753
833
|
|
|
754
834
|
# Check final output guardrails
|
|
755
|
-
if
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
835
|
+
if has_advanced_guardrails:
|
|
836
|
+
# Use new advanced system
|
|
837
|
+
output_guardrail_result = await execute_output_guardrails(
|
|
838
|
+
effective_output_guardrails, output_data, config
|
|
839
|
+
)
|
|
840
|
+
if not output_guardrail_result.is_valid:
|
|
841
|
+
return RunResult(
|
|
842
|
+
final_state=replace(state, messages=new_messages),
|
|
843
|
+
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
844
|
+
reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
|
|
845
|
+
))
|
|
846
|
+
)
|
|
847
|
+
else:
|
|
848
|
+
# Legacy system
|
|
849
|
+
if effective_output_guardrails:
|
|
850
|
+
for guardrail in effective_output_guardrails:
|
|
851
|
+
if config.on_event:
|
|
852
|
+
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
853
|
+
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
854
|
+
content=output_data
|
|
855
|
+
)))
|
|
762
856
|
if asyncio.iscoroutinefunction(guardrail):
|
|
763
857
|
result = await guardrail(output_data)
|
|
764
858
|
else:
|
|
@@ -766,16 +860,14 @@ async def _run_internal(
|
|
|
766
860
|
|
|
767
861
|
if not result.is_valid:
|
|
768
862
|
if config.on_event:
|
|
769
|
-
config.on_event(
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
is_valid=False,
|
|
773
|
-
error_message=result.error_message
|
|
863
|
+
config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
|
|
864
|
+
stage='output',
|
|
865
|
+
reason=getattr(result, 'error_message', 'Output guardrail failed')
|
|
774
866
|
)))
|
|
775
867
|
return RunResult(
|
|
776
868
|
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
777
869
|
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
778
|
-
reason=result
|
|
870
|
+
reason=getattr(result, 'error_message', 'Output guardrail failed')
|
|
779
871
|
))
|
|
780
872
|
)
|
|
781
873
|
|
|
@@ -799,32 +891,44 @@ async def _run_internal(
|
|
|
799
891
|
)
|
|
800
892
|
else:
|
|
801
893
|
# No output codec, return content as string
|
|
802
|
-
if
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
894
|
+
if has_advanced_guardrails:
|
|
895
|
+
# Use new advanced system
|
|
896
|
+
output_guardrail_result = await execute_output_guardrails(
|
|
897
|
+
effective_output_guardrails, get_text_content(assistant_message.content), config
|
|
898
|
+
)
|
|
899
|
+
if not output_guardrail_result.is_valid:
|
|
900
|
+
return RunResult(
|
|
901
|
+
final_state=replace(state, messages=new_messages),
|
|
902
|
+
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
903
|
+
reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
|
|
904
|
+
))
|
|
905
|
+
)
|
|
906
|
+
else:
|
|
907
|
+
# Legacy system
|
|
908
|
+
if effective_output_guardrails:
|
|
909
|
+
for guardrail in effective_output_guardrails:
|
|
815
910
|
if config.on_event:
|
|
816
911
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
817
912
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
818
|
-
content=get_text_content(assistant_message.content)
|
|
819
|
-
is_valid=False,
|
|
820
|
-
error_message=result.error_message
|
|
913
|
+
content=get_text_content(assistant_message.content)
|
|
821
914
|
)))
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
915
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
916
|
+
result = await guardrail(get_text_content(assistant_message.content))
|
|
917
|
+
else:
|
|
918
|
+
result = guardrail(get_text_content(assistant_message.content))
|
|
919
|
+
|
|
920
|
+
if not result.is_valid:
|
|
921
|
+
if config.on_event:
|
|
922
|
+
config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
|
|
923
|
+
stage='output',
|
|
924
|
+
reason=getattr(result, 'error_message', 'Output guardrail failed')
|
|
925
|
+
)))
|
|
926
|
+
return RunResult(
|
|
927
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
928
|
+
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
929
|
+
reason=getattr(result, 'error_message', 'Output guardrail failed')
|
|
930
|
+
))
|
|
931
|
+
)
|
|
828
932
|
|
|
829
933
|
return RunResult(
|
|
830
934
|
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
|
jaf/core/guardrails.py
ADDED
|
@@ -0,0 +1,666 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Advanced guardrails implementation for JAF framework.
|
|
3
|
+
|
|
4
|
+
This module provides LLM-based guardrails with caching, circuit breaking,
|
|
5
|
+
and execution strategies for input validation and output filtering.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
import re
|
|
11
|
+
import time
|
|
12
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
13
|
+
from dataclasses import dataclass, field
|
|
14
|
+
|
|
15
|
+
from .types import (
|
|
16
|
+
Agent,
|
|
17
|
+
RunConfig,
|
|
18
|
+
RunState,
|
|
19
|
+
ValidationResult,
|
|
20
|
+
ValidValidationResult,
|
|
21
|
+
InvalidValidationResult,
|
|
22
|
+
Guardrail,
|
|
23
|
+
AdvancedGuardrailsConfig,
|
|
24
|
+
validate_guardrails_config,
|
|
25
|
+
json_parse_llm_output,
|
|
26
|
+
get_text_content,
|
|
27
|
+
Message,
|
|
28
|
+
ContentRole,
|
|
29
|
+
create_run_id,
|
|
30
|
+
create_trace_id,
|
|
31
|
+
GuardrailEvent,
|
|
32
|
+
GuardrailEventData,
|
|
33
|
+
GuardrailViolationEvent,
|
|
34
|
+
GuardrailViolationEventData
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Constants for content length limits
|
|
38
|
+
SHORT_TIMEOUT_MAX_CONTENT = 10000
|
|
39
|
+
LONG_TIMEOUT_MAX_CONTENT = 50000
|
|
40
|
+
CIRCUIT_BREAKER_CLEANUP_MAX_AGE = 10 * 60 * 1000 # 10 minutes
|
|
41
|
+
|
|
42
|
+
# Constants for timeout values
|
|
43
|
+
DEFAULT_FAST_MODEL_TIMEOUT_MS = 10000
|
|
44
|
+
DEFAULT_TIMEOUT_MS = 5000
|
|
45
|
+
GUARDRAIL_TIMEOUT_MS = 10000
|
|
46
|
+
OUTPUT_GUARDRAIL_TIMEOUT_MS = 15000
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class GuardrailCircuitBreaker:
|
|
50
|
+
"""Circuit breaker for guardrail execution to handle repeated failures."""
|
|
51
|
+
|
|
52
|
+
def __init__(self, max_failures: int = 5, reset_time_ms: int = 60000):
|
|
53
|
+
self.failures = 0
|
|
54
|
+
self.last_failure_time = 0
|
|
55
|
+
self.max_failures = max_failures
|
|
56
|
+
self.reset_time_ms = reset_time_ms
|
|
57
|
+
|
|
58
|
+
def is_open(self) -> bool:
|
|
59
|
+
"""Check if circuit breaker is open (blocking requests)."""
|
|
60
|
+
if self.failures < self.max_failures:
|
|
61
|
+
return False
|
|
62
|
+
|
|
63
|
+
time_since_last_failure = (time.time() * 1000) - self.last_failure_time
|
|
64
|
+
if time_since_last_failure > self.reset_time_ms:
|
|
65
|
+
self.failures = 0
|
|
66
|
+
return False
|
|
67
|
+
|
|
68
|
+
return True
|
|
69
|
+
|
|
70
|
+
def record_failure(self) -> None:
|
|
71
|
+
"""Record a failure."""
|
|
72
|
+
self.failures += 1
|
|
73
|
+
self.last_failure_time = time.time() * 1000
|
|
74
|
+
|
|
75
|
+
def record_success(self) -> None:
|
|
76
|
+
"""Record a success, resetting the failure count."""
|
|
77
|
+
self.failures = 0
|
|
78
|
+
|
|
79
|
+
def should_be_cleaned_up(self, max_age: int) -> bool:
|
|
80
|
+
"""Check if this circuit breaker should be cleaned up."""
|
|
81
|
+
now = time.time() * 1000
|
|
82
|
+
return (self.last_failure_time > 0 and
|
|
83
|
+
(now - self.last_failure_time) > max_age and
|
|
84
|
+
not self.is_open())
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class CacheEntry:
|
|
89
|
+
"""Cache entry for guardrail results."""
|
|
90
|
+
result: ValidationResult
|
|
91
|
+
timestamp: float
|
|
92
|
+
hit_count: int = 1
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class GuardrailCache:
|
|
96
|
+
"""LRU cache for guardrail results."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, max_size: int = 1000, ttl_ms: int = 300000):
|
|
99
|
+
self.cache: Dict[str, CacheEntry] = {}
|
|
100
|
+
self.max_size = max_size
|
|
101
|
+
self.ttl_ms = ttl_ms
|
|
102
|
+
|
|
103
|
+
def _create_key(self, stage: str, rule_prompt: str, content: str, model_name: str) -> str:
|
|
104
|
+
"""Create a cache key."""
|
|
105
|
+
content_hash = self._hash_string(content[:1000])
|
|
106
|
+
rule_hash = self._hash_string(rule_prompt)
|
|
107
|
+
return f"guardrail_{stage}_{model_name}_{rule_hash}_{content_hash}_{len(content)}"
|
|
108
|
+
|
|
109
|
+
def _hash_string(self, s: str) -> str:
|
|
110
|
+
"""Simple hash function for strings."""
|
|
111
|
+
hash_val = 0
|
|
112
|
+
for char in s:
|
|
113
|
+
hash_val = ((hash_val << 5) - hash_val) + ord(char)
|
|
114
|
+
hash_val = hash_val & 0xFFFFFFFF # Keep it 32-bit
|
|
115
|
+
return str(abs(hash_val))
|
|
116
|
+
|
|
117
|
+
def _is_expired(self, entry: CacheEntry) -> bool:
|
|
118
|
+
"""Check if cache entry is expired."""
|
|
119
|
+
return (time.time() * 1000) - entry.timestamp > self.ttl_ms
|
|
120
|
+
|
|
121
|
+
def _evict_lru(self) -> None:
|
|
122
|
+
"""Evict least recently used entry."""
|
|
123
|
+
if len(self.cache) < self.max_size:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
lru_key: Optional[str] = None
|
|
127
|
+
lru_score = float('inf')
|
|
128
|
+
now = time.time() * 1000
|
|
129
|
+
|
|
130
|
+
for key, entry in self.cache.items():
|
|
131
|
+
age_hours = (now - entry.timestamp) / (1000 * 60 * 60)
|
|
132
|
+
score = entry.hit_count / (1 + age_hours)
|
|
133
|
+
if score < lru_score:
|
|
134
|
+
lru_score = score
|
|
135
|
+
lru_key = key
|
|
136
|
+
|
|
137
|
+
if lru_key:
|
|
138
|
+
del self.cache[lru_key]
|
|
139
|
+
|
|
140
|
+
def get(self, stage: str, rule_prompt: str, content: str, model_name: str) -> Optional[ValidationResult]:
|
|
141
|
+
"""Get cached result."""
|
|
142
|
+
key = self._create_key(stage, rule_prompt, content, model_name)
|
|
143
|
+
entry = self.cache.get(key)
|
|
144
|
+
|
|
145
|
+
if not entry or self._is_expired(entry):
|
|
146
|
+
if entry:
|
|
147
|
+
del self.cache[key]
|
|
148
|
+
return None
|
|
149
|
+
|
|
150
|
+
entry.hit_count += 1
|
|
151
|
+
entry.timestamp = time.time() * 1000
|
|
152
|
+
|
|
153
|
+
return entry.result
|
|
154
|
+
|
|
155
|
+
def set(self, stage: str, rule_prompt: str, content: str, model_name: str, result: ValidationResult) -> None:
|
|
156
|
+
"""Cache a result."""
|
|
157
|
+
key = self._create_key(stage, rule_prompt, content, model_name)
|
|
158
|
+
|
|
159
|
+
self._evict_lru()
|
|
160
|
+
|
|
161
|
+
self.cache[key] = CacheEntry(
|
|
162
|
+
result=result,
|
|
163
|
+
timestamp=time.time() * 1000,
|
|
164
|
+
hit_count=1
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def clear(self) -> None:
|
|
168
|
+
"""Clear all cached entries."""
|
|
169
|
+
self.cache.clear()
|
|
170
|
+
|
|
171
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
172
|
+
"""Get cache statistics."""
|
|
173
|
+
return {
|
|
174
|
+
'size': len(self.cache),
|
|
175
|
+
'max_size': self.max_size
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
# Global instances
|
|
180
|
+
_guardrail_cache = GuardrailCache()
|
|
181
|
+
_circuit_breakers: Dict[str, GuardrailCircuitBreaker] = {}
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _get_circuit_breaker(stage: str, model_name: str) -> GuardrailCircuitBreaker:
|
|
185
|
+
"""Get or create a circuit breaker for a stage/model combination."""
|
|
186
|
+
key = f"{stage}-{model_name}"
|
|
187
|
+
if key not in _circuit_breakers:
|
|
188
|
+
_circuit_breakers[key] = GuardrailCircuitBreaker()
|
|
189
|
+
return _circuit_breakers[key]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
async def _with_timeout(awaitable, timeout_ms: int, error_message: str):
|
|
193
|
+
"""Run an awaitable with a timeout."""
|
|
194
|
+
try:
|
|
195
|
+
return await asyncio.wait_for(awaitable, timeout=timeout_ms / 1000)
|
|
196
|
+
except asyncio.TimeoutError:
|
|
197
|
+
raise TimeoutError(f"Timeout: {error_message}")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
async def _create_llm_guardrail(
|
|
201
|
+
config: RunConfig,
|
|
202
|
+
stage: str,
|
|
203
|
+
rule_prompt: str,
|
|
204
|
+
fast_model: Optional[str] = None,
|
|
205
|
+
fail_safe: str = 'allow',
|
|
206
|
+
timeout_ms: int = 30000
|
|
207
|
+
) -> Guardrail:
|
|
208
|
+
"""Create an LLM-based guardrail function."""
|
|
209
|
+
|
|
210
|
+
async def guardrail_func(content: Any) -> ValidationResult:
|
|
211
|
+
content_str = str(content) if not isinstance(content, str) else content
|
|
212
|
+
|
|
213
|
+
model_to_use = fast_model or config.default_fast_model
|
|
214
|
+
if not model_to_use:
|
|
215
|
+
print(f"[JAF:GUARDRAILS] No fast model available for LLM guardrail evaluation, using failSafe: {fail_safe}")
|
|
216
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
217
|
+
else InvalidValidationResult(error_message='No model available for guardrail evaluation'))
|
|
218
|
+
|
|
219
|
+
# Check cache first
|
|
220
|
+
cached_result = _guardrail_cache.get(stage, rule_prompt, content_str, model_to_use)
|
|
221
|
+
if cached_result:
|
|
222
|
+
print(f"[JAF:GUARDRAILS] Cache hit for {stage} guardrail")
|
|
223
|
+
return cached_result
|
|
224
|
+
|
|
225
|
+
# Check circuit breaker
|
|
226
|
+
circuit_breaker = _get_circuit_breaker(stage, model_to_use)
|
|
227
|
+
if circuit_breaker.is_open():
|
|
228
|
+
print(f"[JAF:GUARDRAILS] Circuit breaker open for {stage} guardrail on model {model_to_use}, using failSafe: {fail_safe}")
|
|
229
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
230
|
+
else InvalidValidationResult(error_message='Circuit breaker open - too many recent failures'))
|
|
231
|
+
|
|
232
|
+
# Validate content
|
|
233
|
+
if not content_str:
|
|
234
|
+
print(f"[JAF:GUARDRAILS] Invalid content provided to {stage} guardrail")
|
|
235
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
236
|
+
else InvalidValidationResult(error_message='Invalid content provided to guardrail'))
|
|
237
|
+
|
|
238
|
+
# Check content length
|
|
239
|
+
max_content_length = SHORT_TIMEOUT_MAX_CONTENT if timeout_ms < 10000 else LONG_TIMEOUT_MAX_CONTENT
|
|
240
|
+
if len(content_str) > max_content_length:
|
|
241
|
+
print(f"[JAF:GUARDRAILS] Content too large for {stage} guardrail ({len(content_str)} chars, max: {max_content_length})")
|
|
242
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
243
|
+
else InvalidValidationResult(error_message=f'Content too large for guardrail evaluation ({len(content_str)} > {max_content_length} chars)'))
|
|
244
|
+
|
|
245
|
+
# Prepare content for evaluation
|
|
246
|
+
sanitized_content = (content_str
|
|
247
|
+
.replace('"""', '[TRIPLE_QUOTE]')
|
|
248
|
+
.replace('\n', ' ')[:2000])
|
|
249
|
+
|
|
250
|
+
eval_prompt = f"""You are a guardrail validator for {stage}.
|
|
251
|
+
Rules:
|
|
252
|
+
{rule_prompt}
|
|
253
|
+
|
|
254
|
+
Decide if the {stage} {'user message' if stage == 'input' else 'assistant output'} complies with the rules.
|
|
255
|
+
Return a JSON object with keys: {{"allowed": boolean, "reason": string}}. Do not include extra text.
|
|
256
|
+
{stage.capitalize()} {'user message' if stage == 'input' else 'assistant output'}:
|
|
257
|
+
\"\"\"
|
|
258
|
+
{sanitized_content}
|
|
259
|
+
\"\"\""""
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Create temporary state for guardrail evaluation
|
|
263
|
+
temp_state = RunState(
|
|
264
|
+
run_id=create_run_id('guardrail-eval'),
|
|
265
|
+
trace_id=create_trace_id('guardrail-eval'),
|
|
266
|
+
messages=[Message(role=ContentRole.USER, content=eval_prompt)],
|
|
267
|
+
current_agent_name='guardrail-evaluator',
|
|
268
|
+
context={},
|
|
269
|
+
turn_count=0
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Create evaluation agent
|
|
273
|
+
def eval_instructions(state: RunState) -> str:
|
|
274
|
+
return 'You are a guardrail validator. Return only valid JSON.'
|
|
275
|
+
|
|
276
|
+
eval_agent = Agent(
|
|
277
|
+
name='guardrail-evaluator',
|
|
278
|
+
instructions=eval_instructions,
|
|
279
|
+
model_config={'name': model_to_use} if hasattr(config, 'ModelConfig') else None
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# Create guardrail config (no guardrails to avoid recursion)
|
|
283
|
+
guardrail_config = RunConfig(
|
|
284
|
+
agent_registry=config.agent_registry,
|
|
285
|
+
model_provider=config.model_provider,
|
|
286
|
+
max_turns=1,
|
|
287
|
+
default_fast_model=config.default_fast_model,
|
|
288
|
+
model_override=model_to_use,
|
|
289
|
+
initial_input_guardrails=None,
|
|
290
|
+
final_output_guardrails=None,
|
|
291
|
+
on_event=None
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Execute with timeout
|
|
295
|
+
completion_promise = config.model_provider.get_completion(temp_state, eval_agent, guardrail_config)
|
|
296
|
+
response = await _with_timeout(
|
|
297
|
+
completion_promise,
|
|
298
|
+
timeout_ms,
|
|
299
|
+
f"{stage} guardrail evaluation timed out after {timeout_ms}ms"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
# Handle different response formats
|
|
303
|
+
response_content = None
|
|
304
|
+
if hasattr(response, 'message') and response.message:
|
|
305
|
+
if hasattr(response.message, 'content'):
|
|
306
|
+
response_content = response.message.content
|
|
307
|
+
elif isinstance(response, dict):
|
|
308
|
+
if 'message' in response and response['message']:
|
|
309
|
+
if isinstance(response['message'], dict) and 'content' in response['message']:
|
|
310
|
+
response_content = response['message']['content']
|
|
311
|
+
elif hasattr(response['message'], 'content'):
|
|
312
|
+
response_content = response['message'].content
|
|
313
|
+
|
|
314
|
+
if not response_content:
|
|
315
|
+
circuit_breaker.record_success()
|
|
316
|
+
result = ValidValidationResult()
|
|
317
|
+
_guardrail_cache.set(stage, rule_prompt, content_str, model_to_use, result)
|
|
318
|
+
return result
|
|
319
|
+
|
|
320
|
+
# Parse response
|
|
321
|
+
parsed = json_parse_llm_output(response_content)
|
|
322
|
+
allowed = bool(parsed.get('allowed', True) if parsed else True)
|
|
323
|
+
reason = str(parsed.get('reason', 'Guardrail violation') if parsed else 'Guardrail violation')
|
|
324
|
+
|
|
325
|
+
circuit_breaker.record_success()
|
|
326
|
+
|
|
327
|
+
result = (ValidValidationResult() if allowed
|
|
328
|
+
else InvalidValidationResult(error_message=reason))
|
|
329
|
+
|
|
330
|
+
_guardrail_cache.set(stage, rule_prompt, content_str, model_to_use, result)
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
except Exception as e:
|
|
334
|
+
circuit_breaker.record_failure()
|
|
335
|
+
|
|
336
|
+
error_message = str(e)
|
|
337
|
+
is_timeout = 'Timeout' in error_message
|
|
338
|
+
|
|
339
|
+
log_message = f"[JAF:GUARDRAILS] {stage} guardrail evaluation failed"
|
|
340
|
+
if is_timeout:
|
|
341
|
+
print(f"{log_message} due to timeout ({timeout_ms}ms), using failSafe: {fail_safe}")
|
|
342
|
+
else:
|
|
343
|
+
print(f"{log_message}, using failSafe: {fail_safe} - {error_message}")
|
|
344
|
+
|
|
345
|
+
return (ValidValidationResult() if fail_safe == 'allow'
|
|
346
|
+
else InvalidValidationResult(error_message=f'Guardrail evaluation failed: {error_message}'))
|
|
347
|
+
|
|
348
|
+
return guardrail_func
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def build_effective_guardrails(
|
|
352
|
+
current_agent: Agent,
|
|
353
|
+
config: RunConfig
|
|
354
|
+
) -> Tuple[List[Guardrail], List[Guardrail]]:
|
|
355
|
+
"""Build effective input and output guardrails for an agent."""
|
|
356
|
+
effective_input_guardrails: List[Guardrail] = []
|
|
357
|
+
effective_output_guardrails: List[Guardrail] = []
|
|
358
|
+
|
|
359
|
+
try:
|
|
360
|
+
raw_guardrails_cfg = (current_agent.advanced_config.guardrails
|
|
361
|
+
if current_agent.advanced_config
|
|
362
|
+
else None)
|
|
363
|
+
guardrails_cfg = validate_guardrails_config(raw_guardrails_cfg)
|
|
364
|
+
|
|
365
|
+
fast_model = guardrails_cfg.fast_model or config.default_fast_model
|
|
366
|
+
if not fast_model and (guardrails_cfg.input_prompt or guardrails_cfg.output_prompt):
|
|
367
|
+
print('[JAF:GUARDRAILS] No fast model available for LLM guardrails - skipping LLM-based validation')
|
|
368
|
+
|
|
369
|
+
print('[JAF:GUARDRAILS] Configuration:', {
|
|
370
|
+
'hasInputPrompt': bool(guardrails_cfg.input_prompt),
|
|
371
|
+
'hasOutputPrompt': bool(guardrails_cfg.output_prompt),
|
|
372
|
+
'requireCitations': guardrails_cfg.require_citations,
|
|
373
|
+
'executionMode': guardrails_cfg.execution_mode,
|
|
374
|
+
'failSafe': guardrails_cfg.fail_safe,
|
|
375
|
+
'timeoutMs': guardrails_cfg.timeout_ms,
|
|
376
|
+
'fastModel': fast_model or 'none'
|
|
377
|
+
})
|
|
378
|
+
|
|
379
|
+
# Start with global guardrails
|
|
380
|
+
effective_input_guardrails = list(config.initial_input_guardrails or [])
|
|
381
|
+
effective_output_guardrails = list(config.final_output_guardrails or [])
|
|
382
|
+
|
|
383
|
+
# Add input prompt guardrail
|
|
384
|
+
if guardrails_cfg.input_prompt and guardrails_cfg.input_prompt.strip():
|
|
385
|
+
input_guardrail = await _create_llm_guardrail(
|
|
386
|
+
config, 'input', guardrails_cfg.input_prompt,
|
|
387
|
+
fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
|
|
388
|
+
)
|
|
389
|
+
effective_input_guardrails.append(input_guardrail)
|
|
390
|
+
|
|
391
|
+
# Add citation requirement guardrail
|
|
392
|
+
if guardrails_cfg.require_citations:
|
|
393
|
+
def citation_guardrail(output: Any) -> ValidationResult:
|
|
394
|
+
def find_text(val: Any) -> str:
|
|
395
|
+
if isinstance(val, str):
|
|
396
|
+
return val
|
|
397
|
+
elif isinstance(val, list):
|
|
398
|
+
return ' '.join(find_text(item) for item in val)
|
|
399
|
+
elif isinstance(val, dict):
|
|
400
|
+
return ' '.join(find_text(v) for v in val.values())
|
|
401
|
+
else:
|
|
402
|
+
return str(val)
|
|
403
|
+
|
|
404
|
+
text = find_text(output)
|
|
405
|
+
has_citation = bool(re.search(r'\[(\d+)\]', text))
|
|
406
|
+
return (ValidValidationResult() if has_citation
|
|
407
|
+
else InvalidValidationResult(error_message="Missing required [n] citation in output"))
|
|
408
|
+
|
|
409
|
+
effective_output_guardrails.append(citation_guardrail)
|
|
410
|
+
|
|
411
|
+
# Add output prompt guardrail
|
|
412
|
+
if guardrails_cfg.output_prompt and guardrails_cfg.output_prompt.strip():
|
|
413
|
+
output_guardrail = await _create_llm_guardrail(
|
|
414
|
+
config, 'output', guardrails_cfg.output_prompt,
|
|
415
|
+
fast_model, guardrails_cfg.fail_safe, guardrails_cfg.timeout_ms
|
|
416
|
+
)
|
|
417
|
+
effective_output_guardrails.append(output_guardrail)
|
|
418
|
+
|
|
419
|
+
except Exception as e:
|
|
420
|
+
print(f'[JAF:GUARDRAILS] Failed to configure advanced guardrails: {e}')
|
|
421
|
+
# Fall back to global guardrails only
|
|
422
|
+
effective_input_guardrails = list(config.initial_input_guardrails or [])
|
|
423
|
+
effective_output_guardrails = list(config.final_output_guardrails or [])
|
|
424
|
+
|
|
425
|
+
return effective_input_guardrails, effective_output_guardrails
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
async def execute_input_guardrails_sequential(
|
|
429
|
+
input_guardrails: List[Guardrail],
|
|
430
|
+
first_user_message: Message,
|
|
431
|
+
config: RunConfig
|
|
432
|
+
) -> ValidationResult:
|
|
433
|
+
"""Execute input guardrails sequentially."""
|
|
434
|
+
if not input_guardrails:
|
|
435
|
+
return ValidValidationResult()
|
|
436
|
+
|
|
437
|
+
print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails (sequential)")
|
|
438
|
+
|
|
439
|
+
content = get_text_content(first_user_message.content)
|
|
440
|
+
|
|
441
|
+
for i, guardrail in enumerate(input_guardrails):
|
|
442
|
+
guardrail_name = f"input-guardrail-{i + 1}"
|
|
443
|
+
|
|
444
|
+
try:
|
|
445
|
+
print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
|
|
446
|
+
|
|
447
|
+
timeout_ms = GUARDRAIL_TIMEOUT_MS
|
|
448
|
+
result = await _with_timeout(
|
|
449
|
+
guardrail(content) if asyncio.iscoroutinefunction(guardrail) else guardrail(content),
|
|
450
|
+
timeout_ms,
|
|
451
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
|
|
455
|
+
|
|
456
|
+
if not result.is_valid:
|
|
457
|
+
error_message = getattr(result, 'error_message', 'Guardrail violation')
|
|
458
|
+
print(f"🚨 {guardrail_name} violation: {error_message}")
|
|
459
|
+
if config.on_event:
|
|
460
|
+
config.on_event(GuardrailViolationEvent(
|
|
461
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
462
|
+
))
|
|
463
|
+
return result
|
|
464
|
+
|
|
465
|
+
except Exception as error:
|
|
466
|
+
error_message = str(error)
|
|
467
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
468
|
+
|
|
469
|
+
is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
|
|
470
|
+
|
|
471
|
+
if is_system_error:
|
|
472
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} system error, continuing: {error_message}")
|
|
473
|
+
continue
|
|
474
|
+
else:
|
|
475
|
+
if config.on_event:
|
|
476
|
+
config.on_event(GuardrailViolationEvent(
|
|
477
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
478
|
+
))
|
|
479
|
+
return InvalidValidationResult(error_message=error_message)
|
|
480
|
+
|
|
481
|
+
print("✅ All input guardrails passed (sequential).")
|
|
482
|
+
return ValidValidationResult()
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
async def execute_input_guardrails_parallel(
|
|
486
|
+
input_guardrails: List[Guardrail],
|
|
487
|
+
first_user_message: Message,
|
|
488
|
+
config: RunConfig
|
|
489
|
+
) -> ValidationResult:
|
|
490
|
+
"""Execute input guardrails in parallel."""
|
|
491
|
+
if not input_guardrails:
|
|
492
|
+
return ValidValidationResult()
|
|
493
|
+
|
|
494
|
+
print(f"[JAF:GUARDRAILS] Starting {len(input_guardrails)} input guardrails")
|
|
495
|
+
|
|
496
|
+
content = get_text_content(first_user_message.content)
|
|
497
|
+
|
|
498
|
+
async def run_guardrail(guardrail: Guardrail, index: int):
|
|
499
|
+
guardrail_name = f"input-guardrail-{index + 1}"
|
|
500
|
+
|
|
501
|
+
try:
|
|
502
|
+
print(f"[JAF:GUARDRAILS] Starting {guardrail_name}")
|
|
503
|
+
|
|
504
|
+
timeout_ms = DEFAULT_FAST_MODEL_TIMEOUT_MS if config.default_fast_model else DEFAULT_TIMEOUT_MS
|
|
505
|
+
|
|
506
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
507
|
+
result = await _with_timeout(guardrail(content), timeout_ms,
|
|
508
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms")
|
|
509
|
+
else:
|
|
510
|
+
result = guardrail(content)
|
|
511
|
+
|
|
512
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} completed: {result}")
|
|
513
|
+
return {'result': result, 'guardrail_index': index}
|
|
514
|
+
|
|
515
|
+
except Exception as error:
|
|
516
|
+
error_message = str(error)
|
|
517
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
518
|
+
|
|
519
|
+
return {
|
|
520
|
+
'result': ValidValidationResult(),
|
|
521
|
+
'guardrail_index': index,
|
|
522
|
+
'warning': f"Guardrail {index + 1} failed but was skipped: {error_message}"
|
|
523
|
+
}
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
# Run all guardrails in parallel
|
|
527
|
+
tasks = [run_guardrail(guardrail, i) for i, guardrail in enumerate(input_guardrails)]
|
|
528
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
529
|
+
|
|
530
|
+
print("[JAF:GUARDRAILS] Input guardrails completed. Checking results...")
|
|
531
|
+
|
|
532
|
+
warnings = []
|
|
533
|
+
|
|
534
|
+
for i, result in enumerate(results):
|
|
535
|
+
if isinstance(result, Exception):
|
|
536
|
+
error_message = str(result)
|
|
537
|
+
print(f"[JAF:GUARDRAILS] Input guardrail {i + 1} promise rejected: {error_message}")
|
|
538
|
+
warnings.append(f"Guardrail {i + 1} failed: {error_message}")
|
|
539
|
+
continue
|
|
540
|
+
|
|
541
|
+
if 'warning' in result:
|
|
542
|
+
warnings.append(result['warning'])
|
|
543
|
+
|
|
544
|
+
validation_result = result['result']
|
|
545
|
+
if not validation_result.is_valid:
|
|
546
|
+
error_message = getattr(validation_result, 'error_message', 'Guardrail violation')
|
|
547
|
+
print(f"🚨 Input guardrail {result['guardrail_index'] + 1} violation: {error_message}")
|
|
548
|
+
if config.on_event:
|
|
549
|
+
config.on_event(GuardrailViolationEvent(
|
|
550
|
+
data=GuardrailViolationEventData(stage='input', reason=error_message)
|
|
551
|
+
))
|
|
552
|
+
return validation_result
|
|
553
|
+
|
|
554
|
+
if warnings:
|
|
555
|
+
print(f"[JAF:GUARDRAILS] {len(warnings)} guardrail warnings: {warnings}")
|
|
556
|
+
|
|
557
|
+
print("✅ All input guardrails passed.")
|
|
558
|
+
return ValidValidationResult()
|
|
559
|
+
|
|
560
|
+
except Exception as error:
|
|
561
|
+
print(f"[JAF:GUARDRAILS] Catastrophic failure in input guardrail execution: {error}")
|
|
562
|
+
return ValidValidationResult() # Fail gracefully
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
async def execute_output_guardrails(
|
|
566
|
+
output_guardrails: List[Guardrail],
|
|
567
|
+
output: Any,
|
|
568
|
+
config: RunConfig
|
|
569
|
+
) -> ValidationResult:
|
|
570
|
+
"""Execute output guardrails sequentially."""
|
|
571
|
+
if not output_guardrails:
|
|
572
|
+
return ValidValidationResult()
|
|
573
|
+
|
|
574
|
+
print(f"[JAF:GUARDRAILS] Checking {len(output_guardrails)} output guardrails")
|
|
575
|
+
|
|
576
|
+
for i, guardrail in enumerate(output_guardrails):
|
|
577
|
+
guardrail_name = f"output-guardrail-{i + 1}"
|
|
578
|
+
|
|
579
|
+
try:
|
|
580
|
+
timeout_ms = OUTPUT_GUARDRAIL_TIMEOUT_MS
|
|
581
|
+
|
|
582
|
+
if asyncio.iscoroutinefunction(guardrail):
|
|
583
|
+
result = await _with_timeout(guardrail(output), timeout_ms,
|
|
584
|
+
f"{guardrail_name} execution timed out after {timeout_ms}ms")
|
|
585
|
+
else:
|
|
586
|
+
result = guardrail(output)
|
|
587
|
+
|
|
588
|
+
if not result.is_valid:
|
|
589
|
+
error_message = getattr(result, 'error_message', 'Guardrail violation')
|
|
590
|
+
print(f"🚨 {guardrail_name} violation: {error_message}")
|
|
591
|
+
if config.on_event:
|
|
592
|
+
config.on_event(GuardrailViolationEvent(
|
|
593
|
+
data=GuardrailViolationEventData(stage='output', reason=error_message)
|
|
594
|
+
))
|
|
595
|
+
return result
|
|
596
|
+
|
|
597
|
+
print(f"✅ {guardrail_name} passed")
|
|
598
|
+
|
|
599
|
+
except Exception as error:
|
|
600
|
+
error_message = str(error)
|
|
601
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} failed: {error_message}")
|
|
602
|
+
|
|
603
|
+
is_system_error = 'Timeout' in error_message or 'Circuit breaker' in error_message
|
|
604
|
+
|
|
605
|
+
if is_system_error:
|
|
606
|
+
print(f"[JAF:GUARDRAILS] {guardrail_name} system error, allowing output: {error_message}")
|
|
607
|
+
continue
|
|
608
|
+
else:
|
|
609
|
+
if config.on_event:
|
|
610
|
+
config.on_event(GuardrailViolationEvent(
|
|
611
|
+
data=GuardrailViolationEventData(stage='output', reason=error_message)
|
|
612
|
+
))
|
|
613
|
+
return InvalidValidationResult(error_message=error_message)
|
|
614
|
+
|
|
615
|
+
print("✅ All output guardrails passed")
|
|
616
|
+
return ValidValidationResult()
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def cleanup_circuit_breakers() -> None:
|
|
620
|
+
"""Clean up old circuit breakers."""
|
|
621
|
+
to_remove = []
|
|
622
|
+
for key, breaker in _circuit_breakers.items():
|
|
623
|
+
if breaker.should_be_cleaned_up(CIRCUIT_BREAKER_CLEANUP_MAX_AGE):
|
|
624
|
+
to_remove.append(key)
|
|
625
|
+
|
|
626
|
+
for key in to_remove:
|
|
627
|
+
del _circuit_breakers[key]
|
|
628
|
+
|
|
629
|
+
|
|
630
|
+
class GuardrailCacheManager:
|
|
631
|
+
"""Manager for guardrail cache operations."""
|
|
632
|
+
|
|
633
|
+
@staticmethod
|
|
634
|
+
def get_stats() -> Dict[str, Any]:
|
|
635
|
+
"""Get cache statistics."""
|
|
636
|
+
return _guardrail_cache.get_stats()
|
|
637
|
+
|
|
638
|
+
@staticmethod
|
|
639
|
+
def clear() -> None:
|
|
640
|
+
"""Clear cache."""
|
|
641
|
+
_guardrail_cache.clear()
|
|
642
|
+
|
|
643
|
+
@staticmethod
|
|
644
|
+
def get_metrics() -> Dict[str, Any]:
|
|
645
|
+
"""Get cache metrics."""
|
|
646
|
+
stats = _guardrail_cache.get_stats()
|
|
647
|
+
return {
|
|
648
|
+
**stats,
|
|
649
|
+
'utilization_percent': (stats['size'] / stats['max_size']) * 100,
|
|
650
|
+
'circuit_breakers_count': len(_circuit_breakers)
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
@staticmethod
|
|
654
|
+
def log_stats() -> None:
|
|
655
|
+
"""Log cache statistics."""
|
|
656
|
+
metrics = GuardrailCacheManager.get_metrics()
|
|
657
|
+
print('[JAF:GUARDRAILS] Cache stats:', metrics)
|
|
658
|
+
|
|
659
|
+
@staticmethod
|
|
660
|
+
def cleanup() -> None:
|
|
661
|
+
"""Cleanup old entries."""
|
|
662
|
+
cleanup_circuit_breakers()
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
# Export the cache manager
|
|
666
|
+
guardrail_cache_manager = GuardrailCacheManager()
|
jaf/core/types.py
CHANGED
|
@@ -288,6 +288,7 @@ class Agent(Generic[Ctx, Out]):
|
|
|
288
288
|
output_codec: Optional[Any] = None # Type that can validate Out (like Pydantic model or Zod equivalent)
|
|
289
289
|
handoffs: Optional[List[str]] = None
|
|
290
290
|
model_config: Optional[ModelConfig] = None
|
|
291
|
+
advanced_config: Optional['AdvancedConfig'] = None
|
|
291
292
|
|
|
292
293
|
def as_tool(
|
|
293
294
|
self,
|
|
@@ -331,6 +332,74 @@ class Agent(Generic[Ctx, Out]):
|
|
|
331
332
|
# Guardrail type
|
|
332
333
|
Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
|
|
333
334
|
|
|
335
|
+
@dataclass(frozen=True)
|
|
336
|
+
class AdvancedGuardrailsConfig:
|
|
337
|
+
"""Configuration for advanced guardrails with LLM-based validation."""
|
|
338
|
+
input_prompt: Optional[str] = None
|
|
339
|
+
output_prompt: Optional[str] = None
|
|
340
|
+
require_citations: bool = False
|
|
341
|
+
fast_model: Optional[str] = None
|
|
342
|
+
fail_safe: Literal['allow', 'block'] = 'allow'
|
|
343
|
+
execution_mode: Literal['parallel', 'sequential'] = 'parallel'
|
|
344
|
+
timeout_ms: int = 30000
|
|
345
|
+
|
|
346
|
+
def __post_init__(self):
|
|
347
|
+
"""Validate configuration."""
|
|
348
|
+
if self.timeout_ms < 1000:
|
|
349
|
+
object.__setattr__(self, 'timeout_ms', 1000)
|
|
350
|
+
|
|
351
|
+
@dataclass(frozen=True)
|
|
352
|
+
class AdvancedConfig:
|
|
353
|
+
"""Advanced agent configuration including guardrails."""
|
|
354
|
+
guardrails: Optional[AdvancedGuardrailsConfig] = None
|
|
355
|
+
|
|
356
|
+
def validate_guardrails_config(config: Optional[AdvancedGuardrailsConfig]) -> AdvancedGuardrailsConfig:
|
|
357
|
+
"""Validate and provide defaults for guardrails configuration."""
|
|
358
|
+
if config is None:
|
|
359
|
+
return AdvancedGuardrailsConfig()
|
|
360
|
+
|
|
361
|
+
return AdvancedGuardrailsConfig(
|
|
362
|
+
input_prompt=config.input_prompt.strip() if isinstance(config.input_prompt, str) and config.input_prompt else None,
|
|
363
|
+
output_prompt=config.output_prompt.strip() if isinstance(config.output_prompt, str) and config.output_prompt else None,
|
|
364
|
+
require_citations=config.require_citations,
|
|
365
|
+
fast_model=config.fast_model.strip() if isinstance(config.fast_model, str) and config.fast_model else None,
|
|
366
|
+
fail_safe=config.fail_safe,
|
|
367
|
+
execution_mode=config.execution_mode,
|
|
368
|
+
timeout_ms=max(1000, config.timeout_ms)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
|
|
372
|
+
"""Parse JSON from LLM output, handling common formatting issues."""
|
|
373
|
+
import json
|
|
374
|
+
import re
|
|
375
|
+
|
|
376
|
+
if not text:
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
# Try direct parsing first
|
|
380
|
+
try:
|
|
381
|
+
return json.loads(text)
|
|
382
|
+
except json.JSONDecodeError:
|
|
383
|
+
pass
|
|
384
|
+
|
|
385
|
+
# Try to extract JSON from markdown code blocks
|
|
386
|
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
|
387
|
+
if json_match:
|
|
388
|
+
try:
|
|
389
|
+
return json.loads(json_match.group(1))
|
|
390
|
+
except json.JSONDecodeError:
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
# Try to find the first JSON object in the text
|
|
394
|
+
json_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
|
395
|
+
if json_match:
|
|
396
|
+
try:
|
|
397
|
+
return json.loads(json_match.group(0))
|
|
398
|
+
except json.JSONDecodeError:
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
return None
|
|
402
|
+
|
|
334
403
|
@dataclass(frozen=True)
|
|
335
404
|
class ApprovalValue:
|
|
336
405
|
"""Represents an approval decision with context."""
|
|
@@ -600,6 +669,17 @@ class GuardrailEvent:
|
|
|
600
669
|
type: Literal['guardrail_check'] = 'guardrail_check'
|
|
601
670
|
data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
|
|
602
671
|
|
|
672
|
+
@dataclass(frozen=True)
|
|
673
|
+
class GuardrailViolationEventData:
|
|
674
|
+
"""Data for guardrail violation events."""
|
|
675
|
+
stage: Literal['input', 'output']
|
|
676
|
+
reason: str
|
|
677
|
+
|
|
678
|
+
@dataclass(frozen=True)
|
|
679
|
+
class GuardrailViolationEvent:
|
|
680
|
+
type: Literal['guardrail_violation'] = 'guardrail_violation'
|
|
681
|
+
data: GuardrailViolationEventData = field(default_factory=lambda: GuardrailViolationEventData("input", ""))
|
|
682
|
+
|
|
603
683
|
@dataclass(frozen=True)
|
|
604
684
|
class MemoryEventData:
|
|
605
685
|
"""Data for memory operation events."""
|
|
@@ -632,6 +712,7 @@ class OutputParseEvent:
|
|
|
632
712
|
TraceEvent = Union[
|
|
633
713
|
RunStartEvent,
|
|
634
714
|
GuardrailEvent,
|
|
715
|
+
GuardrailViolationEvent,
|
|
635
716
|
MemoryEvent,
|
|
636
717
|
OutputParseEvent,
|
|
637
718
|
LLMCallStartEvent,
|
|
@@ -710,7 +791,8 @@ class RunConfig(Generic[Ctx]):
|
|
|
710
791
|
initial_input_guardrails: Optional[List[Guardrail]] = None
|
|
711
792
|
final_output_guardrails: Optional[List[Guardrail]] = None
|
|
712
793
|
on_event: Optional[Callable[[TraceEvent], None]] = None
|
|
713
|
-
memory: Optional[
|
|
794
|
+
memory: Optional[Any] = None # MemoryConfig - avoiding circular import
|
|
714
795
|
conversation_id: Optional[str] = None
|
|
796
|
+
default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
|
|
715
797
|
default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
|
|
716
798
|
approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
|
|
@@ -42,8 +42,9 @@ jaf/core/__init__.py,sha256=PIGKm8n6OQ8jcXRS0Hn3_Zsl8m2qX91N80YJoLCJ4eU,1762
|
|
|
42
42
|
jaf/core/agent_tool.py,sha256=tfLNaTIcOZ0dR9GBP1AHLPkLExm_dLbURnVIN4R84FQ,11806
|
|
43
43
|
jaf/core/analytics.py,sha256=zFHIWqWal0bbEFCmJDc4DKeM0Ja7b_D19PqVaBI12pA,23338
|
|
44
44
|
jaf/core/composition.py,sha256=IVxRO1Q9nK7JRH32qQ4p8WMIUu66BhqPNrlTNMGFVwE,26317
|
|
45
|
-
jaf/core/engine.py,sha256=
|
|
45
|
+
jaf/core/engine.py,sha256=7j8LRf52inRKN4gcCPNuXzoBKMr19S9VMyjrrb3Xlek,57406
|
|
46
46
|
jaf/core/errors.py,sha256=5fwTNhkojKRQ4wZj3lZlgDnAsrYyjYOwXJkIr5EGNUc,5539
|
|
47
|
+
jaf/core/guardrails.py,sha256=nv7pQuCx7-9DDZrecWO1DsDqFoujL81FBDrafOsXgcI,26179
|
|
47
48
|
jaf/core/parallel_agents.py,sha256=ahwYoTnkrF4xQgV-hjc5sUaWhQWQFENMZG5riNa_Ieg,12165
|
|
48
49
|
jaf/core/performance.py,sha256=jedQmTEkrKMD6_Aw1h8PdG-5TsdYSFFT7Or6k5dmN2g,9974
|
|
49
50
|
jaf/core/proxy.py,sha256=_WM3cpRlSQLYpgSBrnY30UPMe2iZtlqDQ65kppE-WY0,4609
|
|
@@ -53,7 +54,7 @@ jaf/core/streaming.py,sha256=h_lYHQA9ee_D5QsDO9-Vhevgi7rFXPslPzd9605AJGo,17034
|
|
|
53
54
|
jaf/core/tool_results.py,sha256=-bTOqOX02lMyslp5Z4Dmuhx0cLd5o7kgR88qK2HO_sw,11323
|
|
54
55
|
jaf/core/tools.py,sha256=84N9A7QQ3xxcOs2eUUot3nmCnt5i7iZT9VwkuzuFBxQ,16274
|
|
55
56
|
jaf/core/tracing.py,sha256=iuVgykFUSkoBjem1k6jdVLrhRZzJn-avyxc_6W9BXPI,40159
|
|
56
|
-
jaf/core/types.py,sha256=
|
|
57
|
+
jaf/core/types.py,sha256=FCc9uWTUS6P1iU-_RxJM7k-HNorsHM-0XHqwwaUGLkE,26267
|
|
57
58
|
jaf/core/workflows.py,sha256=Ul-82gzjIXtkhnSMSPv-8igikjkMtW1EBo9yrfodtvI,26294
|
|
58
59
|
jaf/memory/__init__.py,sha256=-L98xlvihurGAzF0DnXtkueDVvO_wV2XxxEwAWdAj50,1400
|
|
59
60
|
jaf/memory/approval_storage.py,sha256=HHZ_b57kIthdR53QE5XNSII9xy1Cg-1cFUCSAZ8A4Rk,11083
|
|
@@ -85,9 +86,9 @@ jaf/visualization/functional_core.py,sha256=zedMDZbvjuOugWwnh6SJ2stvRNQX1Hlkb9Ab
|
|
|
85
86
|
jaf/visualization/graphviz.py,sha256=WTOM6UP72-lVKwI4_SAr5-GCC3ouckxHv88ypCDQWJ0,12056
|
|
86
87
|
jaf/visualization/imperative_shell.py,sha256=GpMrAlMnLo2IQgyB2nardCz09vMvAzaYI46MyrvJ0i4,2593
|
|
87
88
|
jaf/visualization/types.py,sha256=QQcbVeQJLuAOXk8ynd08DXIS-PVCnv3R-XVE9iAcglw,1389
|
|
88
|
-
jaf_py-2.4.
|
|
89
|
-
jaf_py-2.4.
|
|
90
|
-
jaf_py-2.4.
|
|
91
|
-
jaf_py-2.4.
|
|
92
|
-
jaf_py-2.4.
|
|
93
|
-
jaf_py-2.4.
|
|
89
|
+
jaf_py-2.4.6.dist-info/licenses/LICENSE,sha256=LXUQBJxdyr-7C4bk9cQBwvsF_xwA-UVstDTKabpcjlI,1063
|
|
90
|
+
jaf_py-2.4.6.dist-info/METADATA,sha256=ep-RyxTMs_RhA4h10yXfZ5s4RQOiIZ0A5hRoP4ZV3sg,27712
|
|
91
|
+
jaf_py-2.4.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
92
|
+
jaf_py-2.4.6.dist-info/entry_points.txt,sha256=OtIJeNJpb24kgGrqRx9szGgDx1vL9ayq8uHErmu7U5w,41
|
|
93
|
+
jaf_py-2.4.6.dist-info/top_level.txt,sha256=Xu1RZbGaM4_yQX7bpalo881hg7N_dybaOW282F15ruE,4
|
|
94
|
+
jaf_py-2.4.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|