wafer-cli 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/cli.py +105 -859
- wafer/evaluate.py +200 -1097
- wafer/gpu_run.py +1 -5
- wafer/targets.py +0 -158
- wafer/wevin_cli.py +0 -2
- {wafer_cli-0.2.6.dist-info → wafer_cli-0.2.8.dist-info}/METADATA +1 -1
- {wafer_cli-0.2.6.dist-info → wafer_cli-0.2.8.dist-info}/RECORD +10 -12
- wafer/problems.py +0 -357
- wafer/target_lock.py +0 -198
- {wafer_cli-0.2.6.dist-info → wafer_cli-0.2.8.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.6.dist-info → wafer_cli-0.2.8.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.6.dist-info → wafer_cli-0.2.8.dist-info}/top_level.txt +0 -0
wafer/gpu_run.py
CHANGED
|
@@ -19,10 +19,7 @@ CONTAINER_WORKSPACE = "/workspace"
|
|
|
19
19
|
class PushResult:
|
|
20
20
|
"""Result of pushing a directory to remote target."""
|
|
21
21
|
|
|
22
|
-
|
|
23
|
-
workspace_path: (
|
|
24
|
-
str # Full absolute path on remote (e.g., "/home/user/.wafer/workspaces/project")
|
|
25
|
-
)
|
|
22
|
+
workspace_path: str # Absolute path on remote (tilde-expanded)
|
|
26
23
|
files_uploaded: list[str] # Relative paths of uploaded files
|
|
27
24
|
|
|
28
25
|
|
|
@@ -74,7 +71,6 @@ def push_directory(
|
|
|
74
71
|
files_uploaded.append(str(file.relative_to(local_path)))
|
|
75
72
|
|
|
76
73
|
return PushResult(
|
|
77
|
-
workspace_name=workspace_name,
|
|
78
74
|
workspace_path=expanded_workspace,
|
|
79
75
|
files_uploaded=files_uploaded,
|
|
80
76
|
)
|
wafer/targets.py
CHANGED
|
@@ -257,164 +257,6 @@ def get_default_target() -> str | None:
|
|
|
257
257
|
return data.get("default_target")
|
|
258
258
|
|
|
259
259
|
|
|
260
|
-
# ── Pool Management ─────────────────────────────────────────────────────────
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def get_pool(name: str) -> list[str]:
|
|
264
|
-
"""Get list of targets in a named pool.
|
|
265
|
-
|
|
266
|
-
Pools are defined in ~/.wafer/config.toml:
|
|
267
|
-
[pools.my-pool]
|
|
268
|
-
targets = ["target-1", "target-2", "target-3"]
|
|
269
|
-
|
|
270
|
-
Args:
|
|
271
|
-
name: Pool name
|
|
272
|
-
|
|
273
|
-
Returns:
|
|
274
|
-
List of target names in the pool
|
|
275
|
-
|
|
276
|
-
Raises:
|
|
277
|
-
FileNotFoundError: If pool doesn't exist
|
|
278
|
-
"""
|
|
279
|
-
if not CONFIG_FILE.exists():
|
|
280
|
-
raise FileNotFoundError(f"Pool not found: {name} (no config file)")
|
|
281
|
-
|
|
282
|
-
with open(CONFIG_FILE, "rb") as f:
|
|
283
|
-
data = tomllib.load(f)
|
|
284
|
-
|
|
285
|
-
pools = data.get("pools", {})
|
|
286
|
-
if name not in pools:
|
|
287
|
-
raise FileNotFoundError(
|
|
288
|
-
f"Pool not found: {name}\n"
|
|
289
|
-
f" Define pools in ~/.wafer/config.toml:\n"
|
|
290
|
-
f" [pools.{name}]\n"
|
|
291
|
-
f' targets = ["target-1", "target-2"]'
|
|
292
|
-
)
|
|
293
|
-
|
|
294
|
-
pool_config = pools[name]
|
|
295
|
-
targets = pool_config.get("targets", [])
|
|
296
|
-
|
|
297
|
-
if not targets:
|
|
298
|
-
raise ValueError(f"Pool '{name}' has no targets defined")
|
|
299
|
-
|
|
300
|
-
return targets
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
def list_pools() -> list[str]:
|
|
304
|
-
"""List all configured pool names.
|
|
305
|
-
|
|
306
|
-
Returns:
|
|
307
|
-
Sorted list of pool names
|
|
308
|
-
"""
|
|
309
|
-
if not CONFIG_FILE.exists():
|
|
310
|
-
return []
|
|
311
|
-
|
|
312
|
-
with open(CONFIG_FILE, "rb") as f:
|
|
313
|
-
data = tomllib.load(f)
|
|
314
|
-
|
|
315
|
-
return sorted(data.get("pools", {}).keys())
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
def save_pool(name: str, targets: list[str]) -> None:
|
|
319
|
-
"""Save or update a pool configuration.
|
|
320
|
-
|
|
321
|
-
Args:
|
|
322
|
-
name: Pool name
|
|
323
|
-
targets: List of target names (must all exist)
|
|
324
|
-
|
|
325
|
-
Raises:
|
|
326
|
-
FileNotFoundError: If any target doesn't exist
|
|
327
|
-
"""
|
|
328
|
-
# Verify all targets exist
|
|
329
|
-
existing_targets = list_targets()
|
|
330
|
-
missing = [t for t in targets if t not in existing_targets]
|
|
331
|
-
if missing:
|
|
332
|
-
raise FileNotFoundError(f"Targets not found: {', '.join(missing)}")
|
|
333
|
-
|
|
334
|
-
_ensure_dirs()
|
|
335
|
-
|
|
336
|
-
# Load existing config
|
|
337
|
-
if CONFIG_FILE.exists():
|
|
338
|
-
with open(CONFIG_FILE, "rb") as f:
|
|
339
|
-
data = tomllib.load(f)
|
|
340
|
-
else:
|
|
341
|
-
data = {}
|
|
342
|
-
|
|
343
|
-
# Update pools section
|
|
344
|
-
if "pools" not in data:
|
|
345
|
-
data["pools"] = {}
|
|
346
|
-
|
|
347
|
-
data["pools"][name] = {"targets": targets}
|
|
348
|
-
|
|
349
|
-
# Write back - need custom handling for nested structure
|
|
350
|
-
_write_config_with_pools(data)
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
def _write_config_with_pools(data: dict) -> None:
|
|
354
|
-
"""Write config file with pools support.
|
|
355
|
-
|
|
356
|
-
Handles the nested [pools.name] TOML structure and preserves
|
|
357
|
-
existing nested sections like [default], [api], [environments.*].
|
|
358
|
-
"""
|
|
359
|
-
lines = []
|
|
360
|
-
|
|
361
|
-
# Collect nested sections to write after top-level keys
|
|
362
|
-
nested_sections: dict[str, dict] = {}
|
|
363
|
-
|
|
364
|
-
# Write top-level keys first (except pools and nested dicts)
|
|
365
|
-
for key, value in data.items():
|
|
366
|
-
if key == "pools":
|
|
367
|
-
continue
|
|
368
|
-
if value is None:
|
|
369
|
-
continue
|
|
370
|
-
if isinstance(value, dict):
|
|
371
|
-
# Save nested sections for later
|
|
372
|
-
nested_sections[key] = value
|
|
373
|
-
elif isinstance(value, str):
|
|
374
|
-
lines.append(f'{key} = "{value}"')
|
|
375
|
-
elif isinstance(value, bool):
|
|
376
|
-
lines.append(f"{key} = {str(value).lower()}")
|
|
377
|
-
elif isinstance(value, int | float):
|
|
378
|
-
lines.append(f"{key} = {value}")
|
|
379
|
-
elif isinstance(value, list):
|
|
380
|
-
if all(isinstance(v, int) for v in value):
|
|
381
|
-
lines.append(f"{key} = {value}")
|
|
382
|
-
else:
|
|
383
|
-
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
384
|
-
lines.append(f"{key} = [{formatted}]")
|
|
385
|
-
|
|
386
|
-
# Write nested sections (e.g., [default], [api], [environments.foo])
|
|
387
|
-
for section_name, section_data in nested_sections.items():
|
|
388
|
-
lines.append("")
|
|
389
|
-
lines.append(f"[{section_name}]")
|
|
390
|
-
for key, value in section_data.items():
|
|
391
|
-
if value is None:
|
|
392
|
-
continue
|
|
393
|
-
if isinstance(value, str):
|
|
394
|
-
lines.append(f'{key} = "{value}"')
|
|
395
|
-
elif isinstance(value, bool):
|
|
396
|
-
lines.append(f"{key} = {str(value).lower()}")
|
|
397
|
-
elif isinstance(value, int | float):
|
|
398
|
-
lines.append(f"{key} = {value}")
|
|
399
|
-
elif isinstance(value, list):
|
|
400
|
-
if all(isinstance(v, int) for v in value):
|
|
401
|
-
lines.append(f"{key} = {value}")
|
|
402
|
-
else:
|
|
403
|
-
formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
|
|
404
|
-
lines.append(f"{key} = [{formatted}]")
|
|
405
|
-
|
|
406
|
-
# Write pools
|
|
407
|
-
pools = data.get("pools", {})
|
|
408
|
-
for pool_name, pool_config in pools.items():
|
|
409
|
-
lines.append("")
|
|
410
|
-
lines.append(f"[pools.{pool_name}]")
|
|
411
|
-
targets = pool_config.get("targets", [])
|
|
412
|
-
formatted = ", ".join(f'"{t}"' for t in targets)
|
|
413
|
-
lines.append(f"targets = [{formatted}]")
|
|
414
|
-
|
|
415
|
-
CONFIG_FILE.write_text("\n".join(lines) + "\n")
|
|
416
|
-
|
|
417
|
-
|
|
418
260
|
def set_default_target(name: str) -> None:
|
|
419
261
|
"""Set default target.
|
|
420
262
|
|
wafer/wevin_cli.py
CHANGED
|
@@ -253,7 +253,6 @@ def _build_environment(
|
|
|
253
253
|
) -> Environment:
|
|
254
254
|
"""Build a CodingEnvironment from template config."""
|
|
255
255
|
from wafer_core.environments.coding import CodingEnvironment
|
|
256
|
-
from wafer_core.rollouts.templates import DANGEROUS_BASH_COMMANDS
|
|
257
256
|
|
|
258
257
|
working_dir = Path(corpus_path) if corpus_path else Path.cwd()
|
|
259
258
|
resolved_tools = tools_override or tpl.tools
|
|
@@ -261,7 +260,6 @@ def _build_environment(
|
|
|
261
260
|
working_dir=working_dir,
|
|
262
261
|
enabled_tools=resolved_tools,
|
|
263
262
|
bash_allowlist=tpl.bash_allowlist,
|
|
264
|
-
bash_denylist=DANGEROUS_BASH_COMMANDS,
|
|
265
263
|
) # type: ignore[assignment]
|
|
266
264
|
return env
|
|
267
265
|
|
|
@@ -5,31 +5,29 @@ wafer/api_client.py,sha256=cPULiTxqOAYYSfDTNJgd-6Pqrt3IM4Gm9903U7yGIwY,6163
|
|
|
5
5
|
wafer/auth.py,sha256=ZLsXZ73GDLD8GL7Rij1ELtuLqyJ5EU_uPBUMPVKwExA,10703
|
|
6
6
|
wafer/autotuner.py,sha256=6gH0Ho7T58EFerMQcHQxshWe3DF4qU7fb5xthAh5SPM,44364
|
|
7
7
|
wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
|
|
8
|
-
wafer/cli.py,sha256=
|
|
8
|
+
wafer/cli.py,sha256=QgqaBkCrpnLD6IaY35Eo-JITR5vnMKmHCmnqniW0Yv4,184987
|
|
9
9
|
wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
|
|
10
10
|
wafer/corpus.py,sha256=yTF3UA5bOa8BII2fmcXf-3WsIsM5DX4etysv0AzVknE,8912
|
|
11
|
-
wafer/evaluate.py,sha256=
|
|
11
|
+
wafer/evaluate.py,sha256=nIqLQap9-mUtzOWTCJXkZsNydeo36uSTfiD9dGM07aA,130748
|
|
12
12
|
wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
|
|
13
|
-
wafer/gpu_run.py,sha256=
|
|
13
|
+
wafer/gpu_run.py,sha256=gUbzMsMPsw3UHcn00bI-zTSHrU8c2FEpDvbcsczlDPo,9557
|
|
14
14
|
wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
|
|
15
15
|
wafer/ncu_analyze.py,sha256=rAWzKQRZEY6E_CL3gAWUaW3uZ4kvQVZskVCPDpsFJuE,24633
|
|
16
16
|
wafer/nsys_analyze.py,sha256=dRsYNYp1IqzGSPrQuEMW5vRbIxr-VrQwQbotLSrPvlY,6795
|
|
17
|
-
wafer/problems.py,sha256=ce2sy10A1nnNUG3VGsseTS8jL7LZsku4dE8zVf9JHQ4,11296
|
|
18
17
|
wafer/rocprof_compute.py,sha256=Tu16Vb05b2grvheFWi1XLGlAr6m48NEDeZoDyw_4Uzw,19885
|
|
19
18
|
wafer/rocprof_sdk.py,sha256=fAYCxpfJa5BZTTkIMBOXg4KsYK4i_wNOKrJJn1ZfypM,10086
|
|
20
19
|
wafer/rocprof_systems.py,sha256=4IWbMcbYk1x_8iS7P3FC_u5sgH6EXADCtR2lV9id80M,18629
|
|
21
|
-
wafer/
|
|
22
|
-
wafer/targets.py,sha256=JlLvi18IHtOkgtBdkv_nUrzBweVmFoOQH-9tQW5s1yQ,15250
|
|
20
|
+
wafer/targets.py,sha256=WE5TJgFPGtEIh7VaTQHZ4wB2t4kW0c5K8-UmQ_39Ock,10254
|
|
23
21
|
wafer/tracelens.py,sha256=g9ZIeFyNojZn4uTd3skPqIrRiL7aMJOz_-GOd3aiyy4,7998
|
|
24
|
-
wafer/wevin_cli.py,sha256=
|
|
22
|
+
wafer/wevin_cli.py,sha256=jvj8H9cNf2EXhVnifQzDrz0aR3mzHgCv68CdIkCx6po,16685
|
|
25
23
|
wafer/workspaces.py,sha256=92LG1mtkzNz-ap3XzcqY6KnQ9SUCFG8VBIOUj1Who64,25757
|
|
26
24
|
wafer/skills/wafer-guide/SKILL.md,sha256=UfBeIe5GKFzOYcbPmcs8U2nrjbfr-jSMRwg0jQDBfb0,3058
|
|
27
25
|
wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
28
26
|
wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2256
|
|
29
27
|
wafer/templates/optimize_kernel.py,sha256=u6AL7Q3uttqlnBLzcoFdsiPq5lV2TV3bgqwCYYlK9gk,2357
|
|
30
28
|
wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
|
|
31
|
-
wafer_cli-0.2.
|
|
32
|
-
wafer_cli-0.2.
|
|
33
|
-
wafer_cli-0.2.
|
|
34
|
-
wafer_cli-0.2.
|
|
35
|
-
wafer_cli-0.2.
|
|
29
|
+
wafer_cli-0.2.8.dist-info/METADATA,sha256=tihbS8AP8QoiVqZWjudfFu9iXdijuO1QVxhoQb4lml4,559
|
|
30
|
+
wafer_cli-0.2.8.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
|
|
31
|
+
wafer_cli-0.2.8.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
|
|
32
|
+
wafer_cli-0.2.8.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
|
|
33
|
+
wafer_cli-0.2.8.dist-info/RECORD,,
|
wafer/problems.py
DELETED
|
@@ -1,357 +0,0 @@
|
|
|
1
|
-
"""Problem set management for Wafer CLI.
|
|
2
|
-
|
|
3
|
-
Download and manage kernel optimization problem sets for evaluation.
|
|
4
|
-
Follows the same pattern as corpus.py for consistency.
|
|
5
|
-
"""
|
|
6
|
-
|
|
7
|
-
import shutil
|
|
8
|
-
import tarfile
|
|
9
|
-
import tempfile
|
|
10
|
-
from dataclasses import dataclass
|
|
11
|
-
from pathlib import Path
|
|
12
|
-
from typing import Literal
|
|
13
|
-
|
|
14
|
-
import httpx
|
|
15
|
-
|
|
16
|
-
PROBLEMS_CACHE_DIR = Path.home() / ".cache" / "wafer" / "problems"
|
|
17
|
-
|
|
18
|
-
ProblemSetName = Literal["kernelbench", "gpumode"]
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
@dataclass
|
|
22
|
-
class ProblemSetConfig:
|
|
23
|
-
"""Configuration for a downloadable problem set."""
|
|
24
|
-
|
|
25
|
-
name: ProblemSetName
|
|
26
|
-
description: str
|
|
27
|
-
repo: str # GitHub repo in "owner/repo" format
|
|
28
|
-
repo_paths: list[str] # Paths within repo to download
|
|
29
|
-
format_description: str # Brief description of the format
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
PROBLEM_SETS: dict[ProblemSetName, ProblemSetConfig] = {
|
|
33
|
-
"kernelbench": ProblemSetConfig(
|
|
34
|
-
name="kernelbench",
|
|
35
|
-
description="KernelBench GPU kernel optimization problems (level1-4)",
|
|
36
|
-
repo="ScalingIntelligence/KernelBench",
|
|
37
|
-
repo_paths=["KernelBench"],
|
|
38
|
-
format_description="Class-based: Model/ModelNew with get_inputs/get_init_inputs",
|
|
39
|
-
),
|
|
40
|
-
"gpumode": ProblemSetConfig(
|
|
41
|
-
name="gpumode",
|
|
42
|
-
description="GPU Mode reference kernels (pmpp, amd, nvidia, bioml)",
|
|
43
|
-
repo="gpu-mode/reference-kernels",
|
|
44
|
-
repo_paths=["problems"],
|
|
45
|
-
format_description="Functional: ref_kernel/custom_kernel with generate_input",
|
|
46
|
-
),
|
|
47
|
-
}
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def _problems_path(name: ProblemSetName) -> Path:
|
|
51
|
-
"""Get local path for problem set."""
|
|
52
|
-
return PROBLEMS_CACHE_DIR / name
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
def _ensure_cache_dir() -> None:
|
|
56
|
-
"""Ensure cache directory exists."""
|
|
57
|
-
PROBLEMS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def _download_github_repo(config: ProblemSetConfig, dest: Path, verbose: bool = True) -> int:
|
|
61
|
-
"""Download specific paths from GitHub repo.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
config: Problem set configuration
|
|
65
|
-
dest: Destination directory
|
|
66
|
-
verbose: Print progress
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
Number of files downloaded
|
|
70
|
-
"""
|
|
71
|
-
# Fetch tarball from GitHub
|
|
72
|
-
resp = _fetch_github_tarball(config.repo, verbose)
|
|
73
|
-
|
|
74
|
-
# Save to temp file
|
|
75
|
-
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
|
|
76
|
-
tmp.write(resp.content)
|
|
77
|
-
tmp_path = Path(tmp.name)
|
|
78
|
-
|
|
79
|
-
# Extract matching files
|
|
80
|
-
try:
|
|
81
|
-
downloaded = _extract_tarball(tmp_path, dest, config.repo_paths, verbose)
|
|
82
|
-
finally:
|
|
83
|
-
tmp_path.unlink()
|
|
84
|
-
|
|
85
|
-
return downloaded
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def _fetch_github_tarball(repo: str, verbose: bool) -> httpx.Response:
|
|
89
|
-
"""Fetch tarball from GitHub, trying main then master branch."""
|
|
90
|
-
with httpx.Client(timeout=120.0, follow_redirects=True) as client:
|
|
91
|
-
for branch in ["main", "master"]:
|
|
92
|
-
tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
|
|
93
|
-
if verbose:
|
|
94
|
-
print(f" Fetching {repo} ({branch} branch)...")
|
|
95
|
-
try:
|
|
96
|
-
resp = client.get(tarball_url)
|
|
97
|
-
resp.raise_for_status()
|
|
98
|
-
return resp
|
|
99
|
-
except httpx.HTTPStatusError:
|
|
100
|
-
if branch == "master":
|
|
101
|
-
raise
|
|
102
|
-
raise RuntimeError(f"Failed to fetch tarball from {repo}") # Should not reach
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
def _extract_tarball(tmp_path: Path, dest: Path, repo_paths: list[str], verbose: bool) -> int:
|
|
106
|
-
"""Extract files from tarball matching repo_paths."""
|
|
107
|
-
downloaded = 0
|
|
108
|
-
with tarfile.open(tmp_path, "r:gz") as tar:
|
|
109
|
-
for member in tar.getmembers():
|
|
110
|
-
if not member.isfile():
|
|
111
|
-
continue
|
|
112
|
-
# Strip the root directory (e.g., "ScalingIntelligence-KernelBench-abc123/")
|
|
113
|
-
rel_path = "/".join(member.name.split("/")[1:])
|
|
114
|
-
if not _matches_repo_paths(rel_path, repo_paths):
|
|
115
|
-
continue
|
|
116
|
-
target = dest / rel_path
|
|
117
|
-
target.parent.mkdir(parents=True, exist_ok=True)
|
|
118
|
-
extracted = tar.extractfile(member)
|
|
119
|
-
if extracted:
|
|
120
|
-
target.write_bytes(extracted.read())
|
|
121
|
-
downloaded += 1
|
|
122
|
-
if verbose and downloaded <= 10:
|
|
123
|
-
print(f" ✓ {rel_path}")
|
|
124
|
-
if verbose and downloaded > 10:
|
|
125
|
-
print(f" ... and {downloaded - 10} more files")
|
|
126
|
-
return downloaded
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
def _matches_repo_paths(rel_path: str, repo_paths: list[str]) -> bool:
|
|
130
|
-
"""Check if rel_path starts with any of the repo_paths."""
|
|
131
|
-
return any(rel_path.startswith(rp) for rp in repo_paths)
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
def download_problems(name: ProblemSetName, force: bool = False, verbose: bool = True) -> Path:
|
|
135
|
-
"""Download a problem set to local cache.
|
|
136
|
-
|
|
137
|
-
Args:
|
|
138
|
-
name: Problem set name
|
|
139
|
-
force: Re-download even if exists
|
|
140
|
-
verbose: Print progress
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
Path to downloaded problem set
|
|
144
|
-
|
|
145
|
-
Raises:
|
|
146
|
-
ValueError: If problem set name is unknown
|
|
147
|
-
httpx.HTTPError: If download fails
|
|
148
|
-
"""
|
|
149
|
-
if name not in PROBLEM_SETS:
|
|
150
|
-
raise ValueError(f"Unknown problem set: {name}. Available: {list(PROBLEM_SETS.keys())}")
|
|
151
|
-
|
|
152
|
-
config = PROBLEM_SETS[name]
|
|
153
|
-
dest = _problems_path(name)
|
|
154
|
-
|
|
155
|
-
if dest.exists() and not force:
|
|
156
|
-
if verbose:
|
|
157
|
-
print(f"Problem set '{name}' already exists at {dest}")
|
|
158
|
-
print("Use --force to re-download")
|
|
159
|
-
return dest
|
|
160
|
-
|
|
161
|
-
_ensure_cache_dir()
|
|
162
|
-
|
|
163
|
-
if dest.exists():
|
|
164
|
-
shutil.rmtree(dest)
|
|
165
|
-
dest.mkdir(parents=True)
|
|
166
|
-
|
|
167
|
-
if verbose:
|
|
168
|
-
print(f"Downloading {name}: {config.description}")
|
|
169
|
-
|
|
170
|
-
try:
|
|
171
|
-
count = _download_github_repo(config, dest, verbose)
|
|
172
|
-
except Exception:
|
|
173
|
-
# Clean up partial download so next run doesn't skip with stale cache
|
|
174
|
-
if dest.exists():
|
|
175
|
-
shutil.rmtree(dest)
|
|
176
|
-
raise
|
|
177
|
-
|
|
178
|
-
if verbose:
|
|
179
|
-
print(f"Downloaded {count} files to {dest}")
|
|
180
|
-
|
|
181
|
-
return dest
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def get_problems_path(name: ProblemSetName) -> Path | None:
|
|
185
|
-
"""Get path to downloaded problem set, or None if not downloaded.
|
|
186
|
-
|
|
187
|
-
Args:
|
|
188
|
-
name: Problem set name
|
|
189
|
-
|
|
190
|
-
Returns:
|
|
191
|
-
Path if downloaded, None otherwise
|
|
192
|
-
"""
|
|
193
|
-
if name not in PROBLEM_SETS:
|
|
194
|
-
return None
|
|
195
|
-
path = _problems_path(name)
|
|
196
|
-
return path if path.exists() else None
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
def list_problem_sets(verbose: bool = True) -> dict[ProblemSetName, bool]:
|
|
200
|
-
"""List available problem sets and their download status.
|
|
201
|
-
|
|
202
|
-
Returns:
|
|
203
|
-
Dict of problem set name -> is_downloaded
|
|
204
|
-
"""
|
|
205
|
-
result: dict[ProblemSetName, bool] = {}
|
|
206
|
-
for name, config in PROBLEM_SETS.items():
|
|
207
|
-
path = _problems_path(name)
|
|
208
|
-
exists = path.exists()
|
|
209
|
-
result[name] = exists
|
|
210
|
-
if verbose:
|
|
211
|
-
status = "✓" if exists else " "
|
|
212
|
-
print(f"[{status}] {name}: {config.description}")
|
|
213
|
-
print(f" Format: {config.format_description}")
|
|
214
|
-
if exists:
|
|
215
|
-
file_count = sum(1 for _ in path.rglob("*.py") if _.is_file())
|
|
216
|
-
print(f" Location: {path} ({file_count} Python files)")
|
|
217
|
-
return result
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
def list_problems(name: ProblemSetName, verbose: bool = True) -> list[str]:
|
|
221
|
-
"""List available problems in a problem set.
|
|
222
|
-
|
|
223
|
-
Args:
|
|
224
|
-
name: Problem set name
|
|
225
|
-
verbose: Print to stdout
|
|
226
|
-
|
|
227
|
-
Returns:
|
|
228
|
-
List of problem IDs
|
|
229
|
-
|
|
230
|
-
Raises:
|
|
231
|
-
ValueError: If problem set not downloaded
|
|
232
|
-
"""
|
|
233
|
-
path = get_problems_path(name)
|
|
234
|
-
if path is None:
|
|
235
|
-
raise ValueError(
|
|
236
|
-
f"Problem set '{name}' is not downloaded. Run:\n wafer evaluate {name} download"
|
|
237
|
-
)
|
|
238
|
-
|
|
239
|
-
if name == "kernelbench":
|
|
240
|
-
problems = _list_kernelbench_problems(path)
|
|
241
|
-
elif name == "gpumode":
|
|
242
|
-
problems = _list_gpumode_problems(path)
|
|
243
|
-
else:
|
|
244
|
-
problems = []
|
|
245
|
-
|
|
246
|
-
if verbose:
|
|
247
|
-
if not problems:
|
|
248
|
-
print(f"No problems found in {name}")
|
|
249
|
-
else:
|
|
250
|
-
print(f"Available problems in {name} ({len(problems)} total):\n")
|
|
251
|
-
for p in problems:
|
|
252
|
-
print(f" {p}")
|
|
253
|
-
|
|
254
|
-
return problems
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
def _list_kernelbench_problems(path: Path) -> list[str]:
|
|
258
|
-
"""List KernelBench problems: level1/1_Name.py format."""
|
|
259
|
-
problems: list[str] = []
|
|
260
|
-
kb_root = path / "KernelBench"
|
|
261
|
-
if not kb_root.exists():
|
|
262
|
-
kb_root = path # In case structure is flat
|
|
263
|
-
|
|
264
|
-
for level_dir in sorted(kb_root.iterdir()):
|
|
265
|
-
if not (level_dir.is_dir() and level_dir.name.startswith("level")):
|
|
266
|
-
continue
|
|
267
|
-
for problem_file in sorted(level_dir.glob("*.py")):
|
|
268
|
-
if problem_file.name.startswith("__"):
|
|
269
|
-
continue
|
|
270
|
-
problem_id = f"{level_dir.name}/{problem_file.stem}"
|
|
271
|
-
problems.append(problem_id)
|
|
272
|
-
return problems
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def _list_gpumode_problems(path: Path) -> list[str]:
|
|
276
|
-
"""List GPUMode problems: category/problem_name format."""
|
|
277
|
-
problems: list[str] = []
|
|
278
|
-
problems_root = path / "problems"
|
|
279
|
-
if not problems_root.exists():
|
|
280
|
-
problems_root = path
|
|
281
|
-
|
|
282
|
-
for category_dir in sorted(problems_root.iterdir()):
|
|
283
|
-
if not _is_valid_problem_dir(category_dir):
|
|
284
|
-
continue
|
|
285
|
-
for problem_dir in sorted(category_dir.iterdir()):
|
|
286
|
-
if not _is_valid_problem_dir(problem_dir):
|
|
287
|
-
continue
|
|
288
|
-
# Check if it has the expected files
|
|
289
|
-
has_reference = (problem_dir / "reference.py").exists()
|
|
290
|
-
has_task = (problem_dir / "task.yml").exists()
|
|
291
|
-
if has_reference or has_task:
|
|
292
|
-
problem_id = f"{category_dir.name}/{problem_dir.name}"
|
|
293
|
-
problems.append(problem_id)
|
|
294
|
-
return problems
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
def _is_valid_problem_dir(path: Path) -> bool:
|
|
298
|
-
"""Check if path is a valid problem directory (not hidden/special)."""
|
|
299
|
-
return path.is_dir() and not path.name.startswith((".", "_"))
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
def get_problem_path(name: ProblemSetName, problem_id: str) -> Path | None:
|
|
303
|
-
"""Get path to a specific problem.
|
|
304
|
-
|
|
305
|
-
Args:
|
|
306
|
-
name: Problem set name
|
|
307
|
-
problem_id: Problem ID (e.g., "level4/103" or "pmpp/vectoradd_py")
|
|
308
|
-
|
|
309
|
-
Returns:
|
|
310
|
-
Path to problem file/directory, or None if not found
|
|
311
|
-
"""
|
|
312
|
-
base_path = get_problems_path(name)
|
|
313
|
-
if base_path is None:
|
|
314
|
-
return None
|
|
315
|
-
|
|
316
|
-
if name == "kernelbench":
|
|
317
|
-
# Parse problem_id like "level4/103" or "level4/103_GroupedQueryAttention"
|
|
318
|
-
parts = problem_id.split("/")
|
|
319
|
-
if len(parts) != 2:
|
|
320
|
-
return None
|
|
321
|
-
|
|
322
|
-
level_str, problem_part = parts
|
|
323
|
-
if not level_str.startswith("level"):
|
|
324
|
-
level_str = f"level{level_str}"
|
|
325
|
-
|
|
326
|
-
kb_root = base_path / "KernelBench"
|
|
327
|
-
if not kb_root.exists():
|
|
328
|
-
kb_root = base_path
|
|
329
|
-
|
|
330
|
-
problem_dir = kb_root / level_str
|
|
331
|
-
if not problem_dir.exists():
|
|
332
|
-
return None
|
|
333
|
-
|
|
334
|
-
# Find matching problem file
|
|
335
|
-
problem_files = list(problem_dir.glob(f"{problem_part}*.py"))
|
|
336
|
-
if not problem_files:
|
|
337
|
-
# Try exact match
|
|
338
|
-
exact = problem_dir / f"{problem_part}.py"
|
|
339
|
-
if exact.exists():
|
|
340
|
-
return exact
|
|
341
|
-
return None
|
|
342
|
-
|
|
343
|
-
return problem_files[0]
|
|
344
|
-
|
|
345
|
-
elif name == "gpumode":
|
|
346
|
-
# Parse problem_id like "pmpp/vectoradd_py"
|
|
347
|
-
problems_root = base_path / "problems"
|
|
348
|
-
if not problems_root.exists():
|
|
349
|
-
problems_root = base_path
|
|
350
|
-
|
|
351
|
-
problem_path = problems_root / problem_id
|
|
352
|
-
if problem_path.exists() and problem_path.is_dir():
|
|
353
|
-
return problem_path
|
|
354
|
-
|
|
355
|
-
return None
|
|
356
|
-
|
|
357
|
-
return None
|