wafer-cli 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/problems.py ADDED
@@ -0,0 +1,357 @@
1
+ """Problem set management for Wafer CLI.
2
+
3
+ Download and manage kernel optimization problem sets for evaluation.
4
+ Follows the same pattern as corpus.py for consistency.
5
+ """
6
+
7
+ import shutil
8
+ import tarfile
9
+ import tempfile
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Literal
13
+
14
+ import httpx
15
+
16
+ PROBLEMS_CACHE_DIR = Path.home() / ".cache" / "wafer" / "problems"
17
+
18
+ ProblemSetName = Literal["kernelbench", "gpumode"]
19
+
20
+
21
+ @dataclass
22
+ class ProblemSetConfig:
23
+ """Configuration for a downloadable problem set."""
24
+
25
+ name: ProblemSetName
26
+ description: str
27
+ repo: str # GitHub repo in "owner/repo" format
28
+ repo_paths: list[str] # Paths within repo to download
29
+ format_description: str # Brief description of the format
30
+
31
+
32
+ PROBLEM_SETS: dict[ProblemSetName, ProblemSetConfig] = {
33
+ "kernelbench": ProblemSetConfig(
34
+ name="kernelbench",
35
+ description="KernelBench GPU kernel optimization problems (level1-4)",
36
+ repo="ScalingIntelligence/KernelBench",
37
+ repo_paths=["KernelBench"],
38
+ format_description="Class-based: Model/ModelNew with get_inputs/get_init_inputs",
39
+ ),
40
+ "gpumode": ProblemSetConfig(
41
+ name="gpumode",
42
+ description="GPU Mode reference kernels (pmpp, amd, nvidia, bioml)",
43
+ repo="gpu-mode/reference-kernels",
44
+ repo_paths=["problems"],
45
+ format_description="Functional: ref_kernel/custom_kernel with generate_input",
46
+ ),
47
+ }
48
+
49
+
50
+ def _problems_path(name: ProblemSetName) -> Path:
51
+ """Get local path for problem set."""
52
+ return PROBLEMS_CACHE_DIR / name
53
+
54
+
55
+ def _ensure_cache_dir() -> None:
56
+ """Ensure cache directory exists."""
57
+ PROBLEMS_CACHE_DIR.mkdir(parents=True, exist_ok=True)
58
+
59
+
60
+ def _download_github_repo(config: ProblemSetConfig, dest: Path, verbose: bool = True) -> int:
61
+ """Download specific paths from GitHub repo.
62
+
63
+ Args:
64
+ config: Problem set configuration
65
+ dest: Destination directory
66
+ verbose: Print progress
67
+
68
+ Returns:
69
+ Number of files downloaded
70
+ """
71
+ # Fetch tarball from GitHub
72
+ resp = _fetch_github_tarball(config.repo, verbose)
73
+
74
+ # Save to temp file
75
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
76
+ tmp.write(resp.content)
77
+ tmp_path = Path(tmp.name)
78
+
79
+ # Extract matching files
80
+ try:
81
+ downloaded = _extract_tarball(tmp_path, dest, config.repo_paths, verbose)
82
+ finally:
83
+ tmp_path.unlink()
84
+
85
+ return downloaded
86
+
87
+
88
+ def _fetch_github_tarball(repo: str, verbose: bool) -> httpx.Response:
89
+ """Fetch tarball from GitHub, trying main then master branch."""
90
+ with httpx.Client(timeout=120.0, follow_redirects=True) as client:
91
+ for branch in ["main", "master"]:
92
+ tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
93
+ if verbose:
94
+ print(f" Fetching {repo} ({branch} branch)...")
95
+ try:
96
+ resp = client.get(tarball_url)
97
+ resp.raise_for_status()
98
+ return resp
99
+ except httpx.HTTPStatusError:
100
+ if branch == "master":
101
+ raise
102
+ raise RuntimeError(f"Failed to fetch tarball from {repo}") # Should not reach
103
+
104
+
105
+ def _extract_tarball(tmp_path: Path, dest: Path, repo_paths: list[str], verbose: bool) -> int:
106
+ """Extract files from tarball matching repo_paths."""
107
+ downloaded = 0
108
+ with tarfile.open(tmp_path, "r:gz") as tar:
109
+ for member in tar.getmembers():
110
+ if not member.isfile():
111
+ continue
112
+ # Strip the root directory (e.g., "ScalingIntelligence-KernelBench-abc123/")
113
+ rel_path = "/".join(member.name.split("/")[1:])
114
+ if not _matches_repo_paths(rel_path, repo_paths):
115
+ continue
116
+ target = dest / rel_path
117
+ target.parent.mkdir(parents=True, exist_ok=True)
118
+ extracted = tar.extractfile(member)
119
+ if extracted:
120
+ target.write_bytes(extracted.read())
121
+ downloaded += 1
122
+ if verbose and downloaded <= 10:
123
+ print(f" ✓ {rel_path}")
124
+ if verbose and downloaded > 10:
125
+ print(f" ... and {downloaded - 10} more files")
126
+ return downloaded
127
+
128
+
129
+ def _matches_repo_paths(rel_path: str, repo_paths: list[str]) -> bool:
130
+ """Check if rel_path starts with any of the repo_paths."""
131
+ return any(rel_path.startswith(rp) for rp in repo_paths)
132
+
133
+
134
+ def download_problems(name: ProblemSetName, force: bool = False, verbose: bool = True) -> Path:
135
+ """Download a problem set to local cache.
136
+
137
+ Args:
138
+ name: Problem set name
139
+ force: Re-download even if exists
140
+ verbose: Print progress
141
+
142
+ Returns:
143
+ Path to downloaded problem set
144
+
145
+ Raises:
146
+ ValueError: If problem set name is unknown
147
+ httpx.HTTPError: If download fails
148
+ """
149
+ if name not in PROBLEM_SETS:
150
+ raise ValueError(f"Unknown problem set: {name}. Available: {list(PROBLEM_SETS.keys())}")
151
+
152
+ config = PROBLEM_SETS[name]
153
+ dest = _problems_path(name)
154
+
155
+ if dest.exists() and not force:
156
+ if verbose:
157
+ print(f"Problem set '{name}' already exists at {dest}")
158
+ print("Use --force to re-download")
159
+ return dest
160
+
161
+ _ensure_cache_dir()
162
+
163
+ if dest.exists():
164
+ shutil.rmtree(dest)
165
+ dest.mkdir(parents=True)
166
+
167
+ if verbose:
168
+ print(f"Downloading {name}: {config.description}")
169
+
170
+ try:
171
+ count = _download_github_repo(config, dest, verbose)
172
+ except Exception:
173
+ # Clean up partial download so next run doesn't skip with stale cache
174
+ if dest.exists():
175
+ shutil.rmtree(dest)
176
+ raise
177
+
178
+ if verbose:
179
+ print(f"Downloaded {count} files to {dest}")
180
+
181
+ return dest
182
+
183
+
184
+ def get_problems_path(name: ProblemSetName) -> Path | None:
185
+ """Get path to downloaded problem set, or None if not downloaded.
186
+
187
+ Args:
188
+ name: Problem set name
189
+
190
+ Returns:
191
+ Path if downloaded, None otherwise
192
+ """
193
+ if name not in PROBLEM_SETS:
194
+ return None
195
+ path = _problems_path(name)
196
+ return path if path.exists() else None
197
+
198
+
199
+ def list_problem_sets(verbose: bool = True) -> dict[ProblemSetName, bool]:
200
+ """List available problem sets and their download status.
201
+
202
+ Returns:
203
+ Dict of problem set name -> is_downloaded
204
+ """
205
+ result: dict[ProblemSetName, bool] = {}
206
+ for name, config in PROBLEM_SETS.items():
207
+ path = _problems_path(name)
208
+ exists = path.exists()
209
+ result[name] = exists
210
+ if verbose:
211
+ status = "✓" if exists else " "
212
+ print(f"[{status}] {name}: {config.description}")
213
+ print(f" Format: {config.format_description}")
214
+ if exists:
215
+ file_count = sum(1 for _ in path.rglob("*.py") if _.is_file())
216
+ print(f" Location: {path} ({file_count} Python files)")
217
+ return result
218
+
219
+
220
+ def list_problems(name: ProblemSetName, verbose: bool = True) -> list[str]:
221
+ """List available problems in a problem set.
222
+
223
+ Args:
224
+ name: Problem set name
225
+ verbose: Print to stdout
226
+
227
+ Returns:
228
+ List of problem IDs
229
+
230
+ Raises:
231
+ ValueError: If problem set not downloaded
232
+ """
233
+ path = get_problems_path(name)
234
+ if path is None:
235
+ raise ValueError(
236
+ f"Problem set '{name}' is not downloaded. Run:\n wafer evaluate {name} download"
237
+ )
238
+
239
+ if name == "kernelbench":
240
+ problems = _list_kernelbench_problems(path)
241
+ elif name == "gpumode":
242
+ problems = _list_gpumode_problems(path)
243
+ else:
244
+ problems = []
245
+
246
+ if verbose:
247
+ if not problems:
248
+ print(f"No problems found in {name}")
249
+ else:
250
+ print(f"Available problems in {name} ({len(problems)} total):\n")
251
+ for p in problems:
252
+ print(f" {p}")
253
+
254
+ return problems
255
+
256
+
257
+ def _list_kernelbench_problems(path: Path) -> list[str]:
258
+ """List KernelBench problems: level1/1_Name.py format."""
259
+ problems: list[str] = []
260
+ kb_root = path / "KernelBench"
261
+ if not kb_root.exists():
262
+ kb_root = path # In case structure is flat
263
+
264
+ for level_dir in sorted(kb_root.iterdir()):
265
+ if not (level_dir.is_dir() and level_dir.name.startswith("level")):
266
+ continue
267
+ for problem_file in sorted(level_dir.glob("*.py")):
268
+ if problem_file.name.startswith("__"):
269
+ continue
270
+ problem_id = f"{level_dir.name}/{problem_file.stem}"
271
+ problems.append(problem_id)
272
+ return problems
273
+
274
+
275
+ def _list_gpumode_problems(path: Path) -> list[str]:
276
+ """List GPUMode problems: category/problem_name format."""
277
+ problems: list[str] = []
278
+ problems_root = path / "problems"
279
+ if not problems_root.exists():
280
+ problems_root = path
281
+
282
+ for category_dir in sorted(problems_root.iterdir()):
283
+ if not _is_valid_problem_dir(category_dir):
284
+ continue
285
+ for problem_dir in sorted(category_dir.iterdir()):
286
+ if not _is_valid_problem_dir(problem_dir):
287
+ continue
288
+ # Check if it has the expected files
289
+ has_reference = (problem_dir / "reference.py").exists()
290
+ has_task = (problem_dir / "task.yml").exists()
291
+ if has_reference or has_task:
292
+ problem_id = f"{category_dir.name}/{problem_dir.name}"
293
+ problems.append(problem_id)
294
+ return problems
295
+
296
+
297
+ def _is_valid_problem_dir(path: Path) -> bool:
298
+ """Check if path is a valid problem directory (not hidden/special)."""
299
+ return path.is_dir() and not path.name.startswith((".", "_"))
300
+
301
+
302
+ def get_problem_path(name: ProblemSetName, problem_id: str) -> Path | None:
303
+ """Get path to a specific problem.
304
+
305
+ Args:
306
+ name: Problem set name
307
+ problem_id: Problem ID (e.g., "level4/103" or "pmpp/vectoradd_py")
308
+
309
+ Returns:
310
+ Path to problem file/directory, or None if not found
311
+ """
312
+ base_path = get_problems_path(name)
313
+ if base_path is None:
314
+ return None
315
+
316
+ if name == "kernelbench":
317
+ # Parse problem_id like "level4/103" or "level4/103_GroupedQueryAttention"
318
+ parts = problem_id.split("/")
319
+ if len(parts) != 2:
320
+ return None
321
+
322
+ level_str, problem_part = parts
323
+ if not level_str.startswith("level"):
324
+ level_str = f"level{level_str}"
325
+
326
+ kb_root = base_path / "KernelBench"
327
+ if not kb_root.exists():
328
+ kb_root = base_path
329
+
330
+ problem_dir = kb_root / level_str
331
+ if not problem_dir.exists():
332
+ return None
333
+
334
+ # Find matching problem file
335
+ problem_files = list(problem_dir.glob(f"{problem_part}*.py"))
336
+ if not problem_files:
337
+ # Try exact match
338
+ exact = problem_dir / f"{problem_part}.py"
339
+ if exact.exists():
340
+ return exact
341
+ return None
342
+
343
+ return problem_files[0]
344
+
345
+ elif name == "gpumode":
346
+ # Parse problem_id like "pmpp/vectoradd_py"
347
+ problems_root = base_path / "problems"
348
+ if not problems_root.exists():
349
+ problems_root = base_path
350
+
351
+ problem_path = problems_root / problem_id
352
+ if problem_path.exists() and problem_path.is_dir():
353
+ return problem_path
354
+
355
+ return None
356
+
357
+ return None
wafer/target_lock.py ADDED
@@ -0,0 +1,270 @@
1
+ """Target locking for concurrent access control.
2
+
3
+ Uses file locks (fcntl.flock) to ensure only one process uses a target at a time.
4
+ Locks are automatically released when the process exits or crashes.
5
+
6
+ Usage:
7
+ # Try to acquire a single target
8
+ with try_acquire_target("mi300x-1") as acquired:
9
+ if acquired:
10
+ # Got the lock, run eval
11
+ ...
12
+ else:
13
+ # Target busy
14
+ ...
15
+
16
+ # Acquire first available from a pool
17
+ with acquire_from_pool(["mi300x-1", "mi300x-2", "mi300x-3"]) as target:
18
+ if target:
19
+ # Got a target, run eval
20
+ ...
21
+ else:
22
+ # All targets busy
23
+ ...
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import fcntl
29
+ import json
30
+ import os
31
+ import sys
32
+ import time
33
+ from collections.abc import Iterator
34
+ from contextlib import contextmanager
35
+ from datetime import UTC
36
+ from pathlib import Path
37
+
38
+
39
+ def _emit_gpu_event(event_type: str, **data: dict) -> None:
40
+ """Emit structured GPU event to stderr as JSON.
41
+
42
+ Events are written to stderr (not stdout) to avoid interfering with
43
+ command output parsing. Format: JSON with newline.
44
+
45
+ These events can be:
46
+ 1. Parsed from bash output in eval events.jsonl
47
+ 2. Piped to observability systems
48
+ 3. Aggregated for GPU utilization metrics
49
+ """
50
+ from datetime import datetime
51
+
52
+ event = {
53
+ "type": event_type,
54
+ "timestamp": datetime.now(UTC).isoformat(),
55
+ "pid": os.getpid(),
56
+ **data,
57
+ }
58
+ # Write to stderr so it doesn't interfere with stdout capture
59
+ print(f"[GPU_EVENT] {json.dumps(event)}", file=sys.stderr, flush=True)
60
+
61
+
62
+ # Lock directory
63
+ LOCKS_DIR = Path.home() / ".wafer" / "locks"
64
+
65
+
66
+ def _ensure_locks_dir() -> None:
67
+ """Ensure locks directory exists."""
68
+ LOCKS_DIR.mkdir(parents=True, exist_ok=True)
69
+
70
+
71
+ def _lock_path(target_name: str) -> Path:
72
+ """Get path to lock file for a target."""
73
+ return LOCKS_DIR / f"{target_name}.lock"
74
+
75
+
76
+ @contextmanager
77
+ def try_acquire_target(target_name: str) -> Iterator[bool]:
78
+ """Try to acquire exclusive lock on a target.
79
+
80
+ Args:
81
+ target_name: Name of the target to lock
82
+
83
+ Yields:
84
+ True if lock was acquired, False if target is busy
85
+
86
+ The lock is automatically released when the context exits,
87
+ or if the process crashes.
88
+ """
89
+ _ensure_locks_dir()
90
+ lock_file = _lock_path(target_name)
91
+
92
+ # Open or create lock file
93
+ fd = os.open(str(lock_file), os.O_CREAT | os.O_RDWR)
94
+
95
+ try:
96
+ # Try non-blocking exclusive lock
97
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
98
+ # Write PID to lock file for debugging
99
+ os.ftruncate(fd, 0)
100
+ os.write(fd, f"{os.getpid()}\n".encode())
101
+ acquire_time = time.time()
102
+ _emit_gpu_event("gpu_acquire", target=target_name)
103
+ try:
104
+ yield True
105
+ finally:
106
+ # Release lock
107
+ hold_duration_ms = (time.time() - acquire_time) * 1000
108
+ _emit_gpu_event(
109
+ "gpu_release",
110
+ target=target_name,
111
+ hold_duration_ms=round(hold_duration_ms, 1),
112
+ )
113
+ fcntl.flock(fd, fcntl.LOCK_UN)
114
+ except BlockingIOError:
115
+ # Lock is held by another process
116
+ yield False
117
+ finally:
118
+ os.close(fd)
119
+
120
+
121
+ @contextmanager
122
+ def acquire_from_pool(
123
+ target_names: list[str],
124
+ timeout: float | None = None,
125
+ poll_interval: float = 1.0,
126
+ ) -> Iterator[str | None]:
127
+ """Acquire first available target from a list.
128
+
129
+ Tries each target in order, returns the first one that's available.
130
+ If all targets are busy and timeout is set, waits and retries.
131
+
132
+ Args:
133
+ target_names: List of target names to try
134
+ timeout: Max seconds to wait for a target. None = no waiting (fail immediately).
135
+ Use float('inf') to wait forever.
136
+ poll_interval: Seconds between retries when waiting
137
+
138
+ Yields:
139
+ Name of acquired target, or None if all are busy (and timeout expired)
140
+
141
+ Example:
142
+ # Wait up to 5 minutes for a target
143
+ with acquire_from_pool(["gpu-1", "gpu-2", "gpu-3"], timeout=300) as target:
144
+ if target:
145
+ print(f"Got {target}")
146
+ run_eval(target)
147
+ else:
148
+ print("All targets busy after timeout")
149
+ """
150
+ _ensure_locks_dir()
151
+
152
+ start_time = time.monotonic()
153
+
154
+ while True:
155
+ # Try each target in order
156
+ for target_name in target_names:
157
+ lock_file = _lock_path(target_name)
158
+ fd = os.open(str(lock_file), os.O_CREAT | os.O_RDWR)
159
+
160
+ try:
161
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
162
+ # Got the lock - write PID and yield
163
+ os.ftruncate(fd, 0)
164
+ os.write(fd, f"{os.getpid()}\n".encode())
165
+ acquire_time = time.time()
166
+ _emit_gpu_event("gpu_acquire", target=target_name, pool=target_names)
167
+ try:
168
+ yield target_name
169
+ return # Success - exit after context
170
+ finally:
171
+ hold_duration_ms = (time.time() - acquire_time) * 1000
172
+ _emit_gpu_event(
173
+ "gpu_release",
174
+ target=target_name,
175
+ pool=target_names,
176
+ hold_duration_ms=round(hold_duration_ms, 1),
177
+ )
178
+ fcntl.flock(fd, fcntl.LOCK_UN)
179
+ os.close(fd)
180
+ except BlockingIOError:
181
+ # This target is busy, try next
182
+ os.close(fd)
183
+ continue
184
+
185
+ # All targets busy - check if we should wait
186
+ if timeout is None:
187
+ # No waiting, fail immediately
188
+ break
189
+
190
+ elapsed = time.monotonic() - start_time
191
+ if elapsed >= timeout:
192
+ # Timeout expired
193
+ break
194
+
195
+ # Wait and retry
196
+ remaining = timeout - elapsed
197
+ print(f" All targets busy, waiting... ({int(remaining)}s remaining)", file=sys.stderr)
198
+ time.sleep(poll_interval)
199
+
200
+ # All targets busy (timeout expired or no waiting)
201
+ yield None
202
+
203
+
204
+ def is_target_locked(target_name: str) -> bool:
205
+ """Check if a target is currently locked.
206
+
207
+ Note: This is a point-in-time check - the lock status can change
208
+ immediately after this returns.
209
+
210
+ Args:
211
+ target_name: Name of the target to check
212
+
213
+ Returns:
214
+ True if target is locked, False if available
215
+ """
216
+ _ensure_locks_dir()
217
+ lock_file = _lock_path(target_name)
218
+
219
+ if not lock_file.exists():
220
+ return False
221
+
222
+ fd = os.open(str(lock_file), os.O_RDONLY)
223
+ try:
224
+ # Try non-blocking lock
225
+ fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
226
+ # Got it - so it wasn't locked
227
+ fcntl.flock(fd, fcntl.LOCK_UN)
228
+ return False
229
+ except BlockingIOError:
230
+ return True
231
+ finally:
232
+ os.close(fd)
233
+
234
+
235
+ def get_lock_holder(target_name: str) -> int | None:
236
+ """Get PID of process holding lock on a target.
237
+
238
+ Args:
239
+ target_name: Name of the target
240
+
241
+ Returns:
242
+ PID of lock holder, or None if not locked or unknown
243
+ """
244
+ lock_file = _lock_path(target_name)
245
+
246
+ if not lock_file.exists():
247
+ return None
248
+
249
+ try:
250
+ content = lock_file.read_text().strip()
251
+ return int(content)
252
+ except (ValueError, OSError):
253
+ return None
254
+
255
+
256
+ def list_locked_targets() -> list[str]:
257
+ """List all currently locked targets.
258
+
259
+ Returns:
260
+ List of target names that are currently locked
261
+ """
262
+ _ensure_locks_dir()
263
+
264
+ locked = []
265
+ for lock_file in LOCKS_DIR.glob("*.lock"):
266
+ target_name = lock_file.stem
267
+ if is_target_locked(target_name):
268
+ locked.append(target_name)
269
+
270
+ return sorted(locked)