jaf-py 2.5.3__py3-none-any.whl → 2.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +1 -1
- jaf/core/engine.py +1 -9
- jaf/core/regeneration.py +392 -0
- jaf/core/tracing.py +1 -1
- jaf/core/types.py +109 -2
- jaf/memory/providers/in_memory.py +174 -1
- jaf/memory/providers/postgres.py +211 -1
- jaf/memory/providers/redis.py +189 -1
- jaf/memory/types.py +35 -1
- jaf/memory/utils.py +2 -0
- jaf/server/server.py +163 -0
- jaf/server/types.py +49 -1
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/METADATA +2 -2
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/RECORD +18 -17
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/WHEEL +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.5.3.dist-info → jaf_py-2.5.4.dist-info}/top_level.txt +0 -0
jaf/__init__.py
CHANGED
|
@@ -191,7 +191,7 @@ def generate_run_id() -> RunId:
|
|
|
191
191
|
"""Generate a new run ID."""
|
|
192
192
|
return create_run_id(str(uuid.uuid4()))
|
|
193
193
|
|
|
194
|
-
__version__ = "2.5.
|
|
194
|
+
__version__ = "2.5.4"
|
|
195
195
|
__all__ = [
|
|
196
196
|
# Core types and functions
|
|
197
197
|
"TraceId", "RunId", "ValidationResult", "Message", "ModelConfig",
|
jaf/core/engine.py
CHANGED
|
@@ -293,15 +293,7 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
|
|
|
293
293
|
# For HITL scenarios, append new messages to memory messages
|
|
294
294
|
# This prevents duplication when resuming from interruptions
|
|
295
295
|
if memory_messages:
|
|
296
|
-
combined_messages = memory_messages +
|
|
297
|
-
msg for msg in state.messages
|
|
298
|
-
if not any(
|
|
299
|
-
mem_msg.role == msg.role and
|
|
300
|
-
mem_msg.content == msg.content and
|
|
301
|
-
getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None)
|
|
302
|
-
for mem_msg in memory_messages
|
|
303
|
-
)
|
|
304
|
-
]
|
|
296
|
+
combined_messages = memory_messages + list(state.messages)
|
|
305
297
|
else:
|
|
306
298
|
combined_messages = list(state.messages)
|
|
307
299
|
|
jaf/core/regeneration.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Regeneration functionality for the JAF framework.
|
|
3
|
+
|
|
4
|
+
This module implements conversation regeneration where a specific message can be
|
|
5
|
+
regenerated, removing all subsequent messages and creating a new conversation path.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import time
|
|
9
|
+
from dataclasses import replace
|
|
10
|
+
from typing import Any, TypeVar, Optional
|
|
11
|
+
|
|
12
|
+
from .types import (
|
|
13
|
+
RunState, RunConfig, RunResult,
|
|
14
|
+
RegenerationRequest, RegenerationContext,
|
|
15
|
+
MessageId, Message, ErrorOutcome, ModelBehaviorError,
|
|
16
|
+
find_message_index, truncate_messages_after, get_message_by_id,
|
|
17
|
+
generate_run_id, generate_trace_id
|
|
18
|
+
)
|
|
19
|
+
from .engine import run as engine_run
|
|
20
|
+
from ..memory.types import Success, Failure
|
|
21
|
+
|
|
22
|
+
Ctx = TypeVar('Ctx')
|
|
23
|
+
Out = TypeVar('Out')
|
|
24
|
+
|
|
25
|
+
async def regenerate_conversation(
|
|
26
|
+
regeneration_request: RegenerationRequest,
|
|
27
|
+
config: RunConfig[Ctx],
|
|
28
|
+
context: Ctx,
|
|
29
|
+
agent_name: str
|
|
30
|
+
) -> RunResult[Out]:
|
|
31
|
+
"""
|
|
32
|
+
Regenerate a conversation from a specific message ID.
|
|
33
|
+
|
|
34
|
+
This function:
|
|
35
|
+
1. Loads the full conversation from memory
|
|
36
|
+
2. Finds the message to regenerate from
|
|
37
|
+
3. Truncates the conversation at that point
|
|
38
|
+
4. Creates a new RunState with truncated conversation
|
|
39
|
+
5. Executes the regeneration through the normal engine flow
|
|
40
|
+
6. Updates memory with the new conversation path
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
regeneration_request: The regeneration request containing conversation_id and message_id
|
|
44
|
+
config: The run configuration
|
|
45
|
+
context: The context for the regeneration
|
|
46
|
+
agent_name: The name of the agent to use for regeneration
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
RunResult with the regenerated conversation outcome
|
|
50
|
+
"""
|
|
51
|
+
if not config.memory or not config.memory.provider or not config.conversation_id:
|
|
52
|
+
return RunResult(
|
|
53
|
+
final_state=RunState(
|
|
54
|
+
run_id=generate_run_id(),
|
|
55
|
+
trace_id=generate_trace_id(),
|
|
56
|
+
messages=[],
|
|
57
|
+
current_agent_name=agent_name,
|
|
58
|
+
context=context,
|
|
59
|
+
turn_count=0
|
|
60
|
+
),
|
|
61
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
62
|
+
detail="Regeneration requires memory provider and conversation_id to be configured"
|
|
63
|
+
))
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Load the conversation from memory
|
|
67
|
+
conversation_result = await config.memory.provider.get_conversation(regeneration_request.conversation_id)
|
|
68
|
+
if isinstance(conversation_result, Failure):
|
|
69
|
+
return RunResult(
|
|
70
|
+
final_state=RunState(
|
|
71
|
+
run_id=generate_run_id(),
|
|
72
|
+
trace_id=generate_trace_id(),
|
|
73
|
+
messages=[],
|
|
74
|
+
current_agent_name=agent_name,
|
|
75
|
+
context=context,
|
|
76
|
+
turn_count=0
|
|
77
|
+
),
|
|
78
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
79
|
+
detail=f"Failed to load conversation: {conversation_result.error}"
|
|
80
|
+
))
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
conversation_memory = conversation_result.data
|
|
84
|
+
if not conversation_memory:
|
|
85
|
+
return RunResult(
|
|
86
|
+
final_state=RunState(
|
|
87
|
+
run_id=generate_run_id(),
|
|
88
|
+
trace_id=generate_trace_id(),
|
|
89
|
+
messages=[],
|
|
90
|
+
current_agent_name=agent_name,
|
|
91
|
+
context=context,
|
|
92
|
+
turn_count=0
|
|
93
|
+
),
|
|
94
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
95
|
+
detail=f"Conversation {regeneration_request.conversation_id} not found"
|
|
96
|
+
))
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Convert tuple back to list for processing
|
|
100
|
+
original_messages = list(conversation_memory.messages)
|
|
101
|
+
|
|
102
|
+
# Find the message to regenerate from
|
|
103
|
+
regenerate_message = get_message_by_id(original_messages, regeneration_request.message_id)
|
|
104
|
+
if not regenerate_message:
|
|
105
|
+
return RunResult(
|
|
106
|
+
final_state=RunState(
|
|
107
|
+
run_id=generate_run_id(),
|
|
108
|
+
trace_id=generate_trace_id(),
|
|
109
|
+
messages=original_messages,
|
|
110
|
+
current_agent_name=agent_name,
|
|
111
|
+
context=context,
|
|
112
|
+
turn_count=len([m for m in original_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
113
|
+
),
|
|
114
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
115
|
+
detail=f"Message {regeneration_request.message_id} not found in conversation"
|
|
116
|
+
))
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Get the index of the message to regenerate
|
|
120
|
+
regenerate_index = find_message_index(original_messages, regeneration_request.message_id)
|
|
121
|
+
if regenerate_index is None:
|
|
122
|
+
return RunResult(
|
|
123
|
+
final_state=RunState(
|
|
124
|
+
run_id=generate_run_id(),
|
|
125
|
+
trace_id=generate_trace_id(),
|
|
126
|
+
messages=original_messages,
|
|
127
|
+
current_agent_name=agent_name,
|
|
128
|
+
context=context,
|
|
129
|
+
turn_count=len([m for m in original_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
130
|
+
),
|
|
131
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
132
|
+
detail=f"Failed to find index for message {regeneration_request.message_id}"
|
|
133
|
+
))
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def determine_regeneration_type(messages, regenerate_index, context):
|
|
137
|
+
"""Determine if this is pure regeneration or edit scenario."""
|
|
138
|
+
if context and context.get("replace_user_message"):
|
|
139
|
+
return "edit"
|
|
140
|
+
|
|
141
|
+
regenerate_message = messages[regenerate_index]
|
|
142
|
+
if regenerate_message.role in ['assistant', 'ASSISTANT']:
|
|
143
|
+
for i in range(regenerate_index - 1, -1, -1):
|
|
144
|
+
if messages[i].role in ['user', 'USER']:
|
|
145
|
+
return "pure"
|
|
146
|
+
return "edit"
|
|
147
|
+
|
|
148
|
+
# Determine regeneration type
|
|
149
|
+
regen_type = determine_regeneration_type(original_messages, regenerate_index, regeneration_request.context or {})
|
|
150
|
+
print(f"[JAF:REGENERATION] Detected regeneration type: {regen_type}")
|
|
151
|
+
|
|
152
|
+
if regen_type == "pure":
|
|
153
|
+
# For pure regeneration, find the user message that started this conversation turn
|
|
154
|
+
user_message_index = None
|
|
155
|
+
for i in range(regenerate_index - 1, -1, -1):
|
|
156
|
+
if original_messages[i].role in ['user', 'USER']:
|
|
157
|
+
user_message_index = i
|
|
158
|
+
break
|
|
159
|
+
|
|
160
|
+
if user_message_index is not None:
|
|
161
|
+
# Truncate AFTER the user message (keeps user message, removes tool calls/outputs)
|
|
162
|
+
truncated_messages = original_messages[:user_message_index + 1]
|
|
163
|
+
print(f"[JAF:REGENERATION] Pure regeneration: truncated to user message at index {user_message_index}")
|
|
164
|
+
else:
|
|
165
|
+
truncated_messages = original_messages[:regenerate_index]
|
|
166
|
+
print(f"[JAF:REGENERATION] Pure regeneration fallback: no user message found")
|
|
167
|
+
else:
|
|
168
|
+
# Edit regeneration: truncate at the specified point and add replacement query
|
|
169
|
+
truncated_messages = original_messages[:regenerate_index]
|
|
170
|
+
|
|
171
|
+
if (regeneration_request.context and
|
|
172
|
+
regeneration_request.context.get("replace_user_message")):
|
|
173
|
+
|
|
174
|
+
from .types import ContentRole, Message
|
|
175
|
+
replacement_user_message = Message(
|
|
176
|
+
role=ContentRole.USER,
|
|
177
|
+
content=regeneration_request.context.get("replace_user_message")
|
|
178
|
+
)
|
|
179
|
+
truncated_messages.append(replacement_user_message)
|
|
180
|
+
print(f"[JAF:REGENERATION] Edit regeneration: replaced user query with: {regeneration_request.context.get('replace_user_message')}")
|
|
181
|
+
|
|
182
|
+
print(f"[JAF:REGENERATION] Truncated conversation to {len(truncated_messages)} messages")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
print(f"[JAF:REGENERATION] About to store {len(truncated_messages)} truncated messages to memory")
|
|
186
|
+
|
|
187
|
+
def serialize_metadata(metadata):
|
|
188
|
+
import json
|
|
189
|
+
import datetime
|
|
190
|
+
|
|
191
|
+
def json_serializer(obj):
|
|
192
|
+
if isinstance(obj, datetime.datetime):
|
|
193
|
+
return obj.isoformat()
|
|
194
|
+
elif isinstance(obj, datetime.date):
|
|
195
|
+
return obj.isoformat()
|
|
196
|
+
elif hasattr(obj, '__dict__'):
|
|
197
|
+
return obj.__dict__
|
|
198
|
+
return str(obj)
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
json_str = json.dumps(metadata, default=json_serializer)
|
|
202
|
+
return json.loads(json_str)
|
|
203
|
+
except Exception as e:
|
|
204
|
+
print(f"[JAF:REGENERATION] Warning: Metadata serialization failed: {e}")
|
|
205
|
+
return {
|
|
206
|
+
"regeneration_truncated": True,
|
|
207
|
+
"regeneration_point": str(regeneration_request.message_id),
|
|
208
|
+
"original_message_count": len(original_messages),
|
|
209
|
+
"truncated_at_index": regenerate_index,
|
|
210
|
+
"turn_count": len([m for m in truncated_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
metadata = serialize_metadata({
|
|
214
|
+
**conversation_memory.metadata,
|
|
215
|
+
"regeneration_truncated": True,
|
|
216
|
+
"regeneration_point": str(regeneration_request.message_id),
|
|
217
|
+
"original_message_count": len(original_messages),
|
|
218
|
+
"truncated_at_index": regenerate_index,
|
|
219
|
+
"turn_count": len([m for m in truncated_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
220
|
+
})
|
|
221
|
+
|
|
222
|
+
store_result = await config.memory.provider.store_messages(
|
|
223
|
+
regeneration_request.conversation_id,
|
|
224
|
+
truncated_messages,
|
|
225
|
+
metadata
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
print(f"[JAF:REGENERATION] Store result type: {type(store_result)}")
|
|
229
|
+
if isinstance(store_result, Failure):
|
|
230
|
+
print(f"[JAF:REGENERATION] Store failed with error: {store_result.error}")
|
|
231
|
+
return RunResult(
|
|
232
|
+
final_state=RunState(
|
|
233
|
+
run_id=generate_run_id(),
|
|
234
|
+
trace_id=generate_trace_id(),
|
|
235
|
+
messages=original_messages,
|
|
236
|
+
current_agent_name=agent_name,
|
|
237
|
+
context=context,
|
|
238
|
+
turn_count=len([m for m in original_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
239
|
+
),
|
|
240
|
+
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
241
|
+
detail=f"Failed to store truncated conversation: {store_result.error}"
|
|
242
|
+
))
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
print(f"[JAF:REGENERATION] Store successful, proceeding to engine execution")
|
|
246
|
+
|
|
247
|
+
# Create regeneration context for later use
|
|
248
|
+
regeneration_context = RegenerationContext(
|
|
249
|
+
original_message_count=len(original_messages),
|
|
250
|
+
truncated_at_index=regenerate_index,
|
|
251
|
+
regenerated_message_id=regeneration_request.message_id,
|
|
252
|
+
regeneration_id=f"regen_{int(time.time() * 1000)}_{regeneration_request.message_id}",
|
|
253
|
+
timestamp=int(time.time() * 1000)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Calculate turn count from truncated messages
|
|
257
|
+
truncated_turn_count = len([m for m in truncated_messages if (m.role.value if hasattr(m.role, 'value') else m.role) == 'assistant'])
|
|
258
|
+
|
|
259
|
+
final_context = context
|
|
260
|
+
print(f"[JAF:REGENERATION] Using provided context: {type(context).__name__}")
|
|
261
|
+
|
|
262
|
+
# Create initial state for regeneration with truncated conversation
|
|
263
|
+
initial_state = RunState(
|
|
264
|
+
run_id=generate_run_id(),
|
|
265
|
+
trace_id=generate_trace_id(),
|
|
266
|
+
messages=[],
|
|
267
|
+
current_agent_name=agent_name,
|
|
268
|
+
context=final_context,
|
|
269
|
+
turn_count=truncated_turn_count,
|
|
270
|
+
approvals={} # Reset approvals for regeneration
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
print(f"[JAF:REGENERATION] Starting regeneration from message {regeneration_request.message_id}")
|
|
274
|
+
print(f"[JAF:REGENERATION] Original messages: {len(original_messages)}, Truncated to: {len(truncated_messages)}")
|
|
275
|
+
print(f"[JAF:REGENERATION] Regeneration context: {regeneration_context}")
|
|
276
|
+
|
|
277
|
+
# Create a modified config for regeneration that ensures memory storage
|
|
278
|
+
regeneration_config = replace(
|
|
279
|
+
config,
|
|
280
|
+
conversation_id=regeneration_request.conversation_id,
|
|
281
|
+
memory=replace(config.memory, auto_store=True, store_on_completion=True) if config.memory else None
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Execute the regeneration through the normal engine flow
|
|
285
|
+
print(f"[JAF:REGENERATION] About to execute engine with {len(truncated_messages)} messages")
|
|
286
|
+
print(f"[JAF:REGENERATION] Final message: {truncated_messages[-1] if truncated_messages else 'None'}")
|
|
287
|
+
|
|
288
|
+
result = await engine_run(initial_state, regeneration_config)
|
|
289
|
+
|
|
290
|
+
print(f"[JAF:REGENERATION] Regeneration completed with status: {result.outcome.status}")
|
|
291
|
+
if hasattr(result, 'final_state') and hasattr(result.final_state, 'messages'):
|
|
292
|
+
print(f"[JAF:REGENERATION] Final state has {len(result.final_state.messages)} messages")
|
|
293
|
+
assistant_msgs = [m for m in result.final_state.messages if m.role in ['assistant', 'ASSISTANT']]
|
|
294
|
+
print(f"[JAF:REGENERATION] Found {len(assistant_msgs)} assistant messages in result")
|
|
295
|
+
|
|
296
|
+
# After successful regeneration, mark the regeneration point and preserve metadata
|
|
297
|
+
if result.outcome.status == 'completed' and config.memory and config.memory.provider:
|
|
298
|
+
try:
|
|
299
|
+
print(f"[JAF:REGENERATION] Marking regeneration point after successful regeneration")
|
|
300
|
+
|
|
301
|
+
# Get the current conversation to preserve regeneration metadata
|
|
302
|
+
current_conv_result = await config.memory.provider.get_conversation(regeneration_request.conversation_id)
|
|
303
|
+
print(f"[JAF:REGENERATION] Retrieved conversation for preservation: {hasattr(current_conv_result, 'data') and current_conv_result.data is not None}")
|
|
304
|
+
|
|
305
|
+
if hasattr(current_conv_result, 'data') and current_conv_result.data:
|
|
306
|
+
current_metadata = current_conv_result.data.metadata
|
|
307
|
+
regeneration_points = current_metadata.get('regeneration_points', [])
|
|
308
|
+
print(f"[JAF:REGENERATION] Found {len(regeneration_points)} regeneration points in metadata before marking")
|
|
309
|
+
|
|
310
|
+
# Mark the regeneration point by calling the provider method directly
|
|
311
|
+
mark_result = await config.memory.provider.mark_regeneration_point(
|
|
312
|
+
regeneration_request.conversation_id,
|
|
313
|
+
regeneration_request.message_id,
|
|
314
|
+
{
|
|
315
|
+
"regeneration_id": regeneration_context.regeneration_id,
|
|
316
|
+
"original_message_count": len(original_messages),
|
|
317
|
+
"truncated_at_index": regenerate_index,
|
|
318
|
+
"timestamp": regeneration_context.timestamp
|
|
319
|
+
}
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
if isinstance(mark_result, Failure):
|
|
323
|
+
print(f"[JAF:REGENERATION] Warning: Failed to mark regeneration point: {mark_result.error}")
|
|
324
|
+
else:
|
|
325
|
+
print(f"[JAF:REGENERATION] Successfully marked regeneration point")
|
|
326
|
+
|
|
327
|
+
# Get the updated conversation with the new regeneration point
|
|
328
|
+
updated_conv_result = await config.memory.provider.get_conversation(regeneration_request.conversation_id)
|
|
329
|
+
if hasattr(updated_conv_result, 'data') and updated_conv_result.data:
|
|
330
|
+
updated_metadata = updated_conv_result.data.metadata
|
|
331
|
+
updated_regeneration_points = updated_metadata.get('regeneration_points', [])
|
|
332
|
+
print(f"[JAF:REGENERATION] Found {len(updated_regeneration_points)} regeneration points after marking")
|
|
333
|
+
|
|
334
|
+
# Ensure final metadata includes the regeneration points
|
|
335
|
+
final_metadata = {
|
|
336
|
+
**updated_metadata,
|
|
337
|
+
'regeneration_points': updated_regeneration_points,
|
|
338
|
+
'regeneration_count': len(updated_regeneration_points),
|
|
339
|
+
'last_regeneration': updated_regeneration_points[-1] if updated_regeneration_points else None,
|
|
340
|
+
'regeneration_preserved': True,
|
|
341
|
+
'final_preservation_timestamp': int(time.time() * 1000)
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
# Store the final conversation with preserved regeneration metadata
|
|
345
|
+
await config.memory.provider.store_messages(
|
|
346
|
+
regeneration_request.conversation_id,
|
|
347
|
+
result.final_state.messages,
|
|
348
|
+
final_metadata
|
|
349
|
+
)
|
|
350
|
+
print(f"[JAF:REGENERATION] Final preservation completed with {len(updated_regeneration_points)} regeneration points")
|
|
351
|
+
else:
|
|
352
|
+
print(f"[JAF:REGENERATION] No conversation data found for preservation")
|
|
353
|
+
|
|
354
|
+
except Exception as e:
|
|
355
|
+
print(f"[JAF:REGENERATION] Warning: Failed to preserve regeneration points: {e}")
|
|
356
|
+
import traceback
|
|
357
|
+
traceback.print_exc()
|
|
358
|
+
|
|
359
|
+
return result
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
async def get_regeneration_points(
|
|
363
|
+
conversation_id: str,
|
|
364
|
+
config: RunConfig[Ctx]
|
|
365
|
+
) -> Optional[list]:
|
|
366
|
+
"""
|
|
367
|
+
Get all regeneration points for a conversation.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
conversation_id: The conversation ID
|
|
371
|
+
config: The run configuration
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
List of regeneration points or None if not available
|
|
375
|
+
"""
|
|
376
|
+
if not config.memory or not config.memory.provider:
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
try:
|
|
380
|
+
conversation_result = await config.memory.provider.get_conversation(conversation_id)
|
|
381
|
+
if hasattr(conversation_result, 'data') and conversation_result.data:
|
|
382
|
+
metadata = conversation_result.data.metadata
|
|
383
|
+
regeneration_points = metadata.get('regeneration_points', [])
|
|
384
|
+
print(f"[JAF:REGENERATION] Retrieved {len(regeneration_points)} regeneration points for {conversation_id}")
|
|
385
|
+
return regeneration_points
|
|
386
|
+
else:
|
|
387
|
+
print(f"[JAF:REGENERATION] No conversation data found for {conversation_id}")
|
|
388
|
+
return []
|
|
389
|
+
except Exception as e:
|
|
390
|
+
print(f"[JAF:REGENERATION] Failed to get regeneration points: {e}")
|
|
391
|
+
|
|
392
|
+
return []
|
jaf/core/tracing.py
CHANGED
jaf/core/types.py
CHANGED
|
@@ -94,6 +94,11 @@ class RunId(str):
|
|
|
94
94
|
def __new__(cls, value: str) -> 'RunId':
|
|
95
95
|
return str.__new__(cls, value)
|
|
96
96
|
|
|
97
|
+
class MessageId(str):
|
|
98
|
+
"""Branded string type for message IDs."""
|
|
99
|
+
def __new__(cls, value: str) -> 'MessageId':
|
|
100
|
+
return str.__new__(cls, value)
|
|
101
|
+
|
|
97
102
|
def create_trace_id(id_str: str) -> TraceId:
|
|
98
103
|
"""Create a TraceId from a string."""
|
|
99
104
|
return TraceId(id_str)
|
|
@@ -102,6 +107,36 @@ def create_run_id(id_str: str) -> RunId:
|
|
|
102
107
|
"""Create a RunId from a string."""
|
|
103
108
|
return RunId(id_str)
|
|
104
109
|
|
|
110
|
+
def create_message_id(id_str: Union[str, MessageId]) -> MessageId:
|
|
111
|
+
"""
|
|
112
|
+
Create a MessageId from a string or return existing MessageId.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
id_str: Either a string to convert to MessageId or an existing MessageId
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
MessageId: A validated MessageId instance
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ValueError: If the input is invalid or empty
|
|
122
|
+
"""
|
|
123
|
+
# Handle None input
|
|
124
|
+
if id_str is None:
|
|
125
|
+
raise ValueError("Message ID cannot be None")
|
|
126
|
+
|
|
127
|
+
# If already a MessageId, return as-is
|
|
128
|
+
if isinstance(id_str, MessageId):
|
|
129
|
+
return id_str
|
|
130
|
+
|
|
131
|
+
# Convert string to MessageId with validation
|
|
132
|
+
if isinstance(id_str, str):
|
|
133
|
+
if not id_str.strip():
|
|
134
|
+
raise ValueError("Message ID cannot be empty or whitespace")
|
|
135
|
+
return MessageId(id_str.strip())
|
|
136
|
+
|
|
137
|
+
# Handle any other type
|
|
138
|
+
raise ValueError(f"Message ID must be a string or MessageId, got {type(id_str)}")
|
|
139
|
+
|
|
105
140
|
def generate_run_id() -> RunId:
|
|
106
141
|
"""Generate a new unique run ID."""
|
|
107
142
|
import time
|
|
@@ -114,6 +149,12 @@ def generate_trace_id() -> TraceId:
|
|
|
114
149
|
import uuid
|
|
115
150
|
return TraceId(f"trace_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
116
151
|
|
|
152
|
+
def generate_message_id() -> MessageId:
|
|
153
|
+
"""Generate a new unique message ID."""
|
|
154
|
+
import time
|
|
155
|
+
import uuid
|
|
156
|
+
return MessageId(f"msg_{int(time.time() * 1000)}_{uuid.uuid4().hex[:8]}")
|
|
157
|
+
|
|
117
158
|
# Type variables for generic contexts and outputs
|
|
118
159
|
Ctx = TypeVar('Ctx')
|
|
119
160
|
Out = TypeVar('Out')
|
|
@@ -180,12 +221,16 @@ class Message:
|
|
|
180
221
|
- Direct access to .content returns the original string when created with string
|
|
181
222
|
- Use .text_content property for guaranteed string access in all cases
|
|
182
223
|
- Use get_text_content() function to extract text from any content type
|
|
224
|
+
- message_id is optional for backward compatibility
|
|
183
225
|
|
|
184
226
|
Examples:
|
|
185
227
|
# Original usage - still works exactly the same
|
|
186
228
|
msg = Message(role='user', content='Hello')
|
|
187
229
|
text = msg.content # Returns 'Hello' as string
|
|
188
230
|
|
|
231
|
+
# New usage with message ID
|
|
232
|
+
msg = Message(role='user', content='Hello', message_id='msg_123')
|
|
233
|
+
|
|
189
234
|
# Guaranteed string access (recommended for new code)
|
|
190
235
|
text = msg.text_content # Always returns string
|
|
191
236
|
|
|
@@ -197,6 +242,27 @@ class Message:
|
|
|
197
242
|
attachments: Optional[List[Attachment]] = None
|
|
198
243
|
tool_call_id: Optional[str] = None
|
|
199
244
|
tool_calls: Optional[List[ToolCall]] = None
|
|
245
|
+
message_id: Optional[MessageId] = None # Optional for backward compatibility
|
|
246
|
+
|
|
247
|
+
def __post_init__(self):
|
|
248
|
+
"""
|
|
249
|
+
Auto-generate message ID if not provided.
|
|
250
|
+
|
|
251
|
+
This implementation uses object.__setattr__ to bypass frozen dataclass restrictions,
|
|
252
|
+
which is a recommended pattern for one-time initialization of computed fields in
|
|
253
|
+
frozen dataclasses. This ensures:
|
|
254
|
+
|
|
255
|
+
1. Backward compatibility - existing code with message_id=None continues to work
|
|
256
|
+
2. Immutability - the dataclass remains frozen after initialization
|
|
257
|
+
3. Guaranteed unique IDs - every message gets a unique identifier
|
|
258
|
+
4. Clean API - users don't need to manually generate IDs in most cases
|
|
259
|
+
|
|
260
|
+
This pattern is preferred over using field(default_factory=...) because it
|
|
261
|
+
maintains the Optional[MessageId] type hint for backward compatibility while
|
|
262
|
+
ensuring the field is never actually None after object creation.
|
|
263
|
+
"""
|
|
264
|
+
if self.message_id is None:
|
|
265
|
+
object.__setattr__(self, 'message_id', generate_message_id())
|
|
200
266
|
|
|
201
267
|
@property
|
|
202
268
|
def text_content(self) -> str:
|
|
@@ -210,7 +276,8 @@ class Message:
|
|
|
210
276
|
content: str,
|
|
211
277
|
attachments: Optional[List[Attachment]] = None,
|
|
212
278
|
tool_call_id: Optional[str] = None,
|
|
213
|
-
tool_calls: Optional[List[ToolCall]] = None
|
|
279
|
+
tool_calls: Optional[List[ToolCall]] = None,
|
|
280
|
+
message_id: Optional[MessageId] = None
|
|
214
281
|
) -> 'Message':
|
|
215
282
|
"""Create a message with string content and optional attachments."""
|
|
216
283
|
return cls(
|
|
@@ -218,7 +285,8 @@ class Message:
|
|
|
218
285
|
content=content,
|
|
219
286
|
attachments=attachments,
|
|
220
287
|
tool_call_id=tool_call_id,
|
|
221
|
-
tool_calls=tool_calls
|
|
288
|
+
tool_calls=tool_calls,
|
|
289
|
+
message_id=message_id
|
|
222
290
|
)
|
|
223
291
|
|
|
224
292
|
def get_text_content(content: Union[str, List[MessageContentPart]]) -> str:
|
|
@@ -824,3 +892,42 @@ class RunConfig(Generic[Ctx]):
|
|
|
824
892
|
default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
|
|
825
893
|
default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
|
|
826
894
|
approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
|
|
895
|
+
|
|
896
|
+
# Regeneration types for conversation management
|
|
897
|
+
@dataclass(frozen=True)
|
|
898
|
+
class RegenerationRequest:
|
|
899
|
+
"""Request to regenerate a conversation from a specific message."""
|
|
900
|
+
conversation_id: str
|
|
901
|
+
message_id: MessageId # ID of the message to regenerate from
|
|
902
|
+
context: Optional[Dict[str, Any]] = None # Optional context override
|
|
903
|
+
|
|
904
|
+
@dataclass(frozen=True)
|
|
905
|
+
class RegenerationContext:
|
|
906
|
+
"""Context information for a regeneration operation."""
|
|
907
|
+
original_message_count: int
|
|
908
|
+
truncated_at_index: int
|
|
909
|
+
regenerated_message_id: MessageId
|
|
910
|
+
regeneration_id: str # Unique ID for this regeneration operation
|
|
911
|
+
timestamp: int # Unix timestamp in milliseconds
|
|
912
|
+
|
|
913
|
+
# Message utility functions
|
|
914
|
+
def find_message_index(messages: List[Message], message_id: MessageId) -> Optional[int]:
|
|
915
|
+
"""Find the index of a message by its ID."""
|
|
916
|
+
for i, msg in enumerate(messages):
|
|
917
|
+
if msg.message_id == message_id:
|
|
918
|
+
return i
|
|
919
|
+
return None
|
|
920
|
+
|
|
921
|
+
def truncate_messages_after(messages: List[Message], message_id: MessageId) -> List[Message]:
|
|
922
|
+
"""Truncate messages after (and including) the specified message ID."""
|
|
923
|
+
index = find_message_index(messages, message_id)
|
|
924
|
+
if index is None:
|
|
925
|
+
return messages # Message not found, return unchanged
|
|
926
|
+
return messages[:index]
|
|
927
|
+
|
|
928
|
+
def get_message_by_id(messages: List[Message], message_id: MessageId) -> Optional[Message]:
|
|
929
|
+
"""Get a message by its ID."""
|
|
930
|
+
for msg in messages:
|
|
931
|
+
if msg.message_id == message_id:
|
|
932
|
+
return msg
|
|
933
|
+
return None
|