mcpbr 0.4.14__py3-none-any.whl → 0.4.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,405 @@
1
+ """Custom metrics framework for flexible evaluation beyond standard accuracy/pass rates.
2
+
3
+ This module provides:
4
+ - MetricDefinition dataclass for declaring metrics with name, description, compute
5
+ function, aggregation strategy, and direction (higher_is_better).
6
+ - MetricRegistry for registering, looking up, and managing metrics.
7
+ - Built-in metrics: accuracy, pass_rate, avg_tokens, avg_cost, avg_time,
8
+ tool_call_rate, failure_rate.
9
+ - Support for composite metrics (e.g., cost_efficiency = pass_rate / avg_cost).
10
+ - compute_metrics() to evaluate a set of metrics against result data.
11
+ - validate_metric() to check metric definition validity.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import math
17
+ import statistics
18
+ from dataclasses import dataclass
19
+ from typing import Any, Callable
20
+
21
+
22
+ @dataclass
23
+ class MetricDefinition:
24
+ """Definition of a single evaluation metric.
25
+
26
+ Attributes:
27
+ name: Unique identifier for the metric.
28
+ description: Human-readable description of what the metric measures.
29
+ compute_fn: Either a callable ``(list[dict]) -> float`` that computes the
30
+ metric from a list of result dicts, or a string expression referencing
31
+ other metric names (for composite metrics).
32
+ aggregation: Aggregation strategy used when summarising per-task values.
33
+ One of ``"mean"``, ``"sum"``, ``"min"``, ``"max"``, ``"median"``.
34
+ higher_is_better: Whether a higher value is considered better.
35
+ """
36
+
37
+ name: str
38
+ description: str
39
+ compute_fn: Callable[[list[dict[str, Any]]], float] | str
40
+ aggregation: str = "mean"
41
+ higher_is_better: bool = True
42
+
43
+
44
+ _VALID_AGGREGATIONS = frozenset({"mean", "sum", "min", "max", "median"})
45
+
46
+
47
+ class MetricRegistry:
48
+ """Registry for looking up and managing metric definitions.
49
+
50
+ Provides ``register``, ``get``, ``list_metrics``, and ``unregister`` operations.
51
+ """
52
+
53
+ def __init__(self) -> None:
54
+ self._metrics: dict[str, MetricDefinition] = {}
55
+
56
+ # -- public API ----------------------------------------------------------
57
+
58
+ def register(self, metric: MetricDefinition) -> None:
59
+ """Register a metric definition.
60
+
61
+ Args:
62
+ metric: The metric to register.
63
+
64
+ Raises:
65
+ ValueError: If a metric with the same name is already registered.
66
+ """
67
+ if metric.name in self._metrics:
68
+ raise ValueError(f"Metric '{metric.name}' is already registered")
69
+ self._metrics[metric.name] = metric
70
+
71
+ def get(self, name: str) -> MetricDefinition | None:
72
+ """Look up a metric by name.
73
+
74
+ Args:
75
+ name: Metric name.
76
+
77
+ Returns:
78
+ The MetricDefinition if found, otherwise ``None``.
79
+ """
80
+ return self._metrics.get(name)
81
+
82
+ def list_metrics(self) -> list[str]:
83
+ """Return a sorted list of all registered metric names."""
84
+ return sorted(self._metrics.keys())
85
+
86
+ def unregister(self, name: str) -> bool:
87
+ """Remove a metric from the registry.
88
+
89
+ Args:
90
+ name: Metric name to remove.
91
+
92
+ Returns:
93
+ ``True`` if the metric was removed, ``False`` if it was not found.
94
+ """
95
+ if name in self._metrics:
96
+ del self._metrics[name]
97
+ return True
98
+ return False
99
+
100
+ def __contains__(self, name: str) -> bool:
101
+ return name in self._metrics
102
+
103
+ def __len__(self) -> int:
104
+ return len(self._metrics)
105
+
106
+
107
+ # ---------------------------------------------------------------------------
108
+ # Built-in metric compute functions
109
+ # ---------------------------------------------------------------------------
110
+
111
+
112
+ def _compute_accuracy(results: list[dict[str, Any]]) -> float:
113
+ """Fraction of results where ``resolved`` is truthy."""
114
+ if not results:
115
+ return 0.0
116
+ resolved = sum(1 for r in results if r.get("resolved"))
117
+ return resolved / len(results)
118
+
119
+
120
+ def _compute_pass_rate(results: list[dict[str, Any]]) -> float:
121
+ """Fraction of results where ``resolved`` is truthy (alias of accuracy)."""
122
+ return _compute_accuracy(results)
123
+
124
+
125
+ def _compute_avg_tokens(results: list[dict[str, Any]]) -> float:
126
+ """Average total token count per result."""
127
+ token_counts: list[int] = []
128
+ for r in results:
129
+ tokens = r.get("tokens", {})
130
+ total = tokens.get("input", 0) + tokens.get("output", 0)
131
+ token_counts.append(total)
132
+ if not token_counts:
133
+ return 0.0
134
+ return float(statistics.mean(token_counts))
135
+
136
+
137
+ def _compute_avg_cost(results: list[dict[str, Any]]) -> float:
138
+ """Average cost per result."""
139
+ costs = [r.get("cost", 0.0) for r in results]
140
+ if not costs:
141
+ return 0.0
142
+ return statistics.mean(costs)
143
+
144
+
145
+ def _compute_avg_time(results: list[dict[str, Any]]) -> float:
146
+ """Average runtime in seconds per result."""
147
+ runtimes = [r.get("runtime_seconds", 0.0) for r in results]
148
+ if not runtimes:
149
+ return 0.0
150
+ return statistics.mean(runtimes)
151
+
152
+
153
+ def _compute_tool_call_rate(results: list[dict[str, Any]]) -> float:
154
+ """Fraction of results that contain at least one tool call."""
155
+ if not results:
156
+ return 0.0
157
+ with_tools = sum(1 for r in results if r.get("tool_usage"))
158
+ return with_tools / len(results)
159
+
160
+
161
+ def _compute_failure_rate(results: list[dict[str, Any]]) -> float:
162
+ """Fraction of results where ``error`` is present and non-empty."""
163
+ if not results:
164
+ return 0.0
165
+ with_errors = sum(1 for r in results if r.get("error"))
166
+ return with_errors / len(results)
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Built-in metric definitions
171
+ # ---------------------------------------------------------------------------
172
+
173
+ BUILTIN_METRICS: list[MetricDefinition] = [
174
+ MetricDefinition(
175
+ name="accuracy",
176
+ description="Fraction of tasks resolved successfully",
177
+ compute_fn=_compute_accuracy,
178
+ aggregation="mean",
179
+ higher_is_better=True,
180
+ ),
181
+ MetricDefinition(
182
+ name="pass_rate",
183
+ description="Fraction of tasks that pass (alias for accuracy)",
184
+ compute_fn=_compute_pass_rate,
185
+ aggregation="mean",
186
+ higher_is_better=True,
187
+ ),
188
+ MetricDefinition(
189
+ name="avg_tokens",
190
+ description="Average total tokens (input + output) per task",
191
+ compute_fn=_compute_avg_tokens,
192
+ aggregation="mean",
193
+ higher_is_better=False,
194
+ ),
195
+ MetricDefinition(
196
+ name="avg_cost",
197
+ description="Average API cost per task in USD",
198
+ compute_fn=_compute_avg_cost,
199
+ aggregation="mean",
200
+ higher_is_better=False,
201
+ ),
202
+ MetricDefinition(
203
+ name="avg_time",
204
+ description="Average runtime per task in seconds",
205
+ compute_fn=_compute_avg_time,
206
+ aggregation="mean",
207
+ higher_is_better=False,
208
+ ),
209
+ MetricDefinition(
210
+ name="tool_call_rate",
211
+ description="Fraction of tasks that used at least one tool",
212
+ compute_fn=_compute_tool_call_rate,
213
+ aggregation="mean",
214
+ higher_is_better=True,
215
+ ),
216
+ MetricDefinition(
217
+ name="failure_rate",
218
+ description="Fraction of tasks that encountered an error",
219
+ compute_fn=_compute_failure_rate,
220
+ aggregation="mean",
221
+ higher_is_better=False,
222
+ ),
223
+ ]
224
+
225
+
226
+ def create_default_registry() -> MetricRegistry:
227
+ """Create a MetricRegistry pre-populated with all built-in metrics.
228
+
229
+ Returns:
230
+ A MetricRegistry instance containing the built-in metrics.
231
+ """
232
+ registry = MetricRegistry()
233
+ for metric in BUILTIN_METRICS:
234
+ registry.register(metric)
235
+ return registry
236
+
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # Aggregation helpers
240
+ # ---------------------------------------------------------------------------
241
+
242
+
243
+ def _aggregate(values: list[float], method: str) -> float:
244
+ """Aggregate a list of floats using the specified method.
245
+
246
+ Args:
247
+ values: Numeric values to aggregate.
248
+ method: One of ``"mean"``, ``"sum"``, ``"min"``, ``"max"``, ``"median"``.
249
+
250
+ Returns:
251
+ Aggregated value.
252
+
253
+ Raises:
254
+ ValueError: If the method is unrecognised.
255
+ """
256
+ if not values:
257
+ return 0.0
258
+ if method == "mean":
259
+ return statistics.mean(values)
260
+ elif method == "sum":
261
+ return math.fsum(values)
262
+ elif method == "min":
263
+ return min(values)
264
+ elif method == "max":
265
+ return max(values)
266
+ elif method == "median":
267
+ return statistics.median(values)
268
+ else:
269
+ raise ValueError(f"Unknown aggregation method: {method!r}")
270
+
271
+
272
+ # ---------------------------------------------------------------------------
273
+ # Core public API
274
+ # ---------------------------------------------------------------------------
275
+
276
+
277
+ def compute_metrics(
278
+ results: list[dict[str, Any]],
279
+ metrics: list[str],
280
+ registry: MetricRegistry | None = None,
281
+ ) -> dict[str, float]:
282
+ """Compute the requested metrics over a list of result dicts.
283
+
284
+ Each result dict is expected to follow the structure used elsewhere in mcpbr
285
+ (keys such as ``resolved``, ``tokens``, ``cost``, ``runtime_seconds``,
286
+ ``tool_usage``, ``error``).
287
+
288
+ Composite metrics (whose ``compute_fn`` is a string expression) are resolved
289
+ by first computing all non-composite metrics they reference, then evaluating the
290
+ expression in a restricted namespace.
291
+
292
+ Args:
293
+ results: List of per-task result dictionaries.
294
+ metrics: List of metric names to compute.
295
+ registry: Optional MetricRegistry. If ``None``, the default registry
296
+ (containing built-in metrics) is used.
297
+
298
+ Returns:
299
+ Dictionary mapping metric names to their computed float values.
300
+
301
+ Raises:
302
+ KeyError: If a requested metric is not found in the registry.
303
+ ValueError: If a composite expression references an unknown metric or
304
+ fails to evaluate.
305
+ """
306
+ if registry is None:
307
+ registry = create_default_registry()
308
+
309
+ computed: dict[str, float] = {}
310
+
311
+ # Separate callable and composite (expression-based) metrics
312
+ callable_names: list[str] = []
313
+ composite_names: list[str] = []
314
+
315
+ for name in metrics:
316
+ metric_def = registry.get(name)
317
+ if metric_def is None:
318
+ raise KeyError(f"Metric '{name}' is not registered")
319
+ if callable(metric_def.compute_fn):
320
+ callable_names.append(name)
321
+ else:
322
+ composite_names.append(name)
323
+
324
+ # Phase 1: compute all callable metrics
325
+ for name in callable_names:
326
+ metric_def = registry.get(name)
327
+ assert metric_def is not None # guaranteed above
328
+ assert callable(metric_def.compute_fn)
329
+ computed[name] = metric_def.compute_fn(results)
330
+
331
+ # Phase 2: resolve composite metrics
332
+ for name in composite_names:
333
+ metric_def = registry.get(name)
334
+ assert metric_def is not None
335
+ assert isinstance(metric_def.compute_fn, str)
336
+
337
+ # Build a namespace of already-computed values. If the expression
338
+ # references a metric that hasn't been computed yet, compute it now.
339
+ ns: dict[str, float] = {}
340
+ for existing_name, existing_val in computed.items():
341
+ ns[existing_name] = existing_val
342
+
343
+ # Evaluate the expression. We deliberately restrict the namespace to
344
+ # only contain computed metric values (no builtins).
345
+ try:
346
+ value = float(eval(metric_def.compute_fn, {"__builtins__": {}}, ns)) # noqa: S307
347
+ except ZeroDivisionError:
348
+ value = 0.0
349
+ except Exception as exc:
350
+ raise ValueError(
351
+ f"Failed to evaluate composite metric '{name}' "
352
+ f"expression '{metric_def.compute_fn}': {exc}"
353
+ ) from exc
354
+
355
+ computed[name] = value
356
+
357
+ return computed
358
+
359
+
360
+ def validate_metric(metric_def: dict[str, Any]) -> bool:
361
+ """Validate a metric definition dictionary.
362
+
363
+ Checks that the definition contains all required fields with correct types
364
+ and valid values.
365
+
366
+ Required keys:
367
+ - ``name`` (str, non-empty)
368
+ - ``description`` (str)
369
+ - ``compute_fn`` (callable or str)
370
+
371
+ Optional keys (with defaults):
372
+ - ``aggregation`` (str, one of mean/sum/min/max/median)
373
+ - ``higher_is_better`` (bool)
374
+
375
+ Args:
376
+ metric_def: Dictionary representing a metric definition.
377
+
378
+ Returns:
379
+ ``True`` if the definition is valid, ``False`` otherwise.
380
+ """
381
+ # Required fields
382
+ if not isinstance(metric_def.get("name"), str) or not metric_def["name"].strip():
383
+ return False
384
+
385
+ if not isinstance(metric_def.get("description"), str):
386
+ return False
387
+
388
+ compute_fn = metric_def.get("compute_fn")
389
+ if compute_fn is None:
390
+ return False
391
+ if not callable(compute_fn) and not isinstance(compute_fn, str):
392
+ return False
393
+ if isinstance(compute_fn, str) and not compute_fn.strip():
394
+ return False
395
+
396
+ # Optional fields
397
+ aggregation = metric_def.get("aggregation", "mean")
398
+ if aggregation not in _VALID_AGGREGATIONS:
399
+ return False
400
+
401
+ higher_is_better = metric_def.get("higher_is_better", True)
402
+ if not isinstance(higher_is_better, bool):
403
+ return False
404
+
405
+ return True
@@ -0,0 +1,222 @@
1
+ """Dataset versioning for reproducible benchmark evaluations.
2
+
3
+ This module provides utilities to pin and track HuggingFace dataset versions,
4
+ ensuring that benchmark runs can be reproduced with the exact same data.
5
+ Version information includes dataset revision hashes, download timestamps,
6
+ and optional checksums for data integrity verification.
7
+ """
8
+
9
+ import hashlib
10
+ import json
11
+ import logging
12
+ from dataclasses import asdict, dataclass
13
+ from datetime import datetime, timezone
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ from datasets import Dataset, load_dataset
18
+ from huggingface_hub import dataset_info
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class DatasetVersion:
25
+ """Pinned version information for a HuggingFace dataset.
26
+
27
+ Attributes:
28
+ dataset_id: HuggingFace dataset identifier (e.g., 'SWE-bench/SWE-bench_Lite').
29
+ revision: Git revision hash of the dataset (None for latest).
30
+ download_date: ISO 8601 timestamp of when the version was pinned.
31
+ checksum: Optional SHA256 checksum of the dataset content for integrity verification.
32
+ """
33
+
34
+ dataset_id: str
35
+ revision: str | None
36
+ download_date: str
37
+ checksum: str | None
38
+
39
+
40
+ def pin_dataset_version(
41
+ dataset_id: str,
42
+ revision: str | None = None,
43
+ ) -> DatasetVersion:
44
+ """Record the current version of a HuggingFace dataset.
45
+
46
+ Fetches dataset metadata from the HuggingFace Hub to determine the
47
+ current revision. If a specific revision is provided, it is used directly.
48
+
49
+ Args:
50
+ dataset_id: HuggingFace dataset identifier (e.g., 'SWE-bench/SWE-bench_Lite').
51
+ revision: Specific git revision to pin. If None, the latest revision is fetched.
52
+
53
+ Returns:
54
+ DatasetVersion with the pinned revision and metadata.
55
+
56
+ Raises:
57
+ Exception: If the dataset cannot be found or accessed on the HuggingFace Hub.
58
+ """
59
+ info = dataset_info(dataset_id, revision=revision)
60
+ resolved_revision = info.sha
61
+
62
+ # Compute a checksum from the dataset card and file metadata for integrity
63
+ checksum_data = f"{dataset_id}:{resolved_revision}"
64
+ if info.siblings:
65
+ file_names = sorted(s.rfilename for s in info.siblings)
66
+ checksum_data += ":" + ",".join(file_names)
67
+ checksum = hashlib.sha256(checksum_data.encode()).hexdigest()
68
+
69
+ download_date = datetime.now(timezone.utc).isoformat()
70
+
71
+ version = DatasetVersion(
72
+ dataset_id=dataset_id,
73
+ revision=resolved_revision,
74
+ download_date=download_date,
75
+ checksum=checksum,
76
+ )
77
+
78
+ logger.info(
79
+ "Pinned dataset %s at revision %s",
80
+ dataset_id,
81
+ resolved_revision,
82
+ )
83
+
84
+ return version
85
+
86
+
87
+ def load_dataset_pinned(
88
+ dataset_id: str,
89
+ version: DatasetVersion | None = None,
90
+ **kwargs: Any,
91
+ ) -> Dataset:
92
+ """Load a HuggingFace dataset using a pinned version for reproducibility.
93
+
94
+ Wraps the standard ``datasets.load_dataset`` call, injecting the pinned
95
+ revision so that the exact same data snapshot is used across runs.
96
+
97
+ Args:
98
+ dataset_id: HuggingFace dataset identifier.
99
+ version: Pinned version to use. If None, loads the latest version.
100
+ **kwargs: Additional keyword arguments passed to ``datasets.load_dataset``
101
+ (e.g., split, name, streaming).
102
+
103
+ Returns:
104
+ The loaded HuggingFace Dataset.
105
+ """
106
+ revision = None
107
+ if version is not None:
108
+ revision = version.revision
109
+ logger.info(
110
+ "Loading dataset %s at pinned revision %s (pinned on %s)",
111
+ dataset_id,
112
+ revision,
113
+ version.download_date,
114
+ )
115
+ else:
116
+ logger.info("Loading dataset %s at latest revision", dataset_id)
117
+
118
+ return load_dataset(dataset_id, revision=revision, **kwargs)
119
+
120
+
121
+ def save_version_manifest(
122
+ versions: dict[str, DatasetVersion],
123
+ path: Path,
124
+ ) -> None:
125
+ """Save dataset version pins to a JSON manifest file.
126
+
127
+ The manifest file records all pinned dataset versions so they can be
128
+ shared across team members or CI environments for reproducible runs.
129
+
130
+ Args:
131
+ versions: Mapping of dataset identifiers to their pinned versions.
132
+ path: File path to write the JSON manifest.
133
+ """
134
+ manifest: dict[str, Any] = {
135
+ "format_version": "1.0",
136
+ "created_at": datetime.now(timezone.utc).isoformat(),
137
+ "datasets": {},
138
+ }
139
+
140
+ for dataset_id, version in versions.items():
141
+ manifest["datasets"][dataset_id] = asdict(version)
142
+
143
+ path.parent.mkdir(parents=True, exist_ok=True)
144
+
145
+ with open(path, "w") as f:
146
+ json.dump(manifest, f, indent=2)
147
+
148
+ logger.info("Saved version manifest with %d datasets to %s", len(versions), path)
149
+
150
+
151
+ def load_version_manifest(path: Path) -> dict[str, DatasetVersion]:
152
+ """Load pinned dataset versions from a JSON manifest file.
153
+
154
+ Args:
155
+ path: File path to the JSON manifest.
156
+
157
+ Returns:
158
+ Mapping of dataset identifiers to their pinned versions.
159
+
160
+ Raises:
161
+ FileNotFoundError: If the manifest file does not exist.
162
+ json.JSONDecodeError: If the manifest file contains invalid JSON.
163
+ KeyError: If the manifest is missing required fields.
164
+ """
165
+ with open(path) as f:
166
+ manifest = json.load(f)
167
+
168
+ versions: dict[str, DatasetVersion] = {}
169
+ datasets_data = manifest.get("datasets", {})
170
+
171
+ for dataset_id, version_data in datasets_data.items():
172
+ versions[dataset_id] = DatasetVersion(
173
+ dataset_id=version_data["dataset_id"],
174
+ revision=version_data.get("revision"),
175
+ download_date=version_data["download_date"],
176
+ checksum=version_data.get("checksum"),
177
+ )
178
+
179
+ logger.info("Loaded version manifest with %d datasets from %s", len(versions), path)
180
+
181
+ return versions
182
+
183
+
184
+ def get_dataset_info(dataset_id: str) -> dict[str, Any]:
185
+ """Get metadata about a HuggingFace dataset.
186
+
187
+ Retrieves information such as the latest revision, description,
188
+ file listing, and other Hub metadata.
189
+
190
+ Args:
191
+ dataset_id: HuggingFace dataset identifier.
192
+
193
+ Returns:
194
+ Dictionary containing dataset metadata with keys:
195
+ - dataset_id: The dataset identifier.
196
+ - latest_revision: The current HEAD revision hash.
197
+ - description: Dataset description text.
198
+ - tags: List of dataset tags.
199
+ - downloads: Number of downloads.
200
+ - last_modified: Last modification timestamp.
201
+ - files: List of files in the dataset repository.
202
+
203
+ Raises:
204
+ Exception: If the dataset cannot be found or accessed on the HuggingFace Hub.
205
+ """
206
+ info = dataset_info(dataset_id)
207
+
208
+ files: list[str] = []
209
+ if info.siblings:
210
+ files = [s.rfilename for s in info.siblings]
211
+
212
+ result: dict[str, Any] = {
213
+ "dataset_id": dataset_id,
214
+ "latest_revision": info.sha,
215
+ "description": info.description or "",
216
+ "tags": list(info.tags) if info.tags else [],
217
+ "downloads": info.downloads if info.downloads is not None else 0,
218
+ "last_modified": info.last_modified.isoformat() if info.last_modified else None,
219
+ "files": files,
220
+ }
221
+
222
+ return result
mcpbr/docker_env.py CHANGED
@@ -724,6 +724,12 @@ CMD ["/bin/bash"]
724
724
  if self in _active_managers:
725
725
  _active_managers.remove(self)
726
726
 
727
+ # Close the Docker client to release background threads/connections
728
+ try:
729
+ self.client.close()
730
+ except Exception:
731
+ pass
732
+
727
733
  if report and cleanup_report.total_removed > 0:
728
734
  logger.info(str(cleanup_report))
729
735