jaf-py 2.5.9__py3-none-any.whl → 2.5.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. jaf/__init__.py +154 -57
  2. jaf/a2a/__init__.py +42 -21
  3. jaf/a2a/agent.py +79 -126
  4. jaf/a2a/agent_card.py +87 -78
  5. jaf/a2a/client.py +30 -66
  6. jaf/a2a/examples/client_example.py +12 -12
  7. jaf/a2a/examples/integration_example.py +38 -47
  8. jaf/a2a/examples/server_example.py +56 -53
  9. jaf/a2a/memory/__init__.py +0 -4
  10. jaf/a2a/memory/cleanup.py +28 -21
  11. jaf/a2a/memory/factory.py +155 -133
  12. jaf/a2a/memory/providers/composite.py +21 -26
  13. jaf/a2a/memory/providers/in_memory.py +89 -83
  14. jaf/a2a/memory/providers/postgres.py +117 -115
  15. jaf/a2a/memory/providers/redis.py +128 -121
  16. jaf/a2a/memory/serialization.py +77 -87
  17. jaf/a2a/memory/tests/run_comprehensive_tests.py +112 -83
  18. jaf/a2a/memory/tests/test_cleanup.py +211 -94
  19. jaf/a2a/memory/tests/test_serialization.py +73 -68
  20. jaf/a2a/memory/tests/test_stress_concurrency.py +186 -133
  21. jaf/a2a/memory/tests/test_task_lifecycle.py +138 -120
  22. jaf/a2a/memory/types.py +91 -53
  23. jaf/a2a/protocol.py +95 -125
  24. jaf/a2a/server.py +90 -118
  25. jaf/a2a/standalone_client.py +30 -43
  26. jaf/a2a/tests/__init__.py +16 -33
  27. jaf/a2a/tests/run_tests.py +17 -53
  28. jaf/a2a/tests/test_agent.py +40 -140
  29. jaf/a2a/tests/test_client.py +54 -117
  30. jaf/a2a/tests/test_integration.py +28 -82
  31. jaf/a2a/tests/test_protocol.py +54 -139
  32. jaf/a2a/tests/test_types.py +50 -136
  33. jaf/a2a/types.py +58 -34
  34. jaf/cli.py +21 -41
  35. jaf/core/__init__.py +7 -1
  36. jaf/core/agent_tool.py +93 -72
  37. jaf/core/analytics.py +257 -207
  38. jaf/core/checkpoint.py +223 -0
  39. jaf/core/composition.py +249 -235
  40. jaf/core/engine.py +817 -519
  41. jaf/core/errors.py +55 -42
  42. jaf/core/guardrails.py +276 -202
  43. jaf/core/handoff.py +47 -31
  44. jaf/core/parallel_agents.py +69 -75
  45. jaf/core/performance.py +75 -73
  46. jaf/core/proxy.py +43 -44
  47. jaf/core/proxy_helpers.py +24 -27
  48. jaf/core/regeneration.py +220 -129
  49. jaf/core/state.py +68 -66
  50. jaf/core/streaming.py +115 -108
  51. jaf/core/tool_results.py +111 -101
  52. jaf/core/tools.py +114 -116
  53. jaf/core/tracing.py +269 -210
  54. jaf/core/types.py +371 -151
  55. jaf/core/workflows.py +209 -168
  56. jaf/exceptions.py +46 -38
  57. jaf/memory/__init__.py +1 -6
  58. jaf/memory/approval_storage.py +54 -77
  59. jaf/memory/factory.py +4 -4
  60. jaf/memory/providers/in_memory.py +216 -180
  61. jaf/memory/providers/postgres.py +216 -146
  62. jaf/memory/providers/redis.py +173 -116
  63. jaf/memory/types.py +70 -51
  64. jaf/memory/utils.py +36 -34
  65. jaf/plugins/__init__.py +12 -12
  66. jaf/plugins/base.py +105 -96
  67. jaf/policies/__init__.py +0 -1
  68. jaf/policies/handoff.py +37 -46
  69. jaf/policies/validation.py +76 -52
  70. jaf/providers/__init__.py +6 -3
  71. jaf/providers/mcp.py +97 -51
  72. jaf/providers/model.py +361 -280
  73. jaf/server/__init__.py +1 -1
  74. jaf/server/main.py +7 -11
  75. jaf/server/server.py +514 -359
  76. jaf/server/types.py +208 -52
  77. jaf/utils/__init__.py +17 -18
  78. jaf/utils/attachments.py +111 -116
  79. jaf/utils/document_processor.py +175 -174
  80. jaf/visualization/__init__.py +1 -1
  81. jaf/visualization/example.py +111 -110
  82. jaf/visualization/functional_core.py +46 -71
  83. jaf/visualization/graphviz.py +154 -189
  84. jaf/visualization/imperative_shell.py +7 -16
  85. jaf/visualization/types.py +8 -4
  86. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/METADATA +2 -2
  87. jaf_py-2.5.11.dist-info/RECORD +97 -0
  88. jaf_py-2.5.9.dist-info/RECORD +0 -96
  89. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/WHEEL +0 -0
  90. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/entry_points.txt +0 -0
  91. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/licenses/LICENSE +0 -0
  92. {jaf_py-2.5.9.dist-info → jaf_py-2.5.11.dist-info}/top_level.txt +0 -0
jaf/core/engine.py CHANGED
@@ -78,12 +78,12 @@ from .guardrails import (
78
78
  def to_event_data(value: Any) -> Any:
79
79
  """
80
80
  Resilient serializer helper for event payloads.
81
-
81
+
82
82
  Converts various types to event-compatible data:
83
83
  - dataclasses: uses asdict()
84
84
  - Pydantic BaseModel: uses model_dump()
85
85
  - other types: returns as-is
86
-
86
+
87
87
  This prevents TypeError when serializing nested Pydantic models or non-dataclass types.
88
88
  """
89
89
  if is_dataclass(value):
@@ -94,12 +94,12 @@ def to_event_data(value: Any) -> Any:
94
94
  return value
95
95
 
96
96
 
97
- Ctx = TypeVar('Ctx')
98
- Out = TypeVar('Out')
97
+ Ctx = TypeVar("Ctx")
98
+ Out = TypeVar("Out")
99
+
99
100
 
100
101
  async def try_resume_pending_tool_calls(
101
- state: RunState[Ctx],
102
- config: RunConfig[Ctx]
102
+ state: RunState[Ctx], config: RunConfig[Ctx]
103
103
  ) -> Optional[RunResult[Out]]:
104
104
  """
105
105
  Try to resume pending tool calls if the last assistant message contained tool_calls
@@ -110,157 +110,192 @@ async def try_resume_pending_tool_calls(
110
110
  for i in range(len(messages) - 1, -1, -1):
111
111
  msg = messages[i]
112
112
  # Handle both string and enum roles
113
- role_str = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
114
- if role_str == 'assistant' and msg.tool_calls:
113
+ role_str = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
114
+ if role_str == "assistant" and msg.tool_calls:
115
115
  tool_call_ids = {tc.id for tc in msg.tool_calls}
116
-
116
+
117
117
  # Scan forward for tool results tied to these ids
118
118
  executed_ids = set()
119
119
  for j in range(i + 1, len(messages)):
120
120
  m = messages[j]
121
121
  # Handle both string and enum roles
122
- m_role_str = m.role.value if hasattr(m.role, 'value') else str(m.role)
123
- if m_role_str == 'tool' and m.tool_call_id and m.tool_call_id in tool_call_ids:
122
+ m_role_str = m.role.value if hasattr(m.role, "value") else str(m.role)
123
+ if m_role_str == "tool" and m.tool_call_id and m.tool_call_id in tool_call_ids:
124
124
  executed_ids.add(m.tool_call_id)
125
-
125
+
126
126
  pending_tool_calls = [tc for tc in msg.tool_calls if tc.id not in executed_ids]
127
-
127
+
128
128
  if not pending_tool_calls:
129
129
  continue # Continue checking other assistant messages
130
-
130
+
131
131
  current_agent = config.agent_registry.get(state.current_agent_name)
132
132
  if not current_agent:
133
133
  return RunResult(
134
134
  final_state=state,
135
- outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
135
+ outcome=ErrorOutcome(
136
+ error=AgentNotFound(agent_name=state.current_agent_name)
137
+ ),
136
138
  )
137
-
139
+
138
140
  # Execute pending tool calls
139
141
  tool_results = await _execute_tool_calls(
140
- pending_tool_calls,
141
- current_agent,
142
- state,
143
- config
142
+ pending_tool_calls, current_agent, state, config
144
143
  )
145
-
144
+
146
145
  # Check for interruptions
147
- interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
146
+ interruptions = [
147
+ r.get("interruption") for r in tool_results if r.get("interruption")
148
+ ]
148
149
  if interruptions:
149
- completed_results = [r for r in tool_results if not r.get('interruption')]
150
+ completed_results = [r for r in tool_results if not r.get("interruption")]
150
151
  interrupted_state = replace(
151
152
  state,
152
- messages=list(state.messages) + [r['message'] for r in completed_results],
153
+ messages=list(state.messages) + [r["message"] for r in completed_results],
153
154
  turn_count=state.turn_count,
154
- approvals=state.approvals
155
+ approvals=state.approvals,
155
156
  )
156
157
  return RunResult(
157
158
  final_state=interrupted_state,
158
- outcome=InterruptedOutcome(interruptions=interruptions)
159
+ outcome=InterruptedOutcome(interruptions=interruptions),
159
160
  )
160
-
161
+
161
162
  # Continue with normal execution
162
163
  next_state = replace(
163
164
  state,
164
- messages=list(state.messages) + [r['message'] for r in tool_results],
165
+ messages=list(state.messages) + [r["message"] for r in tool_results],
165
166
  turn_count=state.turn_count,
166
- approvals=state.approvals
167
+ approvals=state.approvals,
167
168
  )
168
169
  return await _run_internal(next_state, config)
169
-
170
+
170
171
  except Exception as e:
171
172
  # Best-effort resume; ignore and continue normal flow
172
173
  pass
173
-
174
+
174
175
  return None
175
176
 
176
- async def run(
177
- initial_state: RunState[Ctx],
178
- config: RunConfig[Ctx]
179
- ) -> RunResult[Out]:
177
+
178
+ async def run(initial_state: RunState[Ctx], config: RunConfig[Ctx]) -> RunResult[Out]:
180
179
  """
181
180
  Main execution function for running agents.
182
181
  """
183
182
  try:
184
183
  # Set the current RunConfig in context for agent tools
185
184
  from .agent_tool import set_current_run_config
185
+
186
186
  set_current_run_config(config)
187
-
187
+
188
188
  state_with_memory = await _load_conversation_history(initial_state, config)
189
-
189
+
190
190
  # Emit RunStartEvent AFTER loading conversation history so we have complete context
191
191
  if config.on_event:
192
- config.on_event(RunStartEvent(data=to_event_data(RunStartEventData(
193
- run_id=initial_state.run_id,
194
- trace_id=initial_state.trace_id,
195
- session_id=config.conversation_id,
196
- context=state_with_memory.context,
197
- messages=state_with_memory.messages, # Now includes full conversation history
198
- agent_name=state_with_memory.current_agent_name
199
- ))))
200
-
192
+ config.on_event(
193
+ RunStartEvent(
194
+ data=to_event_data(
195
+ RunStartEventData(
196
+ run_id=initial_state.run_id,
197
+ trace_id=initial_state.trace_id,
198
+ session_id=config.conversation_id,
199
+ context=state_with_memory.context,
200
+ messages=state_with_memory.messages, # Now includes full conversation history
201
+ agent_name=state_with_memory.current_agent_name,
202
+ )
203
+ )
204
+ )
205
+ )
206
+
201
207
  # Load approvals from storage if configured
202
208
  if config.approval_storage:
203
- print(f'[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}')
209
+ print(f"[JAF:ENGINE] Loading approvals for runId {state_with_memory.run_id}")
204
210
  from .state import load_approvals_into_state
211
+
205
212
  state_with_memory = await load_approvals_into_state(state_with_memory, config)
206
-
213
+
207
214
  result = await _run_internal(state_with_memory, config)
208
215
 
209
216
  # Store conversation history only if this is a final completion of the entire conversation
210
217
  # For HITL scenarios, storage happens on interruption to allow resumption
211
218
  # We only store on completion if explicitly indicated this is the end of the conversation
212
- if (config.memory and config.memory.auto_store and config.conversation_id and
213
- result.outcome.status == 'completed' and getattr(config.memory, 'store_on_completion', True)):
214
- print(f'[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}')
219
+ if (
220
+ config.memory
221
+ and config.memory.auto_store
222
+ and config.conversation_id
223
+ and result.outcome.status == "completed"
224
+ and getattr(config.memory, "store_on_completion", True)
225
+ ):
226
+ print(f"[JAF:ENGINE] Storing final completed conversation for {config.conversation_id}")
215
227
  await _store_conversation_history(result.final_state, config)
216
- elif result.outcome.status == 'interrupted':
217
- print('[JAF:ENGINE] Conversation interrupted - storage already handled during interruption')
228
+ elif result.outcome.status == "interrupted":
229
+ print(
230
+ "[JAF:ENGINE] Conversation interrupted - storage already handled during interruption"
231
+ )
218
232
  else:
219
- print(f'[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, "store_on_completion", True) if config.memory else "N/A"}')
233
+ print(
234
+ f"[JAF:ENGINE] Skipping memory store - status: {result.outcome.status}, store_on_completion: {getattr(config.memory, 'store_on_completion', True) if config.memory else 'N/A'}"
235
+ )
220
236
 
221
237
  if config.on_event:
222
- config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
223
- outcome=result.outcome,
224
- trace_id=initial_state.trace_id,
225
- run_id=initial_state.run_id
226
- ))))
238
+ config.on_event(
239
+ RunEndEvent(
240
+ data=to_event_data(
241
+ RunEndEventData(
242
+ outcome=result.outcome,
243
+ trace_id=initial_state.trace_id,
244
+ run_id=initial_state.run_id,
245
+ )
246
+ )
247
+ )
248
+ )
227
249
 
228
250
  return result
229
251
  except Exception as error:
230
252
  error_result = RunResult(
231
253
  final_state=initial_state,
232
- outcome=ErrorOutcome(error=ModelBehaviorError(detail=str(error)))
254
+ outcome=ErrorOutcome(error=ModelBehaviorError(detail=str(error))),
233
255
  )
234
256
  if config.on_event:
235
- config.on_event(RunEndEvent(data=to_event_data(RunEndEventData(
236
- outcome=error_result.outcome,
237
- trace_id=initial_state.trace_id,
238
- run_id=initial_state.run_id
239
- ))))
257
+ config.on_event(
258
+ RunEndEvent(
259
+ data=to_event_data(
260
+ RunEndEventData(
261
+ outcome=error_result.outcome,
262
+ trace_id=initial_state.trace_id,
263
+ run_id=initial_state.run_id,
264
+ )
265
+ )
266
+ )
267
+ )
240
268
  return error_result
241
269
 
270
+
242
271
  async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx]) -> RunState[Ctx]:
243
272
  """Load conversation history from memory provider."""
244
273
  if not (config.memory and config.memory.provider and config.conversation_id):
245
274
  return state
246
275
 
247
276
  if config.on_event:
248
- config.on_event(MemoryEvent(data=MemoryEventData(
249
- operation='load',
250
- conversation_id=config.conversation_id,
251
- status='start'
252
- )))
277
+ config.on_event(
278
+ MemoryEvent(
279
+ data=MemoryEventData(
280
+ operation="load", conversation_id=config.conversation_id, status="start"
281
+ )
282
+ )
283
+ )
253
284
 
254
285
  result = await config.memory.provider.get_conversation(config.conversation_id)
255
286
  if isinstance(result, Failure):
256
287
  print(f"[JAF:ENGINE] Warning: Failed to load conversation: {result.error}")
257
288
  if config.on_event:
258
- config.on_event(MemoryEvent(data=MemoryEventData(
259
- operation='load',
260
- conversation_id=config.conversation_id,
261
- status='fail',
262
- error=str(result.error)
263
- )))
289
+ config.on_event(
290
+ MemoryEvent(
291
+ data=MemoryEventData(
292
+ operation="load",
293
+ conversation_id=config.conversation_id,
294
+ status="fail",
295
+ error=str(result.error),
296
+ )
297
+ )
298
+ )
264
299
  return state
265
300
 
266
301
  conversation_data = result.data
@@ -271,17 +306,17 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
271
306
  # Filter out halted messages - they're for audit/database only, not for LLM context
272
307
  memory_messages = []
273
308
  filtered_count = 0
274
-
309
+
275
310
  for msg in all_memory_messages:
276
- if msg.role not in (ContentRole.TOOL, 'tool'):
311
+ if msg.role not in (ContentRole.TOOL, "tool"):
277
312
  memory_messages.append(msg)
278
313
  else:
279
314
  try:
280
315
  content = json.loads(msg.content)
281
- status = content.get('status')
282
- hitl_status = content.get('hitl_status')
316
+ status = content.get("status")
317
+ hitl_status = content.get("hitl_status")
283
318
  # Filter out ALL halted/pending approval messages (they're for audit only)
284
- if status == 'halted' or hitl_status == 'pending_approval':
319
+ if status == "halted" or hitl_status == "pending_approval":
285
320
  filtered_count += 1
286
321
  continue # Skip this halted message
287
322
  else:
@@ -301,8 +336,12 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
301
336
  approvals_map = state.approvals
302
337
 
303
338
  # Calculate turn count efficiently
304
- memory_assistant_count = sum(1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
305
- current_assistant_count = sum(1 for msg in state.messages if msg.role in (ContentRole.ASSISTANT, 'assistant'))
339
+ memory_assistant_count = sum(
340
+ 1 for msg in memory_messages if msg.role in (ContentRole.ASSISTANT, "assistant")
341
+ )
342
+ current_assistant_count = sum(
343
+ 1 for msg in state.messages if msg.role in (ContentRole.ASSISTANT, "assistant")
344
+ )
306
345
  calculated_turn_count = memory_assistant_count + current_assistant_count
307
346
 
308
347
  # Use metadata turn_count if available, otherwise calculate from messages
@@ -312,40 +351,54 @@ async def _load_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx
312
351
  turn_count = max(metadata_turn_count, calculated_turn_count)
313
352
 
314
353
  if config.on_event:
315
- config.on_event(MemoryEvent(data=MemoryEventData(
316
- operation='load',
317
- conversation_id=config.conversation_id,
318
- status='end',
319
- message_count=len(memory_messages)
320
- )))
354
+ config.on_event(
355
+ MemoryEvent(
356
+ data=MemoryEventData(
357
+ operation="load",
358
+ conversation_id=config.conversation_id,
359
+ status="end",
360
+ message_count=len(memory_messages),
361
+ )
362
+ )
363
+ )
321
364
 
322
365
  if filtered_count > 0:
323
- print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)')
366
+ print(
367
+ f"[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory, filtered to {len(memory_messages)} for LLM context (removed {filtered_count} halted messages)"
368
+ )
324
369
  else:
325
- print(f'[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory')
370
+ print(f"[JAF:MEMORY] Loaded {len(all_memory_messages)} messages from memory")
326
371
 
327
372
  return replace(
328
- state,
329
- messages=combined_messages,
330
- turn_count=turn_count,
331
- approvals=approvals_map
373
+ state, messages=combined_messages, turn_count=turn_count, approvals=approvals_map
332
374
  )
333
375
  return state
334
376
 
377
+
335
378
  async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ctx]):
336
379
  """Store conversation history to memory provider."""
337
- if not (config.memory and config.memory.provider and config.conversation_id and config.memory.auto_store):
380
+ if not (
381
+ config.memory
382
+ and config.memory.provider
383
+ and config.conversation_id
384
+ and config.memory.auto_store
385
+ ):
338
386
  return
339
387
 
340
388
  if config.on_event:
341
- config.on_event(MemoryEvent(data=MemoryEventData(
342
- operation='store',
343
- conversation_id=config.conversation_id,
344
- status='start'
345
- )))
389
+ config.on_event(
390
+ MemoryEvent(
391
+ data=MemoryEventData(
392
+ operation="store", conversation_id=config.conversation_id, status="start"
393
+ )
394
+ )
395
+ )
346
396
 
347
397
  messages_to_store = list(state.messages)
348
- if config.memory.compression_threshold and len(messages_to_store) > config.memory.compression_threshold:
398
+ if (
399
+ config.memory.compression_threshold
400
+ and len(messages_to_store) > config.memory.compression_threshold
401
+ ):
349
402
  keep_first = int(config.memory.compression_threshold * 0.2)
350
403
  keep_recent = config.memory.compression_threshold - keep_first
351
404
  messages_to_store = messages_to_store[:keep_first] + messages_to_store[-keep_recent:]
@@ -355,60 +408,72 @@ async def _store_conversation_history(state: RunState[Ctx], config: RunConfig[Ct
355
408
  if state.approvals:
356
409
  approval_metadata = {
357
410
  "approval_count": len(state.approvals),
358
- "approved_tools": [tool_id for tool_id, approval in state.approvals.items() if approval.approved],
359
- "rejected_tools": [tool_id for tool_id, approval in state.approvals.items() if not approval.approved],
360
- "has_approvals": True
411
+ "approved_tools": [
412
+ tool_id for tool_id, approval in state.approvals.items() if approval.approved
413
+ ],
414
+ "rejected_tools": [
415
+ tool_id for tool_id, approval in state.approvals.items() if not approval.approved
416
+ ],
417
+ "has_approvals": True,
361
418
  }
362
-
419
+
363
420
  metadata = {
364
- "user_id": getattr(state.context, 'user_id', None),
421
+ "user_id": getattr(state.context, "user_id", None),
365
422
  "trace_id": str(state.trace_id),
366
423
  "run_id": str(state.run_id),
367
424
  "agent_name": state.current_agent_name,
368
425
  "turn_count": state.turn_count,
369
- **approval_metadata
426
+ **approval_metadata,
370
427
  }
371
428
 
372
- result = await config.memory.provider.store_messages(config.conversation_id, messages_to_store, metadata)
429
+ result = await config.memory.provider.store_messages(
430
+ config.conversation_id, messages_to_store, metadata
431
+ )
373
432
 
374
433
  if isinstance(result, Failure):
375
434
  print(f"[JAF:ENGINE] Warning: Failed to store conversation: {result.error}")
376
435
  if config.on_event:
377
- config.on_event(MemoryEvent(data=MemoryEventData(
378
- operation='store',
379
- conversation_id=config.conversation_id,
380
- status='fail',
381
- error=str(result.error)
382
- )))
436
+ config.on_event(
437
+ MemoryEvent(
438
+ data=MemoryEventData(
439
+ operation="store",
440
+ conversation_id=config.conversation_id,
441
+ status="fail",
442
+ error=str(result.error),
443
+ )
444
+ )
445
+ )
383
446
  else:
384
- print(f"[JAF:ENGINE] Stored {len(messages_to_store)} messages for conversation {config.conversation_id}")
447
+ print(
448
+ f"[JAF:ENGINE] Stored {len(messages_to_store)} messages for conversation {config.conversation_id}"
449
+ )
385
450
  if config.on_event:
386
- config.on_event(MemoryEvent(data=MemoryEventData(
387
- operation='store',
388
- conversation_id=config.conversation_id,
389
- status='end',
390
- message_count=len(messages_to_store)
391
- )))
451
+ config.on_event(
452
+ MemoryEvent(
453
+ data=MemoryEventData(
454
+ operation="store",
455
+ conversation_id=config.conversation_id,
456
+ status="end",
457
+ message_count=len(messages_to_store),
458
+ )
459
+ )
460
+ )
392
461
 
393
462
  # Removed verbose logging for performance
394
463
 
395
464
 
396
- async def _run_internal(
397
- state: RunState[Ctx],
398
- config: RunConfig[Ctx]
399
- ) -> RunResult[Out]:
465
+ async def _run_internal(state: RunState[Ctx], config: RunConfig[Ctx]) -> RunResult[Out]:
400
466
  """Internal run function with recursive execution logic."""
401
467
  # Try to resume pending tool calls first
402
468
  resumed = await try_resume_pending_tool_calls(state, config)
403
469
  if resumed:
404
470
  return resumed
405
-
471
+
406
472
  # Check max turns
407
473
  max_turns = config.max_turns or 50
408
474
  if state.turn_count >= max_turns:
409
475
  return RunResult(
410
- final_state=state,
411
- outcome=ErrorOutcome(error=MaxTurnsExceeded(turns=state.turn_count))
476
+ final_state=state, outcome=ErrorOutcome(error=MaxTurnsExceeded(turns=state.turn_count))
412
477
  )
413
478
 
414
479
  # Get current agent
@@ -416,30 +481,35 @@ async def _run_internal(
416
481
  if not current_agent:
417
482
  return RunResult(
418
483
  final_state=state,
419
- outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
484
+ outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name)),
420
485
  )
421
486
 
422
487
  # Determine if agent has advanced guardrails configuration
423
488
  has_advanced_guardrails = bool(
424
- current_agent.advanced_config and
425
- current_agent.advanced_config.guardrails and
426
- (current_agent.advanced_config.guardrails.input_prompt or
427
- current_agent.advanced_config.guardrails.output_prompt or
428
- current_agent.advanced_config.guardrails.require_citations)
489
+ current_agent.advanced_config
490
+ and current_agent.advanced_config.guardrails
491
+ and (
492
+ current_agent.advanced_config.guardrails.input_prompt
493
+ or current_agent.advanced_config.guardrails.output_prompt
494
+ or current_agent.advanced_config.guardrails.require_citations
495
+ )
496
+ )
497
+
498
+ print(
499
+ "[JAF:ENGINE] Debug guardrails setup:",
500
+ {
501
+ "agent_name": current_agent.name,
502
+ "has_advanced_config": bool(current_agent.advanced_config),
503
+ "has_advanced_guardrails": has_advanced_guardrails,
504
+ "initial_input_guardrails": len(config.initial_input_guardrails or []),
505
+ "final_output_guardrails": len(config.final_output_guardrails or []),
506
+ },
429
507
  )
430
-
431
- print('[JAF:ENGINE] Debug guardrails setup:', {
432
- 'agent_name': current_agent.name,
433
- 'has_advanced_config': bool(current_agent.advanced_config),
434
- 'has_advanced_guardrails': has_advanced_guardrails,
435
- 'initial_input_guardrails': len(config.initial_input_guardrails or []),
436
- 'final_output_guardrails': len(config.final_output_guardrails or [])
437
- })
438
508
 
439
509
  # Build effective guardrails
440
510
  effective_input_guardrails: List[Guardrail] = []
441
511
  effective_output_guardrails: List[Guardrail] = []
442
-
512
+
443
513
  if has_advanced_guardrails:
444
514
  result = await build_effective_guardrails(current_agent, config)
445
515
  effective_input_guardrails, effective_output_guardrails = result
@@ -448,35 +518,48 @@ async def _run_internal(
448
518
  effective_output_guardrails = list(config.final_output_guardrails or [])
449
519
 
450
520
  # Execute input guardrails on first turn
451
- input_guardrails_to_run = (effective_input_guardrails
452
- if state.turn_count == 0 and effective_input_guardrails
453
- else [])
454
-
455
- print('[JAF:ENGINE] Input guardrails to run:', {
456
- 'turn_count': state.turn_count,
457
- 'effective_input_length': len(effective_input_guardrails),
458
- 'input_guardrails_to_run_length': len(input_guardrails_to_run),
459
- 'has_advanced_guardrails': has_advanced_guardrails
460
- })
521
+ input_guardrails_to_run = (
522
+ effective_input_guardrails if state.turn_count == 0 and effective_input_guardrails else []
523
+ )
524
+
525
+ print(
526
+ "[JAF:ENGINE] Input guardrails to run:",
527
+ {
528
+ "turn_count": state.turn_count,
529
+ "effective_input_length": len(effective_input_guardrails),
530
+ "input_guardrails_to_run_length": len(input_guardrails_to_run),
531
+ "has_advanced_guardrails": has_advanced_guardrails,
532
+ },
533
+ )
461
534
 
462
535
  if input_guardrails_to_run and state.turn_count == 0:
463
- first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
536
+ first_user_message = next(
537
+ (m for m in state.messages if m.role == ContentRole.USER or m.role == "user"), None
538
+ )
464
539
  if first_user_message:
465
540
  if has_advanced_guardrails:
466
- execution_mode = (current_agent.advanced_config.guardrails.execution_mode
467
- if current_agent.advanced_config and current_agent.advanced_config.guardrails
468
- else 'parallel')
469
-
470
- if execution_mode == 'sequential':
541
+ execution_mode = (
542
+ current_agent.advanced_config.guardrails.execution_mode
543
+ if current_agent.advanced_config and current_agent.advanced_config.guardrails
544
+ else "parallel"
545
+ )
546
+
547
+ if execution_mode == "sequential":
471
548
  guardrail_result = await execute_input_guardrails_sequential(
472
549
  input_guardrails_to_run, first_user_message, config
473
550
  )
474
551
  if not guardrail_result.is_valid:
475
552
  return RunResult(
476
553
  final_state=state,
477
- outcome=ErrorOutcome(error=InputGuardrailTripwire(
478
- reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
479
- ))
554
+ outcome=ErrorOutcome(
555
+ error=InputGuardrailTripwire(
556
+ reason=getattr(
557
+ guardrail_result,
558
+ "error_message",
559
+ "Input guardrail violation",
560
+ )
561
+ )
562
+ ),
480
563
  )
481
564
  else:
482
565
  # Parallel execution with LLM call overlap
@@ -484,22 +567,40 @@ async def _run_internal(
484
567
  input_guardrails_to_run, first_user_message, config
485
568
  )
486
569
  if not guardrail_result.is_valid:
487
- print(f"🚨 Input guardrail violation: {getattr(guardrail_result, 'error_message', 'Unknown violation')}")
570
+ print(
571
+ f"🚨 Input guardrail violation: {getattr(guardrail_result, 'error_message', 'Unknown violation')}"
572
+ )
488
573
  return RunResult(
489
574
  final_state=state,
490
- outcome=ErrorOutcome(error=InputGuardrailTripwire(
491
- reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
492
- ))
575
+ outcome=ErrorOutcome(
576
+ error=InputGuardrailTripwire(
577
+ reason=getattr(
578
+ guardrail_result,
579
+ "error_message",
580
+ "Input guardrail violation",
581
+ )
582
+ )
583
+ ),
493
584
  )
494
585
  else:
495
586
  # Legacy guardrails path
496
- print('[JAF:ENGINE] Using LEGACY guardrails path with', len(input_guardrails_to_run), 'guardrails')
587
+ print(
588
+ "[JAF:ENGINE] Using LEGACY guardrails path with",
589
+ len(input_guardrails_to_run),
590
+ "guardrails",
591
+ )
497
592
  for guardrail in input_guardrails_to_run:
498
593
  if config.on_event:
499
- config.on_event(GuardrailEvent(data=GuardrailEventData(
500
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
501
- content=get_text_content(first_user_message.content)
502
- )))
594
+ config.on_event(
595
+ GuardrailEvent(
596
+ data=GuardrailEventData(
597
+ guardrail_name=getattr(
598
+ guardrail, "__name__", "unknown_guardrail"
599
+ ),
600
+ content=get_text_content(first_user_message.content),
601
+ )
602
+ )
603
+ )
503
604
  if asyncio.iscoroutinefunction(guardrail):
504
605
  result = await guardrail(get_text_content(first_user_message.content))
505
606
  else:
@@ -507,24 +608,34 @@ async def _run_internal(
507
608
 
508
609
  if not result.is_valid:
509
610
  if config.on_event:
510
- config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
511
- stage='input',
512
- reason=getattr(result, 'error_message', 'Input guardrail failed')
513
- )))
611
+ config.on_event(
612
+ GuardrailViolationEvent(
613
+ data=GuardrailViolationEventData(
614
+ stage="input",
615
+ reason=getattr(
616
+ result, "error_message", "Input guardrail failed"
617
+ ),
618
+ )
619
+ )
620
+ )
514
621
  return RunResult(
515
622
  final_state=state,
516
- outcome=ErrorOutcome(error=InputGuardrailTripwire(
517
- reason=getattr(result, 'error_message', 'Input guardrail failed')
518
- ))
623
+ outcome=ErrorOutcome(
624
+ error=InputGuardrailTripwire(
625
+ reason=getattr(
626
+ result, "error_message", "Input guardrail failed"
627
+ )
628
+ )
629
+ ),
519
630
  )
520
631
 
521
632
  # Agent debugging logs removed for performance
522
633
 
523
634
  # Get model name
524
635
  model = (
525
- config.model_override or
526
- (current_agent.model_config.name if current_agent.model_config else None) or
527
- "gpt-4o"
636
+ config.model_override
637
+ or (current_agent.model_config.name if current_agent.model_config else None)
638
+ or "gpt-4o"
528
639
  )
529
640
 
530
641
  # Apply before_llm_call callback if provided
@@ -540,24 +651,30 @@ async def _run_internal(
540
651
 
541
652
  # Emit LLM call start event
542
653
  if config.on_event:
543
- config.on_event(LLMCallStartEvent(data=to_event_data(LLMCallStartEventData(
544
- agent_name=current_agent.name,
545
- model=model,
546
- trace_id=state.trace_id,
547
- run_id=state.run_id,
548
- context=state.context,
549
- messages=state.messages
550
- ))))
654
+ config.on_event(
655
+ LLMCallStartEvent(
656
+ data=to_event_data(
657
+ LLMCallStartEventData(
658
+ agent_name=current_agent.name,
659
+ model=model,
660
+ trace_id=state.trace_id,
661
+ run_id=state.run_id,
662
+ context=state.context,
663
+ messages=state.messages,
664
+ )
665
+ )
666
+ )
667
+ )
551
668
 
552
669
  # Retry logic for empty LLM responses
553
670
  llm_response: Dict[str, Any]
554
671
  assistant_event_streamed = False
555
-
672
+
556
673
  for retry_attempt in range(config.max_empty_response_retries + 1):
557
674
  # Get completion from model provider
558
675
  # Check if streaming should be used based on configuration and availability
559
676
  get_stream = getattr(config.model_provider, "get_completion_stream", None)
560
- use_streaming = (config.prefer_streaming != False and callable(get_stream))
677
+ use_streaming = config.prefer_streaming != False and callable(get_stream)
561
678
 
562
679
  if use_streaming:
563
680
  try:
@@ -577,11 +694,13 @@ async def _run_internal(
577
694
  idx = getattr(tcd, "index", 0) or 0
578
695
  # Ensure slot exists
579
696
  while len(partial_tool_calls) <= idx:
580
- partial_tool_calls.append({
581
- "id": None,
582
- "type": "function",
583
- "function": {"name": None, "arguments": ""}
584
- })
697
+ partial_tool_calls.append(
698
+ {
699
+ "id": None,
700
+ "type": "function",
701
+ "function": {"name": None, "arguments": ""},
702
+ }
703
+ )
585
704
  target = partial_tool_calls[idx]
586
705
  # id
587
706
  tc_id = getattr(tcd, "id", None)
@@ -608,34 +727,45 @@ async def _run_internal(
608
727
  arguments = tc["function"]["arguments"]
609
728
  if isinstance(arguments, str):
610
729
  arguments = _normalize_tool_call_arguments(arguments)
611
- message_tool_calls.append({
612
- "id": tc["id"] or f"call_{i}",
613
- "type": "function",
614
- "function": {
615
- "name": tc["function"]["name"] or "",
616
- "arguments": arguments
730
+ message_tool_calls.append(
731
+ {
732
+ "id": tc["id"] or f"call_{i}",
733
+ "type": "function",
734
+ "function": {
735
+ "name": tc["function"]["name"] or "",
736
+ "arguments": arguments,
737
+ },
617
738
  }
618
- })
739
+ )
619
740
 
620
741
  partial_msg = Message(
621
742
  role=ContentRole.ASSISTANT,
622
743
  content=aggregated_text or "",
623
- tool_calls=None if not message_tool_calls else [
744
+ tool_calls=None
745
+ if not message_tool_calls
746
+ else [
624
747
  ToolCall(
625
748
  id=mc["id"],
626
749
  type="function",
627
750
  function=ToolCallFunction(
628
751
  name=mc["function"]["name"],
629
- arguments=_normalize_tool_call_arguments(mc["function"]["arguments"])
752
+ arguments=_normalize_tool_call_arguments(
753
+ mc["function"]["arguments"]
754
+ ),
630
755
  ),
631
- ) for mc in message_tool_calls
756
+ )
757
+ for mc in message_tool_calls
632
758
  ],
633
759
  )
634
760
  try:
635
761
  if config.on_event:
636
- config.on_event(AssistantMessageEvent(data=to_event_data(
637
- AssistantMessageEventData(message=partial_msg)
638
- )))
762
+ config.on_event(
763
+ AssistantMessageEvent(
764
+ data=to_event_data(
765
+ AssistantMessageEventData(message=partial_msg)
766
+ )
767
+ )
768
+ )
639
769
  except Exception as _e:
640
770
  # Do not fail the run on callback errors
641
771
  pass
@@ -648,50 +778,61 @@ async def _run_internal(
648
778
  arguments = tc["function"]["arguments"]
649
779
  if isinstance(arguments, str):
650
780
  arguments = _normalize_tool_call_arguments(arguments)
651
- final_tool_calls.append({
652
- "id": tc["id"] or f"call_{i}",
653
- "type": "function",
654
- "function": {
655
- "name": tc["function"]["name"] or "",
656
- "arguments": arguments
781
+ final_tool_calls.append(
782
+ {
783
+ "id": tc["id"] or f"call_{i}",
784
+ "type": "function",
785
+ "function": {
786
+ "name": tc["function"]["name"] or "",
787
+ "arguments": arguments,
788
+ },
657
789
  }
658
- })
790
+ )
659
791
 
660
792
  llm_response = {
661
- "message": {
662
- "content": aggregated_text or None,
663
- "tool_calls": final_tool_calls
664
- }
793
+ "message": {"content": aggregated_text or None, "tool_calls": final_tool_calls}
665
794
  }
666
795
  except Exception:
667
796
  # Fallback to non-streaming on error
668
797
  assistant_event_streamed = False
669
- llm_response = await config.model_provider.get_completion(state, current_agent, config)
798
+ llm_response = await config.model_provider.get_completion(
799
+ state, current_agent, config
800
+ )
670
801
  else:
671
802
  llm_response = await config.model_provider.get_completion(state, current_agent, config)
672
-
803
+
673
804
  # Check if response has meaningful content
674
- has_content = llm_response.get('message', {}).get('content')
675
- has_tool_calls = llm_response.get('message', {}).get('tool_calls')
676
-
805
+ has_content = llm_response.get("message", {}).get("content")
806
+ has_tool_calls = llm_response.get("message", {}).get("tool_calls")
807
+
677
808
  # If we got a valid response, break out of retry loop
678
809
  if has_content or has_tool_calls:
679
810
  break
680
-
811
+
681
812
  # If this is not the last attempt, retry with exponential backoff
682
813
  if retry_attempt < config.max_empty_response_retries:
683
- delay = config.empty_response_retry_delay * (2 ** retry_attempt)
814
+ delay = config.empty_response_retry_delay * (2**retry_attempt)
684
815
  if config.log_empty_responses:
685
- print(f"[JAF:ENGINE] Empty LLM response on attempt {retry_attempt + 1}/{config.max_empty_response_retries + 1}, retrying in {delay:.1f}s...")
686
- print(f"[JAF:ENGINE] Response had message: {bool(llm_response.get('message'))}, content: {bool(has_content)}, tool_calls: {bool(has_tool_calls)}")
816
+ print(
817
+ f"[JAF:ENGINE] Empty LLM response on attempt {retry_attempt + 1}/{config.max_empty_response_retries + 1}, retrying in {delay:.1f}s..."
818
+ )
819
+ print(
820
+ f"[JAF:ENGINE] Response had message: {bool(llm_response.get('message'))}, content: {bool(has_content)}, tool_calls: {bool(has_tool_calls)}"
821
+ )
687
822
  await asyncio.sleep(delay)
688
823
  else:
689
824
  # Last attempt failed, log detailed diagnostic info
690
825
  if config.log_empty_responses:
691
- print(f"[JAF:ENGINE] Empty LLM response after {config.max_empty_response_retries + 1} attempts")
826
+ print(
827
+ f"[JAF:ENGINE] Empty LLM response after {config.max_empty_response_retries + 1} attempts"
828
+ )
692
829
  print(f"[JAF:ENGINE] Agent: {current_agent.name}, Model: {model}")
693
- print(f"[JAF:ENGINE] Message count: {len(state.messages)}, Turn: {state.turn_count}")
694
- print(f"[JAF:ENGINE] Response structure: {json.dumps(llm_response, indent=2)[:1000]}")
830
+ print(
831
+ f"[JAF:ENGINE] Message count: {len(state.messages)}, Turn: {state.turn_count}"
832
+ )
833
+ print(
834
+ f"[JAF:ENGINE] Response structure: {json.dumps(llm_response, indent=2)[:1000]}"
835
+ )
695
836
 
696
837
  # Apply after_llm_call callback if provided
697
838
  if config.after_llm_call:
@@ -706,30 +847,34 @@ async def _run_internal(
706
847
 
707
848
  # Emit LLM call end event
708
849
  if config.on_event:
709
- config.on_event(LLMCallEndEvent(data=to_event_data(LLMCallEndEventData(
710
- choice=llm_response,
711
- trace_id=state.trace_id,
712
- run_id=state.run_id,
713
- usage=llm_response.get("usage")
714
- ))))
850
+ config.on_event(
851
+ LLMCallEndEvent(
852
+ data=to_event_data(
853
+ LLMCallEndEventData(
854
+ choice=llm_response,
855
+ trace_id=state.trace_id,
856
+ run_id=state.run_id,
857
+ usage=llm_response.get("usage"),
858
+ )
859
+ )
860
+ )
861
+ )
715
862
 
716
863
  # Check if response has message
717
- if not llm_response.get('message'):
864
+ if not llm_response.get("message"):
718
865
  if config.log_empty_responses:
719
866
  print(f"[JAF:ENGINE] ERROR: No message in LLM response")
720
867
  print(f"[JAF:ENGINE] Response structure: {json.dumps(llm_response, indent=2)[:500]}")
721
868
  return RunResult(
722
869
  final_state=state,
723
- outcome=ErrorOutcome(error=ModelBehaviorError(
724
- detail='No message in model response'
725
- ))
870
+ outcome=ErrorOutcome(error=ModelBehaviorError(detail="No message in model response")),
726
871
  )
727
872
 
728
873
  # Create assistant message
729
874
  assistant_message = Message(
730
875
  role=ContentRole.ASSISTANT,
731
- content=llm_response['message'].get('content') or '',
732
- tool_calls=_convert_tool_calls(llm_response['message'].get('tool_calls'))
876
+ content=llm_response["message"].get("content") or "",
877
+ tool_calls=_convert_tool_calls(llm_response["message"].get("tool_calls")),
733
878
  )
734
879
 
735
880
  new_messages = list(state.messages) + [assistant_message]
@@ -737,83 +882,97 @@ async def _run_internal(
737
882
  # Handle tool calls
738
883
  if assistant_message.tool_calls:
739
884
  tool_results = await _execute_tool_calls(
740
- assistant_message.tool_calls,
741
- current_agent,
742
- state,
743
- config
885
+ assistant_message.tool_calls, current_agent, state, config
744
886
  )
745
887
 
746
888
  # Check for interruptions
747
- interruptions = [r.get('interruption') for r in tool_results if r.get('interruption')]
889
+ interruptions = [r.get("interruption") for r in tool_results if r.get("interruption")]
748
890
  if interruptions:
749
891
  # Separate completed tool results from interrupted ones
750
- completed_results = [r for r in tool_results if not r.get('interruption')]
751
- approval_required_results = [r for r in tool_results if r.get('interruption')]
752
-
892
+ completed_results = [r for r in tool_results if not r.get("interruption")]
893
+ approval_required_results = [r for r in tool_results if r.get("interruption")]
894
+
753
895
  # Add pending approvals to state.approvals
754
896
  updated_approvals = dict(state.approvals)
755
897
  for interruption in interruptions:
756
- if interruption.type == 'tool_approval':
898
+ if interruption.type == "tool_approval":
757
899
  updated_approvals[interruption.tool_call.id] = ApprovalValue(
758
- status='pending',
900
+ status="pending",
759
901
  approved=False,
760
- additional_context={'status': 'pending', 'timestamp': str(int(time.time() * 1000))}
902
+ additional_context={
903
+ "status": "pending",
904
+ "timestamp": str(int(time.time() * 1000)),
905
+ },
761
906
  )
762
907
 
763
908
  # Create state with only completed tool results (for LLM context)
764
909
  interrupted_state = replace(
765
910
  state,
766
- messages=new_messages + [r['message'] for r in completed_results],
911
+ messages=new_messages + [r["message"] for r in completed_results],
767
912
  turn_count=state.turn_count + 1,
768
- approvals=updated_approvals
913
+ approvals=updated_approvals,
769
914
  )
770
-
915
+
771
916
  # Store conversation state with ALL messages including approval-required (for database records)
772
917
  if config.memory and config.memory.auto_store and config.conversation_id:
773
- print(f'[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}')
918
+ print(
919
+ f"[JAF:ENGINE] Storing conversation state due to interruption for {config.conversation_id}"
920
+ )
774
921
  state_for_storage = replace(
775
922
  interrupted_state,
776
- messages=interrupted_state.messages + [r['message'] for r in approval_required_results]
923
+ messages=interrupted_state.messages
924
+ + [r["message"] for r in approval_required_results],
777
925
  )
778
926
  await _store_conversation_history(state_for_storage, config)
779
-
927
+
780
928
  return RunResult(
781
929
  final_state=interrupted_state,
782
- outcome=InterruptedOutcome(interruptions=interruptions)
930
+ outcome=InterruptedOutcome(interruptions=interruptions),
783
931
  )
784
932
 
785
933
  # Check for handoffs
786
- handoff_result = next((r for r in tool_results if r.get('is_handoff')), None)
934
+ handoff_result = next((r for r in tool_results if r.get("is_handoff")), None)
787
935
  if handoff_result:
788
- target_agent = handoff_result['target_agent']
936
+ target_agent = handoff_result["target_agent"]
789
937
 
790
938
  # Validate handoff permission
791
939
  if not current_agent.handoffs or target_agent not in current_agent.handoffs:
792
940
  return RunResult(
793
941
  final_state=replace(state, messages=new_messages),
794
- outcome=ErrorOutcome(error=HandoffError(
795
- detail=f"Agent {current_agent.name} cannot handoff to {target_agent}"
796
- ))
942
+ outcome=ErrorOutcome(
943
+ error=HandoffError(
944
+ detail=f"Agent {current_agent.name} cannot handoff to {target_agent}"
945
+ )
946
+ ),
797
947
  )
798
948
 
799
949
  # Emit handoff event
800
950
  if config.on_event:
801
- config.on_event(HandoffEvent(data=to_event_data(HandoffEventData(
802
- from_=current_agent.name,
803
- to=target_agent
804
- ))))
951
+ config.on_event(
952
+ HandoffEvent(
953
+ data=to_event_data(
954
+ HandoffEventData(from_=current_agent.name, to=target_agent)
955
+ )
956
+ )
957
+ )
805
958
 
806
959
  # Remove any halted messages that are being replaced by actual execution results
807
960
  cleaned_new_messages = []
808
961
  for msg in new_messages:
809
- if msg.role not in (ContentRole.TOOL, 'tool'):
962
+ if msg.role not in (ContentRole.TOOL, "tool"):
810
963
  cleaned_new_messages.append(msg)
811
964
  else:
812
965
  try:
813
966
  content = json.loads(msg.content)
814
- if content.get('status') == 'halted' or content.get('hitl_status') == 'pending_approval':
967
+ if (
968
+ content.get("status") == "halted"
969
+ or content.get("hitl_status") == "pending_approval"
970
+ ):
815
971
  # Remove this halted message if we have a new result for the same tool_call_id
816
- if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
972
+ if not any(
973
+ result["message"].tool_call_id == msg.tool_call_id
974
+ for result in tool_results
975
+ ):
817
976
  cleaned_new_messages.append(msg)
818
977
  else:
819
978
  cleaned_new_messages.append(msg)
@@ -823,10 +982,10 @@ async def _run_internal(
823
982
  # Continue with new agent
824
983
  next_state = replace(
825
984
  state,
826
- messages=cleaned_new_messages + [r['message'] for r in tool_results],
985
+ messages=cleaned_new_messages + [r["message"] for r in tool_results],
827
986
  current_agent_name=target_agent,
828
987
  turn_count=state.turn_count + 1,
829
- approvals=state.approvals
988
+ approvals=state.approvals,
830
989
  )
831
990
 
832
991
  return await _run_internal(next_state, config)
@@ -834,14 +993,20 @@ async def _run_internal(
834
993
  # Remove any halted messages that are being replaced by actual execution results
835
994
  cleaned_new_messages = []
836
995
  for msg in new_messages:
837
- if msg.role not in (ContentRole.TOOL, 'tool'):
996
+ if msg.role not in (ContentRole.TOOL, "tool"):
838
997
  cleaned_new_messages.append(msg)
839
998
  else:
840
999
  try:
841
1000
  content = json.loads(msg.content)
842
- if content.get('status') == 'halted' or content.get('hitl_status') == 'pending_approval':
1001
+ if (
1002
+ content.get("status") == "halted"
1003
+ or content.get("hitl_status") == "pending_approval"
1004
+ ):
843
1005
  # Remove this halted message if we have a new result for the same tool_call_id
844
- if not any(result['message'].tool_call_id == msg.tool_call_id for result in tool_results):
1006
+ if not any(
1007
+ result["message"].tool_call_id == msg.tool_call_id
1008
+ for result in tool_results
1009
+ ):
845
1010
  cleaned_new_messages.append(msg)
846
1011
  else:
847
1012
  cleaned_new_messages.append(msg)
@@ -851,9 +1016,9 @@ async def _run_internal(
851
1016
  # Continue with tool results
852
1017
  next_state = replace(
853
1018
  state,
854
- messages=cleaned_new_messages + [r['message'] for r in tool_results],
1019
+ messages=cleaned_new_messages + [r["message"] for r in tool_results],
855
1020
  turn_count=state.turn_count + 1,
856
- approvals=state.approvals
1021
+ approvals=state.approvals,
857
1022
  )
858
1023
 
859
1024
  return await _run_internal(next_state, config)
@@ -863,19 +1028,26 @@ async def _run_internal(
863
1028
  if current_agent.output_codec:
864
1029
  # Parse with output codec
865
1030
  if config.on_event:
866
- config.on_event(OutputParseEvent(data=OutputParseEventData(
867
- content=get_text_content(assistant_message.content),
868
- status='start'
869
- )))
1031
+ config.on_event(
1032
+ OutputParseEvent(
1033
+ data=OutputParseEventData(
1034
+ content=get_text_content(assistant_message.content), status="start"
1035
+ )
1036
+ )
1037
+ )
870
1038
  try:
871
1039
  parsed_content = _try_parse_json(get_text_content(assistant_message.content))
872
1040
  output_data = current_agent.output_codec.model_validate(parsed_content)
873
1041
  if config.on_event:
874
- config.on_event(OutputParseEvent(data=OutputParseEventData(
875
- content=get_text_content(assistant_message.content),
876
- status='end',
877
- parsed_output=output_data
878
- )))
1042
+ config.on_event(
1043
+ OutputParseEvent(
1044
+ data=OutputParseEventData(
1045
+ content=get_text_content(assistant_message.content),
1046
+ status="end",
1047
+ parsed_output=output_data,
1048
+ )
1049
+ )
1050
+ )
879
1051
 
880
1052
  # Check final output guardrails
881
1053
  if has_advanced_guardrails:
@@ -886,19 +1058,31 @@ async def _run_internal(
886
1058
  if not output_guardrail_result.is_valid:
887
1059
  return RunResult(
888
1060
  final_state=replace(state, messages=new_messages),
889
- outcome=ErrorOutcome(error=OutputGuardrailTripwire(
890
- reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
891
- ))
1061
+ outcome=ErrorOutcome(
1062
+ error=OutputGuardrailTripwire(
1063
+ reason=getattr(
1064
+ output_guardrail_result,
1065
+ "error_message",
1066
+ "Output guardrail violation",
1067
+ )
1068
+ )
1069
+ ),
892
1070
  )
893
1071
  else:
894
1072
  # Legacy system
895
1073
  if effective_output_guardrails:
896
1074
  for guardrail in effective_output_guardrails:
897
1075
  if config.on_event:
898
- config.on_event(GuardrailEvent(data=GuardrailEventData(
899
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
900
- content=output_data
901
- )))
1076
+ config.on_event(
1077
+ GuardrailEvent(
1078
+ data=GuardrailEventData(
1079
+ guardrail_name=getattr(
1080
+ guardrail, "__name__", "unknown_guardrail"
1081
+ ),
1082
+ content=output_data,
1083
+ )
1084
+ )
1085
+ )
902
1086
  if asyncio.iscoroutinefunction(guardrail):
903
1087
  result = await guardrail(output_data)
904
1088
  else:
@@ -906,34 +1090,55 @@ async def _run_internal(
906
1090
 
907
1091
  if not result.is_valid:
908
1092
  if config.on_event:
909
- config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
910
- stage='output',
911
- reason=getattr(result, 'error_message', 'Output guardrail failed')
912
- )))
1093
+ config.on_event(
1094
+ GuardrailViolationEvent(
1095
+ data=GuardrailViolationEventData(
1096
+ stage="output",
1097
+ reason=getattr(
1098
+ result, "error_message", "Output guardrail failed"
1099
+ ),
1100
+ )
1101
+ )
1102
+ )
913
1103
  return RunResult(
914
- final_state=replace(state, messages=new_messages, approvals=state.approvals),
915
- outcome=ErrorOutcome(error=OutputGuardrailTripwire(
916
- reason=getattr(result, 'error_message', 'Output guardrail failed')
917
- ))
1104
+ final_state=replace(
1105
+ state, messages=new_messages, approvals=state.approvals
1106
+ ),
1107
+ outcome=ErrorOutcome(
1108
+ error=OutputGuardrailTripwire(
1109
+ reason=getattr(
1110
+ result, "error_message", "Output guardrail failed"
1111
+ )
1112
+ )
1113
+ ),
918
1114
  )
919
1115
 
920
1116
  return RunResult(
921
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
922
- outcome=CompletedOutcome(output=output_data)
1117
+ final_state=replace(
1118
+ state,
1119
+ messages=new_messages,
1120
+ turn_count=state.turn_count + 1,
1121
+ approvals=state.approvals,
1122
+ ),
1123
+ outcome=CompletedOutcome(output=output_data),
923
1124
  )
924
1125
 
925
1126
  except ValidationError as e:
926
1127
  if config.on_event:
927
- config.on_event(OutputParseEvent(data=OutputParseEventData(
928
- content=get_text_content(assistant_message.content),
929
- status='fail',
930
- error=str(e)
931
- )))
1128
+ config.on_event(
1129
+ OutputParseEvent(
1130
+ data=OutputParseEventData(
1131
+ content=get_text_content(assistant_message.content),
1132
+ status="fail",
1133
+ error=str(e),
1134
+ )
1135
+ )
1136
+ )
932
1137
  return RunResult(
933
1138
  final_state=replace(state, messages=new_messages, approvals=state.approvals),
934
- outcome=ErrorOutcome(error=DecodeError(
935
- errors=[{'message': str(e), 'details': e.errors()}]
936
- ))
1139
+ outcome=ErrorOutcome(
1140
+ error=DecodeError(errors=[{"message": str(e), "details": e.errors()}])
1141
+ ),
937
1142
  )
938
1143
  else:
939
1144
  # No output codec, return content as string
@@ -945,19 +1150,31 @@ async def _run_internal(
945
1150
  if not output_guardrail_result.is_valid:
946
1151
  return RunResult(
947
1152
  final_state=replace(state, messages=new_messages),
948
- outcome=ErrorOutcome(error=OutputGuardrailTripwire(
949
- reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
950
- ))
1153
+ outcome=ErrorOutcome(
1154
+ error=OutputGuardrailTripwire(
1155
+ reason=getattr(
1156
+ output_guardrail_result,
1157
+ "error_message",
1158
+ "Output guardrail violation",
1159
+ )
1160
+ )
1161
+ ),
951
1162
  )
952
1163
  else:
953
1164
  # Legacy system
954
1165
  if effective_output_guardrails:
955
1166
  for guardrail in effective_output_guardrails:
956
1167
  if config.on_event:
957
- config.on_event(GuardrailEvent(data=GuardrailEventData(
958
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
959
- content=get_text_content(assistant_message.content)
960
- )))
1168
+ config.on_event(
1169
+ GuardrailEvent(
1170
+ data=GuardrailEventData(
1171
+ guardrail_name=getattr(
1172
+ guardrail, "__name__", "unknown_guardrail"
1173
+ ),
1174
+ content=get_text_content(assistant_message.content),
1175
+ )
1176
+ )
1177
+ )
961
1178
  if asyncio.iscoroutinefunction(guardrail):
962
1179
  result = await guardrail(get_text_content(assistant_message.content))
963
1180
  else:
@@ -965,30 +1182,48 @@ async def _run_internal(
965
1182
 
966
1183
  if not result.is_valid:
967
1184
  if config.on_event:
968
- config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
969
- stage='output',
970
- reason=getattr(result, 'error_message', 'Output guardrail failed')
971
- )))
1185
+ config.on_event(
1186
+ GuardrailViolationEvent(
1187
+ data=GuardrailViolationEventData(
1188
+ stage="output",
1189
+ reason=getattr(
1190
+ result, "error_message", "Output guardrail failed"
1191
+ ),
1192
+ )
1193
+ )
1194
+ )
972
1195
  return RunResult(
973
- final_state=replace(state, messages=new_messages, approvals=state.approvals),
974
- outcome=ErrorOutcome(error=OutputGuardrailTripwire(
975
- reason=getattr(result, 'error_message', 'Output guardrail failed')
976
- ))
1196
+ final_state=replace(
1197
+ state, messages=new_messages, approvals=state.approvals
1198
+ ),
1199
+ outcome=ErrorOutcome(
1200
+ error=OutputGuardrailTripwire(
1201
+ reason=getattr(
1202
+ result, "error_message", "Output guardrail failed"
1203
+ )
1204
+ )
1205
+ ),
977
1206
  )
978
1207
 
979
1208
  return RunResult(
980
- final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
981
- outcome=CompletedOutcome(output=get_text_content(assistant_message.content))
1209
+ final_state=replace(
1210
+ state,
1211
+ messages=new_messages,
1212
+ turn_count=state.turn_count + 1,
1213
+ approvals=state.approvals,
1214
+ ),
1215
+ outcome=CompletedOutcome(output=get_text_content(assistant_message.content)),
982
1216
  )
983
1217
 
984
1218
  # Model produced neither content nor tool calls
985
1219
  return RunResult(
986
1220
  final_state=replace(state, messages=new_messages, approvals=state.approvals),
987
- outcome=ErrorOutcome(error=ModelBehaviorError(
988
- detail='Model produced neither content nor tool calls'
989
- ))
1221
+ outcome=ErrorOutcome(
1222
+ error=ModelBehaviorError(detail="Model produced neither content nor tool calls")
1223
+ ),
990
1224
  )
991
1225
 
1226
+
992
1227
  def _convert_tool_calls(tool_calls: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolCall]]:
993
1228
  """Convert API tool calls to internal format."""
994
1229
  if not tool_calls:
@@ -996,12 +1231,12 @@ def _convert_tool_calls(tool_calls: Optional[List[Dict[str, Any]]]) -> Optional[
996
1231
 
997
1232
  return [
998
1233
  ToolCall(
999
- id=tc['id'],
1000
- type='function',
1234
+ id=tc["id"],
1235
+ type="function",
1001
1236
  function=ToolCallFunction(
1002
- name=tc['function']['name'],
1003
- arguments=_normalize_tool_call_arguments(tc['function']['arguments'])
1004
- )
1237
+ name=tc["function"]["name"],
1238
+ arguments=_normalize_tool_call_arguments(tc["function"]["arguments"]),
1239
+ ),
1005
1240
  )
1006
1241
  for tc in tool_calls
1007
1242
  ]
@@ -1027,24 +1262,28 @@ def _normalize_tool_call_arguments(arguments: Any) -> Any:
1027
1262
 
1028
1263
  return arguments
1029
1264
 
1265
+
1030
1266
  async def _execute_tool_calls(
1031
- tool_calls: List[ToolCall],
1032
- agent: Agent[Ctx, Any],
1033
- state: RunState[Ctx],
1034
- config: RunConfig[Ctx]
1267
+ tool_calls: List[ToolCall], agent: Agent[Ctx, Any], state: RunState[Ctx], config: RunConfig[Ctx]
1035
1268
  ) -> List[Dict[str, Any]]:
1036
1269
  """Execute tool calls and return results."""
1037
1270
 
1038
1271
  async def execute_single_tool_call(tool_call: ToolCall) -> Dict[str, Any]:
1039
- print(f'[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}')
1272
+ print(f"[JAF:TOOL-EXEC] Starting execute_single_tool_call for {tool_call.function.name}")
1040
1273
  if config.on_event:
1041
- config.on_event(ToolCallStartEvent(data=to_event_data(ToolCallStartEventData(
1042
- tool_name=tool_call.function.name,
1043
- args=_try_parse_json(tool_call.function.arguments),
1044
- trace_id=state.trace_id,
1045
- run_id=state.run_id,
1046
- call_id=tool_call.id
1047
- ))))
1274
+ config.on_event(
1275
+ ToolCallStartEvent(
1276
+ data=to_event_data(
1277
+ ToolCallStartEventData(
1278
+ tool_name=tool_call.function.name,
1279
+ args=_try_parse_json(tool_call.function.arguments),
1280
+ trace_id=state.trace_id,
1281
+ run_id=state.run_id,
1282
+ call_id=tool_call.id,
1283
+ )
1284
+ )
1285
+ )
1286
+ )
1048
1287
 
1049
1288
  try:
1050
1289
  # Find the tool
@@ -1056,28 +1295,34 @@ async def _execute_tool_calls(
1056
1295
  break
1057
1296
 
1058
1297
  if not tool:
1059
- error_result = json.dumps({
1060
- 'hitl_status': 'tool_not_found', # HITL workflow status
1061
- 'message': f'Tool {tool_call.function.name} not found',
1062
- 'tool_name': tool_call.function.name,
1063
- })
1298
+ error_result = json.dumps(
1299
+ {
1300
+ "hitl_status": "tool_not_found", # HITL workflow status
1301
+ "message": f"Tool {tool_call.function.name} not found",
1302
+ "tool_name": tool_call.function.name,
1303
+ }
1304
+ )
1064
1305
 
1065
1306
  if config.on_event:
1066
- config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
1067
- tool_name=tool_call.function.name,
1068
- result=error_result,
1069
- trace_id=state.trace_id,
1070
- run_id=state.run_id,
1071
- execution_status='error', # Tool execution failed
1072
- tool_result={'error': 'tool_not_found'},
1073
- call_id=tool_call.id
1074
- ))))
1307
+ config.on_event(
1308
+ ToolCallEndEvent(
1309
+ data=to_event_data(
1310
+ ToolCallEndEventData(
1311
+ tool_name=tool_call.function.name,
1312
+ result=error_result,
1313
+ trace_id=state.trace_id,
1314
+ run_id=state.run_id,
1315
+ execution_status="error", # Tool execution failed
1316
+ tool_result={"error": "tool_not_found"},
1317
+ call_id=tool_call.id,
1318
+ )
1319
+ )
1320
+ )
1321
+ )
1075
1322
 
1076
1323
  return {
1077
- 'message': Message(
1078
- role=ContentRole.TOOL,
1079
- content=error_result,
1080
- tool_call_id=tool_call.id
1324
+ "message": Message(
1325
+ role=ContentRole.TOOL, content=error_result, tool_call_id=tool_call.id
1081
1326
  )
1082
1327
  }
1083
1328
 
@@ -1085,54 +1330,66 @@ async def _execute_tool_calls(
1085
1330
  raw_args = _try_parse_json(tool_call.function.arguments)
1086
1331
  try:
1087
1332
  # Assuming the tool schema parameters is a Pydantic model
1088
- if hasattr(tool.schema.parameters, 'model_validate'):
1333
+ if hasattr(tool.schema.parameters, "model_validate"):
1089
1334
  validated_args = tool.schema.parameters.model_validate(raw_args)
1090
1335
  else:
1091
1336
  validated_args = raw_args
1092
1337
  except ValidationError as e:
1093
- error_result = json.dumps({
1094
- 'hitl_status': 'validation_error', # HITL workflow status
1095
- 'message': f'Invalid arguments for {tool_call.function.name}: {e!s}',
1096
- 'tool_name': tool_call.function.name,
1097
- 'validation_errors': e.errors()
1098
- })
1338
+ error_result = json.dumps(
1339
+ {
1340
+ "hitl_status": "validation_error", # HITL workflow status
1341
+ "message": f"Invalid arguments for {tool_call.function.name}: {e!s}",
1342
+ "tool_name": tool_call.function.name,
1343
+ "validation_errors": e.errors(),
1344
+ }
1345
+ )
1099
1346
 
1100
1347
  if config.on_event:
1101
- config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
1102
- tool_name=tool_call.function.name,
1103
- result=error_result,
1104
- trace_id=state.trace_id,
1105
- run_id=state.run_id,
1106
- execution_status='error', # Tool execution failed due to validation
1107
- tool_result={'error': 'validation_error', 'details': e.errors()},
1108
- call_id=tool_call.id
1109
- ))))
1348
+ config.on_event(
1349
+ ToolCallEndEvent(
1350
+ data=to_event_data(
1351
+ ToolCallEndEventData(
1352
+ tool_name=tool_call.function.name,
1353
+ result=error_result,
1354
+ trace_id=state.trace_id,
1355
+ run_id=state.run_id,
1356
+ execution_status="error", # Tool execution failed due to validation
1357
+ tool_result={
1358
+ "error": "validation_error",
1359
+ "details": e.errors(),
1360
+ },
1361
+ call_id=tool_call.id,
1362
+ )
1363
+ )
1364
+ )
1365
+ )
1110
1366
 
1111
1367
  return {
1112
- 'message': Message(
1113
- role=ContentRole.TOOL,
1114
- content=error_result,
1115
- tool_call_id=tool_call.id
1368
+ "message": Message(
1369
+ role=ContentRole.TOOL, content=error_result, tool_call_id=tool_call.id
1116
1370
  )
1117
1371
  }
1118
1372
 
1119
1373
  # Check if tool needs approval
1120
1374
  needs_approval = False
1121
- approval_func = getattr(tool, 'needs_approval', False)
1375
+ approval_func = getattr(tool, "needs_approval", False)
1122
1376
  if callable(approval_func):
1123
1377
  needs_approval = await approval_func(state.context, validated_args)
1124
1378
  else:
1125
1379
  needs_approval = bool(approval_func)
1126
-
1380
+
1127
1381
  # Check approval status - first by ID, then by signature for cross-session matching
1128
1382
  approval_status = state.approvals.get(tool_call.id)
1129
1383
  if not approval_status:
1130
1384
  signature = f"{tool_call.function.name}:{tool_call.function.arguments}"
1131
1385
  for _, approval in state.approvals.items():
1132
- if approval.additional_context and approval.additional_context.get('signature') == signature:
1386
+ if (
1387
+ approval.additional_context
1388
+ and approval.additional_context.get("signature") == signature
1389
+ ):
1133
1390
  approval_status = approval
1134
1391
  break
1135
-
1392
+
1136
1393
  derived_status = None
1137
1394
  if approval_status:
1138
1395
  # Use explicit status if available
@@ -1140,76 +1397,93 @@ async def _execute_tool_calls(
1140
1397
  derived_status = approval_status.status
1141
1398
  # Fall back to approved boolean if status not set
1142
1399
  elif approval_status.approved is True:
1143
- derived_status = 'approved'
1400
+ derived_status = "approved"
1144
1401
  elif approval_status.approved is False:
1145
- if approval_status.additional_context and approval_status.additional_context.get('status') == 'pending':
1146
- derived_status = 'pending'
1402
+ if (
1403
+ approval_status.additional_context
1404
+ and approval_status.additional_context.get("status") == "pending"
1405
+ ):
1406
+ derived_status = "pending"
1147
1407
  else:
1148
- derived_status = 'rejected'
1408
+ derived_status = "rejected"
1149
1409
 
1150
- is_pending = derived_status == 'pending'
1410
+ is_pending = derived_status == "pending"
1151
1411
 
1152
1412
  # If approval needed and not yet decided, create interruption
1153
1413
  if needs_approval and (approval_status is None or is_pending):
1154
1414
  interruption = ToolApprovalInterruption(
1155
- type='tool_approval',
1415
+ type="tool_approval",
1156
1416
  tool_call=tool_call,
1157
1417
  agent=agent,
1158
- session_id=str(state.run_id)
1418
+ session_id=str(state.run_id),
1159
1419
  )
1160
-
1420
+
1161
1421
  # Return interrupted result with halted message
1162
- halted_result = json.dumps({
1163
- 'hitl_status': 'pending_approval', # HITL workflow status: waiting for approval
1164
- 'message': f'Tool {tool_call.function.name} requires approval.',
1165
- })
1166
-
1422
+ halted_result = json.dumps(
1423
+ {
1424
+ "hitl_status": "pending_approval", # HITL workflow status: waiting for approval
1425
+ "message": f"Tool {tool_call.function.name} requires approval.",
1426
+ }
1427
+ )
1428
+
1167
1429
  return {
1168
- 'message': Message(
1169
- role=ContentRole.TOOL,
1170
- content=halted_result,
1171
- tool_call_id=tool_call.id
1430
+ "message": Message(
1431
+ role=ContentRole.TOOL, content=halted_result, tool_call_id=tool_call.id
1172
1432
  ),
1173
- 'interruption': interruption
1433
+ "interruption": interruption,
1174
1434
  }
1175
1435
 
1176
1436
  # If approval was explicitly rejected, return rejection message
1177
- if derived_status == 'rejected':
1178
- rejection_reason = approval_status.additional_context.get('rejection_reason', 'User declined the action') if approval_status.additional_context else 'User declined the action'
1179
- rejection_result = json.dumps({
1180
- 'hitl_status': 'rejected', # HITL workflow status: user rejected the action
1181
- 'message': f'Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.',
1182
- 'tool_name': tool_call.function.name,
1183
- 'rejection_reason': rejection_reason,
1184
- 'additional_context': approval_status.additional_context if approval_status else None
1185
- })
1186
-
1437
+ if derived_status == "rejected":
1438
+ rejection_reason = (
1439
+ approval_status.additional_context.get(
1440
+ "rejection_reason", "User declined the action"
1441
+ )
1442
+ if approval_status.additional_context
1443
+ else "User declined the action"
1444
+ )
1445
+ rejection_result = json.dumps(
1446
+ {
1447
+ "hitl_status": "rejected", # HITL workflow status: user rejected the action
1448
+ "message": f"Action was not approved. {rejection_reason}. Please ask if you can help with something else or suggest an alternative approach.",
1449
+ "tool_name": tool_call.function.name,
1450
+ "rejection_reason": rejection_reason,
1451
+ "additional_context": approval_status.additional_context
1452
+ if approval_status
1453
+ else None,
1454
+ }
1455
+ )
1456
+
1187
1457
  return {
1188
- 'message': Message(
1189
- role=ContentRole.TOOL,
1190
- content=rejection_result,
1191
- tool_call_id=tool_call.id
1458
+ "message": Message(
1459
+ role=ContentRole.TOOL, content=rejection_result, tool_call_id=tool_call.id
1192
1460
  )
1193
1461
  }
1194
1462
 
1195
1463
  # Determine timeout for this tool
1196
1464
  # Priority: tool-specific timeout > RunConfig default > 30 seconds global default
1197
- if tool and hasattr(tool, 'schema'):
1198
- timeout = getattr(tool.schema, 'timeout', None)
1465
+ if tool and hasattr(tool, "schema"):
1466
+ timeout = getattr(tool.schema, "timeout", None)
1199
1467
  else:
1200
1468
  timeout = None
1201
1469
  if timeout is None:
1202
- timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 300.0
1470
+ timeout = (
1471
+ config.default_tool_timeout
1472
+ if config.default_tool_timeout is not None
1473
+ else 300.0
1474
+ )
1203
1475
 
1204
1476
  # Merge additional context if provided through approval
1205
1477
  additional_context = approval_status.additional_context if approval_status else None
1206
1478
  context_with_additional = state.context
1207
1479
  if additional_context:
1208
1480
  # Create a copy of context with additional fields from approval
1209
- if hasattr(state.context, '__dict__'):
1481
+ if hasattr(state.context, "__dict__"):
1210
1482
  # For dataclass contexts, add additional context as attributes
1211
1483
  context_dict = {**state.context.__dict__, **additional_context}
1212
- context_with_additional = type(state.context)(**{k: v for k, v in context_dict.items() if k in state.context.__dict__})
1484
+ context_with_additional = type(state.context)(
1485
+ **{k: v for k, v in context_dict.items() if k in state.context.__dict__}
1486
+ )
1213
1487
  # Add any extra fields as attributes
1214
1488
  for key, value in additional_context.items():
1215
1489
  if not hasattr(context_with_additional, key):
@@ -1217,143 +1491,167 @@ async def _execute_tool_calls(
1217
1491
  else:
1218
1492
  # For dict contexts, merge normally
1219
1493
  context_with_additional = {**state.context, **additional_context}
1220
-
1221
- print(f'[JAF:ENGINE] About to execute tool: {tool_call.function.name}')
1222
- print(f'[JAF:ENGINE] Tool args:', validated_args)
1223
- print(f'[JAF:ENGINE] Tool context:', state.context)
1224
-
1494
+
1495
+ print(f"[JAF:ENGINE] About to execute tool: {tool_call.function.name}")
1496
+ print(f"[JAF:ENGINE] Tool args:", validated_args)
1497
+ print(f"[JAF:ENGINE] Tool context:", state.context)
1498
+
1225
1499
  # Execute the tool with timeout
1226
1500
  try:
1227
1501
  tool_result = await asyncio.wait_for(
1228
- tool.execute(validated_args, context_with_additional),
1229
- timeout=timeout
1502
+ tool.execute(validated_args, context_with_additional), timeout=timeout
1230
1503
  )
1231
1504
  except asyncio.TimeoutError:
1232
- timeout_error_result = json.dumps({
1233
- 'hitl_status': 'execution_timeout', # HITL workflow status
1234
- 'message': f'Tool {tool_call.function.name} timed out after {timeout} seconds',
1235
- 'tool_name': tool_call.function.name,
1236
- 'timeout_seconds': timeout
1237
- })
1505
+ timeout_error_result = json.dumps(
1506
+ {
1507
+ "hitl_status": "execution_timeout", # HITL workflow status
1508
+ "message": f"Tool {tool_call.function.name} timed out after {timeout} seconds",
1509
+ "tool_name": tool_call.function.name,
1510
+ "timeout_seconds": timeout,
1511
+ }
1512
+ )
1238
1513
 
1239
1514
  if config.on_event:
1240
- config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
1241
- tool_name=tool_call.function.name,
1242
- result=timeout_error_result,
1243
- trace_id=state.trace_id,
1244
- run_id=state.run_id,
1245
- execution_status='timeout', # Tool execution timed out
1246
- tool_result={'error': 'timeout'},
1247
- call_id=tool_call.id
1248
- ))))
1515
+ config.on_event(
1516
+ ToolCallEndEvent(
1517
+ data=to_event_data(
1518
+ ToolCallEndEventData(
1519
+ tool_name=tool_call.function.name,
1520
+ result=timeout_error_result,
1521
+ trace_id=state.trace_id,
1522
+ run_id=state.run_id,
1523
+ execution_status="timeout", # Tool execution timed out
1524
+ tool_result={"error": "timeout"},
1525
+ call_id=tool_call.id,
1526
+ )
1527
+ )
1528
+ )
1529
+ )
1249
1530
 
1250
1531
  return {
1251
- 'message': Message(
1532
+ "message": Message(
1252
1533
  role=ContentRole.TOOL,
1253
1534
  content=timeout_error_result,
1254
- tool_call_id=tool_call.id
1535
+ tool_call_id=tool_call.id,
1255
1536
  )
1256
1537
  }
1257
1538
 
1258
1539
  # Handle both string and ToolResult formats
1259
1540
  if isinstance(tool_result, str):
1260
1541
  result_string = tool_result
1261
- print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned string:', result_string)
1542
+ print(
1543
+ f"[JAF:ENGINE] Tool {tool_call.function.name} returned string:", result_string
1544
+ )
1262
1545
  else:
1263
1546
  # It's a ToolResult object
1264
1547
  result_string = tool_result_to_string(tool_result)
1265
- print(f'[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:', tool_result)
1266
- print(f'[JAF:ENGINE] Converted to string:', result_string)
1548
+ print(
1549
+ f"[JAF:ENGINE] Tool {tool_call.function.name} returned ToolResult:", tool_result
1550
+ )
1551
+ print(f"[JAF:ENGINE] Converted to string:", result_string)
1267
1552
 
1268
1553
  # Wrap tool result with status information for approval context
1269
1554
  if approval_status and approval_status.additional_context:
1270
- final_content = json.dumps({
1271
- 'hitl_status': 'approved_and_executed', # HITL workflow status: approved by user and executed
1272
- 'result': result_string,
1273
- 'tool_name': tool_call.function.name,
1274
- 'approval_context': approval_status.additional_context,
1275
- 'message': 'Tool was approved and executed successfully with additional context.'
1276
- })
1555
+ final_content = json.dumps(
1556
+ {
1557
+ "hitl_status": "approved_and_executed", # HITL workflow status: approved by user and executed
1558
+ "result": result_string,
1559
+ "tool_name": tool_call.function.name,
1560
+ "approval_context": approval_status.additional_context,
1561
+ "message": "Tool was approved and executed successfully with additional context.",
1562
+ }
1563
+ )
1277
1564
  elif needs_approval:
1278
- final_content = json.dumps({
1279
- 'hitl_status': 'approved_and_executed', # HITL workflow status: approved by user and executed
1280
- 'result': result_string,
1281
- 'tool_name': tool_call.function.name,
1282
- 'message': 'Tool was approved and executed successfully.'
1283
- })
1565
+ final_content = json.dumps(
1566
+ {
1567
+ "hitl_status": "approved_and_executed", # HITL workflow status: approved by user and executed
1568
+ "result": result_string,
1569
+ "tool_name": tool_call.function.name,
1570
+ "message": "Tool was approved and executed successfully.",
1571
+ }
1572
+ )
1284
1573
  else:
1285
- final_content = json.dumps({
1286
- 'hitl_status': 'executed', # HITL workflow status: executed normally (no approval needed)
1287
- 'result': result_string,
1288
- 'tool_name': tool_call.function.name,
1289
- 'message': 'Tool executed successfully.'
1290
- })
1574
+ final_content = json.dumps(
1575
+ {
1576
+ "hitl_status": "executed", # HITL workflow status: executed normally (no approval needed)
1577
+ "result": result_string,
1578
+ "tool_name": tool_call.function.name,
1579
+ "message": "Tool executed successfully.",
1580
+ }
1581
+ )
1291
1582
 
1292
1583
  if config.on_event:
1293
- config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
1294
- tool_name=tool_call.function.name,
1295
- result=final_content,
1296
- trace_id=state.trace_id,
1297
- run_id=state.run_id,
1298
- tool_result=tool_result,
1299
- execution_status='success', # Tool execution succeeded
1300
- call_id=tool_call.id
1301
- ))))
1584
+ config.on_event(
1585
+ ToolCallEndEvent(
1586
+ data=to_event_data(
1587
+ ToolCallEndEventData(
1588
+ tool_name=tool_call.function.name,
1589
+ result=final_content,
1590
+ trace_id=state.trace_id,
1591
+ run_id=state.run_id,
1592
+ tool_result=tool_result,
1593
+ execution_status="success", # Tool execution succeeded
1594
+ call_id=tool_call.id,
1595
+ )
1596
+ )
1597
+ )
1598
+ )
1302
1599
 
1303
1600
  # Check for handoff
1304
1601
  handoff_check = _try_parse_json(result_string)
1305
- if isinstance(handoff_check, dict) and 'handoff_to' in handoff_check:
1602
+ if isinstance(handoff_check, dict) and "handoff_to" in handoff_check:
1306
1603
  return {
1307
- 'message': Message(
1308
- role=ContentRole.TOOL,
1309
- content=final_content,
1310
- tool_call_id=tool_call.id
1604
+ "message": Message(
1605
+ role=ContentRole.TOOL, content=final_content, tool_call_id=tool_call.id
1311
1606
  ),
1312
- 'is_handoff': True,
1313
- 'target_agent': handoff_check['handoff_to']
1607
+ "is_handoff": True,
1608
+ "target_agent": handoff_check["handoff_to"],
1314
1609
  }
1315
1610
 
1316
1611
  return {
1317
- 'message': Message(
1318
- role=ContentRole.TOOL,
1319
- content=final_content,
1320
- tool_call_id=tool_call.id
1612
+ "message": Message(
1613
+ role=ContentRole.TOOL, content=final_content, tool_call_id=tool_call.id
1321
1614
  )
1322
1615
  }
1323
1616
 
1324
1617
  except Exception as error:
1325
- error_result = json.dumps({
1326
- 'hitl_status': 'execution_error', # HITL workflow status
1327
- 'message': str(error),
1328
- 'tool_name': tool_call.function.name,
1329
- })
1618
+ error_result = json.dumps(
1619
+ {
1620
+ "hitl_status": "execution_error", # HITL workflow status
1621
+ "message": str(error),
1622
+ "tool_name": tool_call.function.name,
1623
+ }
1624
+ )
1330
1625
 
1331
1626
  if config.on_event:
1332
- config.on_event(ToolCallEndEvent(data=to_event_data(ToolCallEndEventData(
1333
- tool_name=tool_call.function.name,
1334
- result=error_result,
1335
- trace_id=state.trace_id,
1336
- run_id=state.run_id,
1337
- execution_status='error', # Tool execution failed with exception
1338
- tool_result={'error': 'execution_error', 'detail': str(error)},
1339
- call_id=tool_call.id
1340
- ))))
1627
+ config.on_event(
1628
+ ToolCallEndEvent(
1629
+ data=to_event_data(
1630
+ ToolCallEndEventData(
1631
+ tool_name=tool_call.function.name,
1632
+ result=error_result,
1633
+ trace_id=state.trace_id,
1634
+ run_id=state.run_id,
1635
+ execution_status="error", # Tool execution failed with exception
1636
+ tool_result={"error": "execution_error", "detail": str(error)},
1637
+ call_id=tool_call.id,
1638
+ )
1639
+ )
1640
+ )
1641
+ )
1341
1642
 
1342
1643
  return {
1343
- 'message': Message(
1344
- role=ContentRole.TOOL,
1345
- content=error_result,
1346
- tool_call_id=tool_call.id
1644
+ "message": Message(
1645
+ role=ContentRole.TOOL, content=error_result, tool_call_id=tool_call.id
1347
1646
  )
1348
1647
  }
1349
1648
 
1350
1649
  # Execute all tool calls in parallel
1351
- results = await asyncio.gather(*[
1352
- execute_single_tool_call(tc) for tc in tool_calls
1353
- ])
1650
+ results = await asyncio.gather(*[execute_single_tool_call(tc) for tc in tool_calls])
1354
1651
 
1355
1652
  return results
1356
1653
 
1654
+
1357
1655
  def _try_parse_json(text: str) -> Any:
1358
1656
  """Try to parse JSON, return original string if it fails."""
1359
1657
  if not text or not isinstance(text, str):