DeepFabric 4.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. deepfabric/__init__.py +70 -0
  2. deepfabric/__main__.py +6 -0
  3. deepfabric/auth.py +382 -0
  4. deepfabric/builders.py +303 -0
  5. deepfabric/builders_agent.py +1304 -0
  6. deepfabric/cli.py +1288 -0
  7. deepfabric/config.py +899 -0
  8. deepfabric/config_manager.py +251 -0
  9. deepfabric/constants.py +94 -0
  10. deepfabric/dataset_manager.py +534 -0
  11. deepfabric/error_codes.py +581 -0
  12. deepfabric/evaluation/__init__.py +47 -0
  13. deepfabric/evaluation/backends/__init__.py +32 -0
  14. deepfabric/evaluation/backends/ollama_backend.py +137 -0
  15. deepfabric/evaluation/backends/tool_call_parsers.py +409 -0
  16. deepfabric/evaluation/backends/transformers_backend.py +326 -0
  17. deepfabric/evaluation/evaluator.py +845 -0
  18. deepfabric/evaluation/evaluators/__init__.py +13 -0
  19. deepfabric/evaluation/evaluators/base.py +104 -0
  20. deepfabric/evaluation/evaluators/builtin/__init__.py +5 -0
  21. deepfabric/evaluation/evaluators/builtin/tool_calling.py +93 -0
  22. deepfabric/evaluation/evaluators/registry.py +66 -0
  23. deepfabric/evaluation/inference.py +155 -0
  24. deepfabric/evaluation/metrics.py +397 -0
  25. deepfabric/evaluation/parser.py +304 -0
  26. deepfabric/evaluation/reporters/__init__.py +13 -0
  27. deepfabric/evaluation/reporters/base.py +56 -0
  28. deepfabric/evaluation/reporters/cloud_reporter.py +195 -0
  29. deepfabric/evaluation/reporters/file_reporter.py +61 -0
  30. deepfabric/evaluation/reporters/multi_reporter.py +56 -0
  31. deepfabric/exceptions.py +67 -0
  32. deepfabric/factory.py +26 -0
  33. deepfabric/generator.py +1084 -0
  34. deepfabric/graph.py +545 -0
  35. deepfabric/hf_hub.py +214 -0
  36. deepfabric/kaggle_hub.py +219 -0
  37. deepfabric/llm/__init__.py +41 -0
  38. deepfabric/llm/api_key_verifier.py +534 -0
  39. deepfabric/llm/client.py +1206 -0
  40. deepfabric/llm/errors.py +105 -0
  41. deepfabric/llm/rate_limit_config.py +262 -0
  42. deepfabric/llm/rate_limit_detector.py +278 -0
  43. deepfabric/llm/retry_handler.py +270 -0
  44. deepfabric/metrics.py +212 -0
  45. deepfabric/progress.py +262 -0
  46. deepfabric/prompts.py +290 -0
  47. deepfabric/schemas.py +1000 -0
  48. deepfabric/spin/__init__.py +6 -0
  49. deepfabric/spin/client.py +263 -0
  50. deepfabric/spin/models.py +26 -0
  51. deepfabric/stream_simulator.py +90 -0
  52. deepfabric/tools/__init__.py +5 -0
  53. deepfabric/tools/defaults.py +85 -0
  54. deepfabric/tools/loader.py +87 -0
  55. deepfabric/tools/mcp_client.py +677 -0
  56. deepfabric/topic_manager.py +303 -0
  57. deepfabric/topic_model.py +20 -0
  58. deepfabric/training/__init__.py +35 -0
  59. deepfabric/training/api_key_prompt.py +302 -0
  60. deepfabric/training/callback.py +363 -0
  61. deepfabric/training/metrics_sender.py +301 -0
  62. deepfabric/tree.py +438 -0
  63. deepfabric/tui.py +1267 -0
  64. deepfabric/update_checker.py +166 -0
  65. deepfabric/utils.py +150 -0
  66. deepfabric/validation.py +143 -0
  67. deepfabric-4.4.0.dist-info/METADATA +702 -0
  68. deepfabric-4.4.0.dist-info/RECORD +71 -0
  69. deepfabric-4.4.0.dist-info/WHEEL +4 -0
  70. deepfabric-4.4.0.dist-info/entry_points.txt +2 -0
  71. deepfabric-4.4.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,845 @@
1
+ """Main evaluator for running model evaluation."""
2
+
3
+ import json
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ from datasets import Dataset as HFDataset
9
+ from pydantic import BaseModel, Field, ValidationError
10
+ from rich.console import Console
11
+ from tqdm.auto import tqdm
12
+
13
+ from ..metrics import trace
14
+ from ..schemas import ToolDefinition
15
+ from .evaluators import EvaluationContext, EvaluatorRegistry, EvaluatorResult
16
+ from .inference import InferenceConfig, ModelResponse, create_inference_backend
17
+ from .metrics import (
18
+ EvaluationMetrics,
19
+ SampleEvaluation,
20
+ compute_metrics,
21
+ )
22
+ from .parser import ExpectedToolCall, GroundTruth, GroundTruthParser
23
+ from .reporters import BaseReporter, CloudReporter, FileReporter, MultiReporter
24
+
25
+ console = Console()
26
+
27
+
28
+ class EvaluatorConfig(BaseModel):
29
+ """Configuration for evaluation run."""
30
+
31
+ dataset_path: str | None = Field(
32
+ default=None,
33
+ description="Path to evaluation dataset (JSONL). Optional if passing dataset to evaluate().",
34
+ )
35
+ output_path: str | None = Field(
36
+ default=None,
37
+ description="Path to save evaluation results",
38
+ )
39
+ model_path: str | None = Field(
40
+ default=None,
41
+ description="Path to model to evaluate (overrides inference_config.model_path)",
42
+ )
43
+ inference_config: InferenceConfig = Field(
44
+ description="Inference backend configuration (includes model_path)",
45
+ )
46
+ batch_size: int = Field(
47
+ default=1,
48
+ ge=1,
49
+ description="Batch size for evaluation",
50
+ )
51
+ max_samples: int | None = Field(
52
+ default=None,
53
+ description="Maximum number of samples to evaluate (None for all)",
54
+ )
55
+ save_predictions: bool = Field(
56
+ default=True,
57
+ description="Save individual predictions to output file",
58
+ )
59
+ metric_weights: dict[str, float] | None = Field(
60
+ default=None,
61
+ description="Custom weights for overall score computation",
62
+ )
63
+ evaluators: list[str] | dict[str, dict] = Field(
64
+ default=["tool_calling"],
65
+ description="List of evaluator names or dict of name -> config",
66
+ )
67
+ reporters: list[str] | dict[str, dict] = Field(
68
+ default=["file"],
69
+ description="List of reporter names or dict of name -> config",
70
+ )
71
+ cloud_api_key: str | None = Field(
72
+ default=None,
73
+ description="DeepFabric cloud API key (or use DEEPFABRIC_API_KEY env var)",
74
+ )
75
+ multi_turn: bool = Field(
76
+ default=False,
77
+ description="Enable multi-turn evaluation (loops through conversation using ground truth tool responses)",
78
+ )
79
+ max_turns: int = Field(
80
+ default=10,
81
+ ge=1,
82
+ description="Maximum number of turns in multi-turn evaluation",
83
+ )
84
+
85
+
86
+ class EvaluationResult(BaseModel):
87
+ """Complete evaluation result."""
88
+
89
+ metrics: EvaluationMetrics = Field(description="Computed metrics")
90
+ predictions: list[SampleEvaluation] = Field(
91
+ description="Individual sample evaluations",
92
+ )
93
+ config: EvaluatorConfig = Field(description="Evaluation configuration used")
94
+
95
+
96
+ class Evaluator:
97
+ """Orchestrates model evaluation on tool-calling tasks."""
98
+
99
+ def __init__(self, config: EvaluatorConfig):
100
+ """Initialize evaluator.
101
+
102
+ Args:
103
+ config: Evaluation configuration
104
+ """
105
+ self.config = config
106
+ self.backend = create_inference_backend(config.inference_config)
107
+ # Parser will be configured per-sample based on conversation metadata
108
+ self.parser: GroundTruthParser | None = None
109
+
110
+ # Initialize evaluator registry and active evaluators
111
+ self.registry = EvaluatorRegistry()
112
+ self.active_evaluators = self._initialize_evaluators()
113
+
114
+ # Initialize reporters
115
+ self.reporter = self._initialize_reporters()
116
+
117
+ # Track evaluator creation
118
+ trace(
119
+ "evaluator_created",
120
+ {
121
+ "backend": self.config.inference_config.backend,
122
+ "model_path": self.config.inference_config.model_path,
123
+ "has_adapter": self.config.inference_config.adapter_path is not None,
124
+ "evaluators": (
125
+ list(self.config.evaluators)
126
+ if isinstance(self.config.evaluators, list)
127
+ else list(self.config.evaluators.keys())
128
+ ),
129
+ "reporters": (
130
+ list(self.config.reporters)
131
+ if isinstance(self.config.reporters, list)
132
+ else list(self.config.reporters.keys())
133
+ ),
134
+ },
135
+ )
136
+
137
+ def _initialize_evaluators(self) -> list:
138
+ """Initialize evaluators based on config.
139
+
140
+ Returns:
141
+ List of active evaluator instances
142
+ """
143
+ evaluators = []
144
+
145
+ if isinstance(self.config.evaluators, list):
146
+ # Simple list of names
147
+ for name in self.config.evaluators:
148
+ evaluators.append(self.registry.get(name))
149
+ else:
150
+ # Dict with configs
151
+ for name, eval_config in self.config.evaluators.items():
152
+ evaluators.append(self.registry.get(name, config=eval_config))
153
+
154
+ return evaluators
155
+
156
+ def _initialize_reporters(self) -> BaseReporter:
157
+ """Initialize reporters based on config.
158
+
159
+ Returns:
160
+ Reporter instance (may be MultiReporter)
161
+ """
162
+ reporters: list[BaseReporter] = []
163
+
164
+ if isinstance(self.config.reporters, list):
165
+ # Simple list of names
166
+ for name in self.config.reporters:
167
+ if name == "file":
168
+ reporters.append(FileReporter({"path": self.config.output_path}))
169
+ elif name == "cloud":
170
+ reporters.append(CloudReporter({"api_key": self.config.cloud_api_key}))
171
+ else:
172
+ # Dict with configs
173
+ for name, reporter_config in self.config.reporters.items():
174
+ if name == "file":
175
+ # Merge output_path if not in config
176
+ if "path" not in reporter_config and self.config.output_path:
177
+ reporter_config["path"] = self.config.output_path
178
+ reporters.append(FileReporter(reporter_config))
179
+ elif name == "cloud":
180
+ # Merge api_key if not in config
181
+ if "api_key" not in reporter_config and self.config.cloud_api_key:
182
+ reporter_config["api_key"] = self.config.cloud_api_key
183
+ reporters.append(CloudReporter(reporter_config))
184
+
185
+ # Return single reporter or MultiReporter
186
+ if len(reporters) == 0:
187
+ # Default to file reporter
188
+ return FileReporter({"path": self.config.output_path})
189
+ if len(reporters) == 1:
190
+ return reporters[0]
191
+ return MultiReporter(reporters)
192
+
193
+ def load_dataset(self, dataset: HFDataset | None = None) -> list[dict[str, Any]]:
194
+ """Load evaluation dataset from HFDataset or JSONL file.
195
+
196
+ Args:
197
+ dataset: Optional HuggingFace Dataset. If provided, uses this instead
198
+ of loading from config.dataset_path.
199
+
200
+ Returns:
201
+ List of dataset samples
202
+
203
+ Raises:
204
+ FileNotFoundError: If dataset file doesn't exist (when using file path)
205
+ ValueError: If dataset format is invalid or no dataset source provided
206
+ """
207
+ if dataset is not None:
208
+ # Use provided HuggingFace Dataset
209
+ samples = [dict(sample) for sample in dataset]
210
+ elif self.config.dataset_path is not None:
211
+ # Load from file path
212
+ dataset_path = Path(self.config.dataset_path)
213
+ if not dataset_path.exists():
214
+ msg = f"Dataset file not found: {dataset_path}"
215
+ raise FileNotFoundError(msg)
216
+
217
+ samples = []
218
+ with dataset_path.open() as f:
219
+ for line_num, line in enumerate(f, 1):
220
+ try:
221
+ sample = json.loads(line.strip())
222
+ samples.append(sample)
223
+ except json.JSONDecodeError as e:
224
+ msg = f"Invalid JSON on line {line_num}: {e}"
225
+ raise ValueError(msg) from e
226
+ else:
227
+ msg = "No dataset provided. Either pass a HuggingFace Dataset to evaluate() or set dataset_path in config."
228
+ raise ValueError(msg)
229
+
230
+ if self.config.max_samples is not None:
231
+ samples = samples[: self.config.max_samples]
232
+
233
+ return samples
234
+
235
+ def extract_ground_truth(self, sample: dict[str, Any]) -> GroundTruth:
236
+ """Extract ground truth from sample.
237
+
238
+ Args:
239
+ sample: Dataset sample
240
+
241
+ Returns:
242
+ Parsed ground truth
243
+ """
244
+ # Create parser for this sample's conversation type
245
+ from ..schemas import Conversation # noqa: PLC0415
246
+
247
+ # Convert sample dict to Conversation object
248
+ conversation = Conversation.model_validate(sample)
249
+
250
+ # Determine conversation type from metadata
251
+ metadata = conversation.metadata or {}
252
+ conv_type = metadata.get("conversation_type", "basic")
253
+ reasoning_style = metadata.get("reasoning_style")
254
+ agent_mode = metadata.get("agent_mode")
255
+
256
+ # Create parser with appropriate config
257
+ parser = GroundTruthParser(
258
+ conversation_type=conv_type, # type: ignore[arg-type]
259
+ reasoning_style=reasoning_style, # type: ignore[arg-type]
260
+ agent_mode=agent_mode, # type: ignore[arg-type]
261
+ )
262
+
263
+ return parser.parse(conversation)
264
+
265
+ def prepare_messages(self, sample: dict[str, Any]) -> list[dict[str, Any]]:
266
+ """Prepare messages for model inference.
267
+
268
+ Extracts conversation up to the assistant's tool call.
269
+
270
+ Args:
271
+ sample: Dataset sample
272
+
273
+ Returns:
274
+ List of messages for inference
275
+ """
276
+ messages = []
277
+ for msg in sample["messages"]:
278
+ # Stop before first assistant message (where tool call should be generated)
279
+ if msg["role"] == "assistant":
280
+ break
281
+ messages.append({"role": msg["role"], "content": msg["content"]})
282
+
283
+ return messages
284
+
285
+ def prepare_tools(self, sample: dict[str, Any]) -> list[ToolDefinition]:
286
+ """Prepare tool definitions from sample.
287
+
288
+ Args:
289
+ sample: Dataset sample
290
+
291
+ Returns:
292
+ List of available tools
293
+ """
294
+ from ..schemas import Conversation # noqa: PLC0415
295
+
296
+ # Convert to Conversation to access tools field
297
+ conversation = Conversation.model_validate(sample)
298
+
299
+ if not conversation.tools:
300
+ return []
301
+
302
+ # Convert from OpenAI format back to ToolDefinition
303
+ return [ToolDefinition.from_openai(tool) for tool in conversation.tools]
304
+
305
+ def build_tool_response_lookup(self, sample: dict[str, Any]) -> dict[str, dict[str, str]]:
306
+ """Build lookup of tool responses by tool name and arguments.
307
+
308
+ For multi-turn evaluation, we need to look up tool responses when the
309
+ model makes tool calls. We index by (tool_name, arguments_json) to find
310
+ matching responses from ground truth.
311
+
312
+ Args:
313
+ sample: Dataset sample
314
+
315
+ Returns:
316
+ Dict mapping "tool_name:args_json" -> {"content": response, "tool_call_id": id}
317
+ """
318
+ lookup: dict[str, dict[str, str]] = {}
319
+ messages = sample.get("messages", [])
320
+
321
+ # Track pending tool calls from assistant messages
322
+ pending_tool_calls: dict[str, dict] = {} # tool_call_id -> {name, arguments}
323
+
324
+ for msg in messages:
325
+ role = msg.get("role")
326
+
327
+ # Collect tool calls from assistant messages
328
+ if role == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
329
+ for tc in msg["tool_calls"]:
330
+ tc_id = tc.get("id")
331
+ func = tc.get("function", {})
332
+ name = func.get("name", "")
333
+ args = func.get("arguments", "{}")
334
+ if tc_id:
335
+ pending_tool_calls[tc_id] = {"name": name, "arguments": args}
336
+
337
+ # Match tool responses to their calls
338
+ if role == "tool":
339
+ tc_id = msg.get("tool_call_id")
340
+ content = msg.get("content", "")
341
+ if tc_id and tc_id in pending_tool_calls:
342
+ call_info = pending_tool_calls[tc_id]
343
+ # Create lookup key from tool name + normalized arguments
344
+ try:
345
+ args_dict = json.loads(call_info["arguments"])
346
+ normalized_args = json.dumps(args_dict, sort_keys=True)
347
+ key = f"{call_info['name']}:{normalized_args}"
348
+ except (json.JSONDecodeError, TypeError):
349
+ # Fallback if arguments are not a valid JSON string
350
+ key = f"{call_info['name']}:{call_info['arguments']}"
351
+ lookup[key] = {"content": content, "tool_call_id": tc_id}
352
+
353
+ return lookup
354
+
355
+ def find_tool_response(
356
+ self,
357
+ tool_call: dict,
358
+ lookup: dict[str, dict[str, str]],
359
+ ) -> dict[str, str] | None:
360
+ """Find a matching tool response for a predicted tool call.
361
+
362
+ Args:
363
+ tool_call: Predicted tool call with 'name' and 'arguments'
364
+ lookup: Tool response lookup from build_tool_response_lookup
365
+
366
+ Returns:
367
+ Dict with 'content' and 'tool_call_id' if found, None otherwise
368
+ """
369
+ name = tool_call.get("name", "")
370
+ args = tool_call.get("arguments", {})
371
+
372
+ # Normalize arguments to JSON string for comparison
373
+ args_json = json.dumps(args, sort_keys=True) if isinstance(args, dict) else str(args)
374
+
375
+ # Try exact match first
376
+ key = f"{name}:{args_json}"
377
+ if key in lookup:
378
+ return lookup[key]
379
+
380
+ # Try matching just by tool name (less strict)
381
+ # This helps when parameter values differ slightly
382
+ for lookup_key, response in lookup.items():
383
+ if lookup_key.startswith(f"{name}:"):
384
+ return response
385
+
386
+ return None
387
+
388
+ def evaluate_sample(
389
+ self,
390
+ sample: dict[str, Any],
391
+ sample_id: int,
392
+ ) -> SampleEvaluation:
393
+ """Evaluate a single sample using configured evaluators.
394
+
395
+ Args:
396
+ sample: Dataset sample
397
+ sample_id: Sample index
398
+
399
+ Returns:
400
+ Evaluation result for this sample
401
+ """
402
+ # Use multi-turn evaluation if enabled
403
+ if self.config.multi_turn:
404
+ return self.evaluate_sample_multi_turn(sample, sample_id)
405
+
406
+ try:
407
+ # Extract ground truth
408
+ ground_truth = self.extract_ground_truth(sample)
409
+
410
+ # Prepare inputs
411
+ messages = self.prepare_messages(sample)
412
+ tools = self.prepare_tools(sample)
413
+
414
+ # Run inference
415
+ response: ModelResponse = self.backend.generate(messages, tools)
416
+
417
+ # Create evaluation context
418
+ context = EvaluationContext(
419
+ messages=messages,
420
+ tools=tools,
421
+ sample_id=sample_id,
422
+ )
423
+
424
+ # Run all active evaluators
425
+ evaluator_results: list[EvaluatorResult] = []
426
+ for evaluator in self.active_evaluators:
427
+ result = evaluator.evaluate(ground_truth, response, context)
428
+ if result is not None: # Evaluator may skip
429
+ evaluator_results.append(result)
430
+
431
+ # Aggregate results for backwards compatibility
432
+ return self._aggregate_results(
433
+ sample_id=sample_id,
434
+ ground_truth=ground_truth,
435
+ response=response,
436
+ evaluator_results=evaluator_results,
437
+ )
438
+
439
+ except Exception as e: # noqa: BLE001
440
+ # Return failed evaluation with safe defaults
441
+ query = ""
442
+ expected_tool = None
443
+ expected_params: dict[str, Any] = {}
444
+ expected_answer = None
445
+
446
+ # Try to extract ground truth if available
447
+ try:
448
+ gt = self.extract_ground_truth(sample)
449
+ query = gt.query
450
+ expected_tool = gt.expected_tool
451
+ expected_params = gt.expected_parameters
452
+ expected_answer = gt.expected_answer
453
+ except (KeyError, AttributeError, ValidationError):
454
+ pass
455
+
456
+ return SampleEvaluation(
457
+ sample_id=sample_id,
458
+ query=query,
459
+ expected_tool=expected_tool,
460
+ predicted_tool=None,
461
+ expected_parameters=expected_params,
462
+ predicted_parameters={},
463
+ expected_answer=expected_answer,
464
+ predicted_answer=None,
465
+ tool_selection_correct=False,
466
+ parameters_correct=False,
467
+ execution_valid=False,
468
+ response_score=0.0,
469
+ error=str(e),
470
+ )
471
+
472
+ def evaluate_sample_multi_turn(
473
+ self,
474
+ sample: dict[str, Any],
475
+ sample_id: int,
476
+ ) -> SampleEvaluation:
477
+ """Evaluate a single sample using multi-turn conversation.
478
+
479
+ Loops through the conversation, feeding tool responses back to the model
480
+ until it generates a final answer (no tool calls) or max turns reached.
481
+
482
+ Args:
483
+ sample: Dataset sample
484
+ sample_id: Sample index
485
+
486
+ Returns:
487
+ Evaluation result for this sample
488
+ """
489
+ try:
490
+ # Extract ground truth (includes all expected tools)
491
+ ground_truth = self.extract_ground_truth(sample)
492
+
493
+ # Prepare initial inputs
494
+ messages = self.prepare_messages(sample)
495
+ tools = self.prepare_tools(sample)
496
+
497
+ # Build lookup for tool responses from ground truth
498
+ tool_response_lookup = self.build_tool_response_lookup(sample)
499
+
500
+ # Track all predicted tool calls across turns
501
+ all_predicted_tool_calls: list[dict] = []
502
+ final_content = ""
503
+
504
+ # Multi-turn loop
505
+ for turn in range(self.config.max_turns):
506
+ # Run inference
507
+ response: ModelResponse = self.backend.generate(messages, tools)
508
+ final_content = response.content
509
+
510
+ # Check if model made tool calls
511
+ if not response.tool_calls:
512
+ # No tool calls - this is the final answer
513
+ break
514
+
515
+ # Process each tool call
516
+ for tool_call in response.tool_calls:
517
+ all_predicted_tool_calls.append(tool_call)
518
+
519
+ # Find matching tool response from ground truth
520
+ tool_response = self.find_tool_response(tool_call, tool_response_lookup)
521
+
522
+ if tool_response is None:
523
+ # Model called a tool we don't have a response for
524
+ # Continue anyway with an error message
525
+ tool_response = {
526
+ "content": json.dumps({"error": "Tool not found in ground truth"}),
527
+ "tool_call_id": f"generated_{turn}_{len(all_predicted_tool_calls)}",
528
+ }
529
+
530
+ # Add assistant message with tool call to conversation
531
+ messages.append(
532
+ {
533
+ "role": "assistant",
534
+ "content": "",
535
+ "tool_calls": [
536
+ {
537
+ "id": tool_response["tool_call_id"],
538
+ "type": "function",
539
+ "function": {
540
+ "name": tool_call.get("name", ""),
541
+ "arguments": json.dumps(tool_call.get("arguments", {})),
542
+ },
543
+ }
544
+ ],
545
+ }
546
+ )
547
+
548
+ # Add tool response to conversation
549
+ messages.append(
550
+ {
551
+ "role": "tool",
552
+ "tool_call_id": tool_response["tool_call_id"],
553
+ "content": tool_response["content"],
554
+ }
555
+ )
556
+
557
+ # Now compute metrics comparing predicted vs expected tool calls
558
+ return self._compute_multi_turn_metrics(
559
+ sample_id=sample_id,
560
+ ground_truth=ground_truth,
561
+ predicted_tool_calls=all_predicted_tool_calls,
562
+ final_content=final_content,
563
+ )
564
+
565
+ except Exception as e: # noqa: BLE001
566
+ # Return failed evaluation with safe defaults
567
+ query = ""
568
+ expected_tool = None
569
+ expected_params: dict[str, Any] = {}
570
+ expected_answer = None
571
+
572
+ try:
573
+ gt = self.extract_ground_truth(sample)
574
+ query = gt.query
575
+ expected_tool = gt.expected_tool
576
+ expected_params = gt.expected_parameters
577
+ expected_answer = gt.expected_answer
578
+ except (KeyError, AttributeError, ValidationError):
579
+ pass
580
+
581
+ return SampleEvaluation(
582
+ sample_id=sample_id,
583
+ query=query,
584
+ expected_tool=expected_tool,
585
+ predicted_tool=None,
586
+ expected_parameters=expected_params,
587
+ predicted_parameters={},
588
+ expected_answer=expected_answer,
589
+ predicted_answer=None,
590
+ tool_selection_correct=False,
591
+ parameters_correct=False,
592
+ execution_valid=False,
593
+ response_score=0.0,
594
+ error=str(e),
595
+ )
596
+
597
+ def _compute_multi_turn_metrics(
598
+ self,
599
+ sample_id: int,
600
+ ground_truth: GroundTruth,
601
+ predicted_tool_calls: list[dict],
602
+ final_content: str,
603
+ ) -> SampleEvaluation:
604
+ """Compute metrics for multi-turn evaluation.
605
+
606
+ Compares predicted tool calls against expected tools using set comparison.
607
+
608
+ Args:
609
+ sample_id: Sample index
610
+ ground_truth: Expected values including all expected tools
611
+ predicted_tool_calls: All tool calls made by model across turns
612
+ final_content: Final model response content
613
+
614
+ Returns:
615
+ SampleEvaluation with computed metrics
616
+ """
617
+ # Build sets of tool names for comparison
618
+ expected_tools_list = ground_truth.expected_tools or []
619
+ expected_tool_names = {tc.tool_name for tc in expected_tools_list}
620
+ predicted_tool_names = {tc.get("name", "") for tc in (predicted_tool_calls or [])}
621
+
622
+ # Tool set coverage: what fraction of expected tools were called?
623
+ if expected_tool_names:
624
+ matched_tools = expected_tool_names & predicted_tool_names
625
+ tool_coverage = len(matched_tools) / len(expected_tool_names)
626
+ else:
627
+ tool_coverage = 1.0 if not predicted_tool_names else 0.0
628
+
629
+ # Tool set precision: what fraction of predicted tools were expected?
630
+ # (Computed but stored in response_score for now, could be expanded later)
631
+ if predicted_tool_names:
632
+ matched_tools = expected_tool_names & predicted_tool_names
633
+ _tool_precision = len(matched_tools) / len(predicted_tool_names) # noqa: F841
634
+ else:
635
+ _tool_precision = 1.0 if not expected_tool_names else 0.0 # noqa: F841
636
+
637
+ # Overall tool selection is correct if coverage is 100%
638
+ tool_selection_correct = tool_coverage == 1.0
639
+
640
+ # For backwards compatibility, use first predicted/expected tool
641
+ first_predicted_tool = predicted_tool_calls[0].get("name") if predicted_tool_calls else None
642
+ first_predicted_params = (
643
+ predicted_tool_calls[0].get("arguments", {}) if predicted_tool_calls else {}
644
+ )
645
+
646
+ # Parameter accuracy: for matched tools, check if params are structurally correct
647
+ params_correct = self._check_parameter_structure(
648
+ ground_truth.expected_tools,
649
+ predicted_tool_calls,
650
+ )
651
+
652
+ # Execution valid if we got through the conversation
653
+ execution_valid = len(predicted_tool_calls) > 0 or final_content != ""
654
+
655
+ return SampleEvaluation(
656
+ sample_id=sample_id,
657
+ query=ground_truth.query,
658
+ expected_tool=ground_truth.expected_tool,
659
+ predicted_tool=first_predicted_tool,
660
+ expected_parameters=ground_truth.expected_parameters,
661
+ predicted_parameters=first_predicted_params
662
+ if isinstance(first_predicted_params, dict)
663
+ else {},
664
+ expected_answer=ground_truth.expected_answer,
665
+ predicted_answer=final_content,
666
+ tool_selection_correct=tool_selection_correct,
667
+ parameters_correct=params_correct,
668
+ execution_valid=execution_valid,
669
+ response_score=tool_coverage, # Use coverage as response score
670
+ error=None,
671
+ )
672
+
673
+ def _check_parameter_structure(
674
+ self,
675
+ expected_tools: list[ExpectedToolCall],
676
+ predicted_tool_calls: list[dict],
677
+ ) -> bool:
678
+ """Check if predicted tool calls have correct parameter structure.
679
+
680
+ For each matched tool, verifies that predicted params have the same keys.
681
+ Does not check parameter values, only structure.
682
+
683
+ Args:
684
+ expected_tools: List of ExpectedToolCall from ground truth
685
+ predicted_tool_calls: List of predicted tool call dicts
686
+
687
+ Returns:
688
+ True if all matched tools have correct parameter structure
689
+ """
690
+ # Build lookup of expected params by tool name
691
+ expected_params_by_tool: dict[str, set[str]] = {}
692
+ for tc in expected_tools or []:
693
+ if tc.tool_name not in expected_params_by_tool:
694
+ expected_params_by_tool[tc.tool_name] = set(tc.parameters.keys())
695
+
696
+ # Check each predicted tool call
697
+ for pred_call in predicted_tool_calls or []:
698
+ tool_name = pred_call.get("name", "")
699
+ pred_args = pred_call.get("arguments", {})
700
+
701
+ if tool_name in expected_params_by_tool:
702
+ expected_keys = expected_params_by_tool[tool_name]
703
+ pred_keys = set(pred_args.keys()) if isinstance(pred_args, dict) else set()
704
+
705
+ # Check if predicted has all expected keys (may have extra)
706
+ if not expected_keys.issubset(pred_keys):
707
+ return False
708
+
709
+ return True
710
+
711
+ def _aggregate_results(
712
+ self,
713
+ sample_id: int,
714
+ ground_truth: GroundTruth,
715
+ response: ModelResponse,
716
+ evaluator_results: list[EvaluatorResult],
717
+ ) -> SampleEvaluation:
718
+ """Aggregate evaluator results into SampleEvaluation.
719
+
720
+ Args:
721
+ sample_id: Sample index
722
+ ground_truth: Expected values
723
+ response: Model response
724
+ evaluator_results: Results from all evaluators
725
+
726
+ Returns:
727
+ SampleEvaluation with aggregated metrics
728
+ """
729
+ # Extract tool calling metrics from evaluator results
730
+ tool_correct = False
731
+ params_correct = False
732
+ execution_valid = False
733
+ predicted_tool = None
734
+ predicted_params = {}
735
+
736
+ # Extract predictions from response
737
+ if response.tool_call:
738
+ predicted_tool = response.tool_call.get("name")
739
+ predicted_params = response.tool_call.get("arguments", {})
740
+
741
+ # Get metrics from tool_calling evaluator
742
+ for result in evaluator_results:
743
+ if result.evaluator_name == "tool_calling":
744
+ metrics = result.metrics
745
+ tool_correct = metrics.get("tool_selection_accuracy", 0.0) == 1.0
746
+ params_correct = metrics.get("parameter_accuracy", 0.0) == 1.0
747
+ execution_valid = metrics.get("execution_valid", 0.0) == 1.0
748
+
749
+ # Return backwards-compatible SampleEvaluation
750
+ return SampleEvaluation(
751
+ sample_id=sample_id,
752
+ query=ground_truth.query,
753
+ expected_tool=ground_truth.expected_tool,
754
+ predicted_tool=predicted_tool,
755
+ expected_parameters=ground_truth.expected_parameters,
756
+ predicted_parameters=predicted_params,
757
+ expected_answer=ground_truth.expected_answer,
758
+ predicted_answer=response.content,
759
+ tool_selection_correct=tool_correct,
760
+ parameters_correct=params_correct,
761
+ execution_valid=execution_valid,
762
+ response_score=0.0, # TODO: Could use semantic similarity for response quality evaluation in the future, but disabled for tool-calling mode
763
+ error=None,
764
+ )
765
+
766
+ def evaluate(self, dataset: HFDataset | None = None) -> EvaluationResult:
767
+ """Run full evaluation.
768
+
769
+ Args:
770
+ dataset: Optional HuggingFace Dataset to evaluate. If not provided,
771
+ loads from config.dataset_path.
772
+
773
+ Returns:
774
+ Complete evaluation result with metrics and predictions
775
+ """
776
+ console.print("[bold blue]Loading dataset...[/bold blue]")
777
+ samples = self.load_dataset(dataset)
778
+ console.print(f"Loaded {len(samples)} samples")
779
+
780
+ console.print("[bold blue]Running evaluation...[/bold blue]")
781
+ evaluations = []
782
+
783
+ for idx, sample in tqdm(enumerate(samples), total=len(samples), desc="Evaluating"):
784
+ eval_result = self.evaluate_sample(sample, idx)
785
+ evaluations.append(eval_result)
786
+
787
+ # Stream sample to reporters (for cloud real-time tracking)
788
+ self.reporter.report_sample(eval_result)
789
+
790
+ console.print("[bold green]Evaluation complete![/bold green]")
791
+
792
+ # Compute metrics
793
+ metrics = compute_metrics(evaluations, self.config.metric_weights)
794
+
795
+ # Create result
796
+ result = EvaluationResult(
797
+ metrics=metrics,
798
+ predictions=evaluations,
799
+ config=self.config,
800
+ )
801
+
802
+ # Track evaluation completion
803
+ trace(
804
+ "evaluation_completed",
805
+ {
806
+ "backend": self.config.inference_config.backend,
807
+ "model_path": self.config.inference_config.model_path,
808
+ "has_adapter": self.config.inference_config.adapter_path is not None,
809
+ "samples_evaluated": metrics.samples_evaluated,
810
+ "samples_processed": metrics.samples_processed,
811
+ "processing_errors": metrics.processing_errors,
812
+ "tool_selection_accuracy": round(metrics.tool_selection_accuracy, 4),
813
+ "parameter_accuracy": round(metrics.parameter_accuracy, 4),
814
+ "execution_success_rate": round(metrics.execution_success_rate, 4),
815
+ "overall_score": round(metrics.overall_score, 4),
816
+ "success": metrics.processing_errors == 0,
817
+ },
818
+ )
819
+
820
+ # Report results using configured reporters
821
+ if self.config.save_predictions:
822
+ self.reporter.report(result)
823
+
824
+ return result
825
+
826
+ def cleanup(self) -> None:
827
+ """Clean up resources."""
828
+ self.backend.cleanup()
829
+
830
+ def print_summary(self, metrics: EvaluationMetrics) -> None:
831
+ """Print evaluation summary.
832
+
833
+ Args:
834
+ metrics: Computed metrics
835
+ """
836
+ console.print("\n[bold]Evaluation Summary[/bold]")
837
+ console.print(f"Samples Evaluated: {metrics.samples_evaluated}")
838
+ console.print(f"Processed Successfully: {metrics.samples_processed}")
839
+ console.print(f"Processing Errors: {metrics.processing_errors}")
840
+ console.print("\n[bold]Metrics[/bold]")
841
+ console.print(f"Tool Selection Accuracy: {metrics.tool_selection_accuracy:.2%}")
842
+ console.print(f"Parameter Accuracy: {metrics.parameter_accuracy:.2%}")
843
+ console.print(f"Execution Success Rate: {metrics.execution_success_rate:.2%}")
844
+ console.print(f"Response Quality: {metrics.response_quality:.2%}")
845
+ console.print(f"\n[bold green]Overall Score: {metrics.overall_score:.2%}[/bold green]")