ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +87 -0
- ragbits/evaluate/agent_simulation/context.py +118 -0
- ragbits/evaluate/agent_simulation/conversation.py +333 -0
- ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
- ragbits/evaluate/agent_simulation/logger.py +165 -0
- ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
- ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
- ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
- ragbits/evaluate/agent_simulation/models.py +37 -0
- ragbits/evaluate/agent_simulation/results.py +200 -0
- ragbits/evaluate/agent_simulation/scenarios.py +129 -0
- ragbits/evaluate/agent_simulation/simulation.py +243 -0
- ragbits/evaluate/cli.py +150 -0
- ragbits/evaluate/config.py +11 -0
- ragbits/evaluate/dataloaders/__init__.py +3 -0
- ragbits/evaluate/dataloaders/base.py +95 -0
- ragbits/evaluate/dataloaders/document_search.py +61 -0
- ragbits/evaluate/dataloaders/exceptions.py +25 -0
- ragbits/evaluate/dataloaders/gaia.py +78 -0
- ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
- ragbits/evaluate/dataloaders/human_eval.py +70 -0
- ragbits/evaluate/dataloaders/question_answer.py +56 -0
- ragbits/evaluate/dataset_generator/pipeline.py +4 -4
- ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
- ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
- ragbits/evaluate/evaluator.py +178 -50
- ragbits/evaluate/factories/__init__.py +42 -0
- ragbits/evaluate/metrics/__init__.py +2 -23
- ragbits/evaluate/metrics/base.py +40 -17
- ragbits/evaluate/metrics/document_search.py +40 -23
- ragbits/evaluate/metrics/gaia.py +84 -0
- ragbits/evaluate/metrics/hotpot_qa.py +51 -0
- ragbits/evaluate/metrics/human_eval.py +105 -0
- ragbits/evaluate/metrics/question_answer.py +222 -0
- ragbits/evaluate/optimizer.py +138 -86
- ragbits/evaluate/pipelines/__init__.py +37 -0
- ragbits/evaluate/pipelines/base.py +34 -10
- ragbits/evaluate/pipelines/document_search.py +72 -67
- ragbits/evaluate/pipelines/gaia.py +249 -0
- ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
- ragbits/evaluate/pipelines/human_eval.py +323 -0
- ragbits/evaluate/pipelines/question_answer.py +96 -0
- ragbits/evaluate/utils.py +86 -59
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
- ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
- ragbits/evaluate/callbacks/base.py +0 -22
- ragbits/evaluate/callbacks/neptune.py +0 -26
- ragbits/evaluate/loaders/__init__.py +0 -21
- ragbits/evaluate/loaders/base.py +0 -24
- ragbits/evaluate/loaders/hf.py +0 -25
- ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
- /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""Result models for agent simulation scenarios."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SimulationStatus(str, Enum):
|
|
10
|
+
"""Status of a simulation run."""
|
|
11
|
+
|
|
12
|
+
RUNNING = "running"
|
|
13
|
+
COMPLETED = "completed"
|
|
14
|
+
FAILED = "failed"
|
|
15
|
+
TIMEOUT = "timeout"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class TurnResult:
|
|
20
|
+
"""Result of a single conversation turn."""
|
|
21
|
+
|
|
22
|
+
turn_index: int
|
|
23
|
+
task_index: int
|
|
24
|
+
user_message: str
|
|
25
|
+
assistant_message: str
|
|
26
|
+
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
|
27
|
+
task_completed: bool = False
|
|
28
|
+
task_completed_reason: str = ""
|
|
29
|
+
token_usage: dict[str, int] | None = None
|
|
30
|
+
latency_ms: float | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class TaskResult:
|
|
35
|
+
"""Result of a single task within the scenario."""
|
|
36
|
+
|
|
37
|
+
task_index: int
|
|
38
|
+
description: str
|
|
39
|
+
expected_result: str | None
|
|
40
|
+
completed: bool
|
|
41
|
+
turns_taken: int
|
|
42
|
+
final_reason: str
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class ConversationMetrics:
|
|
47
|
+
"""Aggregate metrics for the conversation."""
|
|
48
|
+
|
|
49
|
+
total_turns: int
|
|
50
|
+
total_tasks: int
|
|
51
|
+
tasks_completed: int
|
|
52
|
+
total_tokens: int = 0
|
|
53
|
+
prompt_tokens: int = 0
|
|
54
|
+
completion_tokens: int = 0
|
|
55
|
+
total_cost_usd: float = 0.0
|
|
56
|
+
deepeval_scores: dict[str, float] = field(default_factory=dict)
|
|
57
|
+
custom: dict[str, Any] = field(default_factory=dict)
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def success_rate(self) -> float:
|
|
61
|
+
"""Calculate task success rate."""
|
|
62
|
+
if self.total_tasks == 0:
|
|
63
|
+
return 0.0
|
|
64
|
+
return self.tasks_completed / self.total_tasks
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@dataclass
|
|
68
|
+
class SimulationResult:
|
|
69
|
+
"""Complete result for a scenario simulation."""
|
|
70
|
+
|
|
71
|
+
# Metadata
|
|
72
|
+
scenario_name: str
|
|
73
|
+
start_time: datetime
|
|
74
|
+
status: SimulationStatus
|
|
75
|
+
|
|
76
|
+
# Detailed data
|
|
77
|
+
turns: list[TurnResult] = field(default_factory=list)
|
|
78
|
+
tasks: list[TaskResult] = field(default_factory=list)
|
|
79
|
+
metrics: ConversationMetrics | None = None
|
|
80
|
+
|
|
81
|
+
# Optional metadata
|
|
82
|
+
end_time: datetime | None = None
|
|
83
|
+
agent_model: str | None = None
|
|
84
|
+
simulated_user_model: str | None = None
|
|
85
|
+
checker_model: str | None = None
|
|
86
|
+
personality: str | None = None
|
|
87
|
+
error: str | None = None
|
|
88
|
+
|
|
89
|
+
def to_dict(self) -> dict[str, Any]:
|
|
90
|
+
"""Convert to JSON-serializable dictionary."""
|
|
91
|
+
return {
|
|
92
|
+
"scenario_name": self.scenario_name,
|
|
93
|
+
"start_time": self.start_time.isoformat(),
|
|
94
|
+
"end_time": self.end_time.isoformat() if self.end_time else None,
|
|
95
|
+
"status": self.status.value,
|
|
96
|
+
"agent_model": self.agent_model,
|
|
97
|
+
"simulated_user_model": self.simulated_user_model,
|
|
98
|
+
"checker_model": self.checker_model,
|
|
99
|
+
"personality": self.personality,
|
|
100
|
+
"error": self.error,
|
|
101
|
+
"turns": [
|
|
102
|
+
{
|
|
103
|
+
"turn_index": t.turn_index,
|
|
104
|
+
"task_index": t.task_index,
|
|
105
|
+
"user_message": t.user_message,
|
|
106
|
+
"assistant_message": t.assistant_message,
|
|
107
|
+
"tool_calls": t.tool_calls,
|
|
108
|
+
"task_completed": t.task_completed,
|
|
109
|
+
"task_completed_reason": t.task_completed_reason,
|
|
110
|
+
"token_usage": t.token_usage,
|
|
111
|
+
"latency_ms": t.latency_ms,
|
|
112
|
+
}
|
|
113
|
+
for t in self.turns
|
|
114
|
+
],
|
|
115
|
+
"tasks": [
|
|
116
|
+
{
|
|
117
|
+
"task_index": t.task_index,
|
|
118
|
+
"description": t.description,
|
|
119
|
+
"expected_result": t.expected_result,
|
|
120
|
+
"completed": t.completed,
|
|
121
|
+
"turns_taken": t.turns_taken,
|
|
122
|
+
"final_reason": t.final_reason,
|
|
123
|
+
}
|
|
124
|
+
for t in self.tasks
|
|
125
|
+
],
|
|
126
|
+
"metrics": {
|
|
127
|
+
"total_turns": self.metrics.total_turns,
|
|
128
|
+
"total_tasks": self.metrics.total_tasks,
|
|
129
|
+
"tasks_completed": self.metrics.tasks_completed,
|
|
130
|
+
"success_rate": self.metrics.success_rate,
|
|
131
|
+
"total_tokens": self.metrics.total_tokens,
|
|
132
|
+
"prompt_tokens": self.metrics.prompt_tokens,
|
|
133
|
+
"completion_tokens": self.metrics.completion_tokens,
|
|
134
|
+
"total_cost_usd": self.metrics.total_cost_usd,
|
|
135
|
+
"deepeval_scores": self.metrics.deepeval_scores,
|
|
136
|
+
"custom": self.metrics.custom,
|
|
137
|
+
}
|
|
138
|
+
if self.metrics
|
|
139
|
+
else None,
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_dict(cls, data: dict[str, Any]) -> "SimulationResult":
|
|
144
|
+
"""Create from dictionary."""
|
|
145
|
+
turns = [
|
|
146
|
+
TurnResult(
|
|
147
|
+
turn_index=t["turn_index"],
|
|
148
|
+
task_index=t["task_index"],
|
|
149
|
+
user_message=t["user_message"],
|
|
150
|
+
assistant_message=t["assistant_message"],
|
|
151
|
+
tool_calls=t.get("tool_calls", []),
|
|
152
|
+
task_completed=t.get("task_completed", False),
|
|
153
|
+
task_completed_reason=t.get("task_completed_reason", ""),
|
|
154
|
+
token_usage=t.get("token_usage"),
|
|
155
|
+
latency_ms=t.get("latency_ms"),
|
|
156
|
+
)
|
|
157
|
+
for t in data.get("turns", [])
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
tasks = [
|
|
161
|
+
TaskResult(
|
|
162
|
+
task_index=t["task_index"],
|
|
163
|
+
description=t["description"],
|
|
164
|
+
expected_result=t.get("expected_result"),
|
|
165
|
+
completed=t["completed"],
|
|
166
|
+
turns_taken=t["turns_taken"],
|
|
167
|
+
final_reason=t["final_reason"],
|
|
168
|
+
)
|
|
169
|
+
for t in data.get("tasks", [])
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
metrics_data = data.get("metrics")
|
|
173
|
+
metrics = None
|
|
174
|
+
if metrics_data:
|
|
175
|
+
metrics = ConversationMetrics(
|
|
176
|
+
total_turns=metrics_data["total_turns"],
|
|
177
|
+
total_tasks=metrics_data["total_tasks"],
|
|
178
|
+
tasks_completed=metrics_data["tasks_completed"],
|
|
179
|
+
total_tokens=metrics_data.get("total_tokens", 0),
|
|
180
|
+
prompt_tokens=metrics_data.get("prompt_tokens", 0),
|
|
181
|
+
completion_tokens=metrics_data.get("completion_tokens", 0),
|
|
182
|
+
total_cost_usd=metrics_data.get("total_cost_usd", 0.0),
|
|
183
|
+
deepeval_scores=metrics_data.get("deepeval_scores", {}),
|
|
184
|
+
custom=metrics_data.get("custom", {}),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
return cls(
|
|
188
|
+
scenario_name=data["scenario_name"],
|
|
189
|
+
start_time=datetime.fromisoformat(data["start_time"]),
|
|
190
|
+
end_time=datetime.fromisoformat(data["end_time"]) if data.get("end_time") else None,
|
|
191
|
+
status=SimulationStatus(data["status"]),
|
|
192
|
+
agent_model=data.get("agent_model"),
|
|
193
|
+
simulated_user_model=data.get("simulated_user_model"),
|
|
194
|
+
checker_model=data.get("checker_model"),
|
|
195
|
+
personality=data.get("personality"),
|
|
196
|
+
error=data.get("error"),
|
|
197
|
+
turns=turns,
|
|
198
|
+
tasks=tasks,
|
|
199
|
+
metrics=metrics,
|
|
200
|
+
)
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
"""Scenario loading functionality for agent simulation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Task
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def load_scenarios(scenarios_file: str = "scenarios.json") -> list[Scenario]:
|
|
10
|
+
"""Load scenarios from a JSON file.
|
|
11
|
+
|
|
12
|
+
Expected JSON format:
|
|
13
|
+
[
|
|
14
|
+
{
|
|
15
|
+
"name": "Scenario 1",
|
|
16
|
+
"tasks": [
|
|
17
|
+
{
|
|
18
|
+
"task": "task description",
|
|
19
|
+
"expected_result": "expected result description",
|
|
20
|
+
"expected_tools": ["tool1", "tool2"] # optional
|
|
21
|
+
},
|
|
22
|
+
...
|
|
23
|
+
]
|
|
24
|
+
},
|
|
25
|
+
...
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
scenarios_file: Path to the JSON file containing scenarios
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
List of Scenario objects
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
FileNotFoundError: If the scenarios file doesn't exist
|
|
36
|
+
ValueError: If the file format is invalid
|
|
37
|
+
"""
|
|
38
|
+
scenarios_path = Path(scenarios_file)
|
|
39
|
+
if not scenarios_path.exists():
|
|
40
|
+
raise FileNotFoundError(f"Scenarios file not found: {scenarios_path}")
|
|
41
|
+
|
|
42
|
+
with scenarios_path.open("r", encoding="utf-8") as f:
|
|
43
|
+
data = json.load(f)
|
|
44
|
+
|
|
45
|
+
if not isinstance(data, list):
|
|
46
|
+
raise ValueError(f"Scenarios file must contain a JSON array, got {type(data).__name__}")
|
|
47
|
+
|
|
48
|
+
scenarios: list[Scenario] = []
|
|
49
|
+
for scenario_data in data:
|
|
50
|
+
if not isinstance(scenario_data, dict):
|
|
51
|
+
raise ValueError(f"Each scenario must be a JSON object, got {type(scenario_data).__name__}")
|
|
52
|
+
|
|
53
|
+
name = scenario_data.get("name", "")
|
|
54
|
+
tasks_data = scenario_data.get("tasks", [])
|
|
55
|
+
|
|
56
|
+
if not isinstance(tasks_data, list):
|
|
57
|
+
raise ValueError(f"Tasks must be a JSON array, got {type(tasks_data).__name__}")
|
|
58
|
+
|
|
59
|
+
tasks: list[Task] = []
|
|
60
|
+
for task_data in tasks_data:
|
|
61
|
+
if not isinstance(task_data, dict):
|
|
62
|
+
raise ValueError(f"Each task must be a JSON object, got {type(task_data).__name__}")
|
|
63
|
+
|
|
64
|
+
task_desc = task_data.get("task", "")
|
|
65
|
+
expected_result = task_data.get("expected_result", "")
|
|
66
|
+
expected_tools = task_data.get("expected_tools")
|
|
67
|
+
if expected_tools is not None and not isinstance(expected_tools, list):
|
|
68
|
+
raise ValueError(f"expected_tools must be a list or null, got {type(expected_tools).__name__}")
|
|
69
|
+
tasks.append(Task(task=task_desc, expected_result=expected_result, expected_tools=expected_tools))
|
|
70
|
+
|
|
71
|
+
scenarios.append(Scenario(name=name, tasks=tasks))
|
|
72
|
+
|
|
73
|
+
if not scenarios:
|
|
74
|
+
raise ValueError(f"No scenarios found in {scenarios_path}")
|
|
75
|
+
|
|
76
|
+
return scenarios
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_personalities(personalities_file: str = "personalities.json") -> list[Personality]:
|
|
80
|
+
"""Load personalities from a JSON file.
|
|
81
|
+
|
|
82
|
+
Expected JSON format:
|
|
83
|
+
[
|
|
84
|
+
{
|
|
85
|
+
"name": "Personality 1",
|
|
86
|
+
"description": "Personality description that will be used in the system prompt"
|
|
87
|
+
},
|
|
88
|
+
...
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
personalities_file: Path to the JSON file containing personalities
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
List of Personality objects
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
FileNotFoundError: If the personalities file doesn't exist
|
|
99
|
+
ValueError: If the file format is invalid
|
|
100
|
+
"""
|
|
101
|
+
personalities_path = Path(personalities_file)
|
|
102
|
+
if not personalities_path.exists():
|
|
103
|
+
raise FileNotFoundError(f"Personalities file not found: {personalities_path}")
|
|
104
|
+
|
|
105
|
+
with personalities_path.open("r", encoding="utf-8") as f:
|
|
106
|
+
data = json.load(f)
|
|
107
|
+
|
|
108
|
+
if not isinstance(data, list):
|
|
109
|
+
raise ValueError(f"Personalities file must contain a JSON array, got {type(data).__name__}")
|
|
110
|
+
|
|
111
|
+
personalities: list[Personality] = []
|
|
112
|
+
for personality_data in data:
|
|
113
|
+
if not isinstance(personality_data, dict):
|
|
114
|
+
raise ValueError(f"Each personality must be a JSON object, got {type(personality_data).__name__}")
|
|
115
|
+
|
|
116
|
+
name = personality_data.get("name", "")
|
|
117
|
+
description = personality_data.get("description", "")
|
|
118
|
+
|
|
119
|
+
if not name:
|
|
120
|
+
raise ValueError("Each personality must have a 'name' field")
|
|
121
|
+
if not description:
|
|
122
|
+
raise ValueError("Each personality must have a 'description' field")
|
|
123
|
+
|
|
124
|
+
personalities.append(Personality(name=name, description=description))
|
|
125
|
+
|
|
126
|
+
if not personalities:
|
|
127
|
+
raise ValueError(f"No personalities found in {personalities_path}")
|
|
128
|
+
|
|
129
|
+
return personalities
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""Simulation components for agent evaluation scenarios."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import re
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
from ragbits.agents.tool import ToolCallResult
|
|
10
|
+
from ragbits.core.llms import LiteLLM
|
|
11
|
+
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Task, Turn
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from ragbits.evaluate.agent_simulation.context import DataSnapshot, DomainContext
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SimulatedUser:
|
|
18
|
+
"""A simple LLM-driven user simulator that works through tasks sequentially.
|
|
19
|
+
|
|
20
|
+
It generates the next user utterance based on the conversation so far
|
|
21
|
+
and the current task. It only moves to the next task when the current one is completed.
|
|
22
|
+
|
|
23
|
+
Supports optional data grounding via DataSnapshot to ensure the simulated user
|
|
24
|
+
only requests items that actually exist in the available data.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
llm: LiteLLM,
|
|
30
|
+
scenario: Scenario,
|
|
31
|
+
personality: Personality | None = None,
|
|
32
|
+
data_snapshot: DataSnapshot | None = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Initialize the simulated user.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
llm: The LLM to use for generating messages.
|
|
38
|
+
scenario: The scenario containing tasks to work through.
|
|
39
|
+
personality: Optional personality to influence communication style.
|
|
40
|
+
data_snapshot: Optional data snapshot for grounding requests to available data.
|
|
41
|
+
"""
|
|
42
|
+
self.llm = llm
|
|
43
|
+
self.scenario = scenario
|
|
44
|
+
self.personality = personality
|
|
45
|
+
self.data_snapshot = data_snapshot
|
|
46
|
+
self.current_task_idx = 0
|
|
47
|
+
|
|
48
|
+
def get_current_task(self) -> Task | None:
|
|
49
|
+
"""Get the current task, or None if all tasks are completed."""
|
|
50
|
+
if self.current_task_idx < len(self.scenario.tasks):
|
|
51
|
+
return self.scenario.tasks[self.current_task_idx]
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
def advance_to_next_task(self) -> bool:
|
|
55
|
+
"""Move to the next task. Returns True if there is a next task, False otherwise."""
|
|
56
|
+
if self.current_task_idx < len(self.scenario.tasks) - 1:
|
|
57
|
+
self.current_task_idx += 1
|
|
58
|
+
return True
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
async def next_message(self, history: list[Turn]) -> str:
|
|
62
|
+
"""Generate the next user message based on conversation history and current task."""
|
|
63
|
+
current_task = self.get_current_task()
|
|
64
|
+
if current_task is None:
|
|
65
|
+
return "Thank you, all tasks are completed."
|
|
66
|
+
|
|
67
|
+
history_text = []
|
|
68
|
+
for t in history:
|
|
69
|
+
history_text.append(f"User: {t.user}\nAssistant: {t.assistant}")
|
|
70
|
+
history_block = "\n\n".join(history_text) if history_text else "(no prior messages)"
|
|
71
|
+
|
|
72
|
+
task_context = f"Current task: {current_task.task}"
|
|
73
|
+
if self.current_task_idx > 0:
|
|
74
|
+
completed_tasks = ", ".join([t.task for t in self.scenario.tasks[: self.current_task_idx]])
|
|
75
|
+
task_context += f"\nCompleted tasks: {completed_tasks}"
|
|
76
|
+
|
|
77
|
+
personality_instruction = ""
|
|
78
|
+
if self.personality:
|
|
79
|
+
personality_instruction = f"\n\nPersonality: {self.personality.description}"
|
|
80
|
+
|
|
81
|
+
# Build data grounding block if snapshot is provided
|
|
82
|
+
grounding_block = ""
|
|
83
|
+
if self.data_snapshot:
|
|
84
|
+
grounding_block = (
|
|
85
|
+
"\n\n[AVAILABLE DATA]\n"
|
|
86
|
+
f"{self.data_snapshot.format_for_prompt()}\n\n"
|
|
87
|
+
"IMPORTANT: Only reference items that exist in the AVAILABLE DATA above. "
|
|
88
|
+
"Do not ask for entities that are not listed."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
prompt = (
|
|
92
|
+
"[SYSTEM]\n"
|
|
93
|
+
"You are simulating a concise human user in a terminal chat. "
|
|
94
|
+
f"Scenario: {self.scenario.name}\n"
|
|
95
|
+
f"{task_context}{personality_instruction}{grounding_block}\n"
|
|
96
|
+
"Given the assistant's last reply and the conversation so far, "
|
|
97
|
+
"write ONLY the next user message to work on the current task. Be specific and brief.\n\n"
|
|
98
|
+
"[CONVERSATION]\n"
|
|
99
|
+
f"{history_block}\n\n"
|
|
100
|
+
"[TASK]\nWrite the next USER message now:"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
response = await self.llm.generate(prompt=prompt)
|
|
104
|
+
return response.strip()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class GoalChecker:
|
|
108
|
+
"""A lightweight judge model that decides whether the current task has been achieved.
|
|
109
|
+
|
|
110
|
+
It inspects the conversation so far and checks if the task matches the expected result.
|
|
111
|
+
Supports optional domain context for accurate evaluation in specific domains.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(self, llm: LiteLLM, scenario: Scenario) -> None:
|
|
115
|
+
self.llm = llm
|
|
116
|
+
self.scenario = scenario
|
|
117
|
+
|
|
118
|
+
async def is_task_achieved(
|
|
119
|
+
self,
|
|
120
|
+
current_task: Task,
|
|
121
|
+
history: list[Turn],
|
|
122
|
+
context: DomainContext | None = None,
|
|
123
|
+
) -> tuple[bool, str]:
|
|
124
|
+
"""Check if the current task has been completed based on the conversation history.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
current_task: The task to check completion for.
|
|
128
|
+
history: List of conversation turns so far.
|
|
129
|
+
context: Optional domain context for accurate evaluation (e.g., currency, locale).
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
Tuple of (is_completed, reason).
|
|
133
|
+
"""
|
|
134
|
+
history_text = []
|
|
135
|
+
for t in history:
|
|
136
|
+
history_text.append(f"User: {t.user}\nAssistant: {t.assistant}")
|
|
137
|
+
history_block = "\n\n".join(history_text) if history_text else "(no prior messages)"
|
|
138
|
+
|
|
139
|
+
# Build context block if provided
|
|
140
|
+
context_block = ""
|
|
141
|
+
if context:
|
|
142
|
+
context_block = (
|
|
143
|
+
"\n[IMPORTANT CONTEXT]\n"
|
|
144
|
+
f"{context.format_for_prompt()}\n\n"
|
|
145
|
+
"When evaluating task completion, consider the domain context above "
|
|
146
|
+
f"and use {context.locale} locale conventions.\n\n"
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
prompt = (
|
|
150
|
+
"[SYSTEM]\n"
|
|
151
|
+
"You are a strict task-completion judge for a user-assistant conversation. "
|
|
152
|
+
"Decide if the assistant has fulfilled the current task.\n"
|
|
153
|
+
f"Current task: {current_task.task}\n"
|
|
154
|
+
f"Expected result: {current_task.expected_result}\n"
|
|
155
|
+
f"{context_block}"
|
|
156
|
+
"Respond with a concise JSON object ONLY, no extra text, with fields:\n"
|
|
157
|
+
'{"done": true|false, "reason": "short reason"}\n\n'
|
|
158
|
+
"[CONVERSATION]\n"
|
|
159
|
+
f"{history_block}\n\n"
|
|
160
|
+
"[TASK]\nReturn the JSON now:"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
response = await self.llm.generate(prompt=prompt)
|
|
164
|
+
text = response.strip()
|
|
165
|
+
# Be robust to slight deviations by attempting a minimal parse
|
|
166
|
+
done = False
|
|
167
|
+
reason = ""
|
|
168
|
+
|
|
169
|
+
if not text:
|
|
170
|
+
return False, "Empty response from goal checker"
|
|
171
|
+
|
|
172
|
+
# Try to extract JSON from markdown code blocks if present
|
|
173
|
+
code_block_match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
|
174
|
+
if code_block_match:
|
|
175
|
+
text = code_block_match.group(1)
|
|
176
|
+
|
|
177
|
+
# Try to find JSON object in the text
|
|
178
|
+
json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", text, re.DOTALL)
|
|
179
|
+
if json_match:
|
|
180
|
+
text = json_match.group(0)
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
data = json.loads(text)
|
|
184
|
+
done = bool(data.get("done", False))
|
|
185
|
+
reason = str(data.get("reason", "")).strip()
|
|
186
|
+
except json.JSONDecodeError:
|
|
187
|
+
# If JSON parsing fails, try to infer from response text
|
|
188
|
+
reason = f"Failed to parse JSON response: {text[:100]}"
|
|
189
|
+
# Heuristic: if response contains "done" or "completed" or "true", assume done
|
|
190
|
+
text_lower = text.lower()
|
|
191
|
+
if any(word in text_lower for word in ["done", "completed", "true", "yes", "success"]):
|
|
192
|
+
done = True
|
|
193
|
+
elif any(word in text_lower for word in ["not done", "incomplete", "false", "no", "failed"]):
|
|
194
|
+
done = False
|
|
195
|
+
|
|
196
|
+
return done, reason
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class ToolUsageChecker:
|
|
200
|
+
"""A simple comparator that verifies whether the agent used the expected tools.
|
|
201
|
+
|
|
202
|
+
It checks if all expected tools from the task were called during the conversation turn.
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(self, scenario: Scenario) -> None:
|
|
206
|
+
self.scenario = scenario
|
|
207
|
+
|
|
208
|
+
def check_tool_usage(self, current_task: Task, tool_calls: list[ToolCallResult]) -> tuple[bool, str]: # noqa: PLR6301
|
|
209
|
+
"""Check if the expected tools were used for the current task.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
current_task: The current task being evaluated
|
|
213
|
+
tool_calls: List of tool calls made during this turn
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
Tuple of (success: bool, reason: str)
|
|
217
|
+
"""
|
|
218
|
+
if not current_task.expected_tools:
|
|
219
|
+
return True, "No expected tools specified"
|
|
220
|
+
|
|
221
|
+
if not tool_calls:
|
|
222
|
+
return False, "No tools were called, but tools were expected"
|
|
223
|
+
|
|
224
|
+
# Get the names of tools that were actually called
|
|
225
|
+
called_tool_names = [tc.name for tc in tool_calls]
|
|
226
|
+
expected_tool_names = current_task.expected_tools
|
|
227
|
+
|
|
228
|
+
# Check if all expected tools were used
|
|
229
|
+
missing_tools = [tool for tool in expected_tool_names if tool not in called_tool_names]
|
|
230
|
+
|
|
231
|
+
if missing_tools:
|
|
232
|
+
return (
|
|
233
|
+
False,
|
|
234
|
+
f"Expected tools not used: {', '.join(missing_tools)}. Tools called: {', '.join(called_tool_names)}",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return True, f"All expected tools used: {', '.join(called_tool_names)}"
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def build_llm(model_name: str | None, default_model: str, api_key: str) -> LiteLLM:
|
|
241
|
+
"""Build an LLM instance with the specified model name or default."""
|
|
242
|
+
model = model_name or default_model
|
|
243
|
+
return LiteLLM(model_name=model, use_structured_output=True, api_key=api_key)
|