wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/auth.py +7 -0
- wafer/billing.py +6 -6
- wafer/cli.py +905 -131
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +313 -15
- wafer/evaluate.py +480 -146
- wafer/global_config.py +13 -0
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/specs_cli.py +157 -0
- wafer/ssh_keys.py +6 -6
- wafer/targets_cli.py +472 -0
- wafer/targets_ops.py +29 -2
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +3 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +274 -0
- wafer/wevin_cli.py +125 -26
- wafer/workspaces.py +163 -16
- wafer_cli-0.2.30.dist-info/METADATA +107 -0
- wafer_cli-0.2.30.dist-info/RECORD +47 -0
- wafer_cli-0.2.14.dist-info/METADATA +0 -16
- wafer_cli-0.2.14.dist-info/RECORD +0 -41
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/top_level.txt +0 -0
wafer/targets_cli.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
1
|
+
"""CLI commands for wafer targets — live resource management.
|
|
2
|
+
|
|
3
|
+
These commands always hit provider APIs to show real state.
|
|
4
|
+
Registered as: wafer targets list|show|terminate|sync|provision
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from datetime import UTC, datetime
|
|
10
|
+
|
|
11
|
+
import typer
|
|
12
|
+
|
|
13
|
+
targets_live_app = typer.Typer(
|
|
14
|
+
name="targets",
|
|
15
|
+
help="""Manage live GPU resources across cloud providers.
|
|
16
|
+
|
|
17
|
+
Unlike 'wafer specs' (local config files), these commands query provider APIs
|
|
18
|
+
to show what's actually running.
|
|
19
|
+
|
|
20
|
+
wafer targets list # All running resources
|
|
21
|
+
wafer targets list --unbound # Orphans (no matching spec)
|
|
22
|
+
wafer targets list --provider runpod # Filter by provider
|
|
23
|
+
wafer targets terminate <resource-id> # Kill a resource
|
|
24
|
+
wafer targets terminate --unbound # Kill all orphans
|
|
25
|
+
wafer targets sync # Refresh bindings
|
|
26
|
+
wafer targets provision <spec-name> # Provision from a spec
|
|
27
|
+
""",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@targets_live_app.command("list")
|
|
32
|
+
def targets_list(
|
|
33
|
+
provider: str | None = typer.Option(None, "--provider", "-p", help="Filter by provider"),
|
|
34
|
+
pool: str | None = typer.Option(None, "--pool", help="Filter by pool query from config.toml"),
|
|
35
|
+
) -> None:
|
|
36
|
+
"""List all running GPU resources across providers.
|
|
37
|
+
|
|
38
|
+
Queries RunPod and DigitalOcean APIs to show live state.
|
|
39
|
+
|
|
40
|
+
Examples:
|
|
41
|
+
wafer targets list
|
|
42
|
+
wafer targets list --provider runpod
|
|
43
|
+
wafer targets list --pool mi300x-rocm7
|
|
44
|
+
"""
|
|
45
|
+
import trio
|
|
46
|
+
from wafer_core.targets.providers import get_all_cloud_providers, get_provider
|
|
47
|
+
from wafer_core.targets.types import Target, TargetProvider
|
|
48
|
+
|
|
49
|
+
async def _list() -> list[Target]:
|
|
50
|
+
all_targets: list[Target] = []
|
|
51
|
+
|
|
52
|
+
if provider:
|
|
53
|
+
prov = get_provider(provider)
|
|
54
|
+
all_targets = await prov.list_targets()
|
|
55
|
+
else:
|
|
56
|
+
providers = get_all_cloud_providers()
|
|
57
|
+
|
|
58
|
+
async def _fetch(prov_impl: TargetProvider, results: list[Target]) -> None:
|
|
59
|
+
try:
|
|
60
|
+
targets = await prov_impl.list_targets()
|
|
61
|
+
results.extend(targets)
|
|
62
|
+
except Exception as e:
|
|
63
|
+
typer.echo(
|
|
64
|
+
f" Warning: failed to query {type(prov_impl).__name__}: {e}", err=True
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
async with trio.open_nursery() as nursery:
|
|
68
|
+
for _, prov_impl in providers:
|
|
69
|
+
nursery.start_soon(_fetch, prov_impl, all_targets)
|
|
70
|
+
|
|
71
|
+
return all_targets
|
|
72
|
+
|
|
73
|
+
all_targets = trio.run(_list)
|
|
74
|
+
|
|
75
|
+
# Hydrate targets with cached labels
|
|
76
|
+
from dataclasses import replace
|
|
77
|
+
from wafer_core.targets.state_cache import load_all_labels
|
|
78
|
+
|
|
79
|
+
cached_labels = load_all_labels()
|
|
80
|
+
all_targets = [
|
|
81
|
+
replace(t, labels=cached_labels[t.resource_id])
|
|
82
|
+
if t.resource_id in cached_labels
|
|
83
|
+
else t
|
|
84
|
+
for t in all_targets
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
# Apply pool filter if specified
|
|
88
|
+
if pool:
|
|
89
|
+
from wafer_core.targets.pool import load_pool_query, match_targets
|
|
90
|
+
|
|
91
|
+
try:
|
|
92
|
+
query = load_pool_query(pool)
|
|
93
|
+
except KeyError as e:
|
|
94
|
+
typer.echo(str(e), err=True)
|
|
95
|
+
raise typer.Exit(1) from None
|
|
96
|
+
|
|
97
|
+
all_targets = match_targets(query, all_targets)
|
|
98
|
+
typer.echo(f"Pool {pool!r}: {len(all_targets)} matching target(s)\n")
|
|
99
|
+
|
|
100
|
+
if not all_targets:
|
|
101
|
+
typer.echo("No running resources found.")
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
typer.echo(f"{len(all_targets)} resource(s):\n")
|
|
105
|
+
for target in all_targets:
|
|
106
|
+
_print_target(target)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _print_target(target: Target) -> None:
|
|
110
|
+
"""Print a single target's info."""
|
|
111
|
+
ssh_info = ""
|
|
112
|
+
if target.public_ip and target.ssh_port:
|
|
113
|
+
ssh_info = f" ssh={target.ssh_username}@{target.public_ip}:{target.ssh_port}"
|
|
114
|
+
|
|
115
|
+
name_part = f" name={target.name}" if target.name else ""
|
|
116
|
+
spec_part = f" spec={target.spec_name}" if target.spec_name else ""
|
|
117
|
+
price_part = f" ${target.price_per_hour:.2f}/hr" if target.price_per_hour else ""
|
|
118
|
+
|
|
119
|
+
# Show interesting labels (skip 'image' — too long)
|
|
120
|
+
label_keys = sorted(k for k in target.labels if k != "image")
|
|
121
|
+
labels_part = ""
|
|
122
|
+
if label_keys:
|
|
123
|
+
labels_part = " " + " ".join(f"{k}={target.labels[k]}" for k in label_keys)
|
|
124
|
+
|
|
125
|
+
typer.echo(
|
|
126
|
+
f" {target.resource_id} [{target.provider}] "
|
|
127
|
+
f"status={target.status} gpu={target.gpu_type}"
|
|
128
|
+
f"{spec_part}{name_part}{ssh_info}{price_part}{labels_part}"
|
|
129
|
+
)
|
|
130
|
+
typer.echo()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@targets_live_app.command("terminate")
|
|
134
|
+
def targets_terminate(
|
|
135
|
+
resource_id: str | None = typer.Argument(None, help="Resource ID to terminate"),
|
|
136
|
+
pool_name: str | None = typer.Option(
|
|
137
|
+
None, "--pool", help="Terminate all targets matching a pool query"
|
|
138
|
+
),
|
|
139
|
+
provider_name: str | None = typer.Option(
|
|
140
|
+
None, "--provider", "-p", help="Provider hint (avoids querying all providers)"
|
|
141
|
+
),
|
|
142
|
+
yes: bool = typer.Option(False, "--yes", "-y", help="Skip confirmation"),
|
|
143
|
+
) -> None:
|
|
144
|
+
"""Terminate a running resource by ID, or all targets matching a pool query.
|
|
145
|
+
|
|
146
|
+
Examples:
|
|
147
|
+
wafer targets terminate tkru24z7npcgth
|
|
148
|
+
wafer targets terminate --pool mi300x --yes
|
|
149
|
+
wafer targets terminate --pool runpod-only --provider runpod
|
|
150
|
+
"""
|
|
151
|
+
import trio
|
|
152
|
+
from wafer_core.targets.providers import get_all_cloud_providers, get_provider
|
|
153
|
+
from wafer_core.targets.state_cache import remove_binding
|
|
154
|
+
|
|
155
|
+
if pool_name:
|
|
156
|
+
_terminate_pool(pool_name, provider_name, yes)
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
if not resource_id:
|
|
160
|
+
typer.echo("Provide a resource ID or use --pool <name>.", err=True)
|
|
161
|
+
raise typer.Exit(1)
|
|
162
|
+
|
|
163
|
+
async def _terminate() -> bool:
|
|
164
|
+
if provider_name:
|
|
165
|
+
prov = get_provider(provider_name)
|
|
166
|
+
return await prov.terminate(resource_id)
|
|
167
|
+
|
|
168
|
+
for name, prov in get_all_cloud_providers():
|
|
169
|
+
target = await prov.get_target(resource_id)
|
|
170
|
+
if target is not None:
|
|
171
|
+
success = await prov.terminate(resource_id)
|
|
172
|
+
if success:
|
|
173
|
+
remove_binding(resource_id)
|
|
174
|
+
typer.echo(f"Terminated {resource_id} ({name})")
|
|
175
|
+
return success
|
|
176
|
+
|
|
177
|
+
typer.echo(f"Resource {resource_id} not found on any provider.", err=True)
|
|
178
|
+
return False
|
|
179
|
+
|
|
180
|
+
success = trio.run(_terminate)
|
|
181
|
+
if not success:
|
|
182
|
+
raise typer.Exit(1)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _terminate_pool(pool_name: str, provider_name: str | None, yes: bool) -> None:
|
|
186
|
+
"""Terminate all targets matching a pool query."""
|
|
187
|
+
import trio
|
|
188
|
+
from wafer_core.targets.pool import load_pool_query, match_targets
|
|
189
|
+
from wafer_core.targets.providers import get_all_cloud_providers, get_provider
|
|
190
|
+
from wafer_core.targets.state_cache import remove_binding
|
|
191
|
+
from wafer_core.targets.types import Target
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
query = load_pool_query(pool_name)
|
|
195
|
+
except KeyError as e:
|
|
196
|
+
typer.echo(str(e), err=True)
|
|
197
|
+
raise typer.Exit(1) from None
|
|
198
|
+
|
|
199
|
+
async def _do_terminate() -> int:
|
|
200
|
+
all_targets: list[Target] = []
|
|
201
|
+
if provider_name:
|
|
202
|
+
prov = get_provider(provider_name)
|
|
203
|
+
all_targets = await prov.list_targets()
|
|
204
|
+
else:
|
|
205
|
+
for _, prov in get_all_cloud_providers():
|
|
206
|
+
try:
|
|
207
|
+
all_targets.extend(await prov.list_targets())
|
|
208
|
+
except Exception:
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
matched = match_targets(query, all_targets)
|
|
212
|
+
|
|
213
|
+
if not matched:
|
|
214
|
+
typer.echo(f"No targets match pool {pool_name!r}.")
|
|
215
|
+
return 0
|
|
216
|
+
|
|
217
|
+
typer.echo(f"Found {len(matched)} target(s) matching pool {pool_name!r}:")
|
|
218
|
+
for t in matched:
|
|
219
|
+
name_part = f" name={t.name}" if t.name else ""
|
|
220
|
+
typer.echo(f" {t.resource_id} [{t.provider}] gpu={t.gpu_type}{name_part}")
|
|
221
|
+
|
|
222
|
+
if not yes:
|
|
223
|
+
confirm = typer.confirm("Terminate all?")
|
|
224
|
+
if not confirm:
|
|
225
|
+
return 0
|
|
226
|
+
|
|
227
|
+
count = 0
|
|
228
|
+
for t in matched:
|
|
229
|
+
prov = get_provider(t.provider)
|
|
230
|
+
if await prov.terminate(t.resource_id):
|
|
231
|
+
remove_binding(t.resource_id)
|
|
232
|
+
typer.echo(f" Terminated {t.resource_id}")
|
|
233
|
+
count += 1
|
|
234
|
+
else:
|
|
235
|
+
typer.echo(f" Failed to terminate {t.resource_id}", err=True)
|
|
236
|
+
|
|
237
|
+
return count
|
|
238
|
+
|
|
239
|
+
count = trio.run(_do_terminate)
|
|
240
|
+
typer.echo(f"\nTerminated {count} resource(s).")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@targets_live_app.command("reconcile")
|
|
244
|
+
def targets_reconcile() -> None:
|
|
245
|
+
"""Refresh local binding cache from provider APIs.
|
|
246
|
+
|
|
247
|
+
Queries all cloud providers, matches resources to specs, and updates
|
|
248
|
+
the local state cache. Reports any drift.
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
wafer targets reconcile
|
|
252
|
+
"""
|
|
253
|
+
import trio
|
|
254
|
+
from wafer_core.targets.providers import get_all_cloud_providers
|
|
255
|
+
from wafer_core.targets.reconcile import reconcile
|
|
256
|
+
from wafer_core.targets.spec_store import load_all_specs
|
|
257
|
+
from wafer_core.targets.state_cache import (
|
|
258
|
+
BindingEntry,
|
|
259
|
+
get_binding_hints,
|
|
260
|
+
save_bindings,
|
|
261
|
+
)
|
|
262
|
+
from wafer_core.targets.types import Target
|
|
263
|
+
|
|
264
|
+
async def _sync() -> None:
|
|
265
|
+
specs = load_all_specs()
|
|
266
|
+
|
|
267
|
+
all_targets: list[Target] = []
|
|
268
|
+
for name, prov in get_all_cloud_providers():
|
|
269
|
+
typer.echo(f"Querying {name}...")
|
|
270
|
+
try:
|
|
271
|
+
targets = await prov.list_targets()
|
|
272
|
+
typer.echo(f" Found {len(targets)} resource(s)")
|
|
273
|
+
all_targets.extend(targets)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
typer.echo(f" Failed: {e}", err=True)
|
|
276
|
+
|
|
277
|
+
hints = get_binding_hints()
|
|
278
|
+
result = reconcile(specs, all_targets, binding_hints=hints)
|
|
279
|
+
|
|
280
|
+
# Update binding cache with bound results
|
|
281
|
+
new_bindings = {}
|
|
282
|
+
now = datetime.now(UTC).isoformat()
|
|
283
|
+
for spec, target in result.bound:
|
|
284
|
+
new_bindings[target.resource_id] = BindingEntry(
|
|
285
|
+
spec_name=spec.name,
|
|
286
|
+
provider=target.provider,
|
|
287
|
+
bound_at=now,
|
|
288
|
+
)
|
|
289
|
+
save_bindings(new_bindings)
|
|
290
|
+
|
|
291
|
+
typer.echo("\nSync complete:")
|
|
292
|
+
typer.echo(f" Total resources: {len(all_targets)}")
|
|
293
|
+
typer.echo(f" Matched to specs: {len(result.bound)}")
|
|
294
|
+
typer.echo(f" No matching spec: {len(result.unbound)}")
|
|
295
|
+
|
|
296
|
+
trio.run(_sync)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@targets_live_app.command("provision")
|
|
300
|
+
def targets_provision(
|
|
301
|
+
spec_name: str = typer.Argument(..., help="Spec name to provision from"),
|
|
302
|
+
) -> None:
|
|
303
|
+
"""Explicitly provision a resource from a spec.
|
|
304
|
+
|
|
305
|
+
Creates a new cloud resource and binds it to the spec.
|
|
306
|
+
|
|
307
|
+
Example:
|
|
308
|
+
wafer targets provision runpod-mi300x
|
|
309
|
+
"""
|
|
310
|
+
import trio
|
|
311
|
+
from wafer_core.targets.providers import get_provider
|
|
312
|
+
from wafer_core.targets.spec_store import load_spec
|
|
313
|
+
from wafer_core.targets.state_cache import BindingEntry, add_binding
|
|
314
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
315
|
+
DigitalOceanTarget,
|
|
316
|
+
RunPodTarget,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
try:
|
|
320
|
+
spec = load_spec(spec_name)
|
|
321
|
+
except FileNotFoundError:
|
|
322
|
+
typer.echo(f"Spec not found: {spec_name}", err=True)
|
|
323
|
+
raise typer.Exit(1) from None
|
|
324
|
+
|
|
325
|
+
if isinstance(spec, RunPodTarget):
|
|
326
|
+
provider_name = "runpod"
|
|
327
|
+
elif isinstance(spec, DigitalOceanTarget):
|
|
328
|
+
provider_name = "digitalocean"
|
|
329
|
+
else:
|
|
330
|
+
typer.echo(f"Spec type {type(spec).__name__} cannot be provisioned.", err=True)
|
|
331
|
+
raise typer.Exit(1) from None
|
|
332
|
+
|
|
333
|
+
async def _provision() -> None:
|
|
334
|
+
from wafer_core.targets.probe import probe_target_labels
|
|
335
|
+
from wafer_core.targets.state_cache import save_labels
|
|
336
|
+
|
|
337
|
+
prov = get_provider(provider_name)
|
|
338
|
+
typer.echo(f"Provisioning {spec_name} via {provider_name}...")
|
|
339
|
+
target = await prov.provision(spec)
|
|
340
|
+
|
|
341
|
+
# Cache the binding
|
|
342
|
+
add_binding(
|
|
343
|
+
target.resource_id,
|
|
344
|
+
BindingEntry(
|
|
345
|
+
spec_name=spec_name,
|
|
346
|
+
provider=provider_name,
|
|
347
|
+
bound_at=datetime.now(UTC).isoformat(),
|
|
348
|
+
),
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
typer.echo(f"\nProvisioned: {target.resource_id}")
|
|
352
|
+
if target.public_ip:
|
|
353
|
+
typer.echo(f" SSH: {target.ssh_username}@{target.public_ip}:{target.ssh_port}")
|
|
354
|
+
|
|
355
|
+
# Probe software labels (sync — runs subprocess ssh)
|
|
356
|
+
if target.public_ip and target.ssh_port:
|
|
357
|
+
typer.echo(" Probing software versions...")
|
|
358
|
+
try:
|
|
359
|
+
ssh_key = spec.ssh_key if hasattr(spec, "ssh_key") else None
|
|
360
|
+
labels = probe_target_labels(
|
|
361
|
+
host=target.public_ip,
|
|
362
|
+
port=target.ssh_port,
|
|
363
|
+
username=target.ssh_username,
|
|
364
|
+
ssh_key_path=ssh_key,
|
|
365
|
+
)
|
|
366
|
+
save_labels(target.resource_id, labels)
|
|
367
|
+
if labels:
|
|
368
|
+
typer.echo(f" Labels: {' '.join(f'{k}={v}' for k, v in sorted(labels.items()))}")
|
|
369
|
+
except Exception as e:
|
|
370
|
+
typer.echo(f" Warning: probe failed: {e}", err=True)
|
|
371
|
+
|
|
372
|
+
trio.run(_provision)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
@targets_live_app.command("pools")
|
|
376
|
+
def targets_pools() -> None:
|
|
377
|
+
"""List configured pool queries from config.toml.
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
wafer targets pools
|
|
381
|
+
"""
|
|
382
|
+
from wafer_core.targets.pool import list_pool_names, load_pool_query
|
|
383
|
+
|
|
384
|
+
names = list_pool_names()
|
|
385
|
+
if not names:
|
|
386
|
+
typer.echo("No pools configured in ~/.wafer/config.toml.")
|
|
387
|
+
typer.echo("\nAdd a pool:\n")
|
|
388
|
+
typer.echo(" [pools.mi300x]")
|
|
389
|
+
typer.echo(' gpu_type = "MI300X"')
|
|
390
|
+
typer.echo("")
|
|
391
|
+
typer.echo(" [pools.mi300x-rocm7]")
|
|
392
|
+
typer.echo(' gpu_type = "MI300X"')
|
|
393
|
+
typer.echo(" [pools.mi300x-rocm7.labels]")
|
|
394
|
+
typer.echo(' rocm_version = "7.0.2"')
|
|
395
|
+
return
|
|
396
|
+
|
|
397
|
+
typer.echo(f"{len(names)} pool(s):\n")
|
|
398
|
+
for name in names:
|
|
399
|
+
query = load_pool_query(name)
|
|
400
|
+
parts = []
|
|
401
|
+
if query.gpu_type:
|
|
402
|
+
parts.append(f"gpu_type={query.gpu_type}")
|
|
403
|
+
if query.provider:
|
|
404
|
+
parts.append(f"provider={query.provider}")
|
|
405
|
+
if query.status and query.status != "running":
|
|
406
|
+
parts.append(f"status={query.status}")
|
|
407
|
+
for k, v in sorted(query.labels.items()):
|
|
408
|
+
parts.append(f"{k}={v}")
|
|
409
|
+
criteria = " ".join(parts) if parts else "(match all)"
|
|
410
|
+
typer.echo(f" {name}: {criteria}")
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
@targets_live_app.command("probe")
|
|
414
|
+
def targets_probe(
|
|
415
|
+
resource_id: str = typer.Argument(..., help="Resource ID to probe"),
|
|
416
|
+
provider_name: str | None = typer.Option(
|
|
417
|
+
None, "--provider", "-p", help="Provider hint (avoids querying all providers)"
|
|
418
|
+
),
|
|
419
|
+
) -> None:
|
|
420
|
+
"""Probe a running target's software versions via SSH.
|
|
421
|
+
|
|
422
|
+
Results are cached in ~/.wafer/target_state.json and shown
|
|
423
|
+
by wafer targets list. Used for targets not provisioned by wafer
|
|
424
|
+
(e.g. dashboard-created pods).
|
|
425
|
+
|
|
426
|
+
Examples:
|
|
427
|
+
wafer targets probe ewfo5ckpxlg7y2
|
|
428
|
+
wafer targets probe 543538453 --provider digitalocean
|
|
429
|
+
"""
|
|
430
|
+
import trio
|
|
431
|
+
from wafer_core.targets.probe import probe_target_labels
|
|
432
|
+
from wafer_core.targets.providers import get_all_cloud_providers, get_provider
|
|
433
|
+
from wafer_core.targets.state_cache import save_labels
|
|
434
|
+
|
|
435
|
+
# Find the target (async — needs provider API)
|
|
436
|
+
async def _find_target():
|
|
437
|
+
if provider_name:
|
|
438
|
+
prov = get_provider(provider_name)
|
|
439
|
+
return await prov.get_target(resource_id)
|
|
440
|
+
|
|
441
|
+
for _, prov in get_all_cloud_providers():
|
|
442
|
+
target = await prov.get_target(resource_id)
|
|
443
|
+
if target is not None:
|
|
444
|
+
return target
|
|
445
|
+
return None
|
|
446
|
+
|
|
447
|
+
target = trio.run(_find_target)
|
|
448
|
+
|
|
449
|
+
if target is None:
|
|
450
|
+
typer.echo(f"Resource {resource_id} not found.", err=True)
|
|
451
|
+
raise typer.Exit(1)
|
|
452
|
+
|
|
453
|
+
if not target.public_ip or not target.ssh_port:
|
|
454
|
+
typer.echo(f"Resource {resource_id} has no SSH info (status={target.status}).", err=True)
|
|
455
|
+
raise typer.Exit(1)
|
|
456
|
+
|
|
457
|
+
typer.echo(f"Probing {resource_id} ({target.ssh_username}@{target.public_ip}:{target.ssh_port})...")
|
|
458
|
+
|
|
459
|
+
labels = probe_target_labels(
|
|
460
|
+
host=target.public_ip,
|
|
461
|
+
port=target.ssh_port,
|
|
462
|
+
username=target.ssh_username,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
save_labels(resource_id, labels)
|
|
466
|
+
|
|
467
|
+
if labels:
|
|
468
|
+
typer.echo(f"Labels cached for {resource_id}:")
|
|
469
|
+
for k, v in sorted(labels.items()):
|
|
470
|
+
typer.echo(f" {k}={v}")
|
|
471
|
+
else:
|
|
472
|
+
typer.echo("Probe returned no labels.")
|
wafer/targets_ops.py
CHANGED
|
@@ -15,6 +15,7 @@ import logging
|
|
|
15
15
|
import subprocess
|
|
16
16
|
from collections.abc import Callable
|
|
17
17
|
from dataclasses import dataclass, replace
|
|
18
|
+
from datetime import UTC
|
|
18
19
|
from pathlib import Path
|
|
19
20
|
from typing import TYPE_CHECKING
|
|
20
21
|
|
|
@@ -30,6 +31,26 @@ if TYPE_CHECKING:
|
|
|
30
31
|
logger = logging.getLogger(__name__)
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
def _update_binding_cache(resource_id: str, spec_name: str, provider: str) -> None:
|
|
35
|
+
"""Update the new target state cache when provisioning through the legacy path.
|
|
36
|
+
|
|
37
|
+
This bridges the old per-provider state files with the new unified cache
|
|
38
|
+
so that `wafer targets list` can see resources provisioned via the old flow.
|
|
39
|
+
"""
|
|
40
|
+
from datetime import datetime
|
|
41
|
+
|
|
42
|
+
from wafer_core.targets.state_cache import BindingEntry, add_binding
|
|
43
|
+
|
|
44
|
+
add_binding(
|
|
45
|
+
resource_id,
|
|
46
|
+
BindingEntry(
|
|
47
|
+
spec_name=spec_name,
|
|
48
|
+
provider=provider,
|
|
49
|
+
bound_at=datetime.now(UTC).isoformat(),
|
|
50
|
+
),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
33
54
|
@dataclass(frozen=True)
|
|
34
55
|
class TargetSSHInfo:
|
|
35
56
|
"""SSH connection info for a target."""
|
|
@@ -135,7 +156,8 @@ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
|
|
|
135
156
|
# Check if pod already exists and is running
|
|
136
157
|
existing = get_pod_state(target.name)
|
|
137
158
|
if existing and await check_pod_running(existing.pod_id):
|
|
138
|
-
# Reuse existing pod
|
|
159
|
+
# Reuse existing pod — also update the new state cache
|
|
160
|
+
_update_binding_cache(existing.pod_id, target.name, "runpod")
|
|
139
161
|
return TargetSSHInfo(
|
|
140
162
|
host=existing.public_ip,
|
|
141
163
|
port=existing.ssh_port,
|
|
@@ -151,6 +173,8 @@ async def _get_runpod_ssh_info(target: RunPodTarget) -> TargetSSHInfo:
|
|
|
151
173
|
target_keep_alive = replace(target, keep_alive=True)
|
|
152
174
|
|
|
153
175
|
async with runpod_ssh_context(target_keep_alive) as ssh_info:
|
|
176
|
+
# Update new state cache with provisioned pod
|
|
177
|
+
_update_binding_cache(ssh_info.pod_id, target.name, "runpod")
|
|
154
178
|
return TargetSSHInfo(
|
|
155
179
|
host=ssh_info.host,
|
|
156
180
|
port=ssh_info.port,
|
|
@@ -172,7 +196,8 @@ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInf
|
|
|
172
196
|
# Check if droplet already exists and is running
|
|
173
197
|
existing = get_droplet_state(target.name)
|
|
174
198
|
if existing and await check_droplet_running(existing.droplet_id):
|
|
175
|
-
# Reuse existing droplet
|
|
199
|
+
# Reuse existing droplet — also update the new state cache
|
|
200
|
+
_update_binding_cache(existing.droplet_id, target.name, "digitalocean")
|
|
176
201
|
return TargetSSHInfo(
|
|
177
202
|
host=existing.public_ip,
|
|
178
203
|
port=22, # DigitalOcean uses standard SSH port
|
|
@@ -184,6 +209,8 @@ async def _get_digitalocean_ssh_info(target: DigitalOceanTarget) -> TargetSSHInf
|
|
|
184
209
|
target_keep_alive = replace(target, keep_alive=True)
|
|
185
210
|
|
|
186
211
|
async with digitalocean_ssh_context(target_keep_alive) as ssh_info:
|
|
212
|
+
# Update new state cache with provisioned droplet
|
|
213
|
+
_update_binding_cache(ssh_info.droplet_id, target.name, "digitalocean")
|
|
187
214
|
return TargetSSHInfo(
|
|
188
215
|
host=ssh_info.host,
|
|
189
216
|
port=ssh_info.port,
|
wafer/templates/ask_docs.py
CHANGED
|
@@ -51,7 +51,7 @@ Output your answer directly. Be concise but thorough. Include code examples when
|
|
|
51
51
|
"python -c",
|
|
52
52
|
],
|
|
53
53
|
# Model config
|
|
54
|
-
model="anthropic/claude-
|
|
54
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
55
55
|
max_tokens=8192,
|
|
56
56
|
# Thinking config - disabled for simple doc queries
|
|
57
57
|
thinking=False,
|
|
@@ -56,7 +56,7 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
|
|
|
56
56
|
"python -c",
|
|
57
57
|
],
|
|
58
58
|
# Model config - use thinking for complex optimization reasoning
|
|
59
|
-
model="anthropic/claude-
|
|
59
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
60
60
|
max_tokens=16384,
|
|
61
61
|
# Thinking config - enabled for complex kernel optimization
|
|
62
62
|
thinking=True,
|
|
@@ -68,4 +68,6 @@ IMPORTANT: Always verify correctness with wafer evaluate before claiming success
|
|
|
68
68
|
"kernel": "./kernel.cu",
|
|
69
69
|
"target": "H100",
|
|
70
70
|
},
|
|
71
|
+
# Enable skill discovery (agent can load wafer-guide, etc.)
|
|
72
|
+
include_skills=True,
|
|
71
73
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Template for KernelBench optimization
|
|
1
|
+
"""Template for KernelBench optimization.
|
|
2
2
|
|
|
3
3
|
Usage:
|
|
4
4
|
# Run on a specific problem
|
|
@@ -26,12 +26,18 @@ try:
|
|
|
26
26
|
except ImportError:
|
|
27
27
|
from rollouts.templates import TemplateConfig
|
|
28
28
|
|
|
29
|
-
|
|
29
|
+
from wafer.agent_defaults import ENABLED_TOOLS, KERNELBENCH_BASH_ALLOWLIST
|
|
30
|
+
|
|
31
|
+
# Task-specific instructions only — must stay in sync with the eval's SYSTEM_PROMPT
|
|
32
|
+
# in research/evals/optimize_kernelbench_eval/.../base_config.py.
|
|
33
|
+
# Run test_eval_cli_parity.py to verify.
|
|
34
|
+
# Wafer CLI command docs are auto-generated from --help text and composed
|
|
35
|
+
# at runtime by wevin_cli.py (see wafer.cli_instructions.build_cli_instructions).
|
|
36
|
+
# TODO: Consider having both eval and template import SYSTEM_PROMPT from a shared
|
|
37
|
+
# module so there's only one copy to maintain.
|
|
30
38
|
SYSTEM_PROMPT = """\
|
|
31
39
|
You are a GPU kernel optimization expert. Your task is to write optimized GPU kernels that are correct and faster than the PyTorch baseline.
|
|
32
40
|
|
|
33
|
-
IMPORTANT: You do NOT have a local GPU. You MUST use `wafer evaluate kernelbench` to test kernels on remote GPU hardware.
|
|
34
|
-
|
|
35
41
|
## Kernel Format (KernelBench)
|
|
36
42
|
|
|
37
43
|
The reference file contains a PyTorch `Model` class. You must write a `ModelNew` class that:
|
|
@@ -43,49 +49,14 @@ The reference file also provides:
|
|
|
43
49
|
- `get_inputs()` - generates test inputs for forward()
|
|
44
50
|
- `get_init_inputs()` - generates constructor arguments
|
|
45
51
|
|
|
46
|
-
## Available Tools
|
|
47
|
-
|
|
48
|
-
- read(file_path): Read source files
|
|
49
|
-
- write(file_path, content): Write your optimized kernel
|
|
50
|
-
- glob(pattern): Find files by pattern
|
|
51
|
-
- grep(pattern): Search code
|
|
52
|
-
- bash(command): Run shell commands including wafer CLI
|
|
53
|
-
|
|
54
52
|
## Workflow
|
|
55
53
|
|
|
56
54
|
1. Read the reference problem file to understand what `Model` does
|
|
57
55
|
2. Analyze the computation and identify optimization opportunities
|
|
58
56
|
3. Write an optimized `ModelNew` class with custom $backend_upper kernels using `__global__` kernel definitions and `torch.utils.cpp_extension.load_inline`
|
|
59
|
-
4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl
|
|
57
|
+
4. Test with: `wafer evaluate kernelbench $target_flag --backend $backend --impl optimized.py --reference <problem.py> --benchmark`
|
|
60
58
|
5. Iterate based on feedback until correct and fast
|
|
61
59
|
|
|
62
|
-
## Example Command
|
|
63
|
-
|
|
64
|
-
```bash
|
|
65
|
-
wafer evaluate kernelbench \\
|
|
66
|
-
$target_flag \\
|
|
67
|
-
--backend $backend \\
|
|
68
|
-
--impl optimized_kernel.py \\
|
|
69
|
-
--reference $reference \\
|
|
70
|
-
--benchmark
|
|
71
|
-
```
|
|
72
|
-
|
|
73
|
-
## Profiling Tools (USE THESE!)
|
|
74
|
-
|
|
75
|
-
When your kernel is slower than expected, use profiling to understand WHY:
|
|
76
|
-
|
|
77
|
-
- `wafer rocprof profile --impl <file> --reference <ref>` - AMD GPU profiling
|
|
78
|
-
- `wafer nvidia ncu --impl <file> --reference <ref>` - NVIDIA NCU profiling
|
|
79
|
-
|
|
80
|
-
## CRITICAL: Reactive Debugging
|
|
81
|
-
|
|
82
|
-
After EVERY `wafer evaluate` call:
|
|
83
|
-
1. Check the speedup result
|
|
84
|
-
2. If speedup < 1.0x (slowdown), STOP and analyze:
|
|
85
|
-
- Run profiling to identify the bottleneck
|
|
86
|
-
- Ask: "Why is this slow?" before trying another approach
|
|
87
|
-
3. Don't just try random optimizations - understand the root cause
|
|
88
|
-
|
|
89
60
|
Your kernel MUST:
|
|
90
61
|
- Pass correctness tests (outputs match reference within tolerance)
|
|
91
62
|
- Achieve speedup > 1.0x over PyTorch baseline
|
|
@@ -96,32 +67,16 @@ You MUST run `wafer evaluate kernelbench` to verify your kernel. Your score depe
|
|
|
96
67
|
template = TemplateConfig(
|
|
97
68
|
# Identity
|
|
98
69
|
name="optimize-kernelbench",
|
|
99
|
-
description="Optimize KernelBench problems
|
|
100
|
-
# System prompt
|
|
70
|
+
description="Optimize KernelBench problems",
|
|
71
|
+
# System prompt (task-specific; CLI docs appended at runtime)
|
|
101
72
|
system_prompt=SYSTEM_PROMPT,
|
|
102
73
|
# Tools
|
|
103
|
-
tools=
|
|
104
|
-
bash_allowlist=
|
|
105
|
-
|
|
106
|
-
"wafer nvidia ncu",
|
|
107
|
-
"wafer nvidia nsys",
|
|
108
|
-
"wafer rocprof",
|
|
109
|
-
"wafer compiler-analyze",
|
|
110
|
-
"python",
|
|
111
|
-
"python3",
|
|
112
|
-
"timeout",
|
|
113
|
-
"ls",
|
|
114
|
-
"cat",
|
|
115
|
-
"head",
|
|
116
|
-
"tail",
|
|
117
|
-
"wc",
|
|
118
|
-
"pwd",
|
|
119
|
-
"which",
|
|
120
|
-
],
|
|
121
|
-
# Model config - match eval settings
|
|
74
|
+
tools=ENABLED_TOOLS,
|
|
75
|
+
bash_allowlist=KERNELBENCH_BASH_ALLOWLIST,
|
|
76
|
+
# Model config
|
|
122
77
|
model="anthropic/claude-opus-4-5-20251101",
|
|
123
78
|
max_tokens=8192,
|
|
124
|
-
# No thinking by default
|
|
79
|
+
# No thinking by default, can override with --thinking
|
|
125
80
|
thinking=False,
|
|
126
81
|
# Multi-turn for iterative optimization
|
|
127
82
|
single_turn=False,
|
wafer/templates/trace_analyze.py
CHANGED
|
@@ -60,7 +60,7 @@ Use `--json` flags when available for structured output that's easier to parse.
|
|
|
60
60
|
"python -c",
|
|
61
61
|
],
|
|
62
62
|
# Model config
|
|
63
|
-
model="anthropic/claude-
|
|
63
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
64
64
|
max_tokens=8192,
|
|
65
65
|
# Thinking config - disabled for trace analysis (mostly parsing)
|
|
66
66
|
thinking=False,
|