DeepFabric 4.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepfabric/__init__.py +70 -0
- deepfabric/__main__.py +6 -0
- deepfabric/auth.py +382 -0
- deepfabric/builders.py +303 -0
- deepfabric/builders_agent.py +1304 -0
- deepfabric/cli.py +1288 -0
- deepfabric/config.py +899 -0
- deepfabric/config_manager.py +251 -0
- deepfabric/constants.py +94 -0
- deepfabric/dataset_manager.py +534 -0
- deepfabric/error_codes.py +581 -0
- deepfabric/evaluation/__init__.py +47 -0
- deepfabric/evaluation/backends/__init__.py +32 -0
- deepfabric/evaluation/backends/ollama_backend.py +137 -0
- deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
- deepfabric/evaluation/backends/transformers_backend.py +326 -0
- deepfabric/evaluation/evaluator.py +845 -0
- deepfabric/evaluation/evaluators/__init__.py +13 -0
- deepfabric/evaluation/evaluators/base.py +104 -0
- deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
- deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
- deepfabric/evaluation/evaluators/registry.py +66 -0
- deepfabric/evaluation/inference.py +155 -0
- deepfabric/evaluation/metrics.py +397 -0
- deepfabric/evaluation/parser.py +304 -0
- deepfabric/evaluation/reporters/__init__.py +13 -0
- deepfabric/evaluation/reporters/base.py +56 -0
- deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
- deepfabric/evaluation/reporters/file_reporter.py +61 -0
- deepfabric/evaluation/reporters/multi_reporter.py +56 -0
- deepfabric/exceptions.py +67 -0
- deepfabric/factory.py +26 -0
- deepfabric/generator.py +1084 -0
- deepfabric/graph.py +545 -0
- deepfabric/hf_hub.py +214 -0
- deepfabric/kaggle_hub.py +219 -0
- deepfabric/llm/__init__.py +41 -0
- deepfabric/llm/api_key_verifier.py +534 -0
- deepfabric/llm/client.py +1206 -0
- deepfabric/llm/errors.py +105 -0
- deepfabric/llm/rate_limit_config.py +262 -0
- deepfabric/llm/rate_limit_detector.py +278 -0
- deepfabric/llm/retry_handler.py +270 -0
- deepfabric/metrics.py +212 -0
- deepfabric/progress.py +262 -0
- deepfabric/prompts.py +290 -0
- deepfabric/schemas.py +1000 -0
- deepfabric/spin/__init__.py +6 -0
- deepfabric/spin/client.py +263 -0
- deepfabric/spin/models.py +26 -0
- deepfabric/stream_simulator.py +90 -0
- deepfabric/tools/__init__.py +5 -0
- deepfabric/tools/defaults.py +85 -0
- deepfabric/tools/loader.py +87 -0
- deepfabric/tools/mcp_client.py +677 -0
- deepfabric/topic_manager.py +303 -0
- deepfabric/topic_model.py +20 -0
- deepfabric/training/__init__.py +35 -0
- deepfabric/training/api_key_prompt.py +302 -0
- deepfabric/training/callback.py +363 -0
- deepfabric/training/metrics_sender.py +301 -0
- deepfabric/tree.py +438 -0
- deepfabric/tui.py +1267 -0
- deepfabric/update_checker.py +166 -0
- deepfabric/utils.py +150 -0
- deepfabric/validation.py +143 -0
- deepfabric-4.4.0.dist-info/METADATA +702 -0
- deepfabric-4.4.0.dist-info/RECORD +71 -0
- deepfabric-4.4.0.dist-info/WHEEL +4 -0
- deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
- deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,845 @@
|
|
|
1
|
+
"""Main evaluator for running model evaluation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from datasets import Dataset as HFDataset
|
|
9
|
+
from pydantic import BaseModel, Field, ValidationError
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from tqdm.auto import tqdm
|
|
12
|
+
|
|
13
|
+
from ..metrics import trace
|
|
14
|
+
from ..schemas import ToolDefinition
|
|
15
|
+
from .evaluators import EvaluationContext, EvaluatorRegistry, EvaluatorResult
|
|
16
|
+
from .inference import InferenceConfig, ModelResponse, create_inference_backend
|
|
17
|
+
from .metrics import (
|
|
18
|
+
EvaluationMetrics,
|
|
19
|
+
SampleEvaluation,
|
|
20
|
+
compute_metrics,
|
|
21
|
+
)
|
|
22
|
+
from .parser import ExpectedToolCall, GroundTruth, GroundTruthParser
|
|
23
|
+
from .reporters import BaseReporter, CloudReporter, FileReporter, MultiReporter
|
|
24
|
+
|
|
25
|
+
console = Console()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EvaluatorConfig(BaseModel):
|
|
29
|
+
"""Configuration for evaluation run."""
|
|
30
|
+
|
|
31
|
+
dataset_path: str | None = Field(
|
|
32
|
+
default=None,
|
|
33
|
+
description="Path to evaluation dataset (JSONL). Optional if passing dataset to evaluate().",
|
|
34
|
+
)
|
|
35
|
+
output_path: str | None = Field(
|
|
36
|
+
default=None,
|
|
37
|
+
description="Path to save evaluation results",
|
|
38
|
+
)
|
|
39
|
+
model_path: str | None = Field(
|
|
40
|
+
default=None,
|
|
41
|
+
description="Path to model to evaluate (overrides inference_config.model_path)",
|
|
42
|
+
)
|
|
43
|
+
inference_config: InferenceConfig = Field(
|
|
44
|
+
description="Inference backend configuration (includes model_path)",
|
|
45
|
+
)
|
|
46
|
+
batch_size: int = Field(
|
|
47
|
+
default=1,
|
|
48
|
+
ge=1,
|
|
49
|
+
description="Batch size for evaluation",
|
|
50
|
+
)
|
|
51
|
+
max_samples: int | None = Field(
|
|
52
|
+
default=None,
|
|
53
|
+
description="Maximum number of samples to evaluate (None for all)",
|
|
54
|
+
)
|
|
55
|
+
save_predictions: bool = Field(
|
|
56
|
+
default=True,
|
|
57
|
+
description="Save individual predictions to output file",
|
|
58
|
+
)
|
|
59
|
+
metric_weights: dict[str, float] | None = Field(
|
|
60
|
+
default=None,
|
|
61
|
+
description="Custom weights for overall score computation",
|
|
62
|
+
)
|
|
63
|
+
evaluators: list[str] | dict[str, dict] = Field(
|
|
64
|
+
default=["tool_calling"],
|
|
65
|
+
description="List of evaluator names or dict of name -> config",
|
|
66
|
+
)
|
|
67
|
+
reporters: list[str] | dict[str, dict] = Field(
|
|
68
|
+
default=["file"],
|
|
69
|
+
description="List of reporter names or dict of name -> config",
|
|
70
|
+
)
|
|
71
|
+
cloud_api_key: str | None = Field(
|
|
72
|
+
default=None,
|
|
73
|
+
description="DeepFabric cloud API key (or use DEEPFABRIC_API_KEY env var)",
|
|
74
|
+
)
|
|
75
|
+
multi_turn: bool = Field(
|
|
76
|
+
default=False,
|
|
77
|
+
description="Enable multi-turn evaluation (loops through conversation using ground truth tool responses)",
|
|
78
|
+
)
|
|
79
|
+
max_turns: int = Field(
|
|
80
|
+
default=10,
|
|
81
|
+
ge=1,
|
|
82
|
+
description="Maximum number of turns in multi-turn evaluation",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class EvaluationResult(BaseModel):
|
|
87
|
+
"""Complete evaluation result."""
|
|
88
|
+
|
|
89
|
+
metrics: EvaluationMetrics = Field(description="Computed metrics")
|
|
90
|
+
predictions: list[SampleEvaluation] = Field(
|
|
91
|
+
description="Individual sample evaluations",
|
|
92
|
+
)
|
|
93
|
+
config: EvaluatorConfig = Field(description="Evaluation configuration used")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Evaluator:
|
|
97
|
+
"""Orchestrates model evaluation on tool-calling tasks."""
|
|
98
|
+
|
|
99
|
+
def __init__(self, config: EvaluatorConfig):
|
|
100
|
+
"""Initialize evaluator.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
config: Evaluation configuration
|
|
104
|
+
"""
|
|
105
|
+
self.config = config
|
|
106
|
+
self.backend = create_inference_backend(config.inference_config)
|
|
107
|
+
# Parser will be configured per-sample based on conversation metadata
|
|
108
|
+
self.parser: GroundTruthParser | None = None
|
|
109
|
+
|
|
110
|
+
# Initialize evaluator registry and active evaluators
|
|
111
|
+
self.registry = EvaluatorRegistry()
|
|
112
|
+
self.active_evaluators = self._initialize_evaluators()
|
|
113
|
+
|
|
114
|
+
# Initialize reporters
|
|
115
|
+
self.reporter = self._initialize_reporters()
|
|
116
|
+
|
|
117
|
+
# Track evaluator creation
|
|
118
|
+
trace(
|
|
119
|
+
"evaluator_created",
|
|
120
|
+
{
|
|
121
|
+
"backend": self.config.inference_config.backend,
|
|
122
|
+
"model_path": self.config.inference_config.model_path,
|
|
123
|
+
"has_adapter": self.config.inference_config.adapter_path is not None,
|
|
124
|
+
"evaluators": (
|
|
125
|
+
list(self.config.evaluators)
|
|
126
|
+
if isinstance(self.config.evaluators, list)
|
|
127
|
+
else list(self.config.evaluators.keys())
|
|
128
|
+
),
|
|
129
|
+
"reporters": (
|
|
130
|
+
list(self.config.reporters)
|
|
131
|
+
if isinstance(self.config.reporters, list)
|
|
132
|
+
else list(self.config.reporters.keys())
|
|
133
|
+
),
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def _initialize_evaluators(self) -> list:
|
|
138
|
+
"""Initialize evaluators based on config.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
List of active evaluator instances
|
|
142
|
+
"""
|
|
143
|
+
evaluators = []
|
|
144
|
+
|
|
145
|
+
if isinstance(self.config.evaluators, list):
|
|
146
|
+
# Simple list of names
|
|
147
|
+
for name in self.config.evaluators:
|
|
148
|
+
evaluators.append(self.registry.get(name))
|
|
149
|
+
else:
|
|
150
|
+
# Dict with configs
|
|
151
|
+
for name, eval_config in self.config.evaluators.items():
|
|
152
|
+
evaluators.append(self.registry.get(name, config=eval_config))
|
|
153
|
+
|
|
154
|
+
return evaluators
|
|
155
|
+
|
|
156
|
+
def _initialize_reporters(self) -> BaseReporter:
|
|
157
|
+
"""Initialize reporters based on config.
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
Reporter instance (may be MultiReporter)
|
|
161
|
+
"""
|
|
162
|
+
reporters: list[BaseReporter] = []
|
|
163
|
+
|
|
164
|
+
if isinstance(self.config.reporters, list):
|
|
165
|
+
# Simple list of names
|
|
166
|
+
for name in self.config.reporters:
|
|
167
|
+
if name == "file":
|
|
168
|
+
reporters.append(FileReporter({"path": self.config.output_path}))
|
|
169
|
+
elif name == "cloud":
|
|
170
|
+
reporters.append(CloudReporter({"api_key": self.config.cloud_api_key}))
|
|
171
|
+
else:
|
|
172
|
+
# Dict with configs
|
|
173
|
+
for name, reporter_config in self.config.reporters.items():
|
|
174
|
+
if name == "file":
|
|
175
|
+
# Merge output_path if not in config
|
|
176
|
+
if "path" not in reporter_config and self.config.output_path:
|
|
177
|
+
reporter_config["path"] = self.config.output_path
|
|
178
|
+
reporters.append(FileReporter(reporter_config))
|
|
179
|
+
elif name == "cloud":
|
|
180
|
+
# Merge api_key if not in config
|
|
181
|
+
if "api_key" not in reporter_config and self.config.cloud_api_key:
|
|
182
|
+
reporter_config["api_key"] = self.config.cloud_api_key
|
|
183
|
+
reporters.append(CloudReporter(reporter_config))
|
|
184
|
+
|
|
185
|
+
# Return single reporter or MultiReporter
|
|
186
|
+
if len(reporters) == 0:
|
|
187
|
+
# Default to file reporter
|
|
188
|
+
return FileReporter({"path": self.config.output_path})
|
|
189
|
+
if len(reporters) == 1:
|
|
190
|
+
return reporters[0]
|
|
191
|
+
return MultiReporter(reporters)
|
|
192
|
+
|
|
193
|
+
def load_dataset(self, dataset: HFDataset | None = None) -> list[dict[str, Any]]:
|
|
194
|
+
"""Load evaluation dataset from HFDataset or JSONL file.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
dataset: Optional HuggingFace Dataset. If provided, uses this instead
|
|
198
|
+
of loading from config.dataset_path.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
List of dataset samples
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
FileNotFoundError: If dataset file doesn't exist (when using file path)
|
|
205
|
+
ValueError: If dataset format is invalid or no dataset source provided
|
|
206
|
+
"""
|
|
207
|
+
if dataset is not None:
|
|
208
|
+
# Use provided HuggingFace Dataset
|
|
209
|
+
samples = [dict(sample) for sample in dataset]
|
|
210
|
+
elif self.config.dataset_path is not None:
|
|
211
|
+
# Load from file path
|
|
212
|
+
dataset_path = Path(self.config.dataset_path)
|
|
213
|
+
if not dataset_path.exists():
|
|
214
|
+
msg = f"Dataset file not found: {dataset_path}"
|
|
215
|
+
raise FileNotFoundError(msg)
|
|
216
|
+
|
|
217
|
+
samples = []
|
|
218
|
+
with dataset_path.open() as f:
|
|
219
|
+
for line_num, line in enumerate(f, 1):
|
|
220
|
+
try:
|
|
221
|
+
sample = json.loads(line.strip())
|
|
222
|
+
samples.append(sample)
|
|
223
|
+
except json.JSONDecodeError as e:
|
|
224
|
+
msg = f"Invalid JSON on line {line_num}: {e}"
|
|
225
|
+
raise ValueError(msg) from e
|
|
226
|
+
else:
|
|
227
|
+
msg = "No dataset provided. Either pass a HuggingFace Dataset to evaluate() or set dataset_path in config."
|
|
228
|
+
raise ValueError(msg)
|
|
229
|
+
|
|
230
|
+
if self.config.max_samples is not None:
|
|
231
|
+
samples = samples[: self.config.max_samples]
|
|
232
|
+
|
|
233
|
+
return samples
|
|
234
|
+
|
|
235
|
+
def extract_ground_truth(self, sample: dict[str, Any]) -> GroundTruth:
|
|
236
|
+
"""Extract ground truth from sample.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
sample: Dataset sample
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Parsed ground truth
|
|
243
|
+
"""
|
|
244
|
+
# Create parser for this sample's conversation type
|
|
245
|
+
from ..schemas import Conversation # noqa: PLC0415
|
|
246
|
+
|
|
247
|
+
# Convert sample dict to Conversation object
|
|
248
|
+
conversation = Conversation.model_validate(sample)
|
|
249
|
+
|
|
250
|
+
# Determine conversation type from metadata
|
|
251
|
+
metadata = conversation.metadata or {}
|
|
252
|
+
conv_type = metadata.get("conversation_type", "basic")
|
|
253
|
+
reasoning_style = metadata.get("reasoning_style")
|
|
254
|
+
agent_mode = metadata.get("agent_mode")
|
|
255
|
+
|
|
256
|
+
# Create parser with appropriate config
|
|
257
|
+
parser = GroundTruthParser(
|
|
258
|
+
conversation_type=conv_type, # type: ignore[arg-type]
|
|
259
|
+
reasoning_style=reasoning_style, # type: ignore[arg-type]
|
|
260
|
+
agent_mode=agent_mode, # type: ignore[arg-type]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return parser.parse(conversation)
|
|
264
|
+
|
|
265
|
+
def prepare_messages(self, sample: dict[str, Any]) -> list[dict[str, Any]]:
|
|
266
|
+
"""Prepare messages for model inference.
|
|
267
|
+
|
|
268
|
+
Extracts conversation up to the assistant's tool call.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
sample: Dataset sample
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
List of messages for inference
|
|
275
|
+
"""
|
|
276
|
+
messages = []
|
|
277
|
+
for msg in sample["messages"]:
|
|
278
|
+
# Stop before first assistant message (where tool call should be generated)
|
|
279
|
+
if msg["role"] == "assistant":
|
|
280
|
+
break
|
|
281
|
+
messages.append({"role": msg["role"], "content": msg["content"]})
|
|
282
|
+
|
|
283
|
+
return messages
|
|
284
|
+
|
|
285
|
+
def prepare_tools(self, sample: dict[str, Any]) -> list[ToolDefinition]:
|
|
286
|
+
"""Prepare tool definitions from sample.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
sample: Dataset sample
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
List of available tools
|
|
293
|
+
"""
|
|
294
|
+
from ..schemas import Conversation # noqa: PLC0415
|
|
295
|
+
|
|
296
|
+
# Convert to Conversation to access tools field
|
|
297
|
+
conversation = Conversation.model_validate(sample)
|
|
298
|
+
|
|
299
|
+
if not conversation.tools:
|
|
300
|
+
return []
|
|
301
|
+
|
|
302
|
+
# Convert from OpenAI format back to ToolDefinition
|
|
303
|
+
return [ToolDefinition.from_openai(tool) for tool in conversation.tools]
|
|
304
|
+
|
|
305
|
+
def build_tool_response_lookup(self, sample: dict[str, Any]) -> dict[str, dict[str, str]]:
|
|
306
|
+
"""Build lookup of tool responses by tool name and arguments.
|
|
307
|
+
|
|
308
|
+
For multi-turn evaluation, we need to look up tool responses when the
|
|
309
|
+
model makes tool calls. We index by (tool_name, arguments_json) to find
|
|
310
|
+
matching responses from ground truth.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
sample: Dataset sample
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Dict mapping "tool_name:args_json" -> {"content": response, "tool_call_id": id}
|
|
317
|
+
"""
|
|
318
|
+
lookup: dict[str, dict[str, str]] = {}
|
|
319
|
+
messages = sample.get("messages", [])
|
|
320
|
+
|
|
321
|
+
# Track pending tool calls from assistant messages
|
|
322
|
+
pending_tool_calls: dict[str, dict] = {} # tool_call_id -> {name, arguments}
|
|
323
|
+
|
|
324
|
+
for msg in messages:
|
|
325
|
+
role = msg.get("role")
|
|
326
|
+
|
|
327
|
+
# Collect tool calls from assistant messages
|
|
328
|
+
if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
|
|
329
|
+
for tc in msg["tool_calls"]:
|
|
330
|
+
tc_id = tc.get("id")
|
|
331
|
+
func = tc.get("function", {})
|
|
332
|
+
name = func.get("name", "")
|
|
333
|
+
args = func.get("arguments", "{}")
|
|
334
|
+
if tc_id:
|
|
335
|
+
pending_tool_calls[tc_id] = {"name": name, "arguments": args}
|
|
336
|
+
|
|
337
|
+
# Match tool responses to their calls
|
|
338
|
+
if role == "tool":
|
|
339
|
+
tc_id = msg.get("tool_call_id")
|
|
340
|
+
content = msg.get("content", "")
|
|
341
|
+
if tc_id and tc_id in pending_tool_calls:
|
|
342
|
+
call_info = pending_tool_calls[tc_id]
|
|
343
|
+
# Create lookup key from tool name + normalized arguments
|
|
344
|
+
try:
|
|
345
|
+
args_dict = json.loads(call_info["arguments"])
|
|
346
|
+
normalized_args = json.dumps(args_dict, sort_keys=True)
|
|
347
|
+
key = f"{call_info['name']}:{normalized_args}"
|
|
348
|
+
except (json.JSONDecodeError, TypeError):
|
|
349
|
+
# Fallback if arguments are not a valid JSON string
|
|
350
|
+
key = f"{call_info['name']}:{call_info['arguments']}"
|
|
351
|
+
lookup[key] = {"content": content, "tool_call_id": tc_id}
|
|
352
|
+
|
|
353
|
+
return lookup
|
|
354
|
+
|
|
355
|
+
def find_tool_response(
|
|
356
|
+
self,
|
|
357
|
+
tool_call: dict,
|
|
358
|
+
lookup: dict[str, dict[str, str]],
|
|
359
|
+
) -> dict[str, str] | None:
|
|
360
|
+
"""Find a matching tool response for a predicted tool call.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
tool_call: Predicted tool call with 'name' and 'arguments'
|
|
364
|
+
lookup: Tool response lookup from build_tool_response_lookup
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
Dict with 'content' and 'tool_call_id' if found, None otherwise
|
|
368
|
+
"""
|
|
369
|
+
name = tool_call.get("name", "")
|
|
370
|
+
args = tool_call.get("arguments", {})
|
|
371
|
+
|
|
372
|
+
# Normalize arguments to JSON string for comparison
|
|
373
|
+
args_json = json.dumps(args, sort_keys=True) if isinstance(args, dict) else str(args)
|
|
374
|
+
|
|
375
|
+
# Try exact match first
|
|
376
|
+
key = f"{name}:{args_json}"
|
|
377
|
+
if key in lookup:
|
|
378
|
+
return lookup[key]
|
|
379
|
+
|
|
380
|
+
# Try matching just by tool name (less strict)
|
|
381
|
+
# This helps when parameter values differ slightly
|
|
382
|
+
for lookup_key, response in lookup.items():
|
|
383
|
+
if lookup_key.startswith(f"{name}:"):
|
|
384
|
+
return response
|
|
385
|
+
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
def evaluate_sample(
|
|
389
|
+
self,
|
|
390
|
+
sample: dict[str, Any],
|
|
391
|
+
sample_id: int,
|
|
392
|
+
) -> SampleEvaluation:
|
|
393
|
+
"""Evaluate a single sample using configured evaluators.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
sample: Dataset sample
|
|
397
|
+
sample_id: Sample index
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
Evaluation result for this sample
|
|
401
|
+
"""
|
|
402
|
+
# Use multi-turn evaluation if enabled
|
|
403
|
+
if self.config.multi_turn:
|
|
404
|
+
return self.evaluate_sample_multi_turn(sample, sample_id)
|
|
405
|
+
|
|
406
|
+
try:
|
|
407
|
+
# Extract ground truth
|
|
408
|
+
ground_truth = self.extract_ground_truth(sample)
|
|
409
|
+
|
|
410
|
+
# Prepare inputs
|
|
411
|
+
messages = self.prepare_messages(sample)
|
|
412
|
+
tools = self.prepare_tools(sample)
|
|
413
|
+
|
|
414
|
+
# Run inference
|
|
415
|
+
response: ModelResponse = self.backend.generate(messages, tools)
|
|
416
|
+
|
|
417
|
+
# Create evaluation context
|
|
418
|
+
context = EvaluationContext(
|
|
419
|
+
messages=messages,
|
|
420
|
+
tools=tools,
|
|
421
|
+
sample_id=sample_id,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# Run all active evaluators
|
|
425
|
+
evaluator_results: list[EvaluatorResult] = []
|
|
426
|
+
for evaluator in self.active_evaluators:
|
|
427
|
+
result = evaluator.evaluate(ground_truth, response, context)
|
|
428
|
+
if result is not None: # Evaluator may skip
|
|
429
|
+
evaluator_results.append(result)
|
|
430
|
+
|
|
431
|
+
# Aggregate results for backwards compatibility
|
|
432
|
+
return self._aggregate_results(
|
|
433
|
+
sample_id=sample_id,
|
|
434
|
+
ground_truth=ground_truth,
|
|
435
|
+
response=response,
|
|
436
|
+
evaluator_results=evaluator_results,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
except Exception as e: # noqa: BLE001
|
|
440
|
+
# Return failed evaluation with safe defaults
|
|
441
|
+
query = ""
|
|
442
|
+
expected_tool = None
|
|
443
|
+
expected_params: dict[str, Any] = {}
|
|
444
|
+
expected_answer = None
|
|
445
|
+
|
|
446
|
+
# Try to extract ground truth if available
|
|
447
|
+
try:
|
|
448
|
+
gt = self.extract_ground_truth(sample)
|
|
449
|
+
query = gt.query
|
|
450
|
+
expected_tool = gt.expected_tool
|
|
451
|
+
expected_params = gt.expected_parameters
|
|
452
|
+
expected_answer = gt.expected_answer
|
|
453
|
+
except (KeyError, AttributeError, ValidationError):
|
|
454
|
+
pass
|
|
455
|
+
|
|
456
|
+
return SampleEvaluation(
|
|
457
|
+
sample_id=sample_id,
|
|
458
|
+
query=query,
|
|
459
|
+
expected_tool=expected_tool,
|
|
460
|
+
predicted_tool=None,
|
|
461
|
+
expected_parameters=expected_params,
|
|
462
|
+
predicted_parameters={},
|
|
463
|
+
expected_answer=expected_answer,
|
|
464
|
+
predicted_answer=None,
|
|
465
|
+
tool_selection_correct=False,
|
|
466
|
+
parameters_correct=False,
|
|
467
|
+
execution_valid=False,
|
|
468
|
+
response_score=0.0,
|
|
469
|
+
error=str(e),
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
def evaluate_sample_multi_turn(
|
|
473
|
+
self,
|
|
474
|
+
sample: dict[str, Any],
|
|
475
|
+
sample_id: int,
|
|
476
|
+
) -> SampleEvaluation:
|
|
477
|
+
"""Evaluate a single sample using multi-turn conversation.
|
|
478
|
+
|
|
479
|
+
Loops through the conversation, feeding tool responses back to the model
|
|
480
|
+
until it generates a final answer (no tool calls) or max turns reached.
|
|
481
|
+
|
|
482
|
+
Args:
|
|
483
|
+
sample: Dataset sample
|
|
484
|
+
sample_id: Sample index
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
Evaluation result for this sample
|
|
488
|
+
"""
|
|
489
|
+
try:
|
|
490
|
+
# Extract ground truth (includes all expected tools)
|
|
491
|
+
ground_truth = self.extract_ground_truth(sample)
|
|
492
|
+
|
|
493
|
+
# Prepare initial inputs
|
|
494
|
+
messages = self.prepare_messages(sample)
|
|
495
|
+
tools = self.prepare_tools(sample)
|
|
496
|
+
|
|
497
|
+
# Build lookup for tool responses from ground truth
|
|
498
|
+
tool_response_lookup = self.build_tool_response_lookup(sample)
|
|
499
|
+
|
|
500
|
+
# Track all predicted tool calls across turns
|
|
501
|
+
all_predicted_tool_calls: list[dict] = []
|
|
502
|
+
final_content = ""
|
|
503
|
+
|
|
504
|
+
# Multi-turn loop
|
|
505
|
+
for turn in range(self.config.max_turns):
|
|
506
|
+
# Run inference
|
|
507
|
+
response: ModelResponse = self.backend.generate(messages, tools)
|
|
508
|
+
final_content = response.content
|
|
509
|
+
|
|
510
|
+
# Check if model made tool calls
|
|
511
|
+
if not response.tool_calls:
|
|
512
|
+
# No tool calls - this is the final answer
|
|
513
|
+
break
|
|
514
|
+
|
|
515
|
+
# Process each tool call
|
|
516
|
+
for tool_call in response.tool_calls:
|
|
517
|
+
all_predicted_tool_calls.append(tool_call)
|
|
518
|
+
|
|
519
|
+
# Find matching tool response from ground truth
|
|
520
|
+
tool_response = self.find_tool_response(tool_call, tool_response_lookup)
|
|
521
|
+
|
|
522
|
+
if tool_response is None:
|
|
523
|
+
# Model called a tool we don't have a response for
|
|
524
|
+
# Continue anyway with an error message
|
|
525
|
+
tool_response = {
|
|
526
|
+
"content": json.dumps({"error": "Tool not found in ground truth"}),
|
|
527
|
+
"tool_call_id": f"generated_{turn}_{len(all_predicted_tool_calls)}",
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
# Add assistant message with tool call to conversation
|
|
531
|
+
messages.append(
|
|
532
|
+
{
|
|
533
|
+
"role": "assistant",
|
|
534
|
+
"content": "",
|
|
535
|
+
"tool_calls": [
|
|
536
|
+
{
|
|
537
|
+
"id": tool_response["tool_call_id"],
|
|
538
|
+
"type": "function",
|
|
539
|
+
"function": {
|
|
540
|
+
"name": tool_call.get("name", ""),
|
|
541
|
+
"arguments": json.dumps(tool_call.get("arguments", {})),
|
|
542
|
+
},
|
|
543
|
+
}
|
|
544
|
+
],
|
|
545
|
+
}
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
# Add tool response to conversation
|
|
549
|
+
messages.append(
|
|
550
|
+
{
|
|
551
|
+
"role": "tool",
|
|
552
|
+
"tool_call_id": tool_response["tool_call_id"],
|
|
553
|
+
"content": tool_response["content"],
|
|
554
|
+
}
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
# Now compute metrics comparing predicted vs expected tool calls
|
|
558
|
+
return self._compute_multi_turn_metrics(
|
|
559
|
+
sample_id=sample_id,
|
|
560
|
+
ground_truth=ground_truth,
|
|
561
|
+
predicted_tool_calls=all_predicted_tool_calls,
|
|
562
|
+
final_content=final_content,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
except Exception as e: # noqa: BLE001
|
|
566
|
+
# Return failed evaluation with safe defaults
|
|
567
|
+
query = ""
|
|
568
|
+
expected_tool = None
|
|
569
|
+
expected_params: dict[str, Any] = {}
|
|
570
|
+
expected_answer = None
|
|
571
|
+
|
|
572
|
+
try:
|
|
573
|
+
gt = self.extract_ground_truth(sample)
|
|
574
|
+
query = gt.query
|
|
575
|
+
expected_tool = gt.expected_tool
|
|
576
|
+
expected_params = gt.expected_parameters
|
|
577
|
+
expected_answer = gt.expected_answer
|
|
578
|
+
except (KeyError, AttributeError, ValidationError):
|
|
579
|
+
pass
|
|
580
|
+
|
|
581
|
+
return SampleEvaluation(
|
|
582
|
+
sample_id=sample_id,
|
|
583
|
+
query=query,
|
|
584
|
+
expected_tool=expected_tool,
|
|
585
|
+
predicted_tool=None,
|
|
586
|
+
expected_parameters=expected_params,
|
|
587
|
+
predicted_parameters={},
|
|
588
|
+
expected_answer=expected_answer,
|
|
589
|
+
predicted_answer=None,
|
|
590
|
+
tool_selection_correct=False,
|
|
591
|
+
parameters_correct=False,
|
|
592
|
+
execution_valid=False,
|
|
593
|
+
response_score=0.0,
|
|
594
|
+
error=str(e),
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
def _compute_multi_turn_metrics(
|
|
598
|
+
self,
|
|
599
|
+
sample_id: int,
|
|
600
|
+
ground_truth: GroundTruth,
|
|
601
|
+
predicted_tool_calls: list[dict],
|
|
602
|
+
final_content: str,
|
|
603
|
+
) -> SampleEvaluation:
|
|
604
|
+
"""Compute metrics for multi-turn evaluation.
|
|
605
|
+
|
|
606
|
+
Compares predicted tool calls against expected tools using set comparison.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
sample_id: Sample index
|
|
610
|
+
ground_truth: Expected values including all expected tools
|
|
611
|
+
predicted_tool_calls: All tool calls made by model across turns
|
|
612
|
+
final_content: Final model response content
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
SampleEvaluation with computed metrics
|
|
616
|
+
"""
|
|
617
|
+
# Build sets of tool names for comparison
|
|
618
|
+
expected_tools_list = ground_truth.expected_tools or []
|
|
619
|
+
expected_tool_names = {tc.tool_name for tc in expected_tools_list}
|
|
620
|
+
predicted_tool_names = {tc.get("name", "") for tc in (predicted_tool_calls or [])}
|
|
621
|
+
|
|
622
|
+
# Tool set coverage: what fraction of expected tools were called?
|
|
623
|
+
if expected_tool_names:
|
|
624
|
+
matched_tools = expected_tool_names & predicted_tool_names
|
|
625
|
+
tool_coverage = len(matched_tools) / len(expected_tool_names)
|
|
626
|
+
else:
|
|
627
|
+
tool_coverage = 1.0 if not predicted_tool_names else 0.0
|
|
628
|
+
|
|
629
|
+
# Tool set precision: what fraction of predicted tools were expected?
|
|
630
|
+
# (Computed but stored in response_score for now, could be expanded later)
|
|
631
|
+
if predicted_tool_names:
|
|
632
|
+
matched_tools = expected_tool_names & predicted_tool_names
|
|
633
|
+
_tool_precision = len(matched_tools) / len(predicted_tool_names) # noqa: F841
|
|
634
|
+
else:
|
|
635
|
+
_tool_precision = 1.0 if not expected_tool_names else 0.0 # noqa: F841
|
|
636
|
+
|
|
637
|
+
# Overall tool selection is correct if coverage is 100%
|
|
638
|
+
tool_selection_correct = tool_coverage == 1.0
|
|
639
|
+
|
|
640
|
+
# For backwards compatibility, use first predicted/expected tool
|
|
641
|
+
first_predicted_tool = predicted_tool_calls[0].get("name") if predicted_tool_calls else None
|
|
642
|
+
first_predicted_params = (
|
|
643
|
+
predicted_tool_calls[0].get("arguments", {}) if predicted_tool_calls else {}
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
# Parameter accuracy: for matched tools, check if params are structurally correct
|
|
647
|
+
params_correct = self._check_parameter_structure(
|
|
648
|
+
ground_truth.expected_tools,
|
|
649
|
+
predicted_tool_calls,
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
# Execution valid if we got through the conversation
|
|
653
|
+
execution_valid = len(predicted_tool_calls) > 0 or final_content != ""
|
|
654
|
+
|
|
655
|
+
return SampleEvaluation(
|
|
656
|
+
sample_id=sample_id,
|
|
657
|
+
query=ground_truth.query,
|
|
658
|
+
expected_tool=ground_truth.expected_tool,
|
|
659
|
+
predicted_tool=first_predicted_tool,
|
|
660
|
+
expected_parameters=ground_truth.expected_parameters,
|
|
661
|
+
predicted_parameters=first_predicted_params
|
|
662
|
+
if isinstance(first_predicted_params, dict)
|
|
663
|
+
else {},
|
|
664
|
+
expected_answer=ground_truth.expected_answer,
|
|
665
|
+
predicted_answer=final_content,
|
|
666
|
+
tool_selection_correct=tool_selection_correct,
|
|
667
|
+
parameters_correct=params_correct,
|
|
668
|
+
execution_valid=execution_valid,
|
|
669
|
+
response_score=tool_coverage, # Use coverage as response score
|
|
670
|
+
error=None,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
def _check_parameter_structure(
|
|
674
|
+
self,
|
|
675
|
+
expected_tools: list[ExpectedToolCall],
|
|
676
|
+
predicted_tool_calls: list[dict],
|
|
677
|
+
) -> bool:
|
|
678
|
+
"""Check if predicted tool calls have correct parameter structure.
|
|
679
|
+
|
|
680
|
+
For each matched tool, verifies that predicted params have the same keys.
|
|
681
|
+
Does not check parameter values, only structure.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
expected_tools: List of ExpectedToolCall from ground truth
|
|
685
|
+
predicted_tool_calls: List of predicted tool call dicts
|
|
686
|
+
|
|
687
|
+
Returns:
|
|
688
|
+
True if all matched tools have correct parameter structure
|
|
689
|
+
"""
|
|
690
|
+
# Build lookup of expected params by tool name
|
|
691
|
+
expected_params_by_tool: dict[str, set[str]] = {}
|
|
692
|
+
for tc in expected_tools or []:
|
|
693
|
+
if tc.tool_name not in expected_params_by_tool:
|
|
694
|
+
expected_params_by_tool[tc.tool_name] = set(tc.parameters.keys())
|
|
695
|
+
|
|
696
|
+
# Check each predicted tool call
|
|
697
|
+
for pred_call in predicted_tool_calls or []:
|
|
698
|
+
tool_name = pred_call.get("name", "")
|
|
699
|
+
pred_args = pred_call.get("arguments", {})
|
|
700
|
+
|
|
701
|
+
if tool_name in expected_params_by_tool:
|
|
702
|
+
expected_keys = expected_params_by_tool[tool_name]
|
|
703
|
+
pred_keys = set(pred_args.keys()) if isinstance(pred_args, dict) else set()
|
|
704
|
+
|
|
705
|
+
# Check if predicted has all expected keys (may have extra)
|
|
706
|
+
if not expected_keys.issubset(pred_keys):
|
|
707
|
+
return False
|
|
708
|
+
|
|
709
|
+
return True
|
|
710
|
+
|
|
711
|
+
def _aggregate_results(
|
|
712
|
+
self,
|
|
713
|
+
sample_id: int,
|
|
714
|
+
ground_truth: GroundTruth,
|
|
715
|
+
response: ModelResponse,
|
|
716
|
+
evaluator_results: list[EvaluatorResult],
|
|
717
|
+
) -> SampleEvaluation:
|
|
718
|
+
"""Aggregate evaluator results into SampleEvaluation.
|
|
719
|
+
|
|
720
|
+
Args:
|
|
721
|
+
sample_id: Sample index
|
|
722
|
+
ground_truth: Expected values
|
|
723
|
+
response: Model response
|
|
724
|
+
evaluator_results: Results from all evaluators
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
SampleEvaluation with aggregated metrics
|
|
728
|
+
"""
|
|
729
|
+
# Extract tool calling metrics from evaluator results
|
|
730
|
+
tool_correct = False
|
|
731
|
+
params_correct = False
|
|
732
|
+
execution_valid = False
|
|
733
|
+
predicted_tool = None
|
|
734
|
+
predicted_params = {}
|
|
735
|
+
|
|
736
|
+
# Extract predictions from response
|
|
737
|
+
if response.tool_call:
|
|
738
|
+
predicted_tool = response.tool_call.get("name")
|
|
739
|
+
predicted_params = response.tool_call.get("arguments", {})
|
|
740
|
+
|
|
741
|
+
# Get metrics from tool_calling evaluator
|
|
742
|
+
for result in evaluator_results:
|
|
743
|
+
if result.evaluator_name == "tool_calling":
|
|
744
|
+
metrics = result.metrics
|
|
745
|
+
tool_correct = metrics.get("tool_selection_accuracy", 0.0) == 1.0
|
|
746
|
+
params_correct = metrics.get("parameter_accuracy", 0.0) == 1.0
|
|
747
|
+
execution_valid = metrics.get("execution_valid", 0.0) == 1.0
|
|
748
|
+
|
|
749
|
+
# Return backwards-compatible SampleEvaluation
|
|
750
|
+
return SampleEvaluation(
|
|
751
|
+
sample_id=sample_id,
|
|
752
|
+
query=ground_truth.query,
|
|
753
|
+
expected_tool=ground_truth.expected_tool,
|
|
754
|
+
predicted_tool=predicted_tool,
|
|
755
|
+
expected_parameters=ground_truth.expected_parameters,
|
|
756
|
+
predicted_parameters=predicted_params,
|
|
757
|
+
expected_answer=ground_truth.expected_answer,
|
|
758
|
+
predicted_answer=response.content,
|
|
759
|
+
tool_selection_correct=tool_correct,
|
|
760
|
+
parameters_correct=params_correct,
|
|
761
|
+
execution_valid=execution_valid,
|
|
762
|
+
response_score=0.0, # TODO: Could use semantic similarity for response quality evaluation in the future, but disabled for tool-calling mode
|
|
763
|
+
error=None,
|
|
764
|
+
)
|
|
765
|
+
|
|
766
|
+
def evaluate(self, dataset: HFDataset | None = None) -> EvaluationResult:
|
|
767
|
+
"""Run full evaluation.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
dataset: Optional HuggingFace Dataset to evaluate. If not provided,
|
|
771
|
+
loads from config.dataset_path.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
Complete evaluation result with metrics and predictions
|
|
775
|
+
"""
|
|
776
|
+
console.print("[bold blue]Loading dataset...[/bold blue]")
|
|
777
|
+
samples = self.load_dataset(dataset)
|
|
778
|
+
console.print(f"Loaded {len(samples)} samples")
|
|
779
|
+
|
|
780
|
+
console.print("[bold blue]Running evaluation...[/bold blue]")
|
|
781
|
+
evaluations = []
|
|
782
|
+
|
|
783
|
+
for idx, sample in tqdm(enumerate(samples), total=len(samples), desc="Evaluating"):
|
|
784
|
+
eval_result = self.evaluate_sample(sample, idx)
|
|
785
|
+
evaluations.append(eval_result)
|
|
786
|
+
|
|
787
|
+
# Stream sample to reporters (for cloud real-time tracking)
|
|
788
|
+
self.reporter.report_sample(eval_result)
|
|
789
|
+
|
|
790
|
+
console.print("[bold green]Evaluation complete![/bold green]")
|
|
791
|
+
|
|
792
|
+
# Compute metrics
|
|
793
|
+
metrics = compute_metrics(evaluations, self.config.metric_weights)
|
|
794
|
+
|
|
795
|
+
# Create result
|
|
796
|
+
result = EvaluationResult(
|
|
797
|
+
metrics=metrics,
|
|
798
|
+
predictions=evaluations,
|
|
799
|
+
config=self.config,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
# Track evaluation completion
|
|
803
|
+
trace(
|
|
804
|
+
"evaluation_completed",
|
|
805
|
+
{
|
|
806
|
+
"backend": self.config.inference_config.backend,
|
|
807
|
+
"model_path": self.config.inference_config.model_path,
|
|
808
|
+
"has_adapter": self.config.inference_config.adapter_path is not None,
|
|
809
|
+
"samples_evaluated": metrics.samples_evaluated,
|
|
810
|
+
"samples_processed": metrics.samples_processed,
|
|
811
|
+
"processing_errors": metrics.processing_errors,
|
|
812
|
+
"tool_selection_accuracy": round(metrics.tool_selection_accuracy, 4),
|
|
813
|
+
"parameter_accuracy": round(metrics.parameter_accuracy, 4),
|
|
814
|
+
"execution_success_rate": round(metrics.execution_success_rate, 4),
|
|
815
|
+
"overall_score": round(metrics.overall_score, 4),
|
|
816
|
+
"success": metrics.processing_errors == 0,
|
|
817
|
+
},
|
|
818
|
+
)
|
|
819
|
+
|
|
820
|
+
# Report results using configured reporters
|
|
821
|
+
if self.config.save_predictions:
|
|
822
|
+
self.reporter.report(result)
|
|
823
|
+
|
|
824
|
+
return result
|
|
825
|
+
|
|
826
|
+
def cleanup(self) -> None:
|
|
827
|
+
"""Clean up resources."""
|
|
828
|
+
self.backend.cleanup()
|
|
829
|
+
|
|
830
|
+
def print_summary(self, metrics: EvaluationMetrics) -> None:
|
|
831
|
+
"""Print evaluation summary.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
metrics: Computed metrics
|
|
835
|
+
"""
|
|
836
|
+
console.print("\n[bold]Evaluation Summary[/bold]")
|
|
837
|
+
console.print(f"Samples Evaluated: {metrics.samples_evaluated}")
|
|
838
|
+
console.print(f"Processed Successfully: {metrics.samples_processed}")
|
|
839
|
+
console.print(f"Processing Errors: {metrics.processing_errors}")
|
|
840
|
+
console.print("\n[bold]Metrics[/bold]")
|
|
841
|
+
console.print(f"Tool Selection Accuracy: {metrics.tool_selection_accuracy:.2%}")
|
|
842
|
+
console.print(f"Parameter Accuracy: {metrics.parameter_accuracy:.2%}")
|
|
843
|
+
console.print(f"Execution Success Rate: {metrics.execution_success_rate:.2%}")
|
|
844
|
+
console.print(f"Response Quality: {metrics.response_quality:.2%}")
|
|
845
|
+
console.print(f"\n[bold green]Overall Score: {metrics.overall_score:.2%}[/bold green]")
|