jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +269 -210
- jaf/core/types.py +371 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +360 -279
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
- jaf_py-2.5.11.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/exceptions.py
CHANGED
|
@@ -21,6 +21,7 @@ class JAFException(Exception):
|
|
|
21
21
|
# Agent and execution errors
|
|
22
22
|
class AgentException(JAFException):
|
|
23
23
|
"""Base exception for agent-related errors."""
|
|
24
|
+
|
|
24
25
|
pass
|
|
25
26
|
|
|
26
27
|
|
|
@@ -35,7 +36,9 @@ class AgentNotFoundError(AgentException):
|
|
|
35
36
|
class HandoffError(AgentException):
|
|
36
37
|
"""Raised when agent handoff fails."""
|
|
37
38
|
|
|
38
|
-
def __init__(
|
|
39
|
+
def __init__(
|
|
40
|
+
self, message: str, source_agent: Optional[str] = None, target_agent: Optional[str] = None
|
|
41
|
+
):
|
|
39
42
|
details = {}
|
|
40
43
|
if source_agent:
|
|
41
44
|
details["source_agent"] = source_agent
|
|
@@ -47,6 +50,7 @@ class HandoffError(AgentException):
|
|
|
47
50
|
# Tool execution errors
|
|
48
51
|
class ToolException(JAFException):
|
|
49
52
|
"""Base exception for tool-related errors."""
|
|
53
|
+
|
|
50
54
|
pass
|
|
51
55
|
|
|
52
56
|
|
|
@@ -54,10 +58,10 @@ class ToolExecutionError(ToolException):
|
|
|
54
58
|
"""Raised when tool execution fails."""
|
|
55
59
|
|
|
56
60
|
def __init__(self, tool_name: str, message: str, cause: Optional[Exception] = None):
|
|
57
|
-
super().__init__(
|
|
58
|
-
"tool_name
|
|
59
|
-
"cause": str(cause) if cause else None
|
|
60
|
-
|
|
61
|
+
super().__init__(
|
|
62
|
+
f"Tool '{tool_name}' execution failed: {message}",
|
|
63
|
+
{"tool_name": tool_name, "cause": str(cause) if cause else None},
|
|
64
|
+
)
|
|
61
65
|
self.tool_name = tool_name
|
|
62
66
|
self.cause = cause
|
|
63
67
|
|
|
@@ -67,10 +71,7 @@ class ToolValidationError(ToolException):
|
|
|
67
71
|
|
|
68
72
|
def __init__(self, tool_name: str, validation_errors: List[str]):
|
|
69
73
|
message = f"Tool '{tool_name}' validation failed: {'; '.join(validation_errors)}"
|
|
70
|
-
super().__init__(message, {
|
|
71
|
-
"tool_name": tool_name,
|
|
72
|
-
"validation_errors": validation_errors
|
|
73
|
-
})
|
|
74
|
+
super().__init__(message, {"tool_name": tool_name, "validation_errors": validation_errors})
|
|
74
75
|
self.tool_name = tool_name
|
|
75
76
|
self.validation_errors = validation_errors
|
|
76
77
|
|
|
@@ -78,6 +79,7 @@ class ToolValidationError(ToolException):
|
|
|
78
79
|
# Model and provider errors
|
|
79
80
|
class ModelException(JAFException):
|
|
80
81
|
"""Base exception for model-related errors."""
|
|
82
|
+
|
|
81
83
|
pass
|
|
82
84
|
|
|
83
85
|
|
|
@@ -85,10 +87,10 @@ class ModelProviderError(ModelException):
|
|
|
85
87
|
"""Raised when model provider encounters an error."""
|
|
86
88
|
|
|
87
89
|
def __init__(self, provider: str, message: str, status_code: Optional[int] = None):
|
|
88
|
-
super().__init__(
|
|
89
|
-
"provider
|
|
90
|
-
"status_code": status_code
|
|
91
|
-
|
|
90
|
+
super().__init__(
|
|
91
|
+
f"Model provider '{provider}' error: {message}",
|
|
92
|
+
{"provider": provider, "status_code": status_code},
|
|
93
|
+
)
|
|
92
94
|
self.provider = provider
|
|
93
95
|
self.status_code = status_code
|
|
94
96
|
|
|
@@ -97,26 +99,27 @@ class ModelResponseError(ModelException):
|
|
|
97
99
|
"""Raised when model response is invalid or cannot be parsed."""
|
|
98
100
|
|
|
99
101
|
def __init__(self, message: str, raw_response: Optional[str] = None):
|
|
100
|
-
super().__init__(f"Model response error: {message}", {
|
|
101
|
-
"raw_response": raw_response
|
|
102
|
-
})
|
|
102
|
+
super().__init__(f"Model response error: {message}", {"raw_response": raw_response})
|
|
103
103
|
self.raw_response = raw_response
|
|
104
104
|
|
|
105
105
|
|
|
106
106
|
# Validation and guardrail errors
|
|
107
107
|
class ValidationException(JAFException):
|
|
108
108
|
"""Base exception for validation errors."""
|
|
109
|
+
|
|
109
110
|
pass
|
|
110
111
|
|
|
111
112
|
|
|
112
113
|
class GuardrailViolationError(ValidationException):
|
|
113
114
|
"""Raised when input or output violates a guardrail."""
|
|
114
115
|
|
|
115
|
-
def __init__(
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
116
|
+
def __init__(
|
|
117
|
+
self, guardrail_type: str, message: str, violation_details: Optional[Dict[str, Any]] = None
|
|
118
|
+
):
|
|
119
|
+
super().__init__(
|
|
120
|
+
f"Guardrail violation ({guardrail_type}): {message}",
|
|
121
|
+
{"guardrail_type": guardrail_type, "violation_details": violation_details or {}},
|
|
122
|
+
)
|
|
120
123
|
self.guardrail_type = guardrail_type
|
|
121
124
|
self.violation_details = violation_details or {}
|
|
122
125
|
|
|
@@ -125,15 +128,14 @@ class InputValidationError(ValidationException):
|
|
|
125
128
|
"""Raised when input validation fails."""
|
|
126
129
|
|
|
127
130
|
def __init__(self, message: str, field_errors: Optional[Dict[str, List[str]]] = None):
|
|
128
|
-
super().__init__(f"Input validation error: {message}", {
|
|
129
|
-
"field_errors": field_errors or {}
|
|
130
|
-
})
|
|
131
|
+
super().__init__(f"Input validation error: {message}", {"field_errors": field_errors or {}})
|
|
131
132
|
self.field_errors = field_errors or {}
|
|
132
133
|
|
|
133
134
|
|
|
134
135
|
# Memory system errors
|
|
135
136
|
class MemoryException(JAFException):
|
|
136
137
|
"""Base exception for memory system errors."""
|
|
138
|
+
|
|
137
139
|
pass
|
|
138
140
|
|
|
139
141
|
|
|
@@ -141,9 +143,7 @@ class MemoryConnectionError(MemoryException):
|
|
|
141
143
|
"""Raised when memory provider connection fails."""
|
|
142
144
|
|
|
143
145
|
def __init__(self, provider: str, message: str):
|
|
144
|
-
super().__init__(f"Memory connection error ({provider}): {message}", {
|
|
145
|
-
"provider": provider
|
|
146
|
-
})
|
|
146
|
+
super().__init__(f"Memory connection error ({provider}): {message}", {"provider": provider})
|
|
147
147
|
self.provider = provider
|
|
148
148
|
|
|
149
149
|
|
|
@@ -162,6 +162,7 @@ class MemoryStorageError(MemoryException):
|
|
|
162
162
|
# Session and workflow errors
|
|
163
163
|
class SessionException(JAFException):
|
|
164
164
|
"""Base exception for session-related errors."""
|
|
165
|
+
|
|
165
166
|
pass
|
|
166
167
|
|
|
167
168
|
|
|
@@ -180,10 +181,10 @@ class MaxTurnsExceededError(SessionException):
|
|
|
180
181
|
"""Raised when maximum number of turns is exceeded."""
|
|
181
182
|
|
|
182
183
|
def __init__(self, max_turns: int, current_turns: int):
|
|
183
|
-
super().__init__(
|
|
184
|
-
"
|
|
185
|
-
"current_turns": current_turns
|
|
186
|
-
|
|
184
|
+
super().__init__(
|
|
185
|
+
f"Maximum turns exceeded: {current_turns}/{max_turns}",
|
|
186
|
+
{"max_turns": max_turns, "current_turns": current_turns},
|
|
187
|
+
)
|
|
187
188
|
self.max_turns = max_turns
|
|
188
189
|
self.current_turns = current_turns
|
|
189
190
|
|
|
@@ -191,13 +192,16 @@ class MaxTurnsExceededError(SessionException):
|
|
|
191
192
|
# A2A protocol errors
|
|
192
193
|
class A2AException(JAFException):
|
|
193
194
|
"""Base exception for A2A protocol errors."""
|
|
195
|
+
|
|
194
196
|
pass
|
|
195
197
|
|
|
196
198
|
|
|
197
199
|
class A2AProtocolError(A2AException):
|
|
198
200
|
"""Raised when A2A protocol operation fails."""
|
|
199
201
|
|
|
200
|
-
def __init__(
|
|
202
|
+
def __init__(
|
|
203
|
+
self, message: str, method: Optional[str] = None, context_id: Optional[str] = None
|
|
204
|
+
):
|
|
201
205
|
details = {}
|
|
202
206
|
if method:
|
|
203
207
|
details["method"] = method
|
|
@@ -222,6 +226,7 @@ class A2ATaskError(A2AException):
|
|
|
222
226
|
# Configuration errors
|
|
223
227
|
class ConfigurationException(JAFException):
|
|
224
228
|
"""Base exception for configuration errors."""
|
|
229
|
+
|
|
225
230
|
pass
|
|
226
231
|
|
|
227
232
|
|
|
@@ -229,10 +234,10 @@ class InvalidConfigurationError(ConfigurationException):
|
|
|
229
234
|
"""Raised when configuration is invalid."""
|
|
230
235
|
|
|
231
236
|
def __init__(self, config_type: str, message: str, config_errors: Optional[List[str]] = None):
|
|
232
|
-
super().__init__(
|
|
233
|
-
"config_type
|
|
234
|
-
"config_errors": config_errors or []
|
|
235
|
-
|
|
237
|
+
super().__init__(
|
|
238
|
+
f"Invalid {config_type} configuration: {message}",
|
|
239
|
+
{"config_type": config_type, "config_errors": config_errors or []},
|
|
240
|
+
)
|
|
236
241
|
self.config_type = config_type
|
|
237
242
|
self.config_errors = config_errors or []
|
|
238
243
|
|
|
@@ -245,7 +250,9 @@ def create_agent_error(message: str, agent_name: Optional[str] = None) -> AgentE
|
|
|
245
250
|
return AgentException(message)
|
|
246
251
|
|
|
247
252
|
|
|
248
|
-
def create_tool_error(
|
|
253
|
+
def create_tool_error(
|
|
254
|
+
tool_name: str, message: str, cause: Optional[Exception] = None
|
|
255
|
+
) -> ToolException:
|
|
249
256
|
"""Create a tool-related error."""
|
|
250
257
|
if "validation" in message.lower():
|
|
251
258
|
return ToolValidationError(tool_name, [message])
|
|
@@ -257,7 +264,8 @@ def create_session_error(message: str, session_id: Optional[str] = None) -> Sess
|
|
|
257
264
|
if "max turns" in message.lower() or "maximum turns" in message.lower():
|
|
258
265
|
# Extract numbers if possible
|
|
259
266
|
import re
|
|
260
|
-
|
|
267
|
+
|
|
268
|
+
numbers = re.findall(r"\d+", message)
|
|
261
269
|
if len(numbers) >= 2:
|
|
262
270
|
return MaxTurnsExceededError(int(numbers[0]), int(numbers[1]))
|
|
263
271
|
return MaxTurnsExceededError(10, 10) # Default fallback
|
jaf/memory/__init__.py
CHANGED
|
@@ -37,29 +37,24 @@ __all__ = [
|
|
|
37
37
|
"MemoryProvider",
|
|
38
38
|
"MemoryQuery",
|
|
39
39
|
"MemoryConfig",
|
|
40
|
-
|
|
41
40
|
# Result types
|
|
42
41
|
"Result",
|
|
43
42
|
"Success",
|
|
44
43
|
"Failure",
|
|
45
|
-
|
|
46
44
|
# Configuration types
|
|
47
45
|
"InMemoryConfig",
|
|
48
46
|
"RedisConfig",
|
|
49
47
|
"PostgresConfig",
|
|
50
48
|
"MemoryProviderConfig",
|
|
51
|
-
|
|
52
49
|
# Error types
|
|
53
50
|
"MemoryError",
|
|
54
51
|
"MemoryConnectionError",
|
|
55
52
|
"MemoryNotFoundError",
|
|
56
53
|
"MemoryStorageError",
|
|
57
|
-
|
|
58
54
|
# Factory functions
|
|
59
55
|
"create_memory_provider_from_env",
|
|
60
|
-
|
|
61
56
|
# Provider factories
|
|
62
57
|
"create_in_memory_provider",
|
|
63
58
|
"create_redis_provider",
|
|
64
|
-
"create_postgres_provider"
|
|
59
|
+
"create_postgres_provider",
|
|
65
60
|
]
|
jaf/memory/approval_storage.py
CHANGED
|
@@ -15,19 +15,19 @@ from ..core.types import RunId, ApprovalValue
|
|
|
15
15
|
|
|
16
16
|
class ApprovalStorageResult:
|
|
17
17
|
"""Result wrapper for approval storage operations."""
|
|
18
|
-
|
|
18
|
+
|
|
19
19
|
def __init__(self, success: bool, data: Any = None, error: Optional[str] = None):
|
|
20
20
|
self.success = success
|
|
21
21
|
self.data = data
|
|
22
22
|
self.error = error
|
|
23
23
|
|
|
24
24
|
@classmethod
|
|
25
|
-
def success_result(cls, data: Any = None) ->
|
|
25
|
+
def success_result(cls, data: Any = None) -> "ApprovalStorageResult":
|
|
26
26
|
"""Create a successful result."""
|
|
27
27
|
return cls(success=True, data=data)
|
|
28
28
|
|
|
29
29
|
@classmethod
|
|
30
|
-
def error_result(cls, error: str) ->
|
|
30
|
+
def error_result(cls, error: str) -> "ApprovalStorageResult":
|
|
31
31
|
"""Create an error result."""
|
|
32
32
|
return cls(success=False, error=error)
|
|
33
33
|
|
|
@@ -41,44 +41,30 @@ class ApprovalStorage(ABC):
|
|
|
41
41
|
run_id: RunId,
|
|
42
42
|
tool_call_id: str,
|
|
43
43
|
approval: ApprovalValue,
|
|
44
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
44
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
45
45
|
) -> ApprovalStorageResult:
|
|
46
46
|
"""Store an approval decision for a tool call."""
|
|
47
47
|
pass
|
|
48
48
|
|
|
49
49
|
@abstractmethod
|
|
50
|
-
async def get_approval(
|
|
51
|
-
self,
|
|
52
|
-
run_id: RunId,
|
|
53
|
-
tool_call_id: str
|
|
54
|
-
) -> ApprovalStorageResult:
|
|
50
|
+
async def get_approval(self, run_id: RunId, tool_call_id: str) -> ApprovalStorageResult:
|
|
55
51
|
"""Retrieve approval for a specific tool call. Returns None if not found."""
|
|
56
52
|
pass
|
|
57
53
|
|
|
58
54
|
@abstractmethod
|
|
59
|
-
async def get_run_approvals(
|
|
60
|
-
self,
|
|
61
|
-
run_id: RunId
|
|
62
|
-
) -> ApprovalStorageResult:
|
|
55
|
+
async def get_run_approvals(self, run_id: RunId) -> ApprovalStorageResult:
|
|
63
56
|
"""Get all approvals for a run as a Dict[str, ApprovalValue]."""
|
|
64
57
|
pass
|
|
65
58
|
|
|
66
59
|
@abstractmethod
|
|
67
60
|
async def update_approval(
|
|
68
|
-
self,
|
|
69
|
-
run_id: RunId,
|
|
70
|
-
tool_call_id: str,
|
|
71
|
-
updates: Dict[str, Any]
|
|
61
|
+
self, run_id: RunId, tool_call_id: str, updates: Dict[str, Any]
|
|
72
62
|
) -> ApprovalStorageResult:
|
|
73
63
|
"""Update existing approval with additional context."""
|
|
74
64
|
pass
|
|
75
65
|
|
|
76
66
|
@abstractmethod
|
|
77
|
-
async def delete_approval(
|
|
78
|
-
self,
|
|
79
|
-
run_id: RunId,
|
|
80
|
-
tool_call_id: str
|
|
81
|
-
) -> ApprovalStorageResult:
|
|
67
|
+
async def delete_approval(self, run_id: RunId, tool_call_id: str) -> ApprovalStorageResult:
|
|
82
68
|
"""Delete approval for a tool call. Returns success status."""
|
|
83
69
|
pass
|
|
84
70
|
|
|
@@ -119,16 +105,16 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
119
105
|
run_id: RunId,
|
|
120
106
|
tool_call_id: str,
|
|
121
107
|
approval: ApprovalValue,
|
|
122
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
108
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
123
109
|
) -> ApprovalStorageResult:
|
|
124
110
|
"""Store an approval decision."""
|
|
125
111
|
try:
|
|
126
112
|
async with self._lock:
|
|
127
113
|
run_key = self._get_run_key(run_id)
|
|
128
|
-
|
|
114
|
+
|
|
129
115
|
if run_key not in self._approvals:
|
|
130
116
|
self._approvals[run_key] = {}
|
|
131
|
-
|
|
117
|
+
|
|
132
118
|
# Enhance approval with metadata if provided
|
|
133
119
|
enhanced_approval = approval
|
|
134
120
|
if metadata:
|
|
@@ -136,27 +122,23 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
136
122
|
enhanced_approval = ApprovalValue(
|
|
137
123
|
status=approval.status,
|
|
138
124
|
approved=approval.approved,
|
|
139
|
-
additional_context=additional_context
|
|
125
|
+
additional_context=additional_context,
|
|
140
126
|
)
|
|
141
|
-
|
|
127
|
+
|
|
142
128
|
self._approvals[run_key][tool_call_id] = enhanced_approval
|
|
143
|
-
|
|
129
|
+
|
|
144
130
|
return ApprovalStorageResult.success_result()
|
|
145
131
|
except Exception as e:
|
|
146
132
|
return ApprovalStorageResult.error_result(f"Failed to store approval: {e}")
|
|
147
133
|
|
|
148
|
-
async def get_approval(
|
|
149
|
-
self,
|
|
150
|
-
run_id: RunId,
|
|
151
|
-
tool_call_id: str
|
|
152
|
-
) -> ApprovalStorageResult:
|
|
134
|
+
async def get_approval(self, run_id: RunId, tool_call_id: str) -> ApprovalStorageResult:
|
|
153
135
|
"""Retrieve approval for a specific tool call."""
|
|
154
136
|
try:
|
|
155
137
|
async with self._lock:
|
|
156
138
|
run_key = self._get_run_key(run_id)
|
|
157
139
|
run_approvals = self._approvals.get(run_key, {})
|
|
158
140
|
approval = run_approvals.get(tool_call_id)
|
|
159
|
-
|
|
141
|
+
|
|
160
142
|
return ApprovalStorageResult.success_result(approval)
|
|
161
143
|
except Exception as e:
|
|
162
144
|
return ApprovalStorageResult.error_result(f"Failed to get approval: {e}")
|
|
@@ -167,63 +149,61 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
167
149
|
async with self._lock:
|
|
168
150
|
run_key = self._get_run_key(run_id)
|
|
169
151
|
run_approvals = self._approvals.get(run_key, {}).copy()
|
|
170
|
-
|
|
152
|
+
|
|
171
153
|
return ApprovalStorageResult.success_result(run_approvals)
|
|
172
154
|
except Exception as e:
|
|
173
155
|
return ApprovalStorageResult.error_result(f"Failed to get run approvals: {e}")
|
|
174
156
|
|
|
175
157
|
async def update_approval(
|
|
176
|
-
self,
|
|
177
|
-
run_id: RunId,
|
|
178
|
-
tool_call_id: str,
|
|
179
|
-
updates: Dict[str, Any]
|
|
158
|
+
self, run_id: RunId, tool_call_id: str, updates: Dict[str, Any]
|
|
180
159
|
) -> ApprovalStorageResult:
|
|
181
160
|
"""Update existing approval."""
|
|
182
161
|
try:
|
|
183
162
|
async with self._lock:
|
|
184
163
|
run_key = self._get_run_key(run_id)
|
|
185
|
-
|
|
164
|
+
|
|
186
165
|
if run_key not in self._approvals or tool_call_id not in self._approvals[run_key]:
|
|
187
166
|
return ApprovalStorageResult.error_result(
|
|
188
167
|
f"Approval not found for tool call {tool_call_id} in run {run_id}"
|
|
189
168
|
)
|
|
190
|
-
|
|
169
|
+
|
|
191
170
|
existing = self._approvals[run_key][tool_call_id]
|
|
192
|
-
|
|
171
|
+
|
|
193
172
|
# Merge additional context
|
|
194
|
-
merged_context = {
|
|
195
|
-
|
|
173
|
+
merged_context = {
|
|
174
|
+
**(existing.additional_context or {}),
|
|
175
|
+
**(updates.get("additional_context", {})),
|
|
176
|
+
}
|
|
177
|
+
|
|
196
178
|
updated_approval = ApprovalValue(
|
|
197
|
-
status=updates.get(
|
|
198
|
-
approved=updates.get(
|
|
199
|
-
additional_context=merged_context
|
|
179
|
+
status=updates.get("status", existing.status),
|
|
180
|
+
approved=updates.get("approved", existing.approved),
|
|
181
|
+
additional_context=merged_context
|
|
182
|
+
if merged_context
|
|
183
|
+
else existing.additional_context,
|
|
200
184
|
)
|
|
201
|
-
|
|
185
|
+
|
|
202
186
|
self._approvals[run_key][tool_call_id] = updated_approval
|
|
203
|
-
|
|
187
|
+
|
|
204
188
|
return ApprovalStorageResult.success_result()
|
|
205
189
|
except Exception as e:
|
|
206
190
|
return ApprovalStorageResult.error_result(f"Failed to update approval: {e}")
|
|
207
191
|
|
|
208
|
-
async def delete_approval(
|
|
209
|
-
self,
|
|
210
|
-
run_id: RunId,
|
|
211
|
-
tool_call_id: str
|
|
212
|
-
) -> ApprovalStorageResult:
|
|
192
|
+
async def delete_approval(self, run_id: RunId, tool_call_id: str) -> ApprovalStorageResult:
|
|
213
193
|
"""Delete approval for a tool call."""
|
|
214
194
|
try:
|
|
215
195
|
async with self._lock:
|
|
216
196
|
run_key = self._get_run_key(run_id)
|
|
217
|
-
|
|
197
|
+
|
|
218
198
|
if run_key not in self._approvals:
|
|
219
199
|
return ApprovalStorageResult.success_result(False)
|
|
220
|
-
|
|
200
|
+
|
|
221
201
|
deleted = self._approvals[run_key].pop(tool_call_id, None) is not None
|
|
222
|
-
|
|
202
|
+
|
|
223
203
|
# Clean up empty run maps
|
|
224
204
|
if not self._approvals[run_key]:
|
|
225
205
|
del self._approvals[run_key]
|
|
226
|
-
|
|
206
|
+
|
|
227
207
|
return ApprovalStorageResult.success_result(deleted)
|
|
228
208
|
except Exception as e:
|
|
229
209
|
return ApprovalStorageResult.error_result(f"Failed to delete approval: {e}")
|
|
@@ -233,13 +213,13 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
233
213
|
try:
|
|
234
214
|
async with self._lock:
|
|
235
215
|
run_key = self._get_run_key(run_id)
|
|
236
|
-
|
|
216
|
+
|
|
237
217
|
if run_key not in self._approvals:
|
|
238
218
|
return ApprovalStorageResult.success_result(0)
|
|
239
|
-
|
|
219
|
+
|
|
240
220
|
count = len(self._approvals[run_key])
|
|
241
221
|
del self._approvals[run_key]
|
|
242
|
-
|
|
222
|
+
|
|
243
223
|
return ApprovalStorageResult.success_result(count)
|
|
244
224
|
except Exception as e:
|
|
245
225
|
return ApprovalStorageResult.error_result(f"Failed to clear run approvals: {e}")
|
|
@@ -252,7 +232,7 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
252
232
|
approved_count = 0
|
|
253
233
|
rejected_count = 0
|
|
254
234
|
runs_with_approvals = len(self._approvals)
|
|
255
|
-
|
|
235
|
+
|
|
256
236
|
for run_approvals in self._approvals.values():
|
|
257
237
|
for approval in run_approvals.values():
|
|
258
238
|
total_approvals += 1
|
|
@@ -260,14 +240,14 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
260
240
|
approved_count += 1
|
|
261
241
|
else:
|
|
262
242
|
rejected_count += 1
|
|
263
|
-
|
|
243
|
+
|
|
264
244
|
stats = {
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
245
|
+
"total_approvals": total_approvals,
|
|
246
|
+
"approved_count": approved_count,
|
|
247
|
+
"rejected_count": rejected_count,
|
|
248
|
+
"runs_with_approvals": runs_with_approvals,
|
|
269
249
|
}
|
|
270
|
-
|
|
250
|
+
|
|
271
251
|
return ApprovalStorageResult.success_result(stats)
|
|
272
252
|
except Exception as e:
|
|
273
253
|
return ApprovalStorageResult.error_result(f"Failed to get stats: {e}")
|
|
@@ -277,18 +257,15 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
277
257
|
try:
|
|
278
258
|
# Simple operation to test functionality
|
|
279
259
|
await asyncio.sleep(0.001) # Minimal async operation
|
|
280
|
-
|
|
260
|
+
|
|
281
261
|
health_data = {
|
|
282
|
-
|
|
283
|
-
|
|
262
|
+
"healthy": True,
|
|
263
|
+
"latency_ms": 1.0, # Approximate for in-memory
|
|
284
264
|
}
|
|
285
|
-
|
|
265
|
+
|
|
286
266
|
return ApprovalStorageResult.success_result(health_data)
|
|
287
267
|
except Exception as e:
|
|
288
|
-
health_data = {
|
|
289
|
-
'healthy': False,
|
|
290
|
-
'error': str(e)
|
|
291
|
-
}
|
|
268
|
+
health_data = {"healthy": False, "error": str(e)}
|
|
292
269
|
return ApprovalStorageResult.success_result(health_data)
|
|
293
270
|
|
|
294
271
|
async def close(self) -> ApprovalStorageResult:
|
|
@@ -303,4 +280,4 @@ class InMemoryApprovalStorage(ApprovalStorage):
|
|
|
303
280
|
|
|
304
281
|
def create_in_memory_approval_storage() -> InMemoryApprovalStorage:
|
|
305
282
|
"""Create an in-memory approval storage instance."""
|
|
306
|
-
return InMemoryApprovalStorage()
|
|
283
|
+
return InMemoryApprovalStorage()
|
jaf/memory/factory.py
CHANGED
|
@@ -24,7 +24,7 @@ from .types import (
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
async def create_memory_provider_from_env(
|
|
27
|
-
external_clients: Optional[Dict[str, Any]] = None
|
|
27
|
+
external_clients: Optional[Dict[str, Any]] = None,
|
|
28
28
|
) -> Result[MemoryProvider, MemoryConnectionError]:
|
|
29
29
|
"""
|
|
30
30
|
Create a memory provider based on environment variables.
|
|
@@ -35,7 +35,7 @@ async def create_memory_provider_from_env(
|
|
|
35
35
|
if memory_type == "memory":
|
|
36
36
|
config = InMemoryConfig(
|
|
37
37
|
max_conversations=int(os.getenv("JAF_MEMORY_MAX_CONVERSATIONS", "1000")),
|
|
38
|
-
max_messages_per_conversation=int(os.getenv("JAF_MEMORY_MAX_MESSAGES", "1000"))
|
|
38
|
+
max_messages_per_conversation=int(os.getenv("JAF_MEMORY_MAX_MESSAGES", "1000")),
|
|
39
39
|
)
|
|
40
40
|
return Success(create_in_memory_provider(config))
|
|
41
41
|
|
|
@@ -47,7 +47,7 @@ async def create_memory_provider_from_env(
|
|
|
47
47
|
"port": int(os.getenv("JAF_REDIS_PORT", "6379")),
|
|
48
48
|
"db": int(os.getenv("JAF_REDIS_DB", "0")),
|
|
49
49
|
"key_prefix": os.getenv("JAF_REDIS_PREFIX", "jaf:memory:"),
|
|
50
|
-
"ttl": int(os.getenv("JAF_REDIS_TTL")) if os.getenv("JAF_REDIS_TTL") else None
|
|
50
|
+
"ttl": int(os.getenv("JAF_REDIS_TTL")) if os.getenv("JAF_REDIS_TTL") else None,
|
|
51
51
|
}
|
|
52
52
|
if redis_password:
|
|
53
53
|
config_data["password"] = redis_password
|
|
@@ -65,7 +65,7 @@ async def create_memory_provider_from_env(
|
|
|
65
65
|
"password": os.getenv("JAF_POSTGRES_PASSWORD"),
|
|
66
66
|
"ssl": os.getenv("JAF_POSTGRES_SSL", "false").lower() == "true",
|
|
67
67
|
"table_name": os.getenv("JAF_POSTGRES_TABLE", "conversations"),
|
|
68
|
-
"max_connections": int(os.getenv("JAF_POSTGRES_MAX_CONNECTIONS", "10"))
|
|
68
|
+
"max_connections": int(os.getenv("JAF_POSTGRES_MAX_CONNECTIONS", "10")),
|
|
69
69
|
}
|
|
70
70
|
if connection_string:
|
|
71
71
|
config_data["connection_string"] = connection_string
|