contextforge-eval 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- context_forge/__init__.py +95 -0
- context_forge/core/__init__.py +55 -0
- context_forge/core/trace.py +369 -0
- context_forge/core/types.py +121 -0
- context_forge/evaluation.py +267 -0
- context_forge/exceptions.py +56 -0
- context_forge/graders/__init__.py +44 -0
- context_forge/graders/base.py +264 -0
- context_forge/graders/deterministic/__init__.py +11 -0
- context_forge/graders/deterministic/memory_corruption.py +130 -0
- context_forge/graders/hybrid.py +190 -0
- context_forge/graders/judges/__init__.py +11 -0
- context_forge/graders/judges/backends/__init__.py +9 -0
- context_forge/graders/judges/backends/ollama.py +173 -0
- context_forge/graders/judges/base.py +158 -0
- context_forge/graders/judges/memory_hygiene_judge.py +332 -0
- context_forge/graders/judges/models.py +113 -0
- context_forge/harness/__init__.py +43 -0
- context_forge/harness/user_simulator/__init__.py +70 -0
- context_forge/harness/user_simulator/adapters/__init__.py +13 -0
- context_forge/harness/user_simulator/adapters/base.py +67 -0
- context_forge/harness/user_simulator/adapters/crewai.py +100 -0
- context_forge/harness/user_simulator/adapters/langgraph.py +157 -0
- context_forge/harness/user_simulator/adapters/pydanticai.py +105 -0
- context_forge/harness/user_simulator/llm/__init__.py +5 -0
- context_forge/harness/user_simulator/llm/ollama.py +119 -0
- context_forge/harness/user_simulator/models.py +103 -0
- context_forge/harness/user_simulator/persona.py +154 -0
- context_forge/harness/user_simulator/runner.py +342 -0
- context_forge/harness/user_simulator/scenario.py +95 -0
- context_forge/harness/user_simulator/simulator.py +307 -0
- context_forge/instrumentation/__init__.py +23 -0
- context_forge/instrumentation/base.py +307 -0
- context_forge/instrumentation/instrumentors/__init__.py +17 -0
- context_forge/instrumentation/instrumentors/langchain.py +671 -0
- context_forge/instrumentation/instrumentors/langgraph.py +534 -0
- context_forge/instrumentation/tracer.py +588 -0
- context_forge/py.typed +0 -0
- contextforge_eval-0.1.0.dist-info/METADATA +420 -0
- contextforge_eval-0.1.0.dist-info/RECORD +43 -0
- contextforge_eval-0.1.0.dist-info/WHEEL +5 -0
- contextforge_eval-0.1.0.dist-info/licenses/LICENSE +201 -0
- contextforge_eval-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""Persona and behavior definitions for user simulation."""
|
|
2
|
+
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CommunicationStyle(str, Enum):
|
|
10
|
+
"""How the persona communicates."""
|
|
11
|
+
|
|
12
|
+
CONCISE = "concise"
|
|
13
|
+
VERBOSE = "verbose"
|
|
14
|
+
CASUAL = "casual"
|
|
15
|
+
FORMAL = "formal"
|
|
16
|
+
CONFUSED = "confused"
|
|
17
|
+
IMPATIENT = "impatient"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TechnicalLevel(str, Enum):
|
|
21
|
+
"""Technical sophistication of the persona."""
|
|
22
|
+
|
|
23
|
+
NOVICE = "novice"
|
|
24
|
+
INTERMEDIATE = "intermediate"
|
|
25
|
+
EXPERT = "expert"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Behavior(BaseModel):
|
|
29
|
+
"""Behavioral traits that influence response generation."""
|
|
30
|
+
|
|
31
|
+
communication_style: CommunicationStyle = CommunicationStyle.CASUAL
|
|
32
|
+
technical_level: TechnicalLevel = TechnicalLevel.INTERMEDIATE
|
|
33
|
+
patience_level: int = Field(default=5, ge=1, le=10)
|
|
34
|
+
|
|
35
|
+
# Response patterns
|
|
36
|
+
asks_followup_questions: bool = True
|
|
37
|
+
provides_context_upfront: bool = True
|
|
38
|
+
corrects_misunderstandings: bool = True
|
|
39
|
+
|
|
40
|
+
# Conversation dynamics
|
|
41
|
+
topic_drift_probability: float = Field(default=0.1, ge=0, le=1)
|
|
42
|
+
clarification_threshold: int = Field(default=2)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class Goal(BaseModel):
|
|
46
|
+
"""A specific goal the persona wants to achieve."""
|
|
47
|
+
|
|
48
|
+
description: str
|
|
49
|
+
success_criteria: str
|
|
50
|
+
priority: int = Field(default=1, ge=1, le=5)
|
|
51
|
+
is_achieved: bool = False
|
|
52
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Persona(BaseModel):
|
|
56
|
+
"""Complete persona definition for user simulation.
|
|
57
|
+
|
|
58
|
+
A persona represents a simulated user with specific characteristics,
|
|
59
|
+
goals, and behavioral traits. The LLM uses this to generate contextually
|
|
60
|
+
appropriate responses.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
persona_id: str
|
|
64
|
+
name: str
|
|
65
|
+
description: str = ""
|
|
66
|
+
|
|
67
|
+
# Context that shapes responses
|
|
68
|
+
background: str = ""
|
|
69
|
+
situation: str = ""
|
|
70
|
+
|
|
71
|
+
# Behavioral configuration
|
|
72
|
+
behavior: Behavior = Field(default_factory=Behavior)
|
|
73
|
+
|
|
74
|
+
# Goals for this conversation
|
|
75
|
+
goals: list[Goal] = Field(default_factory=list)
|
|
76
|
+
|
|
77
|
+
# Domain-specific context
|
|
78
|
+
context: dict[str, Any] = Field(default_factory=dict)
|
|
79
|
+
|
|
80
|
+
# Example phrases this persona might use
|
|
81
|
+
example_phrases: list[str] = Field(default_factory=list)
|
|
82
|
+
|
|
83
|
+
def to_system_prompt(self) -> str:
|
|
84
|
+
"""Generate system prompt for LLM-based response generation."""
|
|
85
|
+
style_desc = {
|
|
86
|
+
CommunicationStyle.CONCISE: "Keep responses brief and to the point.",
|
|
87
|
+
CommunicationStyle.VERBOSE: "Provide detailed responses with context.",
|
|
88
|
+
CommunicationStyle.CASUAL: "Use informal, conversational language.",
|
|
89
|
+
CommunicationStyle.FORMAL: "Use professional, polished language.",
|
|
90
|
+
CommunicationStyle.CONFUSED: "Often ask for clarification or express uncertainty.",
|
|
91
|
+
CommunicationStyle.IMPATIENT: "Express urgency, want quick answers.",
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
tech_desc = {
|
|
95
|
+
TechnicalLevel.NOVICE: "Avoid technical jargon. Ask for simpler explanations.",
|
|
96
|
+
TechnicalLevel.INTERMEDIATE: "Comfortable with basic domain terminology.",
|
|
97
|
+
TechnicalLevel.EXPERT: "Use technical terms confidently. Challenge vague answers.",
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
goals_str = "\n".join(
|
|
101
|
+
f"- {g.description}" for g in self.goals if not g.is_achieved
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
prompt_parts = [
|
|
105
|
+
f"You are simulating a user named {self.name}.",
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
if self.background:
|
|
109
|
+
prompt_parts.append(f"\nBackground: {self.background}")
|
|
110
|
+
|
|
111
|
+
if self.situation:
|
|
112
|
+
prompt_parts.append(f"Current Situation: {self.situation}")
|
|
113
|
+
|
|
114
|
+
prompt_parts.extend([
|
|
115
|
+
f"\nCommunication Style: {style_desc[self.behavior.communication_style]}",
|
|
116
|
+
f"Technical Level: {tech_desc[self.behavior.technical_level]}",
|
|
117
|
+
])
|
|
118
|
+
|
|
119
|
+
if goals_str:
|
|
120
|
+
prompt_parts.append(f"\nYour goals for this conversation:\n{goals_str}")
|
|
121
|
+
else:
|
|
122
|
+
prompt_parts.append("\nYour goal: Have a productive conversation")
|
|
123
|
+
|
|
124
|
+
if self.context:
|
|
125
|
+
context_str = ", ".join(f"{k}: {v}" for k, v in self.context.items())
|
|
126
|
+
prompt_parts.append(f"\nAdditional context: {context_str}")
|
|
127
|
+
|
|
128
|
+
if self.example_phrases:
|
|
129
|
+
phrases_str = ", ".join(f'"{p}"' for p in self.example_phrases[:3])
|
|
130
|
+
prompt_parts.append(f"\nExample phrases you might use: {phrases_str}")
|
|
131
|
+
|
|
132
|
+
prompt_parts.append(
|
|
133
|
+
"\n\nRespond as this user would, staying in character. "
|
|
134
|
+
"Generate only the user's message, not the agent's response."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
return "\n".join(prompt_parts)
|
|
138
|
+
|
|
139
|
+
def mark_goal_achieved(self, goal_description: str) -> bool:
|
|
140
|
+
"""Mark a goal as achieved by its description."""
|
|
141
|
+
for goal in self.goals:
|
|
142
|
+
if goal.description == goal_description:
|
|
143
|
+
goal.is_achieved = True
|
|
144
|
+
return True
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
def get_pending_goals(self) -> list[Goal]:
|
|
148
|
+
"""Get list of goals not yet achieved."""
|
|
149
|
+
return [g for g in self.goals if not g.is_achieved]
|
|
150
|
+
|
|
151
|
+
def reset_goals(self) -> None:
|
|
152
|
+
"""Reset all goals to not achieved."""
|
|
153
|
+
for goal in self.goals:
|
|
154
|
+
goal.is_achieved = False
|
|
@@ -0,0 +1,342 @@
|
|
|
1
|
+
"""Simulation runner for orchestrating user-agent conversations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import uuid
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Callable, Optional, Union
|
|
9
|
+
|
|
10
|
+
from langchain_core.messages import HumanMessage
|
|
11
|
+
|
|
12
|
+
from .adapters.base import AgentAdapter
|
|
13
|
+
from .models import (
|
|
14
|
+
ConversationRole,
|
|
15
|
+
SimulationResult,
|
|
16
|
+
SimulationState,
|
|
17
|
+
SimulationTurn,
|
|
18
|
+
)
|
|
19
|
+
from .persona import Persona
|
|
20
|
+
from .scenario import GenerativeScenario, Scenario, ScriptedScenario
|
|
21
|
+
from .simulator import LLMUserSimulator, ScriptedUserSimulator, UserSimulator
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SimulationRunner:
|
|
25
|
+
"""Orchestrates simulation runs between user simulator and agent adapter.
|
|
26
|
+
|
|
27
|
+
Handles the conversation loop, trace capture integration, and
|
|
28
|
+
termination conditions.
|
|
29
|
+
|
|
30
|
+
Example usage:
|
|
31
|
+
from context_forge.harness import SimulationRunner, LangGraphAdapter
|
|
32
|
+
|
|
33
|
+
adapter = LangGraphAdapter(graph=my_graph, ...)
|
|
34
|
+
scenario = GenerativeScenario(...)
|
|
35
|
+
|
|
36
|
+
runner = SimulationRunner(
|
|
37
|
+
adapter=adapter,
|
|
38
|
+
trace_output_dir="./traces",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
result = await runner.run(scenario)
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
adapter: AgentAdapter,
|
|
47
|
+
trace_output_dir: Optional[Union[str, Path]] = None,
|
|
48
|
+
default_max_turns: int = 20,
|
|
49
|
+
):
|
|
50
|
+
"""Initialize the simulation runner.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
adapter: Framework adapter for agent invocation
|
|
54
|
+
trace_output_dir: Directory for trace files
|
|
55
|
+
default_max_turns: Default maximum turns if not specified in scenario
|
|
56
|
+
"""
|
|
57
|
+
self._adapter = adapter
|
|
58
|
+
self._trace_output_dir = Path(trace_output_dir) if trace_output_dir else None
|
|
59
|
+
self._default_max_turns = default_max_turns
|
|
60
|
+
|
|
61
|
+
async def run(
|
|
62
|
+
self,
|
|
63
|
+
scenario: Scenario,
|
|
64
|
+
config: Optional[dict[str, Any]] = None,
|
|
65
|
+
) -> SimulationResult:
|
|
66
|
+
"""Run a complete simulation.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
scenario: Scenario definition (scripted or generative)
|
|
70
|
+
config: Additional configuration for adapter/simulator
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
SimulationResult with conversation history and metrics
|
|
74
|
+
"""
|
|
75
|
+
simulation_id = str(uuid.uuid4())
|
|
76
|
+
|
|
77
|
+
# Create user simulator based on scenario type
|
|
78
|
+
simulator = self._create_simulator(scenario)
|
|
79
|
+
|
|
80
|
+
# Initialize state
|
|
81
|
+
state = SimulationState(
|
|
82
|
+
simulation_id=simulation_id,
|
|
83
|
+
scenario_id=scenario.scenario_id,
|
|
84
|
+
persona_id=scenario.persona.persona_id,
|
|
85
|
+
max_turns=scenario.max_turns,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Initialize adapter and simulator
|
|
89
|
+
await self._adapter.initialize(config)
|
|
90
|
+
if hasattr(simulator, "initialize"):
|
|
91
|
+
await simulator.initialize()
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
# Run conversation loop
|
|
95
|
+
await self._run_conversation_loop(state, simulator, scenario)
|
|
96
|
+
|
|
97
|
+
# Mark success
|
|
98
|
+
state.status = "completed"
|
|
99
|
+
state.ended_at = datetime.now()
|
|
100
|
+
|
|
101
|
+
# Calculate metrics
|
|
102
|
+
metrics = self._calculate_metrics(state)
|
|
103
|
+
|
|
104
|
+
# Save trace if configured
|
|
105
|
+
trace_path = None
|
|
106
|
+
if self._trace_output_dir:
|
|
107
|
+
trace_path = await self._save_trace(state)
|
|
108
|
+
|
|
109
|
+
return SimulationResult(
|
|
110
|
+
simulation_id=simulation_id,
|
|
111
|
+
state=state,
|
|
112
|
+
trace_path=str(trace_path) if trace_path else None,
|
|
113
|
+
metrics=metrics,
|
|
114
|
+
success=True,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
except Exception as e:
|
|
118
|
+
state.status = "failed"
|
|
119
|
+
state.ended_at = datetime.now()
|
|
120
|
+
state.termination_reason = str(e)
|
|
121
|
+
|
|
122
|
+
return SimulationResult(
|
|
123
|
+
simulation_id=simulation_id,
|
|
124
|
+
state=state,
|
|
125
|
+
success=False,
|
|
126
|
+
error=str(e),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
finally:
|
|
130
|
+
await self._adapter.cleanup()
|
|
131
|
+
if hasattr(simulator, "cleanup"):
|
|
132
|
+
await simulator.cleanup()
|
|
133
|
+
|
|
134
|
+
async def _run_conversation_loop(
|
|
135
|
+
self,
|
|
136
|
+
state: SimulationState,
|
|
137
|
+
simulator: UserSimulator,
|
|
138
|
+
scenario: Scenario,
|
|
139
|
+
) -> None:
|
|
140
|
+
"""Execute the main conversation loop."""
|
|
141
|
+
# Get initial message
|
|
142
|
+
initial_message_text = scenario.get_initial_message()
|
|
143
|
+
initial_message = HumanMessage(content=initial_message_text)
|
|
144
|
+
|
|
145
|
+
# Add initial user turn
|
|
146
|
+
state.turns.append(SimulationTurn(
|
|
147
|
+
turn_number=0,
|
|
148
|
+
role=ConversationRole.USER,
|
|
149
|
+
message=initial_message,
|
|
150
|
+
))
|
|
151
|
+
|
|
152
|
+
# Invoke agent with initial message
|
|
153
|
+
agent_response = await self._adapter.invoke(initial_message, state)
|
|
154
|
+
|
|
155
|
+
state.turns.append(SimulationTurn(
|
|
156
|
+
turn_number=0,
|
|
157
|
+
role=ConversationRole.AGENT,
|
|
158
|
+
message=agent_response,
|
|
159
|
+
))
|
|
160
|
+
|
|
161
|
+
state.current_turn = 1
|
|
162
|
+
|
|
163
|
+
# Main loop
|
|
164
|
+
while state.current_turn < state.max_turns:
|
|
165
|
+
# Check termination
|
|
166
|
+
should_stop, reason = await simulator.should_terminate(state)
|
|
167
|
+
if should_stop:
|
|
168
|
+
state.termination_reason = reason
|
|
169
|
+
break
|
|
170
|
+
|
|
171
|
+
# Generate user response
|
|
172
|
+
try:
|
|
173
|
+
user_message = await simulator.generate_response(agent_response, state)
|
|
174
|
+
except StopIteration as e:
|
|
175
|
+
state.termination_reason = str(e)
|
|
176
|
+
break
|
|
177
|
+
|
|
178
|
+
state.turns.append(SimulationTurn(
|
|
179
|
+
turn_number=state.current_turn,
|
|
180
|
+
role=ConversationRole.USER,
|
|
181
|
+
message=user_message,
|
|
182
|
+
))
|
|
183
|
+
|
|
184
|
+
# Invoke agent
|
|
185
|
+
agent_response = await self._adapter.invoke(user_message, state)
|
|
186
|
+
|
|
187
|
+
state.turns.append(SimulationTurn(
|
|
188
|
+
turn_number=state.current_turn,
|
|
189
|
+
role=ConversationRole.AGENT,
|
|
190
|
+
message=agent_response,
|
|
191
|
+
))
|
|
192
|
+
|
|
193
|
+
# Update agent state snapshot
|
|
194
|
+
state.agent_state = self._adapter.get_state()
|
|
195
|
+
|
|
196
|
+
state.current_turn += 1
|
|
197
|
+
|
|
198
|
+
def _create_simulator(self, scenario: Scenario) -> UserSimulator:
|
|
199
|
+
"""Create appropriate simulator for scenario type."""
|
|
200
|
+
if isinstance(scenario, ScriptedScenario):
|
|
201
|
+
llm_fallback = None
|
|
202
|
+
if scenario.fallback == "generative":
|
|
203
|
+
llm_fallback = LLMUserSimulator(scenario.persona)
|
|
204
|
+
return ScriptedUserSimulator(scenario, llm_fallback)
|
|
205
|
+
else:
|
|
206
|
+
return LLMUserSimulator(scenario.persona)
|
|
207
|
+
|
|
208
|
+
def _calculate_metrics(self, state: SimulationState) -> dict[str, Any]:
|
|
209
|
+
"""Calculate simulation metrics."""
|
|
210
|
+
user_turns = [t for t in state.turns if t.role == ConversationRole.USER]
|
|
211
|
+
agent_turns = [t for t in state.turns if t.role == ConversationRole.AGENT]
|
|
212
|
+
|
|
213
|
+
duration = 0.0
|
|
214
|
+
if state.ended_at and state.started_at:
|
|
215
|
+
duration = (state.ended_at - state.started_at).total_seconds()
|
|
216
|
+
|
|
217
|
+
return {
|
|
218
|
+
"total_turns": len(state.turns),
|
|
219
|
+
"user_turns": len(user_turns),
|
|
220
|
+
"agent_turns": len(agent_turns),
|
|
221
|
+
"avg_user_message_length": (
|
|
222
|
+
sum(len(t.message.content) for t in user_turns) / max(len(user_turns), 1)
|
|
223
|
+
),
|
|
224
|
+
"avg_agent_message_length": (
|
|
225
|
+
sum(len(t.message.content) for t in agent_turns) / max(len(agent_turns), 1)
|
|
226
|
+
),
|
|
227
|
+
"duration_seconds": duration,
|
|
228
|
+
"termination_reason": state.termination_reason,
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
async def _save_trace(self, state: SimulationState) -> Path:
|
|
232
|
+
"""Save simulation state as a trace file."""
|
|
233
|
+
if not self._trace_output_dir:
|
|
234
|
+
raise ValueError("No trace output directory configured")
|
|
235
|
+
|
|
236
|
+
self._trace_output_dir.mkdir(parents=True, exist_ok=True)
|
|
237
|
+
|
|
238
|
+
trace_file = self._trace_output_dir / f"simulation_{state.simulation_id}.json"
|
|
239
|
+
|
|
240
|
+
# Convert to JSON-serializable format
|
|
241
|
+
result = SimulationResult(
|
|
242
|
+
simulation_id=state.simulation_id,
|
|
243
|
+
state=state,
|
|
244
|
+
success=True,
|
|
245
|
+
)
|
|
246
|
+
trace_data = result.to_dict()
|
|
247
|
+
|
|
248
|
+
with open(trace_file, "w") as f:
|
|
249
|
+
json.dump(trace_data, f, indent=2, default=str)
|
|
250
|
+
|
|
251
|
+
return trace_file
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class BatchSimulationRunner:
|
|
255
|
+
"""Run multiple simulations with different scenarios/configurations.
|
|
256
|
+
|
|
257
|
+
Useful for evaluation runs across multiple test cases.
|
|
258
|
+
|
|
259
|
+
Example usage:
|
|
260
|
+
def adapter_factory():
|
|
261
|
+
return LangGraphAdapter(graph=build_graph(), ...)
|
|
262
|
+
|
|
263
|
+
runner = BatchSimulationRunner(
|
|
264
|
+
adapter_factory=adapter_factory,
|
|
265
|
+
trace_output_dir="./traces",
|
|
266
|
+
parallel=True,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
results = await runner.run_all(scenarios)
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
adapter_factory: Callable[[], AgentAdapter],
|
|
275
|
+
trace_output_dir: Optional[Union[str, Path]] = None,
|
|
276
|
+
parallel: bool = False,
|
|
277
|
+
max_parallel: int = 4,
|
|
278
|
+
):
|
|
279
|
+
"""Initialize batch simulation runner.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
adapter_factory: Factory function to create adapters
|
|
283
|
+
trace_output_dir: Directory for trace files
|
|
284
|
+
parallel: Whether to run simulations in parallel
|
|
285
|
+
max_parallel: Maximum concurrent simulations
|
|
286
|
+
"""
|
|
287
|
+
self._adapter_factory = adapter_factory
|
|
288
|
+
self._trace_output_dir = Path(trace_output_dir) if trace_output_dir else None
|
|
289
|
+
self._parallel = parallel
|
|
290
|
+
self._max_parallel = max_parallel
|
|
291
|
+
|
|
292
|
+
async def run_all(
|
|
293
|
+
self,
|
|
294
|
+
scenarios: list[Scenario],
|
|
295
|
+
) -> list[SimulationResult]:
|
|
296
|
+
"""Run all scenarios and collect results.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
scenarios: List of scenarios to run
|
|
300
|
+
|
|
301
|
+
Returns:
|
|
302
|
+
List of simulation results
|
|
303
|
+
"""
|
|
304
|
+
if self._parallel:
|
|
305
|
+
return await self._run_parallel(scenarios)
|
|
306
|
+
else:
|
|
307
|
+
return await self._run_sequential(scenarios)
|
|
308
|
+
|
|
309
|
+
async def _run_sequential(
|
|
310
|
+
self,
|
|
311
|
+
scenarios: list[Scenario],
|
|
312
|
+
) -> list[SimulationResult]:
|
|
313
|
+
"""Run scenarios one at a time."""
|
|
314
|
+
results = []
|
|
315
|
+
for scenario in scenarios:
|
|
316
|
+
adapter = self._adapter_factory()
|
|
317
|
+
runner = SimulationRunner(
|
|
318
|
+
adapter=adapter,
|
|
319
|
+
trace_output_dir=self._trace_output_dir,
|
|
320
|
+
)
|
|
321
|
+
result = await runner.run(scenario)
|
|
322
|
+
results.append(result)
|
|
323
|
+
return results
|
|
324
|
+
|
|
325
|
+
async def _run_parallel(
|
|
326
|
+
self,
|
|
327
|
+
scenarios: list[Scenario],
|
|
328
|
+
) -> list[SimulationResult]:
|
|
329
|
+
"""Run scenarios in parallel with concurrency limit."""
|
|
330
|
+
semaphore = asyncio.Semaphore(self._max_parallel)
|
|
331
|
+
|
|
332
|
+
async def run_with_semaphore(scenario: Scenario) -> SimulationResult:
|
|
333
|
+
async with semaphore:
|
|
334
|
+
adapter = self._adapter_factory()
|
|
335
|
+
runner = SimulationRunner(
|
|
336
|
+
adapter=adapter,
|
|
337
|
+
trace_output_dir=self._trace_output_dir,
|
|
338
|
+
)
|
|
339
|
+
return await runner.run(scenario)
|
|
340
|
+
|
|
341
|
+
tasks = [run_with_semaphore(s) for s in scenarios]
|
|
342
|
+
return await asyncio.gather(*tasks)
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""Scenario definitions for user simulation."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from .persona import Persona
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TerminationCondition(BaseModel):
|
|
11
|
+
"""Condition that can end a simulation."""
|
|
12
|
+
|
|
13
|
+
condition_type: Literal["max_turns", "goal_achieved", "keyword", "custom"]
|
|
14
|
+
value: Any
|
|
15
|
+
description: str = ""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ScriptedTurn(BaseModel):
|
|
19
|
+
"""A pre-defined turn in a scripted scenario."""
|
|
20
|
+
|
|
21
|
+
turn_number: int
|
|
22
|
+
user_message: str
|
|
23
|
+
expected_keywords: list[str] = Field(default_factory=list)
|
|
24
|
+
allow_variation: bool = False
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ScriptedScenario(BaseModel):
|
|
28
|
+
"""A scenario with pre-defined user messages.
|
|
29
|
+
|
|
30
|
+
Useful for regression testing and specific edge case validation.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
scenario_id: str
|
|
34
|
+
name: str
|
|
35
|
+
description: str = ""
|
|
36
|
+
persona: Persona
|
|
37
|
+
|
|
38
|
+
# Pre-defined conversation script
|
|
39
|
+
turns: list[ScriptedTurn]
|
|
40
|
+
|
|
41
|
+
# What to do after script exhausted
|
|
42
|
+
fallback: Literal["loop", "generative", "terminate"] = "terminate"
|
|
43
|
+
|
|
44
|
+
# Termination conditions
|
|
45
|
+
max_turns: int = Field(default=50)
|
|
46
|
+
termination_conditions: list[TerminationCondition] = Field(default_factory=list)
|
|
47
|
+
|
|
48
|
+
def get_turn_message(self, turn_number: int) -> Optional[str]:
|
|
49
|
+
"""Get the scripted message for a turn, if available."""
|
|
50
|
+
for turn in self.turns:
|
|
51
|
+
if turn.turn_number == turn_number:
|
|
52
|
+
return turn.user_message
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def get_initial_message(self) -> str:
|
|
56
|
+
"""Get the first user message."""
|
|
57
|
+
if self.turns:
|
|
58
|
+
return self.turns[0].user_message
|
|
59
|
+
raise ValueError("Scripted scenario has no turns defined")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GenerativeScenario(BaseModel):
|
|
63
|
+
"""A scenario where user responses are LLM-generated.
|
|
64
|
+
|
|
65
|
+
The persona and goals guide response generation. More flexible
|
|
66
|
+
than scripted scenarios for exploratory testing.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
scenario_id: str
|
|
70
|
+
name: str
|
|
71
|
+
description: str = ""
|
|
72
|
+
persona: Persona
|
|
73
|
+
|
|
74
|
+
# Initial user message to start conversation
|
|
75
|
+
initial_message: str
|
|
76
|
+
|
|
77
|
+
# Constraints on response generation
|
|
78
|
+
max_turns: int = Field(default=20)
|
|
79
|
+
termination_conditions: list[TerminationCondition] = Field(default_factory=list)
|
|
80
|
+
|
|
81
|
+
# Response generation parameters
|
|
82
|
+
temperature: float = Field(default=0.7, ge=0, le=2)
|
|
83
|
+
max_response_tokens: int = Field(default=500)
|
|
84
|
+
|
|
85
|
+
# Topic boundaries
|
|
86
|
+
allowed_topics: list[str] = Field(default_factory=list)
|
|
87
|
+
forbidden_topics: list[str] = Field(default_factory=list)
|
|
88
|
+
|
|
89
|
+
def get_initial_message(self) -> str:
|
|
90
|
+
"""Get the initial user message."""
|
|
91
|
+
return self.initial_message
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Union type for all scenarios
|
|
95
|
+
Scenario = Union[ScriptedScenario, GenerativeScenario]
|