jaf-py 2.5.10__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +360 -279
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.10.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.10.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
@@ -26,10 +26,12 @@ from ..utils import prepare_message_list_for_db
26
26
 
27
27
  try:
28
28
  import asyncpg
29
+
29
30
  PostgresClient = Union[asyncpg.Connection, asyncpg.Pool]
30
31
  except ImportError:
31
32
  PostgresClient = Any
32
33
 
34
+
33
35
  class PostgresProvider(MemoryProvider):
34
36
  """
35
37
  PostgreSQL implementation of MemoryProvider.
@@ -40,19 +42,19 @@ class PostgresProvider(MemoryProvider):
40
42
  self.client = client
41
43
 
42
44
  async def _db_fetch(self, query: str, *args):
43
- if hasattr(self.client, 'fetch'): # Pool
45
+ if hasattr(self.client, "fetch"): # Pool
44
46
  return await self.client.fetch(query, *args)
45
- else: # Connection
47
+ else: # Connection
46
48
  return await self.client.fetch(query, *args)
47
49
 
48
50
  async def _db_fetchrow(self, query: str, *args):
49
- if hasattr(self.client, 'fetchrow'):
51
+ if hasattr(self.client, "fetchrow"):
50
52
  return await self.client.fetchrow(query, *args)
51
53
  else:
52
54
  return await self.client.fetchrow(query, *args)
53
55
 
54
56
  async def _db_execute(self, query: str, *args) -> str:
55
- if hasattr(self.client, 'execute'):
57
+ if hasattr(self.client, "execute"):
56
58
  return await self.client.execute(query, *args)
57
59
  else:
58
60
  return await self.client.execute(query, *args)
@@ -61,21 +63,21 @@ class PostgresProvider(MemoryProvider):
61
63
  """Convert database row to ConversationMemory using shared utilities."""
62
64
  from ..utils import extract_messages_from_db_row, validate_conversation_metadata
63
65
 
64
- messages = extract_messages_from_db_row(row['messages'])
65
- metadata = validate_conversation_metadata(json.loads(row['metadata']))
66
+ messages = extract_messages_from_db_row(row["messages"])
67
+ metadata = validate_conversation_metadata(json.loads(row["metadata"]))
66
68
 
67
69
  return ConversationMemory(
68
- conversation_id=row['conversation_id'],
69
- user_id=row['user_id'],
70
+ conversation_id=row["conversation_id"],
71
+ user_id=row["user_id"],
70
72
  messages=messages,
71
- metadata=metadata
73
+ metadata=metadata,
72
74
  )
73
75
 
74
76
  async def store_messages(
75
77
  self,
76
78
  conversation_id: str,
77
79
  messages: List[Message],
78
- metadata: Optional[Dict[str, Any]] = None
80
+ metadata: Optional[Dict[str, Any]] = None,
79
81
  ) -> Result[None, MemoryStorageError]:
80
82
  try:
81
83
  now = datetime.now()
@@ -106,34 +108,48 @@ class PostgresProvider(MemoryProvider):
106
108
  current_metadata.get("user_id"),
107
109
  prepare_message_list_for_db(messages),
108
110
  json.dumps(insert_metadata),
109
- json.dumps(update_metadata)
111
+ json.dumps(update_metadata),
110
112
  )
111
113
  return Success(None)
112
114
  except Exception as e:
113
- return Failure(MemoryStorageError(operation="store_messages", provider="Postgres", message=str(e), cause=e))
115
+ return Failure(
116
+ MemoryStorageError(
117
+ operation="store_messages", provider="Postgres", message=str(e), cause=e
118
+ )
119
+ )
114
120
 
115
121
  async def get_conversation(
116
- self,
117
- conversation_id: str
122
+ self, conversation_id: str
118
123
  ) -> Result[Optional[ConversationMemory], MemoryStorageError]:
119
124
  try:
120
- row = await self._db_fetchrow(f"SELECT * FROM {self.config.table_name} WHERE conversation_id = $1", conversation_id)
125
+ row = await self._db_fetchrow(
126
+ f"SELECT * FROM {self.config.table_name} WHERE conversation_id = $1",
127
+ conversation_id,
128
+ )
121
129
  if not row:
122
130
  return Success(None)
123
131
 
124
132
  # Update last activity
125
133
  timestamp = datetime.now().isoformat()
126
- await self._db_execute(f"UPDATE {self.config.table_name} SET metadata = metadata || $2 WHERE conversation_id = $1", conversation_id, f'{{"last_activity": "{timestamp}"}}')
134
+ await self._db_execute(
135
+ f"UPDATE {self.config.table_name} SET metadata = metadata || $2 WHERE conversation_id = $1",
136
+ conversation_id,
137
+ f'{{"last_activity": "{timestamp}"}}',
138
+ )
127
139
 
128
140
  return Success(self._row_to_conversation(row))
129
141
  except Exception as e:
130
- return Failure(MemoryStorageError(operation="get_conversation", provider="Postgres", message=str(e), cause=e))
142
+ return Failure(
143
+ MemoryStorageError(
144
+ operation="get_conversation", provider="Postgres", message=str(e), cause=e
145
+ )
146
+ )
131
147
 
132
148
  async def append_messages(
133
149
  self,
134
150
  conversation_id: str,
135
151
  messages: List[Message],
136
- metadata: Optional[Dict[str, Any]] = None
152
+ metadata: Optional[Dict[str, Any]] = None,
137
153
  ) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
138
154
  try:
139
155
  # First, check if the conversation exists to provide a proper MemoryNotFoundError
@@ -142,7 +158,9 @@ class PostgresProvider(MemoryProvider):
142
158
  check_query = f"SELECT 1 FROM {self.config.table_name} WHERE conversation_id = $1"
143
159
  exists = await self._db_fetchrow(check_query, conversation_id)
144
160
  if not exists:
145
- return Failure(MemoryNotFoundError(conversation_id=conversation_id, provider="Postgres"))
161
+ return Failure(
162
+ MemoryNotFoundError(conversation_id=conversation_id, provider="Postgres")
163
+ )
146
164
 
147
165
  now = datetime.now()
148
166
  update_metadata = metadata or {}
@@ -162,29 +180,31 @@ class PostgresProvider(MemoryProvider):
162
180
  new_messages_json = prepare_message_list_for_db(messages)
163
181
 
164
182
  await self._db_execute(
165
- query,
166
- new_messages_json,
167
- json.dumps(update_metadata),
168
- conversation_id
183
+ query, new_messages_json, json.dumps(update_metadata), conversation_id
169
184
  )
170
185
  return Success(None)
171
186
  except Exception as e:
172
- return Failure(MemoryStorageError(operation="append_messages", provider="Postgres", message=str(e), cause=e))
187
+ return Failure(
188
+ MemoryStorageError(
189
+ operation="append_messages", provider="Postgres", message=str(e), cause=e
190
+ )
191
+ )
173
192
 
174
193
  async def find_conversations(
175
- self,
176
- query: MemoryQuery
194
+ self, query: MemoryQuery
177
195
  ) -> Result[List[ConversationMemory], MemoryStorageError]:
178
196
  try:
179
197
  rows = await self._db_fetch(f"SELECT * FROM {self.config.table_name}")
180
198
  return Success([self._row_to_conversation(row) for row in rows])
181
199
  except Exception as e:
182
- return Failure(MemoryStorageError(operation="find_conversations", provider="Postgres", message=str(e), cause=e))
200
+ return Failure(
201
+ MemoryStorageError(
202
+ operation="find_conversations", provider="Postgres", message=str(e), cause=e
203
+ )
204
+ )
183
205
 
184
206
  async def get_recent_messages(
185
- self,
186
- conversation_id: str,
187
- limit: int = 50
207
+ self, conversation_id: str, limit: int = 50
188
208
  ) -> Result[List[Message], Union[MemoryNotFoundError, MemoryStorageError]]:
189
209
  result = await self.get_conversation(conversation_id)
190
210
  if isinstance(result, Failure):
@@ -192,57 +212,73 @@ class PostgresProvider(MemoryProvider):
192
212
 
193
213
  conversation = result.data
194
214
  if not conversation:
195
- return Failure(MemoryNotFoundError(conversation_id=conversation_id, provider="Postgres", message=f"Conversation {conversation_id} not found"))
215
+ return Failure(
216
+ MemoryNotFoundError(
217
+ conversation_id=conversation_id,
218
+ provider="Postgres",
219
+ message=f"Conversation {conversation_id} not found",
220
+ )
221
+ )
196
222
 
197
223
  return Success(conversation.messages[-limit:])
198
224
 
199
- async def delete_conversation(
200
- self,
201
- conversation_id: str
202
- ) -> Result[bool, MemoryStorageError]:
225
+ async def delete_conversation(self, conversation_id: str) -> Result[bool, MemoryStorageError]:
203
226
  try:
204
- result = await self._db_execute(f"DELETE FROM {self.config.table_name} WHERE conversation_id = $1", conversation_id)
205
- return Success('DELETE 1' in result)
227
+ result = await self._db_execute(
228
+ f"DELETE FROM {self.config.table_name} WHERE conversation_id = $1", conversation_id
229
+ )
230
+ return Success("DELETE 1" in result)
206
231
  except Exception as e:
207
- return Failure(MemoryStorageError(operation="delete_conversation", provider="Postgres", message=str(e), cause=e))
232
+ return Failure(
233
+ MemoryStorageError(
234
+ operation="delete_conversation", provider="Postgres", message=str(e), cause=e
235
+ )
236
+ )
208
237
 
209
- async def clear_user_conversations(
210
- self,
211
- user_id: str
212
- ) -> Result[int, MemoryStorageError]:
238
+ async def clear_user_conversations(self, user_id: str) -> Result[int, MemoryStorageError]:
213
239
  try:
214
- result = await self._db_execute(f"DELETE FROM {self.config.table_name} WHERE user_id = $1", user_id)
215
- return Success(int(result.split(' ')[1]))
240
+ result = await self._db_execute(
241
+ f"DELETE FROM {self.config.table_name} WHERE user_id = $1", user_id
242
+ )
243
+ return Success(int(result.split(" ")[1]))
216
244
  except Exception as e:
217
- return Failure(MemoryStorageError(operation="clear_user_conversations", provider="Postgres", message=str(e), cause=e))
245
+ return Failure(
246
+ MemoryStorageError(
247
+ operation="clear_user_conversations",
248
+ provider="Postgres",
249
+ message=str(e),
250
+ cause=e,
251
+ )
252
+ )
218
253
 
219
254
  async def get_stats(
220
- self,
221
- user_id: Optional[str] = None
255
+ self, user_id: Optional[str] = None
222
256
  ) -> Result[Dict[str, Any], MemoryStorageError]:
223
257
  try:
224
258
  row = await self._db_fetchrow(f"SELECT COUNT(*) as count FROM {self.config.table_name}")
225
- return Success({"total_conversations": row['count']})
259
+ return Success({"total_conversations": row["count"]})
226
260
  except Exception as e:
227
- return Failure(MemoryStorageError(operation="get_stats", provider="Postgres", message=str(e), cause=e))
261
+ return Failure(
262
+ MemoryStorageError(
263
+ operation="get_stats", provider="Postgres", message=str(e), cause=e
264
+ )
265
+ )
228
266
 
229
267
  async def health_check(self) -> Result[Dict[str, Any], MemoryConnectionError]:
230
268
  start_time = datetime.now()
231
269
  try:
232
270
  await self._db_fetch("SELECT 1")
233
271
  latency_ms = (datetime.now() - start_time).total_seconds() * 1000
234
- return Success({
235
- "healthy": True,
236
- "provider": "Postgres",
237
- "latency_ms": latency_ms
238
- })
272
+ return Success({"healthy": True, "provider": "Postgres", "latency_ms": latency_ms})
239
273
  except Exception as e:
240
- return Failure(MemoryConnectionError(provider="Postgres", message="Postgres health check failed", cause=e))
274
+ return Failure(
275
+ MemoryConnectionError(
276
+ provider="Postgres", message="Postgres health check failed", cause=e
277
+ )
278
+ )
241
279
 
242
280
  async def truncate_conversation_after(
243
- self,
244
- conversation_id: str,
245
- message_id: MessageId
281
+ self, conversation_id: str, message_id: MessageId
246
282
  ) -> Result[int, Union[MemoryNotFoundError, MemoryStorageError]]:
247
283
  """
248
284
  Truncate conversation after (and including) the specified message ID.
@@ -253,30 +289,32 @@ class PostgresProvider(MemoryProvider):
253
289
  conv_result = await self.get_conversation(conversation_id)
254
290
  if isinstance(conv_result, Failure):
255
291
  return conv_result
256
-
292
+
257
293
  if not conv_result.data:
258
- return Failure(MemoryNotFoundError(
259
- message=f"Conversation {conversation_id} not found",
260
- provider="Postgres",
261
- conversation_id=conversation_id
262
- ))
263
-
294
+ return Failure(
295
+ MemoryNotFoundError(
296
+ message=f"Conversation {conversation_id} not found",
297
+ provider="Postgres",
298
+ conversation_id=conversation_id,
299
+ )
300
+ )
301
+
264
302
  conversation = conv_result.data
265
303
  messages = list(conversation.messages)
266
304
  truncate_index = find_message_index(messages, message_id)
267
-
305
+
268
306
  if truncate_index is None:
269
307
  # Message not found, nothing to truncate
270
308
  return Success(0)
271
-
309
+
272
310
  # Truncate messages from the found index onwards
273
311
  original_count = len(messages)
274
312
  truncated_messages = messages[:truncate_index]
275
313
  removed_count = original_count - len(truncated_messages)
276
-
314
+
277
315
  # Update conversation with truncated messages
278
316
  now = datetime.now()
279
-
317
+
280
318
  # Convert any datetime objects in existing metadata to ISO strings
281
319
  serializable_metadata = {}
282
320
  for key, value in conversation.metadata.items():
@@ -284,7 +322,7 @@ class PostgresProvider(MemoryProvider):
284
322
  serializable_metadata[key] = value.isoformat()
285
323
  else:
286
324
  serializable_metadata[key] = value
287
-
325
+
288
326
  updated_metadata = {
289
327
  **serializable_metadata,
290
328
  "updated_at": now.isoformat(),
@@ -292,41 +330,44 @@ class PostgresProvider(MemoryProvider):
292
330
  "total_messages": len(truncated_messages),
293
331
  "regeneration_truncated": True,
294
332
  "truncated_at": now.isoformat(),
295
- "messages_removed": removed_count
333
+ "messages_removed": removed_count,
296
334
  }
297
-
335
+
298
336
  # Update in database
299
337
  query = f"""
300
338
  UPDATE {self.config.table_name}
301
339
  SET messages = $1::jsonb, metadata = $2::jsonb
302
340
  WHERE conversation_id = $3
303
341
  """
304
-
342
+
305
343
  await self._db_execute(
306
344
  query,
307
345
  prepare_message_list_for_db(truncated_messages),
308
346
  json.dumps(updated_metadata),
309
- conversation_id
347
+ conversation_id,
348
+ )
349
+
350
+ print(
351
+ f"[MEMORY:Postgres] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}"
310
352
  )
311
-
312
- print(f"[MEMORY:Postgres] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}")
313
353
  return Success(removed_count)
314
-
354
+
315
355
  except Exception as e:
316
356
  print(f"[MEMORY:Postgres] DEBUG: Exception in truncate_conversation_after: {e}")
317
357
  import traceback
358
+
318
359
  traceback.print_exc()
319
- return Failure(MemoryStorageError(
320
- message=f"Failed to truncate conversation: {e}",
321
- provider="Postgres",
322
- operation="truncate_conversation_after",
323
- cause=e
324
- ))
360
+ return Failure(
361
+ MemoryStorageError(
362
+ message=f"Failed to truncate conversation: {e}",
363
+ provider="Postgres",
364
+ operation="truncate_conversation_after",
365
+ cause=e,
366
+ )
367
+ )
325
368
 
326
369
  async def get_conversation_until_message(
327
- self,
328
- conversation_id: str,
329
- message_id: MessageId
370
+ self, conversation_id: str, message_id: MessageId
330
371
  ) -> Result[Optional[ConversationMemory], Union[MemoryNotFoundError, MemoryStorageError]]:
331
372
  """
332
373
  Get conversation history up to (but not including) the specified message ID.
@@ -337,22 +378,24 @@ class PostgresProvider(MemoryProvider):
337
378
  conv_result = await self.get_conversation(conversation_id)
338
379
  if isinstance(conv_result, Failure):
339
380
  return conv_result
340
-
381
+
341
382
  if not conv_result.data:
342
383
  return Success(None)
343
-
384
+
344
385
  conversation = conv_result.data
345
386
  messages = list(conversation.messages)
346
387
  until_index = find_message_index(messages, message_id)
347
-
388
+
348
389
  if until_index is None:
349
390
  # Message not found, return None as lightweight indicator
350
- print(f"[MEMORY:Postgres] Message {message_id} not found in conversation {conversation_id}")
391
+ print(
392
+ f"[MEMORY:Postgres] Message {message_id} not found in conversation {conversation_id}"
393
+ )
351
394
  return Success(None)
352
-
395
+
353
396
  # Return conversation up to (but not including) the specified message
354
397
  truncated_messages = messages[:until_index]
355
-
398
+
356
399
  # Create a copy of the conversation with truncated messages
357
400
  truncated_conversation = ConversationMemory(
358
401
  conversation_id=conversation.conversation_id,
@@ -363,26 +406,27 @@ class PostgresProvider(MemoryProvider):
363
406
  "truncated_for_regeneration": True,
364
407
  "truncated_until_message": str(message_id),
365
408
  "original_message_count": len(messages),
366
- "truncated_message_count": len(truncated_messages)
367
- }
409
+ "truncated_message_count": len(truncated_messages),
410
+ },
411
+ )
412
+
413
+ print(
414
+ f"[MEMORY:Postgres] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages"
368
415
  )
369
-
370
- print(f"[MEMORY:Postgres] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages")
371
416
  return Success(truncated_conversation)
372
-
417
+
373
418
  except Exception as e:
374
- return Failure(MemoryStorageError(
375
- message=f"Failed to get conversation until message: {e}",
376
- provider="Postgres",
377
- operation="get_conversation_until_message",
378
- cause=e
379
- ))
419
+ return Failure(
420
+ MemoryStorageError(
421
+ message=f"Failed to get conversation until message: {e}",
422
+ provider="Postgres",
423
+ operation="get_conversation_until_message",
424
+ cause=e,
425
+ )
426
+ )
380
427
 
381
428
  async def mark_regeneration_point(
382
- self,
383
- conversation_id: str,
384
- message_id: MessageId,
385
- regeneration_metadata: Dict[str, Any]
429
+ self, conversation_id: str, message_id: MessageId, regeneration_metadata: Dict[str, Any]
386
430
  ) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
387
431
  """
388
432
  Mark a regeneration point in the conversation for audit purposes.
@@ -392,77 +436,100 @@ class PostgresProvider(MemoryProvider):
392
436
  conv_result = await self.get_conversation(conversation_id)
393
437
  if isinstance(conv_result, Failure):
394
438
  return conv_result
395
-
439
+
396
440
  if not conv_result.data:
397
- return Failure(MemoryNotFoundError(
398
- message=f"Conversation {conversation_id} not found",
399
- provider="Postgres",
400
- conversation_id=conversation_id
401
- ))
402
-
441
+ return Failure(
442
+ MemoryNotFoundError(
443
+ message=f"Conversation {conversation_id} not found",
444
+ provider="Postgres",
445
+ conversation_id=conversation_id,
446
+ )
447
+ )
448
+
403
449
  conversation = conv_result.data
404
-
450
+
405
451
  # Add regeneration point to metadata
406
452
  regeneration_points = conversation.metadata.get("regeneration_points", [])
407
453
  regeneration_point = {
408
454
  "message_id": str(message_id),
409
455
  "timestamp": datetime.now().isoformat(),
410
- **regeneration_metadata
456
+ **regeneration_metadata,
411
457
  }
412
458
  regeneration_points.append(regeneration_point)
413
-
459
+
414
460
  # Update conversation metadata
415
461
  updated_metadata = {
416
462
  **conversation.metadata,
417
463
  "regeneration_points": regeneration_points,
418
464
  "last_regeneration": regeneration_point,
419
465
  "updated_at": datetime.now().isoformat(),
420
- "regeneration_count": len(regeneration_points)
466
+ "regeneration_count": len(regeneration_points),
421
467
  }
422
-
468
+
423
469
  # Update in database using JSONB merge
424
470
  query = f"""
425
471
  UPDATE {self.config.table_name}
426
472
  SET metadata = metadata || $1::jsonb
427
473
  WHERE conversation_id = $2
428
474
  """
429
-
475
+
430
476
  await self._db_execute(
431
477
  query,
432
- json.dumps({
433
- "regeneration_points": regeneration_points,
434
- "last_regeneration": regeneration_point,
435
- "updated_at": updated_metadata["updated_at"],
436
- "regeneration_count": len(regeneration_points)
437
- }),
438
- conversation_id
478
+ json.dumps(
479
+ {
480
+ "regeneration_points": regeneration_points,
481
+ "last_regeneration": regeneration_point,
482
+ "updated_at": updated_metadata["updated_at"],
483
+ "regeneration_count": len(regeneration_points),
484
+ }
485
+ ),
486
+ conversation_id,
487
+ )
488
+
489
+ print(
490
+ f"[MEMORY:Postgres] Marked regeneration point for conversation {conversation_id} at message {message_id}"
439
491
  )
440
-
441
- print(f"[MEMORY:Postgres] Marked regeneration point for conversation {conversation_id} at message {message_id}")
442
492
  return Success(None)
443
-
493
+
444
494
  except Exception as e:
445
- return Failure(MemoryStorageError(
446
- message=f"Failed to mark regeneration point: {e}",
447
- provider="Postgres",
448
- operation="mark_regeneration_point",
449
- cause=e
450
- ))
495
+ return Failure(
496
+ MemoryStorageError(
497
+ message=f"Failed to mark regeneration point: {e}",
498
+ provider="Postgres",
499
+ operation="mark_regeneration_point",
500
+ cause=e,
501
+ )
502
+ )
451
503
 
452
504
  async def close(self) -> Result[None, MemoryConnectionError]:
453
505
  try:
454
- if hasattr(self.client, 'close'):
506
+ if hasattr(self.client, "close"):
455
507
  await self.client.close()
456
508
  return Success(None)
457
509
  except Exception as e:
458
- return Failure(MemoryConnectionError(provider="Postgres", message="Failed to close Postgres connection", cause=e))
510
+ return Failure(
511
+ MemoryConnectionError(
512
+ provider="Postgres", message="Failed to close Postgres connection", cause=e
513
+ )
514
+ )
459
515
 
460
- async def create_postgres_provider(config: PostgresConfig) -> Result[PostgresProvider, MemoryConnectionError]:
516
+
517
+ async def create_postgres_provider(
518
+ config: PostgresConfig,
519
+ ) -> Result[PostgresProvider, MemoryConnectionError]:
461
520
  try:
462
521
  # Connect to the default 'postgres' database to check if the target database exists
463
522
  try:
464
- conn = await asyncpg.connect(user=config.username, password=config.password, host=config.host, port=config.port, database='postgres')
465
- db_exists = await conn.fetchval("SELECT 1 FROM pg_database WHERE datname = $1", config.database)
523
+ conn = await asyncpg.connect(
524
+ user=config.username,
525
+ password=config.password,
526
+ host=config.host,
527
+ port=config.port,
528
+ database="postgres",
529
+ )
530
+ db_exists = await conn.fetchval(
531
+ "SELECT 1 FROM pg_database WHERE datname = $1", config.database
532
+ )
466
533
  if not db_exists:
467
534
  await conn.execute(f'CREATE DATABASE "{config.database}"')
468
535
  await conn.close()
@@ -474,9 +541,7 @@ async def create_postgres_provider(config: PostgresConfig) -> Result[PostgresPro
474
541
  # Now connect to the target database using connection pool with max_connections
475
542
  if config.connection_string:
476
543
  client = await asyncpg.create_pool(
477
- dsn=config.connection_string,
478
- min_size=1,
479
- max_size=config.max_connections
544
+ dsn=config.connection_string, min_size=1, max_size=config.max_connections
480
545
  )
481
546
  else:
482
547
  client = await asyncpg.create_pool(
@@ -486,7 +551,7 @@ async def create_postgres_provider(config: PostgresConfig) -> Result[PostgresPro
486
551
  password=config.password,
487
552
  database=config.database,
488
553
  min_size=1,
489
- max_size=config.max_connections
554
+ max_size=config.max_connections,
490
555
  )
491
556
 
492
557
  table_name = config.table_name or "conversations"
@@ -507,8 +572,13 @@ async def create_postgres_provider(config: PostgresConfig) -> Result[PostgresPro
507
572
  provider_config = config
508
573
  if not provider_config.table_name:
509
574
  from dataclasses import replace
575
+
510
576
  provider_config = replace(config, table_name=table_name)
511
577
 
512
578
  return Success(PostgresProvider(provider_config, client))
513
579
  except Exception as e:
514
- return Failure(MemoryConnectionError(provider="Postgres", message="Failed to connect to PostgreSQL", cause=e))
580
+ return Failure(
581
+ MemoryConnectionError(
582
+ provider="Postgres", message="Failed to connect to PostgreSQL", cause=e
583
+ )
584
+ )