jaf-py 2.4.3__tar.gz → 2.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. {jaf_py-2.4.3 → jaf_py-2.4.5}/PKG-INFO +1 -1
  2. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/__init__.py +14 -0
  3. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/engine.py +50 -15
  4. jaf_py-2.4.5/jaf/core/parallel_agents.py +339 -0
  5. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/streaming.py +42 -17
  6. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/tracing.py +73 -16
  7. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/types.py +7 -4
  8. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/server/server.py +2 -2
  9. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/PKG-INFO +1 -1
  10. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/SOURCES.txt +1 -0
  11. {jaf_py-2.4.3 → jaf_py-2.4.5}/pyproject.toml +1 -1
  12. {jaf_py-2.4.3 → jaf_py-2.4.5}/LICENSE +0 -0
  13. {jaf_py-2.4.3 → jaf_py-2.4.5}/README.md +0 -0
  14. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/__init__.py +0 -0
  15. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/__init__.py +0 -0
  16. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/agent.py +0 -0
  17. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/agent_card.py +0 -0
  18. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/client.py +0 -0
  19. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/__init__.py +0 -0
  20. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/client_example.py +0 -0
  21. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/integration_example.py +0 -0
  22. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/rag_demo/__init__.py +0 -0
  23. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/server_demo/__init__.py +0 -0
  24. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/examples/server_example.py +0 -0
  25. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/__init__.py +0 -0
  26. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/cleanup.py +0 -0
  27. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/factory.py +0 -0
  28. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/providers/__init__.py +0 -0
  29. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/providers/composite.py +0 -0
  30. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/providers/in_memory.py +0 -0
  31. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/providers/postgres.py +0 -0
  32. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/providers/redis.py +0 -0
  33. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/serialization.py +0 -0
  34. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/__init__.py +0 -0
  35. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/run_comprehensive_tests.py +0 -0
  36. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/test_cleanup.py +0 -0
  37. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/test_serialization.py +0 -0
  38. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/test_stress_concurrency.py +0 -0
  39. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/tests/test_task_lifecycle.py +0 -0
  40. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/memory/types.py +0 -0
  41. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/protocol.py +0 -0
  42. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/server.py +0 -0
  43. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/standalone_client.py +0 -0
  44. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/__init__.py +0 -0
  45. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/run_tests.py +0 -0
  46. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/test_agent.py +0 -0
  47. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/test_client.py +0 -0
  48. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/test_integration.py +0 -0
  49. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/test_protocol.py +0 -0
  50. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/tests/test_types.py +0 -0
  51. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/a2a/types.py +0 -0
  52. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/cli.py +0 -0
  53. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/agent_tool.py +0 -0
  54. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/analytics.py +0 -0
  55. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/composition.py +0 -0
  56. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/errors.py +0 -0
  57. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/performance.py +0 -0
  58. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/proxy.py +0 -0
  59. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/proxy_helpers.py +0 -0
  60. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/state.py +0 -0
  61. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/tool_results.py +0 -0
  62. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/tools.py +0 -0
  63. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/core/workflows.py +0 -0
  64. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/exceptions.py +0 -0
  65. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/__init__.py +0 -0
  66. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/approval_storage.py +0 -0
  67. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/factory.py +0 -0
  68. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/providers/__init__.py +0 -0
  69. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/providers/in_memory.py +0 -0
  70. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/providers/postgres.py +0 -0
  71. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/providers/redis.py +0 -0
  72. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/types.py +0 -0
  73. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/memory/utils.py +0 -0
  74. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/plugins/__init__.py +0 -0
  75. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/plugins/base.py +0 -0
  76. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/policies/__init__.py +0 -0
  77. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/policies/handoff.py +0 -0
  78. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/policies/validation.py +0 -0
  79. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/providers/__init__.py +0 -0
  80. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/providers/mcp.py +0 -0
  81. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/providers/model.py +0 -0
  82. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/server/__init__.py +0 -0
  83. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/server/main.py +0 -0
  84. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/server/types.py +0 -0
  85. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/utils/__init__.py +0 -0
  86. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/utils/attachments.py +0 -0
  87. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/utils/document_processor.py +0 -0
  88. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/__init__.py +0 -0
  89. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/example.py +0 -0
  90. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/functional_core.py +0 -0
  91. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/graphviz.py +0 -0
  92. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/imperative_shell.py +0 -0
  93. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf/visualization/types.py +0 -0
  94. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/dependency_links.txt +0 -0
  95. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/entry_points.txt +0 -0
  96. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/requires.txt +0 -0
  97. {jaf_py-2.4.3 → jaf_py-2.4.5}/jaf_py.egg-info/top_level.txt +0 -0
  98. {jaf_py-2.4.3 → jaf_py-2.4.5}/setup.cfg +0 -0
  99. {jaf_py-2.4.3 → jaf_py-2.4.5}/setup.py +0 -0
  100. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_a2a_deep.py +0 -0
  101. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_a2a_examples.py +0 -0
  102. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_api_reference_examples.py +0 -0
  103. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_attachments.py +0 -0
  104. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_callback_system_examples.py +0 -0
  105. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_coffee_tool.py +0 -0
  106. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_conversation_id_fix.py +0 -0
  107. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_deployment_examples.py +0 -0
  108. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_docs_code_examples.py +0 -0
  109. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_engine.py +0 -0
  110. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_engine_manual.py +0 -0
  111. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_error_handling_examples.py +0 -0
  112. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_getting_started_examples.py +0 -0
  113. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_math_tool.py +0 -0
  114. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_mcp_comprehensive.py +0 -0
  115. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_mcp_docs.py +0 -0
  116. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_mcp_real_functionality.py +0 -0
  117. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_mcp_transports.py +0 -0
  118. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_memory_system_examples.py +0 -0
  119. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_model_providers_examples.py +0 -0
  120. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_property_based.py +0 -0
  121. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_proxy_simple.py +0 -0
  122. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_redis_fixes.py +0 -0
  123. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_redis_memory.py +0 -0
  124. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_server_api_examples.py +0 -0
  125. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_session_continuity.py +0 -0
  126. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_streamable_http_mcp_example.py +0 -0
  127. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_timeout_functionality.py +0 -0
  128. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_tool_integration.py +0 -0
  129. {jaf_py-2.4.3 → jaf_py-2.4.5}/tests/test_validation.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: jaf-py
3
- Version: 2.4.3
3
+ Version: 2.4.5
4
4
  Summary: A purely functional agent framework with immutable state and composable tools - Python implementation
5
5
  Author: JAF Contributors
6
6
  Maintainer: JAF Contributors
@@ -13,6 +13,14 @@ from .agent_tool import (
13
13
  get_current_run_config,
14
14
  set_current_run_config,
15
15
  )
16
+ from .parallel_agents import (
17
+ ParallelAgentGroup,
18
+ ParallelExecutionConfig,
19
+ create_parallel_agents_tool,
20
+ create_simple_parallel_tool,
21
+ create_language_specialists_tool,
22
+ create_domain_experts_tool,
23
+ )
16
24
  from .proxy import ProxyConfig, ProxyAuth, create_proxy_config, get_default_proxy_config
17
25
 
18
26
  __all__ = [
@@ -23,6 +31,8 @@ __all__ = [
23
31
  "Message",
24
32
  "ModelConfig",
25
33
  "ModelProvider",
34
+ "ParallelAgentGroup",
35
+ "ParallelExecutionConfig",
26
36
  "ProxyAuth",
27
37
  "ProxyConfig",
28
38
  "RunConfig",
@@ -41,9 +51,13 @@ __all__ = [
41
51
  "create_agent_tool",
42
52
  "create_conditional_enabler",
43
53
  "create_default_output_extractor",
54
+ "create_domain_experts_tool",
44
55
  "create_json_output_extractor",
56
+ "create_language_specialists_tool",
57
+ "create_parallel_agents_tool",
45
58
  "create_proxy_config",
46
59
  "create_run_id",
60
+ "create_simple_parallel_tool",
47
61
  "create_trace_id",
48
62
  "get_current_run_config",
49
63
  "get_default_proxy_config",
@@ -174,16 +174,18 @@ async def run(
174
174
  from .agent_tool import set_current_run_config
175
175
  set_current_run_config(config)
176
176
 
177
+ state_with_memory = await _load_conversation_history(initial_state, config)
178
+
179
+ # Emit RunStartEvent AFTER loading conversation history so we have complete context
177
180
  if config.on_event:
178
181
  config.on_event(RunStartEvent(data=to_event_data(RunStartEventData(
179
182
  run_id=initial_state.run_id,
180
183
  trace_id=initial_state.trace_id,
181
184
  session_id=config.conversation_id,
182
- context=initial_state.context,
183
- messages=initial_state.messages
185
+ context=state_with_memory.context,
186
+ messages=state_with_memory.messages, # Now includes full conversation history
187
+ agent_name=state_with_memory.current_agent_name
184
188
  ))))
185
-
186
- state_with_memory = await _load_conversation_history(initial_state, config)
187
189
 
188
190
  # Load approvals from storage if configured
189
191
  if config.approval_storage:
@@ -514,12 +516,15 @@ async def _run_internal(
514
516
  if len(partial_tool_calls) > 0:
515
517
  message_tool_calls = []
516
518
  for i, tc in enumerate(partial_tool_calls):
519
+ arguments = tc["function"]["arguments"]
520
+ if isinstance(arguments, str):
521
+ arguments = _normalize_tool_call_arguments(arguments)
517
522
  message_tool_calls.append({
518
523
  "id": tc["id"] or f"call_{i}",
519
524
  "type": "function",
520
525
  "function": {
521
526
  "name": tc["function"]["name"] or "",
522
- "arguments": tc["function"]["arguments"]
527
+ "arguments": arguments
523
528
  }
524
529
  })
525
530
 
@@ -532,7 +537,7 @@ async def _run_internal(
532
537
  type="function",
533
538
  function=ToolCallFunction(
534
539
  name=mc["function"]["name"],
535
- arguments=mc["function"]["arguments"],
540
+ arguments=_normalize_tool_call_arguments(mc["function"]["arguments"])
536
541
  ),
537
542
  ) for mc in message_tool_calls
538
543
  ],
@@ -551,12 +556,15 @@ async def _run_internal(
551
556
  if len(partial_tool_calls) > 0:
552
557
  final_tool_calls = []
553
558
  for i, tc in enumerate(partial_tool_calls):
559
+ arguments = tc["function"]["arguments"]
560
+ if isinstance(arguments, str):
561
+ arguments = _normalize_tool_call_arguments(arguments)
554
562
  final_tool_calls.append({
555
563
  "id": tc["id"] or f"call_{i}",
556
564
  "type": "function",
557
565
  "function": {
558
566
  "name": tc["function"]["name"] or "",
559
- "arguments": tc["function"]["arguments"]
567
+ "arguments": arguments
560
568
  }
561
569
  })
562
570
 
@@ -842,12 +850,33 @@ def _convert_tool_calls(tool_calls: Optional[List[Dict[str, Any]]]) -> Optional[
842
850
  type='function',
843
851
  function=ToolCallFunction(
844
852
  name=tc['function']['name'],
845
- arguments=tc['function']['arguments']
853
+ arguments=_normalize_tool_call_arguments(tc['function']['arguments'])
846
854
  )
847
855
  )
848
856
  for tc in tool_calls
849
857
  ]
850
858
 
859
+
860
+ def _normalize_tool_call_arguments(arguments: Any) -> Any:
861
+ """Strip trailing streaming artifacts so arguments remain valid JSON strings."""
862
+ if not arguments or not isinstance(arguments, str):
863
+ return arguments
864
+
865
+ decoder = json.JSONDecoder()
866
+ try:
867
+ obj, end = decoder.raw_decode(arguments)
868
+ except json.JSONDecodeError:
869
+ return arguments
870
+
871
+ remainder = arguments[end:].strip()
872
+ if remainder:
873
+ try:
874
+ return json.dumps(obj)
875
+ except (TypeError, ValueError):
876
+ return arguments
877
+
878
+ return arguments
879
+
851
880
  async def _execute_tool_calls(
852
881
  tool_calls: List[ToolCall],
853
882
  agent: Agent[Ctx, Any],
@@ -863,7 +892,8 @@ async def _execute_tool_calls(
863
892
  tool_name=tool_call.function.name,
864
893
  args=_try_parse_json(tool_call.function.arguments),
865
894
  trace_id=state.trace_id,
866
- run_id=state.run_id
895
+ run_id=state.run_id,
896
+ call_id=tool_call.id
867
897
  ))))
868
898
 
869
899
  try:
@@ -889,7 +919,8 @@ async def _execute_tool_calls(
889
919
  trace_id=state.trace_id,
890
920
  run_id=state.run_id,
891
921
  status='error',
892
- tool_result={'error': 'tool_not_found'}
922
+ tool_result={'error': 'tool_not_found'},
923
+ call_id=tool_call.id
893
924
  ))))
894
925
 
895
926
  return {
@@ -923,7 +954,8 @@ async def _execute_tool_calls(
923
954
  trace_id=state.trace_id,
924
955
  run_id=state.run_id,
925
956
  status='error',
926
- tool_result={'error': 'validation_error', 'details': e.errors()}
957
+ tool_result={'error': 'validation_error', 'details': e.errors()},
958
+ call_id=tool_call.id
927
959
  ))))
928
960
 
929
961
  return {
@@ -1017,7 +1049,7 @@ async def _execute_tool_calls(
1017
1049
  else:
1018
1050
  timeout = None
1019
1051
  if timeout is None:
1020
- timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 30.0
1052
+ timeout = config.default_tool_timeout if config.default_tool_timeout is not None else 300.0
1021
1053
 
1022
1054
  # Merge additional context if provided through approval
1023
1055
  additional_context = approval_status.additional_context if approval_status else None
@@ -1061,7 +1093,8 @@ async def _execute_tool_calls(
1061
1093
  trace_id=state.trace_id,
1062
1094
  run_id=state.run_id,
1063
1095
  status='timeout',
1064
- tool_result={'error': 'timeout_error'}
1096
+ tool_result={'error': 'timeout_error'},
1097
+ call_id=tool_call.id
1065
1098
  ))))
1066
1099
 
1067
1100
  return {
@@ -1113,7 +1146,8 @@ async def _execute_tool_calls(
1113
1146
  trace_id=state.trace_id,
1114
1147
  run_id=state.run_id,
1115
1148
  tool_result=tool_result,
1116
- status='success'
1149
+ status='success',
1150
+ call_id=tool_call.id
1117
1151
  ))))
1118
1152
 
1119
1153
  # Check for handoff
@@ -1151,7 +1185,8 @@ async def _execute_tool_calls(
1151
1185
  trace_id=state.trace_id,
1152
1186
  run_id=state.run_id,
1153
1187
  status='error',
1154
- tool_result={'error': 'execution_error', 'detail': str(error)}
1188
+ tool_result={'error': 'execution_error', 'detail': str(error)},
1189
+ call_id=tool_call.id
1155
1190
  ))))
1156
1191
 
1157
1192
  return {
@@ -0,0 +1,339 @@
1
+ """
2
+ Parallel Agent Execution for JAF Framework.
3
+
4
+ This module provides functionality to execute multiple sub-agents in parallel groups,
5
+ allowing for coordinated parallel execution with configurable grouping and result aggregation.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, List, Optional, Union, Callable, TypeVar
12
+
13
+ from .types import (
14
+ Agent,
15
+ Tool,
16
+ ToolSchema,
17
+ ToolSource,
18
+ RunConfig,
19
+ RunState,
20
+ RunResult,
21
+ Message,
22
+ ContentRole,
23
+ generate_run_id,
24
+ generate_trace_id,
25
+ )
26
+ from .agent_tool import create_agent_tool, AgentToolInput
27
+
28
+ Ctx = TypeVar('Ctx')
29
+ Out = TypeVar('Out')
30
+
31
+
32
+ @dataclass
33
+ class ParallelAgentGroup:
34
+ """Configuration for a group of agents to be executed in parallel."""
35
+ name: str
36
+ agents: List[Agent[Ctx, Out]]
37
+ shared_input: bool = True # Whether all agents receive the same input
38
+ result_aggregation: str = "combine" # "combine", "first", "majority", "custom"
39
+ custom_aggregator: Optional[Callable[[List[str]], str]] = None
40
+ timeout: Optional[float] = None
41
+ metadata: Optional[Dict[str, Any]] = None
42
+
43
+
44
+ @dataclass
45
+ class ParallelExecutionConfig:
46
+ """Configuration for parallel agent execution."""
47
+ groups: List[ParallelAgentGroup]
48
+ inter_group_execution: str = "sequential" # "sequential" or "parallel"
49
+ global_timeout: Optional[float] = None
50
+ preserve_session: bool = False
51
+
52
+
53
+ class ParallelAgentsTool:
54
+ """Tool that executes multiple agent groups in parallel."""
55
+
56
+ def __init__(
57
+ self,
58
+ config: ParallelExecutionConfig,
59
+ tool_name: str = "execute_parallel_agents",
60
+ tool_description: str = "Execute multiple agents in parallel groups"
61
+ ):
62
+ self.config = config
63
+ self.tool_name = tool_name
64
+ self.tool_description = tool_description
65
+
66
+ # Create tool schema
67
+ self.schema = ToolSchema(
68
+ name=tool_name,
69
+ description=tool_description,
70
+ parameters=AgentToolInput,
71
+ timeout=config.global_timeout
72
+ )
73
+ self.source = ToolSource.NATIVE
74
+ self.metadata = {"source": "parallel_agents", "groups": len(config.groups)}
75
+
76
+ async def execute(self, args: AgentToolInput, context: Ctx) -> str:
77
+ """Execute all configured agent groups."""
78
+ try:
79
+ if self.config.inter_group_execution == "parallel":
80
+ # Execute all groups in parallel
81
+ group_results = await asyncio.gather(*[
82
+ self._execute_group(group, args.input, context)
83
+ for group in self.config.groups
84
+ ])
85
+ else:
86
+ # Execute groups sequentially
87
+ group_results = []
88
+ for group in self.config.groups:
89
+ result = await self._execute_group(group, args.input, context)
90
+ group_results.append(result)
91
+
92
+ # Combine results from all groups
93
+ final_result = {
94
+ "parallel_execution_results": {
95
+ group.name: result for group, result in zip(self.config.groups, group_results)
96
+ },
97
+ "execution_mode": self.config.inter_group_execution,
98
+ "total_groups": len(self.config.groups)
99
+ }
100
+
101
+ return json.dumps(final_result, indent=2)
102
+
103
+ except Exception as e:
104
+ return json.dumps({
105
+ "error": "parallel_execution_failed",
106
+ "message": f"Failed to execute parallel agents: {str(e)}",
107
+ "groups_attempted": len(self.config.groups)
108
+ })
109
+
110
+ async def _execute_group(
111
+ self,
112
+ group: ParallelAgentGroup,
113
+ input_text: str,
114
+ context: Ctx
115
+ ) -> Dict[str, Any]:
116
+ """Execute a single group of agents in parallel."""
117
+ try:
118
+ # Create agent tools for all agents in the group
119
+ agent_tools = []
120
+ for agent in group.agents:
121
+ tool = create_agent_tool(
122
+ agent=agent,
123
+ tool_name=f"run_{agent.name.lower().replace(' ', '_')}",
124
+ tool_description=f"Execute the {agent.name} agent",
125
+ timeout=group.timeout,
126
+ preserve_session=self.config.preserve_session
127
+ )
128
+ agent_tools.append((agent.name, tool))
129
+
130
+ # Execute all agents in the group in parallel
131
+ if group.shared_input:
132
+ # All agents get the same input
133
+ tasks = [
134
+ tool.execute(AgentToolInput(input=input_text), context)
135
+ for _, tool in agent_tools
136
+ ]
137
+ else:
138
+ # This could be extended to support different inputs per agent
139
+ tasks = [
140
+ tool.execute(AgentToolInput(input=input_text), context)
141
+ for _, tool in agent_tools
142
+ ]
143
+
144
+ # Execute with timeout if specified
145
+ if group.timeout:
146
+ results = await asyncio.wait_for(
147
+ asyncio.gather(*tasks, return_exceptions=True),
148
+ timeout=group.timeout
149
+ )
150
+ else:
151
+ results = await asyncio.gather(*tasks, return_exceptions=True)
152
+
153
+ # Process results
154
+ agent_results = {}
155
+ for (agent_name, _), result in zip(agent_tools, results):
156
+ if isinstance(result, Exception):
157
+ agent_results[agent_name] = {
158
+ "error": True,
159
+ "message": str(result),
160
+ "type": type(result).__name__
161
+ }
162
+ else:
163
+ agent_results[agent_name] = {
164
+ "success": True,
165
+ "result": result
166
+ }
167
+
168
+ # Apply result aggregation
169
+ aggregated_result = self._aggregate_results(group, agent_results)
170
+
171
+ return {
172
+ "group_name": group.name,
173
+ "agent_count": len(group.agents),
174
+ "individual_results": agent_results,
175
+ "aggregated_result": aggregated_result,
176
+ "execution_time_ms": None # Could be added with timing
177
+ }
178
+
179
+ except asyncio.TimeoutError:
180
+ return {
181
+ "group_name": group.name,
182
+ "error": "timeout",
183
+ "message": f"Group {group.name} execution timed out after {group.timeout} seconds",
184
+ "agent_count": len(group.agents)
185
+ }
186
+ except Exception as e:
187
+ return {
188
+ "group_name": group.name,
189
+ "error": "execution_failed",
190
+ "message": str(e),
191
+ "agent_count": len(group.agents)
192
+ }
193
+
194
+ def _aggregate_results(
195
+ self,
196
+ group: ParallelAgentGroup,
197
+ agent_results: Dict[str, Any]
198
+ ) -> Union[str, Dict[str, Any]]:
199
+ """Aggregate results from parallel agent execution."""
200
+ successful_results = [
201
+ result["result"] for result in agent_results.values()
202
+ if result.get("success") and "result" in result
203
+ ]
204
+
205
+ if not successful_results:
206
+ return {"error": "no_successful_results", "message": "All agents failed"}
207
+
208
+ if group.result_aggregation == "first":
209
+ return successful_results[0]
210
+ elif group.result_aggregation == "combine":
211
+ return {
212
+ "combined_results": successful_results,
213
+ "result_count": len(successful_results)
214
+ }
215
+ elif group.result_aggregation == "majority":
216
+ # Simple majority logic - could be enhanced
217
+ if len(successful_results) >= len(group.agents) // 2 + 1:
218
+ return successful_results[0] # Return first as majority representative
219
+ else:
220
+ return {"error": "no_majority", "results": successful_results}
221
+ elif group.result_aggregation == "custom" and group.custom_aggregator:
222
+ try:
223
+ return group.custom_aggregator(successful_results)
224
+ except Exception as e:
225
+ return {"error": "custom_aggregation_failed", "message": str(e)}
226
+ else:
227
+ return {"combined_results": successful_results}
228
+
229
+
230
+ def create_parallel_agents_tool(
231
+ groups: List[ParallelAgentGroup],
232
+ tool_name: str = "execute_parallel_agents",
233
+ tool_description: str = "Execute multiple agents in parallel groups",
234
+ inter_group_execution: str = "sequential",
235
+ global_timeout: Optional[float] = None,
236
+ preserve_session: bool = False
237
+ ) -> Tool:
238
+ """
239
+ Create a tool that executes multiple agent groups in parallel.
240
+
241
+ Args:
242
+ groups: List of parallel agent groups to execute
243
+ tool_name: Name of the tool
244
+ tool_description: Description of the tool
245
+ inter_group_execution: How to execute groups ("sequential" or "parallel")
246
+ global_timeout: Global timeout for all executions
247
+ preserve_session: Whether to preserve session across agent calls
248
+
249
+ Returns:
250
+ A Tool that can execute parallel agent groups
251
+ """
252
+ config = ParallelExecutionConfig(
253
+ groups=groups,
254
+ inter_group_execution=inter_group_execution,
255
+ global_timeout=global_timeout,
256
+ preserve_session=preserve_session
257
+ )
258
+
259
+ return ParallelAgentsTool(config, tool_name, tool_description)
260
+
261
+
262
+ def create_simple_parallel_tool(
263
+ agents: List[Agent],
264
+ group_name: str = "parallel_group",
265
+ tool_name: str = "execute_parallel_agents",
266
+ shared_input: bool = True,
267
+ result_aggregation: str = "combine",
268
+ timeout: Optional[float] = None
269
+ ) -> Tool:
270
+ """
271
+ Create a simple parallel agents tool from a list of agents.
272
+
273
+ Args:
274
+ agents: List of agents to execute in parallel
275
+ group_name: Name for the parallel group
276
+ tool_name: Name of the tool
277
+ shared_input: Whether all agents receive the same input
278
+ result_aggregation: How to aggregate results ("combine", "first", "majority")
279
+ timeout: Timeout for parallel execution
280
+
281
+ Returns:
282
+ A Tool that executes all agents in parallel
283
+ """
284
+ group = ParallelAgentGroup(
285
+ name=group_name,
286
+ agents=agents,
287
+ shared_input=shared_input,
288
+ result_aggregation=result_aggregation,
289
+ timeout=timeout
290
+ )
291
+
292
+ return create_parallel_agents_tool([group], tool_name=tool_name)
293
+
294
+
295
+ # Convenience functions for common parallel execution patterns
296
+
297
+ def create_language_specialists_tool(
298
+ language_agents: Dict[str, Agent],
299
+ tool_name: str = "consult_language_specialists",
300
+ timeout: Optional[float] = 300.0
301
+ ) -> Tool:
302
+ """Create a tool that consults multiple language specialists in parallel."""
303
+ group = ParallelAgentGroup(
304
+ name="language_specialists",
305
+ agents=list(language_agents.values()),
306
+ shared_input=True,
307
+ result_aggregation="combine",
308
+ timeout=timeout,
309
+ metadata={"languages": list(language_agents.keys())}
310
+ )
311
+
312
+ return create_parallel_agents_tool(
313
+ [group],
314
+ tool_name=tool_name,
315
+ tool_description="Consult multiple language specialists in parallel"
316
+ )
317
+
318
+
319
+ def create_domain_experts_tool(
320
+ expert_agents: Dict[str, Agent],
321
+ tool_name: str = "consult_domain_experts",
322
+ result_aggregation: str = "combine",
323
+ timeout: Optional[float] = 60.0
324
+ ) -> Tool:
325
+ """Create a tool that consults multiple domain experts in parallel."""
326
+ group = ParallelAgentGroup(
327
+ name="domain_experts",
328
+ agents=list(expert_agents.values()),
329
+ shared_input=True,
330
+ result_aggregation=result_aggregation,
331
+ timeout=timeout,
332
+ metadata={"domains": list(expert_agents.keys())}
333
+ )
334
+
335
+ return create_parallel_agents_tool(
336
+ [group],
337
+ tool_name=tool_name,
338
+ tool_description="Consult multiple domain experts in parallel"
339
+ )
@@ -209,20 +209,37 @@ async def run_streaming(
209
209
  trace_id=initial_state.trace_id
210
210
  )
211
211
 
212
- tool_call_ids = {} # To map tool calls to their IDs
212
+ tool_call_ids: Dict[str, str] = {} # Map call_id -> tool_name for in-flight tool calls
213
213
 
214
214
  def event_handler(event: TraceEvent) -> None:
215
215
  """Handle trace events and put them into the queue."""
216
216
  nonlocal tool_call_ids
217
217
  streaming_event = None
218
+ payload = event.data
219
+
220
+ def _get_event_value(keys: List[str]) -> Any:
221
+ for key in keys:
222
+ if isinstance(payload, dict) and key in payload:
223
+ return payload[key]
224
+ if hasattr(payload, key):
225
+ return getattr(payload, key)
226
+ return None
227
+
218
228
  if event.type == 'tool_call_start':
219
- # Generate a unique ID for the tool call
220
- call_id = f"call_{uuid.uuid4().hex[:8]}"
221
- tool_call_ids[event.data.tool_name] = call_id
222
-
229
+ tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
230
+ args = _get_event_value(['args', 'arguments'])
231
+ call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
232
+
233
+ if not call_id:
234
+ call_id = f"call_{uuid.uuid4().hex[:8]}"
235
+ if isinstance(payload, dict):
236
+ payload['call_id'] = call_id
237
+
238
+ tool_call_ids[call_id] = tool_name
239
+
223
240
  tool_call = StreamingToolCall(
224
- tool_name=event.data.tool_name,
225
- arguments=event.data.args,
241
+ tool_name=tool_name,
242
+ arguments=args,
226
243
  call_id=call_id,
227
244
  status='started'
228
245
  )
@@ -233,18 +250,26 @@ async def run_streaming(
233
250
  trace_id=initial_state.trace_id
234
251
  )
235
252
  elif event.type == 'tool_call_end':
236
- if event.data.tool_name not in tool_call_ids:
237
- raise RuntimeError(
238
- f"Tool call end event received for unknown tool '{event.data.tool_name}'. "
239
- f"Known tool calls: {list(tool_call_ids.keys())}. "
240
- f"This may indicate a missing tool_call_start event or a bug in the streaming implementation."
241
- )
242
- call_id = tool_call_ids[event.data.tool_name]
253
+ tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
254
+ call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
255
+
256
+ if not call_id:
257
+ # Fallback to locate a pending tool call with the same tool name
258
+ matching_call_id = next((cid for cid, name in tool_call_ids.items() if name == tool_name), None)
259
+ if matching_call_id:
260
+ call_id = matching_call_id
261
+ else:
262
+ raise RuntimeError(
263
+ f"Tool call end event received for unknown tool '{tool_name}'. "
264
+ f"Pending call IDs: {list(tool_call_ids.keys())}."
265
+ )
266
+
267
+ tool_call_ids.pop(call_id, None)
243
268
  tool_result = StreamingToolResult(
244
- tool_name=event.data.tool_name,
269
+ tool_name=tool_name,
245
270
  call_id=call_id,
246
- result=event.data.result,
247
- status=event.data.status or 'completed'
271
+ result=_get_event_value(['result']),
272
+ status=_get_event_value(['status']) or 'completed'
248
273
  )
249
274
  streaming_event = StreamingEvent(
250
275
  type=StreamingEventType.TOOL_RESULT,