jaf-py 2.4.4__py3-none-any.whl → 2.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaf/core/__init__.py +14 -0
- jaf/core/engine.py +213 -76
- jaf/core/guardrails.py +666 -0
- jaf/core/parallel_agents.py +339 -0
- jaf/core/streaming.py +42 -17
- jaf/core/tracing.py +22 -9
- jaf/core/types.py +89 -5
- jaf/server/server.py +2 -2
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/METADATA +1 -1
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/RECORD +14 -12
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/WHEEL +0 -0
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/entry_points.txt +0 -0
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/licenses/LICENSE +0 -0
- {jaf_py-2.4.4.dist-info → jaf_py-2.4.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Parallel Agent Execution for JAF Framework.
|
|
3
|
+
|
|
4
|
+
This module provides functionality to execute multiple sub-agents in parallel groups,
|
|
5
|
+
allowing for coordinated parallel execution with configurable grouping and result aggregation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import asyncio
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import Any, Dict, List, Optional, Union, Callable, TypeVar
|
|
12
|
+
|
|
13
|
+
from .types import (
|
|
14
|
+
Agent,
|
|
15
|
+
Tool,
|
|
16
|
+
ToolSchema,
|
|
17
|
+
ToolSource,
|
|
18
|
+
RunConfig,
|
|
19
|
+
RunState,
|
|
20
|
+
RunResult,
|
|
21
|
+
Message,
|
|
22
|
+
ContentRole,
|
|
23
|
+
generate_run_id,
|
|
24
|
+
generate_trace_id,
|
|
25
|
+
)
|
|
26
|
+
from .agent_tool import create_agent_tool, AgentToolInput
|
|
27
|
+
|
|
28
|
+
Ctx = TypeVar('Ctx')
|
|
29
|
+
Out = TypeVar('Out')
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class ParallelAgentGroup:
|
|
34
|
+
"""Configuration for a group of agents to be executed in parallel."""
|
|
35
|
+
name: str
|
|
36
|
+
agents: List[Agent[Ctx, Out]]
|
|
37
|
+
shared_input: bool = True # Whether all agents receive the same input
|
|
38
|
+
result_aggregation: str = "combine" # "combine", "first", "majority", "custom"
|
|
39
|
+
custom_aggregator: Optional[Callable[[List[str]], str]] = None
|
|
40
|
+
timeout: Optional[float] = None
|
|
41
|
+
metadata: Optional[Dict[str, Any]] = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class ParallelExecutionConfig:
|
|
46
|
+
"""Configuration for parallel agent execution."""
|
|
47
|
+
groups: List[ParallelAgentGroup]
|
|
48
|
+
inter_group_execution: str = "sequential" # "sequential" or "parallel"
|
|
49
|
+
global_timeout: Optional[float] = None
|
|
50
|
+
preserve_session: bool = False
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ParallelAgentsTool:
|
|
54
|
+
"""Tool that executes multiple agent groups in parallel."""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
config: ParallelExecutionConfig,
|
|
59
|
+
tool_name: str = "execute_parallel_agents",
|
|
60
|
+
tool_description: str = "Execute multiple agents in parallel groups"
|
|
61
|
+
):
|
|
62
|
+
self.config = config
|
|
63
|
+
self.tool_name = tool_name
|
|
64
|
+
self.tool_description = tool_description
|
|
65
|
+
|
|
66
|
+
# Create tool schema
|
|
67
|
+
self.schema = ToolSchema(
|
|
68
|
+
name=tool_name,
|
|
69
|
+
description=tool_description,
|
|
70
|
+
parameters=AgentToolInput,
|
|
71
|
+
timeout=config.global_timeout
|
|
72
|
+
)
|
|
73
|
+
self.source = ToolSource.NATIVE
|
|
74
|
+
self.metadata = {"source": "parallel_agents", "groups": len(config.groups)}
|
|
75
|
+
|
|
76
|
+
async def execute(self, args: AgentToolInput, context: Ctx) -> str:
|
|
77
|
+
"""Execute all configured agent groups."""
|
|
78
|
+
try:
|
|
79
|
+
if self.config.inter_group_execution == "parallel":
|
|
80
|
+
# Execute all groups in parallel
|
|
81
|
+
group_results = await asyncio.gather(*[
|
|
82
|
+
self._execute_group(group, args.input, context)
|
|
83
|
+
for group in self.config.groups
|
|
84
|
+
])
|
|
85
|
+
else:
|
|
86
|
+
# Execute groups sequentially
|
|
87
|
+
group_results = []
|
|
88
|
+
for group in self.config.groups:
|
|
89
|
+
result = await self._execute_group(group, args.input, context)
|
|
90
|
+
group_results.append(result)
|
|
91
|
+
|
|
92
|
+
# Combine results from all groups
|
|
93
|
+
final_result = {
|
|
94
|
+
"parallel_execution_results": {
|
|
95
|
+
group.name: result for group, result in zip(self.config.groups, group_results)
|
|
96
|
+
},
|
|
97
|
+
"execution_mode": self.config.inter_group_execution,
|
|
98
|
+
"total_groups": len(self.config.groups)
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
return json.dumps(final_result, indent=2)
|
|
102
|
+
|
|
103
|
+
except Exception as e:
|
|
104
|
+
return json.dumps({
|
|
105
|
+
"error": "parallel_execution_failed",
|
|
106
|
+
"message": f"Failed to execute parallel agents: {str(e)}",
|
|
107
|
+
"groups_attempted": len(self.config.groups)
|
|
108
|
+
})
|
|
109
|
+
|
|
110
|
+
async def _execute_group(
|
|
111
|
+
self,
|
|
112
|
+
group: ParallelAgentGroup,
|
|
113
|
+
input_text: str,
|
|
114
|
+
context: Ctx
|
|
115
|
+
) -> Dict[str, Any]:
|
|
116
|
+
"""Execute a single group of agents in parallel."""
|
|
117
|
+
try:
|
|
118
|
+
# Create agent tools for all agents in the group
|
|
119
|
+
agent_tools = []
|
|
120
|
+
for agent in group.agents:
|
|
121
|
+
tool = create_agent_tool(
|
|
122
|
+
agent=agent,
|
|
123
|
+
tool_name=f"run_{agent.name.lower().replace(' ', '_')}",
|
|
124
|
+
tool_description=f"Execute the {agent.name} agent",
|
|
125
|
+
timeout=group.timeout,
|
|
126
|
+
preserve_session=self.config.preserve_session
|
|
127
|
+
)
|
|
128
|
+
agent_tools.append((agent.name, tool))
|
|
129
|
+
|
|
130
|
+
# Execute all agents in the group in parallel
|
|
131
|
+
if group.shared_input:
|
|
132
|
+
# All agents get the same input
|
|
133
|
+
tasks = [
|
|
134
|
+
tool.execute(AgentToolInput(input=input_text), context)
|
|
135
|
+
for _, tool in agent_tools
|
|
136
|
+
]
|
|
137
|
+
else:
|
|
138
|
+
# This could be extended to support different inputs per agent
|
|
139
|
+
tasks = [
|
|
140
|
+
tool.execute(AgentToolInput(input=input_text), context)
|
|
141
|
+
for _, tool in agent_tools
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
# Execute with timeout if specified
|
|
145
|
+
if group.timeout:
|
|
146
|
+
results = await asyncio.wait_for(
|
|
147
|
+
asyncio.gather(*tasks, return_exceptions=True),
|
|
148
|
+
timeout=group.timeout
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
152
|
+
|
|
153
|
+
# Process results
|
|
154
|
+
agent_results = {}
|
|
155
|
+
for (agent_name, _), result in zip(agent_tools, results):
|
|
156
|
+
if isinstance(result, Exception):
|
|
157
|
+
agent_results[agent_name] = {
|
|
158
|
+
"error": True,
|
|
159
|
+
"message": str(result),
|
|
160
|
+
"type": type(result).__name__
|
|
161
|
+
}
|
|
162
|
+
else:
|
|
163
|
+
agent_results[agent_name] = {
|
|
164
|
+
"success": True,
|
|
165
|
+
"result": result
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Apply result aggregation
|
|
169
|
+
aggregated_result = self._aggregate_results(group, agent_results)
|
|
170
|
+
|
|
171
|
+
return {
|
|
172
|
+
"group_name": group.name,
|
|
173
|
+
"agent_count": len(group.agents),
|
|
174
|
+
"individual_results": agent_results,
|
|
175
|
+
"aggregated_result": aggregated_result,
|
|
176
|
+
"execution_time_ms": None # Could be added with timing
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
except asyncio.TimeoutError:
|
|
180
|
+
return {
|
|
181
|
+
"group_name": group.name,
|
|
182
|
+
"error": "timeout",
|
|
183
|
+
"message": f"Group {group.name} execution timed out after {group.timeout} seconds",
|
|
184
|
+
"agent_count": len(group.agents)
|
|
185
|
+
}
|
|
186
|
+
except Exception as e:
|
|
187
|
+
return {
|
|
188
|
+
"group_name": group.name,
|
|
189
|
+
"error": "execution_failed",
|
|
190
|
+
"message": str(e),
|
|
191
|
+
"agent_count": len(group.agents)
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
def _aggregate_results(
|
|
195
|
+
self,
|
|
196
|
+
group: ParallelAgentGroup,
|
|
197
|
+
agent_results: Dict[str, Any]
|
|
198
|
+
) -> Union[str, Dict[str, Any]]:
|
|
199
|
+
"""Aggregate results from parallel agent execution."""
|
|
200
|
+
successful_results = [
|
|
201
|
+
result["result"] for result in agent_results.values()
|
|
202
|
+
if result.get("success") and "result" in result
|
|
203
|
+
]
|
|
204
|
+
|
|
205
|
+
if not successful_results:
|
|
206
|
+
return {"error": "no_successful_results", "message": "All agents failed"}
|
|
207
|
+
|
|
208
|
+
if group.result_aggregation == "first":
|
|
209
|
+
return successful_results[0]
|
|
210
|
+
elif group.result_aggregation == "combine":
|
|
211
|
+
return {
|
|
212
|
+
"combined_results": successful_results,
|
|
213
|
+
"result_count": len(successful_results)
|
|
214
|
+
}
|
|
215
|
+
elif group.result_aggregation == "majority":
|
|
216
|
+
# Simple majority logic - could be enhanced
|
|
217
|
+
if len(successful_results) >= len(group.agents) // 2 + 1:
|
|
218
|
+
return successful_results[0] # Return first as majority representative
|
|
219
|
+
else:
|
|
220
|
+
return {"error": "no_majority", "results": successful_results}
|
|
221
|
+
elif group.result_aggregation == "custom" and group.custom_aggregator:
|
|
222
|
+
try:
|
|
223
|
+
return group.custom_aggregator(successful_results)
|
|
224
|
+
except Exception as e:
|
|
225
|
+
return {"error": "custom_aggregation_failed", "message": str(e)}
|
|
226
|
+
else:
|
|
227
|
+
return {"combined_results": successful_results}
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
def create_parallel_agents_tool(
|
|
231
|
+
groups: List[ParallelAgentGroup],
|
|
232
|
+
tool_name: str = "execute_parallel_agents",
|
|
233
|
+
tool_description: str = "Execute multiple agents in parallel groups",
|
|
234
|
+
inter_group_execution: str = "sequential",
|
|
235
|
+
global_timeout: Optional[float] = None,
|
|
236
|
+
preserve_session: bool = False
|
|
237
|
+
) -> Tool:
|
|
238
|
+
"""
|
|
239
|
+
Create a tool that executes multiple agent groups in parallel.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
groups: List of parallel agent groups to execute
|
|
243
|
+
tool_name: Name of the tool
|
|
244
|
+
tool_description: Description of the tool
|
|
245
|
+
inter_group_execution: How to execute groups ("sequential" or "parallel")
|
|
246
|
+
global_timeout: Global timeout for all executions
|
|
247
|
+
preserve_session: Whether to preserve session across agent calls
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
A Tool that can execute parallel agent groups
|
|
251
|
+
"""
|
|
252
|
+
config = ParallelExecutionConfig(
|
|
253
|
+
groups=groups,
|
|
254
|
+
inter_group_execution=inter_group_execution,
|
|
255
|
+
global_timeout=global_timeout,
|
|
256
|
+
preserve_session=preserve_session
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
return ParallelAgentsTool(config, tool_name, tool_description)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def create_simple_parallel_tool(
|
|
263
|
+
agents: List[Agent],
|
|
264
|
+
group_name: str = "parallel_group",
|
|
265
|
+
tool_name: str = "execute_parallel_agents",
|
|
266
|
+
shared_input: bool = True,
|
|
267
|
+
result_aggregation: str = "combine",
|
|
268
|
+
timeout: Optional[float] = None
|
|
269
|
+
) -> Tool:
|
|
270
|
+
"""
|
|
271
|
+
Create a simple parallel agents tool from a list of agents.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
agents: List of agents to execute in parallel
|
|
275
|
+
group_name: Name for the parallel group
|
|
276
|
+
tool_name: Name of the tool
|
|
277
|
+
shared_input: Whether all agents receive the same input
|
|
278
|
+
result_aggregation: How to aggregate results ("combine", "first", "majority")
|
|
279
|
+
timeout: Timeout for parallel execution
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
A Tool that executes all agents in parallel
|
|
283
|
+
"""
|
|
284
|
+
group = ParallelAgentGroup(
|
|
285
|
+
name=group_name,
|
|
286
|
+
agents=agents,
|
|
287
|
+
shared_input=shared_input,
|
|
288
|
+
result_aggregation=result_aggregation,
|
|
289
|
+
timeout=timeout
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
return create_parallel_agents_tool([group], tool_name=tool_name)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# Convenience functions for common parallel execution patterns
|
|
296
|
+
|
|
297
|
+
def create_language_specialists_tool(
|
|
298
|
+
language_agents: Dict[str, Agent],
|
|
299
|
+
tool_name: str = "consult_language_specialists",
|
|
300
|
+
timeout: Optional[float] = 300.0
|
|
301
|
+
) -> Tool:
|
|
302
|
+
"""Create a tool that consults multiple language specialists in parallel."""
|
|
303
|
+
group = ParallelAgentGroup(
|
|
304
|
+
name="language_specialists",
|
|
305
|
+
agents=list(language_agents.values()),
|
|
306
|
+
shared_input=True,
|
|
307
|
+
result_aggregation="combine",
|
|
308
|
+
timeout=timeout,
|
|
309
|
+
metadata={"languages": list(language_agents.keys())}
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return create_parallel_agents_tool(
|
|
313
|
+
[group],
|
|
314
|
+
tool_name=tool_name,
|
|
315
|
+
tool_description="Consult multiple language specialists in parallel"
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def create_domain_experts_tool(
|
|
320
|
+
expert_agents: Dict[str, Agent],
|
|
321
|
+
tool_name: str = "consult_domain_experts",
|
|
322
|
+
result_aggregation: str = "combine",
|
|
323
|
+
timeout: Optional[float] = 60.0
|
|
324
|
+
) -> Tool:
|
|
325
|
+
"""Create a tool that consults multiple domain experts in parallel."""
|
|
326
|
+
group = ParallelAgentGroup(
|
|
327
|
+
name="domain_experts",
|
|
328
|
+
agents=list(expert_agents.values()),
|
|
329
|
+
shared_input=True,
|
|
330
|
+
result_aggregation=result_aggregation,
|
|
331
|
+
timeout=timeout,
|
|
332
|
+
metadata={"domains": list(expert_agents.keys())}
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
return create_parallel_agents_tool(
|
|
336
|
+
[group],
|
|
337
|
+
tool_name=tool_name,
|
|
338
|
+
tool_description="Consult multiple domain experts in parallel"
|
|
339
|
+
)
|
jaf/core/streaming.py
CHANGED
|
@@ -209,20 +209,37 @@ async def run_streaming(
|
|
|
209
209
|
trace_id=initial_state.trace_id
|
|
210
210
|
)
|
|
211
211
|
|
|
212
|
-
tool_call_ids = {}
|
|
212
|
+
tool_call_ids: Dict[str, str] = {} # Map call_id -> tool_name for in-flight tool calls
|
|
213
213
|
|
|
214
214
|
def event_handler(event: TraceEvent) -> None:
|
|
215
215
|
"""Handle trace events and put them into the queue."""
|
|
216
216
|
nonlocal tool_call_ids
|
|
217
217
|
streaming_event = None
|
|
218
|
+
payload = event.data
|
|
219
|
+
|
|
220
|
+
def _get_event_value(keys: List[str]) -> Any:
|
|
221
|
+
for key in keys:
|
|
222
|
+
if isinstance(payload, dict) and key in payload:
|
|
223
|
+
return payload[key]
|
|
224
|
+
if hasattr(payload, key):
|
|
225
|
+
return getattr(payload, key)
|
|
226
|
+
return None
|
|
227
|
+
|
|
218
228
|
if event.type == 'tool_call_start':
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
229
|
+
tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
|
|
230
|
+
args = _get_event_value(['args', 'arguments'])
|
|
231
|
+
call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
|
|
232
|
+
|
|
233
|
+
if not call_id:
|
|
234
|
+
call_id = f"call_{uuid.uuid4().hex[:8]}"
|
|
235
|
+
if isinstance(payload, dict):
|
|
236
|
+
payload['call_id'] = call_id
|
|
237
|
+
|
|
238
|
+
tool_call_ids[call_id] = tool_name
|
|
239
|
+
|
|
223
240
|
tool_call = StreamingToolCall(
|
|
224
|
-
tool_name=
|
|
225
|
-
arguments=
|
|
241
|
+
tool_name=tool_name,
|
|
242
|
+
arguments=args,
|
|
226
243
|
call_id=call_id,
|
|
227
244
|
status='started'
|
|
228
245
|
)
|
|
@@ -233,18 +250,26 @@ async def run_streaming(
|
|
|
233
250
|
trace_id=initial_state.trace_id
|
|
234
251
|
)
|
|
235
252
|
elif event.type == 'tool_call_end':
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
)
|
|
242
|
-
|
|
253
|
+
tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
|
|
254
|
+
call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
|
|
255
|
+
|
|
256
|
+
if not call_id:
|
|
257
|
+
# Fallback to locate a pending tool call with the same tool name
|
|
258
|
+
matching_call_id = next((cid for cid, name in tool_call_ids.items() if name == tool_name), None)
|
|
259
|
+
if matching_call_id:
|
|
260
|
+
call_id = matching_call_id
|
|
261
|
+
else:
|
|
262
|
+
raise RuntimeError(
|
|
263
|
+
f"Tool call end event received for unknown tool '{tool_name}'. "
|
|
264
|
+
f"Pending call IDs: {list(tool_call_ids.keys())}."
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
tool_call_ids.pop(call_id, None)
|
|
243
268
|
tool_result = StreamingToolResult(
|
|
244
|
-
tool_name=
|
|
269
|
+
tool_name=tool_name,
|
|
245
270
|
call_id=call_id,
|
|
246
|
-
result=
|
|
247
|
-
status=
|
|
271
|
+
result=_get_event_value(['result']),
|
|
272
|
+
status=_get_event_value(['status']) or 'completed'
|
|
248
273
|
)
|
|
249
274
|
streaming_event = StreamingEvent(
|
|
250
275
|
type=StreamingEventType.TOOL_RESULT,
|
jaf/core/tracing.py
CHANGED
|
@@ -10,6 +10,7 @@ import json
|
|
|
10
10
|
import time
|
|
11
11
|
from datetime import datetime
|
|
12
12
|
from typing import Any, Dict, List, Optional, Protocol
|
|
13
|
+
import uuid
|
|
13
14
|
|
|
14
15
|
from opentelemetry import trace
|
|
15
16
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
|
@@ -652,28 +653,36 @@ class LangfuseTraceCollector:
|
|
|
652
653
|
# Start a span for tool calls with detailed input information
|
|
653
654
|
tool_name = event.data.get('tool_name', 'unknown')
|
|
654
655
|
tool_args = event.data.get("args", {})
|
|
656
|
+
call_id = event.data.get("call_id")
|
|
657
|
+
if not call_id:
|
|
658
|
+
call_id = f"{tool_name}-{uuid.uuid4().hex[:8]}"
|
|
659
|
+
try:
|
|
660
|
+
event.data["call_id"] = call_id
|
|
661
|
+
except TypeError:
|
|
662
|
+
# event.data may be immutable; log and rely on synthetic ID tracking downstream
|
|
663
|
+
print(f"[LANGFUSE] Generated synthetic call_id for tool start: {call_id}")
|
|
655
664
|
|
|
656
|
-
print(f"[LANGFUSE] Starting span for tool call: {tool_name}")
|
|
665
|
+
print(f"[LANGFUSE] Starting span for tool call: {tool_name} ({call_id})")
|
|
657
666
|
|
|
658
667
|
# Track this tool call for the trace
|
|
659
668
|
tool_call_data = {
|
|
660
669
|
"tool_name": tool_name,
|
|
661
670
|
"arguments": tool_args,
|
|
662
|
-
"call_id":
|
|
671
|
+
"call_id": call_id,
|
|
663
672
|
"timestamp": datetime.now().isoformat()
|
|
664
673
|
}
|
|
665
674
|
|
|
666
675
|
# Ensure trace_id exists in tracking
|
|
667
676
|
if trace_id not in self.trace_tool_calls:
|
|
668
677
|
self.trace_tool_calls[trace_id] = []
|
|
669
|
-
|
|
678
|
+
|
|
670
679
|
self.trace_tool_calls[trace_id].append(tool_call_data)
|
|
671
680
|
|
|
672
681
|
# Create comprehensive input data for the tool call
|
|
673
682
|
tool_input = {
|
|
674
683
|
"tool_name": tool_name,
|
|
675
684
|
"arguments": tool_args,
|
|
676
|
-
"call_id":
|
|
685
|
+
"call_id": call_id,
|
|
677
686
|
"timestamp": datetime.now().isoformat()
|
|
678
687
|
}
|
|
679
688
|
|
|
@@ -682,7 +691,7 @@ class LangfuseTraceCollector:
|
|
|
682
691
|
input=tool_input,
|
|
683
692
|
metadata={
|
|
684
693
|
"tool_name": tool_name,
|
|
685
|
-
"call_id":
|
|
694
|
+
"call_id": call_id,
|
|
686
695
|
"framework": "jaf",
|
|
687
696
|
"event_type": "tool_call"
|
|
688
697
|
}
|
|
@@ -696,14 +705,15 @@ class LangfuseTraceCollector:
|
|
|
696
705
|
if span_id in self.active_spans:
|
|
697
706
|
tool_name = event.data.get('tool_name', 'unknown')
|
|
698
707
|
tool_result = event.data.get("result")
|
|
708
|
+
call_id = event.data.get("call_id")
|
|
699
709
|
|
|
700
|
-
print(f"[LANGFUSE] Ending span for tool call: {tool_name}")
|
|
710
|
+
print(f"[LANGFUSE] Ending span for tool call: {tool_name} ({call_id})")
|
|
701
711
|
|
|
702
712
|
# Track this tool result for the trace
|
|
703
713
|
tool_result_data = {
|
|
704
714
|
"tool_name": tool_name,
|
|
705
715
|
"result": tool_result,
|
|
706
|
-
"call_id":
|
|
716
|
+
"call_id": call_id,
|
|
707
717
|
"timestamp": datetime.now().isoformat(),
|
|
708
718
|
"status": event.data.get("status", "completed"),
|
|
709
719
|
"tool_result": event.data.get("tool_result")
|
|
@@ -718,7 +728,7 @@ class LangfuseTraceCollector:
|
|
|
718
728
|
tool_output = {
|
|
719
729
|
"tool_name": tool_name,
|
|
720
730
|
"result": tool_result,
|
|
721
|
-
"call_id":
|
|
731
|
+
"call_id": call_id,
|
|
722
732
|
"timestamp": datetime.now().isoformat(),
|
|
723
733
|
"status": event.data.get("status", "completed")
|
|
724
734
|
}
|
|
@@ -729,7 +739,7 @@ class LangfuseTraceCollector:
|
|
|
729
739
|
output=tool_output,
|
|
730
740
|
metadata={
|
|
731
741
|
"tool_name": tool_name,
|
|
732
|
-
"call_id":
|
|
742
|
+
"call_id": call_id,
|
|
733
743
|
"result_length": len(str(tool_result)) if tool_result else 0,
|
|
734
744
|
"framework": "jaf",
|
|
735
745
|
"event_type": "tool_call_end"
|
|
@@ -791,6 +801,9 @@ class LangfuseTraceCollector:
|
|
|
791
801
|
|
|
792
802
|
# Use consistent identifiers that don't depend on timestamp
|
|
793
803
|
if event.type.startswith('tool_call'):
|
|
804
|
+
call_id = event.data.get('call_id') or event.data.get('tool_call_id')
|
|
805
|
+
if call_id:
|
|
806
|
+
return f"tool-{trace_id}-{call_id}"
|
|
794
807
|
tool_name = event.data.get('tool_name') or event.data.get('toolName', 'unknown')
|
|
795
808
|
return f"tool-{tool_name}-{trace_id}"
|
|
796
809
|
elif event.type.startswith('llm_call'):
|
jaf/core/types.py
CHANGED
|
@@ -288,6 +288,7 @@ class Agent(Generic[Ctx, Out]):
|
|
|
288
288
|
output_codec: Optional[Any] = None # Type that can validate Out (like Pydantic model or Zod equivalent)
|
|
289
289
|
handoffs: Optional[List[str]] = None
|
|
290
290
|
model_config: Optional[ModelConfig] = None
|
|
291
|
+
advanced_config: Optional['AdvancedConfig'] = None
|
|
291
292
|
|
|
292
293
|
def as_tool(
|
|
293
294
|
self,
|
|
@@ -331,6 +332,74 @@ class Agent(Generic[Ctx, Out]):
|
|
|
331
332
|
# Guardrail type
|
|
332
333
|
Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
|
|
333
334
|
|
|
335
|
+
@dataclass(frozen=True)
|
|
336
|
+
class AdvancedGuardrailsConfig:
|
|
337
|
+
"""Configuration for advanced guardrails with LLM-based validation."""
|
|
338
|
+
input_prompt: Optional[str] = None
|
|
339
|
+
output_prompt: Optional[str] = None
|
|
340
|
+
require_citations: bool = False
|
|
341
|
+
fast_model: Optional[str] = None
|
|
342
|
+
fail_safe: Literal['allow', 'block'] = 'allow'
|
|
343
|
+
execution_mode: Literal['parallel', 'sequential'] = 'parallel'
|
|
344
|
+
timeout_ms: int = 30000
|
|
345
|
+
|
|
346
|
+
def __post_init__(self):
|
|
347
|
+
"""Validate configuration."""
|
|
348
|
+
if self.timeout_ms < 1000:
|
|
349
|
+
object.__setattr__(self, 'timeout_ms', 1000)
|
|
350
|
+
|
|
351
|
+
@dataclass(frozen=True)
|
|
352
|
+
class AdvancedConfig:
|
|
353
|
+
"""Advanced agent configuration including guardrails."""
|
|
354
|
+
guardrails: Optional[AdvancedGuardrailsConfig] = None
|
|
355
|
+
|
|
356
|
+
def validate_guardrails_config(config: Optional[AdvancedGuardrailsConfig]) -> AdvancedGuardrailsConfig:
|
|
357
|
+
"""Validate and provide defaults for guardrails configuration."""
|
|
358
|
+
if config is None:
|
|
359
|
+
return AdvancedGuardrailsConfig()
|
|
360
|
+
|
|
361
|
+
return AdvancedGuardrailsConfig(
|
|
362
|
+
input_prompt=config.input_prompt.strip() if isinstance(config.input_prompt, str) and config.input_prompt else None,
|
|
363
|
+
output_prompt=config.output_prompt.strip() if isinstance(config.output_prompt, str) and config.output_prompt else None,
|
|
364
|
+
require_citations=config.require_citations,
|
|
365
|
+
fast_model=config.fast_model.strip() if isinstance(config.fast_model, str) and config.fast_model else None,
|
|
366
|
+
fail_safe=config.fail_safe,
|
|
367
|
+
execution_mode=config.execution_mode,
|
|
368
|
+
timeout_ms=max(1000, config.timeout_ms)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
|
|
372
|
+
"""Parse JSON from LLM output, handling common formatting issues."""
|
|
373
|
+
import json
|
|
374
|
+
import re
|
|
375
|
+
|
|
376
|
+
if not text:
|
|
377
|
+
return None
|
|
378
|
+
|
|
379
|
+
# Try direct parsing first
|
|
380
|
+
try:
|
|
381
|
+
return json.loads(text)
|
|
382
|
+
except json.JSONDecodeError:
|
|
383
|
+
pass
|
|
384
|
+
|
|
385
|
+
# Try to extract JSON from markdown code blocks
|
|
386
|
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
|
387
|
+
if json_match:
|
|
388
|
+
try:
|
|
389
|
+
return json.loads(json_match.group(1))
|
|
390
|
+
except json.JSONDecodeError:
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
# Try to find the first JSON object in the text
|
|
394
|
+
json_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
|
395
|
+
if json_match:
|
|
396
|
+
try:
|
|
397
|
+
return json.loads(json_match.group(0))
|
|
398
|
+
except json.JSONDecodeError:
|
|
399
|
+
pass
|
|
400
|
+
|
|
401
|
+
return None
|
|
402
|
+
|
|
334
403
|
@dataclass(frozen=True)
|
|
335
404
|
class ApprovalValue:
|
|
336
405
|
"""Represents an approval decision with context."""
|
|
@@ -541,11 +610,12 @@ class ToolCallStartEventData:
|
|
|
541
610
|
args: Any
|
|
542
611
|
trace_id: TraceId
|
|
543
612
|
run_id: RunId
|
|
613
|
+
call_id: Optional[str] = None
|
|
544
614
|
|
|
545
615
|
@dataclass(frozen=True)
|
|
546
616
|
class ToolCallStartEvent:
|
|
547
617
|
type: Literal['tool_call_start'] = 'tool_call_start'
|
|
548
|
-
data: ToolCallStartEventData = field(default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId("")))
|
|
618
|
+
data: ToolCallStartEventData = field(default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None))
|
|
549
619
|
|
|
550
620
|
@dataclass(frozen=True)
|
|
551
621
|
class ToolCallEndEventData:
|
|
@@ -556,11 +626,12 @@ class ToolCallEndEventData:
|
|
|
556
626
|
run_id: RunId
|
|
557
627
|
tool_result: Optional[Any] = None
|
|
558
628
|
status: Optional[str] = None
|
|
629
|
+
call_id: Optional[str] = None
|
|
559
630
|
|
|
560
631
|
@dataclass(frozen=True)
|
|
561
632
|
class ToolCallEndEvent:
|
|
562
633
|
type: Literal['tool_call_end'] = 'tool_call_end'
|
|
563
|
-
data: ToolCallEndEventData = field(default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId("")))
|
|
634
|
+
data: ToolCallEndEventData = field(default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None))
|
|
564
635
|
|
|
565
636
|
@dataclass(frozen=True)
|
|
566
637
|
class HandoffEventData:
|
|
@@ -598,6 +669,17 @@ class GuardrailEvent:
|
|
|
598
669
|
type: Literal['guardrail_check'] = 'guardrail_check'
|
|
599
670
|
data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
|
|
600
671
|
|
|
672
|
+
@dataclass(frozen=True)
|
|
673
|
+
class GuardrailViolationEventData:
|
|
674
|
+
"""Data for guardrail violation events."""
|
|
675
|
+
stage: Literal['input', 'output']
|
|
676
|
+
reason: str
|
|
677
|
+
|
|
678
|
+
@dataclass(frozen=True)
|
|
679
|
+
class GuardrailViolationEvent:
|
|
680
|
+
type: Literal['guardrail_violation'] = 'guardrail_violation'
|
|
681
|
+
data: GuardrailViolationEventData = field(default_factory=lambda: GuardrailViolationEventData("input", ""))
|
|
682
|
+
|
|
601
683
|
@dataclass(frozen=True)
|
|
602
684
|
class MemoryEventData:
|
|
603
685
|
"""Data for memory operation events."""
|
|
@@ -630,6 +712,7 @@ class OutputParseEvent:
|
|
|
630
712
|
TraceEvent = Union[
|
|
631
713
|
RunStartEvent,
|
|
632
714
|
GuardrailEvent,
|
|
715
|
+
GuardrailViolationEvent,
|
|
633
716
|
MemoryEvent,
|
|
634
717
|
OutputParseEvent,
|
|
635
718
|
LLMCallStartEvent,
|
|
@@ -708,7 +791,8 @@ class RunConfig(Generic[Ctx]):
|
|
|
708
791
|
initial_input_guardrails: Optional[List[Guardrail]] = None
|
|
709
792
|
final_output_guardrails: Optional[List[Guardrail]] = None
|
|
710
793
|
on_event: Optional[Callable[[TraceEvent], None]] = None
|
|
711
|
-
memory: Optional[
|
|
794
|
+
memory: Optional[Any] = None # MemoryConfig - avoiding circular import
|
|
712
795
|
conversation_id: Optional[str] = None
|
|
713
|
-
|
|
714
|
-
|
|
796
|
+
default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
|
|
797
|
+
default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
|
|
798
|
+
approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
|