jaf-py 2.4.4__py3-none-any.whl → 2.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,339 @@
1
+ """
2
+ Parallel Agent Execution for JAF Framework.
3
+
4
+ This module provides functionality to execute multiple sub-agents in parallel groups,
5
+ allowing for coordinated parallel execution with configurable grouping and result aggregation.
6
+ """
7
+
8
+ import asyncio
9
+ import json
10
+ from dataclasses import dataclass
11
+ from typing import Any, Dict, List, Optional, Union, Callable, TypeVar
12
+
13
+ from .types import (
14
+ Agent,
15
+ Tool,
16
+ ToolSchema,
17
+ ToolSource,
18
+ RunConfig,
19
+ RunState,
20
+ RunResult,
21
+ Message,
22
+ ContentRole,
23
+ generate_run_id,
24
+ generate_trace_id,
25
+ )
26
+ from .agent_tool import create_agent_tool, AgentToolInput
27
+
28
+ Ctx = TypeVar('Ctx')
29
+ Out = TypeVar('Out')
30
+
31
+
32
+ @dataclass
33
+ class ParallelAgentGroup:
34
+ """Configuration for a group of agents to be executed in parallel."""
35
+ name: str
36
+ agents: List[Agent[Ctx, Out]]
37
+ shared_input: bool = True # Whether all agents receive the same input
38
+ result_aggregation: str = "combine" # "combine", "first", "majority", "custom"
39
+ custom_aggregator: Optional[Callable[[List[str]], str]] = None
40
+ timeout: Optional[float] = None
41
+ metadata: Optional[Dict[str, Any]] = None
42
+
43
+
44
+ @dataclass
45
+ class ParallelExecutionConfig:
46
+ """Configuration for parallel agent execution."""
47
+ groups: List[ParallelAgentGroup]
48
+ inter_group_execution: str = "sequential" # "sequential" or "parallel"
49
+ global_timeout: Optional[float] = None
50
+ preserve_session: bool = False
51
+
52
+
53
+ class ParallelAgentsTool:
54
+ """Tool that executes multiple agent groups in parallel."""
55
+
56
+ def __init__(
57
+ self,
58
+ config: ParallelExecutionConfig,
59
+ tool_name: str = "execute_parallel_agents",
60
+ tool_description: str = "Execute multiple agents in parallel groups"
61
+ ):
62
+ self.config = config
63
+ self.tool_name = tool_name
64
+ self.tool_description = tool_description
65
+
66
+ # Create tool schema
67
+ self.schema = ToolSchema(
68
+ name=tool_name,
69
+ description=tool_description,
70
+ parameters=AgentToolInput,
71
+ timeout=config.global_timeout
72
+ )
73
+ self.source = ToolSource.NATIVE
74
+ self.metadata = {"source": "parallel_agents", "groups": len(config.groups)}
75
+
76
+ async def execute(self, args: AgentToolInput, context: Ctx) -> str:
77
+ """Execute all configured agent groups."""
78
+ try:
79
+ if self.config.inter_group_execution == "parallel":
80
+ # Execute all groups in parallel
81
+ group_results = await asyncio.gather(*[
82
+ self._execute_group(group, args.input, context)
83
+ for group in self.config.groups
84
+ ])
85
+ else:
86
+ # Execute groups sequentially
87
+ group_results = []
88
+ for group in self.config.groups:
89
+ result = await self._execute_group(group, args.input, context)
90
+ group_results.append(result)
91
+
92
+ # Combine results from all groups
93
+ final_result = {
94
+ "parallel_execution_results": {
95
+ group.name: result for group, result in zip(self.config.groups, group_results)
96
+ },
97
+ "execution_mode": self.config.inter_group_execution,
98
+ "total_groups": len(self.config.groups)
99
+ }
100
+
101
+ return json.dumps(final_result, indent=2)
102
+
103
+ except Exception as e:
104
+ return json.dumps({
105
+ "error": "parallel_execution_failed",
106
+ "message": f"Failed to execute parallel agents: {str(e)}",
107
+ "groups_attempted": len(self.config.groups)
108
+ })
109
+
110
+ async def _execute_group(
111
+ self,
112
+ group: ParallelAgentGroup,
113
+ input_text: str,
114
+ context: Ctx
115
+ ) -> Dict[str, Any]:
116
+ """Execute a single group of agents in parallel."""
117
+ try:
118
+ # Create agent tools for all agents in the group
119
+ agent_tools = []
120
+ for agent in group.agents:
121
+ tool = create_agent_tool(
122
+ agent=agent,
123
+ tool_name=f"run_{agent.name.lower().replace(' ', '_')}",
124
+ tool_description=f"Execute the {agent.name} agent",
125
+ timeout=group.timeout,
126
+ preserve_session=self.config.preserve_session
127
+ )
128
+ agent_tools.append((agent.name, tool))
129
+
130
+ # Execute all agents in the group in parallel
131
+ if group.shared_input:
132
+ # All agents get the same input
133
+ tasks = [
134
+ tool.execute(AgentToolInput(input=input_text), context)
135
+ for _, tool in agent_tools
136
+ ]
137
+ else:
138
+ # This could be extended to support different inputs per agent
139
+ tasks = [
140
+ tool.execute(AgentToolInput(input=input_text), context)
141
+ for _, tool in agent_tools
142
+ ]
143
+
144
+ # Execute with timeout if specified
145
+ if group.timeout:
146
+ results = await asyncio.wait_for(
147
+ asyncio.gather(*tasks, return_exceptions=True),
148
+ timeout=group.timeout
149
+ )
150
+ else:
151
+ results = await asyncio.gather(*tasks, return_exceptions=True)
152
+
153
+ # Process results
154
+ agent_results = {}
155
+ for (agent_name, _), result in zip(agent_tools, results):
156
+ if isinstance(result, Exception):
157
+ agent_results[agent_name] = {
158
+ "error": True,
159
+ "message": str(result),
160
+ "type": type(result).__name__
161
+ }
162
+ else:
163
+ agent_results[agent_name] = {
164
+ "success": True,
165
+ "result": result
166
+ }
167
+
168
+ # Apply result aggregation
169
+ aggregated_result = self._aggregate_results(group, agent_results)
170
+
171
+ return {
172
+ "group_name": group.name,
173
+ "agent_count": len(group.agents),
174
+ "individual_results": agent_results,
175
+ "aggregated_result": aggregated_result,
176
+ "execution_time_ms": None # Could be added with timing
177
+ }
178
+
179
+ except asyncio.TimeoutError:
180
+ return {
181
+ "group_name": group.name,
182
+ "error": "timeout",
183
+ "message": f"Group {group.name} execution timed out after {group.timeout} seconds",
184
+ "agent_count": len(group.agents)
185
+ }
186
+ except Exception as e:
187
+ return {
188
+ "group_name": group.name,
189
+ "error": "execution_failed",
190
+ "message": str(e),
191
+ "agent_count": len(group.agents)
192
+ }
193
+
194
+ def _aggregate_results(
195
+ self,
196
+ group: ParallelAgentGroup,
197
+ agent_results: Dict[str, Any]
198
+ ) -> Union[str, Dict[str, Any]]:
199
+ """Aggregate results from parallel agent execution."""
200
+ successful_results = [
201
+ result["result"] for result in agent_results.values()
202
+ if result.get("success") and "result" in result
203
+ ]
204
+
205
+ if not successful_results:
206
+ return {"error": "no_successful_results", "message": "All agents failed"}
207
+
208
+ if group.result_aggregation == "first":
209
+ return successful_results[0]
210
+ elif group.result_aggregation == "combine":
211
+ return {
212
+ "combined_results": successful_results,
213
+ "result_count": len(successful_results)
214
+ }
215
+ elif group.result_aggregation == "majority":
216
+ # Simple majority logic - could be enhanced
217
+ if len(successful_results) >= len(group.agents) // 2 + 1:
218
+ return successful_results[0] # Return first as majority representative
219
+ else:
220
+ return {"error": "no_majority", "results": successful_results}
221
+ elif group.result_aggregation == "custom" and group.custom_aggregator:
222
+ try:
223
+ return group.custom_aggregator(successful_results)
224
+ except Exception as e:
225
+ return {"error": "custom_aggregation_failed", "message": str(e)}
226
+ else:
227
+ return {"combined_results": successful_results}
228
+
229
+
230
+ def create_parallel_agents_tool(
231
+ groups: List[ParallelAgentGroup],
232
+ tool_name: str = "execute_parallel_agents",
233
+ tool_description: str = "Execute multiple agents in parallel groups",
234
+ inter_group_execution: str = "sequential",
235
+ global_timeout: Optional[float] = None,
236
+ preserve_session: bool = False
237
+ ) -> Tool:
238
+ """
239
+ Create a tool that executes multiple agent groups in parallel.
240
+
241
+ Args:
242
+ groups: List of parallel agent groups to execute
243
+ tool_name: Name of the tool
244
+ tool_description: Description of the tool
245
+ inter_group_execution: How to execute groups ("sequential" or "parallel")
246
+ global_timeout: Global timeout for all executions
247
+ preserve_session: Whether to preserve session across agent calls
248
+
249
+ Returns:
250
+ A Tool that can execute parallel agent groups
251
+ """
252
+ config = ParallelExecutionConfig(
253
+ groups=groups,
254
+ inter_group_execution=inter_group_execution,
255
+ global_timeout=global_timeout,
256
+ preserve_session=preserve_session
257
+ )
258
+
259
+ return ParallelAgentsTool(config, tool_name, tool_description)
260
+
261
+
262
+ def create_simple_parallel_tool(
263
+ agents: List[Agent],
264
+ group_name: str = "parallel_group",
265
+ tool_name: str = "execute_parallel_agents",
266
+ shared_input: bool = True,
267
+ result_aggregation: str = "combine",
268
+ timeout: Optional[float] = None
269
+ ) -> Tool:
270
+ """
271
+ Create a simple parallel agents tool from a list of agents.
272
+
273
+ Args:
274
+ agents: List of agents to execute in parallel
275
+ group_name: Name for the parallel group
276
+ tool_name: Name of the tool
277
+ shared_input: Whether all agents receive the same input
278
+ result_aggregation: How to aggregate results ("combine", "first", "majority")
279
+ timeout: Timeout for parallel execution
280
+
281
+ Returns:
282
+ A Tool that executes all agents in parallel
283
+ """
284
+ group = ParallelAgentGroup(
285
+ name=group_name,
286
+ agents=agents,
287
+ shared_input=shared_input,
288
+ result_aggregation=result_aggregation,
289
+ timeout=timeout
290
+ )
291
+
292
+ return create_parallel_agents_tool([group], tool_name=tool_name)
293
+
294
+
295
+ # Convenience functions for common parallel execution patterns
296
+
297
+ def create_language_specialists_tool(
298
+ language_agents: Dict[str, Agent],
299
+ tool_name: str = "consult_language_specialists",
300
+ timeout: Optional[float] = 300.0
301
+ ) -> Tool:
302
+ """Create a tool that consults multiple language specialists in parallel."""
303
+ group = ParallelAgentGroup(
304
+ name="language_specialists",
305
+ agents=list(language_agents.values()),
306
+ shared_input=True,
307
+ result_aggregation="combine",
308
+ timeout=timeout,
309
+ metadata={"languages": list(language_agents.keys())}
310
+ )
311
+
312
+ return create_parallel_agents_tool(
313
+ [group],
314
+ tool_name=tool_name,
315
+ tool_description="Consult multiple language specialists in parallel"
316
+ )
317
+
318
+
319
+ def create_domain_experts_tool(
320
+ expert_agents: Dict[str, Agent],
321
+ tool_name: str = "consult_domain_experts",
322
+ result_aggregation: str = "combine",
323
+ timeout: Optional[float] = 60.0
324
+ ) -> Tool:
325
+ """Create a tool that consults multiple domain experts in parallel."""
326
+ group = ParallelAgentGroup(
327
+ name="domain_experts",
328
+ agents=list(expert_agents.values()),
329
+ shared_input=True,
330
+ result_aggregation=result_aggregation,
331
+ timeout=timeout,
332
+ metadata={"domains": list(expert_agents.keys())}
333
+ )
334
+
335
+ return create_parallel_agents_tool(
336
+ [group],
337
+ tool_name=tool_name,
338
+ tool_description="Consult multiple domain experts in parallel"
339
+ )
jaf/core/streaming.py CHANGED
@@ -209,20 +209,37 @@ async def run_streaming(
209
209
  trace_id=initial_state.trace_id
210
210
  )
211
211
 
212
- tool_call_ids = {} # To map tool calls to their IDs
212
+ tool_call_ids: Dict[str, str] = {} # Map call_id -> tool_name for in-flight tool calls
213
213
 
214
214
  def event_handler(event: TraceEvent) -> None:
215
215
  """Handle trace events and put them into the queue."""
216
216
  nonlocal tool_call_ids
217
217
  streaming_event = None
218
+ payload = event.data
219
+
220
+ def _get_event_value(keys: List[str]) -> Any:
221
+ for key in keys:
222
+ if isinstance(payload, dict) and key in payload:
223
+ return payload[key]
224
+ if hasattr(payload, key):
225
+ return getattr(payload, key)
226
+ return None
227
+
218
228
  if event.type == 'tool_call_start':
219
- # Generate a unique ID for the tool call
220
- call_id = f"call_{uuid.uuid4().hex[:8]}"
221
- tool_call_ids[event.data.tool_name] = call_id
222
-
229
+ tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
230
+ args = _get_event_value(['args', 'arguments'])
231
+ call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
232
+
233
+ if not call_id:
234
+ call_id = f"call_{uuid.uuid4().hex[:8]}"
235
+ if isinstance(payload, dict):
236
+ payload['call_id'] = call_id
237
+
238
+ tool_call_ids[call_id] = tool_name
239
+
223
240
  tool_call = StreamingToolCall(
224
- tool_name=event.data.tool_name,
225
- arguments=event.data.args,
241
+ tool_name=tool_name,
242
+ arguments=args,
226
243
  call_id=call_id,
227
244
  status='started'
228
245
  )
@@ -233,18 +250,26 @@ async def run_streaming(
233
250
  trace_id=initial_state.trace_id
234
251
  )
235
252
  elif event.type == 'tool_call_end':
236
- if event.data.tool_name not in tool_call_ids:
237
- raise RuntimeError(
238
- f"Tool call end event received for unknown tool '{event.data.tool_name}'. "
239
- f"Known tool calls: {list(tool_call_ids.keys())}. "
240
- f"This may indicate a missing tool_call_start event or a bug in the streaming implementation."
241
- )
242
- call_id = tool_call_ids[event.data.tool_name]
253
+ tool_name = _get_event_value(['tool_name', 'toolName']) or 'unknown'
254
+ call_id = _get_event_value(['call_id', 'tool_call_id', 'toolCallId'])
255
+
256
+ if not call_id:
257
+ # Fallback to locate a pending tool call with the same tool name
258
+ matching_call_id = next((cid for cid, name in tool_call_ids.items() if name == tool_name), None)
259
+ if matching_call_id:
260
+ call_id = matching_call_id
261
+ else:
262
+ raise RuntimeError(
263
+ f"Tool call end event received for unknown tool '{tool_name}'. "
264
+ f"Pending call IDs: {list(tool_call_ids.keys())}."
265
+ )
266
+
267
+ tool_call_ids.pop(call_id, None)
243
268
  tool_result = StreamingToolResult(
244
- tool_name=event.data.tool_name,
269
+ tool_name=tool_name,
245
270
  call_id=call_id,
246
- result=event.data.result,
247
- status=event.data.status or 'completed'
271
+ result=_get_event_value(['result']),
272
+ status=_get_event_value(['status']) or 'completed'
248
273
  )
249
274
  streaming_event = StreamingEvent(
250
275
  type=StreamingEventType.TOOL_RESULT,
jaf/core/tracing.py CHANGED
@@ -10,6 +10,7 @@ import json
10
10
  import time
11
11
  from datetime import datetime
12
12
  from typing import Any, Dict, List, Optional, Protocol
13
+ import uuid
13
14
 
14
15
  from opentelemetry import trace
15
16
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
@@ -652,28 +653,36 @@ class LangfuseTraceCollector:
652
653
  # Start a span for tool calls with detailed input information
653
654
  tool_name = event.data.get('tool_name', 'unknown')
654
655
  tool_args = event.data.get("args", {})
656
+ call_id = event.data.get("call_id")
657
+ if not call_id:
658
+ call_id = f"{tool_name}-{uuid.uuid4().hex[:8]}"
659
+ try:
660
+ event.data["call_id"] = call_id
661
+ except TypeError:
662
+ # event.data may be immutable; log and rely on synthetic ID tracking downstream
663
+ print(f"[LANGFUSE] Generated synthetic call_id for tool start: {call_id}")
655
664
 
656
- print(f"[LANGFUSE] Starting span for tool call: {tool_name}")
665
+ print(f"[LANGFUSE] Starting span for tool call: {tool_name} ({call_id})")
657
666
 
658
667
  # Track this tool call for the trace
659
668
  tool_call_data = {
660
669
  "tool_name": tool_name,
661
670
  "arguments": tool_args,
662
- "call_id": event.data.get("call_id"),
671
+ "call_id": call_id,
663
672
  "timestamp": datetime.now().isoformat()
664
673
  }
665
674
 
666
675
  # Ensure trace_id exists in tracking
667
676
  if trace_id not in self.trace_tool_calls:
668
677
  self.trace_tool_calls[trace_id] = []
669
-
678
+
670
679
  self.trace_tool_calls[trace_id].append(tool_call_data)
671
680
 
672
681
  # Create comprehensive input data for the tool call
673
682
  tool_input = {
674
683
  "tool_name": tool_name,
675
684
  "arguments": tool_args,
676
- "call_id": event.data.get("call_id"),
685
+ "call_id": call_id,
677
686
  "timestamp": datetime.now().isoformat()
678
687
  }
679
688
 
@@ -682,7 +691,7 @@ class LangfuseTraceCollector:
682
691
  input=tool_input,
683
692
  metadata={
684
693
  "tool_name": tool_name,
685
- "call_id": event.data.get("call_id"),
694
+ "call_id": call_id,
686
695
  "framework": "jaf",
687
696
  "event_type": "tool_call"
688
697
  }
@@ -696,14 +705,15 @@ class LangfuseTraceCollector:
696
705
  if span_id in self.active_spans:
697
706
  tool_name = event.data.get('tool_name', 'unknown')
698
707
  tool_result = event.data.get("result")
708
+ call_id = event.data.get("call_id")
699
709
 
700
- print(f"[LANGFUSE] Ending span for tool call: {tool_name}")
710
+ print(f"[LANGFUSE] Ending span for tool call: {tool_name} ({call_id})")
701
711
 
702
712
  # Track this tool result for the trace
703
713
  tool_result_data = {
704
714
  "tool_name": tool_name,
705
715
  "result": tool_result,
706
- "call_id": event.data.get("call_id"),
716
+ "call_id": call_id,
707
717
  "timestamp": datetime.now().isoformat(),
708
718
  "status": event.data.get("status", "completed"),
709
719
  "tool_result": event.data.get("tool_result")
@@ -718,7 +728,7 @@ class LangfuseTraceCollector:
718
728
  tool_output = {
719
729
  "tool_name": tool_name,
720
730
  "result": tool_result,
721
- "call_id": event.data.get("call_id"),
731
+ "call_id": call_id,
722
732
  "timestamp": datetime.now().isoformat(),
723
733
  "status": event.data.get("status", "completed")
724
734
  }
@@ -729,7 +739,7 @@ class LangfuseTraceCollector:
729
739
  output=tool_output,
730
740
  metadata={
731
741
  "tool_name": tool_name,
732
- "call_id": event.data.get("call_id"),
742
+ "call_id": call_id,
733
743
  "result_length": len(str(tool_result)) if tool_result else 0,
734
744
  "framework": "jaf",
735
745
  "event_type": "tool_call_end"
@@ -791,6 +801,9 @@ class LangfuseTraceCollector:
791
801
 
792
802
  # Use consistent identifiers that don't depend on timestamp
793
803
  if event.type.startswith('tool_call'):
804
+ call_id = event.data.get('call_id') or event.data.get('tool_call_id')
805
+ if call_id:
806
+ return f"tool-{trace_id}-{call_id}"
794
807
  tool_name = event.data.get('tool_name') or event.data.get('toolName', 'unknown')
795
808
  return f"tool-{tool_name}-{trace_id}"
796
809
  elif event.type.startswith('llm_call'):
jaf/core/types.py CHANGED
@@ -288,6 +288,7 @@ class Agent(Generic[Ctx, Out]):
288
288
  output_codec: Optional[Any] = None # Type that can validate Out (like Pydantic model or Zod equivalent)
289
289
  handoffs: Optional[List[str]] = None
290
290
  model_config: Optional[ModelConfig] = None
291
+ advanced_config: Optional['AdvancedConfig'] = None
291
292
 
292
293
  def as_tool(
293
294
  self,
@@ -331,6 +332,74 @@ class Agent(Generic[Ctx, Out]):
331
332
  # Guardrail type
332
333
  Guardrail = Callable[[Any], Union[ValidationResult, Awaitable[ValidationResult]]]
333
334
 
335
+ @dataclass(frozen=True)
336
+ class AdvancedGuardrailsConfig:
337
+ """Configuration for advanced guardrails with LLM-based validation."""
338
+ input_prompt: Optional[str] = None
339
+ output_prompt: Optional[str] = None
340
+ require_citations: bool = False
341
+ fast_model: Optional[str] = None
342
+ fail_safe: Literal['allow', 'block'] = 'allow'
343
+ execution_mode: Literal['parallel', 'sequential'] = 'parallel'
344
+ timeout_ms: int = 30000
345
+
346
+ def __post_init__(self):
347
+ """Validate configuration."""
348
+ if self.timeout_ms < 1000:
349
+ object.__setattr__(self, 'timeout_ms', 1000)
350
+
351
+ @dataclass(frozen=True)
352
+ class AdvancedConfig:
353
+ """Advanced agent configuration including guardrails."""
354
+ guardrails: Optional[AdvancedGuardrailsConfig] = None
355
+
356
+ def validate_guardrails_config(config: Optional[AdvancedGuardrailsConfig]) -> AdvancedGuardrailsConfig:
357
+ """Validate and provide defaults for guardrails configuration."""
358
+ if config is None:
359
+ return AdvancedGuardrailsConfig()
360
+
361
+ return AdvancedGuardrailsConfig(
362
+ input_prompt=config.input_prompt.strip() if isinstance(config.input_prompt, str) and config.input_prompt else None,
363
+ output_prompt=config.output_prompt.strip() if isinstance(config.output_prompt, str) and config.output_prompt else None,
364
+ require_citations=config.require_citations,
365
+ fast_model=config.fast_model.strip() if isinstance(config.fast_model, str) and config.fast_model else None,
366
+ fail_safe=config.fail_safe,
367
+ execution_mode=config.execution_mode,
368
+ timeout_ms=max(1000, config.timeout_ms)
369
+ )
370
+
371
+ def json_parse_llm_output(text: str) -> Optional[Dict[str, Any]]:
372
+ """Parse JSON from LLM output, handling common formatting issues."""
373
+ import json
374
+ import re
375
+
376
+ if not text:
377
+ return None
378
+
379
+ # Try direct parsing first
380
+ try:
381
+ return json.loads(text)
382
+ except json.JSONDecodeError:
383
+ pass
384
+
385
+ # Try to extract JSON from markdown code blocks
386
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
387
+ if json_match:
388
+ try:
389
+ return json.loads(json_match.group(1))
390
+ except json.JSONDecodeError:
391
+ pass
392
+
393
+ # Try to find the first JSON object in the text
394
+ json_match = re.search(r'\{.*?\}', text, re.DOTALL)
395
+ if json_match:
396
+ try:
397
+ return json.loads(json_match.group(0))
398
+ except json.JSONDecodeError:
399
+ pass
400
+
401
+ return None
402
+
334
403
  @dataclass(frozen=True)
335
404
  class ApprovalValue:
336
405
  """Represents an approval decision with context."""
@@ -541,11 +610,12 @@ class ToolCallStartEventData:
541
610
  args: Any
542
611
  trace_id: TraceId
543
612
  run_id: RunId
613
+ call_id: Optional[str] = None
544
614
 
545
615
  @dataclass(frozen=True)
546
616
  class ToolCallStartEvent:
547
617
  type: Literal['tool_call_start'] = 'tool_call_start'
548
- data: ToolCallStartEventData = field(default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId("")))
618
+ data: ToolCallStartEventData = field(default_factory=lambda: ToolCallStartEventData("", None, TraceId(""), RunId(""), None))
549
619
 
550
620
  @dataclass(frozen=True)
551
621
  class ToolCallEndEventData:
@@ -556,11 +626,12 @@ class ToolCallEndEventData:
556
626
  run_id: RunId
557
627
  tool_result: Optional[Any] = None
558
628
  status: Optional[str] = None
629
+ call_id: Optional[str] = None
559
630
 
560
631
  @dataclass(frozen=True)
561
632
  class ToolCallEndEvent:
562
633
  type: Literal['tool_call_end'] = 'tool_call_end'
563
- data: ToolCallEndEventData = field(default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId("")))
634
+ data: ToolCallEndEventData = field(default_factory=lambda: ToolCallEndEventData("", "", TraceId(""), RunId(""), None, None))
564
635
 
565
636
  @dataclass(frozen=True)
566
637
  class HandoffEventData:
@@ -598,6 +669,17 @@ class GuardrailEvent:
598
669
  type: Literal['guardrail_check'] = 'guardrail_check'
599
670
  data: GuardrailEventData = field(default_factory=lambda: GuardrailEventData(""))
600
671
 
672
+ @dataclass(frozen=True)
673
+ class GuardrailViolationEventData:
674
+ """Data for guardrail violation events."""
675
+ stage: Literal['input', 'output']
676
+ reason: str
677
+
678
+ @dataclass(frozen=True)
679
+ class GuardrailViolationEvent:
680
+ type: Literal['guardrail_violation'] = 'guardrail_violation'
681
+ data: GuardrailViolationEventData = field(default_factory=lambda: GuardrailViolationEventData("input", ""))
682
+
601
683
  @dataclass(frozen=True)
602
684
  class MemoryEventData:
603
685
  """Data for memory operation events."""
@@ -630,6 +712,7 @@ class OutputParseEvent:
630
712
  TraceEvent = Union[
631
713
  RunStartEvent,
632
714
  GuardrailEvent,
715
+ GuardrailViolationEvent,
633
716
  MemoryEvent,
634
717
  OutputParseEvent,
635
718
  LLMCallStartEvent,
@@ -708,7 +791,8 @@ class RunConfig(Generic[Ctx]):
708
791
  initial_input_guardrails: Optional[List[Guardrail]] = None
709
792
  final_output_guardrails: Optional[List[Guardrail]] = None
710
793
  on_event: Optional[Callable[[TraceEvent], None]] = None
711
- memory: Optional['MemoryConfig'] = None
794
+ memory: Optional[Any] = None # MemoryConfig - avoiding circular import
712
795
  conversation_id: Optional[str] = None
713
- default_tool_timeout: Optional[float] = 30.0 # Default timeout for tool execution in seconds
714
- approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions
796
+ default_fast_model: Optional[str] = None # Default model for fast operations like guardrails
797
+ default_tool_timeout: Optional[float] = 300.0 # Default timeout for tool execution in seconds
798
+ approval_storage: Optional['ApprovalStorage'] = None # Storage for approval decisions