ragbits-evaluate 0.0.30rc1__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,718 +1,333 @@
1
1
  """Conversation orchestration for agent simulation scenarios."""
2
2
 
3
- import sys
4
- from collections.abc import Awaitable, Callable
5
- from dataclasses import dataclass, field
6
3
  from datetime import datetime, timezone
7
- from typing import IO, Any
8
- from uuid import uuid4
9
4
 
10
5
  from ragbits.agents.tool import ToolCallResult
11
6
  from ragbits.chat.interface import ChatInterface
12
- from ragbits.chat.interface.types import (
13
- ChatContext,
14
- ChatResponse,
15
- ConversationIdResponse,
16
- StateUpdateResponse,
17
- TextResponse,
18
- )
7
+ from ragbits.chat.interface.types import ChatContext
19
8
  from ragbits.core.llms import Usage
20
- from ragbits.evaluate.agent_simulation.checkers import CheckerContext, CheckerResult, run_checkers
9
+ from ragbits.evaluate.agent_simulation.context import DataSnapshot, DomainContext
10
+ from ragbits.evaluate.agent_simulation.deepeval_evaluator import DeepEvalEvaluator
21
11
  from ragbits.evaluate.agent_simulation.logger import ConversationLogger
22
- from ragbits.evaluate.agent_simulation.metrics.builtin import (
23
- LatencyMetricCollector,
24
- TokenUsageMetricCollector,
25
- ToolUsageMetricCollector,
26
- )
27
- from ragbits.evaluate.agent_simulation.metrics.collectors import CompositeMetricCollector
28
- from ragbits.evaluate.agent_simulation.models import Personality, Scenario, SimulationConfig, Task, Turn
12
+ from ragbits.evaluate.agent_simulation.metrics.collectors import CompositeMetricCollector, MetricCollector
13
+ from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Turn
29
14
  from ragbits.evaluate.agent_simulation.results import (
30
- CheckerResultItem,
31
15
  ConversationMetrics,
32
16
  SimulationResult,
33
17
  SimulationStatus,
34
18
  TaskResult,
35
19
  TurnResult,
36
20
  )
37
- from ragbits.evaluate.agent_simulation.simulation import SimulatedUser, build_llm
38
- from ragbits.evaluate.agent_simulation.tracing import MemoryTraceHandler, TraceAnalyzer, collect_traces
21
+ from ragbits.evaluate.agent_simulation.simulation import GoalChecker, SimulatedUser, ToolUsageChecker, build_llm
39
22
 
40
- ProgressCallback = Callable[[str, Any], Awaitable[None]]
41
23
 
42
-
43
- def _serialize_response_chunk(chunk: ChatResponse | ToolCallResult | Usage | object) -> tuple[str, dict[str, Any]]:
44
- """Serialize a response chunk to (type, data) tuple.
24
+ def _evaluate_with_deepeval(history: list[Turn], logger: ConversationLogger) -> dict[str, float]:
25
+ """Evaluate conversation with DeepEval metrics.
45
26
 
46
27
  Args:
47
- chunk: The chunk to serialize (TextResponse, ToolCallResult, Usage, etc.)
28
+ history: List of conversation turns to evaluate
29
+ logger: Logger instance to record evaluation results
48
30
 
49
31
  Returns:
50
- Tuple of (chunk_type, chunk_data_dict)
32
+ Dictionary of metric names to scores
51
33
  """
52
- if isinstance(chunk, ChatResponse):
53
- return (chunk.get_type(), chunk.content.model_dump())
54
- elif isinstance(chunk, ToolCallResult):
55
- return (
56
- "tool_call",
57
- {
58
- "name": chunk.name,
59
- "arguments": chunk.arguments,
60
- "result": chunk.result,
61
- },
62
- )
63
- elif isinstance(chunk, Usage):
64
- return (
65
- "usage",
66
- {
67
- "prompt_tokens": chunk.prompt_tokens,
68
- "completion_tokens": chunk.completion_tokens,
69
- "total_tokens": chunk.total_tokens,
70
- "estimated_cost": chunk.estimated_cost,
71
- "n_requests": chunk.n_requests,
72
- },
73
- )
74
- else:
75
- return ("unknown", {"raw": str(chunk), "type": type(chunk).__name__})
76
-
34
+ scores: dict[str, float] = {}
35
+ if not history:
36
+ return scores
37
+
38
+ print("\n=== Running DeepEval Evaluation ===")
39
+ deepeval_evaluator = DeepEvalEvaluator()
40
+ try:
41
+ evaluation_results = deepeval_evaluator.evaluate_conversation(history)
42
+ logger.log_deepeval_metrics(evaluation_results)
43
+ for metric_name, result in evaluation_results.items():
44
+ score = result.get("score")
45
+ if score is not None:
46
+ print(f"{metric_name}: {score:.4f}")
47
+ scores[metric_name] = float(score)
48
+ else:
49
+ error = result.get("error", "Unknown error")
50
+ print(f"{metric_name}: Error - {error}")
51
+ except Exception as e:
52
+ print(f"Error during DeepEval evaluation: {e}")
53
+ logger.log_deepeval_metrics({"error": str(e)})
54
+
55
+ return scores
56
+
57
+
58
+ async def run_simulation( # noqa: PLR0912, PLR0915
59
+ scenario: Scenario,
60
+ chat: ChatInterface,
61
+ max_turns_scenario: int = 15,
62
+ max_turns_task: int | None = 4,
63
+ log_file: str | None = None,
64
+ agent_model_name: str | None = None,
65
+ sim_user_model_name: str | None = None,
66
+ checker_model_name: str | None = None,
67
+ default_model: str = "gpt-4o-mini",
68
+ api_key: str = "",
69
+ user_message_prefix: str = "",
70
+ personality: Personality | None = None,
71
+ domain_context: DomainContext | None = None,
72
+ data_snapshot: DataSnapshot | None = None,
73
+ metric_collectors: list[MetricCollector] | None = None,
74
+ ) -> SimulationResult:
75
+ """Run a conversation between an agent and a simulated user.
77
76
 
78
- def _build_checker_results(checker_result: CheckerResult) -> tuple[list[CheckerResultItem], str]:
79
- """Build checker result items from a CheckerResult.
77
+ The conversation proceeds through tasks defined in the scenario, with a goal checker
78
+ determining when each task is completed. Returns structured results for programmatic access.
80
79
 
81
80
  Args:
82
- checker_result: The result from running checkers
81
+ scenario: The scenario containing tasks to complete
82
+ chat: An instantiated ChatInterface used to drive the assistant side of the conversation.
83
+ max_turns_scenario: Maximum number of conversation turns for the entire scenario
84
+ max_turns_task: Maximum number of conversation turns per task (None for no limit)
85
+ log_file: Optional path to log file
86
+ agent_model_name: Optional override for agent LLM model name
87
+ sim_user_model_name: Optional override for simulated user LLM model name
88
+ checker_model_name: Optional override for goal checker LLM model name
89
+ default_model: Default LLM model name
90
+ api_key: API key for LLM
91
+ user_message_prefix: Optional prefix to add to user messages before sending to agent
92
+ personality: Optional personality to use for the simulated user
93
+ domain_context: Optional domain context for goal checking (currency, locale, business rules)
94
+ data_snapshot: Optional data snapshot to ground simulated user requests to available data
95
+ metric_collectors: Optional list of custom metric collectors for per-turn metrics
83
96
 
84
97
  Returns:
85
- Tuple of (list of CheckerResultItem, checker_mode)
98
+ SimulationResult containing all turns, task results, and metrics
86
99
  """
87
- checkers_list: list[CheckerResultItem] = []
88
- all_results = checker_result.details.get("all_results", [])
89
- checker_mode = checker_result.details.get("mode", "all")
90
-
91
- if all_results:
92
- for r in all_results:
93
- checkers_list.append(
94
- CheckerResultItem(
95
- type=r.get("checker_type", "unknown"),
96
- completed=r.get("completed", False),
97
- reason=r.get("reason", ""),
98
- )
99
- )
100
- else:
101
- checkers_list.append(
102
- CheckerResultItem(
103
- type=checker_result.checker_type,
104
- completed=checker_result.completed,
105
- reason=checker_result.reason,
106
- )
107
- )
108
-
109
- return checkers_list, checker_mode
110
-
111
-
112
- @dataclass
113
- class TurnData:
114
- """Data collected during a single turn."""
115
-
116
- usage: Usage
117
- user_message: str
118
- assistant_response: str
119
- tool_calls: list[ToolCallResult] = field(default_factory=list)
120
-
100
+ start_time = datetime.now(timezone.utc)
121
101
 
122
- @dataclass
123
- class TaskData:
124
- """Data collected for a single task."""
102
+ # Initialize metric collectors
103
+ collectors = CompositeMetricCollector(metric_collectors)
125
104
 
126
- task_index: int
127
- turns: list[TurnData] = field(default_factory=list)
128
- final_reason: str = ""
129
- completed: bool = False
130
-
131
-
132
- @dataclass
133
- class SimulationState:
134
- """Mutable state tracking during simulation execution."""
135
-
136
- simulation_id: str = field(default_factory=lambda: str(uuid4()))
137
- start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
138
- current_task_idx: int = 0
139
- current_turn_idx: int = 0
140
- turns_for_current_task: int = 0
141
- current_task_id: int | None = None
142
-
143
- # Accumulated results
144
- turn_results: list[TurnResult] = field(default_factory=list)
145
- task_data: dict[int, TaskData] = field(default_factory=dict)
146
- history: list[Turn] = field(default_factory=list)
147
- total_usage: Usage = field(default_factory=Usage)
105
+ # Initialize result tracking
106
+ turn_results: list[TurnResult] = []
107
+ task_results: list[TaskResult] = []
108
+ task_turn_counts: dict[int, int] = {} # task_index -> turns taken
109
+ task_final_reasons: dict[int, str] = {} # task_index -> final reason
148
110
 
149
- # Context
150
- chat_context: ChatContext = field(default_factory=ChatContext)
151
- status: SimulationStatus = SimulationStatus.RUNNING
111
+ # Simulated user uses an independent llm (can share the same provider)
112
+ sim_user = SimulatedUser(
113
+ llm=build_llm(sim_user_model_name, default_model, api_key),
114
+ scenario=scenario,
115
+ personality=personality,
116
+ data_snapshot=data_snapshot,
117
+ )
118
+ # Independent goal checker model
119
+ goal_checker = GoalChecker(llm=build_llm(checker_model_name, default_model, api_key), scenario=scenario)
120
+ # Tool usage checker (simple comparator, no LLM needed)
121
+ tool_checker = ToolUsageChecker(scenario=scenario)
122
+
123
+ history: list[Turn] = []
124
+ logger = ConversationLogger(log_file)
125
+ logger.initialize_session(scenario, agent_model_name, sim_user_model_name, checker_model_name, personality)
126
+ total_usage = Usage()
127
+
128
+ # Seed: ask the simulated user for the first message based on the first task
129
+ user_message = await sim_user.next_message(history=history)
130
+
131
+ turns_for_current_task = 0
132
+ current_task_id = None
133
+ status = SimulationStatus.RUNNING
152
134
  error_message: str | None = None
153
135
 
154
- # Cache
155
- task_checkers_cache: dict[int, list] = field(default_factory=dict)
156
-
157
- def get_task_turn_count(self, task_index: int) -> int:
158
- """Get the number of turns taken for a task."""
159
- if task_index in self.task_data:
160
- return len(self.task_data[task_index].turns)
161
- return 0
162
-
163
- def get_task_final_reason(self, task_index: int) -> str:
164
- """Get the final reason for a task."""
165
- if task_index in self.task_data:
166
- return self.task_data[task_index].final_reason
167
- return "Not attempted"
168
-
169
- def ensure_task_data(self, task_index: int) -> TaskData:
170
- """Ensure task data exists for the given index."""
171
- if task_index not in self.task_data:
172
- self.task_data[task_index] = TaskData(task_index=task_index)
173
- return self.task_data[task_index]
174
-
175
- def record_turn(
176
- self,
177
- task_index: int,
178
- user_message: str,
179
- assistant_response: str,
180
- usage: Usage,
181
- tool_calls: list[ToolCallResult],
182
- ) -> None:
183
- """Record turn data for the current task."""
184
- task_data = self.ensure_task_data(task_index)
185
- task_data.turns.append(
186
- TurnData(
187
- usage=usage,
188
- user_message=user_message,
189
- assistant_response=assistant_response,
190
- tool_calls=tool_calls,
136
+ try:
137
+ for turn_idx in range(1, max_turns_scenario + 1):
138
+ current_task = sim_user.get_current_task()
139
+ if current_task is None:
140
+ print("\nAll tasks completed!")
141
+ status = SimulationStatus.COMPLETED
142
+ break
143
+
144
+ current_task_index = sim_user.current_task_idx
145
+
146
+ # Track turns per task
147
+ if current_task_id != id(current_task):
148
+ # New task started, reset counter
149
+ turns_for_current_task = 0
150
+ current_task_id = id(current_task)
151
+
152
+ # Check if we've exceeded max turns for this task
153
+ if max_turns_task is not None and turns_for_current_task >= max_turns_task:
154
+ print(f"\nReached maximum number of turns ({max_turns_task}) for current task. Exiting.")
155
+ status = SimulationStatus.TIMEOUT
156
+ task_final_reasons[current_task_index] = f"Timeout: exceeded {max_turns_task} turns"
157
+ break
158
+
159
+ turns_for_current_task += 1
160
+ task_turn_counts[current_task_index] = task_turn_counts.get(current_task_index, 0) + 1
161
+
162
+ print(f"\n=== Turn {turn_idx} (Task turn: {turns_for_current_task}) ===")
163
+ print(f"Current Task: {current_task.task}")
164
+ print(f"User: {user_message}")
165
+
166
+ # Notify metric collectors of turn start
167
+ collectors.on_turn_start(turn_idx, current_task_index, user_message)
168
+
169
+ # Add optional prefix to user message
170
+ full_user_message = user_message_prefix + user_message if user_message_prefix else user_message
171
+
172
+ assistant_reply_parts: list[str] = []
173
+ tool_calls: list[ToolCallResult] = []
174
+ turn_usage: Usage = Usage()
175
+ dummy_chat_context = ChatContext()
176
+
177
+ stream = chat.chat(
178
+ message=full_user_message,
179
+ history=[
180
+ d
181
+ for turn in history
182
+ for d in ({"role": "user", "text": turn.user}, {"role": "assistant", "text": turn.assistant})
183
+ ],
184
+ context=dummy_chat_context,
191
185
  )
192
- )
193
-
194
- def set_task_result(self, task_index: int, completed: bool, reason: str) -> None:
195
- """Set the result for a task."""
196
- task_data = self.ensure_task_data(task_index)
197
- task_data.completed = completed
198
- task_data.final_reason = reason
199
-
200
- def on_new_task(self, task_id: int) -> None:
201
- """Handle transition to a new task."""
202
- self.turns_for_current_task = 0
203
- self.current_task_id = task_id
204
-
205
-
206
- @dataclass
207
- class SimulationContext:
208
- """Environment and dependencies for running a simulation."""
209
-
210
- config: SimulationConfig
211
- chat: ChatInterface
212
- sim_user: SimulatedUser
213
- checker_context: CheckerContext
214
- logger: ConversationLogger
215
- collectors: CompositeMetricCollector
216
- out: IO[str]
217
- trace_handler: MemoryTraceHandler | None = None
218
- progress_callback: ProgressCallback | None = None
219
-
220
- async def emit_progress(self, event_type: str, **kwargs: Any) -> None:
221
- """Emit a progress callback if one is configured."""
222
- if self.progress_callback:
223
- await self.progress_callback(event_type, **kwargs)
224
-
225
- def print(self, message: str) -> None:
226
- """Print a message to the output stream."""
227
- print(message, file=self.out)
228
-
229
-
230
- async def _process_chat_stream(
231
- ctx: SimulationContext,
232
- state: SimulationState,
233
- turn_idx: int,
234
- task_index: int,
235
- full_user_message: str,
236
- user_message: str,
237
- ) -> tuple[str, list[ToolCallResult], Usage]:
238
- """Process the chat stream and collect response data.
239
-
240
- Tool calls and usage are extracted from traces rather than from the stream,
241
- providing more reliable and complete data.
242
-
243
- Returns:
244
- Tuple of (assistant_reply, tool_calls, turn_usage)
245
- """
246
- assistant_reply_parts: list[str] = []
247
-
248
- # Track span count before the stream to identify new spans from this turn
249
- span_count_before = len(ctx.trace_handler.root_spans) if ctx.trace_handler else 0
250
-
251
- stream = ctx.chat.chat(
252
- message=full_user_message,
253
- history=[
254
- d
255
- for turn in state.history
256
- for d in ({"role": "user", "text": turn.user}, {"role": "assistant", "text": turn.assistant})
257
- ],
258
- context=state.chat_context,
259
- )
260
186
 
261
- async for chunk in stream:
262
- # Emit response chunk event for real-time streaming
263
- if ctx.progress_callback:
264
- chunk_type, chunk_data = _serialize_response_chunk(chunk)
265
- await ctx.emit_progress(
266
- "response_chunk",
187
+ async for chunk in stream:
188
+ if isinstance(chunk, str):
189
+ assistant_reply_parts.append(chunk)
190
+ elif isinstance(chunk, ToolCallResult):
191
+ tool_calls.append(chunk)
192
+ elif isinstance(chunk, Usage):
193
+ turn_usage += chunk
194
+
195
+ total_usage += turn_usage
196
+ assistant_reply = "".join(assistant_reply_parts).strip()
197
+ print(f"Assistant: {assistant_reply}")
198
+ if tool_calls:
199
+ print(f"Tools used: {[tc.name for tc in tool_calls]}")
200
+ print(
201
+ f"Assistant token usage: {turn_usage.total_tokens} total "
202
+ f"({turn_usage.prompt_tokens} prompt + {turn_usage.completion_tokens} completion), "
203
+ f"estimated cost: ${turn_usage.estimated_cost:.6f}"
204
+ )
205
+ logger.log_turn(turn_idx, current_task, user_message, assistant_reply, tool_calls, turn_usage)
206
+
207
+ # Update simulation-visible history (Turn objects)
208
+ history.append(Turn(user=user_message, assistant=assistant_reply))
209
+
210
+ # Ask the judge if current task is achieved
211
+ task_done, reason = await goal_checker.is_task_achieved(current_task, history, context=domain_context)
212
+ logger.log_task_check(turn_idx, task_done, reason)
213
+
214
+ # Check tool usage if expected tools are specified
215
+ if current_task.expected_tools:
216
+ tools_used_correctly, tool_reason = tool_checker.check_tool_usage(current_task, tool_calls)
217
+ logger.log_tool_check(turn_idx, tools_used_correctly, tool_reason, tool_calls)
218
+ if not tools_used_correctly:
219
+ print(f"Tool usage issue: {tool_reason}")
220
+ else:
221
+ print(f"Tool usage verified: {tool_reason}")
222
+
223
+ # Create turn result
224
+ turn_result = TurnResult(
267
225
  turn_index=turn_idx,
268
- task_index=task_index,
269
- chunk_type=chunk_type,
270
- chunk_data=chunk_data,
226
+ task_index=current_task_index,
227
+ user_message=user_message,
228
+ assistant_message=assistant_reply,
229
+ tool_calls=[{"name": tc.name, "arguments": tc.arguments, "result": tc.result} for tc in tool_calls],
230
+ task_completed=task_done,
231
+ task_completed_reason=reason,
232
+ token_usage={
233
+ "total": turn_usage.total_tokens,
234
+ "prompt": turn_usage.prompt_tokens,
235
+ "completion": turn_usage.completion_tokens,
236
+ },
271
237
  )
272
-
273
- ctx.collectors.on_streamed_response(turn_idx, task_index, user_message, chunk)
274
-
275
- # Process chunk by type - only collect text and context updates from stream
276
- if isinstance(chunk, TextResponse):
277
- assistant_reply_parts.append(chunk.content.text)
278
- elif isinstance(chunk, ConversationIdResponse):
279
- state.chat_context.conversation_id = chunk.content.conversation_id
280
- elif isinstance(chunk, StateUpdateResponse):
281
- state.chat_context.state = chunk.content.state
282
-
283
- # Extract tool calls and usage from traces for this turn
284
- tool_calls: list[ToolCallResult] = []
285
- turn_usage = Usage()
286
-
287
- if ctx.trace_handler:
288
- # Get only the new spans from this turn
289
- new_spans = ctx.trace_handler.root_spans[span_count_before:]
290
- if new_spans:
291
- # Create analyzer from just this turn's spans
292
- turn_traces = [span.to_dict() for span in new_spans]
293
- analyzer = TraceAnalyzer.from_traces(turn_traces)
294
- tool_calls = analyzer.get_tool_calls()
295
- turn_usage = analyzer.get_usage()
296
-
297
- return "".join(assistant_reply_parts).strip(), tool_calls, turn_usage
298
-
299
-
300
- def _build_simulation_result(
301
- scenario: Scenario,
302
- ctx: SimulationContext,
303
- state: SimulationState,
304
- task_results: list[TaskResult],
305
- collector_metrics: dict[str, Any],
306
- personality: Personality | None,
307
- traces: list[dict[str, Any]],
308
- ) -> SimulationResult:
309
- """Build the final SimulationResult from state and metrics."""
310
- tasks_completed = sum(1 for t in task_results if t.completed)
311
- total_tasks = len(scenario.tasks)
312
-
313
- metrics_dict: dict[str, Any] = {
314
- "total_turns": len(state.turn_results),
315
- "total_tasks": total_tasks,
316
- "tasks_completed": tasks_completed,
317
- "success_rate": tasks_completed / total_tasks if total_tasks > 0 else 0.0,
318
- }
319
- metrics_dict.update(collector_metrics)
320
-
321
- return SimulationResult(
322
- scenario_name=scenario.name,
323
- start_time=state.start_time,
324
- end_time=datetime.now(timezone.utc),
325
- status=state.status,
326
- agent_model=ctx.config.agent_model_name,
327
- simulated_user_model=ctx.config.sim_user_model_name,
328
- checker_model=ctx.config.checker_model_name,
329
- persona=personality.name if personality else None,
330
- error=state.error_message,
331
- conversation_id=state.chat_context.conversation_id,
332
- final_state=state.chat_context.state,
333
- turns=state.turn_results,
334
- tasks=task_results,
335
- metrics=ConversationMetrics(metrics=metrics_dict),
336
- traces=traces,
337
- )
338
-
339
-
340
- def _build_task_results(
341
- scenario: Scenario,
342
- ctx: SimulationContext,
343
- state: SimulationState,
344
- ) -> list[TaskResult]:
345
- """Build task results from scenario and state."""
346
- task_results: list[TaskResult] = []
238
+ turn_results.append(turn_result)
239
+
240
+ # Notify metric collectors of turn end
241
+ collectors.on_turn_end(turn_result)
242
+
243
+ if task_done:
244
+ print(f"Task completed: {reason}")
245
+ task_final_reasons[current_task_index] = reason
246
+ has_next = sim_user.advance_to_next_task()
247
+ if not has_next:
248
+ print("\nAll tasks completed!")
249
+ status = SimulationStatus.COMPLETED
250
+ break
251
+ next_task = sim_user.get_current_task()
252
+ if next_task:
253
+ logger.log_task_transition(next_task)
254
+ # Reset task turn counter when moving to next task
255
+ turns_for_current_task = 0
256
+ current_task_id = id(next_task) if next_task else None
257
+ else:
258
+ print(f"Task not completed: {reason}")
259
+ task_final_reasons[current_task_index] = reason
260
+
261
+ # Ask the simulator for the next user message
262
+ user_message = await sim_user.next_message(history)
263
+
264
+ else:
265
+ print("\nReached maximum number of turns. Exiting.")
266
+ status = SimulationStatus.TIMEOUT
267
+
268
+ except Exception as e:
269
+ status = SimulationStatus.FAILED
270
+ error_message = str(e)
271
+ print(f"\nSimulation failed with error: {error_message}")
272
+
273
+ # Build task results
347
274
  for i, task in enumerate(scenario.tasks):
348
- completed = i < ctx.sim_user.current_task_idx or (
349
- i == ctx.sim_user.current_task_idx and state.status == SimulationStatus.COMPLETED
275
+ completed = i < sim_user.current_task_idx or (
276
+ i == sim_user.current_task_idx and status == SimulationStatus.COMPLETED
350
277
  )
351
278
  task_results.append(
352
279
  TaskResult(
353
280
  task_index=i,
354
281
  description=task.task,
282
+ expected_result=task.expected_result,
355
283
  completed=completed,
356
- turns_taken=state.get_task_turn_count(i),
357
- final_reason=state.get_task_final_reason(i),
358
- checkers=task.checkers,
359
- checker_mode=task.checker_mode,
284
+ turns_taken=task_turn_counts.get(i, 0),
285
+ final_reason=task_final_reasons.get(i, "Not attempted"),
360
286
  )
361
287
  )
362
- return task_results
363
-
364
-
365
- async def _handle_task_completion(
366
- ctx: SimulationContext,
367
- state: SimulationState,
368
- current_task_index: int,
369
- current_task: Task,
370
- reason: str,
371
- ) -> bool:
372
- """Handle task completion. Returns True if simulation should continue."""
373
- ctx.print(f"Task completed: {reason}")
374
- state.set_task_result(current_task_index, True, reason)
375
-
376
- await ctx.emit_progress(
377
- "task_complete",
378
- task_index=current_task_index,
379
- task_description=current_task.task,
380
- turns_taken=state.get_task_turn_count(current_task_index),
381
- reason=reason,
382
- )
383
288
 
384
- has_next = ctx.sim_user.advance_to_next_task()
385
- if not has_next:
386
- ctx.print("\nAll tasks completed!")
387
- state.status = SimulationStatus.COMPLETED
388
- return False
389
-
390
- next_task = ctx.sim_user.get_current_task()
391
- if next_task:
392
- ctx.logger.log_task_transition(next_task)
393
- state.on_new_task(id(next_task) if next_task else None)
394
- return True
395
-
396
-
397
- async def _run_turn_checkers(
398
- ctx: SimulationContext,
399
- state: SimulationState,
400
- current_task: Task,
401
- current_task_index: int,
402
- tool_calls: list[ToolCallResult],
403
- ) -> CheckerResult:
404
- """Run checkers for the current turn and return the result."""
405
- if current_task_index not in state.task_checkers_cache:
406
- state.task_checkers_cache[current_task_index] = current_task.get_parsed_checkers()
407
-
408
- return await run_checkers(
409
- checkers=state.task_checkers_cache[current_task_index],
410
- task=current_task,
411
- history=state.history,
412
- tool_calls=tool_calls,
413
- state=state.chat_context.state,
414
- context=ctx.checker_context,
415
- mode=current_task.checker_mode,
289
+ # Print total token usage summary
290
+ print("\n=== Total Assistant Token Usage ===")
291
+ print(
292
+ f"Total assistant tokens: {total_usage.total_tokens} "
293
+ f"({total_usage.prompt_tokens} prompt + {total_usage.completion_tokens} completion)"
416
294
  )
295
+ print(f"Total estimated cost: ${total_usage.estimated_cost:.6f}")
296
+ logger.log_total_usage(total_usage)
417
297
 
298
+ # Evaluate conversation with DeepEval metrics
299
+ deepeval_scores = _evaluate_with_deepeval(history, logger)
418
300
 
419
- async def _emit_turn_progress(
420
- ctx: SimulationContext,
421
- turn_idx: int,
422
- current_task_index: int,
423
- current_task: Task,
424
- user_message: str,
425
- assistant_reply: str,
426
- tool_calls_data: list[dict[str, Any]],
427
- checker_result: CheckerResult,
428
- checkers_list: list[CheckerResultItem],
429
- checker_mode: str,
430
- ) -> None:
431
- """Emit progress events for checker decision and turn completion."""
432
- task_done = checker_result.completed
433
- reason = checker_result.reason
434
-
435
- await ctx.emit_progress(
436
- "response_chunk",
437
- turn_index=turn_idx,
438
- task_index=current_task_index,
439
- chunk_type="checker_decision",
440
- chunk_data={
441
- "task_completed": task_done,
442
- "reason": reason,
443
- "task_description": current_task.task,
444
- "checkers": [c.to_dict() for c in checkers_list],
445
- "checker_mode": checker_mode,
446
- },
447
- )
448
-
449
- await ctx.emit_progress(
450
- "turn",
451
- turn_index=turn_idx,
452
- task_index=current_task_index,
453
- user_message=user_message,
454
- assistant_message=assistant_reply,
455
- tool_calls=tool_calls_data,
456
- task_completed=task_done,
457
- task_completed_reason=reason,
458
- checkers=[c.to_dict() for c in checkers_list],
459
- checker_mode=checker_mode,
460
- )
461
-
462
-
463
- def _create_simulation_context(
464
- scenario: Scenario,
465
- chat: ChatInterface,
466
- config: SimulationConfig,
467
- personality: Personality | None,
468
- progress_callback: ProgressCallback | None,
469
- output_stream: IO[str] | None,
470
- trace_handler: MemoryTraceHandler | None = None,
471
- ) -> SimulationContext:
472
- """Create the simulation context with all dependencies."""
473
- out = output_stream if output_stream is not None else sys.stdout
474
-
475
- # Create metric collectors
476
- builtin_collectors = [
477
- LatencyMetricCollector(),
478
- TokenUsageMetricCollector(),
479
- ToolUsageMetricCollector(),
480
- ]
481
- user_collectors = config.create_metric_collectors()
482
- collectors = CompositeMetricCollector(builtin_collectors + user_collectors)
483
-
484
- # Simulated user
485
- sim_user = SimulatedUser(
486
- llm=build_llm(config.sim_user_model_name, config.default_model, config.api_key),
487
- scenario=scenario,
488
- personality=personality,
489
- data_snapshot=config.data_snapshot,
490
- )
491
-
492
- # Checker context
493
- checker_llm = build_llm(config.checker_model_name, config.default_model, config.api_key)
494
- checker_context = CheckerContext(llm=checker_llm, domain_context=config.domain_context)
495
-
496
- # Logger
497
- logger = ConversationLogger(config.log_file)
498
- logger.initialize_session(
499
- scenario, config.agent_model_name, config.sim_user_model_name, config.checker_model_name, personality
500
- )
501
-
502
- return SimulationContext(
503
- config=config,
504
- chat=chat,
505
- sim_user=sim_user,
506
- checker_context=checker_context,
507
- logger=logger,
508
- collectors=collectors,
509
- out=out,
510
- trace_handler=trace_handler,
511
- progress_callback=progress_callback,
512
- )
301
+ # Collect custom metrics from collectors
302
+ custom_metrics = collectors.on_conversation_end(turn_results)
513
303
 
304
+ logger.finalize_session()
514
305
 
515
- async def _execute_turn(
516
- ctx: SimulationContext,
517
- state: SimulationState,
518
- turn_idx: int,
519
- current_task: Task,
520
- current_task_index: int,
521
- user_message: str,
522
- ) -> tuple[bool, str]:
523
- """Execute a single turn. Returns (should_continue, next_user_message)."""
524
- ctx.print(f"\n=== Turn {turn_idx} (Task turn: {state.turns_for_current_task}) ===")
525
- ctx.print(f"Current Task: {current_task.task}")
526
- ctx.print(f"User: {user_message}")
527
-
528
- ctx.collectors.on_turn_start(turn_idx, current_task_index, user_message)
529
-
530
- full_user_message = (
531
- ctx.config.user_message_prefix + user_message if ctx.config.user_message_prefix else user_message
532
- )
533
-
534
- # Process assistant response
535
- assistant_reply, tool_calls, turn_usage = await _process_chat_stream(
536
- ctx, state, turn_idx, current_task_index, full_user_message, user_message
537
- )
538
-
539
- state.total_usage += turn_usage
540
- _log_turn_output(ctx, assistant_reply, tool_calls, turn_usage)
541
- ctx.logger.log_turn(turn_idx, current_task, user_message, assistant_reply, tool_calls, turn_usage)
542
-
543
- # Update state
544
- state.history.append(Turn(user=user_message, assistant=assistant_reply))
545
- state.record_turn(current_task_index, user_message, assistant_reply, turn_usage, tool_calls)
546
-
547
- # Run checkers
548
- checker_result = await _run_turn_checkers(ctx, state, current_task, current_task_index, tool_calls)
549
- ctx.logger.log_task_check(turn_idx, checker_result.completed, checker_result.reason)
550
-
551
- # Build and record turn result
552
- checkers_list, checker_mode = _build_checker_results(checker_result)
553
- tool_calls_data = [{"name": tc.name, "arguments": tc.arguments, "result": tc.result} for tc in tool_calls]
554
-
555
- turn_result = TurnResult(
556
- turn_index=turn_idx,
557
- task_index=current_task_index,
558
- user_message=user_message,
559
- assistant_message=assistant_reply,
560
- tool_calls=tool_calls_data,
561
- task_completed=checker_result.completed,
562
- task_completed_reason=checker_result.reason,
563
- token_usage=turn_usage,
564
- checkers=checkers_list,
565
- checker_mode=checker_mode,
566
- )
567
- state.turn_results.append(turn_result)
568
- ctx.collectors.on_turn_end(turn_result)
569
-
570
- # Emit progress
571
- await _emit_turn_progress(
572
- ctx,
573
- turn_idx,
574
- current_task_index,
575
- current_task,
576
- user_message,
577
- assistant_reply,
578
- tool_calls_data,
579
- checker_result,
580
- checkers_list,
581
- checker_mode,
306
+ # Build metrics
307
+ tasks_completed = sum(1 for t in task_results if t.completed)
308
+ metrics = ConversationMetrics(
309
+ total_turns=len(turn_results),
310
+ total_tasks=len(scenario.tasks),
311
+ tasks_completed=tasks_completed,
312
+ total_tokens=total_usage.total_tokens,
313
+ prompt_tokens=total_usage.prompt_tokens,
314
+ completion_tokens=total_usage.completion_tokens,
315
+ total_cost_usd=total_usage.estimated_cost,
316
+ deepeval_scores=deepeval_scores,
317
+ custom=custom_metrics,
582
318
  )
583
319
 
584
- status = "✓" if checker_result.completed else "✗"
585
- ctx.print(f"Checker [{checker_result.checker_type}]: {status} {checker_result.reason}")
586
-
587
- # Handle task completion or continuation
588
- if checker_result.completed:
589
- should_continue = await _handle_task_completion(
590
- ctx, state, current_task_index, current_task, checker_result.reason
591
- )
592
- if not should_continue:
593
- return False, ""
594
- else:
595
- ctx.print(f"Task not completed: {checker_result.reason}")
596
- state.set_task_result(current_task_index, False, checker_result.reason)
597
-
598
- next_message = await ctx.sim_user.next_message(state.history)
599
- return True, next_message
600
-
601
-
602
- def _log_turn_output(
603
- ctx: SimulationContext,
604
- assistant_reply: str,
605
- tool_calls: list[ToolCallResult],
606
- turn_usage: Usage,
607
- ) -> None:
608
- """Log turn output to the output stream."""
609
- ctx.print(f"Assistant: {assistant_reply}")
610
- if tool_calls:
611
- ctx.print(f"Tools used: {[tc.name for tc in tool_calls]}")
612
- ctx.print(
613
- f"Assistant token usage: {turn_usage.total_tokens} total "
614
- f"({turn_usage.prompt_tokens} prompt + {turn_usage.completion_tokens} completion), "
615
- f"estimated cost: ${turn_usage.estimated_cost:.6f}"
320
+ return SimulationResult(
321
+ scenario_name=scenario.name,
322
+ start_time=start_time,
323
+ end_time=datetime.now(timezone.utc),
324
+ status=status,
325
+ agent_model=agent_model_name,
326
+ simulated_user_model=sim_user_model_name,
327
+ checker_model=checker_model_name,
328
+ personality=personality.name if personality else None,
329
+ error=error_message,
330
+ turns=turn_results,
331
+ tasks=task_results,
332
+ metrics=metrics,
616
333
  )
617
-
618
-
619
- async def _run_simulation_loop(
620
- ctx: SimulationContext,
621
- state: SimulationState,
622
- user_message: str,
623
- ) -> None:
624
- """Run the main simulation loop."""
625
- for turn_idx in range(1, ctx.config.max_turns_scenario + 1):
626
- current_task = ctx.sim_user.get_current_task()
627
- if current_task is None:
628
- ctx.print("\nAll tasks completed!")
629
- state.status = SimulationStatus.COMPLETED
630
- return
631
-
632
- current_task_index = ctx.sim_user.current_task_idx
633
-
634
- if state.current_task_id != id(current_task):
635
- state.on_new_task(id(current_task))
636
-
637
- if ctx.config.max_turns_task is not None and state.turns_for_current_task >= ctx.config.max_turns_task:
638
- ctx.print(f"\nReached maximum number of turns ({ctx.config.max_turns_task}) for current task. Exiting.")
639
- state.status = SimulationStatus.TIMEOUT
640
- state.set_task_result(current_task_index, False, f"Timeout: exceeded {ctx.config.max_turns_task} turns")
641
- return
642
-
643
- state.turns_for_current_task += 1
644
- state.ensure_task_data(current_task_index)
645
-
646
- should_continue, user_message = await _execute_turn(
647
- ctx, state, turn_idx, current_task, current_task_index, user_message
648
- )
649
- if not should_continue:
650
- return
651
-
652
- ctx.print("\nReached maximum number of turns. Exiting.")
653
- state.status = SimulationStatus.TIMEOUT
654
-
655
-
656
- async def run_simulation(
657
- scenario: Scenario,
658
- chat: ChatInterface,
659
- config: SimulationConfig | None = None,
660
- *,
661
- personality: Personality | None = None,
662
- progress_callback: ProgressCallback | None = None,
663
- output_stream: IO[str] | None = None,
664
- ) -> SimulationResult:
665
- """Run a conversation between an agent and a simulated user.
666
-
667
- The conversation proceeds through tasks defined in the scenario, with a goal checker
668
- determining when each task is completed. Returns structured results for programmatic access.
669
-
670
- Args:
671
- scenario: The scenario containing tasks to complete
672
- chat: An instantiated ChatInterface used to drive the assistant side of the conversation.
673
- config: Optional SimulationConfig with all simulation parameters. If not provided,
674
- uses default values.
675
- personality: Optional personality to use for the simulated user.
676
- progress_callback: Optional async callback for progress updates (event_type, **kwargs)
677
- output_stream: Optional stream for simulation output (print statements).
678
- If None, defaults to sys.stdout. Pass a file handle, StringIO,
679
- or any file-like object to redirect output. Use io.StringIO()
680
- to capture output or a custom stream to integrate with logging.
681
-
682
- Returns:
683
- SimulationResult containing all turns, task results, and metrics
684
- """
685
- if config is None:
686
- config = SimulationConfig()
687
-
688
- state = SimulationState()
689
-
690
- # Collect traces from all traceable operations during the simulation
691
- with collect_traces(simulation_id=state.simulation_id) as trace_handler:
692
- ctx = _create_simulation_context(
693
- scenario=scenario,
694
- chat=chat,
695
- config=config,
696
- personality=personality,
697
- progress_callback=progress_callback,
698
- output_stream=output_stream,
699
- trace_handler=trace_handler,
700
- )
701
-
702
- user_message = await ctx.sim_user.next_message(history=state.history)
703
-
704
- try:
705
- await _run_simulation_loop(ctx, state, user_message)
706
- except Exception as e:
707
- state.status = SimulationStatus.FAILED
708
- state.error_message = str(e)
709
- ctx.print(f"\nSimulation failed with error: {state.error_message}")
710
-
711
- traces = trace_handler.get_traces()
712
-
713
- task_results = _build_task_results(scenario, ctx, state)
714
- ctx.logger.log_total_usage(state.total_usage)
715
- collector_metrics = ctx.collectors.on_conversation_end(state.turn_results)
716
- ctx.logger.finalize_session()
717
-
718
- return _build_simulation_result(scenario, ctx, state, task_results, collector_metrics, personality, traces)