mcpbr 0.4.14__py3-none-any.whl → 0.4.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcpbr/benchmarks/__init__.py +12 -0
- mcpbr/benchmarks/adversarial.py +341 -0
- mcpbr/benchmarks/custom.py +607 -0
- mcpbr/benchmarks/longbench.py +623 -0
- mcpbr/benchmarks/mmmu.py +353 -0
- mcpbr/config.py +4 -0
- mcpbr/custom_metrics.py +405 -0
- mcpbr/dataset_versioning.py +222 -0
- mcpbr/docker_env.py +6 -0
- mcpbr/failure_analysis.py +558 -0
- mcpbr/few_shot.py +367 -0
- mcpbr/gpu_support.py +157 -0
- mcpbr/harness.py +8 -0
- mcpbr/latency_metrics.py +317 -0
- mcpbr/sampling.py +193 -0
- {mcpbr-0.4.14.dist-info → mcpbr-0.4.16.dist-info}/METADATA +10 -6
- {mcpbr-0.4.14.dist-info → mcpbr-0.4.16.dist-info}/RECORD +27 -16
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/brave-search.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/filesystem.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/github.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/google-maps.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/postgres.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/slack.yaml +0 -0
- {mcpbr-0.4.14.data → mcpbr-0.4.16.data}/data/mcpbr/data/templates/sqlite.yaml +0 -0
- {mcpbr-0.4.14.dist-info → mcpbr-0.4.16.dist-info}/WHEEL +0 -0
- {mcpbr-0.4.14.dist-info → mcpbr-0.4.16.dist-info}/entry_points.txt +0 -0
- {mcpbr-0.4.14.dist-info → mcpbr-0.4.16.dist-info}/licenses/LICENSE +0 -0
mcpbr/custom_metrics.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Custom metrics framework for flexible evaluation beyond standard accuracy/pass rates.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- MetricDefinition dataclass for declaring metrics with name, description, compute
|
|
5
|
+
function, aggregation strategy, and direction (higher_is_better).
|
|
6
|
+
- MetricRegistry for registering, looking up, and managing metrics.
|
|
7
|
+
- Built-in metrics: accuracy, pass_rate, avg_tokens, avg_cost, avg_time,
|
|
8
|
+
tool_call_rate, failure_rate.
|
|
9
|
+
- Support for composite metrics (e.g., cost_efficiency = pass_rate / avg_cost).
|
|
10
|
+
- compute_metrics() to evaluate a set of metrics against result data.
|
|
11
|
+
- validate_metric() to check metric definition validity.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import math
|
|
17
|
+
import statistics
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import Any, Callable
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class MetricDefinition:
|
|
24
|
+
"""Definition of a single evaluation metric.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
name: Unique identifier for the metric.
|
|
28
|
+
description: Human-readable description of what the metric measures.
|
|
29
|
+
compute_fn: Either a callable ``(list[dict]) -> float`` that computes the
|
|
30
|
+
metric from a list of result dicts, or a string expression referencing
|
|
31
|
+
other metric names (for composite metrics).
|
|
32
|
+
aggregation: Aggregation strategy used when summarising per-task values.
|
|
33
|
+
One of ``"mean"``, ``"sum"``, ``"min"``, ``"max"``, ``"median"``.
|
|
34
|
+
higher_is_better: Whether a higher value is considered better.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
name: str
|
|
38
|
+
description: str
|
|
39
|
+
compute_fn: Callable[[list[dict[str, Any]]], float] | str
|
|
40
|
+
aggregation: str = "mean"
|
|
41
|
+
higher_is_better: bool = True
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
_VALID_AGGREGATIONS = frozenset({"mean", "sum", "min", "max", "median"})
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MetricRegistry:
|
|
48
|
+
"""Registry for looking up and managing metric definitions.
|
|
49
|
+
|
|
50
|
+
Provides ``register``, ``get``, ``list_metrics``, and ``unregister`` operations.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self) -> None:
|
|
54
|
+
self._metrics: dict[str, MetricDefinition] = {}
|
|
55
|
+
|
|
56
|
+
# -- public API ----------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
def register(self, metric: MetricDefinition) -> None:
|
|
59
|
+
"""Register a metric definition.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
metric: The metric to register.
|
|
63
|
+
|
|
64
|
+
Raises:
|
|
65
|
+
ValueError: If a metric with the same name is already registered.
|
|
66
|
+
"""
|
|
67
|
+
if metric.name in self._metrics:
|
|
68
|
+
raise ValueError(f"Metric '{metric.name}' is already registered")
|
|
69
|
+
self._metrics[metric.name] = metric
|
|
70
|
+
|
|
71
|
+
def get(self, name: str) -> MetricDefinition | None:
|
|
72
|
+
"""Look up a metric by name.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
name: Metric name.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
The MetricDefinition if found, otherwise ``None``.
|
|
79
|
+
"""
|
|
80
|
+
return self._metrics.get(name)
|
|
81
|
+
|
|
82
|
+
def list_metrics(self) -> list[str]:
|
|
83
|
+
"""Return a sorted list of all registered metric names."""
|
|
84
|
+
return sorted(self._metrics.keys())
|
|
85
|
+
|
|
86
|
+
def unregister(self, name: str) -> bool:
|
|
87
|
+
"""Remove a metric from the registry.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
name: Metric name to remove.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
``True`` if the metric was removed, ``False`` if it was not found.
|
|
94
|
+
"""
|
|
95
|
+
if name in self._metrics:
|
|
96
|
+
del self._metrics[name]
|
|
97
|
+
return True
|
|
98
|
+
return False
|
|
99
|
+
|
|
100
|
+
def __contains__(self, name: str) -> bool:
|
|
101
|
+
return name in self._metrics
|
|
102
|
+
|
|
103
|
+
def __len__(self) -> int:
|
|
104
|
+
return len(self._metrics)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
# ---------------------------------------------------------------------------
|
|
108
|
+
# Built-in metric compute functions
|
|
109
|
+
# ---------------------------------------------------------------------------
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _compute_accuracy(results: list[dict[str, Any]]) -> float:
|
|
113
|
+
"""Fraction of results where ``resolved`` is truthy."""
|
|
114
|
+
if not results:
|
|
115
|
+
return 0.0
|
|
116
|
+
resolved = sum(1 for r in results if r.get("resolved"))
|
|
117
|
+
return resolved / len(results)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _compute_pass_rate(results: list[dict[str, Any]]) -> float:
|
|
121
|
+
"""Fraction of results where ``resolved`` is truthy (alias of accuracy)."""
|
|
122
|
+
return _compute_accuracy(results)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _compute_avg_tokens(results: list[dict[str, Any]]) -> float:
|
|
126
|
+
"""Average total token count per result."""
|
|
127
|
+
token_counts: list[int] = []
|
|
128
|
+
for r in results:
|
|
129
|
+
tokens = r.get("tokens", {})
|
|
130
|
+
total = tokens.get("input", 0) + tokens.get("output", 0)
|
|
131
|
+
token_counts.append(total)
|
|
132
|
+
if not token_counts:
|
|
133
|
+
return 0.0
|
|
134
|
+
return float(statistics.mean(token_counts))
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _compute_avg_cost(results: list[dict[str, Any]]) -> float:
|
|
138
|
+
"""Average cost per result."""
|
|
139
|
+
costs = [r.get("cost", 0.0) for r in results]
|
|
140
|
+
if not costs:
|
|
141
|
+
return 0.0
|
|
142
|
+
return statistics.mean(costs)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _compute_avg_time(results: list[dict[str, Any]]) -> float:
|
|
146
|
+
"""Average runtime in seconds per result."""
|
|
147
|
+
runtimes = [r.get("runtime_seconds", 0.0) for r in results]
|
|
148
|
+
if not runtimes:
|
|
149
|
+
return 0.0
|
|
150
|
+
return statistics.mean(runtimes)
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _compute_tool_call_rate(results: list[dict[str, Any]]) -> float:
|
|
154
|
+
"""Fraction of results that contain at least one tool call."""
|
|
155
|
+
if not results:
|
|
156
|
+
return 0.0
|
|
157
|
+
with_tools = sum(1 for r in results if r.get("tool_usage"))
|
|
158
|
+
return with_tools / len(results)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def _compute_failure_rate(results: list[dict[str, Any]]) -> float:
|
|
162
|
+
"""Fraction of results where ``error`` is present and non-empty."""
|
|
163
|
+
if not results:
|
|
164
|
+
return 0.0
|
|
165
|
+
with_errors = sum(1 for r in results if r.get("error"))
|
|
166
|
+
return with_errors / len(results)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
# ---------------------------------------------------------------------------
|
|
170
|
+
# Built-in metric definitions
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
|
|
173
|
+
BUILTIN_METRICS: list[MetricDefinition] = [
|
|
174
|
+
MetricDefinition(
|
|
175
|
+
name="accuracy",
|
|
176
|
+
description="Fraction of tasks resolved successfully",
|
|
177
|
+
compute_fn=_compute_accuracy,
|
|
178
|
+
aggregation="mean",
|
|
179
|
+
higher_is_better=True,
|
|
180
|
+
),
|
|
181
|
+
MetricDefinition(
|
|
182
|
+
name="pass_rate",
|
|
183
|
+
description="Fraction of tasks that pass (alias for accuracy)",
|
|
184
|
+
compute_fn=_compute_pass_rate,
|
|
185
|
+
aggregation="mean",
|
|
186
|
+
higher_is_better=True,
|
|
187
|
+
),
|
|
188
|
+
MetricDefinition(
|
|
189
|
+
name="avg_tokens",
|
|
190
|
+
description="Average total tokens (input + output) per task",
|
|
191
|
+
compute_fn=_compute_avg_tokens,
|
|
192
|
+
aggregation="mean",
|
|
193
|
+
higher_is_better=False,
|
|
194
|
+
),
|
|
195
|
+
MetricDefinition(
|
|
196
|
+
name="avg_cost",
|
|
197
|
+
description="Average API cost per task in USD",
|
|
198
|
+
compute_fn=_compute_avg_cost,
|
|
199
|
+
aggregation="mean",
|
|
200
|
+
higher_is_better=False,
|
|
201
|
+
),
|
|
202
|
+
MetricDefinition(
|
|
203
|
+
name="avg_time",
|
|
204
|
+
description="Average runtime per task in seconds",
|
|
205
|
+
compute_fn=_compute_avg_time,
|
|
206
|
+
aggregation="mean",
|
|
207
|
+
higher_is_better=False,
|
|
208
|
+
),
|
|
209
|
+
MetricDefinition(
|
|
210
|
+
name="tool_call_rate",
|
|
211
|
+
description="Fraction of tasks that used at least one tool",
|
|
212
|
+
compute_fn=_compute_tool_call_rate,
|
|
213
|
+
aggregation="mean",
|
|
214
|
+
higher_is_better=True,
|
|
215
|
+
),
|
|
216
|
+
MetricDefinition(
|
|
217
|
+
name="failure_rate",
|
|
218
|
+
description="Fraction of tasks that encountered an error",
|
|
219
|
+
compute_fn=_compute_failure_rate,
|
|
220
|
+
aggregation="mean",
|
|
221
|
+
higher_is_better=False,
|
|
222
|
+
),
|
|
223
|
+
]
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def create_default_registry() -> MetricRegistry:
|
|
227
|
+
"""Create a MetricRegistry pre-populated with all built-in metrics.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
A MetricRegistry instance containing the built-in metrics.
|
|
231
|
+
"""
|
|
232
|
+
registry = MetricRegistry()
|
|
233
|
+
for metric in BUILTIN_METRICS:
|
|
234
|
+
registry.register(metric)
|
|
235
|
+
return registry
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
# ---------------------------------------------------------------------------
|
|
239
|
+
# Aggregation helpers
|
|
240
|
+
# ---------------------------------------------------------------------------
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _aggregate(values: list[float], method: str) -> float:
|
|
244
|
+
"""Aggregate a list of floats using the specified method.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
values: Numeric values to aggregate.
|
|
248
|
+
method: One of ``"mean"``, ``"sum"``, ``"min"``, ``"max"``, ``"median"``.
|
|
249
|
+
|
|
250
|
+
Returns:
|
|
251
|
+
Aggregated value.
|
|
252
|
+
|
|
253
|
+
Raises:
|
|
254
|
+
ValueError: If the method is unrecognised.
|
|
255
|
+
"""
|
|
256
|
+
if not values:
|
|
257
|
+
return 0.0
|
|
258
|
+
if method == "mean":
|
|
259
|
+
return statistics.mean(values)
|
|
260
|
+
elif method == "sum":
|
|
261
|
+
return math.fsum(values)
|
|
262
|
+
elif method == "min":
|
|
263
|
+
return min(values)
|
|
264
|
+
elif method == "max":
|
|
265
|
+
return max(values)
|
|
266
|
+
elif method == "median":
|
|
267
|
+
return statistics.median(values)
|
|
268
|
+
else:
|
|
269
|
+
raise ValueError(f"Unknown aggregation method: {method!r}")
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
# ---------------------------------------------------------------------------
|
|
273
|
+
# Core public API
|
|
274
|
+
# ---------------------------------------------------------------------------
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def compute_metrics(
|
|
278
|
+
results: list[dict[str, Any]],
|
|
279
|
+
metrics: list[str],
|
|
280
|
+
registry: MetricRegistry | None = None,
|
|
281
|
+
) -> dict[str, float]:
|
|
282
|
+
"""Compute the requested metrics over a list of result dicts.
|
|
283
|
+
|
|
284
|
+
Each result dict is expected to follow the structure used elsewhere in mcpbr
|
|
285
|
+
(keys such as ``resolved``, ``tokens``, ``cost``, ``runtime_seconds``,
|
|
286
|
+
``tool_usage``, ``error``).
|
|
287
|
+
|
|
288
|
+
Composite metrics (whose ``compute_fn`` is a string expression) are resolved
|
|
289
|
+
by first computing all non-composite metrics they reference, then evaluating the
|
|
290
|
+
expression in a restricted namespace.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
results: List of per-task result dictionaries.
|
|
294
|
+
metrics: List of metric names to compute.
|
|
295
|
+
registry: Optional MetricRegistry. If ``None``, the default registry
|
|
296
|
+
(containing built-in metrics) is used.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
Dictionary mapping metric names to their computed float values.
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
KeyError: If a requested metric is not found in the registry.
|
|
303
|
+
ValueError: If a composite expression references an unknown metric or
|
|
304
|
+
fails to evaluate.
|
|
305
|
+
"""
|
|
306
|
+
if registry is None:
|
|
307
|
+
registry = create_default_registry()
|
|
308
|
+
|
|
309
|
+
computed: dict[str, float] = {}
|
|
310
|
+
|
|
311
|
+
# Separate callable and composite (expression-based) metrics
|
|
312
|
+
callable_names: list[str] = []
|
|
313
|
+
composite_names: list[str] = []
|
|
314
|
+
|
|
315
|
+
for name in metrics:
|
|
316
|
+
metric_def = registry.get(name)
|
|
317
|
+
if metric_def is None:
|
|
318
|
+
raise KeyError(f"Metric '{name}' is not registered")
|
|
319
|
+
if callable(metric_def.compute_fn):
|
|
320
|
+
callable_names.append(name)
|
|
321
|
+
else:
|
|
322
|
+
composite_names.append(name)
|
|
323
|
+
|
|
324
|
+
# Phase 1: compute all callable metrics
|
|
325
|
+
for name in callable_names:
|
|
326
|
+
metric_def = registry.get(name)
|
|
327
|
+
assert metric_def is not None # guaranteed above
|
|
328
|
+
assert callable(metric_def.compute_fn)
|
|
329
|
+
computed[name] = metric_def.compute_fn(results)
|
|
330
|
+
|
|
331
|
+
# Phase 2: resolve composite metrics
|
|
332
|
+
for name in composite_names:
|
|
333
|
+
metric_def = registry.get(name)
|
|
334
|
+
assert metric_def is not None
|
|
335
|
+
assert isinstance(metric_def.compute_fn, str)
|
|
336
|
+
|
|
337
|
+
# Build a namespace of already-computed values. If the expression
|
|
338
|
+
# references a metric that hasn't been computed yet, compute it now.
|
|
339
|
+
ns: dict[str, float] = {}
|
|
340
|
+
for existing_name, existing_val in computed.items():
|
|
341
|
+
ns[existing_name] = existing_val
|
|
342
|
+
|
|
343
|
+
# Evaluate the expression. We deliberately restrict the namespace to
|
|
344
|
+
# only contain computed metric values (no builtins).
|
|
345
|
+
try:
|
|
346
|
+
value = float(eval(metric_def.compute_fn, {"__builtins__": {}}, ns)) # noqa: S307
|
|
347
|
+
except ZeroDivisionError:
|
|
348
|
+
value = 0.0
|
|
349
|
+
except Exception as exc:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Failed to evaluate composite metric '{name}' "
|
|
352
|
+
f"expression '{metric_def.compute_fn}': {exc}"
|
|
353
|
+
) from exc
|
|
354
|
+
|
|
355
|
+
computed[name] = value
|
|
356
|
+
|
|
357
|
+
return computed
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def validate_metric(metric_def: dict[str, Any]) -> bool:
|
|
361
|
+
"""Validate a metric definition dictionary.
|
|
362
|
+
|
|
363
|
+
Checks that the definition contains all required fields with correct types
|
|
364
|
+
and valid values.
|
|
365
|
+
|
|
366
|
+
Required keys:
|
|
367
|
+
- ``name`` (str, non-empty)
|
|
368
|
+
- ``description`` (str)
|
|
369
|
+
- ``compute_fn`` (callable or str)
|
|
370
|
+
|
|
371
|
+
Optional keys (with defaults):
|
|
372
|
+
- ``aggregation`` (str, one of mean/sum/min/max/median)
|
|
373
|
+
- ``higher_is_better`` (bool)
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
metric_def: Dictionary representing a metric definition.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
``True`` if the definition is valid, ``False`` otherwise.
|
|
380
|
+
"""
|
|
381
|
+
# Required fields
|
|
382
|
+
if not isinstance(metric_def.get("name"), str) or not metric_def["name"].strip():
|
|
383
|
+
return False
|
|
384
|
+
|
|
385
|
+
if not isinstance(metric_def.get("description"), str):
|
|
386
|
+
return False
|
|
387
|
+
|
|
388
|
+
compute_fn = metric_def.get("compute_fn")
|
|
389
|
+
if compute_fn is None:
|
|
390
|
+
return False
|
|
391
|
+
if not callable(compute_fn) and not isinstance(compute_fn, str):
|
|
392
|
+
return False
|
|
393
|
+
if isinstance(compute_fn, str) and not compute_fn.strip():
|
|
394
|
+
return False
|
|
395
|
+
|
|
396
|
+
# Optional fields
|
|
397
|
+
aggregation = metric_def.get("aggregation", "mean")
|
|
398
|
+
if aggregation not in _VALID_AGGREGATIONS:
|
|
399
|
+
return False
|
|
400
|
+
|
|
401
|
+
higher_is_better = metric_def.get("higher_is_better", True)
|
|
402
|
+
if not isinstance(higher_is_better, bool):
|
|
403
|
+
return False
|
|
404
|
+
|
|
405
|
+
return True
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Dataset versioning for reproducible benchmark evaluations.
|
|
2
|
+
|
|
3
|
+
This module provides utilities to pin and track HuggingFace dataset versions,
|
|
4
|
+
ensuring that benchmark runs can be reproduced with the exact same data.
|
|
5
|
+
Version information includes dataset revision hashes, download timestamps,
|
|
6
|
+
and optional checksums for data integrity verification.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import hashlib
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
from dataclasses import asdict, dataclass
|
|
13
|
+
from datetime import datetime, timezone
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
from datasets import Dataset, load_dataset
|
|
18
|
+
from huggingface_hub import dataset_info
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class DatasetVersion:
|
|
25
|
+
"""Pinned version information for a HuggingFace dataset.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
dataset_id: HuggingFace dataset identifier (e.g., 'SWE-bench/SWE-bench_Lite').
|
|
29
|
+
revision: Git revision hash of the dataset (None for latest).
|
|
30
|
+
download_date: ISO 8601 timestamp of when the version was pinned.
|
|
31
|
+
checksum: Optional SHA256 checksum of the dataset content for integrity verification.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
dataset_id: str
|
|
35
|
+
revision: str | None
|
|
36
|
+
download_date: str
|
|
37
|
+
checksum: str | None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def pin_dataset_version(
|
|
41
|
+
dataset_id: str,
|
|
42
|
+
revision: str | None = None,
|
|
43
|
+
) -> DatasetVersion:
|
|
44
|
+
"""Record the current version of a HuggingFace dataset.
|
|
45
|
+
|
|
46
|
+
Fetches dataset metadata from the HuggingFace Hub to determine the
|
|
47
|
+
current revision. If a specific revision is provided, it is used directly.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dataset_id: HuggingFace dataset identifier (e.g., 'SWE-bench/SWE-bench_Lite').
|
|
51
|
+
revision: Specific git revision to pin. If None, the latest revision is fetched.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
DatasetVersion with the pinned revision and metadata.
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
Exception: If the dataset cannot be found or accessed on the HuggingFace Hub.
|
|
58
|
+
"""
|
|
59
|
+
info = dataset_info(dataset_id, revision=revision)
|
|
60
|
+
resolved_revision = info.sha
|
|
61
|
+
|
|
62
|
+
# Compute a checksum from the dataset card and file metadata for integrity
|
|
63
|
+
checksum_data = f"{dataset_id}:{resolved_revision}"
|
|
64
|
+
if info.siblings:
|
|
65
|
+
file_names = sorted(s.rfilename for s in info.siblings)
|
|
66
|
+
checksum_data += ":" + ",".join(file_names)
|
|
67
|
+
checksum = hashlib.sha256(checksum_data.encode()).hexdigest()
|
|
68
|
+
|
|
69
|
+
download_date = datetime.now(timezone.utc).isoformat()
|
|
70
|
+
|
|
71
|
+
version = DatasetVersion(
|
|
72
|
+
dataset_id=dataset_id,
|
|
73
|
+
revision=resolved_revision,
|
|
74
|
+
download_date=download_date,
|
|
75
|
+
checksum=checksum,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
logger.info(
|
|
79
|
+
"Pinned dataset %s at revision %s",
|
|
80
|
+
dataset_id,
|
|
81
|
+
resolved_revision,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return version
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def load_dataset_pinned(
|
|
88
|
+
dataset_id: str,
|
|
89
|
+
version: DatasetVersion | None = None,
|
|
90
|
+
**kwargs: Any,
|
|
91
|
+
) -> Dataset:
|
|
92
|
+
"""Load a HuggingFace dataset using a pinned version for reproducibility.
|
|
93
|
+
|
|
94
|
+
Wraps the standard ``datasets.load_dataset`` call, injecting the pinned
|
|
95
|
+
revision so that the exact same data snapshot is used across runs.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
dataset_id: HuggingFace dataset identifier.
|
|
99
|
+
version: Pinned version to use. If None, loads the latest version.
|
|
100
|
+
**kwargs: Additional keyword arguments passed to ``datasets.load_dataset``
|
|
101
|
+
(e.g., split, name, streaming).
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
The loaded HuggingFace Dataset.
|
|
105
|
+
"""
|
|
106
|
+
revision = None
|
|
107
|
+
if version is not None:
|
|
108
|
+
revision = version.revision
|
|
109
|
+
logger.info(
|
|
110
|
+
"Loading dataset %s at pinned revision %s (pinned on %s)",
|
|
111
|
+
dataset_id,
|
|
112
|
+
revision,
|
|
113
|
+
version.download_date,
|
|
114
|
+
)
|
|
115
|
+
else:
|
|
116
|
+
logger.info("Loading dataset %s at latest revision", dataset_id)
|
|
117
|
+
|
|
118
|
+
return load_dataset(dataset_id, revision=revision, **kwargs)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def save_version_manifest(
|
|
122
|
+
versions: dict[str, DatasetVersion],
|
|
123
|
+
path: Path,
|
|
124
|
+
) -> None:
|
|
125
|
+
"""Save dataset version pins to a JSON manifest file.
|
|
126
|
+
|
|
127
|
+
The manifest file records all pinned dataset versions so they can be
|
|
128
|
+
shared across team members or CI environments for reproducible runs.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
versions: Mapping of dataset identifiers to their pinned versions.
|
|
132
|
+
path: File path to write the JSON manifest.
|
|
133
|
+
"""
|
|
134
|
+
manifest: dict[str, Any] = {
|
|
135
|
+
"format_version": "1.0",
|
|
136
|
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
137
|
+
"datasets": {},
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
for dataset_id, version in versions.items():
|
|
141
|
+
manifest["datasets"][dataset_id] = asdict(version)
|
|
142
|
+
|
|
143
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
144
|
+
|
|
145
|
+
with open(path, "w") as f:
|
|
146
|
+
json.dump(manifest, f, indent=2)
|
|
147
|
+
|
|
148
|
+
logger.info("Saved version manifest with %d datasets to %s", len(versions), path)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def load_version_manifest(path: Path) -> dict[str, DatasetVersion]:
|
|
152
|
+
"""Load pinned dataset versions from a JSON manifest file.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
path: File path to the JSON manifest.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Mapping of dataset identifiers to their pinned versions.
|
|
159
|
+
|
|
160
|
+
Raises:
|
|
161
|
+
FileNotFoundError: If the manifest file does not exist.
|
|
162
|
+
json.JSONDecodeError: If the manifest file contains invalid JSON.
|
|
163
|
+
KeyError: If the manifest is missing required fields.
|
|
164
|
+
"""
|
|
165
|
+
with open(path) as f:
|
|
166
|
+
manifest = json.load(f)
|
|
167
|
+
|
|
168
|
+
versions: dict[str, DatasetVersion] = {}
|
|
169
|
+
datasets_data = manifest.get("datasets", {})
|
|
170
|
+
|
|
171
|
+
for dataset_id, version_data in datasets_data.items():
|
|
172
|
+
versions[dataset_id] = DatasetVersion(
|
|
173
|
+
dataset_id=version_data["dataset_id"],
|
|
174
|
+
revision=version_data.get("revision"),
|
|
175
|
+
download_date=version_data["download_date"],
|
|
176
|
+
checksum=version_data.get("checksum"),
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
logger.info("Loaded version manifest with %d datasets from %s", len(versions), path)
|
|
180
|
+
|
|
181
|
+
return versions
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def get_dataset_info(dataset_id: str) -> dict[str, Any]:
|
|
185
|
+
"""Get metadata about a HuggingFace dataset.
|
|
186
|
+
|
|
187
|
+
Retrieves information such as the latest revision, description,
|
|
188
|
+
file listing, and other Hub metadata.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
dataset_id: HuggingFace dataset identifier.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Dictionary containing dataset metadata with keys:
|
|
195
|
+
- dataset_id: The dataset identifier.
|
|
196
|
+
- latest_revision: The current HEAD revision hash.
|
|
197
|
+
- description: Dataset description text.
|
|
198
|
+
- tags: List of dataset tags.
|
|
199
|
+
- downloads: Number of downloads.
|
|
200
|
+
- last_modified: Last modification timestamp.
|
|
201
|
+
- files: List of files in the dataset repository.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
Exception: If the dataset cannot be found or accessed on the HuggingFace Hub.
|
|
205
|
+
"""
|
|
206
|
+
info = dataset_info(dataset_id)
|
|
207
|
+
|
|
208
|
+
files: list[str] = []
|
|
209
|
+
if info.siblings:
|
|
210
|
+
files = [s.rfilename for s in info.siblings]
|
|
211
|
+
|
|
212
|
+
result: dict[str, Any] = {
|
|
213
|
+
"dataset_id": dataset_id,
|
|
214
|
+
"latest_revision": info.sha,
|
|
215
|
+
"description": info.description or "",
|
|
216
|
+
"tags": list(info.tags) if info.tags else [],
|
|
217
|
+
"downloads": info.downloads if info.downloads is not None else 0,
|
|
218
|
+
"last_modified": info.last_modified.isoformat() if info.last_modified else None,
|
|
219
|
+
"files": files,
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
return result
|
mcpbr/docker_env.py
CHANGED
|
@@ -724,6 +724,12 @@ CMD ["/bin/bash"]
|
|
|
724
724
|
if self in _active_managers:
|
|
725
725
|
_active_managers.remove(self)
|
|
726
726
|
|
|
727
|
+
# Close the Docker client to release background threads/connections
|
|
728
|
+
try:
|
|
729
|
+
self.client.close()
|
|
730
|
+
except Exception:
|
|
731
|
+
pass
|
|
732
|
+
|
|
727
733
|
if report and cleanup_report.total_removed > 0:
|
|
728
734
|
logger.info(str(cleanup_report))
|
|
729
735
|
|