jaf-py 2.3.1__py3-none-any.whl → 2.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/core/engine.py CHANGED
@@ -8,6 +8,7 @@ tool calling, and state management while maintaining functional purity.
8
8
  import asyncio
9
9
  import json
10
10
  import os
11
+ import time
11
12
  from dataclasses import replace, asdict, is_dataclass
12
13
  from typing import Any, Dict, List, Optional, TypeVar
13
14
 
@@ -18,6 +19,7 @@ from .tool_results import tool_result_to_string
18
19
  from .types import (
19
20
  Agent,
20
21
  AgentNotFound,
22
+ ApprovalValue,
21
23
  CompletedOutcome,
22
24
  ContentRole,
23
25
  DecodeError,
@@ -26,6 +28,8 @@ from .types import (
26
28
  HandoffEvent,
27
29
  HandoffEventData,
28
30
  InputGuardrailTripwire,
31
+ InterruptedOutcome,
32
+ Interruption,
29
33
  GuardrailEvent,
30
34
  GuardrailEventData,
31
35
  MemoryEvent,
@@ -36,8 +40,11 @@ from .types import (
36
40
  LLMCallEndEventData,
37
41
  LLMCallStartEvent,
38
42
  LLMCallStartEventData,
43
+ AssistantMessageEvent,
44
+ AssistantMessageEventData,
39
45
  MaxTurnsExceeded,
40
46
  Message,
47
+ get_text_content,
41
48
  ModelBehaviorError,
42
49
  OutputGuardrailTripwire,
43
50
  RunConfig,
@@ -47,6 +54,7 @@ from .types import (
47
54
  RunStartEvent,
48
55
  RunStartEventData,
49
56
  RunState,
57
+ ToolApprovalInterruption,
50
58
  ToolCall,
51
59
  ToolCallEndEvent,
52
60
  ToolCallEndEventData,
@@ -78,6 +86,82 @@ def to_event_data(value: Any) -> Any:
78
86
  Ctx = TypeVar('Ctx')
79
87
  Out = TypeVar('Out')
80
88
 
89
+ async def try_resume_pending_tool_calls(
90
+ state: RunState[Ctx],
91
+ config: RunConfig[Ctx]
92
+ ) -> Optional[RunResult[Out]]:
93
+ """
94
+ Try to resume pending tool calls if the last assistant message contained tool_calls
95
+ and some of those calls have not yet produced tool results.
96
+ """
97
+ try:
98
+ messages = state.messages
99
+ for i in range(len(messages) - 1, -1, -1):
100
+ msg = messages[i]
101
+ # Handle both string and enum roles
102
+ role_str = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
103
+ if role_str == 'assistant' and msg.tool_calls:
104
+ tool_call_ids = {tc.id for tc in msg.tool_calls}
105
+
106
+ # Scan forward for tool results tied to these ids
107
+ executed_ids = set()
108
+ for j in range(i + 1, len(messages)):
109
+ m = messages[j]
110
+ # Handle both string and enum roles
111
+ m_role_str = m.role.value if hasattr(m.role, 'value') else str(m.role)
112
+ if m_role_str == 'tool' and m.tool_call_id and m.tool_call_id in tool_call_ids:
113
+ executed_ids.add(m.tool_call_id)
114
+
115
+ pending_tool_calls = [tc for tc in msg.tool_calls if tc.id not in executed_ids]
116
+
117
+ if not pending_tool_calls:
118
+ continue # Continue checking other assistant messages
119
+
120
+ current_agent = config.agent_registry.get(state.current_agent_name)
121
+ if not current_agent:
122
+ return RunResult(
123
+ final_state=state,
124
+ outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
125
+ )
126
+
127
+ # Execute pending tool calls
128
+ tool_results = await _execute_tool_calls(
129
+ pending_tool_calls,
130
+ current_agent,
131
+ state,
132
+ config
133
+ )
134
+
135
+ # Check for interruptions
136
+ interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
137
+ if interruptions:
138
+ completed_results = [r for r in tool_results if not r.get('interruption')]
139
+ interrupted_state = replace(
140
+ state,
141
+ messages=list(state.messages) + [r['message'] for r in completed_results],
142
+ turn_count=state.turn_count,
143
+ approvals=state.approvals
144
+ )
145
+ return RunResult(
146
+ final_state=interrupted_state,
147
+ outcome=InterruptedOutcome(interruptions=interruptions)
148
+ )
149
+
150
+ # Continue with normal execution
151
+ next_state = replace(
152
+ state,
153
+ messages=list(state.messages) + [r['message'] for r in tool_results],
154
+ turn_count=state.turn_count,
155
+ approvals=state.approvals
156
+ )
157
+ return await _run_internal(next_state, config)
158
+
159
+ except Exception as e:
160
+ # Best-effort resume; ignore and continue normal flow
161
+ pass
162
+
163
+ return None
164
+
81
165
  async def run(
82
166
  initial_state: RunState[Ctx],
83
167
  config: RunConfig[Ctx]
@@ -100,9 +184,26 @@ async def run(
100
184
  ))))
101
185
 
102
186
  state_with_memory = await _load_conversation_history(initial_state, config)
187
+
188
+ # Load approvals from storage if configured
189
+ if config.approval_storage:
190
+ print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}')
191
+ from .state import load_approvals_into_state
192
+ state_with_memory = await load_approvals_into_state(state_with_memory, config)
193
+
103
194
  result = await _run_internal(state_with_memory, config)
104
195
 
105
- await _store_conversation_history(result.final_state, config)
196
+ # Store conversation history only if this is a final completion of the entire conversation
197
+ # For HITL scenarios, storage happens on interruption to allow resumption
198
+ # We only store on completion if explicitly indicated this is the end of the conversation
199
+ if (config.memory and config.memory.auto_store and config.conversation_id and
200
+ result.outcome.status == 'completed' and getattr(config.memory, 'store_on_completion', True)):
201
+ print(f'[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}')
202
+ await _store_conversation_history(result.final_state, config)
203
+ elif result.outcome.status == 'interrupted':
204
+ print('[JAF:ENGINE] Conversation interrupted - storage already handled during interruption')
205
+ else:
206
+ print(f'[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, "store_on_completion", True) if config.memory else "N/A"}')
106
207
 
107
208
  if config.on_event:
108
209
  config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
@@ -152,7 +253,46 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
152
253
  conversation_data = result.data
153
254
  if conversation_data:
154
255
  max_messages = config.memory.max_messages or len(conversation_data.messages)
155
- memory_messages = conversation_data.messages[-max_messages:]
256
+ all_memory_messages = conversation_data.messages[-max_messages:]
257
+
258
+ # Filter out halted messages - they're for audit/database only, not for LLM context
259
+ memory_messages = []
260
+ filtered_count = 0
261
+
262
+ for msg in all_memory_messages:
263
+ if msg.role not in (ContentRole.TOOL, 'tool'):
264
+ memory_messages.append(msg)
265
+ else:
266
+ try:
267
+ content = json.loads(msg.content)
268
+ status = content.get('status')
269
+ # Filter out ALL halted messages (they're for audit only)
270
+ if status == 'halted':
271
+ filtered_count += 1
272
+ continue # Skip this halted message
273
+ else:
274
+ memory_messages.append(msg)
275
+ except (json.JSONDecodeError, TypeError):
276
+ # Keep non-JSON tool messages
277
+ memory_messages.append(msg)
278
+
279
+ # For HITL scenarios, append new messages to memory messages
280
+ # This prevents duplication when resuming from interruptions
281
+ if memory_messages:
282
+ combined_messages = memory_messages + [
283
+ msg for msg in state.messages
284
+ if not any(
285
+ mem_msg.role == msg.role and
286
+ mem_msg.content == msg.content and
287
+ getattr(mem_msg, 'tool_calls', None) == getattr(msg, 'tool_calls', None)
288
+ for mem_msg in memory_messages
289
+ )
290
+ ]
291
+ else:
292
+ combined_messages = list(state.messages)
293
+
294
+ # Approvals will be loaded separately via approval storage system
295
+ approvals_map = state.approvals
156
296
 
157
297
  # Calculate turn count efficiently
158
298
  memory_assistant_count = sum(1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
@@ -172,10 +312,17 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
172
312
  status='end',
173
313
  message_count=len(memory_messages)
174
314
  )))
315
+
316
+ if filtered_count > 0:
317
+ print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)')
318
+ else:
319
+ print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory')
320
+
175
321
  return replace(
176
322
  state,
177
- messages=list(memory_messages) + list(state.messages),
178
- turn_count=turn_count
323
+ messages=combined_messages,
324
+ turn_count=turn_count,
325
+ approvals=approvals_map
179
326
  )
180
327
  return state
181
328
 
@@ -197,12 +344,23 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct
197
344
  keep_recent = config.memory.compression_threshold - keep_first
198
345
  messages_to_store = messages_to_store[:keep_first] + messages_to_store[-keep_recent:]
199
346
 
347
+ # Store approval information if any approvals were made
348
+ approval_metadata = {}
349
+ if state.approvals:
350
+ approval_metadata = {
351
+ "approval_count": len(state.approvals),
352
+ "approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved],
353
+ "rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved],
354
+ "has_approvals": True
355
+ }
356
+
200
357
  metadata = {
201
358
  "user_id": getattr(state.context, 'user_id', None),
202
359
  "trace_id": str(state.trace_id),
203
360
  "run_id": str(state.run_id),
204
361
  "agent_name": state.current_agent_name,
205
- "turn_count": state.turn_count
362
+ "turn_count": state.turn_count,
363
+ **approval_metadata
206
364
  }
207
365
 
208
366
  result = await config.memory.provider.store_messages(config.conversation_id, messages_to_store, metadata)
@@ -234,6 +392,11 @@ async def _run_internal(
234
392
  config: RunConfig[Ctx]
235
393
  ) -> RunResult[Out]:
236
394
  """Internal run function with recursive execution logic."""
395
+ # Try to resume pending tool calls first
396
+ resumed = await try_resume_pending_tool_calls(state, config)
397
+ if resumed:
398
+ return resumed
399
+
237
400
  # Check initial input guardrails on first turn
238
401
  if state.turn_count == 0:
239
402
  first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
@@ -242,18 +405,18 @@ async def _run_internal(
242
405
  if config.on_event:
243
406
  config.on_event(GuardrailEvent(data=GuardrailEventData(
244
407
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
245
- content=first_user_message.content
408
+ content=get_text_content(first_user_message.content)
246
409
  )))
247
410
  if asyncio.iscoroutinefunction(guardrail):
248
- result = await guardrail(first_user_message.content)
411
+ result = await guardrail(get_text_content(first_user_message.content))
249
412
  else:
250
- result = guardrail(first_user_message.content)
413
+ result = guardrail(get_text_content(first_user_message.content))
251
414
 
252
415
  if not result.is_valid:
253
416
  if config.on_event:
254
417
  config.on_event(GuardrailEvent(data=GuardrailEventData(
255
418
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
256
- content=first_user_message.content,
419
+ content=get_text_content(first_user_message.content),
257
420
  is_valid=False,
258
421
  error_message=result.error_message
259
422
  )))
@@ -300,8 +463,115 @@ async def _run_internal(
300
463
  messages=state.messages
301
464
  ))))
302
465
 
303
- # Get completion from model provider
304
- llm_response = await config.model_provider.get_completion(state, current_agent, config)
466
+ # Get completion from model provider, prefer streaming if available
467
+ llm_response: Dict[str, Any]
468
+ assistant_event_streamed = False
469
+
470
+ get_stream = getattr(config.model_provider, "get_completion_stream", None)
471
+ if callable(get_stream):
472
+ try:
473
+ aggregated_text = ""
474
+ # Working array of partial tool calls
475
+ partial_tool_calls: List[Dict[str, Any]] = []
476
+
477
+ async for chunk in get_stream(state, current_agent, config): # type: ignore[arg-type]
478
+ # Text deltas
479
+ delta_text = getattr(chunk, "delta", None)
480
+ if delta_text:
481
+ aggregated_text += delta_text
482
+
483
+ # Tool call deltas
484
+ tcd = getattr(chunk, "tool_call_delta", None)
485
+ if tcd is not None:
486
+ idx = getattr(tcd, "index", 0) or 0
487
+ # Ensure slot exists
488
+ while len(partial_tool_calls) <= idx:
489
+ partial_tool_calls.append({
490
+ "id": None,
491
+ "type": "function",
492
+ "function": {"name": None, "arguments": ""}
493
+ })
494
+ target = partial_tool_calls[idx]
495
+ # id
496
+ tc_id = getattr(tcd, "id", None)
497
+ if tc_id:
498
+ target["id"] = tc_id
499
+ # function fields
500
+ fn = getattr(tcd, "function", None)
501
+ if fn is not None:
502
+ fn_name = getattr(fn, "name", None)
503
+ if fn_name:
504
+ target["function"]["name"] = fn_name
505
+ args_delta = getattr(fn, "arguments_delta", None)
506
+ if args_delta:
507
+ target["function"]["arguments"] += args_delta
508
+
509
+ # Emit partial assistant message when something changed
510
+ if delta_text or tcd is not None:
511
+ assistant_event_streamed = True
512
+ # Normalize tool_calls for message
513
+ message_tool_calls = None
514
+ if len(partial_tool_calls) > 0:
515
+ message_tool_calls = []
516
+ for i, tc in enumerate(partial_tool_calls):
517
+ message_tool_calls.append({
518
+ "id": tc["id"] or f"call_{i}",
519
+ "type": "function",
520
+ "function": {
521
+ "name": tc["function"]["name"] or "",
522
+ "arguments": tc["function"]["arguments"]
523
+ }
524
+ })
525
+
526
+ partial_msg = Message(
527
+ role=ContentRole.ASSISTANT,
528
+ content=aggregated_text or "",
529
+ tool_calls=None if not message_tool_calls else [
530
+ ToolCall(
531
+ id=mc["id"],
532
+ type="function",
533
+ function=ToolCallFunction(
534
+ name=mc["function"]["name"],
535
+ arguments=mc["function"]["arguments"],
536
+ ),
537
+ ) for mc in message_tool_calls
538
+ ],
539
+ )
540
+ try:
541
+ if config.on_event:
542
+ config.on_event(AssistantMessageEvent(data=to_event_data(
543
+ AssistantMessageEventData(message=partial_msg)
544
+ )))
545
+ except Exception as _e:
546
+ # Do not fail the run on callback errors
547
+ pass
548
+
549
+ # Build final response object compatible with downstream logic
550
+ final_tool_calls = None
551
+ if len(partial_tool_calls) > 0:
552
+ final_tool_calls = []
553
+ for i, tc in enumerate(partial_tool_calls):
554
+ final_tool_calls.append({
555
+ "id": tc["id"] or f"call_{i}",
556
+ "type": "function",
557
+ "function": {
558
+ "name": tc["function"]["name"] or "",
559
+ "arguments": tc["function"]["arguments"]
560
+ }
561
+ })
562
+
563
+ llm_response = {
564
+ "message": {
565
+ "content": aggregated_text or None,
566
+ "tool_calls": final_tool_calls
567
+ }
568
+ }
569
+ except Exception:
570
+ # Fallback to non-streaming on error
571
+ assistant_event_streamed = False
572
+ llm_response = await config.model_provider.get_completion(state, current_agent, config)
573
+ else:
574
+ llm_response = await config.model_provider.get_completion(state, current_agent, config)
305
575
 
306
576
  # Emit LLM call end event
307
577
  if config.on_event:
@@ -339,6 +609,45 @@ async def _run_internal(
339
609
  config
340
610
  )
341
611
 
612
+ # Check for interruptions
613
+ interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
614
+ if interruptions:
615
+ # Separate completed tool results from interrupted ones
616
+ completed_results = [r for r in tool_results if not r.get('interruption')]
617
+ approval_required_results = [r for r in tool_results if r.get('interruption')]
618
+
619
+ # Add pending approvals to state.approvals
620
+ updated_approvals = dict(state.approvals)
621
+ for interruption in interruptions:
622
+ if interruption.type == 'tool_approval':
623
+ updated_approvals[interruption.tool_call.id] = ApprovalValue(
624
+ status='pending',
625
+ approved=False,
626
+ additional_context={'status': 'pending', 'timestamp': str(int(time.time() * 1000))}
627
+ )
628
+
629
+ # Create state with only completed tool results (for LLM context)
630
+ interrupted_state = replace(
631
+ state,
632
+ messages=new_messages + [r['message'] for r in completed_results],
633
+ turn_count=state.turn_count + 1,
634
+ approvals=updated_approvals
635
+ )
636
+
637
+ # Store conversation state with ALL messages including approval-required (for database records)
638
+ if config.memory and config.memory.auto_store and config.conversation_id:
639
+ print(f'[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}')
640
+ state_for_storage = replace(
641
+ interrupted_state,
642
+ messages=interrupted_state.messages + [r['message'] for r in approval_required_results]
643
+ )
644
+ await _store_conversation_history(state_for_storage, config)
645
+
646
+ return RunResult(
647
+ final_state=interrupted_state,
648
+ outcome=InterruptedOutcome(interruptions=interruptions)
649
+ )
650
+
342
651
  # Check for handoffs
343
652
  handoff_result = next((r for r in tool_results if r.get('is_handoff')), None)
344
653
  if handoff_result:
@@ -360,40 +669,76 @@ async def _run_internal(
360
669
  to=target_agent
361
670
  ))))
362
671
 
672
+ # Remove any halted messages that are being replaced by actual execution results
673
+ cleaned_new_messages = []
674
+ for msg in new_messages:
675
+ if msg.role not in (ContentRole.TOOL, 'tool'):
676
+ cleaned_new_messages.append(msg)
677
+ else:
678
+ try:
679
+ content = json.loads(msg.content)
680
+ if content.get('status') == 'halted':
681
+ # Remove this halted message if we have a new result for the same tool_call_id
682
+ if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
683
+ cleaned_new_messages.append(msg)
684
+ else:
685
+ cleaned_new_messages.append(msg)
686
+ except (json.JSONDecodeError, TypeError):
687
+ cleaned_new_messages.append(msg)
688
+
363
689
  # Continue with new agent
364
690
  next_state = replace(
365
691
  state,
366
- messages=new_messages + [r['message'] for r in tool_results],
692
+ messages=cleaned_new_messages + [r['message'] for r in tool_results],
367
693
  current_agent_name=target_agent,
368
- turn_count=state.turn_count + 1
694
+ turn_count=state.turn_count + 1,
695
+ approvals=state.approvals
369
696
  )
370
697
 
371
698
  return await _run_internal(next_state, config)
372
699
 
700
+ # Remove any halted messages that are being replaced by actual execution results
701
+ cleaned_new_messages = []
702
+ for msg in new_messages:
703
+ if msg.role not in (ContentRole.TOOL, 'tool'):
704
+ cleaned_new_messages.append(msg)
705
+ else:
706
+ try:
707
+ content = json.loads(msg.content)
708
+ if content.get('status') == 'halted':
709
+ # Remove this halted message if we have a new result for the same tool_call_id
710
+ if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
711
+ cleaned_new_messages.append(msg)
712
+ else:
713
+ cleaned_new_messages.append(msg)
714
+ except (json.JSONDecodeError, TypeError):
715
+ cleaned_new_messages.append(msg)
716
+
373
717
  # Continue with tool results
374
718
  next_state = replace(
375
719
  state,
376
- messages=new_messages + [r['message'] for r in tool_results],
377
- turn_count=state.turn_count + 1
720
+ messages=cleaned_new_messages + [r['message'] for r in tool_results],
721
+ turn_count=state.turn_count + 1,
722
+ approvals=state.approvals
378
723
  )
379
724
 
380
725
  return await _run_internal(next_state, config)
381
726
 
382
727
  # Handle text completion
383
- if assistant_message.content:
728
+ if get_text_content(assistant_message.content):
384
729
  if current_agent.output_codec:
385
730
  # Parse with output codec
386
731
  if config.on_event:
387
732
  config.on_event(OutputParseEvent(data=OutputParseEventData(
388
- content=assistant_message.content,
733
+ content=get_text_content(assistant_message.content),
389
734
  status='start'
390
735
  )))
391
736
  try:
392
- parsed_content = _try_parse_json(assistant_message.content)
737
+ parsed_content = _try_parse_json(get_text_content(assistant_message.content))
393
738
  output_data = current_agent.output_codec.model_validate(parsed_content)
394
739
  if config.on_event:
395
740
  config.on_event(OutputParseEvent(data=OutputParseEventData(
396
- content=assistant_message.content,
741
+ content=get_text_content(assistant_message.content),
397
742
  status='end',
398
743
  parsed_output=output_data
399
744
  )))
@@ -420,26 +765,26 @@ async def _run_internal(
420
765
  error_message=result.error_message
421
766
  )))
422
767
  return RunResult(
423
- final_state=replace(state, messages=new_messages),
768
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
424
769
  outcome=ErrorOutcome(error=OutputGuardrailTripwire(
425
770
  reason=result.error_message or "Output guardrail failed"
426
771
  ))
427
772
  )
428
773
 
429
774
  return RunResult(
430
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
775
+ final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
431
776
  outcome=CompletedOutcome(output=output_data)
432
777
  )
433
778
 
434
779
  except ValidationError as e:
435
780
  if config.on_event:
436
781
  config.on_event(OutputParseEvent(data=OutputParseEventData(
437
- content=assistant_message.content,
782
+ content=get_text_content(assistant_message.content),
438
783
  status='fail',
439
784
  error=str(e)
440
785
  )))
441
786
  return RunResult(
442
- final_state=replace(state, messages=new_messages),
787
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
443
788
  outcome=ErrorOutcome(error=DecodeError(
444
789
  errors=[{'message': str(e), 'details': e.errors()}]
445
790
  ))
@@ -451,36 +796,36 @@ async def _run_internal(
451
796
  if config.on_event:
452
797
  config.on_event(GuardrailEvent(data=GuardrailEventData(
453
798
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
454
- content=assistant_message.content
799
+ content=get_text_content(assistant_message.content)
455
800
  )))
456
801
  if asyncio.iscoroutinefunction(guardrail):
457
- result = await guardrail(assistant_message.content)
802
+ result = await guardrail(get_text_content(assistant_message.content))
458
803
  else:
459
- result = guardrail(assistant_message.content)
804
+ result = guardrail(get_text_content(assistant_message.content))
460
805
 
461
806
  if not result.is_valid:
462
807
  if config.on_event:
463
808
  config.on_event(GuardrailEvent(data=GuardrailEventData(
464
809
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
465
- content=assistant_message.content,
810
+ content=get_text_content(assistant_message.content),
466
811
  is_valid=False,
467
812
  error_message=result.error_message
468
813
  )))
469
814
  return RunResult(
470
- final_state=replace(state, messages=new_messages),
815
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
471
816
  outcome=ErrorOutcome(error=OutputGuardrailTripwire(
472
817
  reason=result.error_message or "Output guardrail failed"
473
818
  ))
474
819
  )
475
820
 
476
821
  return RunResult(
477
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1),
478
- outcome=CompletedOutcome(output=assistant_message.content)
822
+ final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
823
+ outcome=CompletedOutcome(output=get_text_content(assistant_message.content))
479
824
  )
480
825
 
481
826
  # Model produced neither content nor tool calls
482
827
  return RunResult(
483
- final_state=replace(state, messages=new_messages),
828
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
484
829
  outcome=ErrorOutcome(error=ModelBehaviorError(
485
830
  detail='Model produced neither content nor tool calls'
486
831
  ))
@@ -512,6 +857,7 @@ async def _execute_tool_calls(
512
857
  """Execute tool calls and return results."""
513
858
 
514
859
  async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]:
860
+ print(f'[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}')
515
861
  if config.on_event:
516
862
  config.on_event(ToolCallStartEvent(data=to_event_data(ToolCallStartEventData(
517
863
  tool_name=tool_call.function.name,
@@ -531,7 +877,7 @@ async def _execute_tool_calls(
531
877
 
532
878
  if not tool:
533
879
  error_result = json.dumps({
534
- 'error': 'tool_not_found',
880
+ 'status': 'tool_not_found',
535
881
  'message': f'Tool {tool_call.function.name} not found',
536
882
  'tool_name': tool_call.function.name,
537
883
  })
@@ -564,7 +910,7 @@ async def _execute_tool_calls(
564
910
  validated_args = raw_args
565
911
  except ValidationError as e:
566
912
  error_result = json.dumps({
567
- 'error': 'validation_error',
913
+ 'status': 'validation_error',
568
914
  'message': f'Invalid arguments for {tool_call.function.name}: {e!s}',
569
915
  'tool_name': tool_call.function.name,
570
916
  'validation_errors': e.errors()
@@ -588,16 +934,116 @@ async def _execute_tool_calls(
588
934
  )
589
935
  }
590
936
 
937
+ # Check if tool needs approval
938
+ needs_approval = False
939
+ approval_func = getattr(tool, 'needs_approval', False)
940
+ if callable(approval_func):
941
+ needs_approval = await approval_func(state.context, validated_args)
942
+ else:
943
+ needs_approval = bool(approval_func)
944
+
945
+ # Check approval status - first by ID, then by signature for cross-session matching
946
+ approval_status = state.approvals.get(tool_call.id)
947
+ if not approval_status:
948
+ signature = f"{tool_call.function.name}:{tool_call.function.arguments}"
949
+ for _, approval in state.approvals.items():
950
+ if approval.additional_context and approval.additional_context.get('signature') == signature:
951
+ approval_status = approval
952
+ break
953
+
954
+ derived_status = None
955
+ if approval_status:
956
+ # Use explicit status if available
957
+ if approval_status.status:
958
+ derived_status = approval_status.status
959
+ # Fall back to approved boolean if status not set
960
+ elif approval_status.approved is True:
961
+ derived_status = 'approved'
962
+ elif approval_status.approved is False:
963
+ if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending':
964
+ derived_status = 'pending'
965
+ else:
966
+ derived_status = 'rejected'
967
+
968
+ is_pending = derived_status == 'pending'
969
+
970
+ # If approval needed and not yet decided, create interruption
971
+ if needs_approval and (approval_status is None or is_pending):
972
+ interruption = ToolApprovalInterruption(
973
+ type='tool_approval',
974
+ tool_call=tool_call,
975
+ agent=agent,
976
+ session_id=str(state.run_id)
977
+ )
978
+
979
+ # Return interrupted result with halted message
980
+ halted_result = json.dumps({
981
+ 'status': 'halted',
982
+ 'message': f'Tool {tool_call.function.name} requires approval.',
983
+ })
984
+
985
+ return {
986
+ 'message': Message(
987
+ role=ContentRole.TOOL,
988
+ content=halted_result,
989
+ tool_call_id=tool_call.id
990
+ ),
991
+ 'interruption': interruption
992
+ }
993
+
994
+ # If approval was explicitly rejected, return rejection message
995
+ if derived_status == 'rejected':
996
+ rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action'
997
+ rejection_result = json.dumps({
998
+ 'status': 'approval_denied',
999
+ 'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.',
1000
+ 'tool_name': tool_call.function.name,
1001
+ 'rejection_reason': rejection_reason,
1002
+ 'additional_context': approval_status.additional_context if approval_status else None
1003
+ })
1004
+
1005
+ return {
1006
+ 'message': Message(
1007
+ role=ContentRole.TOOL,
1008
+ content=rejection_result,
1009
+ tool_call_id=tool_call.id
1010
+ )
1011
+ }
1012
+
591
1013
  # Determine timeout for this tool
592
1014
  # Priority: tool-specific timeout > RunConfig default > 30 seconds global default
593
- timeout = getattr(tool.schema, 'timeout', None)
1015
+ if tool and hasattr(tool, 'schema'):
1016
+ timeout = getattr(tool.schema, 'timeout', None)
1017
+ else:
1018
+ timeout = None
594
1019
  if timeout is None:
595
1020
  timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
596
1021
 
1022
+ # Merge additional context if provided through approval
1023
+ additional_context = approval_status.additional_context if approval_status else None
1024
+ context_with_additional = state.context
1025
+ if additional_context:
1026
+ # Create a copy of context with additional fields from approval
1027
+ if hasattr(state.context, '__dict__'):
1028
+ # For dataclass contexts, add additional context as attributes
1029
+ context_dict = {**state.context.__dict__, **additional_context}
1030
+ context_with_additional = type(state.context)(**{k: v for k, v in context_dict.items() if k in state.context.__dict__})
1031
+ # Add any extra fields as attributes
1032
+ for key, value in additional_context.items():
1033
+ if not hasattr(context_with_additional, key):
1034
+ setattr(context_with_additional, key, value)
1035
+ else:
1036
+ # For dict contexts, merge normally
1037
+ context_with_additional = {**state.context, **additional_context}
1038
+
1039
+ print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}')
1040
+ print(f'[JAF:ENGINE] Tool args:', validated_args)
1041
+ print(f'[JAF:ENGINE] Tool context:', state.context)
1042
+
597
1043
  # Execute the tool with timeout
598
1044
  try:
599
1045
  tool_result = await asyncio.wait_for(
600
- tool.execute(validated_args, state.context),
1046
+ tool.execute(validated_args, context_with_additional),
601
1047
  timeout=timeout
602
1048
  )
603
1049
  except asyncio.TimeoutError:
@@ -629,14 +1075,41 @@ async def _execute_tool_calls(
629
1075
  # Handle both string and ToolResult formats
630
1076
  if isinstance(tool_result, str):
631
1077
  result_string = tool_result
1078
+ print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned string:', result_string)
632
1079
  else:
633
1080
  # It's a ToolResult object
634
1081
  result_string = tool_result_to_string(tool_result)
1082
+ print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:', tool_result)
1083
+ print(f'[JAF:ENGINE] Converted to string:', result_string)
1084
+
1085
+ # Wrap tool result with status information for approval context
1086
+ if approval_status and approval_status.additional_context:
1087
+ final_content = json.dumps({
1088
+ 'status': 'approved_and_executed',
1089
+ 'result': result_string,
1090
+ 'tool_name': tool_call.function.name,
1091
+ 'approval_context': approval_status.additional_context,
1092
+ 'message': 'Tool was approved and executed successfully with additional context.'
1093
+ })
1094
+ elif needs_approval:
1095
+ final_content = json.dumps({
1096
+ 'status': 'approved_and_executed',
1097
+ 'result': result_string,
1098
+ 'tool_name': tool_call.function.name,
1099
+ 'message': 'Tool was approved and executed successfully.'
1100
+ })
1101
+ else:
1102
+ final_content = json.dumps({
1103
+ 'status': 'executed',
1104
+ 'result': result_string,
1105
+ 'tool_name': tool_call.function.name,
1106
+ 'message': 'Tool executed successfully.'
1107
+ })
635
1108
 
636
1109
  if config.on_event:
637
1110
  config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
638
1111
  tool_name=tool_call.function.name,
639
- result=result_string,
1112
+ result=final_content,
640
1113
  trace_id=state.trace_id,
641
1114
  run_id=state.run_id,
642
1115
  tool_result=tool_result,
@@ -649,7 +1122,7 @@ async def _execute_tool_calls(
649
1122
  return {
650
1123
  'message': Message(
651
1124
  role=ContentRole.TOOL,
652
- content=result_string,
1125
+ content=final_content,
653
1126
  tool_call_id=tool_call.id
654
1127
  ),
655
1128
  'is_handoff': True,
@@ -659,14 +1132,14 @@ async def _execute_tool_calls(
659
1132
  return {
660
1133
  'message': Message(
661
1134
  role=ContentRole.TOOL,
662
- content=result_string,
1135
+ content=final_content,
663
1136
  tool_call_id=tool_call.id
664
1137
  )
665
1138
  }
666
1139
 
667
1140
  except Exception as error:
668
1141
  error_result = json.dumps({
669
- 'error': 'execution_error',
1142
+ 'status': 'execution_error',
670
1143
  'message': str(error),
671
1144
  'tool_name': tool_call.function.name,
672
1145
  })