jaf-py 2.3.1__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +15 -0
- jaf/core/agent_tool.py +6 -4
- jaf/core/analytics.py +4 -3
- jaf/core/engine.py +512 -39
- jaf/core/state.py +156 -0
- jaf/core/tools.py +4 -6
- jaf/core/tracing.py +114 -23
- jaf/core/types.py +157 -4
- jaf/memory/approval_storage.py +306 -0
- jaf/memory/types.py +1 -0
- jaf/memory/utils.py +1 -1
- jaf/providers/model.py +436 -13
- jaf/server/__init__.py +2 -0
- jaf/server/server.py +665 -22
- jaf/server/types.py +149 -4
- jaf/utils/__init__.py +50 -0
- jaf/utils/attachments.py +401 -0
- jaf/utils/document_processor.py +561 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/METADATA +128 -120
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/RECORD +24 -19
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/WHEEL +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.3.1.dist-info → jaf_py-2.4.2.dist-info}/top_level.txt +0 -0
jaf/core/engine.py
CHANGED
|
@@ -8,6 +8,7 @@ tool calling, and state management while maintaining functional purity.
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
10
|
import os
|
|
11
|
+
import time
|
|
11
12
|
from dataclasses import replace, asdict, is_dataclass
|
|
12
13
|
from typing import Any, Dict, List, Optional, TypeVar
|
|
13
14
|
|
|
@@ -18,6 +19,7 @@ from .tool_results import tool_result_to_string
|
|
|
18
19
|
from .types import (
|
|
19
20
|
Agent,
|
|
20
21
|
AgentNotFound,
|
|
22
|
+
ApprovalValue,
|
|
21
23
|
CompletedOutcome,
|
|
22
24
|
ContentRole,
|
|
23
25
|
DecodeError,
|
|
@@ -26,6 +28,8 @@ from .types import (
|
|
|
26
28
|
HandoffEvent,
|
|
27
29
|
HandoffEventData,
|
|
28
30
|
InputGuardrailTripwire,
|
|
31
|
+
InterruptedOutcome,
|
|
32
|
+
Interruption,
|
|
29
33
|
GuardrailEvent,
|
|
30
34
|
GuardrailEventData,
|
|
31
35
|
MemoryEvent,
|
|
@@ -36,8 +40,11 @@ from .types import (
|
|
|
36
40
|
LLMCallEndEventData,
|
|
37
41
|
LLMCallStartEvent,
|
|
38
42
|
LLMCallStartEventData,
|
|
43
|
+
AssistantMessageEvent,
|
|
44
|
+
AssistantMessageEventData,
|
|
39
45
|
MaxTurnsExceeded,
|
|
40
46
|
Message,
|
|
47
|
+
get_text_content,
|
|
41
48
|
ModelBehaviorError,
|
|
42
49
|
OutputGuardrailTripwire,
|
|
43
50
|
RunConfig,
|
|
@@ -47,6 +54,7 @@ from .types import (
|
|
|
47
54
|
RunStartEvent,
|
|
48
55
|
RunStartEventData,
|
|
49
56
|
RunState,
|
|
57
|
+
ToolApprovalInterruption,
|
|
50
58
|
ToolCall,
|
|
51
59
|
ToolCallEndEvent,
|
|
52
60
|
ToolCallEndEventData,
|
|
@@ -78,6 +86,82 @@ def to_event_data(value: Any) -> Any:
|
|
|
78
86
|
Ctx = TypeVar('Ctx')
|
|
79
87
|
Out = TypeVar('Out')
|
|
80
88
|
|
|
89
|
+
async def try_resume_pending_tool_calls(
|
|
90
|
+
state: RunState[Ctx],
|
|
91
|
+
config: RunConfig[Ctx]
|
|
92
|
+
) -> Optional[RunResult[Out]]:
|
|
93
|
+
"""
|
|
94
|
+
Try to resume pending tool calls if the last assistant message contained tool_calls
|
|
95
|
+
and some of those calls have not yet produced tool results.
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
messages = state.messages
|
|
99
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
100
|
+
msg = messages[i]
|
|
101
|
+
# Handle both string and enum roles
|
|
102
|
+
role_str = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
|
|
103
|
+
if role_str == 'assistant' and msg.tool_calls:
|
|
104
|
+
tool_call_ids = {tc.id for tc in msg.tool_calls}
|
|
105
|
+
|
|
106
|
+
# Scan forward for tool results tied to these ids
|
|
107
|
+
executed_ids = set()
|
|
108
|
+
for j in range(i + 1, len(messages)):
|
|
109
|
+
m = messages[j]
|
|
110
|
+
# Handle both string and enum roles
|
|
111
|
+
m_role_str = m.role.value if hasattr(m.role, 'value') else str(m.role)
|
|
112
|
+
if m_role_str == 'tool' and m.tool_call_id and m.tool_call_id in tool_call_ids:
|
|
113
|
+
executed_ids.add(m.tool_call_id)
|
|
114
|
+
|
|
115
|
+
pending_tool_calls = [tc for tc in msg.tool_calls if tc.id not in executed_ids]
|
|
116
|
+
|
|
117
|
+
if not pending_tool_calls:
|
|
118
|
+
continue # Continue checking other assistant messages
|
|
119
|
+
|
|
120
|
+
current_agent = config.agent_registry.get(state.current_agent_name)
|
|
121
|
+
if not current_agent:
|
|
122
|
+
return RunResult(
|
|
123
|
+
final_state=state,
|
|
124
|
+
outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Execute pending tool calls
|
|
128
|
+
tool_results = await _execute_tool_calls(
|
|
129
|
+
pending_tool_calls,
|
|
130
|
+
current_agent,
|
|
131
|
+
state,
|
|
132
|
+
config
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Check for interruptions
|
|
136
|
+
interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
|
|
137
|
+
if interruptions:
|
|
138
|
+
completed_results = [r for r in tool_results if not r.get('interruption')]
|
|
139
|
+
interrupted_state = replace(
|
|
140
|
+
state,
|
|
141
|
+
messages=list(state.messages) + [r['message'] for r in completed_results],
|
|
142
|
+
turn_count=state.turn_count,
|
|
143
|
+
approvals=state.approvals
|
|
144
|
+
)
|
|
145
|
+
return RunResult(
|
|
146
|
+
final_state=interrupted_state,
|
|
147
|
+
outcome=InterruptedOutcome(interruptions=interruptions)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Continue with normal execution
|
|
151
|
+
next_state = replace(
|
|
152
|
+
state,
|
|
153
|
+
messages=list(state.messages) + [r['message'] for r in tool_results],
|
|
154
|
+
turn_count=state.turn_count,
|
|
155
|
+
approvals=state.approvals
|
|
156
|
+
)
|
|
157
|
+
return await _run_internal(next_state, config)
|
|
158
|
+
|
|
159
|
+
except Exception as e:
|
|
160
|
+
# Best-effort resume; ignore and continue normal flow
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
return None
|
|
164
|
+
|
|
81
165
|
async def run(
|
|
82
166
|
initial_state: RunState[Ctx],
|
|
83
167
|
config: RunConfig[Ctx]
|
|
@@ -100,9 +184,26 @@ async def run(
|
|
|
100
184
|
))))
|
|
101
185
|
|
|
102
186
|
state_with_memory = await _load_conversation_history(initial_state, config)
|
|
187
|
+
|
|
188
|
+
# Load approvals from storage if configured
|
|
189
|
+
if config.approval_storage:
|
|
190
|
+
print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}')
|
|
191
|
+
from .state import load_approvals_into_state
|
|
192
|
+
state_with_memory = await load_approvals_into_state(state_with_memory, config)
|
|
193
|
+
|
|
103
194
|
result = await _run_internal(state_with_memory, config)
|
|
104
195
|
|
|
105
|
-
|
|
196
|
+
# Store conversation history only if this is a final completion of the entire conversation
|
|
197
|
+
# For HITL scenarios, storage happens on interruption to allow resumption
|
|
198
|
+
# We only store on completion if explicitly indicated this is the end of the conversation
|
|
199
|
+
if (config.memory and config.memory.auto_store and config.conversation_id and
|
|
200
|
+
result.outcome.status == 'completed' and getattr(config.memory, 'store_on_completion', True)):
|
|
201
|
+
print(f'[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}')
|
|
202
|
+
await _store_conversation_history(result.final_state, config)
|
|
203
|
+
elif result.outcome.status == 'interrupted':
|
|
204
|
+
print('[JAF:ENGINE] Conversation interrupted - storage already handled during interruption')
|
|
205
|
+
else:
|
|
206
|
+
print(f'[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, "store_on_completion", True) if config.memory else "N/A"}')
|
|
106
207
|
|
|
107
208
|
if config.on_event:
|
|
108
209
|
config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
|
|
@@ -152,7 +253,46 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
|
|
|
152
253
|
conversation_data = result.data
|
|
153
254
|
if conversation_data:
|
|
154
255
|
max_messages = config.memory.max_messages or len(conversation_data.messages)
|
|
155
|
-
|
|
256
|
+
all_memory_messages = conversation_data.messages[-max_messages:]
|
|
257
|
+
|
|
258
|
+
# Filter out halted messages - they're for audit/database only, not for LLM context
|
|
259
|
+
memory_messages = []
|
|
260
|
+
filtered_count = 0
|
|
261
|
+
|
|
262
|
+
for msg in all_memory_messages:
|
|
263
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
264
|
+
memory_messages.append(msg)
|
|
265
|
+
else:
|
|
266
|
+
try:
|
|
267
|
+
content = json.loads(msg.content)
|
|
268
|
+
status = content.get('status')
|
|
269
|
+
# Filter out ALL halted messages (they're for audit only)
|
|
270
|
+
if status == 'halted':
|
|
271
|
+
filtered_count += 1
|
|
272
|
+
continue # Skip this halted message
|
|
273
|
+
else:
|
|
274
|
+
memory_messages.append(msg)
|
|
275
|
+
except (json.JSONDecodeError, TypeError):
|
|
276
|
+
# Keep non-JSON tool messages
|
|
277
|
+
memory_messages.append(msg)
|
|
278
|
+
|
|
279
|
+
# For HITL scenarios, append new messages to memory messages
|
|
280
|
+
# This prevents duplication when resuming from interruptions
|
|
281
|
+
if memory_messages:
|
|
282
|
+
combined_messages = memory_messages + [
|
|
283
|
+
msg for msg in state.messages
|
|
284
|
+
if not any(
|
|
285
|
+
mem_msg.role == msg.role and
|
|
286
|
+
mem_msg.content == msg.content and
|
|
287
|
+
getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None)
|
|
288
|
+
for mem_msg in memory_messages
|
|
289
|
+
)
|
|
290
|
+
]
|
|
291
|
+
else:
|
|
292
|
+
combined_messages = list(state.messages)
|
|
293
|
+
|
|
294
|
+
# Approvals will be loaded separately via approval storage system
|
|
295
|
+
approvals_map = state.approvals
|
|
156
296
|
|
|
157
297
|
# Calculate turn count efficiently
|
|
158
298
|
memory_assistant_count = sum(1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
|
|
@@ -172,10 +312,17 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
|
|
|
172
312
|
status='end',
|
|
173
313
|
message_count=len(memory_messages)
|
|
174
314
|
)))
|
|
315
|
+
|
|
316
|
+
if filtered_count > 0:
|
|
317
|
+
print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)')
|
|
318
|
+
else:
|
|
319
|
+
print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory')
|
|
320
|
+
|
|
175
321
|
return replace(
|
|
176
322
|
state,
|
|
177
|
-
messages=
|
|
178
|
-
turn_count=turn_count
|
|
323
|
+
messages=combined_messages,
|
|
324
|
+
turn_count=turn_count,
|
|
325
|
+
approvals=approvals_map
|
|
179
326
|
)
|
|
180
327
|
return state
|
|
181
328
|
|
|
@@ -197,12 +344,23 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct
|
|
|
197
344
|
keep_recent = config.memory.compression_threshold - keep_first
|
|
198
345
|
messages_to_store = messages_to_store[:keep_first] + messages_to_store[-keep_recent:]
|
|
199
346
|
|
|
347
|
+
# Store approval information if any approvals were made
|
|
348
|
+
approval_metadata = {}
|
|
349
|
+
if state.approvals:
|
|
350
|
+
approval_metadata = {
|
|
351
|
+
"approval_count": len(state.approvals),
|
|
352
|
+
"approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved],
|
|
353
|
+
"rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved],
|
|
354
|
+
"has_approvals": True
|
|
355
|
+
}
|
|
356
|
+
|
|
200
357
|
metadata = {
|
|
201
358
|
"user_id": getattr(state.context, 'user_id', None),
|
|
202
359
|
"trace_id": str(state.trace_id),
|
|
203
360
|
"run_id": str(state.run_id),
|
|
204
361
|
"agent_name": state.current_agent_name,
|
|
205
|
-
"turn_count": state.turn_count
|
|
362
|
+
"turn_count": state.turn_count,
|
|
363
|
+
**approval_metadata
|
|
206
364
|
}
|
|
207
365
|
|
|
208
366
|
result = await config.memory.provider.store_messages(config.conversation_id, messages_to_store, metadata)
|
|
@@ -234,6 +392,11 @@ async def _run_internal(
|
|
|
234
392
|
config: RunConfig[Ctx]
|
|
235
393
|
) -> RunResult[Out]:
|
|
236
394
|
"""Internal run function with recursive execution logic."""
|
|
395
|
+
# Try to resume pending tool calls first
|
|
396
|
+
resumed = await try_resume_pending_tool_calls(state, config)
|
|
397
|
+
if resumed:
|
|
398
|
+
return resumed
|
|
399
|
+
|
|
237
400
|
# Check initial input guardrails on first turn
|
|
238
401
|
if state.turn_count == 0:
|
|
239
402
|
first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
|
|
@@ -242,18 +405,18 @@ async def _run_internal(
|
|
|
242
405
|
if config.on_event:
|
|
243
406
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
244
407
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
245
|
-
content=first_user_message.content
|
|
408
|
+
content=get_text_content(first_user_message.content)
|
|
246
409
|
)))
|
|
247
410
|
if asyncio.iscoroutinefunction(guardrail):
|
|
248
|
-
result = await guardrail(first_user_message.content)
|
|
411
|
+
result = await guardrail(get_text_content(first_user_message.content))
|
|
249
412
|
else:
|
|
250
|
-
result = guardrail(first_user_message.content)
|
|
413
|
+
result = guardrail(get_text_content(first_user_message.content))
|
|
251
414
|
|
|
252
415
|
if not result.is_valid:
|
|
253
416
|
if config.on_event:
|
|
254
417
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
255
418
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
256
|
-
content=first_user_message.content,
|
|
419
|
+
content=get_text_content(first_user_message.content),
|
|
257
420
|
is_valid=False,
|
|
258
421
|
error_message=result.error_message
|
|
259
422
|
)))
|
|
@@ -300,8 +463,115 @@ async def _run_internal(
|
|
|
300
463
|
messages=state.messages
|
|
301
464
|
))))
|
|
302
465
|
|
|
303
|
-
# Get completion from model provider
|
|
304
|
-
llm_response
|
|
466
|
+
# Get completion from model provider, prefer streaming if available
|
|
467
|
+
llm_response: Dict[str, Any]
|
|
468
|
+
assistant_event_streamed = False
|
|
469
|
+
|
|
470
|
+
get_stream = getattr(config.model_provider, "get_completion_stream", None)
|
|
471
|
+
if callable(get_stream):
|
|
472
|
+
try:
|
|
473
|
+
aggregated_text = ""
|
|
474
|
+
# Working array of partial tool calls
|
|
475
|
+
partial_tool_calls: List[Dict[str, Any]] = []
|
|
476
|
+
|
|
477
|
+
async for chunk in get_stream(state, current_agent, config): # type: ignore[arg-type]
|
|
478
|
+
# Text deltas
|
|
479
|
+
delta_text = getattr(chunk, "delta", None)
|
|
480
|
+
if delta_text:
|
|
481
|
+
aggregated_text += delta_text
|
|
482
|
+
|
|
483
|
+
# Tool call deltas
|
|
484
|
+
tcd = getattr(chunk, "tool_call_delta", None)
|
|
485
|
+
if tcd is not None:
|
|
486
|
+
idx = getattr(tcd, "index", 0) or 0
|
|
487
|
+
# Ensure slot exists
|
|
488
|
+
while len(partial_tool_calls) <= idx:
|
|
489
|
+
partial_tool_calls.append({
|
|
490
|
+
"id": None,
|
|
491
|
+
"type": "function",
|
|
492
|
+
"function": {"name": None, "arguments": ""}
|
|
493
|
+
})
|
|
494
|
+
target = partial_tool_calls[idx]
|
|
495
|
+
# id
|
|
496
|
+
tc_id = getattr(tcd, "id", None)
|
|
497
|
+
if tc_id:
|
|
498
|
+
target["id"] = tc_id
|
|
499
|
+
# function fields
|
|
500
|
+
fn = getattr(tcd, "function", None)
|
|
501
|
+
if fn is not None:
|
|
502
|
+
fn_name = getattr(fn, "name", None)
|
|
503
|
+
if fn_name:
|
|
504
|
+
target["function"]["name"] = fn_name
|
|
505
|
+
args_delta = getattr(fn, "arguments_delta", None)
|
|
506
|
+
if args_delta:
|
|
507
|
+
target["function"]["arguments"] += args_delta
|
|
508
|
+
|
|
509
|
+
# Emit partial assistant message when something changed
|
|
510
|
+
if delta_text or tcd is not None:
|
|
511
|
+
assistant_event_streamed = True
|
|
512
|
+
# Normalize tool_calls for message
|
|
513
|
+
message_tool_calls = None
|
|
514
|
+
if len(partial_tool_calls) > 0:
|
|
515
|
+
message_tool_calls = []
|
|
516
|
+
for i, tc in enumerate(partial_tool_calls):
|
|
517
|
+
message_tool_calls.append({
|
|
518
|
+
"id": tc["id"] or f"call_{i}",
|
|
519
|
+
"type": "function",
|
|
520
|
+
"function": {
|
|
521
|
+
"name": tc["function"]["name"] or "",
|
|
522
|
+
"arguments": tc["function"]["arguments"]
|
|
523
|
+
}
|
|
524
|
+
})
|
|
525
|
+
|
|
526
|
+
partial_msg = Message(
|
|
527
|
+
role=ContentRole.ASSISTANT,
|
|
528
|
+
content=aggregated_text or "",
|
|
529
|
+
tool_calls=None if not message_tool_calls else [
|
|
530
|
+
ToolCall(
|
|
531
|
+
id=mc["id"],
|
|
532
|
+
type="function",
|
|
533
|
+
function=ToolCallFunction(
|
|
534
|
+
name=mc["function"]["name"],
|
|
535
|
+
arguments=mc["function"]["arguments"],
|
|
536
|
+
),
|
|
537
|
+
) for mc in message_tool_calls
|
|
538
|
+
],
|
|
539
|
+
)
|
|
540
|
+
try:
|
|
541
|
+
if config.on_event:
|
|
542
|
+
config.on_event(AssistantMessageEvent(data=to_event_data(
|
|
543
|
+
AssistantMessageEventData(message=partial_msg)
|
|
544
|
+
)))
|
|
545
|
+
except Exception as _e:
|
|
546
|
+
# Do not fail the run on callback errors
|
|
547
|
+
pass
|
|
548
|
+
|
|
549
|
+
# Build final response object compatible with downstream logic
|
|
550
|
+
final_tool_calls = None
|
|
551
|
+
if len(partial_tool_calls) > 0:
|
|
552
|
+
final_tool_calls = []
|
|
553
|
+
for i, tc in enumerate(partial_tool_calls):
|
|
554
|
+
final_tool_calls.append({
|
|
555
|
+
"id": tc["id"] or f"call_{i}",
|
|
556
|
+
"type": "function",
|
|
557
|
+
"function": {
|
|
558
|
+
"name": tc["function"]["name"] or "",
|
|
559
|
+
"arguments": tc["function"]["arguments"]
|
|
560
|
+
}
|
|
561
|
+
})
|
|
562
|
+
|
|
563
|
+
llm_response = {
|
|
564
|
+
"message": {
|
|
565
|
+
"content": aggregated_text or None,
|
|
566
|
+
"tool_calls": final_tool_calls
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
except Exception:
|
|
570
|
+
# Fallback to non-streaming on error
|
|
571
|
+
assistant_event_streamed = False
|
|
572
|
+
llm_response = await config.model_provider.get_completion(state, current_agent, config)
|
|
573
|
+
else:
|
|
574
|
+
llm_response = await config.model_provider.get_completion(state, current_agent, config)
|
|
305
575
|
|
|
306
576
|
# Emit LLM call end event
|
|
307
577
|
if config.on_event:
|
|
@@ -339,6 +609,45 @@ async def _run_internal(
|
|
|
339
609
|
config
|
|
340
610
|
)
|
|
341
611
|
|
|
612
|
+
# Check for interruptions
|
|
613
|
+
interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
|
|
614
|
+
if interruptions:
|
|
615
|
+
# Separate completed tool results from interrupted ones
|
|
616
|
+
completed_results = [r for r in tool_results if not r.get('interruption')]
|
|
617
|
+
approval_required_results = [r for r in tool_results if r.get('interruption')]
|
|
618
|
+
|
|
619
|
+
# Add pending approvals to state.approvals
|
|
620
|
+
updated_approvals = dict(state.approvals)
|
|
621
|
+
for interruption in interruptions:
|
|
622
|
+
if interruption.type == 'tool_approval':
|
|
623
|
+
updated_approvals[interruption.tool_call.id] = ApprovalValue(
|
|
624
|
+
status='pending',
|
|
625
|
+
approved=False,
|
|
626
|
+
additional_context={'status': 'pending', 'timestamp': str(int(time.time() * 1000))}
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Create state with only completed tool results (for LLM context)
|
|
630
|
+
interrupted_state = replace(
|
|
631
|
+
state,
|
|
632
|
+
messages=new_messages + [r['message'] for r in completed_results],
|
|
633
|
+
turn_count=state.turn_count + 1,
|
|
634
|
+
approvals=updated_approvals
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Store conversation state with ALL messages including approval-required (for database records)
|
|
638
|
+
if config.memory and config.memory.auto_store and config.conversation_id:
|
|
639
|
+
print(f'[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}')
|
|
640
|
+
state_for_storage = replace(
|
|
641
|
+
interrupted_state,
|
|
642
|
+
messages=interrupted_state.messages + [r['message'] for r in approval_required_results]
|
|
643
|
+
)
|
|
644
|
+
await _store_conversation_history(state_for_storage, config)
|
|
645
|
+
|
|
646
|
+
return RunResult(
|
|
647
|
+
final_state=interrupted_state,
|
|
648
|
+
outcome=InterruptedOutcome(interruptions=interruptions)
|
|
649
|
+
)
|
|
650
|
+
|
|
342
651
|
# Check for handoffs
|
|
343
652
|
handoff_result = next((r for r in tool_results if r.get('is_handoff')), None)
|
|
344
653
|
if handoff_result:
|
|
@@ -360,40 +669,76 @@ async def _run_internal(
|
|
|
360
669
|
to=target_agent
|
|
361
670
|
))))
|
|
362
671
|
|
|
672
|
+
# Remove any halted messages that are being replaced by actual execution results
|
|
673
|
+
cleaned_new_messages = []
|
|
674
|
+
for msg in new_messages:
|
|
675
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
676
|
+
cleaned_new_messages.append(msg)
|
|
677
|
+
else:
|
|
678
|
+
try:
|
|
679
|
+
content = json.loads(msg.content)
|
|
680
|
+
if content.get('status') == 'halted':
|
|
681
|
+
# Remove this halted message if we have a new result for the same tool_call_id
|
|
682
|
+
if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
|
|
683
|
+
cleaned_new_messages.append(msg)
|
|
684
|
+
else:
|
|
685
|
+
cleaned_new_messages.append(msg)
|
|
686
|
+
except (json.JSONDecodeError, TypeError):
|
|
687
|
+
cleaned_new_messages.append(msg)
|
|
688
|
+
|
|
363
689
|
# Continue with new agent
|
|
364
690
|
next_state = replace(
|
|
365
691
|
state,
|
|
366
|
-
messages=
|
|
692
|
+
messages=cleaned_new_messages + [r['message'] for r in tool_results],
|
|
367
693
|
current_agent_name=target_agent,
|
|
368
|
-
turn_count=state.turn_count + 1
|
|
694
|
+
turn_count=state.turn_count + 1,
|
|
695
|
+
approvals=state.approvals
|
|
369
696
|
)
|
|
370
697
|
|
|
371
698
|
return await _run_internal(next_state, config)
|
|
372
699
|
|
|
700
|
+
# Remove any halted messages that are being replaced by actual execution results
|
|
701
|
+
cleaned_new_messages = []
|
|
702
|
+
for msg in new_messages:
|
|
703
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
704
|
+
cleaned_new_messages.append(msg)
|
|
705
|
+
else:
|
|
706
|
+
try:
|
|
707
|
+
content = json.loads(msg.content)
|
|
708
|
+
if content.get('status') == 'halted':
|
|
709
|
+
# Remove this halted message if we have a new result for the same tool_call_id
|
|
710
|
+
if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
|
|
711
|
+
cleaned_new_messages.append(msg)
|
|
712
|
+
else:
|
|
713
|
+
cleaned_new_messages.append(msg)
|
|
714
|
+
except (json.JSONDecodeError, TypeError):
|
|
715
|
+
cleaned_new_messages.append(msg)
|
|
716
|
+
|
|
373
717
|
# Continue with tool results
|
|
374
718
|
next_state = replace(
|
|
375
719
|
state,
|
|
376
|
-
messages=
|
|
377
|
-
turn_count=state.turn_count + 1
|
|
720
|
+
messages=cleaned_new_messages + [r['message'] for r in tool_results],
|
|
721
|
+
turn_count=state.turn_count + 1,
|
|
722
|
+
approvals=state.approvals
|
|
378
723
|
)
|
|
379
724
|
|
|
380
725
|
return await _run_internal(next_state, config)
|
|
381
726
|
|
|
382
727
|
# Handle text completion
|
|
383
|
-
if assistant_message.content:
|
|
728
|
+
if get_text_content(assistant_message.content):
|
|
384
729
|
if current_agent.output_codec:
|
|
385
730
|
# Parse with output codec
|
|
386
731
|
if config.on_event:
|
|
387
732
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
388
|
-
content=assistant_message.content,
|
|
733
|
+
content=get_text_content(assistant_message.content),
|
|
389
734
|
status='start'
|
|
390
735
|
)))
|
|
391
736
|
try:
|
|
392
|
-
parsed_content = _try_parse_json(assistant_message.content)
|
|
737
|
+
parsed_content = _try_parse_json(get_text_content(assistant_message.content))
|
|
393
738
|
output_data = current_agent.output_codec.model_validate(parsed_content)
|
|
394
739
|
if config.on_event:
|
|
395
740
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
396
|
-
content=assistant_message.content,
|
|
741
|
+
content=get_text_content(assistant_message.content),
|
|
397
742
|
status='end',
|
|
398
743
|
parsed_output=output_data
|
|
399
744
|
)))
|
|
@@ -420,26 +765,26 @@ async def _run_internal(
|
|
|
420
765
|
error_message=result.error_message
|
|
421
766
|
)))
|
|
422
767
|
return RunResult(
|
|
423
|
-
final_state=replace(state, messages=new_messages),
|
|
768
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
424
769
|
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
425
770
|
reason=result.error_message or "Output guardrail failed"
|
|
426
771
|
))
|
|
427
772
|
)
|
|
428
773
|
|
|
429
774
|
return RunResult(
|
|
430
|
-
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
|
|
775
|
+
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
|
|
431
776
|
outcome=CompletedOutcome(output=output_data)
|
|
432
777
|
)
|
|
433
778
|
|
|
434
779
|
except ValidationError as e:
|
|
435
780
|
if config.on_event:
|
|
436
781
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
437
|
-
content=assistant_message.content,
|
|
782
|
+
content=get_text_content(assistant_message.content),
|
|
438
783
|
status='fail',
|
|
439
784
|
error=str(e)
|
|
440
785
|
)))
|
|
441
786
|
return RunResult(
|
|
442
|
-
final_state=replace(state, messages=new_messages),
|
|
787
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
443
788
|
outcome=ErrorOutcome(error=DecodeError(
|
|
444
789
|
errors=[{'message': str(e), 'details': e.errors()}]
|
|
445
790
|
))
|
|
@@ -451,36 +796,36 @@ async def _run_internal(
|
|
|
451
796
|
if config.on_event:
|
|
452
797
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
453
798
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
454
|
-
content=assistant_message.content
|
|
799
|
+
content=get_text_content(assistant_message.content)
|
|
455
800
|
)))
|
|
456
801
|
if asyncio.iscoroutinefunction(guardrail):
|
|
457
|
-
result = await guardrail(assistant_message.content)
|
|
802
|
+
result = await guardrail(get_text_content(assistant_message.content))
|
|
458
803
|
else:
|
|
459
|
-
result = guardrail(assistant_message.content)
|
|
804
|
+
result = guardrail(get_text_content(assistant_message.content))
|
|
460
805
|
|
|
461
806
|
if not result.is_valid:
|
|
462
807
|
if config.on_event:
|
|
463
808
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
464
809
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
465
|
-
content=assistant_message.content,
|
|
810
|
+
content=get_text_content(assistant_message.content),
|
|
466
811
|
is_valid=False,
|
|
467
812
|
error_message=result.error_message
|
|
468
813
|
)))
|
|
469
814
|
return RunResult(
|
|
470
|
-
final_state=replace(state, messages=new_messages),
|
|
815
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
471
816
|
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
472
817
|
reason=result.error_message or "Output guardrail failed"
|
|
473
818
|
))
|
|
474
819
|
)
|
|
475
820
|
|
|
476
821
|
return RunResult(
|
|
477
|
-
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
|
|
478
|
-
outcome=CompletedOutcome(output=assistant_message.content)
|
|
822
|
+
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
|
|
823
|
+
outcome=CompletedOutcome(output=get_text_content(assistant_message.content))
|
|
479
824
|
)
|
|
480
825
|
|
|
481
826
|
# Model produced neither content nor tool calls
|
|
482
827
|
return RunResult(
|
|
483
|
-
final_state=replace(state, messages=new_messages),
|
|
828
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
484
829
|
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
485
830
|
detail='Model produced neither content nor tool calls'
|
|
486
831
|
))
|
|
@@ -512,6 +857,7 @@ async def _execute_tool_calls(
|
|
|
512
857
|
"""Execute tool calls and return results."""
|
|
513
858
|
|
|
514
859
|
async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]:
|
|
860
|
+
print(f'[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}')
|
|
515
861
|
if config.on_event:
|
|
516
862
|
config.on_event(ToolCallStartEvent(data=to_event_data(ToolCallStartEventData(
|
|
517
863
|
tool_name=tool_call.function.name,
|
|
@@ -531,7 +877,7 @@ async def _execute_tool_calls(
|
|
|
531
877
|
|
|
532
878
|
if not tool:
|
|
533
879
|
error_result = json.dumps({
|
|
534
|
-
'
|
|
880
|
+
'status': 'tool_not_found',
|
|
535
881
|
'message': f'Tool {tool_call.function.name} not found',
|
|
536
882
|
'tool_name': tool_call.function.name,
|
|
537
883
|
})
|
|
@@ -564,7 +910,7 @@ async def _execute_tool_calls(
|
|
|
564
910
|
validated_args = raw_args
|
|
565
911
|
except ValidationError as e:
|
|
566
912
|
error_result = json.dumps({
|
|
567
|
-
'
|
|
913
|
+
'status': 'validation_error',
|
|
568
914
|
'message': f'Invalid arguments for {tool_call.function.name}: {e!s}',
|
|
569
915
|
'tool_name': tool_call.function.name,
|
|
570
916
|
'validation_errors': e.errors()
|
|
@@ -588,16 +934,116 @@ async def _execute_tool_calls(
|
|
|
588
934
|
)
|
|
589
935
|
}
|
|
590
936
|
|
|
937
|
+
# Check if tool needs approval
|
|
938
|
+
needs_approval = False
|
|
939
|
+
approval_func = getattr(tool, 'needs_approval', False)
|
|
940
|
+
if callable(approval_func):
|
|
941
|
+
needs_approval = await approval_func(state.context, validated_args)
|
|
942
|
+
else:
|
|
943
|
+
needs_approval = bool(approval_func)
|
|
944
|
+
|
|
945
|
+
# Check approval status - first by ID, then by signature for cross-session matching
|
|
946
|
+
approval_status = state.approvals.get(tool_call.id)
|
|
947
|
+
if not approval_status:
|
|
948
|
+
signature = f"{tool_call.function.name}:{tool_call.function.arguments}"
|
|
949
|
+
for _, approval in state.approvals.items():
|
|
950
|
+
if approval.additional_context and approval.additional_context.get('signature') == signature:
|
|
951
|
+
approval_status = approval
|
|
952
|
+
break
|
|
953
|
+
|
|
954
|
+
derived_status = None
|
|
955
|
+
if approval_status:
|
|
956
|
+
# Use explicit status if available
|
|
957
|
+
if approval_status.status:
|
|
958
|
+
derived_status = approval_status.status
|
|
959
|
+
# Fall back to approved boolean if status not set
|
|
960
|
+
elif approval_status.approved is True:
|
|
961
|
+
derived_status = 'approved'
|
|
962
|
+
elif approval_status.approved is False:
|
|
963
|
+
if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending':
|
|
964
|
+
derived_status = 'pending'
|
|
965
|
+
else:
|
|
966
|
+
derived_status = 'rejected'
|
|
967
|
+
|
|
968
|
+
is_pending = derived_status == 'pending'
|
|
969
|
+
|
|
970
|
+
# If approval needed and not yet decided, create interruption
|
|
971
|
+
if needs_approval and (approval_status is None or is_pending):
|
|
972
|
+
interruption = ToolApprovalInterruption(
|
|
973
|
+
type='tool_approval',
|
|
974
|
+
tool_call=tool_call,
|
|
975
|
+
agent=agent,
|
|
976
|
+
session_id=str(state.run_id)
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Return interrupted result with halted message
|
|
980
|
+
halted_result = json.dumps({
|
|
981
|
+
'status': 'halted',
|
|
982
|
+
'message': f'Tool {tool_call.function.name} requires approval.',
|
|
983
|
+
})
|
|
984
|
+
|
|
985
|
+
return {
|
|
986
|
+
'message': Message(
|
|
987
|
+
role=ContentRole.TOOL,
|
|
988
|
+
content=halted_result,
|
|
989
|
+
tool_call_id=tool_call.id
|
|
990
|
+
),
|
|
991
|
+
'interruption': interruption
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
# If approval was explicitly rejected, return rejection message
|
|
995
|
+
if derived_status == 'rejected':
|
|
996
|
+
rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action'
|
|
997
|
+
rejection_result = json.dumps({
|
|
998
|
+
'status': 'approval_denied',
|
|
999
|
+
'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.',
|
|
1000
|
+
'tool_name': tool_call.function.name,
|
|
1001
|
+
'rejection_reason': rejection_reason,
|
|
1002
|
+
'additional_context': approval_status.additional_context if approval_status else None
|
|
1003
|
+
})
|
|
1004
|
+
|
|
1005
|
+
return {
|
|
1006
|
+
'message': Message(
|
|
1007
|
+
role=ContentRole.TOOL,
|
|
1008
|
+
content=rejection_result,
|
|
1009
|
+
tool_call_id=tool_call.id
|
|
1010
|
+
)
|
|
1011
|
+
}
|
|
1012
|
+
|
|
591
1013
|
# Determine timeout for this tool
|
|
592
1014
|
# Priority: tool-specific timeout > RunConfig default > 30 seconds global default
|
|
593
|
-
|
|
1015
|
+
if tool and hasattr(tool, 'schema'):
|
|
1016
|
+
timeout = getattr(tool.schema, 'timeout', None)
|
|
1017
|
+
else:
|
|
1018
|
+
timeout = None
|
|
594
1019
|
if timeout is None:
|
|
595
1020
|
timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
|
|
596
1021
|
|
|
1022
|
+
# Merge additional context if provided through approval
|
|
1023
|
+
additional_context = approval_status.additional_context if approval_status else None
|
|
1024
|
+
context_with_additional = state.context
|
|
1025
|
+
if additional_context:
|
|
1026
|
+
# Create a copy of context with additional fields from approval
|
|
1027
|
+
if hasattr(state.context, '__dict__'):
|
|
1028
|
+
# For dataclass contexts, add additional context as attributes
|
|
1029
|
+
context_dict = {**state.context.__dict__, **additional_context}
|
|
1030
|
+
context_with_additional = type(state.context)(**{k: v for k, v in context_dict.items() if k in state.context.__dict__})
|
|
1031
|
+
# Add any extra fields as attributes
|
|
1032
|
+
for key, value in additional_context.items():
|
|
1033
|
+
if not hasattr(context_with_additional, key):
|
|
1034
|
+
setattr(context_with_additional, key, value)
|
|
1035
|
+
else:
|
|
1036
|
+
# For dict contexts, merge normally
|
|
1037
|
+
context_with_additional = {**state.context, **additional_context}
|
|
1038
|
+
|
|
1039
|
+
print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}')
|
|
1040
|
+
print(f'[JAF:ENGINE] Tool args:', validated_args)
|
|
1041
|
+
print(f'[JAF:ENGINE] Tool context:', state.context)
|
|
1042
|
+
|
|
597
1043
|
# Execute the tool with timeout
|
|
598
1044
|
try:
|
|
599
1045
|
tool_result = await asyncio.wait_for(
|
|
600
|
-
tool.execute(validated_args,
|
|
1046
|
+
tool.execute(validated_args, context_with_additional),
|
|
601
1047
|
timeout=timeout
|
|
602
1048
|
)
|
|
603
1049
|
except asyncio.TimeoutError:
|
|
@@ -629,14 +1075,41 @@ async def _execute_tool_calls(
|
|
|
629
1075
|
# Handle both string and ToolResult formats
|
|
630
1076
|
if isinstance(tool_result, str):
|
|
631
1077
|
result_string = tool_result
|
|
1078
|
+
print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned string:', result_string)
|
|
632
1079
|
else:
|
|
633
1080
|
# It's a ToolResult object
|
|
634
1081
|
result_string = tool_result_to_string(tool_result)
|
|
1082
|
+
print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:', tool_result)
|
|
1083
|
+
print(f'[JAF:ENGINE] Converted to string:', result_string)
|
|
1084
|
+
|
|
1085
|
+
# Wrap tool result with status information for approval context
|
|
1086
|
+
if approval_status and approval_status.additional_context:
|
|
1087
|
+
final_content = json.dumps({
|
|
1088
|
+
'status': 'approved_and_executed',
|
|
1089
|
+
'result': result_string,
|
|
1090
|
+
'tool_name': tool_call.function.name,
|
|
1091
|
+
'approval_context': approval_status.additional_context,
|
|
1092
|
+
'message': 'Tool was approved and executed successfully with additional context.'
|
|
1093
|
+
})
|
|
1094
|
+
elif needs_approval:
|
|
1095
|
+
final_content = json.dumps({
|
|
1096
|
+
'status': 'approved_and_executed',
|
|
1097
|
+
'result': result_string,
|
|
1098
|
+
'tool_name': tool_call.function.name,
|
|
1099
|
+
'message': 'Tool was approved and executed successfully.'
|
|
1100
|
+
})
|
|
1101
|
+
else:
|
|
1102
|
+
final_content = json.dumps({
|
|
1103
|
+
'status': 'executed',
|
|
1104
|
+
'result': result_string,
|
|
1105
|
+
'tool_name': tool_call.function.name,
|
|
1106
|
+
'message': 'Tool executed successfully.'
|
|
1107
|
+
})
|
|
635
1108
|
|
|
636
1109
|
if config.on_event:
|
|
637
1110
|
config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
|
|
638
1111
|
tool_name=tool_call.function.name,
|
|
639
|
-
result=
|
|
1112
|
+
result=final_content,
|
|
640
1113
|
trace_id=state.trace_id,
|
|
641
1114
|
run_id=state.run_id,
|
|
642
1115
|
tool_result=tool_result,
|
|
@@ -649,7 +1122,7 @@ async def _execute_tool_calls(
|
|
|
649
1122
|
return {
|
|
650
1123
|
'message': Message(
|
|
651
1124
|
role=ContentRole.TOOL,
|
|
652
|
-
content=
|
|
1125
|
+
content=final_content,
|
|
653
1126
|
tool_call_id=tool_call.id
|
|
654
1127
|
),
|
|
655
1128
|
'is_handoff': True,
|
|
@@ -659,14 +1132,14 @@ async def _execute_tool_calls(
|
|
|
659
1132
|
return {
|
|
660
1133
|
'message': Message(
|
|
661
1134
|
role=ContentRole.TOOL,
|
|
662
|
-
content=
|
|
1135
|
+
content=final_content,
|
|
663
1136
|
tool_call_id=tool_call.id
|
|
664
1137
|
)
|
|
665
1138
|
}
|
|
666
1139
|
|
|
667
1140
|
except Exception as error:
|
|
668
1141
|
error_result = json.dumps({
|
|
669
|
-
'
|
|
1142
|
+
'status': 'execution_error',
|
|
670
1143
|
'message': str(error),
|
|
671
1144
|
'tool_name': tool_call.function.name,
|
|
672
1145
|
})
|