RouteKitAI 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. routekitai/__init__.py +53 -0
  2. routekitai/cli/__init__.py +18 -0
  3. routekitai/cli/main.py +40 -0
  4. routekitai/cli/replay.py +80 -0
  5. routekitai/cli/run.py +95 -0
  6. routekitai/cli/serve.py +966 -0
  7. routekitai/cli/test_agent.py +178 -0
  8. routekitai/cli/trace.py +209 -0
  9. routekitai/cli/trace_analyze.py +120 -0
  10. routekitai/cli/trace_search.py +126 -0
  11. routekitai/core/__init__.py +58 -0
  12. routekitai/core/agent.py +325 -0
  13. routekitai/core/errors.py +49 -0
  14. routekitai/core/hooks.py +174 -0
  15. routekitai/core/memory.py +54 -0
  16. routekitai/core/message.py +132 -0
  17. routekitai/core/model.py +91 -0
  18. routekitai/core/policies.py +373 -0
  19. routekitai/core/policy.py +85 -0
  20. routekitai/core/policy_adapter.py +133 -0
  21. routekitai/core/runtime.py +1403 -0
  22. routekitai/core/tool.py +148 -0
  23. routekitai/core/tools.py +180 -0
  24. routekitai/evals/__init__.py +13 -0
  25. routekitai/evals/dataset.py +75 -0
  26. routekitai/evals/metrics.py +101 -0
  27. routekitai/evals/runner.py +184 -0
  28. routekitai/graphs/__init__.py +12 -0
  29. routekitai/graphs/executors.py +457 -0
  30. routekitai/graphs/graph.py +164 -0
  31. routekitai/memory/__init__.py +13 -0
  32. routekitai/memory/episodic.py +242 -0
  33. routekitai/memory/kv.py +34 -0
  34. routekitai/memory/retrieval.py +192 -0
  35. routekitai/memory/vector.py +700 -0
  36. routekitai/memory/working.py +66 -0
  37. routekitai/message.py +29 -0
  38. routekitai/model.py +48 -0
  39. routekitai/observability/__init__.py +21 -0
  40. routekitai/observability/analyzer.py +314 -0
  41. routekitai/observability/exporters/__init__.py +10 -0
  42. routekitai/observability/exporters/base.py +30 -0
  43. routekitai/observability/exporters/jsonl.py +81 -0
  44. routekitai/observability/exporters/otel.py +119 -0
  45. routekitai/observability/spans.py +111 -0
  46. routekitai/observability/streaming.py +117 -0
  47. routekitai/observability/trace.py +144 -0
  48. routekitai/providers/__init__.py +9 -0
  49. routekitai/providers/anthropic.py +227 -0
  50. routekitai/providers/azure_openai.py +243 -0
  51. routekitai/providers/local.py +196 -0
  52. routekitai/providers/openai.py +321 -0
  53. routekitai/py.typed +0 -0
  54. routekitai/sandbox/__init__.py +12 -0
  55. routekitai/sandbox/filesystem.py +131 -0
  56. routekitai/sandbox/network.py +142 -0
  57. routekitai/sandbox/permissions.py +70 -0
  58. routekitai/tool.py +33 -0
  59. routekitai-0.1.0.dist-info/METADATA +328 -0
  60. routekitai-0.1.0.dist-info/RECORD +64 -0
  61. routekitai-0.1.0.dist-info/WHEEL +5 -0
  62. routekitai-0.1.0.dist-info/entry_points.txt +2 -0
  63. routekitai-0.1.0.dist-info/licenses/LICENSE +21 -0
  64. routekitai-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1403 @@
1
+ """Runtime/Orchestrator for RouteKit."""
2
+
3
+ import asyncio
4
+ import json
5
+ import time
6
+ import uuid
7
+ from abc import ABC, abstractmethod
8
+ from collections.abc import AsyncIterator, Callable
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+ from routekitai.core.errors import ModelError, ToolError
15
+ from routekitai.core.errors import RuntimeError as RouteKitRuntimeError
16
+ from routekitai.core.hooks import PolicyHooks
17
+ from routekitai.core.message import Message, MessageRole
18
+ from routekitai.core.model import ModelResponse, StreamEvent
19
+ from routekitai.core.tool import Tool
20
+ from routekitai.observability.exporters.jsonl import JSONLExporter
21
+ from routekitai.observability.trace import Trace, TraceEvent
22
+ from routekitai.sandbox.permissions import PermissionManager
23
+
24
+ if TYPE_CHECKING:
25
+ from routekitai.core.agent import Agent, RunResult
26
+
27
+
28
+ class Step(BaseModel):
29
+ """A single execution step."""
30
+
31
+ step_id: str = Field(..., description="Step ID")
32
+ step_type: str = Field(..., description="Step type (model_call, tool_call, etc.)")
33
+ input_data: dict[str, Any] = Field(..., description="Step input")
34
+ output_data: dict[str, Any] | None = Field(default=None, description="Step output")
35
+ latency_ms: float | None = Field(default=None, description="Step latency in milliseconds")
36
+ error: str | None = Field(default=None, description="Error if step failed")
37
+
38
+
39
+ class Policy(ABC):
40
+ """Policy interface for agent execution.
41
+
42
+ A policy determines the next step(s) to execute based on current state.
43
+ """
44
+
45
+ @abstractmethod
46
+ async def next_steps(
47
+ self,
48
+ agent: "Agent",
49
+ messages: list[Message],
50
+ state: dict[str, Any],
51
+ ) -> list[Step]:
52
+ """Determine next steps to execute.
53
+
54
+ Args:
55
+ agent: Agent instance
56
+ messages: Current conversation messages
57
+ state: Current agent state
58
+
59
+ Returns:
60
+ List of steps to execute (can be parallel)
61
+ """
62
+ raise NotImplementedError("Subclasses must implement next_steps")
63
+
64
+
65
+ class ReplayMismatchError(RouteKitRuntimeError):
66
+ """Error raised when replay encounters a mismatch."""
67
+
68
+ def __init__(self, message: str, context: dict[str, Any] | None = None) -> None:
69
+ """Initialize replay mismatch error.
70
+
71
+ Args:
72
+ message: Error message
73
+ context: Optional context (step_id, trace_id, etc.)
74
+ """
75
+ super().__init__(message, context=context)
76
+
77
+
78
+ class Runtime(BaseModel):
79
+ """Step-based runtime with tracing, permissions, and replay."""
80
+
81
+ agents: dict[str, "Agent"] = Field(default_factory=dict, description="Registered agents")
82
+ trace_dir: Path | None = Field(default=None, description="Directory for trace files")
83
+ max_retries: int = Field(default=3, description="Maximum retries for operations")
84
+ timeout: float | None = Field(default=None, description="Default timeout in seconds")
85
+ max_concurrency: int = Field(default=5, description="Maximum concurrent tool executions")
86
+ permission_manager: PermissionManager | None = Field(
87
+ default=None, description="Permission manager for tool execution"
88
+ )
89
+ policy_hooks: PolicyHooks | None = Field(
90
+ default=None, description="Policy hooks for governance"
91
+ )
92
+ retry_backoff_base: float = Field(default=1.0, description="Base delay for exponential backoff")
93
+ retry_backoff_max: float = Field(
94
+ default=60.0, description="Maximum delay for exponential backoff"
95
+ )
96
+ config: dict[str, Any] = Field(default_factory=dict, description="Runtime configuration")
97
+
98
+ def __init__(self, **kwargs: Any) -> None:
99
+ """Initialize runtime with replay state."""
100
+ super().__init__(**kwargs)
101
+ self._replay_mode: bool = False
102
+ self._replay_trace: Trace | None = None
103
+ self._replay_step_index: int = 0
104
+ self._replay_step_map: dict[str, TraceEvent] = {} # Map step_id -> event for replay
105
+ self._cancellation_token: asyncio.CancelledError | None = None
106
+ # Progress tracking
107
+ self._current_step: int = 0
108
+ self._total_steps: int = 0
109
+ self._progress_callbacks: list[Callable[[dict[str, Any]], None]] = []
110
+
111
+ def register_agent(self, agent: "Agent") -> None:
112
+ """Register an agent with the runtime.
113
+
114
+ Args:
115
+ agent: Agent to register
116
+ """
117
+ self.agents[agent.name] = agent
118
+
119
+ def add_progress_callback(self, callback: Callable[[dict[str, Any]], None]) -> None:
120
+ """Add a callback for progress updates.
121
+
122
+ Args:
123
+ callback: Function that receives progress dict with keys:
124
+ - current_step: int
125
+ - total_steps: int
126
+ - progress_percent: float
127
+ - current_step_type: str | None
128
+ """
129
+ self._progress_callbacks.append(callback)
130
+
131
+ def remove_progress_callback(self, callback: Callable[[dict[str, Any]], None]) -> None:
132
+ """Remove a progress callback.
133
+
134
+ Args:
135
+ callback: Callback to remove
136
+ """
137
+ if callback in self._progress_callbacks:
138
+ self._progress_callbacks.remove(callback)
139
+
140
+ def _emit_progress(self, trace: Trace, step_type: str | None = None) -> None:
141
+ """Emit progress update to callbacks and trace.
142
+
143
+ Args:
144
+ trace: Trace to add progress event to
145
+ step_type: Optional current step type
146
+ """
147
+ progress_data = {
148
+ "current_step": self._current_step,
149
+ "total_steps": self._total_steps,
150
+ "progress_percent": (
151
+ (self._current_step / self._total_steps * 100) if self._total_steps > 0 else 0.0
152
+ ),
153
+ "current_step_type": step_type,
154
+ }
155
+
156
+ # Notify callbacks
157
+ for callback in self._progress_callbacks:
158
+ try:
159
+ callback(progress_data)
160
+ except Exception:
161
+ # Don't let callback errors break execution
162
+ pass
163
+
164
+ # Add to trace
165
+ trace.add_event("progress_update", progress_data)
166
+
167
+ async def run(
168
+ self,
169
+ agent_name: str,
170
+ prompt: str,
171
+ policy: Policy | None = None,
172
+ cancellation_token: asyncio.CancelledError | None = None,
173
+ **kwargs: Any,
174
+ ) -> "RunResult":
175
+ """Run an agent with step-based execution and tracing.
176
+
177
+ Args:
178
+ agent_name: Name of the agent to run
179
+ prompt: User prompt
180
+ policy: Optional policy (uses agent's default if not provided)
181
+ cancellation_token: Optional cancellation token for async cancellation
182
+ **kwargs: Additional runtime parameters
183
+
184
+ Returns:
185
+ RunResult with trace_id and full execution details
186
+
187
+ Raises:
188
+ RouteKitRuntimeError: If runtime operation fails
189
+ asyncio.CancelledError: If execution is cancelled
190
+ """
191
+ if agent_name not in self.agents:
192
+ raise RouteKitRuntimeError(f"Agent {agent_name} not found")
193
+
194
+ self._cancellation_token = cancellation_token
195
+
196
+ agent = self.agents[agent_name]
197
+ # In replay mode, reuse the original trace_id if available
198
+ if self._replay_mode and self._replay_trace:
199
+ trace_id = self._replay_trace.trace_id
200
+ else:
201
+ trace_id = str(uuid.uuid4())
202
+
203
+ # Apply PII redaction to prompt if hook is configured
204
+ redacted_prompt = prompt
205
+ if self.policy_hooks and self.policy_hooks.pii_redaction:
206
+ redacted_prompt = self.policy_hooks.pii_redaction.redact(prompt)
207
+
208
+ trace = Trace(trace_id=trace_id, metadata={"agent": agent_name, "prompt": redacted_prompt})
209
+ trace.add_event(
210
+ "run_started", {"trace_id": trace_id, "agent": agent_name, "prompt": redacted_prompt}
211
+ )
212
+
213
+ # Export trace if trace_dir is set (lazy/async export)
214
+ exporter = None
215
+ export_task = None
216
+ if self.trace_dir:
217
+ self.trace_dir.mkdir(parents=True, exist_ok=True)
218
+ exporter = JSONLExporter(output_dir=self.trace_dir)
219
+
220
+ try:
221
+ result = await self._execute_steps(agent, prompt, trace, policy, **kwargs)
222
+ # Clean final_state to remove non-serializable objects before storing in trace
223
+ cleaned_final_state: dict[str, Any] = {}
224
+ for key, value in result.final_state.items():
225
+ if isinstance(value, (str, int, float, bool, type(None))):
226
+ cleaned_final_state[key] = value
227
+ elif isinstance(value, (list, dict)):
228
+ # Recursively clean nested structures
229
+ try:
230
+ json.dumps(value) # Test if serializable
231
+ cleaned_final_state[key] = value
232
+ except (TypeError, ValueError):
233
+ # Convert to string if not serializable
234
+ cleaned_final_state[key] = str(value)
235
+ else:
236
+ cleaned_final_state[key] = str(value) # Convert to string representation
237
+
238
+ # Store only serializable parts of result
239
+ result_dict = {
240
+ "output": result.output.model_dump(mode="json"),
241
+ "trace_id": result.trace_id,
242
+ "final_state": cleaned_final_state,
243
+ "messages": [msg.model_dump(mode="json") for msg in result.messages],
244
+ }
245
+ trace.add_event("run_completed", {"trace_id": trace_id, "result": result_dict})
246
+
247
+ # Export trace asynchronously (fire and forget)
248
+ if exporter and self.trace_dir:
249
+ # Ensure directory exists before async export
250
+ self.trace_dir.mkdir(parents=True, exist_ok=True)
251
+ # Create background task for trace export
252
+ export_task = asyncio.create_task(exporter.export(trace))
253
+ # Store task reference for potential cleanup
254
+ # Note: Task will complete in background, errors are logged by exporter
255
+ # For testing, we could await here, but in production we want fire-and-forget
256
+ # The exporter now creates the directory itself, so this should work
257
+
258
+ return result
259
+ except asyncio.CancelledError:
260
+ trace.add_event("cancelled", {"trace_id": trace_id})
261
+ # Cancel export task if it exists
262
+ if export_task and not export_task.done():
263
+ export_task.cancel()
264
+ try:
265
+ await export_task
266
+ except asyncio.CancelledError:
267
+ pass
268
+ if exporter:
269
+ await exporter.export(trace)
270
+ raise
271
+ except (RouteKitRuntimeError, ToolError, ModelError) as e:
272
+ # Re-raise known routkitai errors without wrapping
273
+ trace.add_event(
274
+ "error",
275
+ {
276
+ "error": str(e),
277
+ "error_type": type(e).__name__,
278
+ "context": getattr(e, "context", {}),
279
+ },
280
+ )
281
+ if export_task and not export_task.done():
282
+ export_task.cancel()
283
+ try:
284
+ await export_task
285
+ except asyncio.CancelledError:
286
+ pass
287
+ if exporter:
288
+ await exporter.export(trace)
289
+ raise
290
+ except Exception as e:
291
+ trace.add_event(
292
+ "error",
293
+ {
294
+ "error": str(e),
295
+ "error_type": type(e).__name__,
296
+ "context": {"trace_id": trace_id, "agent_name": agent_name},
297
+ },
298
+ )
299
+ # Cancel export task if it exists
300
+ if export_task and not export_task.done():
301
+ export_task.cancel()
302
+ try:
303
+ await export_task
304
+ except asyncio.CancelledError:
305
+ pass
306
+ if exporter:
307
+ await exporter.export(trace)
308
+ # Wrap unknown exceptions in RouteKitRuntimeError
309
+ raise RouteKitRuntimeError(
310
+ f"Runtime execution failed: {e}",
311
+ context={
312
+ "trace_id": trace_id,
313
+ "agent_name": agent_name,
314
+ "error_type": type(e).__name__,
315
+ },
316
+ ) from e
317
+
318
+ async def _execute_steps(
319
+ self,
320
+ agent: "Agent",
321
+ prompt: str,
322
+ trace: Trace,
323
+ policy: Policy | None,
324
+ **kwargs: Any,
325
+ ) -> "RunResult":
326
+ """Execute agent using step-based policy loop.
327
+
328
+ Args:
329
+ agent: Agent to execute
330
+ prompt: User prompt
331
+ trace: Trace for recording events
332
+ policy: Policy for step execution
333
+ **kwargs: Additional parameters
334
+
335
+ Returns:
336
+ RunResult
337
+ """
338
+ messages: list[Message] = [Message.user(prompt)]
339
+ state: dict[str, Any] = {
340
+ "memory": agent.memory, # Make memory available to policies
341
+ "runtime": self, # Make runtime available for supervisor policy
342
+ }
343
+
344
+ # Use ReActPolicy as default if not provided
345
+ if policy is None:
346
+ from routekitai.core.policies import ReActPolicy
347
+ from routekitai.core.policy_adapter import PolicyAdapter
348
+
349
+ policy = PolicyAdapter(ReActPolicy())
350
+
351
+ max_iterations = kwargs.get("max_iterations", 50)
352
+ iteration = 0
353
+ self._current_step = 0
354
+ self._total_steps = max_iterations # Estimate, will be updated as we go
355
+
356
+ # Emit initial progress
357
+ self._emit_progress(trace, "initialization")
358
+
359
+ while iteration < max_iterations:
360
+ # Check for cancellation
361
+ if self._cancellation_token:
362
+ raise asyncio.CancelledError("Agent execution cancelled")
363
+
364
+ # Update state with current iteration
365
+ state["iteration"] = iteration
366
+ self._current_step = iteration
367
+
368
+ # Get next steps from policy
369
+ steps = await policy.next_steps(agent, messages, state)
370
+
371
+ if not steps:
372
+ # No more steps, finalize
373
+ if messages and messages[-1].role == MessageRole.ASSISTANT:
374
+ output_message = messages[-1]
375
+ else:
376
+ # Generate final response
377
+ final_response = await self._call_model(agent, messages, trace, stream=False)
378
+ assert isinstance(final_response, ModelResponse), (
379
+ "Expected ModelResponse when stream=False"
380
+ )
381
+ output_message = Message.assistant(final_response.content)
382
+ messages.append(output_message)
383
+ break
384
+
385
+ # Update total steps estimate if we have more steps
386
+ if len(steps) > 0:
387
+ self._total_steps = max(self._total_steps, iteration + len(steps))
388
+
389
+ # Emit progress before executing steps
390
+ step_type = steps[0].step_type if steps else None
391
+ self._emit_progress(trace, step_type)
392
+
393
+ # Execute steps (potentially in parallel)
394
+ step_results = await self._execute_steps_parallel(steps, agent, trace)
395
+
396
+ # Process step results
397
+ for step_result in step_results:
398
+ if step_result.step_type == "model_call":
399
+ # Handle model response
400
+ if (
401
+ step_result.output_data
402
+ and isinstance(step_result.output_data, dict)
403
+ and "response" in step_result.output_data
404
+ ):
405
+ response_data = step_result.output_data.get("response", {})
406
+ content = (
407
+ response_data.get("content", "")
408
+ if isinstance(response_data, dict)
409
+ else ""
410
+ )
411
+ tool_calls_data = (
412
+ response_data.get("tool_calls", [])
413
+ if isinstance(response_data, dict)
414
+ else []
415
+ )
416
+
417
+ # Create assistant message with tool calls
418
+ tool_calls: list[dict[str, Any]] | None = None
419
+ if tool_calls_data and isinstance(tool_calls_data, list):
420
+ # Filter and validate tool calls
421
+ tool_calls = [
422
+ {
423
+ "id": str(tc.get("id", "")) if isinstance(tc, dict) else "",
424
+ "name": str(tc.get("name", "")) if isinstance(tc, dict) else "",
425
+ "arguments": tc.get("arguments", {})
426
+ if isinstance(tc, dict)
427
+ and isinstance(tc.get("arguments"), dict)
428
+ else {},
429
+ }
430
+ for tc in tool_calls_data
431
+ if isinstance(tc, dict)
432
+ and tc.get("name") # Only include valid tool calls with names
433
+ ]
434
+ # Set to None if empty after filtering
435
+ if not tool_calls:
436
+ tool_calls = None
437
+
438
+ messages.append(Message.assistant(content, tool_calls=tool_calls))
439
+
440
+ elif step_result.step_type == "tool_call":
441
+ # Handle tool call result
442
+ if step_result.output_data and "result" in step_result.output_data:
443
+ tool_name = (
444
+ step_result.input_data.get("tool_name", "")
445
+ if step_result.input_data
446
+ else ""
447
+ )
448
+ tool_result = step_result.output_data["result"]
449
+
450
+ # Add tool result message
451
+ messages.append(
452
+ Message.tool(
453
+ f"Tool {tool_name} executed",
454
+ {"result": tool_result, "tool": tool_name},
455
+ )
456
+ )
457
+
458
+ elif step_result.step_type == "subagent_call":
459
+ # Handle sub-agent call result (for supervisor policy)
460
+ if (
461
+ step_result.output_data
462
+ and isinstance(step_result.output_data, dict)
463
+ and "result" in step_result.output_data
464
+ ):
465
+ subagent_name = (
466
+ step_result.input_data.get("agent_name", "")
467
+ if step_result.input_data
468
+ else ""
469
+ )
470
+ subagent_result = step_result.output_data.get("result")
471
+
472
+ # Add sub-agent result message
473
+ if subagent_result is not None:
474
+ # Convert result to string for message content
475
+ if isinstance(subagent_result, str):
476
+ result_content = subagent_result
477
+ elif isinstance(subagent_result, dict) and "content" in subagent_result:
478
+ result_content = str(subagent_result["content"])
479
+ elif hasattr(subagent_result, "content"):
480
+ result_content = str(subagent_result.content)
481
+ else:
482
+ result_content = str(subagent_result)
483
+
484
+ messages.append(
485
+ Message.assistant(
486
+ f"Sub-agent {subagent_name} completed: {result_content}"
487
+ )
488
+ )
489
+
490
+ # Update state for supervisor policy
491
+ state["subagent_result"] = {
492
+ "agent": subagent_name,
493
+ "output": subagent_result,
494
+ "trace_id": step_result.output_data.get("trace_id"),
495
+ }
496
+ state["waiting_for_subagent"] = False
497
+
498
+ iteration += 1
499
+
500
+ # Finalize if we hit max iterations
501
+ if iteration >= max_iterations:
502
+ if messages and messages[-1].role == MessageRole.ASSISTANT:
503
+ output_message = messages[-1]
504
+ else:
505
+ final_response = await self._call_model(agent, messages, trace, stream=False)
506
+ assert isinstance(final_response, ModelResponse), (
507
+ "Expected ModelResponse when stream=False"
508
+ )
509
+ output_message = Message.assistant(final_response.content)
510
+ messages.append(output_message)
511
+
512
+ # Import here to avoid circular import
513
+ from routekitai.core.agent import RunResult
514
+
515
+ return RunResult(
516
+ output=output_message,
517
+ trace_id=trace.trace_id,
518
+ final_state=state,
519
+ messages=messages,
520
+ )
521
+
522
+ async def _execute_steps_parallel(
523
+ self, steps: list[Step], agent: "Agent", trace: Trace
524
+ ) -> list[Step]:
525
+ """Execute steps in parallel with concurrency control.
526
+
527
+ Args:
528
+ steps: Steps to execute
529
+ agent: Agent instance
530
+ trace: Trace for recording
531
+
532
+ Returns:
533
+ List of completed steps with outputs (may include steps with errors)
534
+
535
+ Note:
536
+ If a step fails, it will have step.error set and the exception will be
537
+ captured. Other steps will continue executing. The first error will be
538
+ raised after all steps complete.
539
+ """
540
+ semaphore = asyncio.Semaphore(self.max_concurrency)
541
+ tasks = [self._execute_step(step, agent, trace, semaphore) for step in steps]
542
+ # Use return_exceptions=True to collect all results, even if some fail
543
+ results: list[Step | BaseException] = await asyncio.gather(*tasks, return_exceptions=True)
544
+
545
+ # Convert exceptions to steps with errors
546
+ completed_steps: list[Step] = []
547
+ first_error: BaseException | None = None
548
+
549
+ for i, result in enumerate(results):
550
+ if isinstance(result, BaseException):
551
+ # Create a step with error
552
+ step = steps[i]
553
+ step.error = str(result)
554
+ step.output_data = None
555
+ completed_steps.append(step)
556
+ if first_error is None:
557
+ first_error = result
558
+ elif isinstance(result, Step):
559
+ completed_steps.append(result)
560
+ else:
561
+ # Unexpected type - wrap in error
562
+ step = steps[i]
563
+ step.error = f"Unexpected result type: {type(result).__name__}"
564
+ step.output_data = None
565
+ completed_steps.append(step)
566
+ if first_error is None:
567
+ first_error = RuntimeError(f"Unexpected result type: {type(result).__name__}")
568
+
569
+ # Raise first error if any occurred
570
+ if first_error:
571
+ raise first_error
572
+
573
+ return completed_steps
574
+
575
+ async def _execute_step(
576
+ self, step: Step, agent: "Agent", trace: Trace, semaphore: asyncio.Semaphore
577
+ ) -> Step:
578
+ """Execute a single step.
579
+
580
+ Args:
581
+ step: Step to execute
582
+ agent: Agent instance
583
+ trace: Trace for recording
584
+ semaphore: Semaphore for concurrency control
585
+
586
+ Returns:
587
+ Completed step with output
588
+ """
589
+ async with semaphore:
590
+ start_time = time.time()
591
+
592
+ # Add step_started event for trace completeness
593
+ trace.add_event(
594
+ "step_started",
595
+ {
596
+ "step_id": step.step_id,
597
+ "step_type": step.step_type,
598
+ "input_data": step.input_data,
599
+ },
600
+ )
601
+
602
+ try:
603
+ if step.step_type == "model_call":
604
+ # Check if in replay mode
605
+ if self._replay_mode and self._replay_trace:
606
+ # Match by sequential order using step_completed events
607
+ # Find the next step_completed event of type model_call
608
+ model_call_count = len(
609
+ [s for s in self._replay_step_index_history if s == "model_call"]
610
+ )
611
+ matching_step_event = None
612
+ step_count = 0
613
+ for step_event in self._replay_step_events:
614
+ if step_event.data.get("step_type") == "model_call":
615
+ if step_count == model_call_count:
616
+ matching_step_event = step_event
617
+ break
618
+ step_count += 1
619
+
620
+ if matching_step_event:
621
+ # Find the corresponding model_called event by step_id
622
+ step_id = matching_step_event.data.get("step_id")
623
+ matching_event = next(
624
+ (
625
+ e
626
+ for e in self._replay_model_events
627
+ if e.data.get("step_id") == step_id
628
+ ),
629
+ None,
630
+ )
631
+ if not matching_event and model_call_count < len(
632
+ self._replay_model_events
633
+ ):
634
+ # Fallback: use sequential order
635
+ matching_event = self._replay_model_events[model_call_count]
636
+ elif model_call_count < len(self._replay_model_events):
637
+ # Fallback: use sequential order
638
+ matching_event = self._replay_model_events[model_call_count]
639
+ else:
640
+ raise ReplayMismatchError(
641
+ f"Replay mismatch: expected model call at index {model_call_count}, "
642
+ f"but only {len(self._replay_model_events)} model call events in trace",
643
+ context={
644
+ "step_id": step.step_id,
645
+ "trace_id": trace.trace_id,
646
+ "expected_index": model_call_count,
647
+ "available_events": len(self._replay_model_events),
648
+ },
649
+ )
650
+
651
+ if not matching_event:
652
+ raise ReplayMismatchError(
653
+ "Replay mismatch: could not find matching model call event",
654
+ context={
655
+ "step_id": step.step_id,
656
+ "trace_id": trace.trace_id,
657
+ "model_call_count": model_call_count,
658
+ },
659
+ )
660
+
661
+ self._replay_step_index_history.append("model_call")
662
+ response_data = matching_event.data.get("response", {})
663
+ else:
664
+ # Call model (check cancellation)
665
+ if self._cancellation_token:
666
+ raise asyncio.CancelledError("Model call cancelled")
667
+ # Validate and extract messages
668
+ messages_data = step.input_data.get("messages", [])
669
+ if not isinstance(messages_data, list):
670
+ raise RouteKitRuntimeError(
671
+ f"Invalid messages data in step: expected list, got {type(messages_data).__name__}",
672
+ context={"step_id": step.step_id, "step_type": step.step_type},
673
+ )
674
+ # Convert dict messages to Message objects if needed (optimize: avoid conversion if already Message)
675
+ messages: list[Message] = []
676
+ for msg_data in messages_data:
677
+ if isinstance(msg_data, Message):
678
+ messages.append(msg_data)
679
+ elif isinstance(msg_data, dict):
680
+ try:
681
+ messages.append(Message(**msg_data))
682
+ except Exception as e:
683
+ raise RouteKitRuntimeError(
684
+ f"Invalid message format in step: {e}",
685
+ context={
686
+ "step_id": step.step_id,
687
+ "step_type": step.step_type,
688
+ "message_data": str(msg_data)[:100],
689
+ },
690
+ ) from e
691
+ else:
692
+ raise RouteKitRuntimeError(
693
+ f"Invalid message format in step: expected Message or dict, got {type(msg_data).__name__}",
694
+ context={"step_id": step.step_id, "step_type": step.step_type},
695
+ )
696
+ response = await self._call_model(agent, messages, trace, stream=False)
697
+ assert isinstance(response, ModelResponse), (
698
+ "Expected ModelResponse when stream=False"
699
+ )
700
+ response_data = {
701
+ "content": response.content,
702
+ "tool_calls": [
703
+ {
704
+ "id": tc.id,
705
+ "name": tc.name,
706
+ "arguments": tc.arguments,
707
+ }
708
+ for tc in (response.tool_calls or [])
709
+ ],
710
+ "usage": response.usage.model_dump() if response.usage else None,
711
+ }
712
+
713
+ step.output_data = {"response": response_data}
714
+ trace.add_event(
715
+ "model_called",
716
+ {
717
+ "step_id": step.step_id,
718
+ "response": response_data,
719
+ },
720
+ )
721
+
722
+ elif step.step_type == "tool_call":
723
+ # Check if in replay mode for tool calls
724
+ if self._replay_mode and self._replay_trace:
725
+ # Match tool calls by sequential order
726
+ tool_call_index = len(
727
+ [s for s in self._replay_step_index_history if s == "tool_call"]
728
+ )
729
+ if tool_call_index >= len(self._replay_tool_call_events):
730
+ raise ReplayMismatchError(
731
+ f"Replay mismatch: expected tool call at index {tool_call_index}, "
732
+ f"but only {len(self._replay_tool_call_events)} tool call events in trace",
733
+ context={
734
+ "step_id": step.step_id,
735
+ "trace_id": trace.trace_id,
736
+ "tool_name": step.input_data.get("tool_name", "unknown"),
737
+ "expected_index": tool_call_index,
738
+ },
739
+ )
740
+ matching_call = self._replay_tool_call_events[tool_call_index]
741
+ tool_name = matching_call.data.get("tool", "")
742
+
743
+ # Match tool result by sequential order (tool results should be in same order as tool calls)
744
+ if tool_call_index >= len(self._replay_tool_result_events):
745
+ # No result available - might be an error case
746
+ tool_result = ""
747
+ else:
748
+ matching_result = self._replay_tool_result_events[tool_call_index]
749
+ tool_result = matching_result.data.get("result", "")
750
+
751
+ step.output_data = {"result": tool_result}
752
+ # Track that we processed a tool call
753
+ self._replay_step_index_history.append("tool_call")
754
+ else:
755
+ tool_name = step.input_data.get("tool_name")
756
+ tool_args = step.input_data.get("tool_args", {})
757
+
758
+ # Validate tool_name
759
+ if not tool_name:
760
+ raise RouteKitRuntimeError(
761
+ "Missing tool_name in step input_data",
762
+ context={"step_id": step.step_id, "step_type": step.step_type},
763
+ )
764
+ if not isinstance(tool_args, dict):
765
+ raise RouteKitRuntimeError(
766
+ f"Invalid tool_args in step: expected dict, got {type(tool_args).__name__}",
767
+ context={
768
+ "step_id": step.step_id,
769
+ "step_type": step.step_type,
770
+ "tool_name": tool_name,
771
+ },
772
+ )
773
+
774
+ # Find tool
775
+ tool = next((t for t in agent.tools if t.name == tool_name), None)
776
+ if not tool:
777
+ step.error = f"Tool {tool_name} not found"
778
+ raise ToolError(
779
+ f"Tool '{tool_name}' not found in agent '{agent.name}'",
780
+ context={
781
+ "agent_name": agent.name,
782
+ "tool_name": tool_name,
783
+ "step_id": step.step_id,
784
+ },
785
+ )
786
+
787
+ # Execute tool (pass agent for agent-level filters)
788
+ try:
789
+ tool_result = await self._execute_tool(
790
+ tool, tool_args, trace, step.step_id, agent=agent
791
+ )
792
+ step.output_data = {"result": tool_result}
793
+ except ToolError as e:
794
+ # Wrap ToolError from tool filters/approval gates in RouteKitRuntimeError for consistency
795
+ error_msg = str(e)
796
+ if (
797
+ "not allowed" in error_msg
798
+ or "filtered" in error_msg
799
+ or "requires approval" in error_msg
800
+ or "blocked" in error_msg
801
+ ):
802
+ raise RouteKitRuntimeError(
803
+ error_msg, context=getattr(e, "context", {})
804
+ ) from e
805
+ raise
806
+
807
+ elif step.step_type == "subagent_call":
808
+ # Execute sub-agent (for supervisor policy)
809
+ subagent_name = step.input_data.get("agent_name")
810
+ prompt = step.input_data.get("prompt", "")
811
+
812
+ # Validate subagent_name
813
+ if not subagent_name:
814
+ raise RouteKitRuntimeError(
815
+ "Missing agent_name in subagent_call step",
816
+ context={"step_id": step.step_id, "step_type": step.step_type},
817
+ )
818
+
819
+ if subagent_name not in self.agents:
820
+ raise RouteKitRuntimeError(
821
+ f"Sub-agent '{subagent_name}' not found",
822
+ context={
823
+ "step_id": step.step_id,
824
+ "step_type": step.step_type,
825
+ "agent_name": subagent_name,
826
+ },
827
+ )
828
+
829
+ # Execute sub-agent
830
+ subagent_result = await self.run(subagent_name, prompt)
831
+ step.output_data = {
832
+ "result": subagent_result.output.content,
833
+ "trace_id": subagent_result.trace_id,
834
+ "messages": [m.model_dump() for m in subagent_result.messages],
835
+ }
836
+
837
+ trace.add_event(
838
+ "step_completed",
839
+ {
840
+ "step_id": step.step_id,
841
+ "step_type": step.step_type,
842
+ "latency_ms": (time.time() - start_time) * 1000,
843
+ },
844
+ )
845
+
846
+ except (RouteKitRuntimeError, ToolError, ModelError, asyncio.CancelledError) as e:
847
+ # Re-raise known errors
848
+ step.error = str(e)
849
+ trace.add_event(
850
+ "error",
851
+ {
852
+ "step_id": step.step_id,
853
+ "error": str(e),
854
+ "error_type": type(e).__name__,
855
+ "context": getattr(e, "context", {}),
856
+ },
857
+ )
858
+ raise
859
+ except Exception as e:
860
+ # Wrap unknown exceptions, preserving context
861
+ step.error = str(e)
862
+ error_context = {
863
+ "step_id": step.step_id,
864
+ "step_type": step.step_type,
865
+ }
866
+ # Preserve context from original exception if available
867
+ if hasattr(e, "context") and isinstance(e.context, dict):
868
+ error_context.update(e.context)
869
+
870
+ trace.add_event(
871
+ "error",
872
+ {
873
+ "step_id": step.step_id,
874
+ "error": str(e),
875
+ "error_type": type(e).__name__,
876
+ "context": error_context,
877
+ },
878
+ )
879
+ raise RouteKitRuntimeError(
880
+ f"Step execution failed: {e}", context=error_context
881
+ ) from e
882
+
883
+ finally:
884
+ step.latency_ms = (time.time() - start_time) * 1000
885
+
886
+ return step
887
+
888
+ async def _call_model(
889
+ self,
890
+ agent: "Agent",
891
+ messages: list[Message],
892
+ trace: Trace,
893
+ stream: bool = False,
894
+ ) -> ModelResponse | AsyncIterator[StreamEvent]:
895
+ """Call the agent's model.
896
+
897
+ Args:
898
+ agent: Agent instance
899
+ messages: Messages to send
900
+ trace: Trace for recording
901
+ stream: Whether to stream the response
902
+
903
+ Returns:
904
+ Model response or stream of events
905
+ """
906
+ start_time = time.time()
907
+ try:
908
+ response = await agent.model.chat(messages, tools=agent.tools, stream=stream)
909
+ latency_ms = (time.time() - start_time) * 1000
910
+
911
+ if stream:
912
+ # Return streaming iterator
913
+ assert isinstance(response, AsyncIterator), (
914
+ "Expected AsyncIterator when stream=True"
915
+ )
916
+
917
+ async def stream_wrapper() -> AsyncIterator[StreamEvent]:
918
+ content_buffer = ""
919
+ tool_calls_buffer: list[dict[str, Any]] = []
920
+ usage = None
921
+
922
+ async for event in response:
923
+ # Forward stream events to trace
924
+ if event.content:
925
+ content_buffer += event.content or ""
926
+ if event.tool_calls:
927
+ tool_calls_buffer.extend(
928
+ [
929
+ {
930
+ "id": tc.id,
931
+ "name": tc.name,
932
+ "arguments": tc.arguments,
933
+ }
934
+ for tc in (event.tool_calls or [])
935
+ ]
936
+ )
937
+ if event.usage:
938
+ usage = event.usage
939
+
940
+ # Emit streaming event to trace
941
+ trace.add_event(
942
+ "model_stream_chunk",
943
+ {
944
+ "model": agent.model.name,
945
+ "chunk": event.content or "",
946
+ "event_type": event.type,
947
+ },
948
+ )
949
+
950
+ yield event
951
+
952
+ # Emit final model_called event with complete response
953
+ trace.add_event(
954
+ "model_called",
955
+ {
956
+ "model": agent.model.name,
957
+ "messages_count": len(messages),
958
+ "response": {
959
+ "content": content_buffer,
960
+ "tool_calls": tool_calls_buffer,
961
+ "usage": usage.model_dump() if usage else None,
962
+ },
963
+ "latency_ms": latency_ms,
964
+ "streamed": True,
965
+ },
966
+ )
967
+
968
+ return stream_wrapper()
969
+
970
+ # Non-streaming mode
971
+ if not isinstance(response, ModelResponse):
972
+ raise ModelError("Model returned a stream when stream=False was requested")
973
+
974
+ trace.add_event(
975
+ "model_called",
976
+ {
977
+ "model": agent.model.name,
978
+ "messages_count": len(messages),
979
+ "response": {
980
+ "content": response.content,
981
+ "tool_calls": [
982
+ {
983
+ "id": tc.id,
984
+ "name": tc.name,
985
+ "arguments": tc.arguments,
986
+ }
987
+ for tc in (response.tool_calls or [])
988
+ ],
989
+ "usage": response.usage.model_dump() if response.usage else None,
990
+ },
991
+ "latency_ms": latency_ms,
992
+ },
993
+ )
994
+
995
+ return response
996
+ except (ModelError, asyncio.CancelledError) as e:
997
+ # Re-raise model errors and cancellations
998
+ trace.add_event(
999
+ "error",
1000
+ {
1001
+ "error": str(e),
1002
+ "error_type": type(e).__name__,
1003
+ "context": {"model": agent.model.name, "context": "model_call"},
1004
+ },
1005
+ )
1006
+ raise
1007
+ except Exception as e:
1008
+ # Wrap unknown exceptions
1009
+ trace.add_event(
1010
+ "error",
1011
+ {
1012
+ "error": str(e),
1013
+ "error_type": type(e).__name__,
1014
+ "context": {"model": agent.model.name, "context": "model_call"},
1015
+ },
1016
+ )
1017
+ raise ModelError(
1018
+ f"Model call failed: {e}",
1019
+ context={"model": agent.model.name, "error_type": type(e).__name__},
1020
+ ) from e
1021
+
1022
+ async def _execute_tool(
1023
+ self,
1024
+ tool: Tool,
1025
+ tool_args: dict[str, Any],
1026
+ trace: Trace,
1027
+ step_id: str | None = None,
1028
+ agent: "Agent | None" = None,
1029
+ ) -> Any:
1030
+ """Execute a tool with permission checks, retries, and timeout.
1031
+
1032
+ Args:
1033
+ tool: Tool to execute
1034
+ tool_args: Tool arguments
1035
+ trace: Trace for recording
1036
+ step_id: Optional step ID
1037
+ agent: Optional agent instance (for agent-level filters)
1038
+
1039
+ Returns:
1040
+ Tool execution result
1041
+
1042
+ Raises:
1043
+ ToolError: If tool execution fails
1044
+ """
1045
+ # Check tool filter (allow/deny list) - agent level first, then runtime level
1046
+ # Agent-level filter takes precedence - if agent has a filter, only check that
1047
+ if agent and agent.tool_filter:
1048
+ if not agent.tool_filter.is_allowed(tool.name):
1049
+ error_msg = f"Tool {tool.name} is not allowed (filtered by agent policy)"
1050
+ trace.add_event(
1051
+ "tool_called",
1052
+ {
1053
+ "tool": tool.name,
1054
+ "step_id": step_id,
1055
+ "error": error_msg,
1056
+ },
1057
+ )
1058
+ raise ToolError(
1059
+ error_msg,
1060
+ context={"tool_name": tool.name, "step_id": step_id, "filter_level": "agent"},
1061
+ )
1062
+ # Runtime-level filter (only checked if agent doesn't have a filter)
1063
+ elif self.policy_hooks and self.policy_hooks.tool_filter:
1064
+ if not self.policy_hooks.tool_filter.is_allowed(tool.name):
1065
+ error_msg = f"Tool {tool.name} is not allowed (filtered by runtime policy)"
1066
+ trace.add_event(
1067
+ "tool_called",
1068
+ {
1069
+ "tool": tool.name,
1070
+ "step_id": step_id,
1071
+ "error": error_msg,
1072
+ },
1073
+ )
1074
+ raise ToolError(
1075
+ error_msg,
1076
+ context={"tool_name": tool.name, "step_id": step_id, "filter_level": "agent"},
1077
+ )
1078
+
1079
+ # Check approval gate
1080
+ if self.policy_hooks and self.policy_hooks.approval_gate:
1081
+ # Convert ToolPermission enum to strings for approval gate
1082
+ permission_strings = [
1083
+ p.value if hasattr(p, "value") else str(p) for p in (tool.permissions or [])
1084
+ ]
1085
+ if self.policy_hooks.approval_gate.requires_approval(
1086
+ tool.name, tool_args, permission_strings
1087
+ ):
1088
+ if not self.policy_hooks.approval_gate.is_approved(tool.name, tool_args):
1089
+ error_msg = f"Tool {tool.name} requires approval (blocked by approval gate)"
1090
+ trace.add_event(
1091
+ "tool_called",
1092
+ {
1093
+ "tool": tool.name,
1094
+ "step_id": step_id,
1095
+ "error": error_msg,
1096
+ },
1097
+ )
1098
+ raise ToolError(
1099
+ error_msg,
1100
+ context={
1101
+ "tool_name": tool.name,
1102
+ "step_id": step_id,
1103
+ "filter_level": "agent",
1104
+ },
1105
+ )
1106
+
1107
+ # Check permissions
1108
+ if self.permission_manager:
1109
+ # Check if tool requires permissions
1110
+ if tool.permissions:
1111
+ # Check each required permission
1112
+ for perm in tool.permissions:
1113
+ if not self.permission_manager.check_permission(perm, tool.name):
1114
+ error_msg = f"Permission denied for tool {tool.name} (requires {perm})"
1115
+ trace.add_event(
1116
+ "tool_called",
1117
+ {
1118
+ "tool": tool.name,
1119
+ "step_id": step_id,
1120
+ "error": error_msg,
1121
+ },
1122
+ )
1123
+ raise ToolError(
1124
+ error_msg,
1125
+ context={
1126
+ "tool_name": tool.name,
1127
+ "step_id": step_id,
1128
+ "filter_level": "agent",
1129
+ },
1130
+ )
1131
+
1132
+ # Redact sensitive fields before logging
1133
+ redacted_args = tool.redact_data(tool_args)
1134
+
1135
+ # Apply PII redaction hook if configured
1136
+ if self.policy_hooks and self.policy_hooks.pii_redaction:
1137
+ redacted_args = self.policy_hooks.pii_redaction.redact_dict(redacted_args)
1138
+
1139
+ trace.add_event(
1140
+ "tool_called",
1141
+ {
1142
+ "tool": tool.name,
1143
+ "arguments": redacted_args,
1144
+ "step_id": step_id,
1145
+ },
1146
+ )
1147
+
1148
+ start_time = time.time()
1149
+ timeout = tool.timeout or self.timeout
1150
+
1151
+ # Retry logic with exponential backoff
1152
+ last_error: Exception | None = None
1153
+ for attempt in range(self.max_retries + 1):
1154
+ try:
1155
+ # Check for cancellation
1156
+ if self._cancellation_token:
1157
+ raise asyncio.CancelledError("Tool execution cancelled")
1158
+
1159
+ if timeout:
1160
+ result = await asyncio.wait_for(tool.execute(**tool_args), timeout=timeout)
1161
+ else:
1162
+ result = await tool.execute(**tool_args)
1163
+
1164
+ latency_ms = (time.time() - start_time) * 1000
1165
+ trace.add_event(
1166
+ "tool_result",
1167
+ {
1168
+ "tool": tool.name,
1169
+ "result": str(result),
1170
+ "step_id": step_id,
1171
+ "latency_ms": latency_ms,
1172
+ },
1173
+ )
1174
+
1175
+ return result
1176
+
1177
+ except asyncio.CancelledError:
1178
+ # Don't retry on cancellation
1179
+ raise
1180
+ except TimeoutError as e:
1181
+ last_error = e
1182
+ if attempt < self.max_retries:
1183
+ # Exponential backoff: base * (2^attempt), capped at max
1184
+ backoff_delay = min(
1185
+ self.retry_backoff_base * (2**attempt), self.retry_backoff_max
1186
+ )
1187
+ await asyncio.sleep(backoff_delay)
1188
+ continue
1189
+ raise ToolError(
1190
+ f"Tool '{tool.name}' timed out after {timeout}s",
1191
+ context={"tool_name": tool.name, "timeout": timeout, "step_id": step_id},
1192
+ ) from e
1193
+ except ToolError as e:
1194
+ # Retry ToolError from execution failures (not validation/permission errors)
1195
+ # Check if it's a retryable error by examining the message
1196
+ error_msg = str(e).lower()
1197
+ is_retryable = (
1198
+ (
1199
+ "execution failed" in error_msg
1200
+ or "intentional failure" in error_msg
1201
+ or "failed:" in error_msg
1202
+ )
1203
+ and "validation" not in error_msg
1204
+ and "permission" not in error_msg
1205
+ )
1206
+
1207
+ if is_retryable and attempt < self.max_retries:
1208
+ last_error = e
1209
+ # Exponential backoff
1210
+ backoff_delay = min(
1211
+ self.retry_backoff_base * (2**attempt), self.retry_backoff_max
1212
+ )
1213
+ await asyncio.sleep(backoff_delay)
1214
+ continue
1215
+ # Don't retry validation/permission errors or if retries exhausted
1216
+ raise
1217
+ except Exception as e:
1218
+ last_error = e
1219
+ if attempt < self.max_retries:
1220
+ # Exponential backoff
1221
+ backoff_delay = min(
1222
+ self.retry_backoff_base * (2**attempt), self.retry_backoff_max
1223
+ )
1224
+ await asyncio.sleep(backoff_delay)
1225
+ continue
1226
+ # Wrap unknown exceptions in ToolError
1227
+ raise ToolError(
1228
+ f"Tool {tool.name} failed: {e}",
1229
+ context={"tool_name": tool.name, "attempt": attempt + 1, "step_id": step_id},
1230
+ ) from e
1231
+
1232
+ raise ToolError(
1233
+ f"Tool {tool.name} failed after {self.max_retries} retries",
1234
+ context={"tool_name": tool.name, "max_retries": self.max_retries, "step_id": step_id},
1235
+ ) from last_error
1236
+
1237
+ async def replay(
1238
+ self,
1239
+ trace_id: str,
1240
+ agent_name: str,
1241
+ verify_output: bool = True,
1242
+ strict: bool = True,
1243
+ ) -> "RunResult":
1244
+ """Replay a trace with deterministic execution.
1245
+
1246
+ This method loads a trace file and re-executes the agent run using
1247
+ the recorded model responses and tool results. This enables:
1248
+ - Deterministic testing of agent behavior
1249
+ - Debugging failed runs
1250
+ - Reproducing issues in production
1251
+
1252
+ Args:
1253
+ trace_id: Trace ID to replay
1254
+ agent_name: Agent name to use for replay (must match original agent)
1255
+ verify_output: If True, verify replay output matches original
1256
+ strict: If True, raise ReplayMismatchError on any mismatch
1257
+
1258
+ Returns:
1259
+ RunResult from replay
1260
+
1261
+ Raises:
1262
+ RouteKitRuntimeError: If trace not found or agent mismatch
1263
+ ReplayMismatchError: If replay encounters a mismatch and strict=True
1264
+ """
1265
+ if agent_name not in self.agents:
1266
+ raise RouteKitRuntimeError(f"Agent {agent_name} not found")
1267
+
1268
+ # Load trace
1269
+ if not self.trace_dir:
1270
+ raise RouteKitRuntimeError(
1271
+ "trace_dir must be set for replay",
1272
+ context={"trace_id": trace_id, "agent_name": agent_name},
1273
+ )
1274
+
1275
+ exporter = JSONLExporter(output_dir=self.trace_dir)
1276
+ trace = await exporter.load(trace_id)
1277
+ if not trace:
1278
+ raise RouteKitRuntimeError(
1279
+ f"Trace {trace_id} not found",
1280
+ context={"trace_id": trace_id, "trace_dir": str(self.trace_dir)},
1281
+ )
1282
+
1283
+ # Verify agent matches original
1284
+ run_started = trace.get_events_by_type("run_started")
1285
+ if run_started:
1286
+ original_agent = run_started[0].data.get("agent", "")
1287
+ if original_agent and original_agent != agent_name:
1288
+ if strict:
1289
+ raise ReplayMismatchError(
1290
+ f"Agent mismatch: trace was for '{original_agent}', "
1291
+ f"replay requested '{agent_name}'"
1292
+ )
1293
+
1294
+ # Set replay mode
1295
+ self._replay_mode = True
1296
+ self._replay_trace = trace
1297
+ self._replay_step_index = 0
1298
+ self._replay_step_index_history: list[str] = [] # Track step types for sequential matching
1299
+ # Build ordered lists of events by type for sequential matching
1300
+ self._replay_model_events = trace.get_events_by_type("model_called")
1301
+ self._replay_tool_call_events = trace.get_events_by_type("tool_called")
1302
+ self._replay_tool_result_events = trace.get_events_by_type("tool_result")
1303
+ # Build step_completed events in order for unified step matching
1304
+ self._replay_step_events = [e for e in trace.events if e.type == "step_completed"]
1305
+
1306
+ try:
1307
+ # Extract prompt from trace
1308
+ if not run_started:
1309
+ raise RouteKitRuntimeError(
1310
+ "Trace missing run_started event", context={"trace_id": trace_id}
1311
+ )
1312
+ prompt = run_started[0].data.get("prompt", "")
1313
+
1314
+ # Replay execution
1315
+ agent = self.agents[agent_name]
1316
+ result = await self._execute_steps(agent, prompt, trace, None)
1317
+
1318
+ # Verify output if requested
1319
+ if verify_output:
1320
+ run_completed = trace.get_events_by_type("run_completed")
1321
+ if run_completed:
1322
+ original_result = run_completed[0].data.get("result", {})
1323
+ original_output = original_result.get("output", {}).get("content", "")
1324
+ if original_output != result.output.content:
1325
+ error_msg = (
1326
+ f"Output mismatch: original='{original_output}', "
1327
+ f"replay='{result.output.content}'"
1328
+ )
1329
+ if strict:
1330
+ raise ReplayMismatchError(
1331
+ error_msg,
1332
+ context={
1333
+ "trace_id": trace_id,
1334
+ "original_output": original_output,
1335
+ "replay_output": result.output.content,
1336
+ },
1337
+ )
1338
+ else:
1339
+ # Log warning but continue
1340
+ trace.add_event("replay_warning", {"message": error_msg})
1341
+
1342
+ return result
1343
+
1344
+ finally:
1345
+ self._replay_mode = False
1346
+ self._replay_trace = None
1347
+ self._replay_step_index = 0
1348
+ self._replay_step_index_history = []
1349
+ self._replay_model_events = []
1350
+ self._replay_tool_call_events = []
1351
+ self._replay_tool_result_events = []
1352
+
1353
+
1354
+ class DefaultPolicy(Policy):
1355
+ """Default policy for simple agent execution."""
1356
+
1357
+ async def next_steps(
1358
+ self,
1359
+ agent: "Agent",
1360
+ messages: list[Message],
1361
+ state: dict[str, Any],
1362
+ ) -> list[Step]:
1363
+ """Default policy: call model, then execute any tool calls.
1364
+
1365
+ Args:
1366
+ agent: Agent instance
1367
+ messages: Current messages
1368
+ state: Current state
1369
+
1370
+ Returns:
1371
+ List of steps
1372
+ """
1373
+ # Check if we need to call model
1374
+ if not messages or messages[-1].role != MessageRole.ASSISTANT:
1375
+ # Call model
1376
+ return [
1377
+ Step(
1378
+ step_id=str(uuid.uuid4()),
1379
+ step_type="model_call",
1380
+ input_data={"messages": [m.model_dump() for m in messages]},
1381
+ )
1382
+ ]
1383
+
1384
+ # Check for tool calls in last message
1385
+ last_message = messages[-1]
1386
+ if last_message.tool_calls:
1387
+ # Execute tool calls
1388
+ steps = []
1389
+ for tool_call in last_message.tool_calls:
1390
+ steps.append(
1391
+ Step(
1392
+ step_id=str(uuid.uuid4()),
1393
+ step_type="tool_call",
1394
+ input_data={
1395
+ "tool_name": tool_call["name"],
1396
+ "tool_args": tool_call.get("arguments", {}),
1397
+ },
1398
+ )
1399
+ )
1400
+ return steps
1401
+
1402
+ # No more steps
1403
+ return []