jaf-py 2.5.10__py3-none-any.whl → 2.5.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +154 -57
- jaf/a2a/__init__.py +42 -21
- jaf/a2a/agent.py +79 -126
- jaf/a2a/agent_card.py +87 -78
- jaf/a2a/client.py +30 -66
- jaf/a2a/examples/client_example.py +12 -12
- jaf/a2a/examples/integration_example.py +38 -47
- jaf/a2a/examples/server_example.py +56 -53
- jaf/a2a/memory/__init__.py +0 -4
- jaf/a2a/memory/cleanup.py +28 -21
- jaf/a2a/memory/factory.py +155 -133
- jaf/a2a/memory/providers/composite.py +21 -26
- jaf/a2a/memory/providers/in_memory.py +89 -83
- jaf/a2a/memory/providers/postgres.py +117 -115
- jaf/a2a/memory/providers/redis.py +128 -121
- jaf/a2a/memory/serialization.py +77 -87
- jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
- jaf/a2a/memory/tests/test_cleanup.py +211 -94
- jaf/a2a/memory/tests/test_serialization.py +73 -68
- jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
- jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
- jaf/a2a/memory/types.py +91 -53
- jaf/a2a/protocol.py +95 -125
- jaf/a2a/server.py +90 -118
- jaf/a2a/standalone_client.py +30 -43
- jaf/a2a/tests/__init__.py +16 -33
- jaf/a2a/tests/run_tests.py +17 -53
- jaf/a2a/tests/test_agent.py +40 -140
- jaf/a2a/tests/test_client.py +54 -117
- jaf/a2a/tests/test_integration.py +28 -82
- jaf/a2a/tests/test_protocol.py +54 -139
- jaf/a2a/tests/test_types.py +50 -136
- jaf/a2a/types.py +58 -34
- jaf/cli.py +21 -41
- jaf/core/__init__.py +7 -1
- jaf/core/agent_tool.py +93 -72
- jaf/core/analytics.py +257 -207
- jaf/core/checkpoint.py +223 -0
- jaf/core/composition.py +249 -235
- jaf/core/engine.py +817 -519
- jaf/core/errors.py +55 -42
- jaf/core/guardrails.py +276 -202
- jaf/core/handoff.py +47 -31
- jaf/core/parallel_agents.py +69 -75
- jaf/core/performance.py +75 -73
- jaf/core/proxy.py +43 -44
- jaf/core/proxy_helpers.py +24 -27
- jaf/core/regeneration.py +220 -129
- jaf/core/state.py +68 -66
- jaf/core/streaming.py +115 -108
- jaf/core/tool_results.py +111 -101
- jaf/core/tools.py +114 -116
- jaf/core/tracing.py +310 -210
- jaf/core/types.py +403 -151
- jaf/core/workflows.py +209 -168
- jaf/exceptions.py +46 -38
- jaf/memory/__init__.py +1 -6
- jaf/memory/approval_storage.py +54 -77
- jaf/memory/factory.py +4 -4
- jaf/memory/providers/in_memory.py +216 -180
- jaf/memory/providers/postgres.py +216 -146
- jaf/memory/providers/redis.py +173 -116
- jaf/memory/types.py +70 -51
- jaf/memory/utils.py +36 -34
- jaf/plugins/__init__.py +12 -12
- jaf/plugins/base.py +105 -96
- jaf/policies/__init__.py +0 -1
- jaf/policies/handoff.py +37 -46
- jaf/policies/validation.py +76 -52
- jaf/providers/__init__.py +6 -3
- jaf/providers/mcp.py +97 -51
- jaf/providers/model.py +475 -283
- jaf/server/__init__.py +1 -1
- jaf/server/main.py +7 -11
- jaf/server/server.py +514 -359
- jaf/server/types.py +208 -52
- jaf/utils/__init__.py +17 -18
- jaf/utils/attachments.py +111 -116
- jaf/utils/document_processor.py +175 -174
- jaf/visualization/__init__.py +1 -1
- jaf/visualization/example.py +111 -110
- jaf/visualization/functional_core.py +46 -71
- jaf/visualization/graphviz.py +154 -189
- jaf/visualization/imperative_shell.py +7 -16
- jaf/visualization/types.py +8 -4
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/METADATA +2 -2
- jaf_py-2.5.12.dist-info/RECORD +97 -0
- jaf_py-2.5.10.dist-info/RECORD +0 -96
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.10.dist-info → jaf_py-2.5.12.dist-info}/top_level.txt +0 -0
jaf/memory/providers/redis.py
CHANGED
|
@@ -24,10 +24,12 @@ from ..types import (
|
|
|
24
24
|
|
|
25
25
|
try:
|
|
26
26
|
import redis.asyncio as redis
|
|
27
|
+
|
|
27
28
|
RedisClient = redis.Redis
|
|
28
29
|
except ImportError:
|
|
29
30
|
RedisClient = Any
|
|
30
31
|
|
|
32
|
+
|
|
31
33
|
class RedisProvider(MemoryProvider):
|
|
32
34
|
"""
|
|
33
35
|
Redis implementation of MemoryProvider.
|
|
@@ -45,18 +47,20 @@ class RedisProvider(MemoryProvider):
|
|
|
45
47
|
def _serialize(self, conversation: ConversationMemory) -> str:
|
|
46
48
|
"""Serialize conversation using shared utilities."""
|
|
47
49
|
from ..utils import serialize_conversation_for_json
|
|
50
|
+
|
|
48
51
|
return serialize_conversation_for_json(conversation)
|
|
49
52
|
|
|
50
53
|
def _deserialize(self, data: str) -> ConversationMemory:
|
|
51
54
|
"""Deserialize conversation using shared utilities."""
|
|
52
55
|
from ..utils import deserialize_conversation_from_json
|
|
56
|
+
|
|
53
57
|
return deserialize_conversation_from_json(data)
|
|
54
58
|
|
|
55
59
|
async def store_messages(
|
|
56
60
|
self,
|
|
57
61
|
conversation_id: str,
|
|
58
62
|
messages: List[Message],
|
|
59
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
63
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
60
64
|
) -> Result[None, MemoryStorageError]:
|
|
61
65
|
try:
|
|
62
66
|
now = datetime.now()
|
|
@@ -69,18 +73,21 @@ class RedisProvider(MemoryProvider):
|
|
|
69
73
|
"updated_at": now,
|
|
70
74
|
"total_messages": len(messages),
|
|
71
75
|
"last_activity": now,
|
|
72
|
-
**(metadata or {})
|
|
73
|
-
}
|
|
76
|
+
**(metadata or {}),
|
|
77
|
+
},
|
|
74
78
|
)
|
|
75
79
|
key = self._get_key(conversation_id)
|
|
76
80
|
await self.redis_client.set(key, self._serialize(conversation), ex=self.config.ttl)
|
|
77
81
|
return Success(None)
|
|
78
82
|
except Exception as e:
|
|
79
|
-
return Failure(
|
|
83
|
+
return Failure(
|
|
84
|
+
MemoryStorageError(
|
|
85
|
+
operation="store_messages", provider="Redis", message=str(e), cause=e
|
|
86
|
+
)
|
|
87
|
+
)
|
|
80
88
|
|
|
81
89
|
async def get_conversation(
|
|
82
|
-
self,
|
|
83
|
-
conversation_id: str
|
|
90
|
+
self, conversation_id: str
|
|
84
91
|
) -> Result[Optional[ConversationMemory], MemoryStorageError]:
|
|
85
92
|
try:
|
|
86
93
|
key = self._get_key(conversation_id)
|
|
@@ -94,13 +101,17 @@ class RedisProvider(MemoryProvider):
|
|
|
94
101
|
await self.redis_client.set(key, self._serialize(conversation), ex=self.config.ttl)
|
|
95
102
|
return Success(conversation)
|
|
96
103
|
except Exception as e:
|
|
97
|
-
return Failure(
|
|
104
|
+
return Failure(
|
|
105
|
+
MemoryStorageError(
|
|
106
|
+
operation="get_conversation", provider="Redis", message=str(e), cause=e
|
|
107
|
+
)
|
|
108
|
+
)
|
|
98
109
|
|
|
99
110
|
async def append_messages(
|
|
100
111
|
self,
|
|
101
112
|
conversation_id: str,
|
|
102
113
|
messages: List[Message],
|
|
103
|
-
metadata: Optional[Dict[str, Any]] = None
|
|
114
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
104
115
|
) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
105
116
|
result = await self.get_conversation(conversation_id)
|
|
106
117
|
if isinstance(result, Failure):
|
|
@@ -108,7 +119,13 @@ class RedisProvider(MemoryProvider):
|
|
|
108
119
|
|
|
109
120
|
existing = result.data
|
|
110
121
|
if not existing:
|
|
111
|
-
return Failure(
|
|
122
|
+
return Failure(
|
|
123
|
+
MemoryNotFoundError(
|
|
124
|
+
conversation_id=conversation_id,
|
|
125
|
+
provider="Redis",
|
|
126
|
+
message=f"Conversation {conversation_id} not found",
|
|
127
|
+
)
|
|
128
|
+
)
|
|
112
129
|
|
|
113
130
|
# Convert tuple back to list, append new messages, then store
|
|
114
131
|
all_messages = list(existing.messages) + messages
|
|
@@ -117,14 +134,13 @@ class RedisProvider(MemoryProvider):
|
|
|
117
134
|
"updated_at": datetime.now(),
|
|
118
135
|
"last_activity": datetime.now(),
|
|
119
136
|
"total_messages": len(all_messages),
|
|
120
|
-
**(metadata or {})
|
|
137
|
+
**(metadata or {}),
|
|
121
138
|
}
|
|
122
139
|
|
|
123
140
|
return await self.store_messages(conversation_id, all_messages, updated_metadata)
|
|
124
141
|
|
|
125
142
|
async def find_conversations(
|
|
126
|
-
self,
|
|
127
|
-
query: MemoryQuery
|
|
143
|
+
self, query: MemoryQuery
|
|
128
144
|
) -> Result[List[ConversationMemory], MemoryStorageError]:
|
|
129
145
|
try:
|
|
130
146
|
keys = await self.redis_client.keys(f"{self.config.key_prefix}*")
|
|
@@ -144,12 +160,14 @@ class RedisProvider(MemoryProvider):
|
|
|
144
160
|
conversations.append(conv)
|
|
145
161
|
return Success(conversations)
|
|
146
162
|
except Exception as e:
|
|
147
|
-
return Failure(
|
|
163
|
+
return Failure(
|
|
164
|
+
MemoryStorageError(
|
|
165
|
+
operation="find_conversations", provider="Redis", message=str(e), cause=e
|
|
166
|
+
)
|
|
167
|
+
)
|
|
148
168
|
|
|
149
169
|
async def get_recent_messages(
|
|
150
|
-
self,
|
|
151
|
-
conversation_id: str,
|
|
152
|
-
limit: int = 50
|
|
170
|
+
self, conversation_id: str, limit: int = 50
|
|
153
171
|
) -> Result[List[Message], Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
154
172
|
result = await self.get_conversation(conversation_id)
|
|
155
173
|
if isinstance(result, Failure):
|
|
@@ -157,36 +175,47 @@ class RedisProvider(MemoryProvider):
|
|
|
157
175
|
|
|
158
176
|
conversation = result.data
|
|
159
177
|
if not conversation:
|
|
160
|
-
return Failure(
|
|
178
|
+
return Failure(
|
|
179
|
+
MemoryNotFoundError(
|
|
180
|
+
conversation_id=conversation_id,
|
|
181
|
+
provider="Redis",
|
|
182
|
+
message=f"Conversation {conversation_id} not found",
|
|
183
|
+
)
|
|
184
|
+
)
|
|
161
185
|
|
|
162
186
|
return Success(conversation.messages[-limit:])
|
|
163
187
|
|
|
164
|
-
async def delete_conversation(
|
|
165
|
-
self,
|
|
166
|
-
conversation_id: str
|
|
167
|
-
) -> Result[bool, MemoryStorageError]:
|
|
188
|
+
async def delete_conversation(self, conversation_id: str) -> Result[bool, MemoryStorageError]:
|
|
168
189
|
try:
|
|
169
190
|
deleted = await self.redis_client.delete(self._get_key(conversation_id))
|
|
170
191
|
return Success(deleted > 0)
|
|
171
192
|
except Exception as e:
|
|
172
|
-
return Failure(
|
|
193
|
+
return Failure(
|
|
194
|
+
MemoryStorageError(
|
|
195
|
+
operation="delete_conversation", provider="Redis", message=str(e), cause=e
|
|
196
|
+
)
|
|
197
|
+
)
|
|
173
198
|
|
|
174
|
-
async def clear_user_conversations(
|
|
175
|
-
self,
|
|
176
|
-
user_id: str
|
|
177
|
-
) -> Result[int, MemoryStorageError]:
|
|
199
|
+
async def clear_user_conversations(self, user_id: str) -> Result[int, MemoryStorageError]:
|
|
178
200
|
# This is inefficient in Redis, consider a different approach for production
|
|
179
|
-
return Failure(
|
|
201
|
+
return Failure(
|
|
202
|
+
MemoryStorageError(
|
|
203
|
+
operation="clear_user_conversations",
|
|
204
|
+
provider="Redis",
|
|
205
|
+
message="clear_user_conversations not efficiently supported",
|
|
206
|
+
)
|
|
207
|
+
)
|
|
180
208
|
|
|
181
209
|
async def get_stats(
|
|
182
|
-
self,
|
|
183
|
-
user_id: Optional[str] = None
|
|
210
|
+
self, user_id: Optional[str] = None
|
|
184
211
|
) -> Result[Dict[str, Any], MemoryStorageError]:
|
|
185
212
|
try:
|
|
186
213
|
keys = await self.redis_client.keys(f"{self.config.key_prefix}*")
|
|
187
214
|
return Success({"total_conversations": len(keys)})
|
|
188
215
|
except Exception as e:
|
|
189
|
-
return Failure(
|
|
216
|
+
return Failure(
|
|
217
|
+
MemoryStorageError(operation="get_stats", provider="Redis", message=str(e), cause=e)
|
|
218
|
+
)
|
|
190
219
|
|
|
191
220
|
async def health_check(self) -> Result[Dict[str, Any], MemoryConnectionError]:
|
|
192
221
|
start_time = datetime.now()
|
|
@@ -194,21 +223,23 @@ class RedisProvider(MemoryProvider):
|
|
|
194
223
|
await self.redis_client.ping()
|
|
195
224
|
latency_ms = (datetime.now() - start_time).total_seconds() * 1000
|
|
196
225
|
db_size = await self.redis_client.dbsize()
|
|
197
|
-
return Success(
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
"db_size": db_size
|
|
226
|
+
return Success(
|
|
227
|
+
{
|
|
228
|
+
"healthy": True,
|
|
229
|
+
"latency_ms": latency_ms,
|
|
230
|
+
"provider": "Redis",
|
|
231
|
+
"details": {"db_size": db_size},
|
|
203
232
|
}
|
|
204
|
-
|
|
233
|
+
)
|
|
205
234
|
except Exception as e:
|
|
206
|
-
return Failure(
|
|
235
|
+
return Failure(
|
|
236
|
+
MemoryConnectionError(
|
|
237
|
+
provider="Redis", message="Redis health check failed", cause=e
|
|
238
|
+
)
|
|
239
|
+
)
|
|
207
240
|
|
|
208
241
|
async def truncate_conversation_after(
|
|
209
|
-
self,
|
|
210
|
-
conversation_id: str,
|
|
211
|
-
message_id: MessageId
|
|
242
|
+
self, conversation_id: str, message_id: MessageId
|
|
212
243
|
) -> Result[int, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
213
244
|
"""
|
|
214
245
|
Truncate conversation after (and including) the specified message ID.
|
|
@@ -219,27 +250,29 @@ class RedisProvider(MemoryProvider):
|
|
|
219
250
|
conv_result = await self.get_conversation(conversation_id)
|
|
220
251
|
if isinstance(conv_result, Failure):
|
|
221
252
|
return conv_result
|
|
222
|
-
|
|
253
|
+
|
|
223
254
|
if not conv_result.data:
|
|
224
|
-
return Failure(
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
255
|
+
return Failure(
|
|
256
|
+
MemoryNotFoundError(
|
|
257
|
+
message=f"Conversation {conversation_id} not found",
|
|
258
|
+
provider="Redis",
|
|
259
|
+
conversation_id=conversation_id,
|
|
260
|
+
)
|
|
261
|
+
)
|
|
262
|
+
|
|
230
263
|
conversation = conv_result.data
|
|
231
264
|
messages = list(conversation.messages)
|
|
232
265
|
truncate_index = find_message_index(messages, message_id)
|
|
233
|
-
|
|
266
|
+
|
|
234
267
|
if truncate_index is None:
|
|
235
268
|
# Message not found, nothing to truncate
|
|
236
269
|
return Success(0)
|
|
237
|
-
|
|
270
|
+
|
|
238
271
|
# Truncate messages from the found index onwards
|
|
239
272
|
original_count = len(messages)
|
|
240
273
|
truncated_messages = messages[:truncate_index]
|
|
241
274
|
removed_count = original_count - len(truncated_messages)
|
|
242
|
-
|
|
275
|
+
|
|
243
276
|
# Update conversation with truncated messages
|
|
244
277
|
now = datetime.now()
|
|
245
278
|
updated_metadata = {
|
|
@@ -249,35 +282,39 @@ class RedisProvider(MemoryProvider):
|
|
|
249
282
|
"total_messages": len(truncated_messages),
|
|
250
283
|
"regeneration_truncated": True,
|
|
251
284
|
"truncated_at": now.isoformat(),
|
|
252
|
-
"messages_removed": removed_count
|
|
285
|
+
"messages_removed": removed_count,
|
|
253
286
|
}
|
|
254
|
-
|
|
287
|
+
|
|
255
288
|
# Store updated conversation
|
|
256
289
|
updated_conversation = ConversationMemory(
|
|
257
290
|
conversation_id=conversation_id,
|
|
258
291
|
user_id=conversation.user_id,
|
|
259
292
|
messages=truncated_messages,
|
|
260
|
-
metadata=updated_metadata
|
|
293
|
+
metadata=updated_metadata,
|
|
261
294
|
)
|
|
262
|
-
|
|
295
|
+
|
|
263
296
|
key = self._get_key(conversation_id)
|
|
264
|
-
await self.redis_client.set(
|
|
265
|
-
|
|
266
|
-
|
|
297
|
+
await self.redis_client.set(
|
|
298
|
+
key, self._serialize(updated_conversation), ex=self.config.ttl
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
print(
|
|
302
|
+
f"[MEMORY:Redis] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}"
|
|
303
|
+
)
|
|
267
304
|
return Success(removed_count)
|
|
268
|
-
|
|
305
|
+
|
|
269
306
|
except Exception as e:
|
|
270
|
-
return Failure(
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
307
|
+
return Failure(
|
|
308
|
+
MemoryStorageError(
|
|
309
|
+
message=f"Failed to truncate conversation: {e}",
|
|
310
|
+
provider="Redis",
|
|
311
|
+
operation="truncate_conversation_after",
|
|
312
|
+
cause=e,
|
|
313
|
+
)
|
|
314
|
+
)
|
|
276
315
|
|
|
277
316
|
async def get_conversation_until_message(
|
|
278
|
-
self,
|
|
279
|
-
conversation_id: str,
|
|
280
|
-
message_id: MessageId
|
|
317
|
+
self, conversation_id: str, message_id: MessageId
|
|
281
318
|
) -> Result[Optional[ConversationMemory], Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
282
319
|
"""
|
|
283
320
|
Get conversation history up to (but not including) the specified message ID.
|
|
@@ -288,22 +325,24 @@ class RedisProvider(MemoryProvider):
|
|
|
288
325
|
conv_result = await self.get_conversation(conversation_id)
|
|
289
326
|
if isinstance(conv_result, Failure):
|
|
290
327
|
return conv_result
|
|
291
|
-
|
|
328
|
+
|
|
292
329
|
if not conv_result.data:
|
|
293
330
|
return Success(None)
|
|
294
|
-
|
|
331
|
+
|
|
295
332
|
conversation = conv_result.data
|
|
296
333
|
messages = list(conversation.messages)
|
|
297
334
|
until_index = find_message_index(messages, message_id)
|
|
298
|
-
|
|
335
|
+
|
|
299
336
|
if until_index is None:
|
|
300
337
|
# Message not found, return None as lightweight indicator
|
|
301
|
-
print(
|
|
338
|
+
print(
|
|
339
|
+
f"[MEMORY:Redis] Message {message_id} not found in conversation {conversation_id}"
|
|
340
|
+
)
|
|
302
341
|
return Success(None)
|
|
303
|
-
|
|
342
|
+
|
|
304
343
|
# Return conversation up to (but not including) the specified message
|
|
305
344
|
truncated_messages = messages[:until_index]
|
|
306
|
-
|
|
345
|
+
|
|
307
346
|
# Create a copy of the conversation with truncated messages
|
|
308
347
|
truncated_conversation = ConversationMemory(
|
|
309
348
|
conversation_id=conversation.conversation_id,
|
|
@@ -314,26 +353,27 @@ class RedisProvider(MemoryProvider):
|
|
|
314
353
|
"truncated_for_regeneration": True,
|
|
315
354
|
"truncated_until_message": str(message_id),
|
|
316
355
|
"original_message_count": len(messages),
|
|
317
|
-
"truncated_message_count": len(truncated_messages)
|
|
318
|
-
}
|
|
356
|
+
"truncated_message_count": len(truncated_messages),
|
|
357
|
+
},
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
print(
|
|
361
|
+
f"[MEMORY:Redis] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages"
|
|
319
362
|
)
|
|
320
|
-
|
|
321
|
-
print(f"[MEMORY:Redis] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages")
|
|
322
363
|
return Success(truncated_conversation)
|
|
323
|
-
|
|
364
|
+
|
|
324
365
|
except Exception as e:
|
|
325
|
-
return Failure(
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
366
|
+
return Failure(
|
|
367
|
+
MemoryStorageError(
|
|
368
|
+
message=f"Failed to get conversation until message: {e}",
|
|
369
|
+
provider="Redis",
|
|
370
|
+
operation="get_conversation_until_message",
|
|
371
|
+
cause=e,
|
|
372
|
+
)
|
|
373
|
+
)
|
|
331
374
|
|
|
332
375
|
async def mark_regeneration_point(
|
|
333
|
-
self,
|
|
334
|
-
conversation_id: str,
|
|
335
|
-
message_id: MessageId,
|
|
336
|
-
regeneration_metadata: Dict[str, Any]
|
|
376
|
+
self, conversation_id: str, message_id: MessageId, regeneration_metadata: Dict[str, Any]
|
|
337
377
|
) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
338
378
|
"""
|
|
339
379
|
Mark a regeneration point in the conversation for audit purposes.
|
|
@@ -343,64 +383,79 @@ class RedisProvider(MemoryProvider):
|
|
|
343
383
|
conv_result = await self.get_conversation(conversation_id)
|
|
344
384
|
if isinstance(conv_result, Failure):
|
|
345
385
|
return conv_result
|
|
346
|
-
|
|
386
|
+
|
|
347
387
|
if not conv_result.data:
|
|
348
|
-
return Failure(
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
388
|
+
return Failure(
|
|
389
|
+
MemoryNotFoundError(
|
|
390
|
+
message=f"Conversation {conversation_id} not found",
|
|
391
|
+
provider="Redis",
|
|
392
|
+
conversation_id=conversation_id,
|
|
393
|
+
)
|
|
394
|
+
)
|
|
395
|
+
|
|
354
396
|
conversation = conv_result.data
|
|
355
|
-
|
|
397
|
+
|
|
356
398
|
# Add regeneration point to metadata
|
|
357
399
|
regeneration_points = conversation.metadata.get("regeneration_points", [])
|
|
358
400
|
regeneration_point = {
|
|
359
401
|
"message_id": str(message_id),
|
|
360
402
|
"timestamp": datetime.now().isoformat(),
|
|
361
|
-
**regeneration_metadata
|
|
403
|
+
**regeneration_metadata,
|
|
362
404
|
}
|
|
363
405
|
regeneration_points.append(regeneration_point)
|
|
364
|
-
|
|
406
|
+
|
|
365
407
|
# Update conversation metadata
|
|
366
408
|
updated_metadata = {
|
|
367
409
|
**conversation.metadata,
|
|
368
410
|
"regeneration_points": regeneration_points,
|
|
369
411
|
"last_regeneration": regeneration_point,
|
|
370
412
|
"updated_at": datetime.now(),
|
|
371
|
-
"regeneration_count": len(regeneration_points)
|
|
413
|
+
"regeneration_count": len(regeneration_points),
|
|
372
414
|
}
|
|
373
|
-
|
|
415
|
+
|
|
374
416
|
# Store updated conversation
|
|
375
417
|
updated_conversation = ConversationMemory(
|
|
376
418
|
conversation_id=conversation.conversation_id,
|
|
377
419
|
user_id=conversation.user_id,
|
|
378
420
|
messages=conversation.messages,
|
|
379
|
-
metadata=updated_metadata
|
|
421
|
+
metadata=updated_metadata,
|
|
380
422
|
)
|
|
381
|
-
|
|
423
|
+
|
|
382
424
|
key = self._get_key(conversation_id)
|
|
383
|
-
await self.redis_client.set(
|
|
384
|
-
|
|
385
|
-
|
|
425
|
+
await self.redis_client.set(
|
|
426
|
+
key, self._serialize(updated_conversation), ex=self.config.ttl
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
print(
|
|
430
|
+
f"[MEMORY:Redis] Marked regeneration point for conversation {conversation_id} at message {message_id}"
|
|
431
|
+
)
|
|
386
432
|
return Success(None)
|
|
387
|
-
|
|
433
|
+
|
|
388
434
|
except Exception as e:
|
|
389
|
-
return Failure(
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
435
|
+
return Failure(
|
|
436
|
+
MemoryStorageError(
|
|
437
|
+
message=f"Failed to mark regeneration point: {e}",
|
|
438
|
+
provider="Redis",
|
|
439
|
+
operation="mark_regeneration_point",
|
|
440
|
+
cause=e,
|
|
441
|
+
)
|
|
442
|
+
)
|
|
395
443
|
|
|
396
444
|
async def close(self) -> Result[None, MemoryConnectionError]:
|
|
397
445
|
try:
|
|
398
446
|
await self.redis_client.aclose()
|
|
399
447
|
return Success(None)
|
|
400
448
|
except Exception as e:
|
|
401
|
-
return Failure(
|
|
449
|
+
return Failure(
|
|
450
|
+
MemoryConnectionError(
|
|
451
|
+
provider="Redis", message="Failed to close Redis connection", cause=e
|
|
452
|
+
)
|
|
453
|
+
)
|
|
454
|
+
|
|
402
455
|
|
|
403
|
-
async def create_redis_provider(
|
|
456
|
+
async def create_redis_provider(
|
|
457
|
+
config: RedisConfig,
|
|
458
|
+
) -> Result[RedisProvider, MemoryConnectionError]:
|
|
404
459
|
try:
|
|
405
460
|
# These will be passed to the Redis client constructor
|
|
406
461
|
# and will override any values parsed from the URL.
|
|
@@ -421,4 +476,6 @@ async def create_redis_provider(config: RedisConfig) -> Result[RedisProvider, Me
|
|
|
421
476
|
await redis_client.ping()
|
|
422
477
|
return Success(RedisProvider(config, redis_client))
|
|
423
478
|
except Exception as e:
|
|
424
|
-
return Failure(
|
|
479
|
+
return Failure(
|
|
480
|
+
MemoryConnectionError(provider="Redis", message="Failed to connect to Redis", cause=e)
|
|
481
|
+
)
|