jaf-py 2.5.3__py3-none-any.whl → 2.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +1 -1
- jaf/core/engine.py +159 -117
- jaf/core/regeneration.py +392 -0
- jaf/core/tracing.py +1 -1
- jaf/core/types.py +115 -2
- jaf/memory/providers/in_memory.py +174 -1
- jaf/memory/providers/postgres.py +211 -1
- jaf/memory/providers/redis.py +189 -1
- jaf/memory/types.py +35 -1
- jaf/memory/utils.py +2 -0
- jaf/server/server.py +163 -0
- jaf/server/types.py +49 -1
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/METADATA +2 -2
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/RECORD +18 -17
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.5.dist-info}/top_level.txt +0 -0
|
@@ -11,7 +11,7 @@ from collections import OrderedDict
|
|
|
11
11
|
from datetime import datetime
|
|
12
12
|
from typing import Any, Dict, List, Optional, Union
|
|
13
13
|
|
|
14
|
-
from ...core.types import Message
|
|
14
|
+
from ...core.types import Message, MessageId, find_message_index
|
|
15
15
|
from ..types import (
|
|
16
16
|
ConversationMemory,
|
|
17
17
|
Failure,
|
|
@@ -307,6 +307,179 @@ class InMemoryProvider(MemoryProvider):
|
|
|
307
307
|
cause=e
|
|
308
308
|
))
|
|
309
309
|
|
|
310
|
+
async def truncate_conversation_after(
|
|
311
|
+
self,
|
|
312
|
+
conversation_id: str,
|
|
313
|
+
message_id: MessageId
|
|
314
|
+
) -> Result[int, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
315
|
+
"""
|
|
316
|
+
Truncate conversation after (and including) the specified message ID.
|
|
317
|
+
Returns the number of messages removed.
|
|
318
|
+
"""
|
|
319
|
+
async with self._lock:
|
|
320
|
+
try:
|
|
321
|
+
conversation = self._conversations.get(conversation_id)
|
|
322
|
+
if conversation is None:
|
|
323
|
+
return Failure(MemoryNotFoundError(
|
|
324
|
+
message=f"Conversation {conversation_id} not found",
|
|
325
|
+
provider="InMemory",
|
|
326
|
+
conversation_id=conversation_id
|
|
327
|
+
))
|
|
328
|
+
|
|
329
|
+
messages = list(conversation.messages)
|
|
330
|
+
truncate_index = find_message_index(messages, message_id)
|
|
331
|
+
|
|
332
|
+
if truncate_index is None:
|
|
333
|
+
# Message not found, nothing to truncate
|
|
334
|
+
return Success(0)
|
|
335
|
+
|
|
336
|
+
# Truncate messages from the found index onwards
|
|
337
|
+
original_count = len(messages)
|
|
338
|
+
truncated_messages = messages[:truncate_index]
|
|
339
|
+
removed_count = original_count - len(truncated_messages)
|
|
340
|
+
|
|
341
|
+
# Update conversation with truncated messages
|
|
342
|
+
now = datetime.now()
|
|
343
|
+
updated_metadata = {
|
|
344
|
+
**conversation.metadata,
|
|
345
|
+
"updated_at": now,
|
|
346
|
+
"last_activity": now,
|
|
347
|
+
"total_messages": len(truncated_messages),
|
|
348
|
+
"regeneration_truncated": True,
|
|
349
|
+
"truncated_at": now.isoformat(),
|
|
350
|
+
"messages_removed": removed_count
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
updated_conversation = ConversationMemory(
|
|
354
|
+
conversation_id=conversation_id,
|
|
355
|
+
user_id=conversation.user_id,
|
|
356
|
+
messages=truncated_messages,
|
|
357
|
+
metadata=updated_metadata
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
self._conversations[conversation_id] = updated_conversation
|
|
361
|
+
self._conversations.move_to_end(conversation_id)
|
|
362
|
+
|
|
363
|
+
print(f"[MEMORY:InMemory] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}")
|
|
364
|
+
return Success(removed_count)
|
|
365
|
+
|
|
366
|
+
except Exception as e:
|
|
367
|
+
return Failure(MemoryStorageError(
|
|
368
|
+
message=f"Failed to truncate conversation: {e}",
|
|
369
|
+
provider="InMemory",
|
|
370
|
+
operation="truncate_conversation_after",
|
|
371
|
+
cause=e
|
|
372
|
+
))
|
|
373
|
+
|
|
374
|
+
async def get_conversation_until_message(
|
|
375
|
+
self,
|
|
376
|
+
conversation_id: str,
|
|
377
|
+
message_id: MessageId
|
|
378
|
+
) -> Result[Optional[ConversationMemory], Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
379
|
+
"""
|
|
380
|
+
Get conversation history up to (but not including) the specified message ID.
|
|
381
|
+
Useful for regeneration scenarios.
|
|
382
|
+
"""
|
|
383
|
+
async with self._lock:
|
|
384
|
+
try:
|
|
385
|
+
conversation = self._conversations.get(conversation_id)
|
|
386
|
+
if conversation is None:
|
|
387
|
+
return Success(None)
|
|
388
|
+
|
|
389
|
+
messages = list(conversation.messages)
|
|
390
|
+
until_index = find_message_index(messages, message_id)
|
|
391
|
+
|
|
392
|
+
if until_index is None:
|
|
393
|
+
# Message not found, return None as lightweight indicator
|
|
394
|
+
print(f"[MEMORY:InMemory] Message {message_id} not found in conversation {conversation_id}")
|
|
395
|
+
return Success(None)
|
|
396
|
+
|
|
397
|
+
# Return conversation up to (but not including) the specified message
|
|
398
|
+
truncated_messages = messages[:until_index]
|
|
399
|
+
|
|
400
|
+
# Create a copy of the conversation with truncated messages
|
|
401
|
+
truncated_conversation = ConversationMemory(
|
|
402
|
+
conversation_id=conversation.conversation_id,
|
|
403
|
+
user_id=conversation.user_id,
|
|
404
|
+
messages=truncated_messages,
|
|
405
|
+
metadata={
|
|
406
|
+
**conversation.metadata,
|
|
407
|
+
"truncated_for_regeneration": True,
|
|
408
|
+
"truncated_until_message": str(message_id),
|
|
409
|
+
"original_message_count": len(messages),
|
|
410
|
+
"truncated_message_count": len(truncated_messages)
|
|
411
|
+
}
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
print(f"[MEMORY:InMemory] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages (found at index {until_index})")
|
|
415
|
+
return Success(truncated_conversation)
|
|
416
|
+
|
|
417
|
+
except Exception as e:
|
|
418
|
+
return Failure(MemoryStorageError(
|
|
419
|
+
message=f"Failed to get conversation until message: {e}",
|
|
420
|
+
provider="InMemory",
|
|
421
|
+
operation="get_conversation_until_message",
|
|
422
|
+
cause=e
|
|
423
|
+
))
|
|
424
|
+
|
|
425
|
+
async def mark_regeneration_point(
|
|
426
|
+
self,
|
|
427
|
+
conversation_id: str,
|
|
428
|
+
message_id: MessageId,
|
|
429
|
+
regeneration_metadata: Dict[str, Any]
|
|
430
|
+
) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
431
|
+
"""
|
|
432
|
+
Mark a regeneration point in the conversation for audit purposes.
|
|
433
|
+
"""
|
|
434
|
+
async with self._lock:
|
|
435
|
+
try:
|
|
436
|
+
conversation = self._conversations.get(conversation_id)
|
|
437
|
+
if conversation is None:
|
|
438
|
+
return Failure(MemoryNotFoundError(
|
|
439
|
+
message=f"Conversation {conversation_id} not found",
|
|
440
|
+
provider="InMemory",
|
|
441
|
+
conversation_id=conversation_id
|
|
442
|
+
))
|
|
443
|
+
|
|
444
|
+
# Add regeneration point to metadata
|
|
445
|
+
regeneration_points = conversation.metadata.get("regeneration_points", [])
|
|
446
|
+
regeneration_point = {
|
|
447
|
+
"message_id": str(message_id),
|
|
448
|
+
"timestamp": datetime.now().isoformat(),
|
|
449
|
+
**regeneration_metadata
|
|
450
|
+
}
|
|
451
|
+
regeneration_points.append(regeneration_point)
|
|
452
|
+
|
|
453
|
+
# Update conversation metadata
|
|
454
|
+
updated_metadata = {
|
|
455
|
+
**conversation.metadata,
|
|
456
|
+
"regeneration_points": regeneration_points,
|
|
457
|
+
"last_regeneration": regeneration_point,
|
|
458
|
+
"updated_at": datetime.now(),
|
|
459
|
+
"regeneration_count": len(regeneration_points)
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
updated_conversation = ConversationMemory(
|
|
463
|
+
conversation_id=conversation.conversation_id,
|
|
464
|
+
user_id=conversation.user_id,
|
|
465
|
+
messages=conversation.messages,
|
|
466
|
+
metadata=updated_metadata
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
self._conversations[conversation_id] = updated_conversation
|
|
470
|
+
self._conversations.move_to_end(conversation_id)
|
|
471
|
+
|
|
472
|
+
print(f"[MEMORY:InMemory] Marked regeneration point for conversation {conversation_id} at message {message_id}")
|
|
473
|
+
return Success(None)
|
|
474
|
+
|
|
475
|
+
except Exception as e:
|
|
476
|
+
return Failure(MemoryStorageError(
|
|
477
|
+
message=f"Failed to mark regeneration point: {e}",
|
|
478
|
+
provider="InMemory",
|
|
479
|
+
operation="mark_regeneration_point",
|
|
480
|
+
cause=e
|
|
481
|
+
))
|
|
482
|
+
|
|
310
483
|
async def close(self) -> Result[None, MemoryConnectionError]:
|
|
311
484
|
"""Close/cleanup the provider."""
|
|
312
485
|
async with self._lock:
|
jaf/memory/providers/postgres.py
CHANGED
|
@@ -9,7 +9,7 @@ import json
|
|
|
9
9
|
from datetime import datetime
|
|
10
10
|
from typing import Any, Dict, List, Optional, Union
|
|
11
11
|
|
|
12
|
-
from ...core.types import Message
|
|
12
|
+
from ...core.types import Message, MessageId, find_message_index
|
|
13
13
|
from ..types import (
|
|
14
14
|
ConversationMemory,
|
|
15
15
|
Failure,
|
|
@@ -239,6 +239,216 @@ class PostgresProvider(MemoryProvider):
|
|
|
239
239
|
except Exception as e:
|
|
240
240
|
return Failure(MemoryConnectionError(provider="Postgres", message="Postgres health check failed", cause=e))
|
|
241
241
|
|
|
242
|
+
async def truncate_conversation_after(
|
|
243
|
+
self,
|
|
244
|
+
conversation_id: str,
|
|
245
|
+
message_id: MessageId
|
|
246
|
+
) -> Result[int, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
247
|
+
"""
|
|
248
|
+
Truncate conversation after (and including) the specified message ID.
|
|
249
|
+
Returns the number of messages removed.
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
# Get the conversation
|
|
253
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
254
|
+
if isinstance(conv_result, Failure):
|
|
255
|
+
return conv_result
|
|
256
|
+
|
|
257
|
+
if not conv_result.data:
|
|
258
|
+
return Failure(MemoryNotFoundError(
|
|
259
|
+
message=f"Conversation {conversation_id} not found",
|
|
260
|
+
provider="Postgres",
|
|
261
|
+
conversation_id=conversation_id
|
|
262
|
+
))
|
|
263
|
+
|
|
264
|
+
conversation = conv_result.data
|
|
265
|
+
messages = list(conversation.messages)
|
|
266
|
+
truncate_index = find_message_index(messages, message_id)
|
|
267
|
+
|
|
268
|
+
if truncate_index is None:
|
|
269
|
+
# Message not found, nothing to truncate
|
|
270
|
+
return Success(0)
|
|
271
|
+
|
|
272
|
+
# Truncate messages from the found index onwards
|
|
273
|
+
original_count = len(messages)
|
|
274
|
+
truncated_messages = messages[:truncate_index]
|
|
275
|
+
removed_count = original_count - len(truncated_messages)
|
|
276
|
+
|
|
277
|
+
# Update conversation with truncated messages
|
|
278
|
+
now = datetime.now()
|
|
279
|
+
|
|
280
|
+
# Convert any datetime objects in existing metadata to ISO strings
|
|
281
|
+
serializable_metadata = {}
|
|
282
|
+
for key, value in conversation.metadata.items():
|
|
283
|
+
if isinstance(value, datetime):
|
|
284
|
+
serializable_metadata[key] = value.isoformat()
|
|
285
|
+
else:
|
|
286
|
+
serializable_metadata[key] = value
|
|
287
|
+
|
|
288
|
+
updated_metadata = {
|
|
289
|
+
**serializable_metadata,
|
|
290
|
+
"updated_at": now.isoformat(),
|
|
291
|
+
"last_activity": now.isoformat(),
|
|
292
|
+
"total_messages": len(truncated_messages),
|
|
293
|
+
"regeneration_truncated": True,
|
|
294
|
+
"truncated_at": now.isoformat(),
|
|
295
|
+
"messages_removed": removed_count
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
# Update in database
|
|
299
|
+
query = f"""
|
|
300
|
+
UPDATE {self.config.table_name}
|
|
301
|
+
SET messages = $1::jsonb, metadata = $2::jsonb
|
|
302
|
+
WHERE conversation_id = $3
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
await self._db_execute(
|
|
306
|
+
query,
|
|
307
|
+
prepare_message_list_for_db(truncated_messages),
|
|
308
|
+
json.dumps(updated_metadata),
|
|
309
|
+
conversation_id
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
print(f"[MEMORY:Postgres] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}")
|
|
313
|
+
return Success(removed_count)
|
|
314
|
+
|
|
315
|
+
except Exception as e:
|
|
316
|
+
print(f"[MEMORY:Postgres] DEBUG: Exception in truncate_conversation_after: {e}")
|
|
317
|
+
import traceback
|
|
318
|
+
traceback.print_exc()
|
|
319
|
+
return Failure(MemoryStorageError(
|
|
320
|
+
message=f"Failed to truncate conversation: {e}",
|
|
321
|
+
provider="Postgres",
|
|
322
|
+
operation="truncate_conversation_after",
|
|
323
|
+
cause=e
|
|
324
|
+
))
|
|
325
|
+
|
|
326
|
+
async def get_conversation_until_message(
|
|
327
|
+
self,
|
|
328
|
+
conversation_id: str,
|
|
329
|
+
message_id: MessageId
|
|
330
|
+
) -> Result[Optional[ConversationMemory], Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
331
|
+
"""
|
|
332
|
+
Get conversation history up to (but not including) the specified message ID.
|
|
333
|
+
Useful for regeneration scenarios.
|
|
334
|
+
"""
|
|
335
|
+
try:
|
|
336
|
+
# Get the conversation
|
|
337
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
338
|
+
if isinstance(conv_result, Failure):
|
|
339
|
+
return conv_result
|
|
340
|
+
|
|
341
|
+
if not conv_result.data:
|
|
342
|
+
return Success(None)
|
|
343
|
+
|
|
344
|
+
conversation = conv_result.data
|
|
345
|
+
messages = list(conversation.messages)
|
|
346
|
+
until_index = find_message_index(messages, message_id)
|
|
347
|
+
|
|
348
|
+
if until_index is None:
|
|
349
|
+
# Message not found, return None as lightweight indicator
|
|
350
|
+
print(f"[MEMORY:Postgres] Message {message_id} not found in conversation {conversation_id}")
|
|
351
|
+
return Success(None)
|
|
352
|
+
|
|
353
|
+
# Return conversation up to (but not including) the specified message
|
|
354
|
+
truncated_messages = messages[:until_index]
|
|
355
|
+
|
|
356
|
+
# Create a copy of the conversation with truncated messages
|
|
357
|
+
truncated_conversation = ConversationMemory(
|
|
358
|
+
conversation_id=conversation.conversation_id,
|
|
359
|
+
user_id=conversation.user_id,
|
|
360
|
+
messages=truncated_messages,
|
|
361
|
+
metadata={
|
|
362
|
+
**conversation.metadata,
|
|
363
|
+
"truncated_for_regeneration": True,
|
|
364
|
+
"truncated_until_message": str(message_id),
|
|
365
|
+
"original_message_count": len(messages),
|
|
366
|
+
"truncated_message_count": len(truncated_messages)
|
|
367
|
+
}
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
print(f"[MEMORY:Postgres] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages")
|
|
371
|
+
return Success(truncated_conversation)
|
|
372
|
+
|
|
373
|
+
except Exception as e:
|
|
374
|
+
return Failure(MemoryStorageError(
|
|
375
|
+
message=f"Failed to get conversation until message: {e}",
|
|
376
|
+
provider="Postgres",
|
|
377
|
+
operation="get_conversation_until_message",
|
|
378
|
+
cause=e
|
|
379
|
+
))
|
|
380
|
+
|
|
381
|
+
async def mark_regeneration_point(
|
|
382
|
+
self,
|
|
383
|
+
conversation_id: str,
|
|
384
|
+
message_id: MessageId,
|
|
385
|
+
regeneration_metadata: Dict[str, Any]
|
|
386
|
+
) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
387
|
+
"""
|
|
388
|
+
Mark a regeneration point in the conversation for audit purposes.
|
|
389
|
+
"""
|
|
390
|
+
try:
|
|
391
|
+
# Get the conversation
|
|
392
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
393
|
+
if isinstance(conv_result, Failure):
|
|
394
|
+
return conv_result
|
|
395
|
+
|
|
396
|
+
if not conv_result.data:
|
|
397
|
+
return Failure(MemoryNotFoundError(
|
|
398
|
+
message=f"Conversation {conversation_id} not found",
|
|
399
|
+
provider="Postgres",
|
|
400
|
+
conversation_id=conversation_id
|
|
401
|
+
))
|
|
402
|
+
|
|
403
|
+
conversation = conv_result.data
|
|
404
|
+
|
|
405
|
+
# Add regeneration point to metadata
|
|
406
|
+
regeneration_points = conversation.metadata.get("regeneration_points", [])
|
|
407
|
+
regeneration_point = {
|
|
408
|
+
"message_id": str(message_id),
|
|
409
|
+
"timestamp": datetime.now().isoformat(),
|
|
410
|
+
**regeneration_metadata
|
|
411
|
+
}
|
|
412
|
+
regeneration_points.append(regeneration_point)
|
|
413
|
+
|
|
414
|
+
# Update conversation metadata
|
|
415
|
+
updated_metadata = {
|
|
416
|
+
**conversation.metadata,
|
|
417
|
+
"regeneration_points": regeneration_points,
|
|
418
|
+
"last_regeneration": regeneration_point,
|
|
419
|
+
"updated_at": datetime.now().isoformat(),
|
|
420
|
+
"regeneration_count": len(regeneration_points)
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
# Update in database using JSONB merge
|
|
424
|
+
query = f"""
|
|
425
|
+
UPDATE {self.config.table_name}
|
|
426
|
+
SET metadata = metadata || $1::jsonb
|
|
427
|
+
WHERE conversation_id = $2
|
|
428
|
+
"""
|
|
429
|
+
|
|
430
|
+
await self._db_execute(
|
|
431
|
+
query,
|
|
432
|
+
json.dumps({
|
|
433
|
+
"regeneration_points": regeneration_points,
|
|
434
|
+
"last_regeneration": regeneration_point,
|
|
435
|
+
"updated_at": updated_metadata["updated_at"],
|
|
436
|
+
"regeneration_count": len(regeneration_points)
|
|
437
|
+
}),
|
|
438
|
+
conversation_id
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
print(f"[MEMORY:Postgres] Marked regeneration point for conversation {conversation_id} at message {message_id}")
|
|
442
|
+
return Success(None)
|
|
443
|
+
|
|
444
|
+
except Exception as e:
|
|
445
|
+
return Failure(MemoryStorageError(
|
|
446
|
+
message=f"Failed to mark regeneration point: {e}",
|
|
447
|
+
provider="Postgres",
|
|
448
|
+
operation="mark_regeneration_point",
|
|
449
|
+
cause=e
|
|
450
|
+
))
|
|
451
|
+
|
|
242
452
|
async def close(self) -> Result[None, MemoryConnectionError]:
|
|
243
453
|
try:
|
|
244
454
|
if hasattr(self.client, 'close'):
|
jaf/memory/providers/redis.py
CHANGED
|
@@ -8,7 +8,7 @@ Best for production environments with shared state and persistence across restar
|
|
|
8
8
|
from datetime import datetime
|
|
9
9
|
from typing import Any, Dict, List, Optional, Union
|
|
10
10
|
|
|
11
|
-
from ...core.types import Message
|
|
11
|
+
from ...core.types import Message, MessageId, find_message_index
|
|
12
12
|
from ..types import (
|
|
13
13
|
ConversationMemory,
|
|
14
14
|
Failure,
|
|
@@ -205,6 +205,194 @@ class RedisProvider(MemoryProvider):
|
|
|
205
205
|
except Exception as e:
|
|
206
206
|
return Failure(MemoryConnectionError(provider="Redis", message="Redis health check failed", cause=e))
|
|
207
207
|
|
|
208
|
+
async def truncate_conversation_after(
|
|
209
|
+
self,
|
|
210
|
+
conversation_id: str,
|
|
211
|
+
message_id: MessageId
|
|
212
|
+
) -> Result[int, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
213
|
+
"""
|
|
214
|
+
Truncate conversation after (and including) the specified message ID.
|
|
215
|
+
Returns the number of messages removed.
|
|
216
|
+
"""
|
|
217
|
+
try:
|
|
218
|
+
# Get the conversation
|
|
219
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
220
|
+
if isinstance(conv_result, Failure):
|
|
221
|
+
return conv_result
|
|
222
|
+
|
|
223
|
+
if not conv_result.data:
|
|
224
|
+
return Failure(MemoryNotFoundError(
|
|
225
|
+
message=f"Conversation {conversation_id} not found",
|
|
226
|
+
provider="Redis",
|
|
227
|
+
conversation_id=conversation_id
|
|
228
|
+
))
|
|
229
|
+
|
|
230
|
+
conversation = conv_result.data
|
|
231
|
+
messages = list(conversation.messages)
|
|
232
|
+
truncate_index = find_message_index(messages, message_id)
|
|
233
|
+
|
|
234
|
+
if truncate_index is None:
|
|
235
|
+
# Message not found, nothing to truncate
|
|
236
|
+
return Success(0)
|
|
237
|
+
|
|
238
|
+
# Truncate messages from the found index onwards
|
|
239
|
+
original_count = len(messages)
|
|
240
|
+
truncated_messages = messages[:truncate_index]
|
|
241
|
+
removed_count = original_count - len(truncated_messages)
|
|
242
|
+
|
|
243
|
+
# Update conversation with truncated messages
|
|
244
|
+
now = datetime.now()
|
|
245
|
+
updated_metadata = {
|
|
246
|
+
**conversation.metadata,
|
|
247
|
+
"updated_at": now,
|
|
248
|
+
"last_activity": now,
|
|
249
|
+
"total_messages": len(truncated_messages),
|
|
250
|
+
"regeneration_truncated": True,
|
|
251
|
+
"truncated_at": now.isoformat(),
|
|
252
|
+
"messages_removed": removed_count
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
# Store updated conversation
|
|
256
|
+
updated_conversation = ConversationMemory(
|
|
257
|
+
conversation_id=conversation_id,
|
|
258
|
+
user_id=conversation.user_id,
|
|
259
|
+
messages=truncated_messages,
|
|
260
|
+
metadata=updated_metadata
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
key = self._get_key(conversation_id)
|
|
264
|
+
await self.redis_client.set(key, self._serialize(updated_conversation), ex=self.config.ttl)
|
|
265
|
+
|
|
266
|
+
print(f"[MEMORY:Redis] Truncated conversation {conversation_id}: removed {removed_count} messages after message {message_id}")
|
|
267
|
+
return Success(removed_count)
|
|
268
|
+
|
|
269
|
+
except Exception as e:
|
|
270
|
+
return Failure(MemoryStorageError(
|
|
271
|
+
message=f"Failed to truncate conversation: {e}",
|
|
272
|
+
provider="Redis",
|
|
273
|
+
operation="truncate_conversation_after",
|
|
274
|
+
cause=e
|
|
275
|
+
))
|
|
276
|
+
|
|
277
|
+
async def get_conversation_until_message(
|
|
278
|
+
self,
|
|
279
|
+
conversation_id: str,
|
|
280
|
+
message_id: MessageId
|
|
281
|
+
) -> Result[Optional[ConversationMemory], Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
282
|
+
"""
|
|
283
|
+
Get conversation history up to (but not including) the specified message ID.
|
|
284
|
+
Useful for regeneration scenarios.
|
|
285
|
+
"""
|
|
286
|
+
try:
|
|
287
|
+
# Get the conversation
|
|
288
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
289
|
+
if isinstance(conv_result, Failure):
|
|
290
|
+
return conv_result
|
|
291
|
+
|
|
292
|
+
if not conv_result.data:
|
|
293
|
+
return Success(None)
|
|
294
|
+
|
|
295
|
+
conversation = conv_result.data
|
|
296
|
+
messages = list(conversation.messages)
|
|
297
|
+
until_index = find_message_index(messages, message_id)
|
|
298
|
+
|
|
299
|
+
if until_index is None:
|
|
300
|
+
# Message not found, return None as lightweight indicator
|
|
301
|
+
print(f"[MEMORY:Redis] Message {message_id} not found in conversation {conversation_id}")
|
|
302
|
+
return Success(None)
|
|
303
|
+
|
|
304
|
+
# Return conversation up to (but not including) the specified message
|
|
305
|
+
truncated_messages = messages[:until_index]
|
|
306
|
+
|
|
307
|
+
# Create a copy of the conversation with truncated messages
|
|
308
|
+
truncated_conversation = ConversationMemory(
|
|
309
|
+
conversation_id=conversation.conversation_id,
|
|
310
|
+
user_id=conversation.user_id,
|
|
311
|
+
messages=truncated_messages,
|
|
312
|
+
metadata={
|
|
313
|
+
**conversation.metadata,
|
|
314
|
+
"truncated_for_regeneration": True,
|
|
315
|
+
"truncated_until_message": str(message_id),
|
|
316
|
+
"original_message_count": len(messages),
|
|
317
|
+
"truncated_message_count": len(truncated_messages)
|
|
318
|
+
}
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
print(f"[MEMORY:Redis] Retrieved conversation {conversation_id} until message {message_id}: {len(truncated_messages)} messages")
|
|
322
|
+
return Success(truncated_conversation)
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
return Failure(MemoryStorageError(
|
|
326
|
+
message=f"Failed to get conversation until message: {e}",
|
|
327
|
+
provider="Redis",
|
|
328
|
+
operation="get_conversation_until_message",
|
|
329
|
+
cause=e
|
|
330
|
+
))
|
|
331
|
+
|
|
332
|
+
async def mark_regeneration_point(
|
|
333
|
+
self,
|
|
334
|
+
conversation_id: str,
|
|
335
|
+
message_id: MessageId,
|
|
336
|
+
regeneration_metadata: Dict[str, Any]
|
|
337
|
+
) -> Result[None, Union[MemoryNotFoundError, MemoryStorageError]]:
|
|
338
|
+
"""
|
|
339
|
+
Mark a regeneration point in the conversation for audit purposes.
|
|
340
|
+
"""
|
|
341
|
+
try:
|
|
342
|
+
# Get the conversation
|
|
343
|
+
conv_result = await self.get_conversation(conversation_id)
|
|
344
|
+
if isinstance(conv_result, Failure):
|
|
345
|
+
return conv_result
|
|
346
|
+
|
|
347
|
+
if not conv_result.data:
|
|
348
|
+
return Failure(MemoryNotFoundError(
|
|
349
|
+
message=f"Conversation {conversation_id} not found",
|
|
350
|
+
provider="Redis",
|
|
351
|
+
conversation_id=conversation_id
|
|
352
|
+
))
|
|
353
|
+
|
|
354
|
+
conversation = conv_result.data
|
|
355
|
+
|
|
356
|
+
# Add regeneration point to metadata
|
|
357
|
+
regeneration_points = conversation.metadata.get("regeneration_points", [])
|
|
358
|
+
regeneration_point = {
|
|
359
|
+
"message_id": str(message_id),
|
|
360
|
+
"timestamp": datetime.now().isoformat(),
|
|
361
|
+
**regeneration_metadata
|
|
362
|
+
}
|
|
363
|
+
regeneration_points.append(regeneration_point)
|
|
364
|
+
|
|
365
|
+
# Update conversation metadata
|
|
366
|
+
updated_metadata = {
|
|
367
|
+
**conversation.metadata,
|
|
368
|
+
"regeneration_points": regeneration_points,
|
|
369
|
+
"last_regeneration": regeneration_point,
|
|
370
|
+
"updated_at": datetime.now(),
|
|
371
|
+
"regeneration_count": len(regeneration_points)
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
# Store updated conversation
|
|
375
|
+
updated_conversation = ConversationMemory(
|
|
376
|
+
conversation_id=conversation.conversation_id,
|
|
377
|
+
user_id=conversation.user_id,
|
|
378
|
+
messages=conversation.messages,
|
|
379
|
+
metadata=updated_metadata
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
key = self._get_key(conversation_id)
|
|
383
|
+
await self.redis_client.set(key, self._serialize(updated_conversation), ex=self.config.ttl)
|
|
384
|
+
|
|
385
|
+
print(f"[MEMORY:Redis] Marked regeneration point for conversation {conversation_id} at message {message_id}")
|
|
386
|
+
return Success(None)
|
|
387
|
+
|
|
388
|
+
except Exception as e:
|
|
389
|
+
return Failure(MemoryStorageError(
|
|
390
|
+
message=f"Failed to mark regeneration point: {e}",
|
|
391
|
+
provider="Redis",
|
|
392
|
+
operation="mark_regeneration_point",
|
|
393
|
+
cause=e
|
|
394
|
+
))
|
|
395
|
+
|
|
208
396
|
async def close(self) -> Result[None, MemoryConnectionError]:
|
|
209
397
|
try:
|
|
210
398
|
await self.redis_client.aclose()
|
jaf/memory/types.py
CHANGED
|
@@ -11,7 +11,7 @@ from typing import Any, Dict, Generic, List, Optional, Protocol, TypeVar, Union
|
|
|
11
11
|
|
|
12
12
|
from pydantic import BaseModel, Field
|
|
13
13
|
|
|
14
|
-
from ..core.types import Message, TraceId
|
|
14
|
+
from ..core.types import Message, TraceId, MessageId
|
|
15
15
|
|
|
16
16
|
# Generic Result type for functional error handling
|
|
17
17
|
T = TypeVar('T')
|
|
@@ -135,6 +135,40 @@ class MemoryProvider(Protocol):
|
|
|
135
135
|
"""Close/cleanup the provider."""
|
|
136
136
|
...
|
|
137
137
|
|
|
138
|
+
# Regeneration support methods
|
|
139
|
+
async def truncate_conversation_after(
|
|
140
|
+
self,
|
|
141
|
+
conversation_id: str,
|
|
142
|
+
message_id: MessageId
|
|
143
|
+
) -> Result[int, Union['MemoryNotFoundError', 'MemoryStorageError']]:
|
|
144
|
+
"""
|
|
145
|
+
Truncate conversation after (and including) the specified message ID.
|
|
146
|
+
Returns the number of messages removed.
|
|
147
|
+
"""
|
|
148
|
+
...
|
|
149
|
+
|
|
150
|
+
async def get_conversation_until_message(
|
|
151
|
+
self,
|
|
152
|
+
conversation_id: str,
|
|
153
|
+
message_id: MessageId
|
|
154
|
+
) -> Result[Optional[ConversationMemory], Union['MemoryNotFoundError', 'MemoryStorageError']]:
|
|
155
|
+
"""
|
|
156
|
+
Get conversation history up to (but not including) the specified message ID.
|
|
157
|
+
Useful for regeneration scenarios.
|
|
158
|
+
"""
|
|
159
|
+
...
|
|
160
|
+
|
|
161
|
+
async def mark_regeneration_point(
|
|
162
|
+
self,
|
|
163
|
+
conversation_id: str,
|
|
164
|
+
message_id: MessageId,
|
|
165
|
+
regeneration_metadata: Dict[str, Any]
|
|
166
|
+
) -> Result[None, Union['MemoryNotFoundError', 'MemoryStorageError']]:
|
|
167
|
+
"""
|
|
168
|
+
Mark a regeneration point in the conversation for audit purposes.
|
|
169
|
+
"""
|
|
170
|
+
...
|
|
171
|
+
|
|
138
172
|
# Configuration models using Pydantic for validation
|
|
139
173
|
|
|
140
174
|
class InMemoryConfig(BaseModel):
|
jaf/memory/utils.py
CHANGED
|
@@ -22,6 +22,7 @@ def serialize_message(msg: Message) -> dict:
|
|
|
22
22
|
return {
|
|
23
23
|
"role": msg.role,
|
|
24
24
|
"content": msg.content,
|
|
25
|
+
"message_id": msg.message_id,
|
|
25
26
|
"tool_call_id": msg.tool_call_id,
|
|
26
27
|
"tool_calls": [
|
|
27
28
|
{
|
|
@@ -58,6 +59,7 @@ def deserialize_message(msg_data: dict) -> Message:
|
|
|
58
59
|
return Message(
|
|
59
60
|
role=msg_data["role"],
|
|
60
61
|
content=msg_data["content"],
|
|
62
|
+
message_id=msg_data.get("message_id"),
|
|
61
63
|
tool_call_id=msg_data.get("tool_call_id"),
|
|
62
64
|
tool_calls=tool_calls
|
|
63
65
|
)
|