jaf-py 2.4.1__py3-none-any.whl → 2.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/core/engine.py CHANGED
@@ -8,6 +8,7 @@ tool calling, and state management while maintaining functional purity.
8
8
  import asyncio
9
9
  import json
10
10
  import os
11
+ import time
11
12
  from dataclasses import replace, asdict, is_dataclass
12
13
  from typing import Any, Dict, List, Optional, TypeVar
13
14
 
@@ -18,6 +19,7 @@ from .tool_results import tool_result_to_string
18
19
  from .types import (
19
20
  Agent,
20
21
  AgentNotFound,
22
+ ApprovalValue,
21
23
  CompletedOutcome,
22
24
  ContentRole,
23
25
  DecodeError,
@@ -26,6 +28,8 @@ from .types import (
26
28
  HandoffEvent,
27
29
  HandoffEventData,
28
30
  InputGuardrailTripwire,
31
+ InterruptedOutcome,
32
+ Interruption,
29
33
  GuardrailEvent,
30
34
  GuardrailEventData,
31
35
  MemoryEvent,
@@ -40,6 +44,7 @@ from .types import (
40
44
  AssistantMessageEventData,
41
45
  MaxTurnsExceeded,
42
46
  Message,
47
+ get_text_content,
43
48
  ModelBehaviorError,
44
49
  OutputGuardrailTripwire,
45
50
  RunConfig,
@@ -49,6 +54,7 @@ from .types import (
49
54
  RunStartEvent,
50
55
  RunStartEventData,
51
56
  RunState,
57
+ ToolApprovalInterruption,
52
58
  ToolCall,
53
59
  ToolCallEndEvent,
54
60
  ToolCallEndEventData,
@@ -80,6 +86,82 @@ def to_event_data(value: Any) -> Any:
80
86
  Ctx = TypeVar('Ctx')
81
87
  Out = TypeVar('Out')
82
88
 
89
+ async def try_resume_pending_tool_calls(
90
+ state: RunState[Ctx],
91
+ config: RunConfig[Ctx]
92
+ ) -> Optional[RunResult[Out]]:
93
+ """
94
+ Try to resume pending tool calls if the last assistant message contained tool_calls
95
+ and some of those calls have not yet produced tool results.
96
+ """
97
+ try:
98
+ messages = state.messages
99
+ for i in range(len(messages) - 1, -1, -1):
100
+ msg = messages[i]
101
+ # Handle both string and enum roles
102
+ role_str = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
103
+ if role_str == 'assistant' and msg.tool_calls:
104
+ tool_call_ids = {tc.id for tc in msg.tool_calls}
105
+
106
+ # Scan forward for tool results tied to these ids
107
+ executed_ids = set()
108
+ for j in range(i + 1, len(messages)):
109
+ m = messages[j]
110
+ # Handle both string and enum roles
111
+ m_role_str = m.role.value if hasattr(m.role, 'value') else str(m.role)
112
+ if m_role_str == 'tool' and m.tool_call_id and m.tool_call_id in tool_call_ids:
113
+ executed_ids.add(m.tool_call_id)
114
+
115
+ pending_tool_calls = [tc for tc in msg.tool_calls if tc.id not in executed_ids]
116
+
117
+ if not pending_tool_calls:
118
+ continue # Continue checking other assistant messages
119
+
120
+ current_agent = config.agent_registry.get(state.current_agent_name)
121
+ if not current_agent:
122
+ return RunResult(
123
+ final_state=state,
124
+ outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
125
+ )
126
+
127
+ # Execute pending tool calls
128
+ tool_results = await _execute_tool_calls(
129
+ pending_tool_calls,
130
+ current_agent,
131
+ state,
132
+ config
133
+ )
134
+
135
+ # Check for interruptions
136
+ interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
137
+ if interruptions:
138
+ completed_results = [r for r in tool_results if not r.get('interruption')]
139
+ interrupted_state = replace(
140
+ state,
141
+ messages=list(state.messages) + [r['message'] for r in completed_results],
142
+ turn_count=state.turn_count,
143
+ approvals=state.approvals
144
+ )
145
+ return RunResult(
146
+ final_state=interrupted_state,
147
+ outcome=InterruptedOutcome(interruptions=interruptions)
148
+ )
149
+
150
+ # Continue with normal execution
151
+ next_state = replace(
152
+ state,
153
+ messages=list(state.messages) + [r['message'] for r in tool_results],
154
+ turn_count=state.turn_count,
155
+ approvals=state.approvals
156
+ )
157
+ return await _run_internal(next_state, config)
158
+
159
+ except Exception as e:
160
+ # Best-effort resume; ignore and continue normal flow
161
+ pass
162
+
163
+ return None
164
+
83
165
  async def run(
84
166
  initial_state: RunState[Ctx],
85
167
  config: RunConfig[Ctx]
@@ -102,9 +184,26 @@ async def run(
102
184
  ))))
103
185
 
104
186
  state_with_memory = await _load_conversation_history(initial_state, config)
187
+
188
+ # Load approvals from storage if configured
189
+ if config.approval_storage:
190
+ print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}')
191
+ from .state import load_approvals_into_state
192
+ state_with_memory = await load_approvals_into_state(state_with_memory, config)
193
+
105
194
  result = await _run_internal(state_with_memory, config)
106
195
 
107
- await _store_conversation_history(result.final_state, config)
196
+ # Store conversation history only if this is a final completion of the entire conversation
197
+ # For HITL scenarios, storage happens on interruption to allow resumption
198
+ # We only store on completion if explicitly indicated this is the end of the conversation
199
+ if (config.memory and config.memory.auto_store and config.conversation_id and
200
+ result.outcome.status == 'completed' and getattr(config.memory, 'store_on_completion', True)):
201
+ print(f'[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}')
202
+ await _store_conversation_history(result.final_state, config)
203
+ elif result.outcome.status == 'interrupted':
204
+ print('[JAF:ENGINE] Conversation interrupted - storage already handled during interruption')
205
+ else:
206
+ print(f'[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, "store_on_completion", True) if config.memory else "N/A"}')
108
207
 
109
208
  if config.on_event:
110
209
  config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
@@ -154,7 +253,46 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
154
253
  conversation_data = result.data
155
254
  if conversation_data:
156
255
  max_messages = config.memory.max_messages or len(conversation_data.messages)
157
- memory_messages = conversation_data.messages[-max_messages:]
256
+ all_memory_messages = conversation_data.messages[-max_messages:]
257
+
258
+ # Filter out halted messages - they're for audit/database only, not for LLM context
259
+ memory_messages = []
260
+ filtered_count = 0
261
+
262
+ for msg in all_memory_messages:
263
+ if msg.role not in (ContentRole.TOOL, 'tool'):
264
+ memory_messages.append(msg)
265
+ else:
266
+ try:
267
+ content = json.loads(msg.content)
268
+ status = content.get('status')
269
+ # Filter out ALL halted messages (they're for audit only)
270
+ if status == 'halted':
271
+ filtered_count += 1
272
+ continue # Skip this halted message
273
+ else:
274
+ memory_messages.append(msg)
275
+ except (json.JSONDecodeError, TypeError):
276
+ # Keep non-JSON tool messages
277
+ memory_messages.append(msg)
278
+
279
+ # For HITL scenarios, append new messages to memory messages
280
+ # This prevents duplication when resuming from interruptions
281
+ if memory_messages:
282
+ combined_messages = memory_messages + [
283
+ msg for msg in state.messages
284
+ if not any(
285
+ mem_msg.role == msg.role and
286
+ mem_msg.content == msg.content and
287
+ getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None)
288
+ for mem_msg in memory_messages
289
+ )
290
+ ]
291
+ else:
292
+ combined_messages = list(state.messages)
293
+
294
+ # Approvals will be loaded separately via approval storage system
295
+ approvals_map = state.approvals
158
296
 
159
297
  # Calculate turn count efficiently
160
298
  memory_assistant_count = sum(1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
@@ -174,10 +312,17 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
174
312
  status='end',
175
313
  message_count=len(memory_messages)
176
314
  )))
315
+
316
+ if filtered_count > 0:
317
+ print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)')
318
+ else:
319
+ print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory')
320
+
177
321
  return replace(
178
322
  state,
179
- messages=list(memory_messages) + list(state.messages),
180
- turn_count=turn_count
323
+ messages=combined_messages,
324
+ turn_count=turn_count,
325
+ approvals=approvals_map
181
326
  )
182
327
  return state
183
328
 
@@ -199,12 +344,23 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct
199
344
  keep_recent = config.memory.compression_threshold - keep_first
200
345
  messages_to_store = messages_to_store[:keep_first] + messages_to_store[-keep_recent:]
201
346
 
347
+ # Store approval information if any approvals were made
348
+ approval_metadata = {}
349
+ if state.approvals:
350
+ approval_metadata = {
351
+ "approval_count": len(state.approvals),
352
+ "approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved],
353
+ "rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved],
354
+ "has_approvals": True
355
+ }
356
+
202
357
  metadata = {
203
358
  "user_id": getattr(state.context, 'user_id', None),
204
359
  "trace_id": str(state.trace_id),
205
360
  "run_id": str(state.run_id),
206
361
  "agent_name": state.current_agent_name,
207
- "turn_count": state.turn_count
362
+ "turn_count": state.turn_count,
363
+ **approval_metadata
208
364
  }
209
365
 
210
366
  result = await config.memory.provider.store_messages(config.conversation_id, messages_to_store, metadata)
@@ -236,6 +392,11 @@ async def _run_internal(
236
392
  config: RunConfig[Ctx]
237
393
  ) -> RunResult[Out]:
238
394
  """Internal run function with recursive execution logic."""
395
+ # Try to resume pending tool calls first
396
+ resumed = await try_resume_pending_tool_calls(state, config)
397
+ if resumed:
398
+ return resumed
399
+
239
400
  # Check initial input guardrails on first turn
240
401
  if state.turn_count == 0:
241
402
  first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
@@ -244,18 +405,18 @@ async def _run_internal(
244
405
  if config.on_event:
245
406
  config.on_event(GuardrailEvent(data=GuardrailEventData(
246
407
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
247
- content=first_user_message.content
408
+ content=get_text_content(first_user_message.content)
248
409
  )))
249
410
  if asyncio.iscoroutinefunction(guardrail):
250
- result = await guardrail(first_user_message.content)
411
+ result = await guardrail(get_text_content(first_user_message.content))
251
412
  else:
252
- result = guardrail(first_user_message.content)
413
+ result = guardrail(get_text_content(first_user_message.content))
253
414
 
254
415
  if not result.is_valid:
255
416
  if config.on_event:
256
417
  config.on_event(GuardrailEvent(data=GuardrailEventData(
257
418
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
258
- content=first_user_message.content,
419
+ content=get_text_content(first_user_message.content),
259
420
  is_valid=False,
260
421
  error_message=result.error_message
261
422
  )))
@@ -448,6 +609,45 @@ async def _run_internal(
448
609
  config
449
610
  )
450
611
 
612
+ # Check for interruptions
613
+ interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
614
+ if interruptions:
615
+ # Separate completed tool results from interrupted ones
616
+ completed_results = [r for r in tool_results if not r.get('interruption')]
617
+ approval_required_results = [r for r in tool_results if r.get('interruption')]
618
+
619
+ # Add pending approvals to state.approvals
620
+ updated_approvals = dict(state.approvals)
621
+ for interruption in interruptions:
622
+ if interruption.type == 'tool_approval':
623
+ updated_approvals[interruption.tool_call.id] = ApprovalValue(
624
+ status='pending',
625
+ approved=False,
626
+ additional_context={'status': 'pending', 'timestamp': str(int(time.time() * 1000))}
627
+ )
628
+
629
+ # Create state with only completed tool results (for LLM context)
630
+ interrupted_state = replace(
631
+ state,
632
+ messages=new_messages + [r['message'] for r in completed_results],
633
+ turn_count=state.turn_count + 1,
634
+ approvals=updated_approvals
635
+ )
636
+
637
+ # Store conversation state with ALL messages including approval-required (for database records)
638
+ if config.memory and config.memory.auto_store and config.conversation_id:
639
+ print(f'[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}')
640
+ state_for_storage = replace(
641
+ interrupted_state,
642
+ messages=interrupted_state.messages + [r['message'] for r in approval_required_results]
643
+ )
644
+ await _store_conversation_history(state_for_storage, config)
645
+
646
+ return RunResult(
647
+ final_state=interrupted_state,
648
+ outcome=InterruptedOutcome(interruptions=interruptions)
649
+ )
650
+
451
651
  # Check for handoffs
452
652
  handoff_result = next((r for r in tool_results if r.get('is_handoff')), None)
453
653
  if handoff_result:
@@ -469,40 +669,76 @@ async def _run_internal(
469
669
  to=target_agent
470
670
  ))))
471
671
 
672
+ # Remove any halted messages that are being replaced by actual execution results
673
+ cleaned_new_messages = []
674
+ for msg in new_messages:
675
+ if msg.role not in (ContentRole.TOOL, 'tool'):
676
+ cleaned_new_messages.append(msg)
677
+ else:
678
+ try:
679
+ content = json.loads(msg.content)
680
+ if content.get('status') == 'halted':
681
+ # Remove this halted message if we have a new result for the same tool_call_id
682
+ if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
683
+ cleaned_new_messages.append(msg)
684
+ else:
685
+ cleaned_new_messages.append(msg)
686
+ except (json.JSONDecodeError, TypeError):
687
+ cleaned_new_messages.append(msg)
688
+
472
689
  # Continue with new agent
473
690
  next_state = replace(
474
691
  state,
475
- messages=new_messages + [r['message'] for r in tool_results],
692
+ messages=cleaned_new_messages + [r['message'] for r in tool_results],
476
693
  current_agent_name=target_agent,
477
- turn_count=state.turn_count + 1
694
+ turn_count=state.turn_count + 1,
695
+ approvals=state.approvals
478
696
  )
479
697
 
480
698
  return await _run_internal(next_state, config)
481
699
 
700
+ # Remove any halted messages that are being replaced by actual execution results
701
+ cleaned_new_messages = []
702
+ for msg in new_messages:
703
+ if msg.role not in (ContentRole.TOOL, 'tool'):
704
+ cleaned_new_messages.append(msg)
705
+ else:
706
+ try:
707
+ content = json.loads(msg.content)
708
+ if content.get('status') == 'halted':
709
+ # Remove this halted message if we have a new result for the same tool_call_id
710
+ if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
711
+ cleaned_new_messages.append(msg)
712
+ else:
713
+ cleaned_new_messages.append(msg)
714
+ except (json.JSONDecodeError, TypeError):
715
+ cleaned_new_messages.append(msg)
716
+
482
717
  # Continue with tool results
483
718
  next_state = replace(
484
719
  state,
485
- messages=new_messages + [r['message'] for r in tool_results],
486
- turn_count=state.turn_count + 1
720
+ messages=cleaned_new_messages + [r['message'] for r in tool_results],
721
+ turn_count=state.turn_count + 1,
722
+ approvals=state.approvals
487
723
  )
488
724
 
489
725
  return await _run_internal(next_state, config)
490
726
 
491
727
  # Handle text completion
492
- if assistant_message.content:
728
+ if get_text_content(assistant_message.content):
493
729
  if current_agent.output_codec:
494
730
  # Parse with output codec
495
731
  if config.on_event:
496
732
  config.on_event(OutputParseEvent(data=OutputParseEventData(
497
- content=assistant_message.content,
733
+ content=get_text_content(assistant_message.content),
498
734
  status='start'
499
735
  )))
500
736
  try:
501
- parsed_content = _try_parse_json(assistant_message.content)
737
+ parsed_content = _try_parse_json(get_text_content(assistant_message.content))
502
738
  output_data = current_agent.output_codec.model_validate(parsed_content)
503
739
  if config.on_event:
504
740
  config.on_event(OutputParseEvent(data=OutputParseEventData(
505
- content=assistant_message.content,
741
+ content=get_text_content(assistant_message.content),
506
742
  status='end',
507
743
  parsed_output=output_data
508
744
  )))
@@ -529,26 +765,26 @@ async def _run_internal(
529
765
  error_message=result.error_message
530
766
  )))
531
767
  return RunResult(
532
- final_state=replace(state, messages=new_messages),
768
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
533
769
  outcome=ErrorOutcome(error=OutputGuardrailTripwire(
534
770
  reason=result.error_message or "Output guardrail failed"
535
771
  ))
536
772
  )
537
773
 
538
774
  return RunResult(
539
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
775
+ final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
540
776
  outcome=CompletedOutcome(output=output_data)
541
777
  )
542
778
 
543
779
  except ValidationError as e:
544
780
  if config.on_event:
545
781
  config.on_event(OutputParseEvent(data=OutputParseEventData(
546
- content=assistant_message.content,
782
+ content=get_text_content(assistant_message.content),
547
783
  status='fail',
548
784
  error=str(e)
549
785
  )))
550
786
  return RunResult(
551
- final_state=replace(state, messages=new_messages),
787
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
552
788
  outcome=ErrorOutcome(error=DecodeError(
553
789
  errors=[{'message': str(e), 'details': e.errors()}]
554
790
  ))
@@ -560,36 +796,36 @@ async def _run_internal(
560
796
  if config.on_event:
561
797
  config.on_event(GuardrailEvent(data=GuardrailEventData(
562
798
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
563
- content=assistant_message.content
799
+ content=get_text_content(assistant_message.content)
564
800
  )))
565
801
  if asyncio.iscoroutinefunction(guardrail):
566
- result = await guardrail(assistant_message.content)
802
+ result = await guardrail(get_text_content(assistant_message.content))
567
803
  else:
568
- result = guardrail(assistant_message.content)
804
+ result = guardrail(get_text_content(assistant_message.content))
569
805
 
570
806
  if not result.is_valid:
571
807
  if config.on_event:
572
808
  config.on_event(GuardrailEvent(data=GuardrailEventData(
573
809
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
574
- content=assistant_message.content,
810
+ content=get_text_content(assistant_message.content),
575
811
  is_valid=False,
576
812
  error_message=result.error_message
577
813
  )))
578
814
  return RunResult(
579
- final_state=replace(state, messages=new_messages),
815
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
580
816
  outcome=ErrorOutcome(error=OutputGuardrailTripwire(
581
817
  reason=result.error_message or "Output guardrail failed"
582
818
  ))
583
819
  )
584
820
 
585
821
  return RunResult(
586
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
587
- outcome=CompletedOutcome(output=assistant_message.content)
822
+ final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
823
+ outcome=CompletedOutcome(output=get_text_content(assistant_message.content))
588
824
  )
589
825
 
590
826
  # Model produced neither content nor tool calls
591
827
  return RunResult(
592
- final_state=replace(state, messages=new_messages),
828
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
593
829
  outcome=ErrorOutcome(error=ModelBehaviorError(
594
830
  detail='Model produced neither content nor tool calls'
595
831
  ))
@@ -621,6 +857,7 @@ async def _execute_tool_calls(
621
857
  """Execute tool calls and return results."""
622
858
 
623
859
  async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]:
860
+ print(f'[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}')
624
861
  if config.on_event:
625
862
  config.on_event(ToolCallStartEvent(data=to_event_data(ToolCallStartEventData(
626
863
  tool_name=tool_call.function.name,
@@ -640,7 +877,7 @@ async def _execute_tool_calls(
640
877
 
641
878
  if not tool:
642
879
  error_result = json.dumps({
643
- 'error': 'tool_not_found',
880
+ 'status': 'tool_not_found',
644
881
  'message': f'Tool {tool_call.function.name} not found',
645
882
  'tool_name': tool_call.function.name,
646
883
  })
@@ -673,7 +910,7 @@ async def _execute_tool_calls(
673
910
  validated_args = raw_args
674
911
  except ValidationError as e:
675
912
  error_result = json.dumps({
676
- 'error': 'validation_error',
913
+ 'status': 'validation_error',
677
914
  'message': f'Invalid arguments for {tool_call.function.name}: {e!s}',
678
915
  'tool_name': tool_call.function.name,
679
916
  'validation_errors': e.errors()
@@ -697,16 +934,116 @@ async def _execute_tool_calls(
697
934
  )
698
935
  }
699
936
 
937
+ # Check if tool needs approval
938
+ needs_approval = False
939
+ approval_func = getattr(tool, 'needs_approval', False)
940
+ if callable(approval_func):
941
+ needs_approval = await approval_func(state.context, validated_args)
942
+ else:
943
+ needs_approval = bool(approval_func)
944
+
945
+ # Check approval status - first by ID, then by signature for cross-session matching
946
+ approval_status = state.approvals.get(tool_call.id)
947
+ if not approval_status:
948
+ signature = f"{tool_call.function.name}:{tool_call.function.arguments}"
949
+ for _, approval in state.approvals.items():
950
+ if approval.additional_context and approval.additional_context.get('signature') == signature:
951
+ approval_status = approval
952
+ break
953
+
954
+ derived_status = None
955
+ if approval_status:
956
+ # Use explicit status if available
957
+ if approval_status.status:
958
+ derived_status = approval_status.status
959
+ # Fall back to approved boolean if status not set
960
+ elif approval_status.approved is True:
961
+ derived_status = 'approved'
962
+ elif approval_status.approved is False:
963
+ if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending':
964
+ derived_status = 'pending'
965
+ else:
966
+ derived_status = 'rejected'
967
+
968
+ is_pending = derived_status == 'pending'
969
+
970
+ # If approval needed and not yet decided, create interruption
971
+ if needs_approval and (approval_status is None or is_pending):
972
+ interruption = ToolApprovalInterruption(
973
+ type='tool_approval',
974
+ tool_call=tool_call,
975
+ agent=agent,
976
+ session_id=str(state.run_id)
977
+ )
978
+
979
+ # Return interrupted result with halted message
980
+ halted_result = json.dumps({
981
+ 'status': 'halted',
982
+ 'message': f'Tool {tool_call.function.name} requires approval.',
983
+ })
984
+
985
+ return {
986
+ 'message': Message(
987
+ role=ContentRole.TOOL,
988
+ content=halted_result,
989
+ tool_call_id=tool_call.id
990
+ ),
991
+ 'interruption': interruption
992
+ }
993
+
994
+ # If approval was explicitly rejected, return rejection message
995
+ if derived_status == 'rejected':
996
+ rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action'
997
+ rejection_result = json.dumps({
998
+ 'status': 'approval_denied',
999
+ 'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.',
1000
+ 'tool_name': tool_call.function.name,
1001
+ 'rejection_reason': rejection_reason,
1002
+ 'additional_context': approval_status.additional_context if approval_status else None
1003
+ })
1004
+
1005
+ return {
1006
+ 'message': Message(
1007
+ role=ContentRole.TOOL,
1008
+ content=rejection_result,
1009
+ tool_call_id=tool_call.id
1010
+ )
1011
+ }
1012
+
700
1013
  # Determine timeout for this tool
701
1014
  # Priority: tool-specific timeout > RunConfig default > 30 seconds global default
702
- timeout = getattr(tool.schema, 'timeout', None)
1015
+ if tool and hasattr(tool, 'schema'):
1016
+ timeout = getattr(tool.schema, 'timeout', None)
1017
+ else:
1018
+ timeout = None
703
1019
  if timeout is None:
704
1020
  timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
705
1021
 
1022
+ # Merge additional context if provided through approval
1023
+ additional_context = approval_status.additional_context if approval_status else None
1024
+ context_with_additional = state.context
1025
+ if additional_context:
1026
+ # Create a copy of context with additional fields from approval
1027
+ if hasattr(state.context, '__dict__'):
1028
+ # For dataclass contexts, add additional context as attributes
1029
+ context_dict = {**state.context.__dict__, **additional_context}
1030
+ context_with_additional = type(state.context)(**{k: v for k, v in context_dict.items() if k in state.context.__dict__})
1031
+ # Add any extra fields as attributes
1032
+ for key, value in additional_context.items():
1033
+ if not hasattr(context_with_additional, key):
1034
+ setattr(context_with_additional, key, value)
1035
+ else:
1036
+ # For dict contexts, merge normally
1037
+ context_with_additional = {**state.context, **additional_context}
1038
+
1039
+ print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}')
1040
+ print(f'[JAF:ENGINE] Tool args:', validated_args)
1041
+ print(f'[JAF:ENGINE] Tool context:', state.context)
1042
+
706
1043
  # Execute the tool with timeout
707
1044
  try:
708
1045
  tool_result = await asyncio.wait_for(
709
- tool.execute(validated_args, state.context),
1046
+ tool.execute(validated_args, context_with_additional),
710
1047
  timeout=timeout
711
1048
  )
712
1049
  except asyncio.TimeoutError:
@@ -738,14 +1075,41 @@ async def _execute_tool_calls(
738
1075
  # Handle both string and ToolResult formats
739
1076
  if isinstance(tool_result, str):
740
1077
  result_string = tool_result
1078
+ print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned string:', result_string)
741
1079
  else:
742
1080
  # It's a ToolResult object
743
1081
  result_string = tool_result_to_string(tool_result)
1082
+ print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:', tool_result)
1083
+ print(f'[JAF:ENGINE] Converted to string:', result_string)
1084
+
1085
+ # Wrap tool result with status information for approval context
1086
+ if approval_status and approval_status.additional_context:
1087
+ final_content = json.dumps({
1088
+ 'status': 'approved_and_executed',
1089
+ 'result': result_string,
1090
+ 'tool_name': tool_call.function.name,
1091
+ 'approval_context': approval_status.additional_context,
1092
+ 'message': 'Tool was approved and executed successfully with additional context.'
1093
+ })
1094
+ elif needs_approval:
1095
+ final_content = json.dumps({
1096
+ 'status': 'approved_and_executed',
1097
+ 'result': result_string,
1098
+ 'tool_name': tool_call.function.name,
1099
+ 'message': 'Tool was approved and executed successfully.'
1100
+ })
1101
+ else:
1102
+ final_content = json.dumps({
1103
+ 'status': 'executed',
1104
+ 'result': result_string,
1105
+ 'tool_name': tool_call.function.name,
1106
+ 'message': 'Tool executed successfully.'
1107
+ })
744
1108
 
745
1109
  if config.on_event:
746
1110
  config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
747
1111
  tool_name=tool_call.function.name,
748
- result=result_string,
1112
+ result=final_content,
749
1113
  trace_id=state.trace_id,
750
1114
  run_id=state.run_id,
751
1115
  tool_result=tool_result,
@@ -758,7 +1122,7 @@ async def _execute_tool_calls(
758
1122
  return {
759
1123
  'message': Message(
760
1124
  role=ContentRole.TOOL,
761
- content=result_string,
1125
+ content=final_content,
762
1126
  tool_call_id=tool_call.id
763
1127
  ),
764
1128
  'is_handoff': True,
@@ -768,14 +1132,14 @@ async def _execute_tool_calls(
768
1132
  return {
769
1133
  'message': Message(
770
1134
  role=ContentRole.TOOL,
771
- content=result_string,
1135
+ content=final_content,
772
1136
  tool_call_id=tool_call.id
773
1137
  )
774
1138
  }
775
1139
 
776
1140
  except Exception as error:
777
1141
  error_result = json.dumps({
778
- 'error': 'execution_error',
1142
+ 'status': 'execution_error',
779
1143
  'message': str(error),
780
1144
  'tool_name': tool_call.function.name,
781
1145
  })