ragbits-evaluate 0.5.0__py3-none-any.whl → 1.4.0.dev202602030301__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragbits/evaluate/agent_simulation/__init__.py +87 -0
- ragbits/evaluate/agent_simulation/context.py +118 -0
- ragbits/evaluate/agent_simulation/conversation.py +333 -0
- ragbits/evaluate/agent_simulation/deepeval_evaluator.py +92 -0
- ragbits/evaluate/agent_simulation/logger.py +165 -0
- ragbits/evaluate/agent_simulation/metrics/__init__.py +19 -0
- ragbits/evaluate/agent_simulation/metrics/builtin.py +221 -0
- ragbits/evaluate/agent_simulation/metrics/collectors.py +142 -0
- ragbits/evaluate/agent_simulation/models.py +37 -0
- ragbits/evaluate/agent_simulation/results.py +200 -0
- ragbits/evaluate/agent_simulation/scenarios.py +129 -0
- ragbits/evaluate/agent_simulation/simulation.py +243 -0
- ragbits/evaluate/cli.py +150 -0
- ragbits/evaluate/config.py +11 -0
- ragbits/evaluate/dataloaders/__init__.py +3 -0
- ragbits/evaluate/dataloaders/base.py +95 -0
- ragbits/evaluate/dataloaders/document_search.py +61 -0
- ragbits/evaluate/dataloaders/exceptions.py +25 -0
- ragbits/evaluate/dataloaders/gaia.py +78 -0
- ragbits/evaluate/dataloaders/hotpot_qa.py +95 -0
- ragbits/evaluate/dataloaders/human_eval.py +70 -0
- ragbits/evaluate/dataloaders/question_answer.py +56 -0
- ragbits/evaluate/dataset_generator/pipeline.py +4 -4
- ragbits/evaluate/dataset_generator/prompts/qa.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/corpus_generation.py +2 -4
- ragbits/evaluate/dataset_generator/tasks/text_generation/base.py +3 -5
- ragbits/evaluate/dataset_generator/tasks/text_generation/qa.py +3 -3
- ragbits/evaluate/evaluator.py +178 -50
- ragbits/evaluate/factories/__init__.py +42 -0
- ragbits/evaluate/metrics/__init__.py +2 -23
- ragbits/evaluate/metrics/base.py +40 -17
- ragbits/evaluate/metrics/document_search.py +40 -23
- ragbits/evaluate/metrics/gaia.py +84 -0
- ragbits/evaluate/metrics/hotpot_qa.py +51 -0
- ragbits/evaluate/metrics/human_eval.py +105 -0
- ragbits/evaluate/metrics/question_answer.py +222 -0
- ragbits/evaluate/optimizer.py +138 -86
- ragbits/evaluate/pipelines/__init__.py +37 -0
- ragbits/evaluate/pipelines/base.py +34 -10
- ragbits/evaluate/pipelines/document_search.py +72 -67
- ragbits/evaluate/pipelines/gaia.py +249 -0
- ragbits/evaluate/pipelines/hotpot_qa.py +342 -0
- ragbits/evaluate/pipelines/human_eval.py +323 -0
- ragbits/evaluate/pipelines/question_answer.py +96 -0
- ragbits/evaluate/utils.py +86 -59
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/METADATA +33 -9
- ragbits_evaluate-1.4.0.dev202602030301.dist-info/RECORD +59 -0
- {ragbits_evaluate-0.5.0.dist-info → ragbits_evaluate-1.4.0.dev202602030301.dist-info}/WHEEL +1 -1
- ragbits/evaluate/callbacks/base.py +0 -22
- ragbits/evaluate/callbacks/neptune.py +0 -26
- ragbits/evaluate/loaders/__init__.py +0 -21
- ragbits/evaluate/loaders/base.py +0 -24
- ragbits/evaluate/loaders/hf.py +0 -25
- ragbits_evaluate-0.5.0.dist-info/RECORD +0 -33
- /ragbits/evaluate/{callbacks/__init__.py → py.typed} +0 -0
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""Agent simulation utilities for evaluation scenarios.
|
|
2
|
+
|
|
3
|
+
This module uses lazy imports for components that require optional dependencies
|
|
4
|
+
(ragbits-agents, ragbits-chat) to allow importing result models independently.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
|
|
9
|
+
# Import context, metrics, and result models eagerly - they have no external dependencies
|
|
10
|
+
from ragbits.evaluate.agent_simulation.context import DataSnapshot, DomainContext
|
|
11
|
+
from ragbits.evaluate.agent_simulation.metrics import (
|
|
12
|
+
CompositeMetricCollector,
|
|
13
|
+
LatencyMetricCollector,
|
|
14
|
+
MetricCollector,
|
|
15
|
+
TokenUsageMetricCollector,
|
|
16
|
+
ToolUsageMetricCollector,
|
|
17
|
+
)
|
|
18
|
+
from ragbits.evaluate.agent_simulation.results import (
|
|
19
|
+
ConversationMetrics,
|
|
20
|
+
SimulationResult,
|
|
21
|
+
SimulationStatus,
|
|
22
|
+
TaskResult,
|
|
23
|
+
TurnResult,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from ragbits.evaluate.agent_simulation.conversation import run_simulation
|
|
28
|
+
from ragbits.evaluate.agent_simulation.deepeval_evaluator import DeepEvalEvaluator
|
|
29
|
+
from ragbits.evaluate.agent_simulation.logger import ConversationLogger
|
|
30
|
+
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Task, Turn
|
|
31
|
+
from ragbits.evaluate.agent_simulation.scenarios import load_personalities, load_scenarios
|
|
32
|
+
from ragbits.evaluate.agent_simulation.simulation import GoalChecker, SimulatedUser
|
|
33
|
+
|
|
34
|
+
__all__ = [
|
|
35
|
+
"CompositeMetricCollector",
|
|
36
|
+
"ConversationLogger",
|
|
37
|
+
"ConversationMetrics",
|
|
38
|
+
"DataSnapshot",
|
|
39
|
+
"DeepEvalEvaluator",
|
|
40
|
+
"DomainContext",
|
|
41
|
+
"GoalChecker",
|
|
42
|
+
"LatencyMetricCollector",
|
|
43
|
+
"MetricCollector",
|
|
44
|
+
"Personality",
|
|
45
|
+
"Scenario",
|
|
46
|
+
"SimulatedUser",
|
|
47
|
+
"SimulationResult",
|
|
48
|
+
"SimulationStatus",
|
|
49
|
+
"Task",
|
|
50
|
+
"TaskResult",
|
|
51
|
+
"TokenUsageMetricCollector",
|
|
52
|
+
"ToolUsageMetricCollector",
|
|
53
|
+
"Turn",
|
|
54
|
+
"TurnResult",
|
|
55
|
+
"load_personalities",
|
|
56
|
+
"load_scenarios",
|
|
57
|
+
"run_simulation",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def __getattr__(name: str) -> object:
|
|
62
|
+
"""Lazy import for components with optional dependencies."""
|
|
63
|
+
if name == "run_simulation":
|
|
64
|
+
from ragbits.evaluate.agent_simulation.conversation import run_simulation
|
|
65
|
+
|
|
66
|
+
return run_simulation
|
|
67
|
+
if name == "DeepEvalEvaluator":
|
|
68
|
+
from ragbits.evaluate.agent_simulation.deepeval_evaluator import DeepEvalEvaluator
|
|
69
|
+
|
|
70
|
+
return DeepEvalEvaluator
|
|
71
|
+
if name == "ConversationLogger":
|
|
72
|
+
from ragbits.evaluate.agent_simulation.logger import ConversationLogger
|
|
73
|
+
|
|
74
|
+
return ConversationLogger
|
|
75
|
+
if name in ("Personality", "Scenario", "Task", "Turn"):
|
|
76
|
+
from ragbits.evaluate.agent_simulation import models
|
|
77
|
+
|
|
78
|
+
return getattr(models, name)
|
|
79
|
+
if name in ("load_personalities", "load_scenarios"):
|
|
80
|
+
from ragbits.evaluate.agent_simulation import scenarios
|
|
81
|
+
|
|
82
|
+
return getattr(scenarios, name)
|
|
83
|
+
if name in ("GoalChecker", "SimulatedUser"):
|
|
84
|
+
from ragbits.evaluate.agent_simulation import simulation
|
|
85
|
+
|
|
86
|
+
return getattr(simulation, name)
|
|
87
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Context models for agent simulation scenarios."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
DEFAULT_MAX_ITEMS_IN_PROMPT = 15
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class DomainContext:
|
|
11
|
+
"""Domain-specific context for goal checking and simulation.
|
|
12
|
+
|
|
13
|
+
Provides additional context to the GoalChecker to avoid false negatives
|
|
14
|
+
from value interpretation differences or missing domain knowledge.
|
|
15
|
+
|
|
16
|
+
The context is intentionally generic - use the `metadata` field for any
|
|
17
|
+
domain-specific information that doesn't fit the standard fields.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
>>> context = DomainContext(
|
|
21
|
+
... domain_type="customer_support",
|
|
22
|
+
... locale="en_US",
|
|
23
|
+
... metadata={"ticket_statuses": ["open", "pending", "resolved"]},
|
|
24
|
+
... )
|
|
25
|
+
>>> result = await goal_checker.is_task_achieved(task, history, context=context)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
domain_type: str
|
|
29
|
+
"""Type of domain (e.g., "customer_support", "booking", "search", "qa")."""
|
|
30
|
+
|
|
31
|
+
locale: str = "en_US"
|
|
32
|
+
"""Locale for language and formatting (e.g., "en_US", "de_DE")."""
|
|
33
|
+
|
|
34
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
35
|
+
"""Arbitrary domain-specific metadata for goal checking context."""
|
|
36
|
+
|
|
37
|
+
def format_for_prompt(self) -> str:
|
|
38
|
+
"""Format context for inclusion in LLM prompts.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Formatted string suitable for prompt injection.
|
|
42
|
+
"""
|
|
43
|
+
parts = [
|
|
44
|
+
f"Domain: {self.domain_type}",
|
|
45
|
+
f"Locale: {self.locale}",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
if self.metadata:
|
|
49
|
+
parts.append("Additional context:")
|
|
50
|
+
for key, value in self.metadata.items():
|
|
51
|
+
if isinstance(value, list) and len(value) > DEFAULT_MAX_ITEMS_IN_PROMPT:
|
|
52
|
+
truncated = value[:DEFAULT_MAX_ITEMS_IN_PROMPT]
|
|
53
|
+
parts.append(f" {key}: {truncated} ... and {len(value) - DEFAULT_MAX_ITEMS_IN_PROMPT} more")
|
|
54
|
+
else:
|
|
55
|
+
parts.append(f" {key}: {value}")
|
|
56
|
+
|
|
57
|
+
return "\n".join(parts)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class DataSnapshot:
|
|
62
|
+
"""Sample of available data to ground simulated user requests.
|
|
63
|
+
|
|
64
|
+
Provides the simulated user with knowledge of what data actually exists,
|
|
65
|
+
preventing unrealistic requests for non-existent entities.
|
|
66
|
+
|
|
67
|
+
The snapshot is intentionally generic - store any domain-specific data
|
|
68
|
+
in the `entities` dict with descriptive keys.
|
|
69
|
+
|
|
70
|
+
Example:
|
|
71
|
+
>>> snapshot = DataSnapshot(
|
|
72
|
+
... entities={
|
|
73
|
+
... "available_topics": ["billing", "technical", "returns"],
|
|
74
|
+
... "sample_users": [{"id": "u1", "name": "John"}],
|
|
75
|
+
... },
|
|
76
|
+
... description="Customer support knowledge base",
|
|
77
|
+
... )
|
|
78
|
+
>>> # SimulatedUser will only reference items from this data
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
entities: dict[str, list[Any]] = field(default_factory=dict)
|
|
82
|
+
"""Named collections of available entities (e.g., {"users": [...], "documents": [...]})."""
|
|
83
|
+
|
|
84
|
+
description: str = ""
|
|
85
|
+
"""Optional description of the data snapshot for context."""
|
|
86
|
+
|
|
87
|
+
def format_for_prompt(self, max_items: int = DEFAULT_MAX_ITEMS_IN_PROMPT) -> str:
|
|
88
|
+
"""Format data snapshot for inclusion in LLM prompts.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
max_items: Maximum number of items to include per entity type.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Formatted string suitable for prompt injection.
|
|
95
|
+
"""
|
|
96
|
+
parts = []
|
|
97
|
+
|
|
98
|
+
if self.description:
|
|
99
|
+
parts.append(f"Context: {self.description}")
|
|
100
|
+
|
|
101
|
+
for entity_name, entity_list in self.entities.items():
|
|
102
|
+
if not entity_list:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
truncated = entity_list[:max_items]
|
|
106
|
+
# Format items - if dicts with 'name', use that; otherwise str()
|
|
107
|
+
formatted_items = []
|
|
108
|
+
for item in truncated:
|
|
109
|
+
if isinstance(item, dict) and "name" in item:
|
|
110
|
+
formatted_items.append(item["name"])
|
|
111
|
+
else:
|
|
112
|
+
formatted_items.append(str(item))
|
|
113
|
+
|
|
114
|
+
parts.append(f"{entity_name}: {', '.join(formatted_items)}")
|
|
115
|
+
if len(entity_list) > max_items:
|
|
116
|
+
parts.append(f" ... and {len(entity_list) - max_items} more")
|
|
117
|
+
|
|
118
|
+
return "\n".join(parts)
|
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
"""Conversation orchestration for agent simulation scenarios."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
|
|
5
|
+
from ragbits.agents.tool import ToolCallResult
|
|
6
|
+
from ragbits.chat.interface import ChatInterface
|
|
7
|
+
from ragbits.chat.interface.types import ChatContext
|
|
8
|
+
from ragbits.core.llms import Usage
|
|
9
|
+
from ragbits.evaluate.agent_simulation.context import DataSnapshot, DomainContext
|
|
10
|
+
from ragbits.evaluate.agent_simulation.deepeval_evaluator import DeepEvalEvaluator
|
|
11
|
+
from ragbits.evaluate.agent_simulation.logger import ConversationLogger
|
|
12
|
+
from ragbits.evaluate.agent_simulation.metrics.collectors import CompositeMetricCollector, MetricCollector
|
|
13
|
+
from ragbits.evaluate.agent_simulation.models import Personality, Scenario, Turn
|
|
14
|
+
from ragbits.evaluate.agent_simulation.results import (
|
|
15
|
+
ConversationMetrics,
|
|
16
|
+
SimulationResult,
|
|
17
|
+
SimulationStatus,
|
|
18
|
+
TaskResult,
|
|
19
|
+
TurnResult,
|
|
20
|
+
)
|
|
21
|
+
from ragbits.evaluate.agent_simulation.simulation import GoalChecker, SimulatedUser, ToolUsageChecker, build_llm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _evaluate_with_deepeval(history: list[Turn], logger: ConversationLogger) -> dict[str, float]:
|
|
25
|
+
"""Evaluate conversation with DeepEval metrics.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
history: List of conversation turns to evaluate
|
|
29
|
+
logger: Logger instance to record evaluation results
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Dictionary of metric names to scores
|
|
33
|
+
"""
|
|
34
|
+
scores: dict[str, float] = {}
|
|
35
|
+
if not history:
|
|
36
|
+
return scores
|
|
37
|
+
|
|
38
|
+
print("\n=== Running DeepEval Evaluation ===")
|
|
39
|
+
deepeval_evaluator = DeepEvalEvaluator()
|
|
40
|
+
try:
|
|
41
|
+
evaluation_results = deepeval_evaluator.evaluate_conversation(history)
|
|
42
|
+
logger.log_deepeval_metrics(evaluation_results)
|
|
43
|
+
for metric_name, result in evaluation_results.items():
|
|
44
|
+
score = result.get("score")
|
|
45
|
+
if score is not None:
|
|
46
|
+
print(f"{metric_name}: {score:.4f}")
|
|
47
|
+
scores[metric_name] = float(score)
|
|
48
|
+
else:
|
|
49
|
+
error = result.get("error", "Unknown error")
|
|
50
|
+
print(f"{metric_name}: Error - {error}")
|
|
51
|
+
except Exception as e:
|
|
52
|
+
print(f"Error during DeepEval evaluation: {e}")
|
|
53
|
+
logger.log_deepeval_metrics({"error": str(e)})
|
|
54
|
+
|
|
55
|
+
return scores
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
async def run_simulation( # noqa: PLR0912, PLR0915
|
|
59
|
+
scenario: Scenario,
|
|
60
|
+
chat: ChatInterface,
|
|
61
|
+
max_turns_scenario: int = 15,
|
|
62
|
+
max_turns_task: int | None = 4,
|
|
63
|
+
log_file: str | None = None,
|
|
64
|
+
agent_model_name: str | None = None,
|
|
65
|
+
sim_user_model_name: str | None = None,
|
|
66
|
+
checker_model_name: str | None = None,
|
|
67
|
+
default_model: str = "gpt-4o-mini",
|
|
68
|
+
api_key: str = "",
|
|
69
|
+
user_message_prefix: str = "",
|
|
70
|
+
personality: Personality | None = None,
|
|
71
|
+
domain_context: DomainContext | None = None,
|
|
72
|
+
data_snapshot: DataSnapshot | None = None,
|
|
73
|
+
metric_collectors: list[MetricCollector] | None = None,
|
|
74
|
+
) -> SimulationResult:
|
|
75
|
+
"""Run a conversation between an agent and a simulated user.
|
|
76
|
+
|
|
77
|
+
The conversation proceeds through tasks defined in the scenario, with a goal checker
|
|
78
|
+
determining when each task is completed. Returns structured results for programmatic access.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
scenario: The scenario containing tasks to complete
|
|
82
|
+
chat: An instantiated ChatInterface used to drive the assistant side of the conversation.
|
|
83
|
+
max_turns_scenario: Maximum number of conversation turns for the entire scenario
|
|
84
|
+
max_turns_task: Maximum number of conversation turns per task (None for no limit)
|
|
85
|
+
log_file: Optional path to log file
|
|
86
|
+
agent_model_name: Optional override for agent LLM model name
|
|
87
|
+
sim_user_model_name: Optional override for simulated user LLM model name
|
|
88
|
+
checker_model_name: Optional override for goal checker LLM model name
|
|
89
|
+
default_model: Default LLM model name
|
|
90
|
+
api_key: API key for LLM
|
|
91
|
+
user_message_prefix: Optional prefix to add to user messages before sending to agent
|
|
92
|
+
personality: Optional personality to use for the simulated user
|
|
93
|
+
domain_context: Optional domain context for goal checking (currency, locale, business rules)
|
|
94
|
+
data_snapshot: Optional data snapshot to ground simulated user requests to available data
|
|
95
|
+
metric_collectors: Optional list of custom metric collectors for per-turn metrics
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
SimulationResult containing all turns, task results, and metrics
|
|
99
|
+
"""
|
|
100
|
+
start_time = datetime.now(timezone.utc)
|
|
101
|
+
|
|
102
|
+
# Initialize metric collectors
|
|
103
|
+
collectors = CompositeMetricCollector(metric_collectors)
|
|
104
|
+
|
|
105
|
+
# Initialize result tracking
|
|
106
|
+
turn_results: list[TurnResult] = []
|
|
107
|
+
task_results: list[TaskResult] = []
|
|
108
|
+
task_turn_counts: dict[int, int] = {} # task_index -> turns taken
|
|
109
|
+
task_final_reasons: dict[int, str] = {} # task_index -> final reason
|
|
110
|
+
|
|
111
|
+
# Simulated user uses an independent llm (can share the same provider)
|
|
112
|
+
sim_user = SimulatedUser(
|
|
113
|
+
llm=build_llm(sim_user_model_name, default_model, api_key),
|
|
114
|
+
scenario=scenario,
|
|
115
|
+
personality=personality,
|
|
116
|
+
data_snapshot=data_snapshot,
|
|
117
|
+
)
|
|
118
|
+
# Independent goal checker model
|
|
119
|
+
goal_checker = GoalChecker(llm=build_llm(checker_model_name, default_model, api_key), scenario=scenario)
|
|
120
|
+
# Tool usage checker (simple comparator, no LLM needed)
|
|
121
|
+
tool_checker = ToolUsageChecker(scenario=scenario)
|
|
122
|
+
|
|
123
|
+
history: list[Turn] = []
|
|
124
|
+
logger = ConversationLogger(log_file)
|
|
125
|
+
logger.initialize_session(scenario, agent_model_name, sim_user_model_name, checker_model_name, personality)
|
|
126
|
+
total_usage = Usage()
|
|
127
|
+
|
|
128
|
+
# Seed: ask the simulated user for the first message based on the first task
|
|
129
|
+
user_message = await sim_user.next_message(history=history)
|
|
130
|
+
|
|
131
|
+
turns_for_current_task = 0
|
|
132
|
+
current_task_id = None
|
|
133
|
+
status = SimulationStatus.RUNNING
|
|
134
|
+
error_message: str | None = None
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
for turn_idx in range(1, max_turns_scenario + 1):
|
|
138
|
+
current_task = sim_user.get_current_task()
|
|
139
|
+
if current_task is None:
|
|
140
|
+
print("\nAll tasks completed!")
|
|
141
|
+
status = SimulationStatus.COMPLETED
|
|
142
|
+
break
|
|
143
|
+
|
|
144
|
+
current_task_index = sim_user.current_task_idx
|
|
145
|
+
|
|
146
|
+
# Track turns per task
|
|
147
|
+
if current_task_id != id(current_task):
|
|
148
|
+
# New task started, reset counter
|
|
149
|
+
turns_for_current_task = 0
|
|
150
|
+
current_task_id = id(current_task)
|
|
151
|
+
|
|
152
|
+
# Check if we've exceeded max turns for this task
|
|
153
|
+
if max_turns_task is not None and turns_for_current_task >= max_turns_task:
|
|
154
|
+
print(f"\nReached maximum number of turns ({max_turns_task}) for current task. Exiting.")
|
|
155
|
+
status = SimulationStatus.TIMEOUT
|
|
156
|
+
task_final_reasons[current_task_index] = f"Timeout: exceeded {max_turns_task} turns"
|
|
157
|
+
break
|
|
158
|
+
|
|
159
|
+
turns_for_current_task += 1
|
|
160
|
+
task_turn_counts[current_task_index] = task_turn_counts.get(current_task_index, 0) + 1
|
|
161
|
+
|
|
162
|
+
print(f"\n=== Turn {turn_idx} (Task turn: {turns_for_current_task}) ===")
|
|
163
|
+
print(f"Current Task: {current_task.task}")
|
|
164
|
+
print(f"User: {user_message}")
|
|
165
|
+
|
|
166
|
+
# Notify metric collectors of turn start
|
|
167
|
+
collectors.on_turn_start(turn_idx, current_task_index, user_message)
|
|
168
|
+
|
|
169
|
+
# Add optional prefix to user message
|
|
170
|
+
full_user_message = user_message_prefix + user_message if user_message_prefix else user_message
|
|
171
|
+
|
|
172
|
+
assistant_reply_parts: list[str] = []
|
|
173
|
+
tool_calls: list[ToolCallResult] = []
|
|
174
|
+
turn_usage: Usage = Usage()
|
|
175
|
+
dummy_chat_context = ChatContext()
|
|
176
|
+
|
|
177
|
+
stream = chat.chat(
|
|
178
|
+
message=full_user_message,
|
|
179
|
+
history=[
|
|
180
|
+
d
|
|
181
|
+
for turn in history
|
|
182
|
+
for d in ({"role": "user", "text": turn.user}, {"role": "assistant", "text": turn.assistant})
|
|
183
|
+
],
|
|
184
|
+
context=dummy_chat_context,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
async for chunk in stream:
|
|
188
|
+
if isinstance(chunk, str):
|
|
189
|
+
assistant_reply_parts.append(chunk)
|
|
190
|
+
elif isinstance(chunk, ToolCallResult):
|
|
191
|
+
tool_calls.append(chunk)
|
|
192
|
+
elif isinstance(chunk, Usage):
|
|
193
|
+
turn_usage += chunk
|
|
194
|
+
|
|
195
|
+
total_usage += turn_usage
|
|
196
|
+
assistant_reply = "".join(assistant_reply_parts).strip()
|
|
197
|
+
print(f"Assistant: {assistant_reply}")
|
|
198
|
+
if tool_calls:
|
|
199
|
+
print(f"Tools used: {[tc.name for tc in tool_calls]}")
|
|
200
|
+
print(
|
|
201
|
+
f"Assistant token usage: {turn_usage.total_tokens} total "
|
|
202
|
+
f"({turn_usage.prompt_tokens} prompt + {turn_usage.completion_tokens} completion), "
|
|
203
|
+
f"estimated cost: ${turn_usage.estimated_cost:.6f}"
|
|
204
|
+
)
|
|
205
|
+
logger.log_turn(turn_idx, current_task, user_message, assistant_reply, tool_calls, turn_usage)
|
|
206
|
+
|
|
207
|
+
# Update simulation-visible history (Turn objects)
|
|
208
|
+
history.append(Turn(user=user_message, assistant=assistant_reply))
|
|
209
|
+
|
|
210
|
+
# Ask the judge if current task is achieved
|
|
211
|
+
task_done, reason = await goal_checker.is_task_achieved(current_task, history, context=domain_context)
|
|
212
|
+
logger.log_task_check(turn_idx, task_done, reason)
|
|
213
|
+
|
|
214
|
+
# Check tool usage if expected tools are specified
|
|
215
|
+
if current_task.expected_tools:
|
|
216
|
+
tools_used_correctly, tool_reason = tool_checker.check_tool_usage(current_task, tool_calls)
|
|
217
|
+
logger.log_tool_check(turn_idx, tools_used_correctly, tool_reason, tool_calls)
|
|
218
|
+
if not tools_used_correctly:
|
|
219
|
+
print(f"Tool usage issue: {tool_reason}")
|
|
220
|
+
else:
|
|
221
|
+
print(f"Tool usage verified: {tool_reason}")
|
|
222
|
+
|
|
223
|
+
# Create turn result
|
|
224
|
+
turn_result = TurnResult(
|
|
225
|
+
turn_index=turn_idx,
|
|
226
|
+
task_index=current_task_index,
|
|
227
|
+
user_message=user_message,
|
|
228
|
+
assistant_message=assistant_reply,
|
|
229
|
+
tool_calls=[{"name": tc.name, "arguments": tc.arguments, "result": tc.result} for tc in tool_calls],
|
|
230
|
+
task_completed=task_done,
|
|
231
|
+
task_completed_reason=reason,
|
|
232
|
+
token_usage={
|
|
233
|
+
"total": turn_usage.total_tokens,
|
|
234
|
+
"prompt": turn_usage.prompt_tokens,
|
|
235
|
+
"completion": turn_usage.completion_tokens,
|
|
236
|
+
},
|
|
237
|
+
)
|
|
238
|
+
turn_results.append(turn_result)
|
|
239
|
+
|
|
240
|
+
# Notify metric collectors of turn end
|
|
241
|
+
collectors.on_turn_end(turn_result)
|
|
242
|
+
|
|
243
|
+
if task_done:
|
|
244
|
+
print(f"Task completed: {reason}")
|
|
245
|
+
task_final_reasons[current_task_index] = reason
|
|
246
|
+
has_next = sim_user.advance_to_next_task()
|
|
247
|
+
if not has_next:
|
|
248
|
+
print("\nAll tasks completed!")
|
|
249
|
+
status = SimulationStatus.COMPLETED
|
|
250
|
+
break
|
|
251
|
+
next_task = sim_user.get_current_task()
|
|
252
|
+
if next_task:
|
|
253
|
+
logger.log_task_transition(next_task)
|
|
254
|
+
# Reset task turn counter when moving to next task
|
|
255
|
+
turns_for_current_task = 0
|
|
256
|
+
current_task_id = id(next_task) if next_task else None
|
|
257
|
+
else:
|
|
258
|
+
print(f"Task not completed: {reason}")
|
|
259
|
+
task_final_reasons[current_task_index] = reason
|
|
260
|
+
|
|
261
|
+
# Ask the simulator for the next user message
|
|
262
|
+
user_message = await sim_user.next_message(history)
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
print("\nReached maximum number of turns. Exiting.")
|
|
266
|
+
status = SimulationStatus.TIMEOUT
|
|
267
|
+
|
|
268
|
+
except Exception as e:
|
|
269
|
+
status = SimulationStatus.FAILED
|
|
270
|
+
error_message = str(e)
|
|
271
|
+
print(f"\nSimulation failed with error: {error_message}")
|
|
272
|
+
|
|
273
|
+
# Build task results
|
|
274
|
+
for i, task in enumerate(scenario.tasks):
|
|
275
|
+
completed = i < sim_user.current_task_idx or (
|
|
276
|
+
i == sim_user.current_task_idx and status == SimulationStatus.COMPLETED
|
|
277
|
+
)
|
|
278
|
+
task_results.append(
|
|
279
|
+
TaskResult(
|
|
280
|
+
task_index=i,
|
|
281
|
+
description=task.task,
|
|
282
|
+
expected_result=task.expected_result,
|
|
283
|
+
completed=completed,
|
|
284
|
+
turns_taken=task_turn_counts.get(i, 0),
|
|
285
|
+
final_reason=task_final_reasons.get(i, "Not attempted"),
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# Print total token usage summary
|
|
290
|
+
print("\n=== Total Assistant Token Usage ===")
|
|
291
|
+
print(
|
|
292
|
+
f"Total assistant tokens: {total_usage.total_tokens} "
|
|
293
|
+
f"({total_usage.prompt_tokens} prompt + {total_usage.completion_tokens} completion)"
|
|
294
|
+
)
|
|
295
|
+
print(f"Total estimated cost: ${total_usage.estimated_cost:.6f}")
|
|
296
|
+
logger.log_total_usage(total_usage)
|
|
297
|
+
|
|
298
|
+
# Evaluate conversation with DeepEval metrics
|
|
299
|
+
deepeval_scores = _evaluate_with_deepeval(history, logger)
|
|
300
|
+
|
|
301
|
+
# Collect custom metrics from collectors
|
|
302
|
+
custom_metrics = collectors.on_conversation_end(turn_results)
|
|
303
|
+
|
|
304
|
+
logger.finalize_session()
|
|
305
|
+
|
|
306
|
+
# Build metrics
|
|
307
|
+
tasks_completed = sum(1 for t in task_results if t.completed)
|
|
308
|
+
metrics = ConversationMetrics(
|
|
309
|
+
total_turns=len(turn_results),
|
|
310
|
+
total_tasks=len(scenario.tasks),
|
|
311
|
+
tasks_completed=tasks_completed,
|
|
312
|
+
total_tokens=total_usage.total_tokens,
|
|
313
|
+
prompt_tokens=total_usage.prompt_tokens,
|
|
314
|
+
completion_tokens=total_usage.completion_tokens,
|
|
315
|
+
total_cost_usd=total_usage.estimated_cost,
|
|
316
|
+
deepeval_scores=deepeval_scores,
|
|
317
|
+
custom=custom_metrics,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
return SimulationResult(
|
|
321
|
+
scenario_name=scenario.name,
|
|
322
|
+
start_time=start_time,
|
|
323
|
+
end_time=datetime.now(timezone.utc),
|
|
324
|
+
status=status,
|
|
325
|
+
agent_model=agent_model_name,
|
|
326
|
+
simulated_user_model=sim_user_model_name,
|
|
327
|
+
checker_model=checker_model_name,
|
|
328
|
+
personality=personality.name if personality else None,
|
|
329
|
+
error=error_message,
|
|
330
|
+
turns=turn_results,
|
|
331
|
+
tasks=task_results,
|
|
332
|
+
metrics=metrics,
|
|
333
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""DeepEval integration for multi-turn conversation evaluation."""
|
|
2
|
+
|
|
3
|
+
from deepeval.metrics import ( # type: ignore[attr-defined]
|
|
4
|
+
ConversationCompletenessMetric,
|
|
5
|
+
ConversationRelevancyMetric,
|
|
6
|
+
KnowledgeRetentionMetric,
|
|
7
|
+
)
|
|
8
|
+
from deepeval.test_case import ConversationalTestCase, LLMTestCase # type: ignore[attr-defined]
|
|
9
|
+
|
|
10
|
+
from ragbits.evaluate.agent_simulation.models import Turn
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DeepEvalEvaluator:
|
|
14
|
+
"""Evaluator using DeepEval metrics for multi-turn conversations."""
|
|
15
|
+
|
|
16
|
+
def __init__(self) -> None:
|
|
17
|
+
"""Initialize the DeepEval evaluator with multi-turn metrics."""
|
|
18
|
+
self.completeness_metric = ConversationCompletenessMetric()
|
|
19
|
+
self.knowledge_retention_metric = KnowledgeRetentionMetric()
|
|
20
|
+
self.conversation_relevancy_metric = ConversationRelevancyMetric()
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def _evaluate_metric(
|
|
24
|
+
metric: ConversationCompletenessMetric | KnowledgeRetentionMetric | ConversationRelevancyMetric,
|
|
25
|
+
test_case: ConversationalTestCase,
|
|
26
|
+
) -> dict[str, float | str | None]:
|
|
27
|
+
"""Evaluate a single metric on a test case.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
metric: The metric instance to evaluate
|
|
31
|
+
test_case: The conversational test case to evaluate
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Dictionary containing score, reason, success, and optionally error
|
|
35
|
+
"""
|
|
36
|
+
try:
|
|
37
|
+
metric.measure(test_case)
|
|
38
|
+
return {
|
|
39
|
+
"score": metric.score,
|
|
40
|
+
"reason": getattr(metric, "reason", None),
|
|
41
|
+
"success": getattr(metric, "success", None),
|
|
42
|
+
}
|
|
43
|
+
except Exception as e:
|
|
44
|
+
return {
|
|
45
|
+
"score": None,
|
|
46
|
+
"reason": None,
|
|
47
|
+
"success": None,
|
|
48
|
+
"error": str(e),
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def _evaluate_all_metrics(self, test_case: ConversationalTestCase) -> dict[str, dict[str, float | str | None]]:
|
|
52
|
+
"""Evaluate all metrics on a test case.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
test_case: The conversational test case to evaluate
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Dictionary containing evaluation results for each metric
|
|
59
|
+
"""
|
|
60
|
+
results: dict[str, dict[str, float | str | None]] = {}
|
|
61
|
+
|
|
62
|
+
results["ConversationCompletenessMetric"] = self._evaluate_metric(self.completeness_metric, test_case)
|
|
63
|
+
|
|
64
|
+
results["KnowledgeRetentionMetric"] = self._evaluate_metric(self.knowledge_retention_metric, test_case)
|
|
65
|
+
|
|
66
|
+
results["ConversationRelevancyMetric"] = self._evaluate_metric(self.conversation_relevancy_metric, test_case)
|
|
67
|
+
|
|
68
|
+
return results
|
|
69
|
+
|
|
70
|
+
def evaluate_conversation(self, turns: list[Turn]) -> dict[str, dict[str, float | str | None]]:
|
|
71
|
+
"""Evaluate a conversation using DeepEval metrics.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
turns: List of conversation turns to evaluate
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Dictionary containing evaluation results for each metric
|
|
78
|
+
"""
|
|
79
|
+
if not turns:
|
|
80
|
+
return {}
|
|
81
|
+
|
|
82
|
+
# Convert ragbits Turn objects to deepeval LLMTestCase objects
|
|
83
|
+
deepeval_turns = []
|
|
84
|
+
for turn in turns:
|
|
85
|
+
# Each turn becomes an LLMTestCase where input is user message and actual_output is assistant response
|
|
86
|
+
deepeval_turns.append(LLMTestCase(input=turn.user, actual_output=turn.assistant))
|
|
87
|
+
|
|
88
|
+
# Create conversational test case
|
|
89
|
+
test_case = ConversationalTestCase(turns=deepeval_turns)
|
|
90
|
+
|
|
91
|
+
# Evaluate with each metric
|
|
92
|
+
return self._evaluate_all_metrics(test_case)
|