jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +310 -210
- jaf/core/types.py +403 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +475 -283
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
- jaf_py-2.5.12.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/policies/handoff.py
CHANGED
|
@@ -14,24 +14,24 @@ from ..core.types import Guardrail, ValidationResult
|
|
|
14
14
|
@dataclass
|
|
15
15
|
class HandoffPolicy:
|
|
16
16
|
"""Policy configuration for agent handoffs."""
|
|
17
|
+
|
|
17
18
|
allowed_handoffs: Dict[str, List[str]] # source_agent -> [allowed_target_agents]
|
|
18
19
|
require_permission: bool = True
|
|
19
20
|
validate_context: bool = True
|
|
20
21
|
max_handoff_depth: int = 10
|
|
21
22
|
|
|
23
|
+
|
|
22
24
|
def create_handoff_guardrail(
|
|
23
|
-
policy: HandoffPolicy,
|
|
24
|
-
current_agent: str,
|
|
25
|
-
handoff_history: Optional[List[str]] = None
|
|
25
|
+
policy: HandoffPolicy, current_agent: str, handoff_history: Optional[List[str]] = None
|
|
26
26
|
) -> Guardrail:
|
|
27
27
|
"""
|
|
28
28
|
Create a guardrail that validates agent handoffs.
|
|
29
|
-
|
|
29
|
+
|
|
30
30
|
Args:
|
|
31
31
|
policy: Handoff policy configuration
|
|
32
32
|
current_agent: Name of the current agent
|
|
33
33
|
handoff_history: List of previous agents in the current session
|
|
34
|
-
|
|
34
|
+
|
|
35
35
|
Returns:
|
|
36
36
|
Guardrail function for handoff validation
|
|
37
37
|
"""
|
|
@@ -40,20 +40,16 @@ def create_handoff_guardrail(
|
|
|
40
40
|
def handoff_guardrail(handoff_data: Any) -> ValidationResult:
|
|
41
41
|
# Extract target agent from handoff data
|
|
42
42
|
if isinstance(handoff_data, dict):
|
|
43
|
-
target_agent = handoff_data.get(
|
|
43
|
+
target_agent = handoff_data.get("handoff_to") or handoff_data.get("target_agent")
|
|
44
44
|
elif isinstance(handoff_data, str):
|
|
45
45
|
# Assume the string is the target agent name
|
|
46
46
|
target_agent = handoff_data
|
|
47
47
|
else:
|
|
48
|
-
return ValidationResult(
|
|
49
|
-
is_valid=False,
|
|
50
|
-
error_message="Invalid handoff data format"
|
|
51
|
-
)
|
|
48
|
+
return ValidationResult(is_valid=False, error_message="Invalid handoff data format")
|
|
52
49
|
|
|
53
50
|
if not target_agent:
|
|
54
51
|
return ValidationResult(
|
|
55
|
-
is_valid=False,
|
|
56
|
-
error_message="No target agent specified in handoff"
|
|
52
|
+
is_valid=False, error_message="No target agent specified in handoff"
|
|
57
53
|
)
|
|
58
54
|
|
|
59
55
|
# Check permission
|
|
@@ -62,49 +58,47 @@ def create_handoff_guardrail(
|
|
|
62
58
|
if target_agent not in allowed_targets:
|
|
63
59
|
return ValidationResult(
|
|
64
60
|
is_valid=False,
|
|
65
|
-
error_message=f"Agent {current_agent} is not allowed to handoff to {target_agent}"
|
|
61
|
+
error_message=f"Agent {current_agent} is not allowed to handoff to {target_agent}",
|
|
66
62
|
)
|
|
67
63
|
|
|
68
64
|
# Check handoff depth to prevent infinite loops
|
|
69
65
|
if len(handoff_history) >= policy.max_handoff_depth:
|
|
70
66
|
return ValidationResult(
|
|
71
67
|
is_valid=False,
|
|
72
|
-
error_message=f"Maximum handoff depth ({policy.max_handoff_depth}) exceeded"
|
|
68
|
+
error_message=f"Maximum handoff depth ({policy.max_handoff_depth}) exceeded",
|
|
73
69
|
)
|
|
74
70
|
|
|
75
71
|
# Check for immediate circular handoffs
|
|
76
72
|
if handoff_history and handoff_history[-1] == target_agent:
|
|
77
73
|
return ValidationResult(
|
|
78
74
|
is_valid=False,
|
|
79
|
-
error_message=f"Circular handoff detected: {current_agent} -> {target_agent} -> {current_agent}"
|
|
75
|
+
error_message=f"Circular handoff detected: {current_agent} -> {target_agent} -> {current_agent}",
|
|
80
76
|
)
|
|
81
77
|
|
|
82
78
|
# Validate context if required
|
|
83
79
|
if policy.validate_context and isinstance(handoff_data, dict):
|
|
84
|
-
context = handoff_data.get(
|
|
80
|
+
context = handoff_data.get("context", {})
|
|
85
81
|
if not isinstance(context, dict):
|
|
86
82
|
return ValidationResult(
|
|
87
|
-
is_valid=False,
|
|
88
|
-
error_message="Handoff context must be a dictionary"
|
|
83
|
+
is_valid=False, error_message="Handoff context must be a dictionary"
|
|
89
84
|
)
|
|
90
85
|
|
|
91
86
|
return ValidationResult(is_valid=True)
|
|
92
87
|
|
|
93
88
|
return handoff_guardrail
|
|
94
89
|
|
|
90
|
+
|
|
95
91
|
def validate_handoff_permissions(
|
|
96
|
-
source_agent: str,
|
|
97
|
-
target_agent: str,
|
|
98
|
-
allowed_handoffs: Dict[str, List[str]]
|
|
92
|
+
source_agent: str, target_agent: str, allowed_handoffs: Dict[str, List[str]]
|
|
99
93
|
) -> ValidationResult:
|
|
100
94
|
"""
|
|
101
95
|
Validate if a handoff is allowed between two agents.
|
|
102
|
-
|
|
96
|
+
|
|
103
97
|
Args:
|
|
104
98
|
source_agent: Name of the source agent
|
|
105
99
|
target_agent: Name of the target agent
|
|
106
100
|
allowed_handoffs: Dictionary mapping source agents to allowed targets
|
|
107
|
-
|
|
101
|
+
|
|
108
102
|
Returns:
|
|
109
103
|
ValidationResult indicating if the handoff is allowed
|
|
110
104
|
"""
|
|
@@ -113,22 +107,23 @@ def validate_handoff_permissions(
|
|
|
113
107
|
if target_agent not in allowed_targets:
|
|
114
108
|
return ValidationResult(
|
|
115
109
|
is_valid=False,
|
|
116
|
-
error_message=f"Handoff from {source_agent} to {target_agent} not permitted"
|
|
110
|
+
error_message=f"Handoff from {source_agent} to {target_agent} not permitted",
|
|
117
111
|
)
|
|
118
112
|
|
|
119
113
|
return ValidationResult(is_valid=True)
|
|
120
114
|
|
|
115
|
+
|
|
121
116
|
def create_role_based_handoff_policy(
|
|
122
117
|
agent_roles: Dict[str, str], # agent_name -> role
|
|
123
|
-
role_permissions: Dict[str, List[str]] # role -> [allowed_target_roles]
|
|
118
|
+
role_permissions: Dict[str, List[str]], # role -> [allowed_target_roles]
|
|
124
119
|
) -> HandoffPolicy:
|
|
125
120
|
"""
|
|
126
121
|
Create a handoff policy based on agent roles.
|
|
127
|
-
|
|
122
|
+
|
|
128
123
|
Args:
|
|
129
124
|
agent_roles: Mapping of agent names to their roles
|
|
130
125
|
role_permissions: Mapping of roles to allowed target roles
|
|
131
|
-
|
|
126
|
+
|
|
132
127
|
Returns:
|
|
133
128
|
HandoffPolicy configured for role-based permissions
|
|
134
129
|
"""
|
|
@@ -137,27 +132,27 @@ def create_role_based_handoff_policy(
|
|
|
137
132
|
for agent_name, agent_role in agent_roles.items():
|
|
138
133
|
allowed_target_roles = role_permissions.get(agent_role, [])
|
|
139
134
|
allowed_targets = [
|
|
140
|
-
target_agent
|
|
135
|
+
target_agent
|
|
136
|
+
for target_agent, target_role in agent_roles.items()
|
|
141
137
|
if target_role in allowed_target_roles and target_agent != agent_name
|
|
142
138
|
]
|
|
143
139
|
allowed_handoffs[agent_name] = allowed_targets
|
|
144
140
|
|
|
145
141
|
return HandoffPolicy(
|
|
146
|
-
allowed_handoffs=allowed_handoffs,
|
|
147
|
-
require_permission=True,
|
|
148
|
-
validate_context=True
|
|
142
|
+
allowed_handoffs=allowed_handoffs, require_permission=True, validate_context=True
|
|
149
143
|
)
|
|
150
144
|
|
|
145
|
+
|
|
151
146
|
def create_workflow_handoff_policy(
|
|
152
|
-
workflow_steps: List[List[str]] # List of workflow steps, each step is a list of agent names
|
|
147
|
+
workflow_steps: List[List[str]], # List of workflow steps, each step is a list of agent names
|
|
153
148
|
) -> HandoffPolicy:
|
|
154
149
|
"""
|
|
155
150
|
Create a handoff policy based on a defined workflow.
|
|
156
|
-
|
|
151
|
+
|
|
157
152
|
Args:
|
|
158
153
|
workflow_steps: List of workflow steps, where each step contains agent names
|
|
159
154
|
that can transition to agents in the next step
|
|
160
|
-
|
|
155
|
+
|
|
161
156
|
Returns:
|
|
162
157
|
HandoffPolicy configured for workflow-based transitions
|
|
163
158
|
"""
|
|
@@ -180,14 +175,13 @@ def create_workflow_handoff_policy(
|
|
|
180
175
|
allowed_handoffs[agent] = allowed_targets
|
|
181
176
|
|
|
182
177
|
return HandoffPolicy(
|
|
183
|
-
allowed_handoffs=allowed_handoffs,
|
|
184
|
-
require_permission=True,
|
|
185
|
-
validate_context=True
|
|
178
|
+
allowed_handoffs=allowed_handoffs, require_permission=True, validate_context=True
|
|
186
179
|
)
|
|
187
180
|
|
|
181
|
+
|
|
188
182
|
# Predefined policies for common scenarios
|
|
189
183
|
def create_hierarchical_handoff_policy(
|
|
190
|
-
hierarchy: Dict[str, List[str]] # supervisor -> [subordinates]
|
|
184
|
+
hierarchy: Dict[str, List[str]], # supervisor -> [subordinates]
|
|
191
185
|
) -> HandoffPolicy:
|
|
192
186
|
"""Create a policy for hierarchical agent handoffs."""
|
|
193
187
|
allowed_handoffs = {}
|
|
@@ -214,11 +208,10 @@ def create_hierarchical_handoff_policy(
|
|
|
214
208
|
allowed_handoffs[subordinate] = supervisors + list(siblings)
|
|
215
209
|
|
|
216
210
|
return HandoffPolicy(
|
|
217
|
-
allowed_handoffs=allowed_handoffs,
|
|
218
|
-
require_permission=True,
|
|
219
|
-
validate_context=True
|
|
211
|
+
allowed_handoffs=allowed_handoffs, require_permission=True, validate_context=True
|
|
220
212
|
)
|
|
221
213
|
|
|
214
|
+
|
|
222
215
|
def create_open_handoff_policy(agent_names: List[str]) -> HandoffPolicy:
|
|
223
216
|
"""Create a policy that allows any agent to handoff to any other agent."""
|
|
224
217
|
allowed_handoffs = {}
|
|
@@ -229,14 +222,12 @@ def create_open_handoff_policy(agent_names: List[str]) -> HandoffPolicy:
|
|
|
229
222
|
return HandoffPolicy(
|
|
230
223
|
allowed_handoffs=allowed_handoffs,
|
|
231
224
|
require_permission=False, # Open policy
|
|
232
|
-
validate_context=False
|
|
225
|
+
validate_context=False,
|
|
233
226
|
)
|
|
234
227
|
|
|
228
|
+
|
|
235
229
|
def create_restricted_handoff_policy() -> HandoffPolicy:
|
|
236
230
|
"""Create a policy that doesn't allow any handoffs."""
|
|
237
231
|
return HandoffPolicy(
|
|
238
|
-
allowed_handoffs={},
|
|
239
|
-
require_permission=True,
|
|
240
|
-
validate_context=True,
|
|
241
|
-
max_handoff_depth=0
|
|
232
|
+
allowed_handoffs={}, require_permission=True, validate_context=True, max_handoff_depth=0
|
|
242
233
|
)
|
jaf/policies/validation.py
CHANGED
|
@@ -19,23 +19,23 @@ from ..core.types import Guardrail, InvalidValidationResult, ValidationResult, V
|
|
|
19
19
|
@dataclass
|
|
20
20
|
class GuardrailConfig:
|
|
21
21
|
"""Configuration for guardrails."""
|
|
22
|
+
|
|
22
23
|
enabled: bool = True
|
|
23
24
|
strict_mode: bool = False
|
|
24
25
|
custom_message: Optional[str] = None
|
|
25
26
|
|
|
27
|
+
|
|
26
28
|
def create_length_guardrail(
|
|
27
|
-
max_length: int,
|
|
28
|
-
min_length: int = 0,
|
|
29
|
-
config: Optional[GuardrailConfig] = None
|
|
29
|
+
max_length: int, min_length: int = 0, config: Optional[GuardrailConfig] = None
|
|
30
30
|
) -> Guardrail:
|
|
31
31
|
"""
|
|
32
32
|
Create a guardrail that validates text length.
|
|
33
|
-
|
|
33
|
+
|
|
34
34
|
Args:
|
|
35
35
|
max_length: Maximum allowed length
|
|
36
36
|
min_length: Minimum required length
|
|
37
37
|
config: Optional guardrail configuration
|
|
38
|
-
|
|
38
|
+
|
|
39
39
|
Returns:
|
|
40
40
|
Guardrail function
|
|
41
41
|
"""
|
|
@@ -48,34 +48,37 @@ def create_length_guardrail(
|
|
|
48
48
|
text_length = len(text)
|
|
49
49
|
|
|
50
50
|
if text_length > max_length:
|
|
51
|
-
message = (
|
|
52
|
-
|
|
51
|
+
message = (
|
|
52
|
+
config.custom_message or f"Text length {text_length} exceeds maximum {max_length}"
|
|
53
|
+
)
|
|
53
54
|
return InvalidValidationResult(error_message=message)
|
|
54
55
|
|
|
55
56
|
if text_length < min_length:
|
|
56
|
-
message = (
|
|
57
|
-
|
|
57
|
+
message = (
|
|
58
|
+
config.custom_message or f"Text length {text_length} below minimum {min_length}"
|
|
59
|
+
)
|
|
58
60
|
return InvalidValidationResult(error_message=message)
|
|
59
61
|
|
|
60
62
|
return ValidValidationResult()
|
|
61
63
|
|
|
62
64
|
return length_guardrail
|
|
63
65
|
|
|
66
|
+
|
|
64
67
|
def create_content_filter_guardrail(
|
|
65
68
|
blocked_patterns: List[str],
|
|
66
69
|
allowed_patterns: Optional[List[str]] = None,
|
|
67
70
|
case_sensitive: bool = False,
|
|
68
|
-
config: Optional[GuardrailConfig] = None
|
|
71
|
+
config: Optional[GuardrailConfig] = None,
|
|
69
72
|
) -> Guardrail:
|
|
70
73
|
"""
|
|
71
74
|
Create a guardrail that filters content based on patterns.
|
|
72
|
-
|
|
75
|
+
|
|
73
76
|
Args:
|
|
74
77
|
blocked_patterns: Regex patterns that should be blocked
|
|
75
78
|
allowed_patterns: Regex patterns that override blocks (allowlist)
|
|
76
79
|
case_sensitive: Whether pattern matching is case sensitive
|
|
77
80
|
config: Optional guardrail configuration
|
|
78
|
-
|
|
81
|
+
|
|
79
82
|
Returns:
|
|
80
83
|
Guardrail function
|
|
81
84
|
"""
|
|
@@ -84,8 +87,9 @@ def create_content_filter_guardrail(
|
|
|
84
87
|
# Compile patterns for efficiency
|
|
85
88
|
flags = 0 if case_sensitive else re.IGNORECASE
|
|
86
89
|
compiled_blocked = [re.compile(pattern, flags) for pattern in blocked_patterns]
|
|
87
|
-
compiled_allowed = (
|
|
88
|
-
|
|
90
|
+
compiled_allowed = (
|
|
91
|
+
[re.compile(pattern, flags) for pattern in allowed_patterns] if allowed_patterns else []
|
|
92
|
+
)
|
|
89
93
|
|
|
90
94
|
def content_filter_guardrail(text: str) -> ValidationResult:
|
|
91
95
|
if not config.enabled:
|
|
@@ -98,25 +102,26 @@ def create_content_filter_guardrail(
|
|
|
98
102
|
is_allowed = any(allowed.search(text) for allowed in compiled_allowed)
|
|
99
103
|
|
|
100
104
|
if not is_allowed:
|
|
101
|
-
message = (
|
|
102
|
-
|
|
105
|
+
message = (
|
|
106
|
+
config.custom_message or f"Content blocked by pattern: {pattern.pattern}"
|
|
107
|
+
)
|
|
103
108
|
return InvalidValidationResult(error_message=message)
|
|
104
109
|
|
|
105
110
|
return ValidValidationResult()
|
|
106
111
|
|
|
107
112
|
return content_filter_guardrail
|
|
108
113
|
|
|
114
|
+
|
|
109
115
|
def create_json_validation_guardrail(
|
|
110
|
-
schema_class: type[BaseModel],
|
|
111
|
-
config: Optional[GuardrailConfig] = None
|
|
116
|
+
schema_class: type[BaseModel], config: Optional[GuardrailConfig] = None
|
|
112
117
|
) -> Guardrail:
|
|
113
118
|
"""
|
|
114
119
|
Create a guardrail that validates JSON against a Pydantic schema.
|
|
115
|
-
|
|
120
|
+
|
|
116
121
|
Args:
|
|
117
122
|
schema_class: Pydantic model class for validation
|
|
118
123
|
config: Optional guardrail configuration
|
|
119
|
-
|
|
124
|
+
|
|
120
125
|
Returns:
|
|
121
126
|
Guardrail function
|
|
122
127
|
"""
|
|
@@ -132,8 +137,7 @@ def create_json_validation_guardrail(
|
|
|
132
137
|
try:
|
|
133
138
|
data = json.loads(data)
|
|
134
139
|
except json.JSONDecodeError as e:
|
|
135
|
-
message =
|
|
136
|
-
f"Invalid JSON format: {e!s}")
|
|
140
|
+
message = config.custom_message or f"Invalid JSON format: {e!s}"
|
|
137
141
|
return InvalidValidationResult(error_message=message)
|
|
138
142
|
|
|
139
143
|
# Validate against schema
|
|
@@ -141,36 +145,37 @@ def create_json_validation_guardrail(
|
|
|
141
145
|
return ValidValidationResult()
|
|
142
146
|
|
|
143
147
|
except ValidationError as e:
|
|
144
|
-
message =
|
|
145
|
-
f"Schema validation failed: {e!s}")
|
|
148
|
+
message = config.custom_message or f"Schema validation failed: {e!s}"
|
|
146
149
|
return InvalidValidationResult(error_message=message)
|
|
147
150
|
except Exception as e:
|
|
148
|
-
message =
|
|
149
|
-
f"Validation error: {e!s}")
|
|
151
|
+
message = config.custom_message or f"Validation error: {e!s}"
|
|
150
152
|
return InvalidValidationResult(error_message=message)
|
|
151
153
|
|
|
152
154
|
return json_validation_guardrail
|
|
153
155
|
|
|
156
|
+
|
|
154
157
|
@dataclass
|
|
155
158
|
class RateLimitState:
|
|
156
159
|
"""State for rate limiting."""
|
|
160
|
+
|
|
157
161
|
calls: List[float] = field(default_factory=list)
|
|
158
162
|
window_size: float = 60.0 # seconds
|
|
159
163
|
max_calls: int = 10
|
|
160
164
|
|
|
165
|
+
|
|
161
166
|
def create_rate_limit_guardrail(
|
|
162
167
|
max_calls: int = 10,
|
|
163
168
|
window_size: float = 60.0, # seconds
|
|
164
|
-
config: Optional[GuardrailConfig] = None
|
|
169
|
+
config: Optional[GuardrailConfig] = None,
|
|
165
170
|
) -> Guardrail:
|
|
166
171
|
"""
|
|
167
172
|
Create a guardrail that implements rate limiting.
|
|
168
|
-
|
|
173
|
+
|
|
169
174
|
Args:
|
|
170
175
|
max_calls: Maximum number of calls allowed in the window
|
|
171
176
|
window_size: Time window in seconds
|
|
172
177
|
config: Optional guardrail configuration
|
|
173
|
-
|
|
178
|
+
|
|
174
179
|
Returns:
|
|
175
180
|
Guardrail function
|
|
176
181
|
"""
|
|
@@ -189,8 +194,10 @@ def create_rate_limit_guardrail(
|
|
|
189
194
|
|
|
190
195
|
# Check if we're at the limit
|
|
191
196
|
if len(state.calls) >= state.max_calls:
|
|
192
|
-
message = (
|
|
193
|
-
|
|
197
|
+
message = (
|
|
198
|
+
config.custom_message
|
|
199
|
+
or f"Rate limit exceeded: {len(state.calls)}/{state.max_calls} calls in {state.window_size}s"
|
|
200
|
+
)
|
|
194
201
|
return InvalidValidationResult(error_message=message)
|
|
195
202
|
|
|
196
203
|
# Record this call
|
|
@@ -199,19 +206,18 @@ def create_rate_limit_guardrail(
|
|
|
199
206
|
|
|
200
207
|
return rate_limit_guardrail
|
|
201
208
|
|
|
209
|
+
|
|
202
210
|
def combine_guardrails(
|
|
203
|
-
guardrails: List[Guardrail],
|
|
204
|
-
require_all: bool = True,
|
|
205
|
-
config: Optional[GuardrailConfig] = None
|
|
211
|
+
guardrails: List[Guardrail], require_all: bool = True, config: Optional[GuardrailConfig] = None
|
|
206
212
|
) -> Guardrail:
|
|
207
213
|
"""
|
|
208
214
|
Combine multiple guardrails into a single guardrail.
|
|
209
|
-
|
|
215
|
+
|
|
210
216
|
Args:
|
|
211
217
|
guardrails: List of guardrails to combine
|
|
212
218
|
require_all: If True, all guardrails must pass; if False, at least one must pass
|
|
213
219
|
config: Optional guardrail configuration
|
|
214
|
-
|
|
220
|
+
|
|
215
221
|
Returns:
|
|
216
222
|
Combined guardrail function
|
|
217
223
|
"""
|
|
@@ -228,7 +234,7 @@ def combine_guardrails(
|
|
|
228
234
|
# Handle both sync and async guardrails
|
|
229
235
|
if callable(guardrail):
|
|
230
236
|
result = guardrail(data)
|
|
231
|
-
if hasattr(result,
|
|
237
|
+
if hasattr(result, "__await__"):
|
|
232
238
|
result = await result
|
|
233
239
|
else:
|
|
234
240
|
continue
|
|
@@ -265,24 +271,36 @@ def combine_guardrails(
|
|
|
265
271
|
|
|
266
272
|
return combined_guardrail
|
|
267
273
|
|
|
274
|
+
|
|
268
275
|
# Common guardrail presets
|
|
269
276
|
def create_safe_text_guardrail(config: Optional[GuardrailConfig] = None) -> Guardrail:
|
|
270
277
|
"""Create a guardrail for safe text content."""
|
|
271
|
-
return combine_guardrails(
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
278
|
+
return combine_guardrails(
|
|
279
|
+
[
|
|
280
|
+
create_length_guardrail(max_length=10000, min_length=1),
|
|
281
|
+
create_content_filter_guardrail(
|
|
282
|
+
[
|
|
283
|
+
r"<script\b[^<]*(?:(?!<\/script>)<[^<]*)*<\/script>", # Script tags
|
|
284
|
+
r"javascript:", # JavaScript URLs
|
|
285
|
+
r"data:.*base64", # Base64 data URLs
|
|
286
|
+
],
|
|
287
|
+
case_sensitive=False,
|
|
288
|
+
),
|
|
289
|
+
],
|
|
290
|
+
config=config,
|
|
291
|
+
)
|
|
292
|
+
|
|
279
293
|
|
|
280
294
|
def create_api_input_guardrail(config: Optional[GuardrailConfig] = None) -> Guardrail:
|
|
281
295
|
"""Create a guardrail for API input validation."""
|
|
282
|
-
return combine_guardrails(
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
296
|
+
return combine_guardrails(
|
|
297
|
+
[
|
|
298
|
+
create_length_guardrail(max_length=50000),
|
|
299
|
+
create_rate_limit_guardrail(max_calls=100, window_size=60),
|
|
300
|
+
],
|
|
301
|
+
config=config,
|
|
302
|
+
)
|
|
303
|
+
|
|
286
304
|
|
|
287
305
|
# Compatibility aliases for tests
|
|
288
306
|
def create_content_filter(blocked_patterns: List[str], **kwargs) -> Guardrail:
|
|
@@ -290,7 +308,7 @@ def create_content_filter(blocked_patterns: List[str], **kwargs) -> Guardrail:
|
|
|
290
308
|
guardrail = create_content_filter_guardrail(
|
|
291
309
|
blocked_patterns,
|
|
292
310
|
config=GuardrailConfig(custom_message="Contains inappropriate content"),
|
|
293
|
-
**kwargs
|
|
311
|
+
**kwargs,
|
|
294
312
|
)
|
|
295
313
|
|
|
296
314
|
async def async_wrapper(text: str) -> ValidationResult:
|
|
@@ -298,18 +316,24 @@ def create_content_filter(blocked_patterns: List[str], **kwargs) -> Guardrail:
|
|
|
298
316
|
|
|
299
317
|
return async_wrapper
|
|
300
318
|
|
|
319
|
+
|
|
301
320
|
def create_length_limiter(max_length: int, min_length: int = 0, **kwargs) -> Guardrail:
|
|
302
321
|
"""Create length limiter (test compatibility)."""
|
|
303
322
|
|
|
304
323
|
async def async_wrapper(text: str) -> ValidationResult:
|
|
305
324
|
if len(text) > max_length:
|
|
306
|
-
return InvalidValidationResult(
|
|
325
|
+
return InvalidValidationResult(
|
|
326
|
+
error_message=f"Text exceeds maximum length of {max_length}"
|
|
327
|
+
)
|
|
307
328
|
if len(text) < min_length:
|
|
308
|
-
return InvalidValidationResult(
|
|
329
|
+
return InvalidValidationResult(
|
|
330
|
+
error_message=f"Text below minimum length of {min_length}"
|
|
331
|
+
)
|
|
309
332
|
return ValidValidationResult()
|
|
310
333
|
|
|
311
334
|
return async_wrapper
|
|
312
335
|
|
|
336
|
+
|
|
313
337
|
def create_format_validator(schema_class: type[BaseModel], **kwargs) -> Guardrail:
|
|
314
338
|
"""Create format validator (test compatibility)."""
|
|
315
339
|
guardrail = create_json_validation_guardrail(schema_class, **kwargs)
|
jaf/providers/__init__.py
CHANGED
|
@@ -21,12 +21,13 @@ _DEPRECATED_ALIASES = {
|
|
|
21
21
|
_REMOVED_EXPORTS = {
|
|
22
22
|
# No safe automatic migration known — force an explicit choice.
|
|
23
23
|
"MCPClient": "MCPClient was removed. Use transport-specific tool factories: "
|
|
24
|
-
|
|
24
|
+
"create_mcp_stdio_tools, create_mcp_sse_tools, or create_mcp_http_tools.",
|
|
25
25
|
"create_mcp_tools_from_client": "create_mcp_tools_from_client was removed. "
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
"Construct tools via create_mcp_stdio_tools, "
|
|
27
|
+
"create_mcp_sse_tools, or create_mcp_http_tools as appropriate.",
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
+
|
|
30
31
|
def __getattr__(name: str):
|
|
31
32
|
if name in _DEPRECATED_ALIASES:
|
|
32
33
|
obj, new_name = _DEPRECATED_ALIASES[name]
|
|
@@ -46,10 +47,12 @@ def __getattr__(name: str):
|
|
|
46
47
|
raise AttributeError(f"{name} has been removed")
|
|
47
48
|
raise AttributeError(name)
|
|
48
49
|
|
|
50
|
+
|
|
49
51
|
def __dir__():
|
|
50
52
|
# Make deprecated names discoverable in REPLs without advertising in __all__
|
|
51
53
|
return sorted(set(globals()) | set(_DEPRECATED_ALIASES))
|
|
52
54
|
|
|
55
|
+
|
|
53
56
|
__all__ = [
|
|
54
57
|
"FastMCPTool",
|
|
55
58
|
"MCPToolArgs",
|