jaf-py 2.4.1__py3-none-any.whl → 2.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/__init__.py +15 -0
- jaf/core/agent_tool.py +6 -4
- jaf/core/analytics.py +4 -3
- jaf/core/engine.py +401 -37
- jaf/core/state.py +156 -0
- jaf/core/tracing.py +114 -23
- jaf/core/types.py +113 -3
- jaf/memory/approval_storage.py +306 -0
- jaf/memory/types.py +1 -0
- jaf/memory/utils.py +1 -1
- jaf/providers/model.py +277 -17
- jaf/server/__init__.py +2 -0
- jaf/server/server.py +665 -22
- jaf/server/types.py +149 -4
- jaf/utils/__init__.py +50 -0
- jaf/utils/attachments.py +401 -0
- jaf/utils/document_processor.py +561 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/METADATA +10 -2
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/RECORD +23 -18
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/WHEEL +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.4.1.dist-info → jaf_py-2.4.2.dist-info}/top_level.txt +0 -0
jaf/core/engine.py
CHANGED
|
@@ -8,6 +8,7 @@ tool calling, and state management while maintaining functional purity.
|
|
|
8
8
|
import asyncio
|
|
9
9
|
import json
|
|
10
10
|
import os
|
|
11
|
+
import time
|
|
11
12
|
from dataclasses import replace, asdict, is_dataclass
|
|
12
13
|
from typing import Any, Dict, List, Optional, TypeVar
|
|
13
14
|
|
|
@@ -18,6 +19,7 @@ from .tool_results import tool_result_to_string
|
|
|
18
19
|
from .types import (
|
|
19
20
|
Agent,
|
|
20
21
|
AgentNotFound,
|
|
22
|
+
ApprovalValue,
|
|
21
23
|
CompletedOutcome,
|
|
22
24
|
ContentRole,
|
|
23
25
|
DecodeError,
|
|
@@ -26,6 +28,8 @@ from .types import (
|
|
|
26
28
|
HandoffEvent,
|
|
27
29
|
HandoffEventData,
|
|
28
30
|
InputGuardrailTripwire,
|
|
31
|
+
InterruptedOutcome,
|
|
32
|
+
Interruption,
|
|
29
33
|
GuardrailEvent,
|
|
30
34
|
GuardrailEventData,
|
|
31
35
|
MemoryEvent,
|
|
@@ -40,6 +44,7 @@ from .types import (
|
|
|
40
44
|
AssistantMessageEventData,
|
|
41
45
|
MaxTurnsExceeded,
|
|
42
46
|
Message,
|
|
47
|
+
get_text_content,
|
|
43
48
|
ModelBehaviorError,
|
|
44
49
|
OutputGuardrailTripwire,
|
|
45
50
|
RunConfig,
|
|
@@ -49,6 +54,7 @@ from .types import (
|
|
|
49
54
|
RunStartEvent,
|
|
50
55
|
RunStartEventData,
|
|
51
56
|
RunState,
|
|
57
|
+
ToolApprovalInterruption,
|
|
52
58
|
ToolCall,
|
|
53
59
|
ToolCallEndEvent,
|
|
54
60
|
ToolCallEndEventData,
|
|
@@ -80,6 +86,82 @@ def to_event_data(value: Any) -> Any:
|
|
|
80
86
|
Ctx = TypeVar('Ctx')
|
|
81
87
|
Out = TypeVar('Out')
|
|
82
88
|
|
|
89
|
+
async def try_resume_pending_tool_calls(
|
|
90
|
+
state: RunState[Ctx],
|
|
91
|
+
config: RunConfig[Ctx]
|
|
92
|
+
) -> Optional[RunResult[Out]]:
|
|
93
|
+
"""
|
|
94
|
+
Try to resume pending tool calls if the last assistant message contained tool_calls
|
|
95
|
+
and some of those calls have not yet produced tool results.
|
|
96
|
+
"""
|
|
97
|
+
try:
|
|
98
|
+
messages = state.messages
|
|
99
|
+
for i in range(len(messages) - 1, -1, -1):
|
|
100
|
+
msg = messages[i]
|
|
101
|
+
# Handle both string and enum roles
|
|
102
|
+
role_str = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
|
|
103
|
+
if role_str == 'assistant' and msg.tool_calls:
|
|
104
|
+
tool_call_ids = {tc.id for tc in msg.tool_calls}
|
|
105
|
+
|
|
106
|
+
# Scan forward for tool results tied to these ids
|
|
107
|
+
executed_ids = set()
|
|
108
|
+
for j in range(i + 1, len(messages)):
|
|
109
|
+
m = messages[j]
|
|
110
|
+
# Handle both string and enum roles
|
|
111
|
+
m_role_str = m.role.value if hasattr(m.role, 'value') else str(m.role)
|
|
112
|
+
if m_role_str == 'tool' and m.tool_call_id and m.tool_call_id in tool_call_ids:
|
|
113
|
+
executed_ids.add(m.tool_call_id)
|
|
114
|
+
|
|
115
|
+
pending_tool_calls = [tc for tc in msg.tool_calls if tc.id not in executed_ids]
|
|
116
|
+
|
|
117
|
+
if not pending_tool_calls:
|
|
118
|
+
continue # Continue checking other assistant messages
|
|
119
|
+
|
|
120
|
+
current_agent = config.agent_registry.get(state.current_agent_name)
|
|
121
|
+
if not current_agent:
|
|
122
|
+
return RunResult(
|
|
123
|
+
final_state=state,
|
|
124
|
+
outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Execute pending tool calls
|
|
128
|
+
tool_results = await _execute_tool_calls(
|
|
129
|
+
pending_tool_calls,
|
|
130
|
+
current_agent,
|
|
131
|
+
state,
|
|
132
|
+
config
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Check for interruptions
|
|
136
|
+
interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
|
|
137
|
+
if interruptions:
|
|
138
|
+
completed_results = [r for r in tool_results if not r.get('interruption')]
|
|
139
|
+
interrupted_state = replace(
|
|
140
|
+
state,
|
|
141
|
+
messages=list(state.messages) + [r['message'] for r in completed_results],
|
|
142
|
+
turn_count=state.turn_count,
|
|
143
|
+
approvals=state.approvals
|
|
144
|
+
)
|
|
145
|
+
return RunResult(
|
|
146
|
+
final_state=interrupted_state,
|
|
147
|
+
outcome=InterruptedOutcome(interruptions=interruptions)
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Continue with normal execution
|
|
151
|
+
next_state = replace(
|
|
152
|
+
state,
|
|
153
|
+
messages=list(state.messages) + [r['message'] for r in tool_results],
|
|
154
|
+
turn_count=state.turn_count,
|
|
155
|
+
approvals=state.approvals
|
|
156
|
+
)
|
|
157
|
+
return await _run_internal(next_state, config)
|
|
158
|
+
|
|
159
|
+
except Exception as e:
|
|
160
|
+
# Best-effort resume; ignore and continue normal flow
|
|
161
|
+
pass
|
|
162
|
+
|
|
163
|
+
return None
|
|
164
|
+
|
|
83
165
|
async def run(
|
|
84
166
|
initial_state: RunState[Ctx],
|
|
85
167
|
config: RunConfig[Ctx]
|
|
@@ -102,9 +184,26 @@ async def run(
|
|
|
102
184
|
))))
|
|
103
185
|
|
|
104
186
|
state_with_memory = await _load_conversation_history(initial_state, config)
|
|
187
|
+
|
|
188
|
+
# Load approvals from storage if configured
|
|
189
|
+
if config.approval_storage:
|
|
190
|
+
print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}')
|
|
191
|
+
from .state import load_approvals_into_state
|
|
192
|
+
state_with_memory = await load_approvals_into_state(state_with_memory, config)
|
|
193
|
+
|
|
105
194
|
result = await _run_internal(state_with_memory, config)
|
|
106
195
|
|
|
107
|
-
|
|
196
|
+
# Store conversation history only if this is a final completion of the entire conversation
|
|
197
|
+
# For HITL scenarios, storage happens on interruption to allow resumption
|
|
198
|
+
# We only store on completion if explicitly indicated this is the end of the conversation
|
|
199
|
+
if (config.memory and config.memory.auto_store and config.conversation_id and
|
|
200
|
+
result.outcome.status == 'completed' and getattr(config.memory, 'store_on_completion', True)):
|
|
201
|
+
print(f'[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}')
|
|
202
|
+
await _store_conversation_history(result.final_state, config)
|
|
203
|
+
elif result.outcome.status == 'interrupted':
|
|
204
|
+
print('[JAF:ENGINE] Conversation interrupted - storage already handled during interruption')
|
|
205
|
+
else:
|
|
206
|
+
print(f'[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, "store_on_completion", True) if config.memory else "N/A"}')
|
|
108
207
|
|
|
109
208
|
if config.on_event:
|
|
110
209
|
config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
|
|
@@ -154,7 +253,46 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
|
|
|
154
253
|
conversation_data = result.data
|
|
155
254
|
if conversation_data:
|
|
156
255
|
max_messages = config.memory.max_messages or len(conversation_data.messages)
|
|
157
|
-
|
|
256
|
+
all_memory_messages = conversation_data.messages[-max_messages:]
|
|
257
|
+
|
|
258
|
+
# Filter out halted messages - they're for audit/database only, not for LLM context
|
|
259
|
+
memory_messages = []
|
|
260
|
+
filtered_count = 0
|
|
261
|
+
|
|
262
|
+
for msg in all_memory_messages:
|
|
263
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
264
|
+
memory_messages.append(msg)
|
|
265
|
+
else:
|
|
266
|
+
try:
|
|
267
|
+
content = json.loads(msg.content)
|
|
268
|
+
status = content.get('status')
|
|
269
|
+
# Filter out ALL halted messages (they're for audit only)
|
|
270
|
+
if status == 'halted':
|
|
271
|
+
filtered_count += 1
|
|
272
|
+
continue # Skip this halted message
|
|
273
|
+
else:
|
|
274
|
+
memory_messages.append(msg)
|
|
275
|
+
except (json.JSONDecodeError, TypeError):
|
|
276
|
+
# Keep non-JSON tool messages
|
|
277
|
+
memory_messages.append(msg)
|
|
278
|
+
|
|
279
|
+
# For HITL scenarios, append new messages to memory messages
|
|
280
|
+
# This prevents duplication when resuming from interruptions
|
|
281
|
+
if memory_messages:
|
|
282
|
+
combined_messages = memory_messages + [
|
|
283
|
+
msg for msg in state.messages
|
|
284
|
+
if not any(
|
|
285
|
+
mem_msg.role == msg.role and
|
|
286
|
+
mem_msg.content == msg.content and
|
|
287
|
+
getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None)
|
|
288
|
+
for mem_msg in memory_messages
|
|
289
|
+
)
|
|
290
|
+
]
|
|
291
|
+
else:
|
|
292
|
+
combined_messages = list(state.messages)
|
|
293
|
+
|
|
294
|
+
# Approvals will be loaded separately via approval storage system
|
|
295
|
+
approvals_map = state.approvals
|
|
158
296
|
|
|
159
297
|
# Calculate turn count efficiently
|
|
160
298
|
memory_assistant_count = sum(1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
|
|
@@ -174,10 +312,17 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
|
|
|
174
312
|
status='end',
|
|
175
313
|
message_count=len(memory_messages)
|
|
176
314
|
)))
|
|
315
|
+
|
|
316
|
+
if filtered_count > 0:
|
|
317
|
+
print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)')
|
|
318
|
+
else:
|
|
319
|
+
print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory')
|
|
320
|
+
|
|
177
321
|
return replace(
|
|
178
322
|
state,
|
|
179
|
-
messages=
|
|
180
|
-
turn_count=turn_count
|
|
323
|
+
messages=combined_messages,
|
|
324
|
+
turn_count=turn_count,
|
|
325
|
+
approvals=approvals_map
|
|
181
326
|
)
|
|
182
327
|
return state
|
|
183
328
|
|
|
@@ -199,12 +344,23 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct
|
|
|
199
344
|
keep_recent = config.memory.compression_threshold - keep_first
|
|
200
345
|
messages_to_store = messages_to_store[:keep_first] + messages_to_store[-keep_recent:]
|
|
201
346
|
|
|
347
|
+
# Store approval information if any approvals were made
|
|
348
|
+
approval_metadata = {}
|
|
349
|
+
if state.approvals:
|
|
350
|
+
approval_metadata = {
|
|
351
|
+
"approval_count": len(state.approvals),
|
|
352
|
+
"approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved],
|
|
353
|
+
"rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved],
|
|
354
|
+
"has_approvals": True
|
|
355
|
+
}
|
|
356
|
+
|
|
202
357
|
metadata = {
|
|
203
358
|
"user_id": getattr(state.context, 'user_id', None),
|
|
204
359
|
"trace_id": str(state.trace_id),
|
|
205
360
|
"run_id": str(state.run_id),
|
|
206
361
|
"agent_name": state.current_agent_name,
|
|
207
|
-
"turn_count": state.turn_count
|
|
362
|
+
"turn_count": state.turn_count,
|
|
363
|
+
**approval_metadata
|
|
208
364
|
}
|
|
209
365
|
|
|
210
366
|
result = await config.memory.provider.store_messages(config.conversation_id, messages_to_store, metadata)
|
|
@@ -236,6 +392,11 @@ async def _run_internal(
|
|
|
236
392
|
config: RunConfig[Ctx]
|
|
237
393
|
) -> RunResult[Out]:
|
|
238
394
|
"""Internal run function with recursive execution logic."""
|
|
395
|
+
# Try to resume pending tool calls first
|
|
396
|
+
resumed = await try_resume_pending_tool_calls(state, config)
|
|
397
|
+
if resumed:
|
|
398
|
+
return resumed
|
|
399
|
+
|
|
239
400
|
# Check initial input guardrails on first turn
|
|
240
401
|
if state.turn_count == 0:
|
|
241
402
|
first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
|
|
@@ -244,18 +405,18 @@ async def _run_internal(
|
|
|
244
405
|
if config.on_event:
|
|
245
406
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
246
407
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
247
|
-
content=first_user_message.content
|
|
408
|
+
content=get_text_content(first_user_message.content)
|
|
248
409
|
)))
|
|
249
410
|
if asyncio.iscoroutinefunction(guardrail):
|
|
250
|
-
result = await guardrail(first_user_message.content)
|
|
411
|
+
result = await guardrail(get_text_content(first_user_message.content))
|
|
251
412
|
else:
|
|
252
|
-
result = guardrail(first_user_message.content)
|
|
413
|
+
result = guardrail(get_text_content(first_user_message.content))
|
|
253
414
|
|
|
254
415
|
if not result.is_valid:
|
|
255
416
|
if config.on_event:
|
|
256
417
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
257
418
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
258
|
-
content=first_user_message.content,
|
|
419
|
+
content=get_text_content(first_user_message.content),
|
|
259
420
|
is_valid=False,
|
|
260
421
|
error_message=result.error_message
|
|
261
422
|
)))
|
|
@@ -448,6 +609,45 @@ async def _run_internal(
|
|
|
448
609
|
config
|
|
449
610
|
)
|
|
450
611
|
|
|
612
|
+
# Check for interruptions
|
|
613
|
+
interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
|
|
614
|
+
if interruptions:
|
|
615
|
+
# Separate completed tool results from interrupted ones
|
|
616
|
+
completed_results = [r for r in tool_results if not r.get('interruption')]
|
|
617
|
+
approval_required_results = [r for r in tool_results if r.get('interruption')]
|
|
618
|
+
|
|
619
|
+
# Add pending approvals to state.approvals
|
|
620
|
+
updated_approvals = dict(state.approvals)
|
|
621
|
+
for interruption in interruptions:
|
|
622
|
+
if interruption.type == 'tool_approval':
|
|
623
|
+
updated_approvals[interruption.tool_call.id] = ApprovalValue(
|
|
624
|
+
status='pending',
|
|
625
|
+
approved=False,
|
|
626
|
+
additional_context={'status': 'pending', 'timestamp': str(int(time.time() * 1000))}
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
# Create state with only completed tool results (for LLM context)
|
|
630
|
+
interrupted_state = replace(
|
|
631
|
+
state,
|
|
632
|
+
messages=new_messages + [r['message'] for r in completed_results],
|
|
633
|
+
turn_count=state.turn_count + 1,
|
|
634
|
+
approvals=updated_approvals
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# Store conversation state with ALL messages including approval-required (for database records)
|
|
638
|
+
if config.memory and config.memory.auto_store and config.conversation_id:
|
|
639
|
+
print(f'[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}')
|
|
640
|
+
state_for_storage = replace(
|
|
641
|
+
interrupted_state,
|
|
642
|
+
messages=interrupted_state.messages + [r['message'] for r in approval_required_results]
|
|
643
|
+
)
|
|
644
|
+
await _store_conversation_history(state_for_storage, config)
|
|
645
|
+
|
|
646
|
+
return RunResult(
|
|
647
|
+
final_state=interrupted_state,
|
|
648
|
+
outcome=InterruptedOutcome(interruptions=interruptions)
|
|
649
|
+
)
|
|
650
|
+
|
|
451
651
|
# Check for handoffs
|
|
452
652
|
handoff_result = next((r for r in tool_results if r.get('is_handoff')), None)
|
|
453
653
|
if handoff_result:
|
|
@@ -469,40 +669,76 @@ async def _run_internal(
|
|
|
469
669
|
to=target_agent
|
|
470
670
|
))))
|
|
471
671
|
|
|
672
|
+
# Remove any halted messages that are being replaced by actual execution results
|
|
673
|
+
cleaned_new_messages = []
|
|
674
|
+
for msg in new_messages:
|
|
675
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
676
|
+
cleaned_new_messages.append(msg)
|
|
677
|
+
else:
|
|
678
|
+
try:
|
|
679
|
+
content = json.loads(msg.content)
|
|
680
|
+
if content.get('status') == 'halted':
|
|
681
|
+
# Remove this halted message if we have a new result for the same tool_call_id
|
|
682
|
+
if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
|
|
683
|
+
cleaned_new_messages.append(msg)
|
|
684
|
+
else:
|
|
685
|
+
cleaned_new_messages.append(msg)
|
|
686
|
+
except (json.JSONDecodeError, TypeError):
|
|
687
|
+
cleaned_new_messages.append(msg)
|
|
688
|
+
|
|
472
689
|
# Continue with new agent
|
|
473
690
|
next_state = replace(
|
|
474
691
|
state,
|
|
475
|
-
messages=
|
|
692
|
+
messages=cleaned_new_messages + [r['message'] for r in tool_results],
|
|
476
693
|
current_agent_name=target_agent,
|
|
477
|
-
turn_count=state.turn_count + 1
|
|
694
|
+
turn_count=state.turn_count + 1,
|
|
695
|
+
approvals=state.approvals
|
|
478
696
|
)
|
|
479
697
|
|
|
480
698
|
return await _run_internal(next_state, config)
|
|
481
699
|
|
|
700
|
+
# Remove any halted messages that are being replaced by actual execution results
|
|
701
|
+
cleaned_new_messages = []
|
|
702
|
+
for msg in new_messages:
|
|
703
|
+
if msg.role not in (ContentRole.TOOL, 'tool'):
|
|
704
|
+
cleaned_new_messages.append(msg)
|
|
705
|
+
else:
|
|
706
|
+
try:
|
|
707
|
+
content = json.loads(msg.content)
|
|
708
|
+
if content.get('status') == 'halted':
|
|
709
|
+
# Remove this halted message if we have a new result for the same tool_call_id
|
|
710
|
+
if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
|
|
711
|
+
cleaned_new_messages.append(msg)
|
|
712
|
+
else:
|
|
713
|
+
cleaned_new_messages.append(msg)
|
|
714
|
+
except (json.JSONDecodeError, TypeError):
|
|
715
|
+
cleaned_new_messages.append(msg)
|
|
716
|
+
|
|
482
717
|
# Continue with tool results
|
|
483
718
|
next_state = replace(
|
|
484
719
|
state,
|
|
485
|
-
messages=
|
|
486
|
-
turn_count=state.turn_count + 1
|
|
720
|
+
messages=cleaned_new_messages + [r['message'] for r in tool_results],
|
|
721
|
+
turn_count=state.turn_count + 1,
|
|
722
|
+
approvals=state.approvals
|
|
487
723
|
)
|
|
488
724
|
|
|
489
725
|
return await _run_internal(next_state, config)
|
|
490
726
|
|
|
491
727
|
# Handle text completion
|
|
492
|
-
if assistant_message.content:
|
|
728
|
+
if get_text_content(assistant_message.content):
|
|
493
729
|
if current_agent.output_codec:
|
|
494
730
|
# Parse with output codec
|
|
495
731
|
if config.on_event:
|
|
496
732
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
497
|
-
content=assistant_message.content,
|
|
733
|
+
content=get_text_content(assistant_message.content),
|
|
498
734
|
status='start'
|
|
499
735
|
)))
|
|
500
736
|
try:
|
|
501
|
-
parsed_content = _try_parse_json(assistant_message.content)
|
|
737
|
+
parsed_content = _try_parse_json(get_text_content(assistant_message.content))
|
|
502
738
|
output_data = current_agent.output_codec.model_validate(parsed_content)
|
|
503
739
|
if config.on_event:
|
|
504
740
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
505
|
-
content=assistant_message.content,
|
|
741
|
+
content=get_text_content(assistant_message.content),
|
|
506
742
|
status='end',
|
|
507
743
|
parsed_output=output_data
|
|
508
744
|
)))
|
|
@@ -529,26 +765,26 @@ async def _run_internal(
|
|
|
529
765
|
error_message=result.error_message
|
|
530
766
|
)))
|
|
531
767
|
return RunResult(
|
|
532
|
-
final_state=replace(state, messages=new_messages),
|
|
768
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
533
769
|
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
534
770
|
reason=result.error_message or "Output guardrail failed"
|
|
535
771
|
))
|
|
536
772
|
)
|
|
537
773
|
|
|
538
774
|
return RunResult(
|
|
539
|
-
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
|
|
775
|
+
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
|
|
540
776
|
outcome=CompletedOutcome(output=output_data)
|
|
541
777
|
)
|
|
542
778
|
|
|
543
779
|
except ValidationError as e:
|
|
544
780
|
if config.on_event:
|
|
545
781
|
config.on_event(OutputParseEvent(data=OutputParseEventData(
|
|
546
|
-
content=assistant_message.content,
|
|
782
|
+
content=get_text_content(assistant_message.content),
|
|
547
783
|
status='fail',
|
|
548
784
|
error=str(e)
|
|
549
785
|
)))
|
|
550
786
|
return RunResult(
|
|
551
|
-
final_state=replace(state, messages=new_messages),
|
|
787
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
552
788
|
outcome=ErrorOutcome(error=DecodeError(
|
|
553
789
|
errors=[{'message': str(e), 'details': e.errors()}]
|
|
554
790
|
))
|
|
@@ -560,36 +796,36 @@ async def _run_internal(
|
|
|
560
796
|
if config.on_event:
|
|
561
797
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
562
798
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
563
|
-
content=assistant_message.content
|
|
799
|
+
content=get_text_content(assistant_message.content)
|
|
564
800
|
)))
|
|
565
801
|
if asyncio.iscoroutinefunction(guardrail):
|
|
566
|
-
result = await guardrail(assistant_message.content)
|
|
802
|
+
result = await guardrail(get_text_content(assistant_message.content))
|
|
567
803
|
else:
|
|
568
|
-
result = guardrail(assistant_message.content)
|
|
804
|
+
result = guardrail(get_text_content(assistant_message.content))
|
|
569
805
|
|
|
570
806
|
if not result.is_valid:
|
|
571
807
|
if config.on_event:
|
|
572
808
|
config.on_event(GuardrailEvent(data=GuardrailEventData(
|
|
573
809
|
guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
|
|
574
|
-
content=assistant_message.content,
|
|
810
|
+
content=get_text_content(assistant_message.content),
|
|
575
811
|
is_valid=False,
|
|
576
812
|
error_message=result.error_message
|
|
577
813
|
)))
|
|
578
814
|
return RunResult(
|
|
579
|
-
final_state=replace(state, messages=new_messages),
|
|
815
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
580
816
|
outcome=ErrorOutcome(error=OutputGuardrailTripwire(
|
|
581
817
|
reason=result.error_message or "Output guardrail failed"
|
|
582
818
|
))
|
|
583
819
|
)
|
|
584
820
|
|
|
585
821
|
return RunResult(
|
|
586
|
-
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
|
|
587
|
-
outcome=CompletedOutcome(output=assistant_message.content)
|
|
822
|
+
final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
|
|
823
|
+
outcome=CompletedOutcome(output=get_text_content(assistant_message.content))
|
|
588
824
|
)
|
|
589
825
|
|
|
590
826
|
# Model produced neither content nor tool calls
|
|
591
827
|
return RunResult(
|
|
592
|
-
final_state=replace(state, messages=new_messages),
|
|
828
|
+
final_state=replace(state, messages=new_messages, approvals=state.approvals),
|
|
593
829
|
outcome=ErrorOutcome(error=ModelBehaviorError(
|
|
594
830
|
detail='Model produced neither content nor tool calls'
|
|
595
831
|
))
|
|
@@ -621,6 +857,7 @@ async def _execute_tool_calls(
|
|
|
621
857
|
"""Execute tool calls and return results."""
|
|
622
858
|
|
|
623
859
|
async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]:
|
|
860
|
+
print(f'[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}')
|
|
624
861
|
if config.on_event:
|
|
625
862
|
config.on_event(ToolCallStartEvent(data=to_event_data(ToolCallStartEventData(
|
|
626
863
|
tool_name=tool_call.function.name,
|
|
@@ -640,7 +877,7 @@ async def _execute_tool_calls(
|
|
|
640
877
|
|
|
641
878
|
if not tool:
|
|
642
879
|
error_result = json.dumps({
|
|
643
|
-
'
|
|
880
|
+
'status': 'tool_not_found',
|
|
644
881
|
'message': f'Tool {tool_call.function.name} not found',
|
|
645
882
|
'tool_name': tool_call.function.name,
|
|
646
883
|
})
|
|
@@ -673,7 +910,7 @@ async def _execute_tool_calls(
|
|
|
673
910
|
validated_args = raw_args
|
|
674
911
|
except ValidationError as e:
|
|
675
912
|
error_result = json.dumps({
|
|
676
|
-
'
|
|
913
|
+
'status': 'validation_error',
|
|
677
914
|
'message': f'Invalid arguments for {tool_call.function.name}: {e!s}',
|
|
678
915
|
'tool_name': tool_call.function.name,
|
|
679
916
|
'validation_errors': e.errors()
|
|
@@ -697,16 +934,116 @@ async def _execute_tool_calls(
|
|
|
697
934
|
)
|
|
698
935
|
}
|
|
699
936
|
|
|
937
|
+
# Check if tool needs approval
|
|
938
|
+
needs_approval = False
|
|
939
|
+
approval_func = getattr(tool, 'needs_approval', False)
|
|
940
|
+
if callable(approval_func):
|
|
941
|
+
needs_approval = await approval_func(state.context, validated_args)
|
|
942
|
+
else:
|
|
943
|
+
needs_approval = bool(approval_func)
|
|
944
|
+
|
|
945
|
+
# Check approval status - first by ID, then by signature for cross-session matching
|
|
946
|
+
approval_status = state.approvals.get(tool_call.id)
|
|
947
|
+
if not approval_status:
|
|
948
|
+
signature = f"{tool_call.function.name}:{tool_call.function.arguments}"
|
|
949
|
+
for _, approval in state.approvals.items():
|
|
950
|
+
if approval.additional_context and approval.additional_context.get('signature') == signature:
|
|
951
|
+
approval_status = approval
|
|
952
|
+
break
|
|
953
|
+
|
|
954
|
+
derived_status = None
|
|
955
|
+
if approval_status:
|
|
956
|
+
# Use explicit status if available
|
|
957
|
+
if approval_status.status:
|
|
958
|
+
derived_status = approval_status.status
|
|
959
|
+
# Fall back to approved boolean if status not set
|
|
960
|
+
elif approval_status.approved is True:
|
|
961
|
+
derived_status = 'approved'
|
|
962
|
+
elif approval_status.approved is False:
|
|
963
|
+
if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending':
|
|
964
|
+
derived_status = 'pending'
|
|
965
|
+
else:
|
|
966
|
+
derived_status = 'rejected'
|
|
967
|
+
|
|
968
|
+
is_pending = derived_status == 'pending'
|
|
969
|
+
|
|
970
|
+
# If approval needed and not yet decided, create interruption
|
|
971
|
+
if needs_approval and (approval_status is None or is_pending):
|
|
972
|
+
interruption = ToolApprovalInterruption(
|
|
973
|
+
type='tool_approval',
|
|
974
|
+
tool_call=tool_call,
|
|
975
|
+
agent=agent,
|
|
976
|
+
session_id=str(state.run_id)
|
|
977
|
+
)
|
|
978
|
+
|
|
979
|
+
# Return interrupted result with halted message
|
|
980
|
+
halted_result = json.dumps({
|
|
981
|
+
'status': 'halted',
|
|
982
|
+
'message': f'Tool {tool_call.function.name} requires approval.',
|
|
983
|
+
})
|
|
984
|
+
|
|
985
|
+
return {
|
|
986
|
+
'message': Message(
|
|
987
|
+
role=ContentRole.TOOL,
|
|
988
|
+
content=halted_result,
|
|
989
|
+
tool_call_id=tool_call.id
|
|
990
|
+
),
|
|
991
|
+
'interruption': interruption
|
|
992
|
+
}
|
|
993
|
+
|
|
994
|
+
# If approval was explicitly rejected, return rejection message
|
|
995
|
+
if derived_status == 'rejected':
|
|
996
|
+
rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action'
|
|
997
|
+
rejection_result = json.dumps({
|
|
998
|
+
'status': 'approval_denied',
|
|
999
|
+
'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.',
|
|
1000
|
+
'tool_name': tool_call.function.name,
|
|
1001
|
+
'rejection_reason': rejection_reason,
|
|
1002
|
+
'additional_context': approval_status.additional_context if approval_status else None
|
|
1003
|
+
})
|
|
1004
|
+
|
|
1005
|
+
return {
|
|
1006
|
+
'message': Message(
|
|
1007
|
+
role=ContentRole.TOOL,
|
|
1008
|
+
content=rejection_result,
|
|
1009
|
+
tool_call_id=tool_call.id
|
|
1010
|
+
)
|
|
1011
|
+
}
|
|
1012
|
+
|
|
700
1013
|
# Determine timeout for this tool
|
|
701
1014
|
# Priority: tool-specific timeout > RunConfig default > 30 seconds global default
|
|
702
|
-
|
|
1015
|
+
if tool and hasattr(tool, 'schema'):
|
|
1016
|
+
timeout = getattr(tool.schema, 'timeout', None)
|
|
1017
|
+
else:
|
|
1018
|
+
timeout = None
|
|
703
1019
|
if timeout is None:
|
|
704
1020
|
timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
|
|
705
1021
|
|
|
1022
|
+
# Merge additional context if provided through approval
|
|
1023
|
+
additional_context = approval_status.additional_context if approval_status else None
|
|
1024
|
+
context_with_additional = state.context
|
|
1025
|
+
if additional_context:
|
|
1026
|
+
# Create a copy of context with additional fields from approval
|
|
1027
|
+
if hasattr(state.context, '__dict__'):
|
|
1028
|
+
# For dataclass contexts, add additional context as attributes
|
|
1029
|
+
context_dict = {**state.context.__dict__, **additional_context}
|
|
1030
|
+
context_with_additional = type(state.context)(**{k: v for k, v in context_dict.items() if k in state.context.__dict__})
|
|
1031
|
+
# Add any extra fields as attributes
|
|
1032
|
+
for key, value in additional_context.items():
|
|
1033
|
+
if not hasattr(context_with_additional, key):
|
|
1034
|
+
setattr(context_with_additional, key, value)
|
|
1035
|
+
else:
|
|
1036
|
+
# For dict contexts, merge normally
|
|
1037
|
+
context_with_additional = {**state.context, **additional_context}
|
|
1038
|
+
|
|
1039
|
+
print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}')
|
|
1040
|
+
print(f'[JAF:ENGINE] Tool args:', validated_args)
|
|
1041
|
+
print(f'[JAF:ENGINE] Tool context:', state.context)
|
|
1042
|
+
|
|
706
1043
|
# Execute the tool with timeout
|
|
707
1044
|
try:
|
|
708
1045
|
tool_result = await asyncio.wait_for(
|
|
709
|
-
tool.execute(validated_args,
|
|
1046
|
+
tool.execute(validated_args, context_with_additional),
|
|
710
1047
|
timeout=timeout
|
|
711
1048
|
)
|
|
712
1049
|
except asyncio.TimeoutError:
|
|
@@ -738,14 +1075,41 @@ async def _execute_tool_calls(
|
|
|
738
1075
|
# Handle both string and ToolResult formats
|
|
739
1076
|
if isinstance(tool_result, str):
|
|
740
1077
|
result_string = tool_result
|
|
1078
|
+
print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned string:', result_string)
|
|
741
1079
|
else:
|
|
742
1080
|
# It's a ToolResult object
|
|
743
1081
|
result_string = tool_result_to_string(tool_result)
|
|
1082
|
+
print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:', tool_result)
|
|
1083
|
+
print(f'[JAF:ENGINE] Converted to string:', result_string)
|
|
1084
|
+
|
|
1085
|
+
# Wrap tool result with status information for approval context
|
|
1086
|
+
if approval_status and approval_status.additional_context:
|
|
1087
|
+
final_content = json.dumps({
|
|
1088
|
+
'status': 'approved_and_executed',
|
|
1089
|
+
'result': result_string,
|
|
1090
|
+
'tool_name': tool_call.function.name,
|
|
1091
|
+
'approval_context': approval_status.additional_context,
|
|
1092
|
+
'message': 'Tool was approved and executed successfully with additional context.'
|
|
1093
|
+
})
|
|
1094
|
+
elif needs_approval:
|
|
1095
|
+
final_content = json.dumps({
|
|
1096
|
+
'status': 'approved_and_executed',
|
|
1097
|
+
'result': result_string,
|
|
1098
|
+
'tool_name': tool_call.function.name,
|
|
1099
|
+
'message': 'Tool was approved and executed successfully.'
|
|
1100
|
+
})
|
|
1101
|
+
else:
|
|
1102
|
+
final_content = json.dumps({
|
|
1103
|
+
'status': 'executed',
|
|
1104
|
+
'result': result_string,
|
|
1105
|
+
'tool_name': tool_call.function.name,
|
|
1106
|
+
'message': 'Tool executed successfully.'
|
|
1107
|
+
})
|
|
744
1108
|
|
|
745
1109
|
if config.on_event:
|
|
746
1110
|
config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
|
|
747
1111
|
tool_name=tool_call.function.name,
|
|
748
|
-
result=
|
|
1112
|
+
result=final_content,
|
|
749
1113
|
trace_id=state.trace_id,
|
|
750
1114
|
run_id=state.run_id,
|
|
751
1115
|
tool_result=tool_result,
|
|
@@ -758,7 +1122,7 @@ async def _execute_tool_calls(
|
|
|
758
1122
|
return {
|
|
759
1123
|
'message': Message(
|
|
760
1124
|
role=ContentRole.TOOL,
|
|
761
|
-
content=
|
|
1125
|
+
content=final_content,
|
|
762
1126
|
tool_call_id=tool_call.id
|
|
763
1127
|
),
|
|
764
1128
|
'is_handoff': True,
|
|
@@ -768,14 +1132,14 @@ async def _execute_tool_calls(
|
|
|
768
1132
|
return {
|
|
769
1133
|
'message': Message(
|
|
770
1134
|
role=ContentRole.TOOL,
|
|
771
|
-
content=
|
|
1135
|
+
content=final_content,
|
|
772
1136
|
tool_call_id=tool_call.id
|
|
773
1137
|
)
|
|
774
1138
|
}
|
|
775
1139
|
|
|
776
1140
|
except Exception as error:
|
|
777
1141
|
error_result = json.dumps({
|
|
778
|
-
'
|
|
1142
|
+
'status': 'execution_error',
|
|
779
1143
|
'message': str(error),
|
|
780
1144
|
'tool_name': tool_call.function.name,
|
|
781
1145
|
})
|