jaf-py 2.4.4__py3-none-any.whl → 2.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
jaf/core/__init__.py CHANGED
@@ -13,6 +13,14 @@ from .agent_tool import (
13
13
  get_current_run_config,
14
14
  set_current_run_config,
15
15
  )
16
+ from .parallel_agents import (
17
+ ParallelAgentGroup,
18
+ ParallelExecutionConfig,
19
+ create_parallel_agents_tool,
20
+ create_simple_parallel_tool,
21
+ create_language_specialists_tool,
22
+ create_domain_experts_tool,
23
+ )
16
24
  from .proxy import ProxyConfig, ProxyAuth, create_proxy_config, get_default_proxy_config
17
25
 
18
26
  __all__ = [
@@ -23,6 +31,8 @@ __all__ = [
23
31
  "Message",
24
32
  "ModelConfig",
25
33
  "ModelProvider",
34
+ "ParallelAgentGroup",
35
+ "ParallelExecutionConfig",
26
36
  "ProxyAuth",
27
37
  "ProxyConfig",
28
38
  "RunConfig",
@@ -41,9 +51,13 @@ __all__ = [
41
51
  "create_agent_tool",
42
52
  "create_conditional_enabler",
43
53
  "create_default_output_extractor",
54
+ "create_domain_experts_tool",
44
55
  "create_json_output_extractor",
56
+ "create_language_specialists_tool",
57
+ "create_parallel_agents_tool",
45
58
  "create_proxy_config",
46
59
  "create_run_id",
60
+ "create_simple_parallel_tool",
47
61
  "create_trace_id",
48
62
  "get_current_run_config",
49
63
  "get_default_proxy_config",
jaf/core/engine.py CHANGED
@@ -32,6 +32,8 @@ from .types import (
32
32
  Interruption,
33
33
  GuardrailEvent,
34
34
  GuardrailEventData,
35
+ GuardrailViolationEvent,
36
+ GuardrailViolationEventData,
35
37
  MemoryEvent,
36
38
  MemoryEventData,
37
39
  OutputParseEvent,
@@ -61,6 +63,15 @@ from .types import (
61
63
  ToolCallFunction,
62
64
  ToolCallStartEvent,
63
65
  ToolCallStartEventData,
66
+ Guardrail,
67
+ ValidValidationResult,
68
+ InvalidValidationResult,
69
+ )
70
+ from .guardrails import (
71
+ build_effective_guardrails,
72
+ execute_input_guardrails_sequential,
73
+ execute_input_guardrails_parallel,
74
+ execute_output_guardrails,
64
75
  )
65
76
 
66
77
 
@@ -399,36 +410,6 @@ async def _run_internal(
399
410
  if resumed:
400
411
  return resumed
401
412
 
402
- # Check initial input guardrails on first turn
403
- if state.turn_count == 0:
404
- first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
405
- if first_user_message and config.initial_input_guardrails:
406
- for guardrail in config.initial_input_guardrails:
407
- if config.on_event:
408
- config.on_event(GuardrailEvent(data=GuardrailEventData(
409
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
410
- content=get_text_content(first_user_message.content)
411
- )))
412
- if asyncio.iscoroutinefunction(guardrail):
413
- result = await guardrail(get_text_content(first_user_message.content))
414
- else:
415
- result = guardrail(get_text_content(first_user_message.content))
416
-
417
- if not result.is_valid:
418
- if config.on_event:
419
- config.on_event(GuardrailEvent(data=GuardrailEventData(
420
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
421
- content=get_text_content(first_user_message.content),
422
- is_valid=False,
423
- error_message=result.error_message
424
- )))
425
- return RunResult(
426
- final_state=state,
427
- outcome=ErrorOutcome(error=InputGuardrailTripwire(
428
- reason=result.error_message or "Input guardrail failed"
429
- ))
430
- )
431
-
432
413
  # Check max turns
433
414
  max_turns = config.max_turns or 50
434
415
  if state.turn_count >= max_turns:
@@ -445,6 +426,105 @@ async def _run_internal(
445
426
  outcome=ErrorOutcome(error=AgentNotFound(agent_name=state.current_agent_name))
446
427
  )
447
428
 
429
+ # Determine if agent has advanced guardrails configuration
430
+ has_advanced_guardrails = bool(
431
+ current_agent.advanced_config and
432
+ current_agent.advanced_config.guardrails and
433
+ (current_agent.advanced_config.guardrails.input_prompt or
434
+ current_agent.advanced_config.guardrails.output_prompt or
435
+ current_agent.advanced_config.guardrails.require_citations)
436
+ )
437
+
438
+ print('[JAF:ENGINE] Debug guardrails setup:', {
439
+ 'agent_name': current_agent.name,
440
+ 'has_advanced_config': bool(current_agent.advanced_config),
441
+ 'has_advanced_guardrails': has_advanced_guardrails,
442
+ 'initial_input_guardrails': len(config.initial_input_guardrails or []),
443
+ 'final_output_guardrails': len(config.final_output_guardrails or [])
444
+ })
445
+
446
+ # Build effective guardrails
447
+ effective_input_guardrails: List[Guardrail] = []
448
+ effective_output_guardrails: List[Guardrail] = []
449
+
450
+ if has_advanced_guardrails:
451
+ result = await build_effective_guardrails(current_agent, config)
452
+ effective_input_guardrails, effective_output_guardrails = result
453
+ else:
454
+ effective_input_guardrails = list(config.initial_input_guardrails or [])
455
+ effective_output_guardrails = list(config.final_output_guardrails or [])
456
+
457
+ # Execute input guardrails on first turn
458
+ input_guardrails_to_run = (effective_input_guardrails
459
+ if state.turn_count == 0 and effective_input_guardrails
460
+ else [])
461
+
462
+ print('[JAF:ENGINE] Input guardrails to run:', {
463
+ 'turn_count': state.turn_count,
464
+ 'effective_input_length': len(effective_input_guardrails),
465
+ 'input_guardrails_to_run_length': len(input_guardrails_to_run),
466
+ 'has_advanced_guardrails': has_advanced_guardrails
467
+ })
468
+
469
+ if input_guardrails_to_run and state.turn_count == 0:
470
+ first_user_message = next((m for m in state.messages if m.role == ContentRole.USER or m.role == 'user'), None)
471
+ if first_user_message:
472
+ if has_advanced_guardrails:
473
+ execution_mode = (current_agent.advanced_config.guardrails.execution_mode
474
+ if current_agent.advanced_config and current_agent.advanced_config.guardrails
475
+ else 'parallel')
476
+
477
+ if execution_mode == 'sequential':
478
+ guardrail_result = await execute_input_guardrails_sequential(
479
+ input_guardrails_to_run, first_user_message, config
480
+ )
481
+ if not guardrail_result.is_valid:
482
+ return RunResult(
483
+ final_state=state,
484
+ outcome=ErrorOutcome(error=InputGuardrailTripwire(
485
+ reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
486
+ ))
487
+ )
488
+ else:
489
+ # Parallel execution with LLM call overlap
490
+ guardrail_result = await execute_input_guardrails_parallel(
491
+ input_guardrails_to_run, first_user_message, config
492
+ )
493
+ if not guardrail_result.is_valid:
494
+ print(f"🚨 Input guardrail violation: {getattr(guardrail_result, 'error_message', 'Unknown violation')}")
495
+ return RunResult(
496
+ final_state=state,
497
+ outcome=ErrorOutcome(error=InputGuardrailTripwire(
498
+ reason=getattr(guardrail_result, 'error_message', 'Input guardrail violation')
499
+ ))
500
+ )
501
+ else:
502
+ # Legacy guardrails path
503
+ print('[JAF:ENGINE] Using LEGACY guardrails path with', len(input_guardrails_to_run), 'guardrails')
504
+ for guardrail in input_guardrails_to_run:
505
+ if config.on_event:
506
+ config.on_event(GuardrailEvent(data=GuardrailEventData(
507
+ guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
508
+ content=get_text_content(first_user_message.content)
509
+ )))
510
+ if asyncio.iscoroutinefunction(guardrail):
511
+ result = await guardrail(get_text_content(first_user_message.content))
512
+ else:
513
+ result = guardrail(get_text_content(first_user_message.content))
514
+
515
+ if not result.is_valid:
516
+ if config.on_event:
517
+ config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
518
+ stage='input',
519
+ reason=getattr(result, 'error_message', 'Input guardrail failed')
520
+ )))
521
+ return RunResult(
522
+ final_state=state,
523
+ outcome=ErrorOutcome(error=InputGuardrailTripwire(
524
+ reason=getattr(result, 'error_message', 'Input guardrail failed')
525
+ ))
526
+ )
527
+
448
528
  # Agent debugging logs removed for performance
449
529
 
450
530
  # Get model name
@@ -516,12 +596,15 @@ async def _run_internal(
516
596
  if len(partial_tool_calls) > 0:
517
597
  message_tool_calls = []
518
598
  for i, tc in enumerate(partial_tool_calls):
599
+ arguments = tc["function"]["arguments"]
600
+ if isinstance(arguments, str):
601
+ arguments = _normalize_tool_call_arguments(arguments)
519
602
  message_tool_calls.append({
520
603
  "id": tc["id"] or f"call_{i}",
521
604
  "type": "function",
522
605
  "function": {
523
606
  "name": tc["function"]["name"] or "",
524
- "arguments": tc["function"]["arguments"]
607
+ "arguments": arguments
525
608
  }
526
609
  })
527
610
 
@@ -534,7 +617,7 @@ async def _run_internal(
534
617
  type="function",
535
618
  function=ToolCallFunction(
536
619
  name=mc["function"]["name"],
537
- arguments=mc["function"]["arguments"],
620
+ arguments=_normalize_tool_call_arguments(mc["function"]["arguments"])
538
621
  ),
539
622
  ) for mc in message_tool_calls
540
623
  ],
@@ -553,12 +636,15 @@ async def _run_internal(
553
636
  if len(partial_tool_calls) > 0:
554
637
  final_tool_calls = []
555
638
  for i, tc in enumerate(partial_tool_calls):
639
+ arguments = tc["function"]["arguments"]
640
+ if isinstance(arguments, str):
641
+ arguments = _normalize_tool_call_arguments(arguments)
556
642
  final_tool_calls.append({
557
643
  "id": tc["id"] or f"call_{i}",
558
644
  "type": "function",
559
645
  "function": {
560
646
  "name": tc["function"]["name"] or "",
561
- "arguments": tc["function"]["arguments"]
647
+ "arguments": arguments
562
648
  }
563
649
  })
564
650
 
@@ -746,13 +832,27 @@ async def _run_internal(
746
832
  )))
747
833
 
748
834
  # Check final output guardrails
749
- if config.final_output_guardrails:
750
- for guardrail in config.final_output_guardrails:
751
- if config.on_event:
752
- config.on_event(GuardrailEvent(data=GuardrailEventData(
753
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
754
- content=output_data
755
- )))
835
+ if has_advanced_guardrails:
836
+ # Use new advanced system
837
+ output_guardrail_result = await execute_output_guardrails(
838
+ effective_output_guardrails, output_data, config
839
+ )
840
+ if not output_guardrail_result.is_valid:
841
+ return RunResult(
842
+ final_state=replace(state, messages=new_messages),
843
+ outcome=ErrorOutcome(error=OutputGuardrailTripwire(
844
+ reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
845
+ ))
846
+ )
847
+ else:
848
+ # Legacy system
849
+ if effective_output_guardrails:
850
+ for guardrail in effective_output_guardrails:
851
+ if config.on_event:
852
+ config.on_event(GuardrailEvent(data=GuardrailEventData(
853
+ guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
854
+ content=output_data
855
+ )))
756
856
  if asyncio.iscoroutinefunction(guardrail):
757
857
  result = await guardrail(output_data)
758
858
  else:
@@ -760,16 +860,14 @@ async def _run_internal(
760
860
 
761
861
  if not result.is_valid:
762
862
  if config.on_event:
763
- config.on_event(GuardrailEvent(data=GuardrailEventData(
764
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
765
- content=output_data,
766
- is_valid=False,
767
- error_message=result.error_message
863
+ config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
864
+ stage='output',
865
+ reason=getattr(result, 'error_message', 'Output guardrail failed')
768
866
  )))
769
867
  return RunResult(
770
868
  final_state=replace(state, messages=new_messages, approvals=state.approvals),
771
869
  outcome=ErrorOutcome(error=OutputGuardrailTripwire(
772
- reason=result.error_message or "Output guardrail failed"
870
+ reason=getattr(result, 'error_message', 'Output guardrail failed')
773
871
  ))
774
872
  )
775
873
 
@@ -793,32 +891,44 @@ async def _run_internal(
793
891
  )
794
892
  else:
795
893
  # No output codec, return content as string
796
- if config.final_output_guardrails:
797
- for guardrail in config.final_output_guardrails:
798
- if config.on_event:
799
- config.on_event(GuardrailEvent(data=GuardrailEventData(
800
- guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
801
- content=get_text_content(assistant_message.content)
802
- )))
803
- if asyncio.iscoroutinefunction(guardrail):
804
- result = await guardrail(get_text_content(assistant_message.content))
805
- else:
806
- result = guardrail(get_text_content(assistant_message.content))
807
-
808
- if not result.is_valid:
894
+ if has_advanced_guardrails:
895
+ # Use new advanced system
896
+ output_guardrail_result = await execute_output_guardrails(
897
+ effective_output_guardrails, get_text_content(assistant_message.content), config
898
+ )
899
+ if not output_guardrail_result.is_valid:
900
+ return RunResult(
901
+ final_state=replace(state, messages=new_messages),
902
+ outcome=ErrorOutcome(error=OutputGuardrailTripwire(
903
+ reason=getattr(output_guardrail_result, 'error_message', 'Output guardrail violation')
904
+ ))
905
+ )
906
+ else:
907
+ # Legacy system
908
+ if effective_output_guardrails:
909
+ for guardrail in effective_output_guardrails:
809
910
  if config.on_event:
810
911
  config.on_event(GuardrailEvent(data=GuardrailEventData(
811
912
  guardrail_name=getattr(guardrail, '__name__', 'unknown_guardrail'),
812
- content=get_text_content(assistant_message.content),
813
- is_valid=False,
814
- error_message=result.error_message
913
+ content=get_text_content(assistant_message.content)
815
914
  )))
816
- return RunResult(
817
- final_state=replace(state, messages=new_messages, approvals=state.approvals),
818
- outcome=ErrorOutcome(error=OutputGuardrailTripwire(
819
- reason=result.error_message or "Output guardrail failed"
820
- ))
821
- )
915
+ if asyncio.iscoroutinefunction(guardrail):
916
+ result = await guardrail(get_text_content(assistant_message.content))
917
+ else:
918
+ result = guardrail(get_text_content(assistant_message.content))
919
+
920
+ if not result.is_valid:
921
+ if config.on_event:
922
+ config.on_event(GuardrailViolationEvent(data=GuardrailViolationEventData(
923
+ stage='output',
924
+ reason=getattr(result, 'error_message', 'Output guardrail failed')
925
+ )))
926
+ return RunResult(
927
+ final_state=replace(state, messages=new_messages, approvals=state.approvals),
928
+ outcome=ErrorOutcome(error=OutputGuardrailTripwire(
929
+ reason=getattr(result, 'error_message', 'Output guardrail failed')
930
+ ))
931
+ )
822
932
 
823
933
  return RunResult(
824
934
  final_state=replace(state, messages=new_messages, turn_count=state.turn_count + 1, approvals=state.approvals),
@@ -844,12 +954,33 @@ def _convert_tool_calls(tool_calls: Optional[List[Dict[str, Any]]]) -> Optional[
844
954
  type='function',
845
955
  function=ToolCallFunction(
846
956
  name=tc['function']['name'],
847
- arguments=tc['function']['arguments']
957
+ arguments=_normalize_tool_call_arguments(tc['function']['arguments'])
848
958
  )
849
959
  )
850
960
  for tc in tool_calls
851
961
  ]
852
962
 
963
+
964
+ def _normalize_tool_call_arguments(arguments: Any) -> Any:
965
+ """Strip trailing streaming artifacts so arguments remain valid JSON strings."""
966
+ if not arguments or not isinstance(arguments, str):
967
+ return arguments
968
+
969
+ decoder = json.JSONDecoder()
970
+ try:
971
+ obj, end = decoder.raw_decode(arguments)
972
+ except json.JSONDecodeError:
973
+ return arguments
974
+
975
+ remainder = arguments[end:].strip()
976
+ if remainder:
977
+ try:
978
+ return json.dumps(obj)
979
+ except (TypeError, ValueError):
980
+ return arguments
981
+
982
+ return arguments
983
+
853
984
  async def _execute_tool_calls(
854
985
  tool_calls: List[ToolCall],
855
986
  agent: Agent[Ctx, Any],
@@ -865,7 +996,8 @@ async def _execute_tool_calls(
865
996
  tool_name=tool_call.function.name,
866
997
  args=_try_parse_json(tool_call.function.arguments),
867
998
  trace_id=state.trace_id,
868
- run_id=state.run_id
999
+ run_id=state.run_id,
1000
+ call_id=tool_call.id
869
1001
  ))))
870
1002
 
871
1003
  try:
@@ -891,7 +1023,8 @@ async def _execute_tool_calls(
891
1023
  trace_id=state.trace_id,
892
1024
  run_id=state.run_id,
893
1025
  status='error',
894
- tool_result={'error': 'tool_not_found'}
1026
+ tool_result={'error': 'tool_not_found'},
1027
+ call_id=tool_call.id
895
1028
  ))))
896
1029
 
897
1030
  return {
@@ -925,7 +1058,8 @@ async def _execute_tool_calls(
925
1058
  trace_id=state.trace_id,
926
1059
  run_id=state.run_id,
927
1060
  status='error',
928
- tool_result={'error': 'validation_error', 'details': e.errors()}
1061
+ tool_result={'error': 'validation_error', 'details': e.errors()},
1062
+ call_id=tool_call.id
929
1063
  ))))
930
1064
 
931
1065
  return {
@@ -1019,7 +1153,7 @@ async def _execute_tool_calls(
1019
1153
  else:
1020
1154
  timeout = None
1021
1155
  if timeout is None:
1022
- timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
1156
+ timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 300.0
1023
1157
 
1024
1158
  # Merge additional context if provided through approval
1025
1159
  additional_context = approval_status.additional_context if approval_status else None
@@ -1063,7 +1197,8 @@ async def _execute_tool_calls(
1063
1197
  trace_id=state.trace_id,
1064
1198
  run_id=state.run_id,
1065
1199
  status='timeout',
1066
- tool_result={'error': 'timeout_error'}
1200
+ tool_result={'error': 'timeout_error'},
1201
+ call_id=tool_call.id
1067
1202
  ))))
1068
1203
 
1069
1204
  return {
@@ -1115,7 +1250,8 @@ async def _execute_tool_calls(
1115
1250
  trace_id=state.trace_id,
1116
1251
  run_id=state.run_id,
1117
1252
  tool_result=tool_result,
1118
- status='success'
1253
+ status='success',
1254
+ call_id=tool_call.id
1119
1255
  ))))
1120
1256
 
1121
1257
  # Check for handoff
@@ -1153,7 +1289,8 @@ async def _execute_tool_calls(
1153
1289
  trace_id=state.trace_id,
1154
1290
  run_id=state.run_id,
1155
1291
  status='error',
1156
- tool_result={'error': 'execution_error', 'detail': str(error)}
1292
+ tool_result={'error': 'execution_error', 'detail': str(error)},
1293
+ call_id=tool_call.id
1157
1294
  ))))
1158
1295
 
1159
1296
  return {