wafer-cli 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/gpu_run.py CHANGED
@@ -19,10 +19,7 @@ CONTAINER_WORKSPACE = "/workspace"
19
19
  class PushResult:
20
20
  """Result of pushing a directory to remote target."""
21
21
 
22
- workspace_name: str # Just the workspace name (e.g., "project")
23
- workspace_path: (
24
- str # Full absolute path on remote (e.g., "/home/user/.wafer/workspaces/project")
25
- )
22
+ workspace_path: str # Absolute path on remote (tilde-expanded)
26
23
  files_uploaded: list[str] # Relative paths of uploaded files
27
24
 
28
25
 
@@ -74,7 +71,6 @@ def push_directory(
74
71
  files_uploaded.append(str(file.relative_to(local_path)))
75
72
 
76
73
  return PushResult(
77
- workspace_name=workspace_name,
78
74
  workspace_path=expanded_workspace,
79
75
  files_uploaded=files_uploaded,
80
76
  )
wafer/targets.py CHANGED
@@ -257,164 +257,6 @@ def get_default_target() -> str | None:
257
257
  return data.get("default_target")
258
258
 
259
259
 
260
- # ── Pool Management ─────────────────────────────────────────────────────────
261
-
262
-
263
- def get_pool(name: str) -> list[str]:
264
- """Get list of targets in a named pool.
265
-
266
- Pools are defined in ~/.wafer/config.toml:
267
- [pools.my-pool]
268
- targets = ["target-1", "target-2", "target-3"]
269
-
270
- Args:
271
- name: Pool name
272
-
273
- Returns:
274
- List of target names in the pool
275
-
276
- Raises:
277
- FileNotFoundError: If pool doesn't exist
278
- """
279
- if not CONFIG_FILE.exists():
280
- raise FileNotFoundError(f"Pool not found: {name} (no config file)")
281
-
282
- with open(CONFIG_FILE, "rb") as f:
283
- data = tomllib.load(f)
284
-
285
- pools = data.get("pools", {})
286
- if name not in pools:
287
- raise FileNotFoundError(
288
- f"Pool not found: {name}\n"
289
- f" Define pools in ~/.wafer/config.toml:\n"
290
- f" [pools.{name}]\n"
291
- f' targets = ["target-1", "target-2"]'
292
- )
293
-
294
- pool_config = pools[name]
295
- targets = pool_config.get("targets", [])
296
-
297
- if not targets:
298
- raise ValueError(f"Pool '{name}' has no targets defined")
299
-
300
- return targets
301
-
302
-
303
- def list_pools() -> list[str]:
304
- """List all configured pool names.
305
-
306
- Returns:
307
- Sorted list of pool names
308
- """
309
- if not CONFIG_FILE.exists():
310
- return []
311
-
312
- with open(CONFIG_FILE, "rb") as f:
313
- data = tomllib.load(f)
314
-
315
- return sorted(data.get("pools", {}).keys())
316
-
317
-
318
- def save_pool(name: str, targets: list[str]) -> None:
319
- """Save or update a pool configuration.
320
-
321
- Args:
322
- name: Pool name
323
- targets: List of target names (must all exist)
324
-
325
- Raises:
326
- FileNotFoundError: If any target doesn't exist
327
- """
328
- # Verify all targets exist
329
- existing_targets = list_targets()
330
- missing = [t for t in targets if t not in existing_targets]
331
- if missing:
332
- raise FileNotFoundError(f"Targets not found: {', '.join(missing)}")
333
-
334
- _ensure_dirs()
335
-
336
- # Load existing config
337
- if CONFIG_FILE.exists():
338
- with open(CONFIG_FILE, "rb") as f:
339
- data = tomllib.load(f)
340
- else:
341
- data = {}
342
-
343
- # Update pools section
344
- if "pools" not in data:
345
- data["pools"] = {}
346
-
347
- data["pools"][name] = {"targets": targets}
348
-
349
- # Write back - need custom handling for nested structure
350
- _write_config_with_pools(data)
351
-
352
-
353
- def _write_config_with_pools(data: dict) -> None:
354
- """Write config file with pools support.
355
-
356
- Handles the nested [pools.name] TOML structure and preserves
357
- existing nested sections like [default], [api], [environments.*].
358
- """
359
- lines = []
360
-
361
- # Collect nested sections to write after top-level keys
362
- nested_sections: dict[str, dict] = {}
363
-
364
- # Write top-level keys first (except pools and nested dicts)
365
- for key, value in data.items():
366
- if key == "pools":
367
- continue
368
- if value is None:
369
- continue
370
- if isinstance(value, dict):
371
- # Save nested sections for later
372
- nested_sections[key] = value
373
- elif isinstance(value, str):
374
- lines.append(f'{key} = "{value}"')
375
- elif isinstance(value, bool):
376
- lines.append(f"{key} = {str(value).lower()}")
377
- elif isinstance(value, int | float):
378
- lines.append(f"{key} = {value}")
379
- elif isinstance(value, list):
380
- if all(isinstance(v, int) for v in value):
381
- lines.append(f"{key} = {value}")
382
- else:
383
- formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
384
- lines.append(f"{key} = [{formatted}]")
385
-
386
- # Write nested sections (e.g., [default], [api], [environments.foo])
387
- for section_name, section_data in nested_sections.items():
388
- lines.append("")
389
- lines.append(f"[{section_name}]")
390
- for key, value in section_data.items():
391
- if value is None:
392
- continue
393
- if isinstance(value, str):
394
- lines.append(f'{key} = "{value}"')
395
- elif isinstance(value, bool):
396
- lines.append(f"{key} = {str(value).lower()}")
397
- elif isinstance(value, int | float):
398
- lines.append(f"{key} = {value}")
399
- elif isinstance(value, list):
400
- if all(isinstance(v, int) for v in value):
401
- lines.append(f"{key} = {value}")
402
- else:
403
- formatted = ", ".join(f'"{v}"' if isinstance(v, str) else str(v) for v in value)
404
- lines.append(f"{key} = [{formatted}]")
405
-
406
- # Write pools
407
- pools = data.get("pools", {})
408
- for pool_name, pool_config in pools.items():
409
- lines.append("")
410
- lines.append(f"[pools.{pool_name}]")
411
- targets = pool_config.get("targets", [])
412
- formatted = ", ".join(f'"{t}"' for t in targets)
413
- lines.append(f"targets = [{formatted}]")
414
-
415
- CONFIG_FILE.write_text("\n".join(lines) + "\n")
416
-
417
-
418
260
  def set_default_target(name: str) -> None:
419
261
  """Set default target.
420
262
 
wafer/wevin_cli.py CHANGED
@@ -253,7 +253,6 @@ def _build_environment(
253
253
  ) -> Environment:
254
254
  """Build a CodingEnvironment from template config."""
255
255
  from wafer_core.environments.coding import CodingEnvironment
256
- from wafer_core.rollouts.templates import DANGEROUS_BASH_COMMANDS
257
256
 
258
257
  working_dir = Path(corpus_path) if corpus_path else Path.cwd()
259
258
  resolved_tools = tools_override or tpl.tools
@@ -261,7 +260,6 @@ def _build_environment(
261
260
  working_dir=working_dir,
262
261
  enabled_tools=resolved_tools,
263
262
  bash_allowlist=tpl.bash_allowlist,
264
- bash_denylist=DANGEROUS_BASH_COMMANDS,
265
263
  ) # type: ignore[assignment]
266
264
  return env
267
265
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
@@ -5,31 +5,29 @@ wafer/api_client.py,sha256=cPULiTxqOAYYSfDTNJgd-6Pqrt3IM4Gm9903U7yGIwY,6163
5
5
  wafer/auth.py,sha256=ZLsXZ73GDLD8GL7Rij1ELtuLqyJ5EU_uPBUMPVKwExA,10703
6
6
  wafer/autotuner.py,sha256=6gH0Ho7T58EFerMQcHQxshWe3DF4qU7fb5xthAh5SPM,44364
7
7
  wafer/billing.py,sha256=jbLB2lI4_9f2KD8uEFDi_ixLlowe5hasC0TIZJyIXRg,7163
8
- wafer/cli.py,sha256=rdR84w5ubXO1W2xnrURTrX3AtXy3VeUFVekeGw0djyA,211114
8
+ wafer/cli.py,sha256=QgqaBkCrpnLD6IaY35Eo-JITR5vnMKmHCmnqniW0Yv4,184987
9
9
  wafer/config.py,sha256=h5Eo9_yfWqWGoPNdVQikI9GoZVUeysunSYiixf1mKcw,3411
10
10
  wafer/corpus.py,sha256=yTF3UA5bOa8BII2fmcXf-3WsIsM5DX4etysv0AzVknE,8912
11
- wafer/evaluate.py,sha256=D_0WqIw1HysH7XXiyjxRpv6BUuOoPljN9-1vjPt4xFo,166765
11
+ wafer/evaluate.py,sha256=nIqLQap9-mUtzOWTCJXkZsNydeo36uSTfiD9dGM07aA,130748
12
12
  wafer/global_config.py,sha256=fhaR_RU3ufMksDmOohH1OLeQ0JT0SDW1hEip_zaP75k,11345
13
- wafer/gpu_run.py,sha256=TwqXy72T7f2I7e6n5WWod3xgxCPnDhU0BgLsB4CUoQY,9716
13
+ wafer/gpu_run.py,sha256=gUbzMsMPsw3UHcn00bI-zTSHrU8c2FEpDvbcsczlDPo,9557
14
14
  wafer/inference.py,sha256=tZCO5i05FKY27ewis3CSBHFBeFbXY3xwj0DSjdoMY9s,4314
15
15
  wafer/ncu_analyze.py,sha256=rAWzKQRZEY6E_CL3gAWUaW3uZ4kvQVZskVCPDpsFJuE,24633
16
16
  wafer/nsys_analyze.py,sha256=dRsYNYp1IqzGSPrQuEMW5vRbIxr-VrQwQbotLSrPvlY,6795
17
- wafer/problems.py,sha256=ce2sy10A1nnNUG3VGsseTS8jL7LZsku4dE8zVf9JHQ4,11296
18
17
  wafer/rocprof_compute.py,sha256=Tu16Vb05b2grvheFWi1XLGlAr6m48NEDeZoDyw_4Uzw,19885
19
18
  wafer/rocprof_sdk.py,sha256=fAYCxpfJa5BZTTkIMBOXg4KsYK4i_wNOKrJJn1ZfypM,10086
20
19
  wafer/rocprof_systems.py,sha256=4IWbMcbYk1x_8iS7P3FC_u5sgH6EXADCtR2lV9id80M,18629
21
- wafer/target_lock.py,sha256=QW0NMlu9Paa28O5iSAvGtN11j3kU2di0lADBrfwr2js,5160
22
- wafer/targets.py,sha256=JlLvi18IHtOkgtBdkv_nUrzBweVmFoOQH-9tQW5s1yQ,15250
20
+ wafer/targets.py,sha256=WE5TJgFPGtEIh7VaTQHZ4wB2t4kW0c5K8-UmQ_39Ock,10254
23
21
  wafer/tracelens.py,sha256=g9ZIeFyNojZn4uTd3skPqIrRiL7aMJOz_-GOd3aiyy4,7998
24
- wafer/wevin_cli.py,sha256=1_o2P47namZmPkbt47TnyYDmwhEzQYbSg5zjHffu2JQ,16802
22
+ wafer/wevin_cli.py,sha256=jvj8H9cNf2EXhVnifQzDrz0aR3mzHgCv68CdIkCx6po,16685
25
23
  wafer/workspaces.py,sha256=92LG1mtkzNz-ap3XzcqY6KnQ9SUCFG8VBIOUj1Who64,25757
26
24
  wafer/skills/wafer-guide/SKILL.md,sha256=UfBeIe5GKFzOYcbPmcs8U2nrjbfr-jSMRwg0jQDBfb0,3058
27
25
  wafer/templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
26
  wafer/templates/ask_docs.py,sha256=Lxs-faz9v5m4Qa4NjF2X_lE8KwM9ES9MNJkxo7ep56o,2256
29
27
  wafer/templates/optimize_kernel.py,sha256=u6AL7Q3uttqlnBLzcoFdsiPq5lV2TV3bgqwCYYlK9gk,2357
30
28
  wafer/templates/trace_analyze.py,sha256=XE1VqzVkIUsZbXF8EzQdDYgg-AZEYAOFpr6B_vnRELc,2880
31
- wafer_cli-0.2.6.dist-info/METADATA,sha256=I72dRpc7rN96uFpTsyJkaJ0gS55LA9xpugMUGo2Dayw,559
32
- wafer_cli-0.2.6.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
33
- wafer_cli-0.2.6.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
34
- wafer_cli-0.2.6.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
35
- wafer_cli-0.2.6.dist-info/RECORD,,
29
+ wafer_cli-0.2.8.dist-info/METADATA,sha256=tihbS8AP8QoiVqZWjudfFu9iXdijuO1QVxhoQb4lml4,559
30
+ wafer_cli-0.2.8.dist-info/WHEEL,sha256=qELbo2s1Yzl39ZmrAibXA2jjPLUYfnVhUNTlyF1rq0Y,92
31
+ wafer_cli-0.2.8.dist-info/entry_points.txt,sha256=WqB7hB__WhtPY8y1cO2sZiUz7fCq6Ik-usAigpeFvWE,41
32
+ wafer_cli-0.2.8.dist-info/top_level.txt,sha256=2MK1IVMWfpLL8BZCQ3E9aG6L6L666gSA_teYlwan4fs,6
33
+ wafer_cli-0.2.8.dist-info/RECORD,,
wafer/problems.py DELETED
@@ -1,357 +0,0 @@
1
- """Problem set management for Wafer CLI.
2
-
3
- Download and manage kernel optimization problem sets for evaluation.
4
- Follows the same pattern as corpus.py for consistency.
5
- """
6
-
7
- import shutil
8
- import tarfile
9
- import tempfile
10
- from dataclasses import dataclass
11
- from pathlib import Path
12
- from typing import Literal
13
-
14
- import httpx
15
-
16
- PROBLEMS_CACHE_DIR = Path.home() / ".cache" / "wafer" / "problems"
17
-
18
- ProblemSetName = Literal["kernelbench", "gpumode"]
19
-
20
-
21
- @dataclass
22
- class ProblemSetConfig:
23
- """Configuration for a downloadable problem set."""
24
-
25
- name: ProblemSetName
26
- description: str
27
- repo: str # GitHub repo in "owner/repo" format
28
- repo_paths: list[str] # Paths within repo to download
29
- format_description: str # Brief description of the format
30
-
31
-
32
- PROBLEM_SETS: dict[ProblemSetName, ProblemSetConfig] = {
33
- "kernelbench": ProblemSetConfig(
34
- name="kernelbench",
35
- description="KernelBench GPU kernel optimization problems (level1-4)",
36
- repo="ScalingIntelligence/KernelBench",
37
- repo_paths=["KernelBench"],
38
- format_description="Class-based: Model/ModelNew with get_inputs/get_init_inputs",
39
- ),
40
- "gpumode": ProblemSetConfig(
41
- name="gpumode",
42
- description="GPU Mode reference kernels (pmpp, amd, nvidia, bioml)",
43
- repo="gpu-mode/reference-kernels",
44
- repo_paths=["problems"],
45
- format_description="Functional: ref_kernel/custom_kernel with generate_input",
46
- ),
47
- }
48
-
49
-
50
- def _problems_path(name: ProblemSetName) -> Path:
51
- """Get local path for problem set."""
52
- return PROBLEMS_CACHE_DIR / name
53
-
54
-
55
- def _ensure_cache_dir() -> None:
56
- """Ensure cache directory exists."""
57
- PROBLEMS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
58
-
59
-
60
- def _download_github_repo(config: ProblemSetConfig, dest: Path, verbose: bool = True) -> int:
61
- """Download specific paths from GitHub repo.
62
-
63
- Args:
64
- config: Problem set configuration
65
- dest: Destination directory
66
- verbose: Print progress
67
-
68
- Returns:
69
- Number of files downloaded
70
- """
71
- # Fetch tarball from GitHub
72
- resp = _fetch_github_tarball(config.repo, verbose)
73
-
74
- # Save to temp file
75
- with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
76
- tmp.write(resp.content)
77
- tmp_path = Path(tmp.name)
78
-
79
- # Extract matching files
80
- try:
81
- downloaded = _extract_tarball(tmp_path, dest, config.repo_paths, verbose)
82
- finally:
83
- tmp_path.unlink()
84
-
85
- return downloaded
86
-
87
-
88
- def _fetch_github_tarball(repo: str, verbose: bool) -> httpx.Response:
89
- """Fetch tarball from GitHub, trying main then master branch."""
90
- with httpx.Client(timeout=120.0, follow_redirects=True) as client:
91
- for branch in ["main", "master"]:
92
- tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
93
- if verbose:
94
- print(f" Fetching {repo} ({branch} branch)...")
95
- try:
96
- resp = client.get(tarball_url)
97
- resp.raise_for_status()
98
- return resp
99
- except httpx.HTTPStatusError:
100
- if branch == "master":
101
- raise
102
- raise RuntimeError(f"Failed to fetch tarball from {repo}") # Should not reach
103
-
104
-
105
- def _extract_tarball(tmp_path: Path, dest: Path, repo_paths: list[str], verbose: bool) -> int:
106
- """Extract files from tarball matching repo_paths."""
107
- downloaded = 0
108
- with tarfile.open(tmp_path, "r:gz") as tar:
109
- for member in tar.getmembers():
110
- if not member.isfile():
111
- continue
112
- # Strip the root directory (e.g., "ScalingIntelligence-KernelBench-abc123/")
113
- rel_path = "/".join(member.name.split("/")[1:])
114
- if not _matches_repo_paths(rel_path, repo_paths):
115
- continue
116
- target = dest / rel_path
117
- target.parent.mkdir(parents=True, exist_ok=True)
118
- extracted = tar.extractfile(member)
119
- if extracted:
120
- target.write_bytes(extracted.read())
121
- downloaded += 1
122
- if verbose and downloaded <= 10:
123
- print(f" ✓ {rel_path}")
124
- if verbose and downloaded > 10:
125
- print(f" ... and {downloaded - 10} more files")
126
- return downloaded
127
-
128
-
129
- def _matches_repo_paths(rel_path: str, repo_paths: list[str]) -> bool:
130
- """Check if rel_path starts with any of the repo_paths."""
131
- return any(rel_path.startswith(rp) for rp in repo_paths)
132
-
133
-
134
- def download_problems(name: ProblemSetName, force: bool = False, verbose: bool = True) -> Path:
135
- """Download a problem set to local cache.
136
-
137
- Args:
138
- name: Problem set name
139
- force: Re-download even if exists
140
- verbose: Print progress
141
-
142
- Returns:
143
- Path to downloaded problem set
144
-
145
- Raises:
146
- ValueError: If problem set name is unknown
147
- httpx.HTTPError: If download fails
148
- """
149
- if name not in PROBLEM_SETS:
150
- raise ValueError(f"Unknown problem set: {name}. Available: {list(PROBLEM_SETS.keys())}")
151
-
152
- config = PROBLEM_SETS[name]
153
- dest = _problems_path(name)
154
-
155
- if dest.exists() and not force:
156
- if verbose:
157
- print(f"Problem set '{name}' already exists at {dest}")
158
- print("Use --force to re-download")
159
- return dest
160
-
161
- _ensure_cache_dir()
162
-
163
- if dest.exists():
164
- shutil.rmtree(dest)
165
- dest.mkdir(parents=True)
166
-
167
- if verbose:
168
- print(f"Downloading {name}: {config.description}")
169
-
170
- try:
171
- count = _download_github_repo(config, dest, verbose)
172
- except Exception:
173
- # Clean up partial download so next run doesn't skip with stale cache
174
- if dest.exists():
175
- shutil.rmtree(dest)
176
- raise
177
-
178
- if verbose:
179
- print(f"Downloaded {count} files to {dest}")
180
-
181
- return dest
182
-
183
-
184
- def get_problems_path(name: ProblemSetName) -> Path | None:
185
- """Get path to downloaded problem set, or None if not downloaded.
186
-
187
- Args:
188
- name: Problem set name
189
-
190
- Returns:
191
- Path if downloaded, None otherwise
192
- """
193
- if name not in PROBLEM_SETS:
194
- return None
195
- path = _problems_path(name)
196
- return path if path.exists() else None
197
-
198
-
199
- def list_problem_sets(verbose: bool = True) -> dict[ProblemSetName, bool]:
200
- """List available problem sets and their download status.
201
-
202
- Returns:
203
- Dict of problem set name -> is_downloaded
204
- """
205
- result: dict[ProblemSetName, bool] = {}
206
- for name, config in PROBLEM_SETS.items():
207
- path = _problems_path(name)
208
- exists = path.exists()
209
- result[name] = exists
210
- if verbose:
211
- status = "✓" if exists else " "
212
- print(f"[{status}] {name}: {config.description}")
213
- print(f" Format: {config.format_description}")
214
- if exists:
215
- file_count = sum(1 for _ in path.rglob("*.py") if _.is_file())
216
- print(f" Location: {path} ({file_count} Python files)")
217
- return result
218
-
219
-
220
- def list_problems(name: ProblemSetName, verbose: bool = True) -> list[str]:
221
- """List available problems in a problem set.
222
-
223
- Args:
224
- name: Problem set name
225
- verbose: Print to stdout
226
-
227
- Returns:
228
- List of problem IDs
229
-
230
- Raises:
231
- ValueError: If problem set not downloaded
232
- """
233
- path = get_problems_path(name)
234
- if path is None:
235
- raise ValueError(
236
- f"Problem set '{name}' is not downloaded. Run:\n wafer evaluate {name} download"
237
- )
238
-
239
- if name == "kernelbench":
240
- problems = _list_kernelbench_problems(path)
241
- elif name == "gpumode":
242
- problems = _list_gpumode_problems(path)
243
- else:
244
- problems = []
245
-
246
- if verbose:
247
- if not problems:
248
- print(f"No problems found in {name}")
249
- else:
250
- print(f"Available problems in {name} ({len(problems)} total):\n")
251
- for p in problems:
252
- print(f" {p}")
253
-
254
- return problems
255
-
256
-
257
- def _list_kernelbench_problems(path: Path) -> list[str]:
258
- """List KernelBench problems: level1/1_Name.py format."""
259
- problems: list[str] = []
260
- kb_root = path / "KernelBench"
261
- if not kb_root.exists():
262
- kb_root = path # In case structure is flat
263
-
264
- for level_dir in sorted(kb_root.iterdir()):
265
- if not (level_dir.is_dir() and level_dir.name.startswith("level")):
266
- continue
267
- for problem_file in sorted(level_dir.glob("*.py")):
268
- if problem_file.name.startswith("__"):
269
- continue
270
- problem_id = f"{level_dir.name}/{problem_file.stem}"
271
- problems.append(problem_id)
272
- return problems
273
-
274
-
275
- def _list_gpumode_problems(path: Path) -> list[str]:
276
- """List GPUMode problems: category/problem_name format."""
277
- problems: list[str] = []
278
- problems_root = path / "problems"
279
- if not problems_root.exists():
280
- problems_root = path
281
-
282
- for category_dir in sorted(problems_root.iterdir()):
283
- if not _is_valid_problem_dir(category_dir):
284
- continue
285
- for problem_dir in sorted(category_dir.iterdir()):
286
- if not _is_valid_problem_dir(problem_dir):
287
- continue
288
- # Check if it has the expected files
289
- has_reference = (problem_dir / "reference.py").exists()
290
- has_task = (problem_dir / "task.yml").exists()
291
- if has_reference or has_task:
292
- problem_id = f"{category_dir.name}/{problem_dir.name}"
293
- problems.append(problem_id)
294
- return problems
295
-
296
-
297
- def _is_valid_problem_dir(path: Path) -> bool:
298
- """Check if path is a valid problem directory (not hidden/special)."""
299
- return path.is_dir() and not path.name.startswith((".", "_"))
300
-
301
-
302
- def get_problem_path(name: ProblemSetName, problem_id: str) -> Path | None:
303
- """Get path to a specific problem.
304
-
305
- Args:
306
- name: Problem set name
307
- problem_id: Problem ID (e.g., "level4/103" or "pmpp/vectoradd_py")
308
-
309
- Returns:
310
- Path to problem file/directory, or None if not found
311
- """
312
- base_path = get_problems_path(name)
313
- if base_path is None:
314
- return None
315
-
316
- if name == "kernelbench":
317
- # Parse problem_id like "level4/103" or "level4/103_GroupedQueryAttention"
318
- parts = problem_id.split("/")
319
- if len(parts) != 2:
320
- return None
321
-
322
- level_str, problem_part = parts
323
- if not level_str.startswith("level"):
324
- level_str = f"level{level_str}"
325
-
326
- kb_root = base_path / "KernelBench"
327
- if not kb_root.exists():
328
- kb_root = base_path
329
-
330
- problem_dir = kb_root / level_str
331
- if not problem_dir.exists():
332
- return None
333
-
334
- # Find matching problem file
335
- problem_files = list(problem_dir.glob(f"{problem_part}*.py"))
336
- if not problem_files:
337
- # Try exact match
338
- exact = problem_dir / f"{problem_part}.py"
339
- if exact.exists():
340
- return exact
341
- return None
342
-
343
- return problem_files[0]
344
-
345
- elif name == "gpumode":
346
- # Parse problem_id like "pmpp/vectoradd_py"
347
- problems_root = base_path / "problems"
348
- if not problems_root.exists():
349
- problems_root = base_path
350
-
351
- problem_path = problems_root / problem_id
352
- if problem_path.exists() and problem_path.is_dir():
353
- return problem_path
354
-
355
- return None
356
-
357
- return None