ragbits-evaluate 0.0.30rc1__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +4 -49
- ragbits/evaluate/agent_simulation/conversation.py +278 -663
- ragbits/evaluate/agent_simulation/logger.py +1 -1
- ragbits/evaluate/agent_simulation/metrics/__init__.py +0 -10
- ragbits/evaluate/agent_simulation/metrics/builtin.py +49 -59
- ragbits/evaluate/agent_simulation/metrics/collectors.py +17 -37
- ragbits/evaluate/agent_simulation/models.py +18 -198
- ragbits/evaluate/agent_simulation/results.py +49 -125
- ragbits/evaluate/agent_simulation/scenarios.py +19 -95
- ragbits/evaluate/agent_simulation/simulation.py +166 -72
- ragbits/evaluate/metrics/question_answer.py +25 -8
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +2 -6
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/RECORD +14 -25
- ragbits/evaluate/agent_simulation/checkers.py +0 -591
- ragbits/evaluate/agent_simulation/display.py +0 -118
- ragbits/evaluate/agent_simulation/metrics/deepeval.py +0 -295
- ragbits/evaluate/agent_simulation/tracing.py +0 -233
- ragbits/evaluate/api.py +0 -603
- ragbits/evaluate/api_types.py +0 -343
- ragbits/evaluate/execution_manager.py +0 -451
- ragbits/evaluate/stores/__init__.py +0 -36
- ragbits/evaluate/stores/base.py +0 -98
- ragbits/evaluate/stores/file.py +0 -466
- ragbits/evaluate/stores/kv.py +0 -535
- {ragbits_evaluate-0.0.30rc1.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +0 -0
|
@@ -1,718 +1,333 @@
|
|
|
1
1
|
"""Conversation orchestration for agent simulation scenarios."""
|
|
2
2
|
|
|
3
|
-
import sys
|
|
4
|
-
from collections.abc import Awaitable, Callable
|
|
5
|
-
from dataclasses import dataclass, field
|
|
6
3
|
from datetime import datetime, timezone
|
|
7
|
-
from typing import IO, Any
|
|
8
|
-
from uuid import uuid4
|
|
9
4
|
|
|
10
5
|
from ragbits.agents.tool import ToolCallResult
|
|
11
6
|
from ragbits.chat.interface import ChatInterface
|
|
12
|
-
from ragbits.chat.interface.types import
|
|
13
|
-
ChatContext,
|
|
14
|
-
ChatResponse,
|
|
15
|
-
ConversationIdResponse,
|
|
16
|
-
StateUpdateResponse,
|
|
17
|
-
TextResponse,
|
|
18
|
-
)
|
|
7
|
+
from ragbits.chat.interface.types import ChatContext
|
|
19
8
|
from ragbits.core.llms import Usage
|
|
20
|
-
from ragbits.evaluate.agent_simulation.
|
|
9
|
+
from ragbits.evaluate.agent_simulation.context import DataSnapshot, DomainContext
|
|
10
|
+
from ragbits.evaluate.agent_simulation.deepeval_evaluator import DeepEvalEvaluator
|
|
21
11
|
from ragbits.evaluate.agent_simulation.logger import ConversationLogger
|
|
22
|
-
from ragbits.evaluate.agent_simulation.metrics.
|
|
23
|
-
|
|
24
|
-
TokenUsageMetricCollector,
|
|
25
|
-
ToolUsageMetricCollector,
|
|
26
|
-
)
|
|
27
|
-
from ragbits.evaluate.agent_simulation.metrics.collectors import CompositeMetricCollector
|
|
28
|
-
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, SimulationConfig, Task, Turn
|
|
12
|
+
from ragbits.evaluate.agent_simulation.metrics.collectors import CompositeMetricCollector, MetricCollector
|
|
13
|
+
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Turn
|
|
29
14
|
from ragbits.evaluate.agent_simulation.results import (
|
|
30
|
-
CheckerResultItem,
|
|
31
15
|
ConversationMetrics,
|
|
32
16
|
SimulationResult,
|
|
33
17
|
SimulationStatus,
|
|
34
18
|
TaskResult,
|
|
35
19
|
TurnResult,
|
|
36
20
|
)
|
|
37
|
-
from ragbits.evaluate.agent_simulation.simulation import SimulatedUser, build_llm
|
|
38
|
-
from ragbits.evaluate.agent_simulation.tracing import MemoryTraceHandler, TraceAnalyzer, collect_traces
|
|
21
|
+
from ragbits.evaluate.agent_simulation.simulation import GoalChecker, SimulatedUser, ToolUsageChecker, build_llm
|
|
39
22
|
|
|
40
|
-
ProgressCallback = Callable[[str, Any], Awaitable[None]]
|
|
41
23
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
"""Serialize a response chunk to (type, data) tuple.
|
|
24
|
+
def _evaluate_with_deepeval(history: list[Turn], logger: ConversationLogger) -> dict[str, float]:
|
|
25
|
+
"""Evaluate conversation with DeepEval metrics.
|
|
45
26
|
|
|
46
27
|
Args:
|
|
47
|
-
|
|
28
|
+
history: List of conversation turns to evaluate
|
|
29
|
+
logger: Logger instance to record evaluation results
|
|
48
30
|
|
|
49
31
|
Returns:
|
|
50
|
-
|
|
32
|
+
Dictionary of metric names to scores
|
|
51
33
|
"""
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
"
|
|
68
|
-
"
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
34
|
+
scores: dict[str, float] = {}
|
|
35
|
+
if not history:
|
|
36
|
+
return scores
|
|
37
|
+
|
|
38
|
+
print("\n=== Running DeepEval Evaluation ===")
|
|
39
|
+
deepeval_evaluator = DeepEvalEvaluator()
|
|
40
|
+
try:
|
|
41
|
+
evaluation_results = deepeval_evaluator.evaluate_conversation(history)
|
|
42
|
+
logger.log_deepeval_metrics(evaluation_results)
|
|
43
|
+
for metric_name, result in evaluation_results.items():
|
|
44
|
+
score = result.get("score")
|
|
45
|
+
if score is not None:
|
|
46
|
+
print(f"{metric_name}: {score:.4f}")
|
|
47
|
+
scores[metric_name] = float(score)
|
|
48
|
+
else:
|
|
49
|
+
error = result.get("error", "Unknown error")
|
|
50
|
+
print(f"{metric_name}: Error - {error}")
|
|
51
|
+
except Exception as e:
|
|
52
|
+
print(f"Error during DeepEval evaluation: {e}")
|
|
53
|
+
logger.log_deepeval_metrics({"error": str(e)})
|
|
54
|
+
|
|
55
|
+
return scores
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def run_simulation( # noqa: PLR0912, PLR0915
|
|
59
|
+
scenario: Scenario,
|
|
60
|
+
chat: ChatInterface,
|
|
61
|
+
max_turns_scenario: int = 15,
|
|
62
|
+
max_turns_task: int | None = 4,
|
|
63
|
+
log_file: str | None = None,
|
|
64
|
+
agent_model_name: str | None = None,
|
|
65
|
+
sim_user_model_name: str | None = None,
|
|
66
|
+
checker_model_name: str | None = None,
|
|
67
|
+
default_model: str = "gpt-4o-mini",
|
|
68
|
+
api_key: str = "",
|
|
69
|
+
user_message_prefix: str = "",
|
|
70
|
+
personality: Personality | None = None,
|
|
71
|
+
domain_context: DomainContext | None = None,
|
|
72
|
+
data_snapshot: DataSnapshot | None = None,
|
|
73
|
+
metric_collectors: list[MetricCollector] | None = None,
|
|
74
|
+
) -> SimulationResult:
|
|
75
|
+
"""Run a conversation between an agent and a simulated user.
|
|
77
76
|
|
|
78
|
-
|
|
79
|
-
|
|
77
|
+
The conversation proceeds through tasks defined in the scenario, with a goal checker
|
|
78
|
+
determining when each task is completed. Returns structured results for programmatic access.
|
|
80
79
|
|
|
81
80
|
Args:
|
|
82
|
-
|
|
81
|
+
scenario: The scenario containing tasks to complete
|
|
82
|
+
chat: An instantiated ChatInterface used to drive the assistant side of the conversation.
|
|
83
|
+
max_turns_scenario: Maximum number of conversation turns for the entire scenario
|
|
84
|
+
max_turns_task: Maximum number of conversation turns per task (None for no limit)
|
|
85
|
+
log_file: Optional path to log file
|
|
86
|
+
agent_model_name: Optional override for agent LLM model name
|
|
87
|
+
sim_user_model_name: Optional override for simulated user LLM model name
|
|
88
|
+
checker_model_name: Optional override for goal checker LLM model name
|
|
89
|
+
default_model: Default LLM model name
|
|
90
|
+
api_key: API key for LLM
|
|
91
|
+
user_message_prefix: Optional prefix to add to user messages before sending to agent
|
|
92
|
+
personality: Optional personality to use for the simulated user
|
|
93
|
+
domain_context: Optional domain context for goal checking (currency, locale, business rules)
|
|
94
|
+
data_snapshot: Optional data snapshot to ground simulated user requests to available data
|
|
95
|
+
metric_collectors: Optional list of custom metric collectors for per-turn metrics
|
|
83
96
|
|
|
84
97
|
Returns:
|
|
85
|
-
|
|
98
|
+
SimulationResult containing all turns, task results, and metrics
|
|
86
99
|
"""
|
|
87
|
-
|
|
88
|
-
all_results = checker_result.details.get("all_results", [])
|
|
89
|
-
checker_mode = checker_result.details.get("mode", "all")
|
|
90
|
-
|
|
91
|
-
if all_results:
|
|
92
|
-
for r in all_results:
|
|
93
|
-
checkers_list.append(
|
|
94
|
-
CheckerResultItem(
|
|
95
|
-
type=r.get("checker_type", "unknown"),
|
|
96
|
-
completed=r.get("completed", False),
|
|
97
|
-
reason=r.get("reason", ""),
|
|
98
|
-
)
|
|
99
|
-
)
|
|
100
|
-
else:
|
|
101
|
-
checkers_list.append(
|
|
102
|
-
CheckerResultItem(
|
|
103
|
-
type=checker_result.checker_type,
|
|
104
|
-
completed=checker_result.completed,
|
|
105
|
-
reason=checker_result.reason,
|
|
106
|
-
)
|
|
107
|
-
)
|
|
108
|
-
|
|
109
|
-
return checkers_list, checker_mode
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
@dataclass
|
|
113
|
-
class TurnData:
|
|
114
|
-
"""Data collected during a single turn."""
|
|
115
|
-
|
|
116
|
-
usage: Usage
|
|
117
|
-
user_message: str
|
|
118
|
-
assistant_response: str
|
|
119
|
-
tool_calls: list[ToolCallResult] = field(default_factory=list)
|
|
120
|
-
|
|
100
|
+
start_time = datetime.now(timezone.utc)
|
|
121
101
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
"""Data collected for a single task."""
|
|
102
|
+
# Initialize metric collectors
|
|
103
|
+
collectors = CompositeMetricCollector(metric_collectors)
|
|
125
104
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
@dataclass
|
|
133
|
-
class SimulationState:
|
|
134
|
-
"""Mutable state tracking during simulation execution."""
|
|
135
|
-
|
|
136
|
-
simulation_id: str = field(default_factory=lambda: str(uuid4()))
|
|
137
|
-
start_time: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
138
|
-
current_task_idx: int = 0
|
|
139
|
-
current_turn_idx: int = 0
|
|
140
|
-
turns_for_current_task: int = 0
|
|
141
|
-
current_task_id: int | None = None
|
|
142
|
-
|
|
143
|
-
# Accumulated results
|
|
144
|
-
turn_results: list[TurnResult] = field(default_factory=list)
|
|
145
|
-
task_data: dict[int, TaskData] = field(default_factory=dict)
|
|
146
|
-
history: list[Turn] = field(default_factory=list)
|
|
147
|
-
total_usage: Usage = field(default_factory=Usage)
|
|
105
|
+
# Initialize result tracking
|
|
106
|
+
turn_results: list[TurnResult] = []
|
|
107
|
+
task_results: list[TaskResult] = []
|
|
108
|
+
task_turn_counts: dict[int, int] = {} # task_index -> turns taken
|
|
109
|
+
task_final_reasons: dict[int, str] = {} # task_index -> final reason
|
|
148
110
|
|
|
149
|
-
#
|
|
150
|
-
|
|
151
|
-
|
|
111
|
+
# Simulated user uses an independent llm (can share the same provider)
|
|
112
|
+
sim_user = SimulatedUser(
|
|
113
|
+
llm=build_llm(sim_user_model_name, default_model, api_key),
|
|
114
|
+
scenario=scenario,
|
|
115
|
+
personality=personality,
|
|
116
|
+
data_snapshot=data_snapshot,
|
|
117
|
+
)
|
|
118
|
+
# Independent goal checker model
|
|
119
|
+
goal_checker = GoalChecker(llm=build_llm(checker_model_name, default_model, api_key), scenario=scenario)
|
|
120
|
+
# Tool usage checker (simple comparator, no LLM needed)
|
|
121
|
+
tool_checker = ToolUsageChecker(scenario=scenario)
|
|
122
|
+
|
|
123
|
+
history: list[Turn] = []
|
|
124
|
+
logger = ConversationLogger(log_file)
|
|
125
|
+
logger.initialize_session(scenario, agent_model_name, sim_user_model_name, checker_model_name, personality)
|
|
126
|
+
total_usage = Usage()
|
|
127
|
+
|
|
128
|
+
# Seed: ask the simulated user for the first message based on the first task
|
|
129
|
+
user_message = await sim_user.next_message(history=history)
|
|
130
|
+
|
|
131
|
+
turns_for_current_task = 0
|
|
132
|
+
current_task_id = None
|
|
133
|
+
status = SimulationStatus.RUNNING
|
|
152
134
|
error_message: str | None = None
|
|
153
135
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
136
|
+
try:
|
|
137
|
+
for turn_idx in range(1, max_turns_scenario + 1):
|
|
138
|
+
current_task = sim_user.get_current_task()
|
|
139
|
+
if current_task is None:
|
|
140
|
+
print("\nAll tasks completed!")
|
|
141
|
+
status = SimulationStatus.COMPLETED
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
current_task_index = sim_user.current_task_idx
|
|
145
|
+
|
|
146
|
+
# Track turns per task
|
|
147
|
+
if current_task_id != id(current_task):
|
|
148
|
+
# New task started, reset counter
|
|
149
|
+
turns_for_current_task = 0
|
|
150
|
+
current_task_id = id(current_task)
|
|
151
|
+
|
|
152
|
+
# Check if we've exceeded max turns for this task
|
|
153
|
+
if max_turns_task is not None and turns_for_current_task >= max_turns_task:
|
|
154
|
+
print(f"\nReached maximum number of turns ({max_turns_task}) for current task. Exiting.")
|
|
155
|
+
status = SimulationStatus.TIMEOUT
|
|
156
|
+
task_final_reasons[current_task_index] = f"Timeout: exceeded {max_turns_task} turns"
|
|
157
|
+
break
|
|
158
|
+
|
|
159
|
+
turns_for_current_task += 1
|
|
160
|
+
task_turn_counts[current_task_index] = task_turn_counts.get(current_task_index, 0) + 1
|
|
161
|
+
|
|
162
|
+
print(f"\n=== Turn {turn_idx} (Task turn: {turns_for_current_task}) ===")
|
|
163
|
+
print(f"Current Task: {current_task.task}")
|
|
164
|
+
print(f"User: {user_message}")
|
|
165
|
+
|
|
166
|
+
# Notify metric collectors of turn start
|
|
167
|
+
collectors.on_turn_start(turn_idx, current_task_index, user_message)
|
|
168
|
+
|
|
169
|
+
# Add optional prefix to user message
|
|
170
|
+
full_user_message = user_message_prefix + user_message if user_message_prefix else user_message
|
|
171
|
+
|
|
172
|
+
assistant_reply_parts: list[str] = []
|
|
173
|
+
tool_calls: list[ToolCallResult] = []
|
|
174
|
+
turn_usage: Usage = Usage()
|
|
175
|
+
dummy_chat_context = ChatContext()
|
|
176
|
+
|
|
177
|
+
stream = chat.chat(
|
|
178
|
+
message=full_user_message,
|
|
179
|
+
history=[
|
|
180
|
+
d
|
|
181
|
+
for turn in history
|
|
182
|
+
for d in ({"role": "user", "text": turn.user}, {"role": "assistant", "text": turn.assistant})
|
|
183
|
+
],
|
|
184
|
+
context=dummy_chat_context,
|
|
191
185
|
)
|
|
192
|
-
)
|
|
193
|
-
|
|
194
|
-
def set_task_result(self, task_index: int, completed: bool, reason: str) -> None:
|
|
195
|
-
"""Set the result for a task."""
|
|
196
|
-
task_data = self.ensure_task_data(task_index)
|
|
197
|
-
task_data.completed = completed
|
|
198
|
-
task_data.final_reason = reason
|
|
199
|
-
|
|
200
|
-
def on_new_task(self, task_id: int) -> None:
|
|
201
|
-
"""Handle transition to a new task."""
|
|
202
|
-
self.turns_for_current_task = 0
|
|
203
|
-
self.current_task_id = task_id
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
@dataclass
|
|
207
|
-
class SimulationContext:
|
|
208
|
-
"""Environment and dependencies for running a simulation."""
|
|
209
|
-
|
|
210
|
-
config: SimulationConfig
|
|
211
|
-
chat: ChatInterface
|
|
212
|
-
sim_user: SimulatedUser
|
|
213
|
-
checker_context: CheckerContext
|
|
214
|
-
logger: ConversationLogger
|
|
215
|
-
collectors: CompositeMetricCollector
|
|
216
|
-
out: IO[str]
|
|
217
|
-
trace_handler: MemoryTraceHandler | None = None
|
|
218
|
-
progress_callback: ProgressCallback | None = None
|
|
219
|
-
|
|
220
|
-
async def emit_progress(self, event_type: str, **kwargs: Any) -> None:
|
|
221
|
-
"""Emit a progress callback if one is configured."""
|
|
222
|
-
if self.progress_callback:
|
|
223
|
-
await self.progress_callback(event_type, **kwargs)
|
|
224
|
-
|
|
225
|
-
def print(self, message: str) -> None:
|
|
226
|
-
"""Print a message to the output stream."""
|
|
227
|
-
print(message, file=self.out)
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
async def _process_chat_stream(
|
|
231
|
-
ctx: SimulationContext,
|
|
232
|
-
state: SimulationState,
|
|
233
|
-
turn_idx: int,
|
|
234
|
-
task_index: int,
|
|
235
|
-
full_user_message: str,
|
|
236
|
-
user_message: str,
|
|
237
|
-
) -> tuple[str, list[ToolCallResult], Usage]:
|
|
238
|
-
"""Process the chat stream and collect response data.
|
|
239
|
-
|
|
240
|
-
Tool calls and usage are extracted from traces rather than from the stream,
|
|
241
|
-
providing more reliable and complete data.
|
|
242
|
-
|
|
243
|
-
Returns:
|
|
244
|
-
Tuple of (assistant_reply, tool_calls, turn_usage)
|
|
245
|
-
"""
|
|
246
|
-
assistant_reply_parts: list[str] = []
|
|
247
|
-
|
|
248
|
-
# Track span count before the stream to identify new spans from this turn
|
|
249
|
-
span_count_before = len(ctx.trace_handler.root_spans) if ctx.trace_handler else 0
|
|
250
|
-
|
|
251
|
-
stream = ctx.chat.chat(
|
|
252
|
-
message=full_user_message,
|
|
253
|
-
history=[
|
|
254
|
-
d
|
|
255
|
-
for turn in state.history
|
|
256
|
-
for d in ({"role": "user", "text": turn.user}, {"role": "assistant", "text": turn.assistant})
|
|
257
|
-
],
|
|
258
|
-
context=state.chat_context,
|
|
259
|
-
)
|
|
260
186
|
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
187
|
+
async for chunk in stream:
|
|
188
|
+
if isinstance(chunk, str):
|
|
189
|
+
assistant_reply_parts.append(chunk)
|
|
190
|
+
elif isinstance(chunk, ToolCallResult):
|
|
191
|
+
tool_calls.append(chunk)
|
|
192
|
+
elif isinstance(chunk, Usage):
|
|
193
|
+
turn_usage += chunk
|
|
194
|
+
|
|
195
|
+
total_usage += turn_usage
|
|
196
|
+
assistant_reply = "".join(assistant_reply_parts).strip()
|
|
197
|
+
print(f"Assistant: {assistant_reply}")
|
|
198
|
+
if tool_calls:
|
|
199
|
+
print(f"Tools used: {[tc.name for tc in tool_calls]}")
|
|
200
|
+
print(
|
|
201
|
+
f"Assistant token usage: {turn_usage.total_tokens} total "
|
|
202
|
+
f"({turn_usage.prompt_tokens} prompt + {turn_usage.completion_tokens} completion), "
|
|
203
|
+
f"estimated cost: ${turn_usage.estimated_cost:.6f}"
|
|
204
|
+
)
|
|
205
|
+
logger.log_turn(turn_idx, current_task, user_message, assistant_reply, tool_calls, turn_usage)
|
|
206
|
+
|
|
207
|
+
# Update simulation-visible history (Turn objects)
|
|
208
|
+
history.append(Turn(user=user_message, assistant=assistant_reply))
|
|
209
|
+
|
|
210
|
+
# Ask the judge if current task is achieved
|
|
211
|
+
task_done, reason = await goal_checker.is_task_achieved(current_task, history, context=domain_context)
|
|
212
|
+
logger.log_task_check(turn_idx, task_done, reason)
|
|
213
|
+
|
|
214
|
+
# Check tool usage if expected tools are specified
|
|
215
|
+
if current_task.expected_tools:
|
|
216
|
+
tools_used_correctly, tool_reason = tool_checker.check_tool_usage(current_task, tool_calls)
|
|
217
|
+
logger.log_tool_check(turn_idx, tools_used_correctly, tool_reason, tool_calls)
|
|
218
|
+
if not tools_used_correctly:
|
|
219
|
+
print(f"Tool usage issue: {tool_reason}")
|
|
220
|
+
else:
|
|
221
|
+
print(f"Tool usage verified: {tool_reason}")
|
|
222
|
+
|
|
223
|
+
# Create turn result
|
|
224
|
+
turn_result = TurnResult(
|
|
267
225
|
turn_index=turn_idx,
|
|
268
|
-
task_index=
|
|
269
|
-
|
|
270
|
-
|
|
226
|
+
task_index=current_task_index,
|
|
227
|
+
user_message=user_message,
|
|
228
|
+
assistant_message=assistant_reply,
|
|
229
|
+
tool_calls=[{"name": tc.name, "arguments": tc.arguments, "result": tc.result} for tc in tool_calls],
|
|
230
|
+
task_completed=task_done,
|
|
231
|
+
task_completed_reason=reason,
|
|
232
|
+
token_usage={
|
|
233
|
+
"total": turn_usage.total_tokens,
|
|
234
|
+
"prompt": turn_usage.prompt_tokens,
|
|
235
|
+
"completion": turn_usage.completion_tokens,
|
|
236
|
+
},
|
|
271
237
|
)
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
) -> SimulationResult:
|
|
309
|
-
"""Build the final SimulationResult from state and metrics."""
|
|
310
|
-
tasks_completed = sum(1 for t in task_results if t.completed)
|
|
311
|
-
total_tasks = len(scenario.tasks)
|
|
312
|
-
|
|
313
|
-
metrics_dict: dict[str, Any] = {
|
|
314
|
-
"total_turns": len(state.turn_results),
|
|
315
|
-
"total_tasks": total_tasks,
|
|
316
|
-
"tasks_completed": tasks_completed,
|
|
317
|
-
"success_rate": tasks_completed / total_tasks if total_tasks > 0 else 0.0,
|
|
318
|
-
}
|
|
319
|
-
metrics_dict.update(collector_metrics)
|
|
320
|
-
|
|
321
|
-
return SimulationResult(
|
|
322
|
-
scenario_name=scenario.name,
|
|
323
|
-
start_time=state.start_time,
|
|
324
|
-
end_time=datetime.now(timezone.utc),
|
|
325
|
-
status=state.status,
|
|
326
|
-
agent_model=ctx.config.agent_model_name,
|
|
327
|
-
simulated_user_model=ctx.config.sim_user_model_name,
|
|
328
|
-
checker_model=ctx.config.checker_model_name,
|
|
329
|
-
persona=personality.name if personality else None,
|
|
330
|
-
error=state.error_message,
|
|
331
|
-
conversation_id=state.chat_context.conversation_id,
|
|
332
|
-
final_state=state.chat_context.state,
|
|
333
|
-
turns=state.turn_results,
|
|
334
|
-
tasks=task_results,
|
|
335
|
-
metrics=ConversationMetrics(metrics=metrics_dict),
|
|
336
|
-
traces=traces,
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
def _build_task_results(
|
|
341
|
-
scenario: Scenario,
|
|
342
|
-
ctx: SimulationContext,
|
|
343
|
-
state: SimulationState,
|
|
344
|
-
) -> list[TaskResult]:
|
|
345
|
-
"""Build task results from scenario and state."""
|
|
346
|
-
task_results: list[TaskResult] = []
|
|
238
|
+
turn_results.append(turn_result)
|
|
239
|
+
|
|
240
|
+
# Notify metric collectors of turn end
|
|
241
|
+
collectors.on_turn_end(turn_result)
|
|
242
|
+
|
|
243
|
+
if task_done:
|
|
244
|
+
print(f"Task completed: {reason}")
|
|
245
|
+
task_final_reasons[current_task_index] = reason
|
|
246
|
+
has_next = sim_user.advance_to_next_task()
|
|
247
|
+
if not has_next:
|
|
248
|
+
print("\nAll tasks completed!")
|
|
249
|
+
status = SimulationStatus.COMPLETED
|
|
250
|
+
break
|
|
251
|
+
next_task = sim_user.get_current_task()
|
|
252
|
+
if next_task:
|
|
253
|
+
logger.log_task_transition(next_task)
|
|
254
|
+
# Reset task turn counter when moving to next task
|
|
255
|
+
turns_for_current_task = 0
|
|
256
|
+
current_task_id = id(next_task) if next_task else None
|
|
257
|
+
else:
|
|
258
|
+
print(f"Task not completed: {reason}")
|
|
259
|
+
task_final_reasons[current_task_index] = reason
|
|
260
|
+
|
|
261
|
+
# Ask the simulator for the next user message
|
|
262
|
+
user_message = await sim_user.next_message(history)
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
print("\nReached maximum number of turns. Exiting.")
|
|
266
|
+
status = SimulationStatus.TIMEOUT
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
status = SimulationStatus.FAILED
|
|
270
|
+
error_message = str(e)
|
|
271
|
+
print(f"\nSimulation failed with error: {error_message}")
|
|
272
|
+
|
|
273
|
+
# Build task results
|
|
347
274
|
for i, task in enumerate(scenario.tasks):
|
|
348
|
-
completed = i <
|
|
349
|
-
i ==
|
|
275
|
+
completed = i < sim_user.current_task_idx or (
|
|
276
|
+
i == sim_user.current_task_idx and status == SimulationStatus.COMPLETED
|
|
350
277
|
)
|
|
351
278
|
task_results.append(
|
|
352
279
|
TaskResult(
|
|
353
280
|
task_index=i,
|
|
354
281
|
description=task.task,
|
|
282
|
+
expected_result=task.expected_result,
|
|
355
283
|
completed=completed,
|
|
356
|
-
turns_taken=
|
|
357
|
-
final_reason=
|
|
358
|
-
checkers=task.checkers,
|
|
359
|
-
checker_mode=task.checker_mode,
|
|
284
|
+
turns_taken=task_turn_counts.get(i, 0),
|
|
285
|
+
final_reason=task_final_reasons.get(i, "Not attempted"),
|
|
360
286
|
)
|
|
361
287
|
)
|
|
362
|
-
return task_results
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
async def _handle_task_completion(
|
|
366
|
-
ctx: SimulationContext,
|
|
367
|
-
state: SimulationState,
|
|
368
|
-
current_task_index: int,
|
|
369
|
-
current_task: Task,
|
|
370
|
-
reason: str,
|
|
371
|
-
) -> bool:
|
|
372
|
-
"""Handle task completion. Returns True if simulation should continue."""
|
|
373
|
-
ctx.print(f"Task completed: {reason}")
|
|
374
|
-
state.set_task_result(current_task_index, True, reason)
|
|
375
|
-
|
|
376
|
-
await ctx.emit_progress(
|
|
377
|
-
"task_complete",
|
|
378
|
-
task_index=current_task_index,
|
|
379
|
-
task_description=current_task.task,
|
|
380
|
-
turns_taken=state.get_task_turn_count(current_task_index),
|
|
381
|
-
reason=reason,
|
|
382
|
-
)
|
|
383
288
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
next_task = ctx.sim_user.get_current_task()
|
|
391
|
-
if next_task:
|
|
392
|
-
ctx.logger.log_task_transition(next_task)
|
|
393
|
-
state.on_new_task(id(next_task) if next_task else None)
|
|
394
|
-
return True
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
async def _run_turn_checkers(
|
|
398
|
-
ctx: SimulationContext,
|
|
399
|
-
state: SimulationState,
|
|
400
|
-
current_task: Task,
|
|
401
|
-
current_task_index: int,
|
|
402
|
-
tool_calls: list[ToolCallResult],
|
|
403
|
-
) -> CheckerResult:
|
|
404
|
-
"""Run checkers for the current turn and return the result."""
|
|
405
|
-
if current_task_index not in state.task_checkers_cache:
|
|
406
|
-
state.task_checkers_cache[current_task_index] = current_task.get_parsed_checkers()
|
|
407
|
-
|
|
408
|
-
return await run_checkers(
|
|
409
|
-
checkers=state.task_checkers_cache[current_task_index],
|
|
410
|
-
task=current_task,
|
|
411
|
-
history=state.history,
|
|
412
|
-
tool_calls=tool_calls,
|
|
413
|
-
state=state.chat_context.state,
|
|
414
|
-
context=ctx.checker_context,
|
|
415
|
-
mode=current_task.checker_mode,
|
|
289
|
+
# Print total token usage summary
|
|
290
|
+
print("\n=== Total Assistant Token Usage ===")
|
|
291
|
+
print(
|
|
292
|
+
f"Total assistant tokens: {total_usage.total_tokens} "
|
|
293
|
+
f"({total_usage.prompt_tokens} prompt + {total_usage.completion_tokens} completion)"
|
|
416
294
|
)
|
|
295
|
+
print(f"Total estimated cost: ${total_usage.estimated_cost:.6f}")
|
|
296
|
+
logger.log_total_usage(total_usage)
|
|
417
297
|
|
|
298
|
+
# Evaluate conversation with DeepEval metrics
|
|
299
|
+
deepeval_scores = _evaluate_with_deepeval(history, logger)
|
|
418
300
|
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
turn_idx: int,
|
|
422
|
-
current_task_index: int,
|
|
423
|
-
current_task: Task,
|
|
424
|
-
user_message: str,
|
|
425
|
-
assistant_reply: str,
|
|
426
|
-
tool_calls_data: list[dict[str, Any]],
|
|
427
|
-
checker_result: CheckerResult,
|
|
428
|
-
checkers_list: list[CheckerResultItem],
|
|
429
|
-
checker_mode: str,
|
|
430
|
-
) -> None:
|
|
431
|
-
"""Emit progress events for checker decision and turn completion."""
|
|
432
|
-
task_done = checker_result.completed
|
|
433
|
-
reason = checker_result.reason
|
|
434
|
-
|
|
435
|
-
await ctx.emit_progress(
|
|
436
|
-
"response_chunk",
|
|
437
|
-
turn_index=turn_idx,
|
|
438
|
-
task_index=current_task_index,
|
|
439
|
-
chunk_type="checker_decision",
|
|
440
|
-
chunk_data={
|
|
441
|
-
"task_completed": task_done,
|
|
442
|
-
"reason": reason,
|
|
443
|
-
"task_description": current_task.task,
|
|
444
|
-
"checkers": [c.to_dict() for c in checkers_list],
|
|
445
|
-
"checker_mode": checker_mode,
|
|
446
|
-
},
|
|
447
|
-
)
|
|
448
|
-
|
|
449
|
-
await ctx.emit_progress(
|
|
450
|
-
"turn",
|
|
451
|
-
turn_index=turn_idx,
|
|
452
|
-
task_index=current_task_index,
|
|
453
|
-
user_message=user_message,
|
|
454
|
-
assistant_message=assistant_reply,
|
|
455
|
-
tool_calls=tool_calls_data,
|
|
456
|
-
task_completed=task_done,
|
|
457
|
-
task_completed_reason=reason,
|
|
458
|
-
checkers=[c.to_dict() for c in checkers_list],
|
|
459
|
-
checker_mode=checker_mode,
|
|
460
|
-
)
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
def _create_simulation_context(
|
|
464
|
-
scenario: Scenario,
|
|
465
|
-
chat: ChatInterface,
|
|
466
|
-
config: SimulationConfig,
|
|
467
|
-
personality: Personality | None,
|
|
468
|
-
progress_callback: ProgressCallback | None,
|
|
469
|
-
output_stream: IO[str] | None,
|
|
470
|
-
trace_handler: MemoryTraceHandler | None = None,
|
|
471
|
-
) -> SimulationContext:
|
|
472
|
-
"""Create the simulation context with all dependencies."""
|
|
473
|
-
out = output_stream if output_stream is not None else sys.stdout
|
|
474
|
-
|
|
475
|
-
# Create metric collectors
|
|
476
|
-
builtin_collectors = [
|
|
477
|
-
LatencyMetricCollector(),
|
|
478
|
-
TokenUsageMetricCollector(),
|
|
479
|
-
ToolUsageMetricCollector(),
|
|
480
|
-
]
|
|
481
|
-
user_collectors = config.create_metric_collectors()
|
|
482
|
-
collectors = CompositeMetricCollector(builtin_collectors + user_collectors)
|
|
483
|
-
|
|
484
|
-
# Simulated user
|
|
485
|
-
sim_user = SimulatedUser(
|
|
486
|
-
llm=build_llm(config.sim_user_model_name, config.default_model, config.api_key),
|
|
487
|
-
scenario=scenario,
|
|
488
|
-
personality=personality,
|
|
489
|
-
data_snapshot=config.data_snapshot,
|
|
490
|
-
)
|
|
491
|
-
|
|
492
|
-
# Checker context
|
|
493
|
-
checker_llm = build_llm(config.checker_model_name, config.default_model, config.api_key)
|
|
494
|
-
checker_context = CheckerContext(llm=checker_llm, domain_context=config.domain_context)
|
|
495
|
-
|
|
496
|
-
# Logger
|
|
497
|
-
logger = ConversationLogger(config.log_file)
|
|
498
|
-
logger.initialize_session(
|
|
499
|
-
scenario, config.agent_model_name, config.sim_user_model_name, config.checker_model_name, personality
|
|
500
|
-
)
|
|
501
|
-
|
|
502
|
-
return SimulationContext(
|
|
503
|
-
config=config,
|
|
504
|
-
chat=chat,
|
|
505
|
-
sim_user=sim_user,
|
|
506
|
-
checker_context=checker_context,
|
|
507
|
-
logger=logger,
|
|
508
|
-
collectors=collectors,
|
|
509
|
-
out=out,
|
|
510
|
-
trace_handler=trace_handler,
|
|
511
|
-
progress_callback=progress_callback,
|
|
512
|
-
)
|
|
301
|
+
# Collect custom metrics from collectors
|
|
302
|
+
custom_metrics = collectors.on_conversation_end(turn_results)
|
|
513
303
|
|
|
304
|
+
logger.finalize_session()
|
|
514
305
|
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
ctx.collectors.on_turn_start(turn_idx, current_task_index, user_message)
|
|
529
|
-
|
|
530
|
-
full_user_message = (
|
|
531
|
-
ctx.config.user_message_prefix + user_message if ctx.config.user_message_prefix else user_message
|
|
532
|
-
)
|
|
533
|
-
|
|
534
|
-
# Process assistant response
|
|
535
|
-
assistant_reply, tool_calls, turn_usage = await _process_chat_stream(
|
|
536
|
-
ctx, state, turn_idx, current_task_index, full_user_message, user_message
|
|
537
|
-
)
|
|
538
|
-
|
|
539
|
-
state.total_usage += turn_usage
|
|
540
|
-
_log_turn_output(ctx, assistant_reply, tool_calls, turn_usage)
|
|
541
|
-
ctx.logger.log_turn(turn_idx, current_task, user_message, assistant_reply, tool_calls, turn_usage)
|
|
542
|
-
|
|
543
|
-
# Update state
|
|
544
|
-
state.history.append(Turn(user=user_message, assistant=assistant_reply))
|
|
545
|
-
state.record_turn(current_task_index, user_message, assistant_reply, turn_usage, tool_calls)
|
|
546
|
-
|
|
547
|
-
# Run checkers
|
|
548
|
-
checker_result = await _run_turn_checkers(ctx, state, current_task, current_task_index, tool_calls)
|
|
549
|
-
ctx.logger.log_task_check(turn_idx, checker_result.completed, checker_result.reason)
|
|
550
|
-
|
|
551
|
-
# Build and record turn result
|
|
552
|
-
checkers_list, checker_mode = _build_checker_results(checker_result)
|
|
553
|
-
tool_calls_data = [{"name": tc.name, "arguments": tc.arguments, "result": tc.result} for tc in tool_calls]
|
|
554
|
-
|
|
555
|
-
turn_result = TurnResult(
|
|
556
|
-
turn_index=turn_idx,
|
|
557
|
-
task_index=current_task_index,
|
|
558
|
-
user_message=user_message,
|
|
559
|
-
assistant_message=assistant_reply,
|
|
560
|
-
tool_calls=tool_calls_data,
|
|
561
|
-
task_completed=checker_result.completed,
|
|
562
|
-
task_completed_reason=checker_result.reason,
|
|
563
|
-
token_usage=turn_usage,
|
|
564
|
-
checkers=checkers_list,
|
|
565
|
-
checker_mode=checker_mode,
|
|
566
|
-
)
|
|
567
|
-
state.turn_results.append(turn_result)
|
|
568
|
-
ctx.collectors.on_turn_end(turn_result)
|
|
569
|
-
|
|
570
|
-
# Emit progress
|
|
571
|
-
await _emit_turn_progress(
|
|
572
|
-
ctx,
|
|
573
|
-
turn_idx,
|
|
574
|
-
current_task_index,
|
|
575
|
-
current_task,
|
|
576
|
-
user_message,
|
|
577
|
-
assistant_reply,
|
|
578
|
-
tool_calls_data,
|
|
579
|
-
checker_result,
|
|
580
|
-
checkers_list,
|
|
581
|
-
checker_mode,
|
|
306
|
+
# Build metrics
|
|
307
|
+
tasks_completed = sum(1 for t in task_results if t.completed)
|
|
308
|
+
metrics = ConversationMetrics(
|
|
309
|
+
total_turns=len(turn_results),
|
|
310
|
+
total_tasks=len(scenario.tasks),
|
|
311
|
+
tasks_completed=tasks_completed,
|
|
312
|
+
total_tokens=total_usage.total_tokens,
|
|
313
|
+
prompt_tokens=total_usage.prompt_tokens,
|
|
314
|
+
completion_tokens=total_usage.completion_tokens,
|
|
315
|
+
total_cost_usd=total_usage.estimated_cost,
|
|
316
|
+
deepeval_scores=deepeval_scores,
|
|
317
|
+
custom=custom_metrics,
|
|
582
318
|
)
|
|
583
319
|
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
if
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
next_message = await ctx.sim_user.next_message(state.history)
|
|
599
|
-
return True, next_message
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
def _log_turn_output(
|
|
603
|
-
ctx: SimulationContext,
|
|
604
|
-
assistant_reply: str,
|
|
605
|
-
tool_calls: list[ToolCallResult],
|
|
606
|
-
turn_usage: Usage,
|
|
607
|
-
) -> None:
|
|
608
|
-
"""Log turn output to the output stream."""
|
|
609
|
-
ctx.print(f"Assistant: {assistant_reply}")
|
|
610
|
-
if tool_calls:
|
|
611
|
-
ctx.print(f"Tools used: {[tc.name for tc in tool_calls]}")
|
|
612
|
-
ctx.print(
|
|
613
|
-
f"Assistant token usage: {turn_usage.total_tokens} total "
|
|
614
|
-
f"({turn_usage.prompt_tokens} prompt + {turn_usage.completion_tokens} completion), "
|
|
615
|
-
f"estimated cost: ${turn_usage.estimated_cost:.6f}"
|
|
320
|
+
return SimulationResult(
|
|
321
|
+
scenario_name=scenario.name,
|
|
322
|
+
start_time=start_time,
|
|
323
|
+
end_time=datetime.now(timezone.utc),
|
|
324
|
+
status=status,
|
|
325
|
+
agent_model=agent_model_name,
|
|
326
|
+
simulated_user_model=sim_user_model_name,
|
|
327
|
+
checker_model=checker_model_name,
|
|
328
|
+
personality=personality.name if personality else None,
|
|
329
|
+
error=error_message,
|
|
330
|
+
turns=turn_results,
|
|
331
|
+
tasks=task_results,
|
|
332
|
+
metrics=metrics,
|
|
616
333
|
)
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
async def _run_simulation_loop(
|
|
620
|
-
ctx: SimulationContext,
|
|
621
|
-
state: SimulationState,
|
|
622
|
-
user_message: str,
|
|
623
|
-
) -> None:
|
|
624
|
-
"""Run the main simulation loop."""
|
|
625
|
-
for turn_idx in range(1, ctx.config.max_turns_scenario + 1):
|
|
626
|
-
current_task = ctx.sim_user.get_current_task()
|
|
627
|
-
if current_task is None:
|
|
628
|
-
ctx.print("\nAll tasks completed!")
|
|
629
|
-
state.status = SimulationStatus.COMPLETED
|
|
630
|
-
return
|
|
631
|
-
|
|
632
|
-
current_task_index = ctx.sim_user.current_task_idx
|
|
633
|
-
|
|
634
|
-
if state.current_task_id != id(current_task):
|
|
635
|
-
state.on_new_task(id(current_task))
|
|
636
|
-
|
|
637
|
-
if ctx.config.max_turns_task is not None and state.turns_for_current_task >= ctx.config.max_turns_task:
|
|
638
|
-
ctx.print(f"\nReached maximum number of turns ({ctx.config.max_turns_task}) for current task. Exiting.")
|
|
639
|
-
state.status = SimulationStatus.TIMEOUT
|
|
640
|
-
state.set_task_result(current_task_index, False, f"Timeout: exceeded {ctx.config.max_turns_task} turns")
|
|
641
|
-
return
|
|
642
|
-
|
|
643
|
-
state.turns_for_current_task += 1
|
|
644
|
-
state.ensure_task_data(current_task_index)
|
|
645
|
-
|
|
646
|
-
should_continue, user_message = await _execute_turn(
|
|
647
|
-
ctx, state, turn_idx, current_task, current_task_index, user_message
|
|
648
|
-
)
|
|
649
|
-
if not should_continue:
|
|
650
|
-
return
|
|
651
|
-
|
|
652
|
-
ctx.print("\nReached maximum number of turns. Exiting.")
|
|
653
|
-
state.status = SimulationStatus.TIMEOUT
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
async def run_simulation(
|
|
657
|
-
scenario: Scenario,
|
|
658
|
-
chat: ChatInterface,
|
|
659
|
-
config: SimulationConfig | None = None,
|
|
660
|
-
*,
|
|
661
|
-
personality: Personality | None = None,
|
|
662
|
-
progress_callback: ProgressCallback | None = None,
|
|
663
|
-
output_stream: IO[str] | None = None,
|
|
664
|
-
) -> SimulationResult:
|
|
665
|
-
"""Run a conversation between an agent and a simulated user.
|
|
666
|
-
|
|
667
|
-
The conversation proceeds through tasks defined in the scenario, with a goal checker
|
|
668
|
-
determining when each task is completed. Returns structured results for programmatic access.
|
|
669
|
-
|
|
670
|
-
Args:
|
|
671
|
-
scenario: The scenario containing tasks to complete
|
|
672
|
-
chat: An instantiated ChatInterface used to drive the assistant side of the conversation.
|
|
673
|
-
config: Optional SimulationConfig with all simulation parameters. If not provided,
|
|
674
|
-
uses default values.
|
|
675
|
-
personality: Optional personality to use for the simulated user.
|
|
676
|
-
progress_callback: Optional async callback for progress updates (event_type, **kwargs)
|
|
677
|
-
output_stream: Optional stream for simulation output (print statements).
|
|
678
|
-
If None, defaults to sys.stdout. Pass a file handle, StringIO,
|
|
679
|
-
or any file-like object to redirect output. Use io.StringIO()
|
|
680
|
-
to capture output or a custom stream to integrate with logging.
|
|
681
|
-
|
|
682
|
-
Returns:
|
|
683
|
-
SimulationResult containing all turns, task results, and metrics
|
|
684
|
-
"""
|
|
685
|
-
if config is None:
|
|
686
|
-
config = SimulationConfig()
|
|
687
|
-
|
|
688
|
-
state = SimulationState()
|
|
689
|
-
|
|
690
|
-
# Collect traces from all traceable operations during the simulation
|
|
691
|
-
with collect_traces(simulation_id=state.simulation_id) as trace_handler:
|
|
692
|
-
ctx = _create_simulation_context(
|
|
693
|
-
scenario=scenario,
|
|
694
|
-
chat=chat,
|
|
695
|
-
config=config,
|
|
696
|
-
personality=personality,
|
|
697
|
-
progress_callback=progress_callback,
|
|
698
|
-
output_stream=output_stream,
|
|
699
|
-
trace_handler=trace_handler,
|
|
700
|
-
)
|
|
701
|
-
|
|
702
|
-
user_message = await ctx.sim_user.next_message(history=state.history)
|
|
703
|
-
|
|
704
|
-
try:
|
|
705
|
-
await _run_simulation_loop(ctx, state, user_message)
|
|
706
|
-
except Exception as e:
|
|
707
|
-
state.status = SimulationStatus.FAILED
|
|
708
|
-
state.error_message = str(e)
|
|
709
|
-
ctx.print(f"\nSimulation failed with error: {state.error_message}")
|
|
710
|
-
|
|
711
|
-
traces = trace_handler.get_traces()
|
|
712
|
-
|
|
713
|
-
task_results = _build_task_results(scenario, ctx, state)
|
|
714
|
-
ctx.logger.log_total_usage(state.total_usage)
|
|
715
|
-
collector_metrics = ctx.collectors.on_conversation_end(state.turn_results)
|
|
716
|
-
ctx.logger.finalize_session()
|
|
717
|
-
|
|
718
|
-
return _build_simulation_result(scenario, ctx, state, task_results, collector_metrics, personality, traces)
|