mcpbr 0.4.16__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,369 @@
1
+ """Docker image pre-warming for mcpbr benchmark evaluations.
2
+
3
+ Pre-pulls Docker images needed for a benchmark run before evaluation begins,
4
+ so that image pull time does not inflate task-level timing measurements.
5
+ Supports parallel pulling, progress reporting, and local cache detection.
6
+ """
7
+
8
+ import asyncio
9
+ import logging
10
+ import time
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Callable
13
+
14
+ import docker
15
+ import docker.errors
16
+ from rich.console import Console
17
+ from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
18
+ from rich.table import Table
19
+
20
+ from .docker_env import SWEBENCH_IMAGE_REGISTRY, get_swebench_image_name
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Default base images used by non-SWE-bench benchmarks
25
+ DEFAULT_BASE_IMAGES: dict[str, list[str]] = {
26
+ "humaneval": ["python:3.11-slim"],
27
+ "mbpp": ["python:3.11-slim"],
28
+ "apps": ["python:3.11-slim"],
29
+ "codecontests": ["python:3.11-slim"],
30
+ "bigcodebench": ["python:3.11-slim"],
31
+ "leetcode": ["python:3.11-slim"],
32
+ "codereval": ["python:3.11-slim"],
33
+ "gsm8k": ["python:3.11-slim"],
34
+ "math": ["python:3.11-slim"],
35
+ "truthfulqa": ["python:3.11-slim"],
36
+ "bigbench-hard": ["python:3.11-slim"],
37
+ "hellaswag": ["python:3.11-slim"],
38
+ "arc": ["python:3.11-slim"],
39
+ "repoqa": ["python:3.11-slim"],
40
+ "toolbench": ["python:3.11-slim"],
41
+ "aider-polyglot": ["python:3.11-slim"],
42
+ "terminalbench": ["python:3.11-slim"],
43
+ "gaia": ["python:3.11-slim"],
44
+ "agentbench": ["python:3.11-slim"],
45
+ "webarena": ["python:3.11-slim"],
46
+ "mlagentbench": ["python:3.11-slim"],
47
+ "intercode": ["python:3.11-slim"],
48
+ "mmmu": ["python:3.11-slim"],
49
+ "longbench": ["python:3.11-slim"],
50
+ "adversarial": ["python:3.11-slim"],
51
+ "mcptoolbench": ["python:3.11-slim"],
52
+ "custom": ["python:3.11-slim"],
53
+ "cybergym": ["python:3.11-slim"],
54
+ }
55
+
56
+
57
+ @dataclass
58
+ class PrewarmResult:
59
+ """Result of a Docker image pre-warming operation.
60
+
61
+ Attributes:
62
+ total_images: Total number of images that were requested.
63
+ already_cached: Number of images already available locally.
64
+ newly_pulled: Number of images successfully pulled from registry.
65
+ failed: List of image names that failed to pull.
66
+ pull_time_seconds: Wall-clock time for the entire pre-warm operation.
67
+ """
68
+
69
+ total_images: int = 0
70
+ already_cached: int = 0
71
+ newly_pulled: int = 0
72
+ failed: list[str] = field(default_factory=list)
73
+ pull_time_seconds: float = 0.0
74
+
75
+
76
+ def get_required_images(benchmark_name: str, tasks: list[dict[str, Any]]) -> list[str]:
77
+ """Determine which Docker images are needed for a benchmark run.
78
+
79
+ For SWE-bench variants, each task maps to a unique per-instance image from
80
+ the Epoch Research registry. For other benchmarks, a common base image
81
+ (typically ``python:3.11-slim``) is used.
82
+
83
+ Args:
84
+ benchmark_name: Name of the benchmark (e.g., ``"swe-bench-lite"``).
85
+ tasks: List of task dictionaries loaded from the benchmark.
86
+
87
+ Returns:
88
+ Deduplicated list of Docker image names required for the run.
89
+ """
90
+ images: list[str] = []
91
+
92
+ is_swebench = benchmark_name.startswith("swe-bench")
93
+
94
+ if is_swebench:
95
+ seen: set[str] = set()
96
+ for task in tasks:
97
+ instance_id = task.get("instance_id", "")
98
+ if instance_id and instance_id not in seen:
99
+ seen.add(instance_id)
100
+ images.append(get_swebench_image_name(instance_id))
101
+ else:
102
+ # Use default base images for the benchmark, or fall back to python:3.11-slim
103
+ base_images = DEFAULT_BASE_IMAGES.get(benchmark_name, ["python:3.11-slim"])
104
+ images = list(base_images)
105
+
106
+ return images
107
+
108
+
109
+ def check_cached_images(images: list[str]) -> dict[str, bool]:
110
+ """Check which Docker images are already available in the local cache.
111
+
112
+ Uses ``docker.from_env()`` to inspect the local image store. Images that
113
+ are present locally are marked ``True``; missing or inaccessible images
114
+ are marked ``False``.
115
+
116
+ Args:
117
+ images: List of Docker image names to check.
118
+
119
+ Returns:
120
+ Dictionary mapping each image name to a boolean indicating whether
121
+ it is cached locally.
122
+ """
123
+ result: dict[str, bool] = {}
124
+ try:
125
+ client = docker.from_env()
126
+ except docker.errors.DockerException:
127
+ logger.warning("Could not connect to Docker daemon for cache check")
128
+ return {image: False for image in images}
129
+
130
+ for image in images:
131
+ try:
132
+ client.images.get(image)
133
+ result[image] = True
134
+ except docker.errors.ImageNotFound:
135
+ result[image] = False
136
+ except docker.errors.APIError:
137
+ result[image] = False
138
+
139
+ return result
140
+
141
+
142
+ async def _pull_single_image(
143
+ client: docker.DockerClient,
144
+ image: str,
145
+ semaphore: asyncio.Semaphore,
146
+ on_progress: Callable[[str, str], None] | None = None,
147
+ ) -> tuple[str, bool]:
148
+ """Pull a single Docker image, respecting the concurrency semaphore.
149
+
150
+ Args:
151
+ client: Docker client instance.
152
+ image: Full image name to pull.
153
+ semaphore: Asyncio semaphore to limit parallel pulls.
154
+ on_progress: Optional callback ``(image, status)`` for progress updates.
155
+
156
+ Returns:
157
+ Tuple of ``(image_name, success)``.
158
+ """
159
+
160
+ async with semaphore:
161
+ if on_progress:
162
+ on_progress(image, "pulling")
163
+
164
+ def _do_pull() -> bool:
165
+ try:
166
+ # Determine platform for SWE-bench images
167
+ platform = "linux/amd64" if SWEBENCH_IMAGE_REGISTRY in image else None
168
+ client.images.pull(image, platform=platform)
169
+ return True
170
+ except docker.errors.ImageNotFound:
171
+ logger.warning("Image not found in registry: %s", image)
172
+ return False
173
+ except docker.errors.APIError as exc:
174
+ logger.warning("Failed to pull image %s: %s", image, exc)
175
+ return False
176
+
177
+ loop = asyncio.get_event_loop()
178
+ success = await loop.run_in_executor(None, _do_pull)
179
+
180
+ if on_progress:
181
+ on_progress(image, "done" if success else "failed")
182
+
183
+ return image, success
184
+
185
+
186
+ async def prewarm_images(
187
+ benchmark_name: str,
188
+ tasks: list[dict[str, Any]],
189
+ max_parallel: int = 3,
190
+ on_progress: Callable[[str, str], None] | None = None,
191
+ ) -> PrewarmResult:
192
+ """Pre-pull all Docker images needed for a benchmark run.
193
+
194
+ Checks the local cache first, then pulls missing images in parallel
195
+ (limited by ``max_parallel``). Returns a summary of the operation.
196
+
197
+ Args:
198
+ benchmark_name: Name of the benchmark (e.g., ``"swe-bench-verified"``).
199
+ tasks: List of task dictionaries from the benchmark loader.
200
+ max_parallel: Maximum number of concurrent image pulls. Defaults to 3.
201
+ on_progress: Optional callback ``(image_name, status_string)`` invoked
202
+ when an image starts pulling or completes.
203
+
204
+ Returns:
205
+ PrewarmResult summarising cached, pulled, and failed images.
206
+ """
207
+ start_time = time.monotonic()
208
+
209
+ images = get_required_images(benchmark_name, tasks)
210
+ total = len(images)
211
+
212
+ if total == 0:
213
+ return PrewarmResult(pull_time_seconds=time.monotonic() - start_time)
214
+
215
+ # Check local cache
216
+ cache_status = check_cached_images(images)
217
+ already_cached = sum(1 for cached in cache_status.values() if cached)
218
+ to_pull = [img for img, cached in cache_status.items() if not cached]
219
+
220
+ if not to_pull:
221
+ return PrewarmResult(
222
+ total_images=total,
223
+ already_cached=already_cached,
224
+ newly_pulled=0,
225
+ failed=[],
226
+ pull_time_seconds=time.monotonic() - start_time,
227
+ )
228
+
229
+ # Pull missing images in parallel
230
+ try:
231
+ client = docker.from_env()
232
+ except docker.errors.DockerException as exc:
233
+ logger.error("Cannot connect to Docker daemon: %s", exc)
234
+ return PrewarmResult(
235
+ total_images=total,
236
+ already_cached=already_cached,
237
+ newly_pulled=0,
238
+ failed=to_pull,
239
+ pull_time_seconds=time.monotonic() - start_time,
240
+ )
241
+
242
+ semaphore = asyncio.Semaphore(max_parallel)
243
+
244
+ pull_tasks = [
245
+ _pull_single_image(client, image, semaphore, on_progress=on_progress) for image in to_pull
246
+ ]
247
+
248
+ results = await asyncio.gather(*pull_tasks, return_exceptions=True)
249
+
250
+ newly_pulled = 0
251
+ failed: list[str] = []
252
+ for result in results:
253
+ if isinstance(result, Exception):
254
+ logger.error("Unexpected error during image pull: %s", result)
255
+ failed.append(str(result))
256
+ else:
257
+ image_name, success = result
258
+ if success:
259
+ newly_pulled += 1
260
+ else:
261
+ failed.append(image_name)
262
+
263
+ return PrewarmResult(
264
+ total_images=total,
265
+ already_cached=already_cached,
266
+ newly_pulled=newly_pulled,
267
+ failed=failed,
268
+ pull_time_seconds=time.monotonic() - start_time,
269
+ )
270
+
271
+
272
+ def format_prewarm_report(result: PrewarmResult) -> None:
273
+ """Print a rich-formatted summary of the pre-warm operation.
274
+
275
+ Displays a table with counts of cached, pulled, and failed images,
276
+ plus total elapsed time.
277
+
278
+ Args:
279
+ result: PrewarmResult from a completed pre-warm operation.
280
+ """
281
+ console = Console()
282
+
283
+ table = Table(title="Docker Image Pre-warm Summary", show_header=True, header_style="bold")
284
+ table.add_column("Metric", style="cyan", no_wrap=True)
285
+ table.add_column("Value", justify="right")
286
+
287
+ table.add_row("Total images", str(result.total_images))
288
+ table.add_row("Already cached", str(result.already_cached))
289
+ table.add_row("Newly pulled", str(result.newly_pulled))
290
+ table.add_row("Failed", str(len(result.failed)))
291
+ table.add_row("Pull time", f"{result.pull_time_seconds:.1f}s")
292
+
293
+ console.print()
294
+ console.print(table)
295
+
296
+ if result.failed:
297
+ console.print()
298
+ console.print("[red bold]Failed to pull the following images:[/red bold]")
299
+ for image in result.failed:
300
+ console.print(f"[red] - {image}[/red]")
301
+
302
+ if result.failed:
303
+ console.print()
304
+ console.print(
305
+ "[yellow]Some images could not be pre-warmed. "
306
+ "Evaluation will attempt to pull them at task time.[/yellow]"
307
+ )
308
+ elif result.newly_pulled > 0:
309
+ console.print()
310
+ console.print("[green bold]All images pre-warmed successfully.[/green bold]")
311
+ elif result.total_images > 0:
312
+ console.print()
313
+ console.print("[green]All images already cached locally.[/green]")
314
+ console.print()
315
+
316
+
317
+ async def prewarm_images_with_progress(
318
+ benchmark_name: str,
319
+ tasks: list[dict[str, Any]],
320
+ max_parallel: int = 3,
321
+ ) -> PrewarmResult:
322
+ """Pre-pull images with a rich progress bar displayed in the terminal.
323
+
324
+ This is a convenience wrapper around :func:`prewarm_images` that
325
+ creates and manages a ``rich.progress.Progress`` bar automatically.
326
+
327
+ Args:
328
+ benchmark_name: Name of the benchmark.
329
+ tasks: List of task dictionaries from the benchmark loader.
330
+ max_parallel: Maximum number of concurrent image pulls. Defaults to 3.
331
+
332
+ Returns:
333
+ PrewarmResult summarising cached, pulled, and failed images.
334
+ """
335
+ images = get_required_images(benchmark_name, tasks)
336
+ cache_status = check_cached_images(images)
337
+ to_pull = [img for img, cached in cache_status.items() if not cached]
338
+
339
+ if not to_pull:
340
+ # Nothing to pull, just compute the result quickly
341
+ return await prewarm_images(benchmark_name, tasks, max_parallel)
342
+
343
+ console = Console()
344
+ progress = Progress(
345
+ SpinnerColumn(),
346
+ TextColumn("[progress.description]{task.description}"),
347
+ BarColumn(),
348
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
349
+ TimeElapsedColumn(),
350
+ console=console,
351
+ )
352
+
353
+ task_id = progress.add_task("Pre-warming Docker images", total=len(to_pull))
354
+ completed_count = 0
355
+
356
+ def _on_progress(image: str, status: str) -> None:
357
+ nonlocal completed_count
358
+ if status in ("done", "failed"):
359
+ completed_count += 1
360
+ progress.update(task_id, completed=completed_count)
361
+ else:
362
+ # Shorten image name for display
363
+ short_name = image.split("/")[-1][:40]
364
+ progress.update(task_id, description=f"Pulling {short_name}")
365
+
366
+ with progress:
367
+ result = await prewarm_images(benchmark_name, tasks, max_parallel, on_progress=_on_progress)
368
+
369
+ return result