wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/targets_cli.py ADDED
@@ -0,0 +1,472 @@
1
+ """CLI commands for wafer targets — live resource management.
2
+
3
+ These commands always hit provider APIs to show real state.
4
+ Registered as: wafer targets list|show|terminate|sync|provision
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from datetime import UTC, datetime
10
+
11
+ import typer
12
+
13
+ targets_live_app = typer.Typer(
14
+ name="targets",
15
+ help="""Manage live GPU resources across cloud providers.
16
+
17
+ Unlike 'wafer specs' (local config files), these commands query provider APIs
18
+ to show what's actually running.
19
+
20
+ wafer targets list # All running resources
21
+ wafer targets list --unbound # Orphans (no matching spec)
22
+ wafer targets list --provider runpod # Filter by provider
23
+ wafer targets terminate <resource-id> # Kill a resource
24
+ wafer targets terminate --unbound # Kill all orphans
25
+ wafer targets sync # Refresh bindings
26
+ wafer targets provision <spec-name> # Provision from a spec
27
+ """,
28
+ )
29
+
30
+
31
+ @targets_live_app.command("list")
32
+ def targets_list(
33
+ provider: str | None = typer.Option(None, "--provider", "-p", help="Filter by provider"),
34
+ pool: str | None = typer.Option(None, "--pool", help="Filter by pool query from config.toml"),
35
+ ) -> None:
36
+ """List all running GPU resources across providers.
37
+
38
+ Queries RunPod and DigitalOcean APIs to show live state.
39
+
40
+ Examples:
41
+ wafer targets list
42
+ wafer targets list --provider runpod
43
+ wafer targets list --pool mi300x-rocm7
44
+ """
45
+ import trio
46
+ from wafer_core.targets.providers import get_all_cloud_providers, get_provider
47
+ from wafer_core.targets.types import Target, TargetProvider
48
+
49
+ async def _list() -> list[Target]:
50
+ all_targets: list[Target] = []
51
+
52
+ if provider:
53
+ prov = get_provider(provider)
54
+ all_targets = await prov.list_targets()
55
+ else:
56
+ providers = get_all_cloud_providers()
57
+
58
+ async def _fetch(prov_impl: TargetProvider, results: list[Target]) -> None:
59
+ try:
60
+ targets = await prov_impl.list_targets()
61
+ results.extend(targets)
62
+ except Exception as e:
63
+ typer.echo(
64
+ f" Warning: failed to query {type(prov_impl).__name__}: {e}", err=True
65
+ )
66
+
67
+ async with trio.open_nursery() as nursery:
68
+ for _, prov_impl in providers:
69
+ nursery.start_soon(_fetch, prov_impl, all_targets)
70
+
71
+ return all_targets
72
+
73
+ all_targets = trio.run(_list)
74
+
75
+ # Hydrate targets with cached labels
76
+ from dataclasses import replace
77
+ from wafer_core.targets.state_cache import load_all_labels
78
+
79
+ cached_labels = load_all_labels()
80
+ all_targets = [
81
+ replace(t, labels=cached_labels[t.resource_id])
82
+ if t.resource_id in cached_labels
83
+ else t
84
+ for t in all_targets
85
+ ]
86
+
87
+ # Apply pool filter if specified
88
+ if pool:
89
+ from wafer_core.targets.pool import load_pool_query, match_targets
90
+
91
+ try:
92
+ query = load_pool_query(pool)
93
+ except KeyError as e:
94
+ typer.echo(str(e), err=True)
95
+ raise typer.Exit(1) from None
96
+
97
+ all_targets = match_targets(query, all_targets)
98
+ typer.echo(f"Pool {pool!r}: {len(all_targets)} matching target(s)\n")
99
+
100
+ if not all_targets:
101
+ typer.echo("No running resources found.")
102
+ return
103
+
104
+ typer.echo(f"{len(all_targets)} resource(s):\n")
105
+ for target in all_targets:
106
+ _print_target(target)
107
+
108
+
109
+ def _print_target(target: Target) -> None:
110
+ """Print a single target's info."""
111
+ ssh_info = ""
112
+ if target.public_ip and target.ssh_port:
113
+ ssh_info = f" ssh={target.ssh_username}@{target.public_ip}:{target.ssh_port}"
114
+
115
+ name_part = f" name={target.name}" if target.name else ""
116
+ spec_part = f" spec={target.spec_name}" if target.spec_name else ""
117
+ price_part = f" ${target.price_per_hour:.2f}/hr" if target.price_per_hour else ""
118
+
119
+ # Show interesting labels (skip 'image' — too long)
120
+ label_keys = sorted(k for k in target.labels if k != "image")
121
+ labels_part = ""
122
+ if label_keys:
123
+ labels_part = " " + " ".join(f"{k}={target.labels[k]}" for k in label_keys)
124
+
125
+ typer.echo(
126
+ f" {target.resource_id} [{target.provider}] "
127
+ f"status={target.status} gpu={target.gpu_type}"
128
+ f"{spec_part}{name_part}{ssh_info}{price_part}{labels_part}"
129
+ )
130
+ typer.echo()
131
+
132
+
133
+ @targets_live_app.command("terminate")
134
+ def targets_terminate(
135
+ resource_id: str | None = typer.Argument(None, help="Resource ID to terminate"),
136
+ pool_name: str | None = typer.Option(
137
+ None, "--pool", help="Terminate all targets matching a pool query"
138
+ ),
139
+ provider_name: str | None = typer.Option(
140
+ None, "--provider", "-p", help="Provider hint (avoids querying all providers)"
141
+ ),
142
+ yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"),
143
+ ) -> None:
144
+ """Terminate a running resource by ID, or all targets matching a pool query.
145
+
146
+ Examples:
147
+ wafer targets terminate tkru24z7npcgth
148
+ wafer targets terminate --pool mi300x --yes
149
+ wafer targets terminate --pool runpod-only --provider runpod
150
+ """
151
+ import trio
152
+ from wafer_core.targets.providers import get_all_cloud_providers, get_provider
153
+ from wafer_core.targets.state_cache import remove_binding
154
+
155
+ if pool_name:
156
+ _terminate_pool(pool_name, provider_name, yes)
157
+ return
158
+
159
+ if not resource_id:
160
+ typer.echo("Provide a resource ID or use --pool <name>.", err=True)
161
+ raise typer.Exit(1)
162
+
163
+ async def _terminate() -> bool:
164
+ if provider_name:
165
+ prov = get_provider(provider_name)
166
+ return await prov.terminate(resource_id)
167
+
168
+ for name, prov in get_all_cloud_providers():
169
+ target = await prov.get_target(resource_id)
170
+ if target is not None:
171
+ success = await prov.terminate(resource_id)
172
+ if success:
173
+ remove_binding(resource_id)
174
+ typer.echo(f"Terminated {resource_id} ({name})")
175
+ return success
176
+
177
+ typer.echo(f"Resource {resource_id} not found on any provider.", err=True)
178
+ return False
179
+
180
+ success = trio.run(_terminate)
181
+ if not success:
182
+ raise typer.Exit(1)
183
+
184
+
185
+ def _terminate_pool(pool_name: str, provider_name: str | None, yes: bool) -> None:
186
+ """Terminate all targets matching a pool query."""
187
+ import trio
188
+ from wafer_core.targets.pool import load_pool_query, match_targets
189
+ from wafer_core.targets.providers import get_all_cloud_providers, get_provider
190
+ from wafer_core.targets.state_cache import remove_binding
191
+ from wafer_core.targets.types import Target
192
+
193
+ try:
194
+ query = load_pool_query(pool_name)
195
+ except KeyError as e:
196
+ typer.echo(str(e), err=True)
197
+ raise typer.Exit(1) from None
198
+
199
+ async def _do_terminate() -> int:
200
+ all_targets: list[Target] = []
201
+ if provider_name:
202
+ prov = get_provider(provider_name)
203
+ all_targets = await prov.list_targets()
204
+ else:
205
+ for _, prov in get_all_cloud_providers():
206
+ try:
207
+ all_targets.extend(await prov.list_targets())
208
+ except Exception:
209
+ pass
210
+
211
+ matched = match_targets(query, all_targets)
212
+
213
+ if not matched:
214
+ typer.echo(f"No targets match pool {pool_name!r}.")
215
+ return 0
216
+
217
+ typer.echo(f"Found {len(matched)} target(s) matching pool {pool_name!r}:")
218
+ for t in matched:
219
+ name_part = f" name={t.name}" if t.name else ""
220
+ typer.echo(f" {t.resource_id} [{t.provider}] gpu={t.gpu_type}{name_part}")
221
+
222
+ if not yes:
223
+ confirm = typer.confirm("Terminate all?")
224
+ if not confirm:
225
+ return 0
226
+
227
+ count = 0
228
+ for t in matched:
229
+ prov = get_provider(t.provider)
230
+ if await prov.terminate(t.resource_id):
231
+ remove_binding(t.resource_id)
232
+ typer.echo(f" Terminated {t.resource_id}")
233
+ count += 1
234
+ else:
235
+ typer.echo(f" Failed to terminate {t.resource_id}", err=True)
236
+
237
+ return count
238
+
239
+ count = trio.run(_do_terminate)
240
+ typer.echo(f"\nTerminated {count} resource(s).")
241
+
242
+
243
+ @targets_live_app.command("reconcile")
244
+ def targets_reconcile() -> None:
245
+ """Refresh local binding cache from provider APIs.
246
+
247
+ Queries all cloud providers, matches resources to specs, and updates
248
+ the local state cache. Reports any drift.
249
+
250
+ Example:
251
+ wafer targets reconcile
252
+ """
253
+ import trio
254
+ from wafer_core.targets.providers import get_all_cloud_providers
255
+ from wafer_core.targets.reconcile import reconcile
256
+ from wafer_core.targets.spec_store import load_all_specs
257
+ from wafer_core.targets.state_cache import (
258
+ BindingEntry,
259
+ get_binding_hints,
260
+ save_bindings,
261
+ )
262
+ from wafer_core.targets.types import Target
263
+
264
+ async def _sync() -> None:
265
+ specs = load_all_specs()
266
+
267
+ all_targets: list[Target] = []
268
+ for name, prov in get_all_cloud_providers():
269
+ typer.echo(f"Querying {name}...")
270
+ try:
271
+ targets = await prov.list_targets()
272
+ typer.echo(f" Found {len(targets)} resource(s)")
273
+ all_targets.extend(targets)
274
+ except Exception as e:
275
+ typer.echo(f" Failed: {e}", err=True)
276
+
277
+ hints = get_binding_hints()
278
+ result = reconcile(specs, all_targets, binding_hints=hints)
279
+
280
+ # Update binding cache with bound results
281
+ new_bindings = {}
282
+ now = datetime.now(UTC).isoformat()
283
+ for spec, target in result.bound:
284
+ new_bindings[target.resource_id] = BindingEntry(
285
+ spec_name=spec.name,
286
+ provider=target.provider,
287
+ bound_at=now,
288
+ )
289
+ save_bindings(new_bindings)
290
+
291
+ typer.echo("\nSync complete:")
292
+ typer.echo(f" Total resources: {len(all_targets)}")
293
+ typer.echo(f" Matched to specs: {len(result.bound)}")
294
+ typer.echo(f" No matching spec: {len(result.unbound)}")
295
+
296
+ trio.run(_sync)
297
+
298
+
299
+ @targets_live_app.command("provision")
300
+ def targets_provision(
301
+ spec_name: str = typer.Argument(..., help="Spec name to provision from"),
302
+ ) -> None:
303
+ """Explicitly provision a resource from a spec.
304
+
305
+ Creates a new cloud resource and binds it to the spec.
306
+
307
+ Example:
308
+ wafer targets provision runpod-mi300x
309
+ """
310
+ import trio
311
+ from wafer_core.targets.providers import get_provider
312
+ from wafer_core.targets.spec_store import load_spec
313
+ from wafer_core.targets.state_cache import BindingEntry, add_binding
314
+ from wafer_core.utils.kernel_utils.targets.config import (
315
+ DigitalOceanTarget,
316
+ RunPodTarget,
317
+ )
318
+
319
+ try:
320
+ spec = load_spec(spec_name)
321
+ except FileNotFoundError:
322
+ typer.echo(f"Spec not found: {spec_name}", err=True)
323
+ raise typer.Exit(1) from None
324
+
325
+ if isinstance(spec, RunPodTarget):
326
+ provider_name = "runpod"
327
+ elif isinstance(spec, DigitalOceanTarget):
328
+ provider_name = "digitalocean"
329
+ else:
330
+ typer.echo(f"Spec type {type(spec).__name__} cannot be provisioned.", err=True)
331
+ raise typer.Exit(1) from None
332
+
333
+ async def _provision() -> None:
334
+ from wafer_core.targets.probe import probe_target_labels
335
+ from wafer_core.targets.state_cache import save_labels
336
+
337
+ prov = get_provider(provider_name)
338
+ typer.echo(f"Provisioning {spec_name} via {provider_name}...")
339
+ target = await prov.provision(spec)
340
+
341
+ # Cache the binding
342
+ add_binding(
343
+ target.resource_id,
344
+ BindingEntry(
345
+ spec_name=spec_name,
346
+ provider=provider_name,
347
+ bound_at=datetime.now(UTC).isoformat(),
348
+ ),
349
+ )
350
+
351
+ typer.echo(f"\nProvisioned: {target.resource_id}")
352
+ if target.public_ip:
353
+ typer.echo(f" SSH: {target.ssh_username}@{target.public_ip}:{target.ssh_port}")
354
+
355
+ # Probe software labels (sync — runs subprocess ssh)
356
+ if target.public_ip and target.ssh_port:
357
+ typer.echo(" Probing software versions...")
358
+ try:
359
+ ssh_key = spec.ssh_key if hasattr(spec, "ssh_key") else None
360
+ labels = probe_target_labels(
361
+ host=target.public_ip,
362
+ port=target.ssh_port,
363
+ username=target.ssh_username,
364
+ ssh_key_path=ssh_key,
365
+ )
366
+ save_labels(target.resource_id, labels)
367
+ if labels:
368
+ typer.echo(f" Labels: {' '.join(f'{k}={v}' for k, v in sorted(labels.items()))}")
369
+ except Exception as e:
370
+ typer.echo(f" Warning: probe failed: {e}", err=True)
371
+
372
+ trio.run(_provision)
373
+
374
+
375
+ @targets_live_app.command("pools")
376
+ def targets_pools() -> None:
377
+ """List configured pool queries from config.toml.
378
+
379
+ Example:
380
+ wafer targets pools
381
+ """
382
+ from wafer_core.targets.pool import list_pool_names, load_pool_query
383
+
384
+ names = list_pool_names()
385
+ if not names:
386
+ typer.echo("No pools configured in ~/.wafer/config.toml.")
387
+ typer.echo("\nAdd a pool:\n")
388
+ typer.echo(" [pools.mi300x]")
389
+ typer.echo(' gpu_type = "MI300X"')
390
+ typer.echo("")
391
+ typer.echo(" [pools.mi300x-rocm7]")
392
+ typer.echo(' gpu_type = "MI300X"')
393
+ typer.echo(" [pools.mi300x-rocm7.labels]")
394
+ typer.echo(' rocm_version = "7.0.2"')
395
+ return
396
+
397
+ typer.echo(f"{len(names)} pool(s):\n")
398
+ for name in names:
399
+ query = load_pool_query(name)
400
+ parts = []
401
+ if query.gpu_type:
402
+ parts.append(f"gpu_type={query.gpu_type}")
403
+ if query.provider:
404
+ parts.append(f"provider={query.provider}")
405
+ if query.status and query.status != "running":
406
+ parts.append(f"status={query.status}")
407
+ for k, v in sorted(query.labels.items()):
408
+ parts.append(f"{k}={v}")
409
+ criteria = " ".join(parts) if parts else "(match all)"
410
+ typer.echo(f" {name}: {criteria}")
411
+
412
+
413
+ @targets_live_app.command("probe")
414
+ def targets_probe(
415
+ resource_id: str = typer.Argument(..., help="Resource ID to probe"),
416
+ provider_name: str | None = typer.Option(
417
+ None, "--provider", "-p", help="Provider hint (avoids querying all providers)"
418
+ ),
419
+ ) -> None:
420
+ """Probe a running target's software versions via SSH.
421
+
422
+ Results are cached in ~/.wafer/target_state.json and shown
423
+ by wafer targets list. Used for targets not provisioned by wafer
424
+ (e.g. dashboard-created pods).
425
+
426
+ Examples:
427
+ wafer targets probe ewfo5ckpxlg7y2
428
+ wafer targets probe 543538453 --provider digitalocean
429
+ """
430
+ import trio
431
+ from wafer_core.targets.probe import probe_target_labels
432
+ from wafer_core.targets.providers import get_all_cloud_providers, get_provider
433
+ from wafer_core.targets.state_cache import save_labels
434
+
435
+ # Find the target (async — needs provider API)
436
+ async def _find_target():
437
+ if provider_name:
438
+ prov = get_provider(provider_name)
439
+ return await prov.get_target(resource_id)
440
+
441
+ for _, prov in get_all_cloud_providers():
442
+ target = await prov.get_target(resource_id)
443
+ if target is not None:
444
+ return target
445
+ return None
446
+
447
+ target = trio.run(_find_target)
448
+
449
+ if target is None:
450
+ typer.echo(f"Resource {resource_id} not found.", err=True)
451
+ raise typer.Exit(1)
452
+
453
+ if not target.public_ip or not target.ssh_port:
454
+ typer.echo(f"Resource {resource_id} has no SSH info (status={target.status}).", err=True)
455
+ raise typer.Exit(1)
456
+
457
+ typer.echo(f"Probing {resource_id} ({target.ssh_username}@{target.public_ip}:{target.ssh_port})...")
458
+
459
+ labels = probe_target_labels(
460
+ host=target.public_ip,
461
+ port=target.ssh_port,
462
+ username=target.ssh_username,
463
+ )
464
+
465
+ save_labels(resource_id, labels)
466
+
467
+ if labels:
468
+ typer.echo(f"Labels cached for {resource_id}:")
469
+ for k, v in sorted(labels.items()):
470
+ typer.echo(f" {k}={v}")
471
+ else:
472
+ typer.echo("Probe returned no labels.")
wafer/targets_ops.py CHANGED
@@ -15,6 +15,7 @@ import logging
15
15
  import subprocess
16
16
  from collections.abc import Callable
17
17
  from dataclasses import dataclass, replace
18
+ from datetime import UTC
18
19
  from pathlib import Path
19
20
  from typing import TYPE_CHECKING
20
21
 
@@ -30,6 +31,26 @@ if TYPE_CHECKING:
30
31
  logger = logging.getLogger(__name__)
31
32
 
32
33
 
34
+ def _update_binding_cache(resource_id: str, spec_name: str, provider: str) -> None:
35
+ """Update the new target state cache when provisioning through the legacy path.
36
+
37
+ This bridges the old per-provider state files with the new unified cache
38
+ so that `wafer targets list` can see resources provisioned via the old flow.
39
+ """
40
+ from datetime import datetime
41
+
42
+ from wafer_core.targets.state_cache import BindingEntry, add_binding
43
+
44
+ add_binding(
45
+ resource_id,
46
+ BindingEntry(
47
+ spec_name=spec_name,
48
+ provider=provider,
49
+ bound_at=datetime.now(UTC).isoformat(),
50
+ ),
51
+ )
52
+
53
+
33
54
  @dataclass(frozen=True)
34
55
  class TargetSSHInfo:
35
56
  """SSH connection info for a target."""
@@ -135,7 +156,8 @@ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
135
156
  # Check if pod already exists and is running
136
157
  existing = get_pod_state(target.name)
137
158
  if existing and await check_pod_running(existing.pod_id):
138
- # Reuse existing pod
159
+ # Reuse existing pod — also update the new state cache
160
+ _update_binding_cache(existing.pod_id, target.name, "runpod")
139
161
  return TargetSSHInfo(
140
162
  host=existing.public_ip,
141
163
  port=existing.ssh_port,
@@ -151,6 +173,8 @@ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
151
173
  target_keep_alive = replace(target, keep_alive=True)
152
174
 
153
175
  async with runpod_ssh_context(target_keep_alive) as ssh_info:
176
+ # Update new state cache with provisioned pod
177
+ _update_binding_cache(ssh_info.pod_id, target.name, "runpod")
154
178
  return TargetSSHInfo(
155
179
  host=ssh_info.host,
156
180
  port=ssh_info.port,
@@ -172,7 +196,8 @@ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInf
172
196
  # Check if droplet already exists and is running
173
197
  existing = get_droplet_state(target.name)
174
198
  if existing and await check_droplet_running(existing.droplet_id):
175
- # Reuse existing droplet
199
+ # Reuse existing droplet — also update the new state cache
200
+ _update_binding_cache(existing.droplet_id, target.name, "digitalocean")
176
201
  return TargetSSHInfo(
177
202
  host=existing.public_ip,
178
203
  port=22, # DigitalOcean uses standard SSH port
@@ -184,6 +209,8 @@ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInf
184
209
  target_keep_alive = replace(target, keep_alive=True)
185
210
 
186
211
  async with digitalocean_ssh_context(target_keep_alive) as ssh_info:
212
+ # Update new state cache with provisioned droplet
213
+ _update_binding_cache(ssh_info.droplet_id, target.name, "digitalocean")
187
214
  return TargetSSHInfo(
188
215
  host=ssh_info.host,
189
216
  port=ssh_info.port,
@@ -51,7 +51,7 @@ Output your answer directly. Be concise but thorough. Include code examples when
51
51
  "python -c",
52
52
  ],
53
53
  # Model config
54
- model="anthropic/claude-sonnet-4-5-20250929",
54
+ model="anthropic/claude-opus-4-5-20251101",
55
55
  max_tokens=8192,
56
56
  # Thinking config - disabled for simple doc queries
57
57
  thinking=False,
@@ -56,7 +56,7 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
56
56
  "python -c",
57
57
  ],
58
58
  # Model config - use thinking for complex optimization reasoning
59
- model="anthropic/claude-sonnet-4-5-20250929",
59
+ model="anthropic/claude-opus-4-5-20251101",
60
60
  max_tokens=16384,
61
61
  # Thinking config - enabled for complex kernel optimization
62
62
  thinking=True,
@@ -68,4 +68,6 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
68
68
  "kernel": "./kernel.cu",
69
69
  "target": "H100",
70
70
  },
71
+ # Enable skill discovery (agent can load wafer-guide, etc.)
72
+ include_skills=True,
71
73
  )
@@ -1,4 +1,4 @@
1
- """Template for KernelBench optimization - matches eval system prompt.
1
+ """Template for KernelBench optimization.
2
2
 
3
3
  Usage:
4
4
  # Run on a specific problem
@@ -26,12 +26,18 @@ try:
26
26
  except ImportError:
27
27
  from rollouts.templates import TemplateConfig
28
28
 
29
- # System prompt matches optimize_kernelbench_eval/base_config.py SYSTEM_PROMPT
29
+ from wafer.agent_defaults import ENABLED_TOOLS, KERNELBENCH_BASH_ALLOWLIST
30
+
31
+ # Task-specific instructions only — must stay in sync with the eval's SYSTEM_PROMPT
32
+ # in research/evals/optimize_kernelbench_eval/.../base_config.py.
33
+ # Run test_eval_cli_parity.py to verify.
34
+ # Wafer CLI command docs are auto-generated from --help text and composed
35
+ # at runtime by wevin_cli.py (see wafer.cli_instructions.build_cli_instructions).
36
+ # TODO: Consider having both eval and template import SYSTEM_PROMPT from a shared
37
+ # module so there's only one copy to maintain.
30
38
  SYSTEM_PROMPT = """\
31
39
  You are a GPU kernel optimization expert. Your task is to write optimized GPU kernels that are correct and faster than the PyTorch baseline.
32
40
 
33
- IMPORTANT: You do NOT have a local GPU. You MUST use `wafer evaluate kernelbench` to test kernels on remote GPU hardware.
34
-
35
41
  ## Kernel Format (KernelBench)
36
42
 
37
43
  The reference file contains a PyTorch `Model` class. You must write a `ModelNew` class that:
@@ -43,49 +49,14 @@ The reference file also provides:
43
49
  - `get_inputs()` - generates test inputs for forward()
44
50
  - `get_init_inputs()` - generates constructor arguments
45
51
 
46
- ## Available Tools
47
-
48
- - read(file_path): Read source files
49
- - write(file_path, content): Write your optimized kernel
50
- - glob(pattern): Find files by pattern
51
- - grep(pattern): Search code
52
- - bash(command): Run shell commands including wafer CLI
53
-
54
52
  ## Workflow
55
53
 
56
54
  1. Read the reference problem file to understand what `Model` does
57
55
  2. Analyze the computation and identify optimization opportunities
58
56
  3. Write an optimized `ModelNew` class with custom $backend_upper kernels using `__global__` kernel definitions and `torch.utils.cpp_extension.load_inline`
59
- 4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl <your_file.py> --reference <problem.py> --benchmark`
57
+ 4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl optimized.py --reference <problem.py> --benchmark`
60
58
  5. Iterate based on feedback until correct and fast
61
59
 
62
- ## Example Command
63
-
64
- ```bash
65
- wafer evaluate kernelbench \\
66
- $target_flag \\
67
- --backend $backend \\
68
- --impl optimized_kernel.py \\
69
- --reference $reference \\
70
- --benchmark
71
- ```
72
-
73
- ## Profiling Tools (USE THESE!)
74
-
75
- When your kernel is slower than expected, use profiling to understand WHY:
76
-
77
- - `wafer rocprof profile --impl <file> --reference <ref>` - AMD GPU profiling
78
- - `wafer nvidia ncu --impl <file> --reference <ref>` - NVIDIA NCU profiling
79
-
80
- ## CRITICAL: Reactive Debugging
81
-
82
- After EVERY `wafer evaluate` call:
83
- 1. Check the speedup result
84
- 2. If speedup < 1.0x (slowdown), STOP and analyze:
85
- - Run profiling to identify the bottleneck
86
- - Ask: "Why is this slow?" before trying another approach
87
- 3. Don't just try random optimizations - understand the root cause
88
-
89
60
  Your kernel MUST:
90
61
  - Pass correctness tests (outputs match reference within tolerance)
91
62
  - Achieve speedup > 1.0x over PyTorch baseline
@@ -96,32 +67,16 @@ You MUST run `wafer evaluate kernelbench` to verify your kernel. Your score depe
96
67
  template = TemplateConfig(
97
68
  # Identity
98
69
  name="optimize-kernelbench",
99
- description="Optimize KernelBench problems (matches eval system prompt)",
100
- # System prompt
70
+ description="Optimize KernelBench problems",
71
+ # System prompt (task-specific; CLI docs appended at runtime)
101
72
  system_prompt=SYSTEM_PROMPT,
102
73
  # Tools
103
- tools=["read", "write", "edit", "glob", "grep", "bash"],
104
- bash_allowlist=[
105
- "wafer evaluate",
106
- "wafer nvidia ncu",
107
- "wafer nvidia nsys",
108
- "wafer rocprof",
109
- "wafer compiler-analyze",
110
- "python",
111
- "python3",
112
- "timeout",
113
- "ls",
114
- "cat",
115
- "head",
116
- "tail",
117
- "wc",
118
- "pwd",
119
- "which",
120
- ],
121
- # Model config - match eval settings
74
+ tools=ENABLED_TOOLS,
75
+ bash_allowlist=KERNELBENCH_BASH_ALLOWLIST,
76
+ # Model config
122
77
  model="anthropic/claude-opus-4-5-20251101",
123
78
  max_tokens=8192,
124
- # No thinking by default (match eval), can override with --thinking
79
+ # No thinking by default, can override with --thinking
125
80
  thinking=False,
126
81
  # Multi-turn for iterative optimization
127
82
  single_turn=False,
@@ -60,7 +60,7 @@ Use `--json` flags when available for structured output that's easier to parse.
60
60
  "python -c",
61
61
  ],
62
62
  # Model config
63
- model="anthropic/claude-sonnet-4-5-20250929",
63
+ model="anthropic/claude-opus-4-5-20251101",
64
64
  max_tokens=8192,
65
65
  # Thinking config - disabled for trace analysis (mostly parsing)
66
66
  thinking=False,