mcpbr 0.4.16__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcpbr/__init__.py +20 -1
- mcpbr/config.py +37 -1
- mcpbr/config_migration.py +470 -0
- mcpbr/config_wizard.py +647 -0
- mcpbr/dashboard.py +619 -0
- mcpbr/dataset_streaming.py +491 -0
- mcpbr/docker_cache.py +539 -0
- mcpbr/docker_env.py +2 -1
- mcpbr/docker_prewarm.py +370 -0
- mcpbr/dry_run.py +533 -0
- mcpbr/formatting.py +444 -0
- mcpbr/gpu_support.py +2 -1
- mcpbr/graceful_degradation.py +277 -0
- mcpbr/harness.py +38 -4
- mcpbr/languages.py +228 -0
- mcpbr/logging_config.py +207 -0
- mcpbr/models.py +66 -0
- mcpbr/preflight.py +2 -1
- mcpbr/pricing.py +72 -0
- mcpbr/providers.py +316 -3
- mcpbr/resource_limits.py +487 -0
- mcpbr/result_streaming.py +519 -0
- mcpbr/sdk.py +264 -0
- mcpbr/smoke_test.py +2 -1
- mcpbr/task_batching.py +403 -0
- mcpbr/task_scheduler.py +468 -0
- {mcpbr-0.4.16.dist-info → mcpbr-0.6.0.dist-info}/METADATA +8 -1
- {mcpbr-0.4.16.dist-info → mcpbr-0.6.0.dist-info}/RECORD +38 -22
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/brave-search.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/filesystem.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/github.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/google-maps.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/postgres.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/slack.yaml +0 -0
- {mcpbr-0.4.16.data → mcpbr-0.6.0.data}/data/mcpbr/data/templates/sqlite.yaml +0 -0
- {mcpbr-0.4.16.dist-info → mcpbr-0.6.0.dist-info}/WHEEL +0 -0
- {mcpbr-0.4.16.dist-info → mcpbr-0.6.0.dist-info}/entry_points.txt +0 -0
- {mcpbr-0.4.16.dist-info → mcpbr-0.6.0.dist-info}/licenses/LICENSE +0 -0
mcpbr/sdk.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
"""Public Python SDK for mcpbr.
|
|
2
|
+
|
|
3
|
+
Provides a programmatic interface for running MCP server benchmarks
|
|
4
|
+
from Python code, without requiring the CLI.
|
|
5
|
+
|
|
6
|
+
Example usage::
|
|
7
|
+
|
|
8
|
+
from mcpbr import MCPBenchmark, list_benchmarks, list_models
|
|
9
|
+
|
|
10
|
+
# List available benchmarks
|
|
11
|
+
for b in list_benchmarks():
|
|
12
|
+
print(b["name"])
|
|
13
|
+
|
|
14
|
+
# Create and run a benchmark
|
|
15
|
+
bench = MCPBenchmark({
|
|
16
|
+
"mcp_server": {
|
|
17
|
+
"command": "npx",
|
|
18
|
+
"args": ["-y", "@modelcontextprotocol/server-filesystem", "{workdir}"],
|
|
19
|
+
},
|
|
20
|
+
"benchmark": "humaneval",
|
|
21
|
+
"model": "sonnet",
|
|
22
|
+
})
|
|
23
|
+
|
|
24
|
+
is_valid, errors = bench.validate()
|
|
25
|
+
plan = bench.dry_run()
|
|
26
|
+
|
|
27
|
+
# Async execution
|
|
28
|
+
import asyncio
|
|
29
|
+
result = asyncio.run(bench.run())
|
|
30
|
+
print(result.success, result.summary)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
from dataclasses import dataclass
|
|
34
|
+
from pathlib import Path
|
|
35
|
+
from typing import Any
|
|
36
|
+
|
|
37
|
+
from . import __version__
|
|
38
|
+
from .benchmarks import BENCHMARK_REGISTRY
|
|
39
|
+
from .config import VALID_PROVIDERS, HarnessConfig, load_config
|
|
40
|
+
from .models import SUPPORTED_MODELS, validate_model
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class BenchmarkResult:
|
|
45
|
+
"""Result of a benchmark run.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
success: Whether the benchmark completed successfully.
|
|
49
|
+
summary: Aggregated results (e.g., pass rate, resolved count).
|
|
50
|
+
tasks: Per-task results as a list of dicts.
|
|
51
|
+
metadata: Run metadata (benchmark name, model, timestamps, etc.).
|
|
52
|
+
total_cost: Total API cost in USD.
|
|
53
|
+
total_tokens: Total tokens consumed.
|
|
54
|
+
duration_seconds: Wall-clock duration of the run.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
success: bool
|
|
58
|
+
summary: dict[str, Any]
|
|
59
|
+
tasks: list[dict[str, Any]]
|
|
60
|
+
metadata: dict[str, Any]
|
|
61
|
+
total_cost: float = 0.0
|
|
62
|
+
total_tokens: int = 0
|
|
63
|
+
duration_seconds: float = 0.0
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class MCPBenchmark:
|
|
67
|
+
"""High-level interface for configuring and running MCP benchmarks.
|
|
68
|
+
|
|
69
|
+
Can be initialized from a config dict, a YAML file path (str or Path),
|
|
70
|
+
or an existing HarnessConfig instance.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
config: A dict of config values, a path to a YAML config file
|
|
74
|
+
(str or Path), or a HarnessConfig instance.
|
|
75
|
+
|
|
76
|
+
Raises:
|
|
77
|
+
FileNotFoundError: If a file path is given and the file does not exist.
|
|
78
|
+
ValueError: If the config dict is invalid.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def __init__(self, config: dict[str, Any] | str | Path | HarnessConfig) -> None:
|
|
82
|
+
if isinstance(config, HarnessConfig):
|
|
83
|
+
self.config: HarnessConfig = config
|
|
84
|
+
elif isinstance(config, (str, Path)):
|
|
85
|
+
path = Path(config)
|
|
86
|
+
if not path.exists():
|
|
87
|
+
raise FileNotFoundError(f"Config file not found: {path}")
|
|
88
|
+
self.config = load_config(path, warn_security=False)
|
|
89
|
+
elif isinstance(config, dict):
|
|
90
|
+
self.config = HarnessConfig(**config)
|
|
91
|
+
else:
|
|
92
|
+
raise TypeError(
|
|
93
|
+
f"config must be a dict, str, Path, or HarnessConfig, got {type(config).__name__}"
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def validate(self) -> tuple[bool, list[str]]:
|
|
97
|
+
"""Validate the current configuration.
|
|
98
|
+
|
|
99
|
+
Checks that the configuration is internally consistent, the model
|
|
100
|
+
is recognized, and required fields are present.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
A tuple of (is_valid, list_of_warnings_or_errors).
|
|
104
|
+
"""
|
|
105
|
+
errors: list[str] = []
|
|
106
|
+
|
|
107
|
+
# Validate model is in the supported registry
|
|
108
|
+
model_valid, model_error = validate_model(self.config.model)
|
|
109
|
+
if not model_valid:
|
|
110
|
+
errors.append(f"Model warning: {model_error}")
|
|
111
|
+
|
|
112
|
+
# Validate benchmark is in the registry
|
|
113
|
+
if self.config.benchmark not in BENCHMARK_REGISTRY:
|
|
114
|
+
errors.append(
|
|
115
|
+
f"Unknown benchmark: {self.config.benchmark}. "
|
|
116
|
+
f"Available: {', '.join(BENCHMARK_REGISTRY.keys())}"
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Validate provider
|
|
120
|
+
if self.config.provider not in VALID_PROVIDERS:
|
|
121
|
+
errors.append(
|
|
122
|
+
f"Unknown provider: {self.config.provider}. Available: {', '.join(VALID_PROVIDERS)}"
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
is_valid = len(errors) == 0
|
|
126
|
+
return is_valid, errors
|
|
127
|
+
|
|
128
|
+
def dry_run(self) -> dict[str, Any]:
|
|
129
|
+
"""Generate an execution plan without running anything.
|
|
130
|
+
|
|
131
|
+
Returns:
|
|
132
|
+
A dict describing what would be executed, including benchmark,
|
|
133
|
+
model, provider, MCP server config, and runtime settings.
|
|
134
|
+
"""
|
|
135
|
+
plan: dict[str, Any] = {
|
|
136
|
+
"benchmark": self.config.benchmark,
|
|
137
|
+
"model": self.config.model,
|
|
138
|
+
"provider": self.config.provider,
|
|
139
|
+
"agent_harness": self.config.agent_harness,
|
|
140
|
+
"timeout_seconds": self.config.timeout_seconds,
|
|
141
|
+
"max_concurrent": self.config.max_concurrent,
|
|
142
|
+
"max_iterations": self.config.max_iterations,
|
|
143
|
+
"sample_size": self.config.sample_size,
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
# Include MCP server info
|
|
147
|
+
if self.config.mcp_server:
|
|
148
|
+
plan["mcp_server"] = {
|
|
149
|
+
"command": self.config.mcp_server.command,
|
|
150
|
+
"args": self.config.mcp_server.args,
|
|
151
|
+
"name": self.config.mcp_server.name,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
# Include comparison mode info if applicable
|
|
155
|
+
if self.config.comparison_mode:
|
|
156
|
+
plan["comparison_mode"] = True
|
|
157
|
+
if self.config.mcp_server_a:
|
|
158
|
+
plan["mcp_server_a"] = {
|
|
159
|
+
"command": self.config.mcp_server_a.command,
|
|
160
|
+
"args": self.config.mcp_server_a.args,
|
|
161
|
+
"name": self.config.mcp_server_a.name,
|
|
162
|
+
}
|
|
163
|
+
if self.config.mcp_server_b:
|
|
164
|
+
plan["mcp_server_b"] = {
|
|
165
|
+
"command": self.config.mcp_server_b.command,
|
|
166
|
+
"args": self.config.mcp_server_b.args,
|
|
167
|
+
"name": self.config.mcp_server_b.name,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
# Optional settings
|
|
171
|
+
if self.config.budget is not None:
|
|
172
|
+
plan["budget"] = self.config.budget
|
|
173
|
+
if self.config.thinking_budget is not None:
|
|
174
|
+
plan["thinking_budget"] = self.config.thinking_budget
|
|
175
|
+
if self.config.agent_prompt is not None:
|
|
176
|
+
plan["agent_prompt"] = self.config.agent_prompt
|
|
177
|
+
|
|
178
|
+
return plan
|
|
179
|
+
|
|
180
|
+
async def run(self, **kwargs: Any) -> BenchmarkResult:
|
|
181
|
+
"""Execute the benchmark.
|
|
182
|
+
|
|
183
|
+
This is the main entry point for running a benchmark programmatically.
|
|
184
|
+
It delegates to the internal _execute method, which can be overridden
|
|
185
|
+
or mocked for testing.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
**kwargs: Additional keyword arguments passed to the executor.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
BenchmarkResult with the evaluation results.
|
|
192
|
+
"""
|
|
193
|
+
return await self._execute(**kwargs)
|
|
194
|
+
|
|
195
|
+
async def _execute(self, **kwargs: Any) -> BenchmarkResult:
|
|
196
|
+
"""Internal execution method.
|
|
197
|
+
|
|
198
|
+
Override or mock this method for testing. In production, this
|
|
199
|
+
would orchestrate the full benchmark pipeline (task loading,
|
|
200
|
+
environment creation, agent execution, evaluation).
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
**kwargs: Additional keyword arguments.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
BenchmarkResult with the evaluation results.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
NotImplementedError: Full execution pipeline is not yet
|
|
210
|
+
wired into the SDK. Use the CLI for actual runs.
|
|
211
|
+
"""
|
|
212
|
+
raise NotImplementedError(
|
|
213
|
+
"Full benchmark execution via the SDK is not yet implemented. "
|
|
214
|
+
"Use the `mcpbr` CLI for actual benchmark runs, or mock "
|
|
215
|
+
"MCPBenchmark._execute for testing."
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def list_benchmarks() -> list[dict[str, str]]:
|
|
220
|
+
"""List all available benchmarks.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
A list of dicts, each containing 'name' (the benchmark identifier)
|
|
224
|
+
and 'class' (the benchmark class name).
|
|
225
|
+
"""
|
|
226
|
+
return [{"name": name, "class": cls.__name__} for name, cls in BENCHMARK_REGISTRY.items()]
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def list_providers() -> list[str]:
|
|
230
|
+
"""List all supported model providers.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
A list of provider name strings.
|
|
234
|
+
"""
|
|
235
|
+
return list(VALID_PROVIDERS)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def list_models() -> list[dict[str, str]]:
|
|
239
|
+
"""List all supported models with their metadata.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
A list of dicts, each containing 'id', 'provider',
|
|
243
|
+
'display_name', 'context_window', 'supports_tools', and 'notes'.
|
|
244
|
+
"""
|
|
245
|
+
return [
|
|
246
|
+
{
|
|
247
|
+
"id": info.id,
|
|
248
|
+
"provider": info.provider,
|
|
249
|
+
"display_name": info.display_name,
|
|
250
|
+
"context_window": info.context_window,
|
|
251
|
+
"supports_tools": info.supports_tools,
|
|
252
|
+
"notes": info.notes,
|
|
253
|
+
}
|
|
254
|
+
for info in SUPPORTED_MODELS.values()
|
|
255
|
+
]
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def get_version() -> str:
|
|
259
|
+
"""Get the current mcpbr version.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
The version string (e.g., '0.6.0').
|
|
263
|
+
"""
|
|
264
|
+
return __version__
|
mcpbr/smoke_test.py
CHANGED
|
@@ -7,12 +7,13 @@ from dataclasses import dataclass
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any
|
|
9
9
|
|
|
10
|
-
import docker
|
|
11
10
|
from anthropic import Anthropic
|
|
12
11
|
from rich.console import Console
|
|
13
12
|
from rich.panel import Panel
|
|
14
13
|
from rich.table import Table
|
|
15
14
|
|
|
15
|
+
import docker
|
|
16
|
+
|
|
16
17
|
from .config import load_config
|
|
17
18
|
from .config_validator import validate_config
|
|
18
19
|
|
mcpbr/task_batching.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""Task batching with smart scheduling for efficient batch execution.
|
|
2
|
+
|
|
3
|
+
Groups similar benchmark tasks to minimize Docker container restarts and
|
|
4
|
+
maximize resource reuse. Supports multiple batching strategies including
|
|
5
|
+
repo-based, image-based, category-based, fixed-size, and adaptive grouping.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import uuid
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
# Estimated overhead per Docker container restart in seconds
|
|
15
|
+
_CONTAINER_RESTART_OVERHEAD_SECONDS = 30.0
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BatchStrategy(Enum):
|
|
19
|
+
"""Strategy for grouping tasks into batches.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
BY_REPO: Group tasks that share the same repository.
|
|
23
|
+
BY_IMAGE: Group tasks that require the same Docker image.
|
|
24
|
+
BY_CATEGORY: Group tasks that belong to the same benchmark category.
|
|
25
|
+
FIXED_SIZE: Split tasks into fixed-size chunks regardless of similarity.
|
|
26
|
+
ADAPTIVE: Dynamically size batches based on task similarity signals.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
BY_REPO = "by_repo"
|
|
30
|
+
BY_IMAGE = "by_image"
|
|
31
|
+
BY_CATEGORY = "by_category"
|
|
32
|
+
FIXED_SIZE = "fixed_size"
|
|
33
|
+
ADAPTIVE = "adaptive"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class TaskBatch:
|
|
38
|
+
"""A batch of grouped tasks for efficient execution.
|
|
39
|
+
|
|
40
|
+
Attributes:
|
|
41
|
+
batch_id: Unique identifier for this batch.
|
|
42
|
+
tasks: List of task dictionaries in this batch.
|
|
43
|
+
common_image: Shared Docker image if all tasks use the same one, else None.
|
|
44
|
+
common_repo: Shared repository if all tasks target the same repo, else None.
|
|
45
|
+
batch_size: Number of tasks in this batch.
|
|
46
|
+
estimated_savings_seconds: Estimated time saved by batching vs individual execution.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
batch_id: str
|
|
50
|
+
tasks: list[dict[str, Any]]
|
|
51
|
+
common_image: str | None = None
|
|
52
|
+
common_repo: str | None = None
|
|
53
|
+
batch_size: int = 0
|
|
54
|
+
estimated_savings_seconds: float = 0.0
|
|
55
|
+
|
|
56
|
+
def __post_init__(self) -> None:
|
|
57
|
+
"""Compute batch_size from tasks if not explicitly set."""
|
|
58
|
+
if self.batch_size == 0 and self.tasks:
|
|
59
|
+
self.batch_size = len(self.tasks)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class BatchSavings:
|
|
64
|
+
"""Estimated savings from batching tasks.
|
|
65
|
+
|
|
66
|
+
Attributes:
|
|
67
|
+
total_batches: Total number of batches created.
|
|
68
|
+
avg_batch_size: Average number of tasks per batch.
|
|
69
|
+
estimated_container_reuse: Number of container restarts avoided.
|
|
70
|
+
estimated_time_saved_seconds: Total estimated time saved in seconds.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
total_batches: int = 0
|
|
74
|
+
avg_batch_size: float = 0.0
|
|
75
|
+
estimated_container_reuse: int = 0
|
|
76
|
+
estimated_time_saved_seconds: float = 0.0
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TaskBatcher:
|
|
80
|
+
"""Groups benchmark tasks into batches for efficient execution.
|
|
81
|
+
|
|
82
|
+
Batching reduces Docker container restarts by grouping tasks that share
|
|
83
|
+
common requirements (repository, image, category). Supports multiple
|
|
84
|
+
strategies and configurable batch sizes.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
strategy: Batching strategy to use.
|
|
88
|
+
max_batch_size: Maximum number of tasks per batch.
|
|
89
|
+
min_batch_size: Minimum number of tasks to form a batch. Groups smaller
|
|
90
|
+
than this are still returned as batches (no tasks are dropped).
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
>>> batcher = TaskBatcher(strategy=BatchStrategy.BY_REPO, max_batch_size=5)
|
|
94
|
+
>>> tasks = [{"instance_id": "t1", "repo": "org/repo1"}, ...]
|
|
95
|
+
>>> batches = batcher.batch(tasks)
|
|
96
|
+
>>> print(batcher.preview(batches))
|
|
97
|
+
"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
strategy: BatchStrategy = BatchStrategy.BY_REPO,
|
|
102
|
+
max_batch_size: int = 10,
|
|
103
|
+
min_batch_size: int = 2,
|
|
104
|
+
) -> None:
|
|
105
|
+
"""Initialize the TaskBatcher.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
strategy: Batching strategy to use.
|
|
109
|
+
max_batch_size: Maximum number of tasks per batch.
|
|
110
|
+
min_batch_size: Minimum number of tasks to form a batch.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If max_batch_size < 1 or min_batch_size < 1 or
|
|
114
|
+
min_batch_size > max_batch_size.
|
|
115
|
+
"""
|
|
116
|
+
if max_batch_size < 1:
|
|
117
|
+
raise ValueError(f"max_batch_size must be >= 1, got {max_batch_size}")
|
|
118
|
+
if min_batch_size < 1:
|
|
119
|
+
raise ValueError(f"min_batch_size must be >= 1, got {min_batch_size}")
|
|
120
|
+
if min_batch_size > max_batch_size:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
f"min_batch_size ({min_batch_size}) must be <= max_batch_size ({max_batch_size})"
|
|
123
|
+
)
|
|
124
|
+
self.strategy = strategy
|
|
125
|
+
self.max_batch_size = max_batch_size
|
|
126
|
+
self.min_batch_size = min_batch_size
|
|
127
|
+
|
|
128
|
+
def batch(self, tasks: list[dict[str, Any]]) -> list[TaskBatch]:
|
|
129
|
+
"""Group tasks into batches using the configured strategy.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
tasks: List of task dictionaries to batch. Each task should have
|
|
133
|
+
at minimum an ``instance_id`` key. Depending on the strategy,
|
|
134
|
+
``repo``, ``image``, and ``category`` fields are also used.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
List of TaskBatch objects. Every input task appears in exactly one
|
|
138
|
+
batch. Batches are sorted by descending size for scheduling efficiency.
|
|
139
|
+
"""
|
|
140
|
+
if not tasks:
|
|
141
|
+
return []
|
|
142
|
+
|
|
143
|
+
if self.strategy == BatchStrategy.BY_REPO:
|
|
144
|
+
return self._batch_by_field(tasks, "repo")
|
|
145
|
+
elif self.strategy == BatchStrategy.BY_IMAGE:
|
|
146
|
+
return self._batch_by_field(tasks, "image")
|
|
147
|
+
elif self.strategy == BatchStrategy.BY_CATEGORY:
|
|
148
|
+
return self._batch_by_field(tasks, "category")
|
|
149
|
+
elif self.strategy == BatchStrategy.FIXED_SIZE:
|
|
150
|
+
return self._batch_fixed_size(tasks)
|
|
151
|
+
elif self.strategy == BatchStrategy.ADAPTIVE:
|
|
152
|
+
return self._batch_adaptive(tasks)
|
|
153
|
+
else:
|
|
154
|
+
raise ValueError(f"Unknown batch strategy: {self.strategy}")
|
|
155
|
+
|
|
156
|
+
def estimate_savings(self, batches: list[TaskBatch]) -> BatchSavings:
|
|
157
|
+
"""Estimate time saved by batching compared to individual execution.
|
|
158
|
+
|
|
159
|
+
The savings come primarily from container reuse: tasks in the same batch
|
|
160
|
+
can share a Docker container instead of each requiring a fresh one.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
batches: List of TaskBatch objects to analyze.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
BatchSavings with estimated metrics.
|
|
167
|
+
"""
|
|
168
|
+
if not batches:
|
|
169
|
+
return BatchSavings()
|
|
170
|
+
|
|
171
|
+
total_tasks = sum(b.batch_size for b in batches)
|
|
172
|
+
total_batches = len(batches)
|
|
173
|
+
avg_batch_size = total_tasks / total_batches if total_batches > 0 else 0.0
|
|
174
|
+
|
|
175
|
+
# Without batching, each task needs its own container restart.
|
|
176
|
+
# With batching, only the first task in each batch needs a restart.
|
|
177
|
+
container_reuse = total_tasks - total_batches
|
|
178
|
+
time_saved = container_reuse * _CONTAINER_RESTART_OVERHEAD_SECONDS
|
|
179
|
+
|
|
180
|
+
return BatchSavings(
|
|
181
|
+
total_batches=total_batches,
|
|
182
|
+
avg_batch_size=round(avg_batch_size, 2),
|
|
183
|
+
estimated_container_reuse=container_reuse,
|
|
184
|
+
estimated_time_saved_seconds=round(time_saved, 2),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def preview(self, batches: list[TaskBatch]) -> str:
|
|
188
|
+
"""Generate a formatted preview of the batching plan.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
batches: List of TaskBatch objects to preview.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Human-readable string summarizing the batches and estimated savings.
|
|
195
|
+
"""
|
|
196
|
+
if not batches:
|
|
197
|
+
return "No batches to preview."
|
|
198
|
+
|
|
199
|
+
savings = self.estimate_savings(batches)
|
|
200
|
+
lines: list[str] = []
|
|
201
|
+
lines.append(f"Batch Plan ({self.strategy.value})")
|
|
202
|
+
lines.append("=" * 50)
|
|
203
|
+
lines.append(f"Total batches: {savings.total_batches}")
|
|
204
|
+
lines.append(f"Average batch size: {savings.avg_batch_size}")
|
|
205
|
+
lines.append(f"Estimated container reuse: {savings.estimated_container_reuse}")
|
|
206
|
+
lines.append(f"Estimated time saved: {savings.estimated_time_saved_seconds:.1f}s")
|
|
207
|
+
lines.append("")
|
|
208
|
+
|
|
209
|
+
for i, b in enumerate(batches, 1):
|
|
210
|
+
label_parts: list[str] = []
|
|
211
|
+
if b.common_repo:
|
|
212
|
+
label_parts.append(f"repo={b.common_repo}")
|
|
213
|
+
if b.common_image:
|
|
214
|
+
label_parts.append(f"image={b.common_image}")
|
|
215
|
+
label = ", ".join(label_parts) if label_parts else "mixed"
|
|
216
|
+
|
|
217
|
+
lines.append(f" Batch {i}: {b.batch_size} tasks ({label})")
|
|
218
|
+
task_ids = [t.get("instance_id", "?") for t in b.tasks[:5]]
|
|
219
|
+
if b.batch_size > 5:
|
|
220
|
+
task_ids.append(f"... +{b.batch_size - 5} more")
|
|
221
|
+
for tid in task_ids:
|
|
222
|
+
lines.append(f" - {tid}")
|
|
223
|
+
|
|
224
|
+
return "\n".join(lines)
|
|
225
|
+
|
|
226
|
+
# ------------------------------------------------------------------
|
|
227
|
+
# Private helpers
|
|
228
|
+
# ------------------------------------------------------------------
|
|
229
|
+
|
|
230
|
+
def _batch_by_field(self, tasks: list[dict[str, Any]], field_name: str) -> list[TaskBatch]:
|
|
231
|
+
"""Group tasks by a shared field, then split into max-sized chunks.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
tasks: List of task dictionaries.
|
|
235
|
+
field_name: Key to group tasks by (e.g. "repo", "image", "category").
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Sorted list of TaskBatch objects.
|
|
239
|
+
"""
|
|
240
|
+
groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
|
241
|
+
for task in tasks:
|
|
242
|
+
key = str(task.get(field_name, "_ungrouped_"))
|
|
243
|
+
groups[key].append(task)
|
|
244
|
+
|
|
245
|
+
batches: list[TaskBatch] = []
|
|
246
|
+
for key, group_tasks in sorted(groups.items()):
|
|
247
|
+
for chunk in self._split_into_chunks(group_tasks):
|
|
248
|
+
common_image = self._common_value(chunk, "image")
|
|
249
|
+
common_repo = self._common_value(chunk, "repo")
|
|
250
|
+
savings = self._estimate_batch_savings(len(chunk))
|
|
251
|
+
batches.append(
|
|
252
|
+
TaskBatch(
|
|
253
|
+
batch_id=str(uuid.uuid4()),
|
|
254
|
+
tasks=chunk,
|
|
255
|
+
common_image=common_image,
|
|
256
|
+
common_repo=common_repo,
|
|
257
|
+
batch_size=len(chunk),
|
|
258
|
+
estimated_savings_seconds=savings,
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# Sort largest first for better scheduling
|
|
263
|
+
batches.sort(key=lambda b: b.batch_size, reverse=True)
|
|
264
|
+
return batches
|
|
265
|
+
|
|
266
|
+
def _batch_fixed_size(self, tasks: list[dict[str, Any]]) -> list[TaskBatch]:
|
|
267
|
+
"""Split tasks into fixed-size chunks.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
tasks: List of task dictionaries.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Sorted list of TaskBatch objects.
|
|
274
|
+
"""
|
|
275
|
+
batches: list[TaskBatch] = []
|
|
276
|
+
for chunk in self._split_into_chunks(tasks):
|
|
277
|
+
common_image = self._common_value(chunk, "image")
|
|
278
|
+
common_repo = self._common_value(chunk, "repo")
|
|
279
|
+
savings = self._estimate_batch_savings(len(chunk))
|
|
280
|
+
batches.append(
|
|
281
|
+
TaskBatch(
|
|
282
|
+
batch_id=str(uuid.uuid4()),
|
|
283
|
+
tasks=chunk,
|
|
284
|
+
common_image=common_image,
|
|
285
|
+
common_repo=common_repo,
|
|
286
|
+
batch_size=len(chunk),
|
|
287
|
+
estimated_savings_seconds=savings,
|
|
288
|
+
)
|
|
289
|
+
)
|
|
290
|
+
return batches
|
|
291
|
+
|
|
292
|
+
def _batch_adaptive(self, tasks: list[dict[str, Any]]) -> list[TaskBatch]:
|
|
293
|
+
"""Adaptively group tasks based on multi-field similarity.
|
|
294
|
+
|
|
295
|
+
Tasks are first grouped by a composite key of all available grouping
|
|
296
|
+
fields (repo, image, category). Groups that share more fields get
|
|
297
|
+
larger batch sizes (up to max_batch_size). Groups with no shared
|
|
298
|
+
fields get smaller batches (down toward min_batch_size).
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
tasks: List of task dictionaries.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Sorted list of TaskBatch objects.
|
|
305
|
+
"""
|
|
306
|
+
# Build composite similarity groups
|
|
307
|
+
similarity_fields = ["repo", "image", "category"]
|
|
308
|
+
groups: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
|
309
|
+
for task in tasks:
|
|
310
|
+
parts = []
|
|
311
|
+
for f in similarity_fields:
|
|
312
|
+
parts.append(str(task.get(f, "_")))
|
|
313
|
+
key = "|".join(parts)
|
|
314
|
+
groups[key].append(task)
|
|
315
|
+
|
|
316
|
+
batches: list[TaskBatch] = []
|
|
317
|
+
for key, group_tasks in sorted(groups.items()):
|
|
318
|
+
# Determine how many fields are shared (non-default)
|
|
319
|
+
key_parts = key.split("|")
|
|
320
|
+
shared_count = sum(1 for p in key_parts if p != "_")
|
|
321
|
+
|
|
322
|
+
# Scale batch size based on similarity: more shared fields -> larger batches
|
|
323
|
+
similarity_ratio = shared_count / len(similarity_fields) if similarity_fields else 0
|
|
324
|
+
adaptive_max = self.min_batch_size + int(
|
|
325
|
+
(self.max_batch_size - self.min_batch_size) * similarity_ratio
|
|
326
|
+
)
|
|
327
|
+
adaptive_max = max(adaptive_max, self.min_batch_size)
|
|
328
|
+
|
|
329
|
+
for chunk in self._split_into_chunks(group_tasks, max_size=adaptive_max):
|
|
330
|
+
common_image = self._common_value(chunk, "image")
|
|
331
|
+
common_repo = self._common_value(chunk, "repo")
|
|
332
|
+
savings = self._estimate_batch_savings(len(chunk))
|
|
333
|
+
batches.append(
|
|
334
|
+
TaskBatch(
|
|
335
|
+
batch_id=str(uuid.uuid4()),
|
|
336
|
+
tasks=chunk,
|
|
337
|
+
common_image=common_image,
|
|
338
|
+
common_repo=common_repo,
|
|
339
|
+
batch_size=len(chunk),
|
|
340
|
+
estimated_savings_seconds=savings,
|
|
341
|
+
)
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
batches.sort(key=lambda b: b.batch_size, reverse=True)
|
|
345
|
+
return batches
|
|
346
|
+
|
|
347
|
+
def _split_into_chunks(
|
|
348
|
+
self,
|
|
349
|
+
tasks: list[dict[str, Any]],
|
|
350
|
+
max_size: int | None = None,
|
|
351
|
+
) -> list[list[dict[str, Any]]]:
|
|
352
|
+
"""Split a list of tasks into chunks of at most max_size.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
tasks: Tasks to split.
|
|
356
|
+
max_size: Override for maximum chunk size. Defaults to self.max_batch_size.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
List of task sublists.
|
|
360
|
+
"""
|
|
361
|
+
size = max_size if max_size is not None else self.max_batch_size
|
|
362
|
+
if size < 1:
|
|
363
|
+
size = 1
|
|
364
|
+
chunks: list[list[dict[str, Any]]] = []
|
|
365
|
+
for i in range(0, len(tasks), size):
|
|
366
|
+
chunks.append(tasks[i : i + size])
|
|
367
|
+
return chunks
|
|
368
|
+
|
|
369
|
+
@staticmethod
|
|
370
|
+
def _common_value(tasks: list[dict[str, Any]], field_name: str) -> str | None:
|
|
371
|
+
"""Return the shared value for a field if all tasks agree, else None.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
tasks: List of task dictionaries.
|
|
375
|
+
field_name: Key to check.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
The common value string, or None if tasks differ or field is absent.
|
|
379
|
+
"""
|
|
380
|
+
if not tasks:
|
|
381
|
+
return None
|
|
382
|
+
values = {t.get(field_name) for t in tasks}
|
|
383
|
+
values.discard(None)
|
|
384
|
+
if len(values) == 1:
|
|
385
|
+
return str(values.pop())
|
|
386
|
+
return None
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def _estimate_batch_savings(batch_size: int) -> float:
|
|
390
|
+
"""Estimate time saved for a single batch.
|
|
391
|
+
|
|
392
|
+
Each additional task in a batch beyond the first avoids one container
|
|
393
|
+
restart.
|
|
394
|
+
|
|
395
|
+
Args:
|
|
396
|
+
batch_size: Number of tasks in the batch.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Estimated time saved in seconds.
|
|
400
|
+
"""
|
|
401
|
+
if batch_size <= 1:
|
|
402
|
+
return 0.0
|
|
403
|
+
return (batch_size - 1) * _CONTAINER_RESTART_OVERHEAD_SECONDS
|