wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ """Generate agent system prompt instructions from the wafer CLI's own --help text.
2
+
3
+ Walks the typer/click command tree and extracts help text for commands
4
+ matching the bash_allowlist. This ensures agent instructions stay in sync
5
+ with the CLI — the --help text is the single source of truth for both
6
+ human users and AI agents.
7
+
8
+ Usage:
9
+ from wafer.cli_instructions import build_cli_instructions
10
+
11
+ instructions = build_cli_instructions([
12
+ "wafer evaluate",
13
+ "wafer nvidia ncu",
14
+ "wafer rocprof profile",
15
+ "python", # non-wafer commands are skipped
16
+ ])
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import click
22
+ import typer.main
23
+
24
+
25
+ def _resolve_command(root: click.BaseCommand, parts: list[str]) -> click.BaseCommand | None:
26
+ """Walk the click command tree to find a (sub)command by name parts.
27
+
28
+ Args:
29
+ root: The root click command (from typer.main.get_command)
30
+ parts: Command path segments, e.g. ["evaluate", "kernelbench"]
31
+
32
+ Returns:
33
+ The click command at that path, or None if not found.
34
+ """
35
+ cmd = root
36
+ for part in parts:
37
+ if not isinstance(cmd, click.MultiCommand):
38
+ return None
39
+ ctx = click.Context(cmd, info_name=part)
40
+ child = cmd.get_command(ctx, part)
41
+ if child is None:
42
+ return None
43
+ cmd = child
44
+ return cmd
45
+
46
+
47
+ def _format_command_help(cmd_path: str, cmd: click.BaseCommand) -> str:
48
+ """Format a single command's help text for inclusion in a system prompt.
49
+
50
+ Extracts the description and option help text (skipping --help itself).
51
+ """
52
+ lines = [f"### `{cmd_path}`"]
53
+
54
+ if cmd.help:
55
+ lines.append(cmd.help.strip())
56
+
57
+ # Extract option help
58
+ option_lines = []
59
+ for param in getattr(cmd, "params", []):
60
+ if not isinstance(param, click.Option):
61
+ continue
62
+ # Skip --help
63
+ if param.name == "help":
64
+ continue
65
+ name = "/".join(param.opts)
66
+ type_name = param.type.name.upper() if hasattr(param.type, "name") else ""
67
+ help_text = param.help or ""
68
+ is_flag = type_name in ("BOOL", "BOOLEAN") or param.is_flag
69
+ if type_name and not is_flag:
70
+ option_lines.append(f" {name} {type_name} {help_text}")
71
+ else:
72
+ option_lines.append(f" {name} {help_text}")
73
+
74
+ if option_lines:
75
+ lines.append("")
76
+ lines.append("Options:")
77
+ lines.extend(option_lines)
78
+
79
+ # List subcommands if this is a group
80
+ if isinstance(cmd, click.MultiCommand):
81
+ ctx = click.Context(cmd, info_name=cmd_path.split()[-1])
82
+ subcmd_names = cmd.list_commands(ctx)
83
+ if subcmd_names:
84
+ subcmd_lines = []
85
+ for name in subcmd_names:
86
+ subcmd = cmd.get_command(ctx, name)
87
+ if subcmd:
88
+ desc = (subcmd.help or subcmd.short_help or "").strip().split("\n")[0]
89
+ subcmd_lines.append(f" {cmd_path} {name} {desc}")
90
+ if subcmd_lines:
91
+ lines.append("")
92
+ lines.append("Subcommands:")
93
+ lines.extend(subcmd_lines)
94
+
95
+ return "\n".join(lines)
96
+
97
+
98
+ def build_cli_instructions(bash_allowlist: list[str]) -> str:
99
+ """Generate CLI instruction text from --help for allowed wafer commands.
100
+
101
+ Walks the typer/click command tree and extracts help text for each
102
+ wafer command in the bash_allowlist. Non-wafer commands (python, ls, etc.)
103
+ are skipped.
104
+
105
+ Args:
106
+ bash_allowlist: List of allowed bash command prefixes.
107
+ Example: ["wafer evaluate", "wafer nvidia ncu", "python"]
108
+
109
+ Returns:
110
+ Markdown-formatted CLI instructions, or empty string if no wafer
111
+ commands are in the allowlist.
112
+ """
113
+ if not bash_allowlist:
114
+ return ""
115
+
116
+ # Filter to wafer commands only
117
+ wafer_commands = [cmd for cmd in bash_allowlist if cmd.startswith("wafer ")]
118
+ if not wafer_commands:
119
+ return ""
120
+
121
+ # Lazy import to avoid circular deps at module level
122
+ from wafer.cli import app
123
+
124
+ root = typer.main.get_command(app)
125
+
126
+ sections = []
127
+ for cmd_str in wafer_commands:
128
+ # "wafer evaluate kernelbench" -> ["evaluate", "kernelbench"]
129
+ parts = cmd_str.split()[1:] # drop "wafer" prefix
130
+ cmd = _resolve_command(root, parts)
131
+ if cmd is None:
132
+ # Command not found in tree — skip silently
133
+ continue
134
+ sections.append(_format_command_help(cmd_str, cmd))
135
+
136
+ if not sections:
137
+ return ""
138
+
139
+ header = (
140
+ "## Wafer CLI Commands\n\n"
141
+ "You do not have a local GPU. Use the wafer CLI to run on remote GPU hardware.\n"
142
+ )
143
+ return header + "\n\n".join(sections)
wafer/corpus.py CHANGED
@@ -3,10 +3,12 @@
3
3
  Download and manage documentation corpora for agent filesystem access.
4
4
  """
5
5
 
6
+ import re
6
7
  import shutil
7
8
  import tarfile
8
9
  import tempfile
9
10
  from dataclasses import dataclass
11
+ from html.parser import HTMLParser
10
12
  from pathlib import Path
11
13
  from typing import Literal
12
14
  from urllib.parse import urlparse
@@ -33,7 +35,7 @@ class CorpusConfig:
33
35
 
34
36
  name: CorpusName
35
37
  description: str
36
- source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
38
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
37
39
  urls: list[str] | None = None
38
40
  repo: str | None = None
39
41
  repo_paths: list[str] | None = None
@@ -67,21 +69,74 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
67
69
  ),
68
70
  "cutlass": CorpusConfig(
69
71
  name="cutlass",
70
- description="CUTLASS and CuTe DSL documentation",
71
- source_type="github_repo",
72
- repo="NVIDIA/cutlass",
73
- repo_paths=["media/docs", "python/cutlass/docs"],
72
+ description="CUTLASS C++ documentation, examples, and tutorials",
73
+ source_type="mixed",
74
+ # Official NVIDIA CUTLASS documentation (scraped as markdown)
75
+ urls=[
76
+ "https://docs.nvidia.com/cutlass/latest/overview.html",
77
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
78
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
79
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
80
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
81
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
82
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
83
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
84
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
85
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
86
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
87
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
88
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
89
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
90
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
91
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
92
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
93
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
94
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
95
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
96
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
97
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
98
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
99
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
100
+ ],
101
+ # NVIDIA/cutlass GitHub examples (excluding python/)
102
+ repos=[
103
+ RepoSource(
104
+ repo="NVIDIA/cutlass",
105
+ paths=["examples"],
106
+ branch="main",
107
+ ),
108
+ ],
74
109
  ),
75
110
  "hip": CorpusConfig(
76
111
  name="hip",
77
- description="HIP programming guide and API reference",
78
- source_type="github_repo",
79
- repo="ROCm/HIP",
80
- repo_paths=["docs"],
112
+ description="HIP programming guide, API reference, and examples",
113
+ source_type="github_multi_repo",
114
+ repos=[
115
+ # HIP - main documentation and API
116
+ RepoSource(
117
+ repo="ROCm/HIP",
118
+ paths=["docs"],
119
+ ),
120
+ # HIP examples - code samples
121
+ RepoSource(
122
+ repo="ROCm/HIP-Examples",
123
+ paths=["HIP-Examples-Applications", "mini-nbody"],
124
+ ),
125
+ # clr - HIP/OpenCL runtime (low-level)
126
+ RepoSource(
127
+ repo="ROCm/clr",
128
+ paths=["hipamd/include", "rocclr/device/gpu"],
129
+ ),
130
+ # ROCm docs - official documentation
131
+ RepoSource(
132
+ repo="ROCm/ROCm",
133
+ paths=["docs"],
134
+ ),
135
+ ],
81
136
  ),
82
137
  "amd": CorpusConfig(
83
138
  name="amd",
84
- description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM)",
139
+ description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM, FlashAttention)",
85
140
  source_type="github_multi_repo",
86
141
  repos=[
87
142
  # rocWMMA - wave matrix multiply-accumulate (WMMA) intrinsics
@@ -125,11 +180,17 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
125
180
  paths=["docs"],
126
181
  branch="develop_deprecated",
127
182
  ),
128
- # HipKittens - high-performance AMD kernels
183
+ # HipKittens - high-performance AMD kernels (main branch: MI350X/CDNA4+)
129
184
  RepoSource(
130
185
  repo="HazyResearch/HipKittens",
131
186
  paths=["docs", "kernels", "include"],
132
187
  ),
188
+ # HipKittens cdna3 branch - MI300X/MI325X (gfx942)
189
+ RepoSource(
190
+ repo="HazyResearch/HipKittens",
191
+ paths=["kernels", "include", "tests"],
192
+ branch="cdna3",
193
+ ),
133
194
  # vLLM AMD kernels
134
195
  RepoSource(
135
196
  repo="vllm-project/vllm",
@@ -145,6 +206,46 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
145
206
  repo="huggingface/hf-rocm-kernels",
146
207
  paths=["csrc", "hf_rocm_kernels", "docs"],
147
208
  ),
209
+ # ROCm/flash-attention - FlashAttention for AMD GPUs
210
+ RepoSource(
211
+ repo="ROCm/flash-attention",
212
+ paths=["csrc", "docs"],
213
+ ),
214
+ # ROCm/triton - Triton compiler for AMD GPUs
215
+ RepoSource(
216
+ repo="ROCm/triton",
217
+ paths=["python/tutorials", "third_party/amd"],
218
+ ),
219
+ # ROCm/rccl - ROCm Communication Collectives Library (multi-GPU)
220
+ RepoSource(
221
+ repo="ROCm/rccl",
222
+ paths=["docs"],
223
+ ),
224
+ # ROCm/rocprofiler-sdk - AMD GPU profiling SDK
225
+ RepoSource(
226
+ repo="ROCm/rocprofiler-sdk",
227
+ paths=["docs", "samples"],
228
+ ),
229
+ # ROCm/omniperf - AMD GPU profiling tool
230
+ RepoSource(
231
+ repo="ROCm/omniperf",
232
+ paths=["docs", "src/omniperf_analyze"],
233
+ ),
234
+ # ROCm/omnitrace - Application tracing for AMD
235
+ RepoSource(
236
+ repo="ROCm/omnitrace",
237
+ paths=["docs"],
238
+ ),
239
+ # AMD GPUOpen Performance Guides
240
+ RepoSource(
241
+ repo="GPUOpen-Tools/gpu_performance_api",
242
+ paths=["docs"],
243
+ ),
244
+ # AMD LLVM - AMD GPU compiler backend
245
+ RepoSource(
246
+ repo="ROCm/llvm-project",
247
+ paths=["amd/device-libs/README.md", "llvm/docs/AMDGPUUsage.rst"],
248
+ ),
148
249
  ],
149
250
  ),
150
251
  }
@@ -169,19 +270,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
169
270
  return base_dir / "/".join(path_parts)
170
271
 
171
272
 
273
+ class _HTMLToMarkdown(HTMLParser):
274
+ """HTML to Markdown converter for NVIDIA documentation pages.
275
+
276
+ Uses stdlib HTMLParser - requires subclassing due to callback-based API.
277
+ The public interface is the functional `_html_to_markdown()` below.
278
+ """
279
+
280
+ def __init__(self) -> None:
281
+ super().__init__()
282
+ self.output: list[str] = []
283
+ self.current_tag: str = ""
284
+ self.in_code_block = False
285
+ self.in_pre = False
286
+ self.list_depth = 0
287
+ self.ordered_list_counters: list[int] = []
288
+ self.skip_content = False
289
+ self.link_href: str | None = None
290
+
291
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
292
+ self.current_tag = tag
293
+ attrs_dict = dict(attrs)
294
+
295
+ # Skip script, style, nav, footer, header
296
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
297
+ self.skip_content = True
298
+ return
299
+
300
+ if tag == "h1":
301
+ self.output.append("\n# ")
302
+ elif tag == "h2":
303
+ self.output.append("\n## ")
304
+ elif tag == "h3":
305
+ self.output.append("\n### ")
306
+ elif tag == "h4":
307
+ self.output.append("\n#### ")
308
+ elif tag == "h5":
309
+ self.output.append("\n##### ")
310
+ elif tag == "h6":
311
+ self.output.append("\n###### ")
312
+ elif tag == "p":
313
+ self.output.append("\n\n")
314
+ elif tag == "br":
315
+ self.output.append("\n")
316
+ elif tag == "strong" or tag == "b":
317
+ self.output.append("**")
318
+ elif tag == "em" or tag == "i":
319
+ self.output.append("*")
320
+ elif tag == "code" and not self.in_pre:
321
+ self.output.append("`")
322
+ self.in_code_block = True
323
+ elif tag == "pre":
324
+ self.in_pre = True
325
+ # Check for language hint in class
326
+ lang = ""
327
+ if class_attr := attrs_dict.get("class"):
328
+ if "python" in class_attr.lower():
329
+ lang = "python"
330
+ elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
331
+ lang = "cpp"
332
+ elif "cuda" in class_attr.lower():
333
+ lang = "cuda"
334
+ self.output.append(f"\n```{lang}\n")
335
+ elif tag == "ul":
336
+ self.list_depth += 1
337
+ self.output.append("\n")
338
+ elif tag == "ol":
339
+ self.list_depth += 1
340
+ self.ordered_list_counters.append(1)
341
+ self.output.append("\n")
342
+ elif tag == "li":
343
+ indent = " " * (self.list_depth - 1)
344
+ if self.ordered_list_counters:
345
+ num = self.ordered_list_counters[-1]
346
+ self.output.append(f"{indent}{num}. ")
347
+ self.ordered_list_counters[-1] += 1
348
+ else:
349
+ self.output.append(f"{indent}- ")
350
+ elif tag == "a":
351
+ self.link_href = attrs_dict.get("href")
352
+ self.output.append("[")
353
+ elif tag == "img":
354
+ alt = attrs_dict.get("alt", "image")
355
+ src = attrs_dict.get("src", "")
356
+ self.output.append(f"![{alt}]({src})")
357
+ elif tag == "blockquote":
358
+ self.output.append("\n> ")
359
+ elif tag == "hr":
360
+ self.output.append("\n---\n")
361
+ elif tag == "table":
362
+ self.output.append("\n")
363
+ elif tag == "th":
364
+ self.output.append("| ")
365
+ elif tag == "td":
366
+ self.output.append("| ")
367
+ elif tag == "tr":
368
+ pass # Handled in endtag
369
+
370
+ def handle_endtag(self, tag: str) -> None:
371
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
372
+ self.skip_content = False
373
+ return
374
+
375
+ if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
376
+ self.output.append("\n")
377
+ elif tag == "strong" or tag == "b":
378
+ self.output.append("**")
379
+ elif tag == "em" or tag == "i":
380
+ self.output.append("*")
381
+ elif tag == "code" and not self.in_pre:
382
+ self.output.append("`")
383
+ self.in_code_block = False
384
+ elif tag == "pre":
385
+ self.in_pre = False
386
+ self.output.append("\n```\n")
387
+ elif tag == "ul":
388
+ self.list_depth = max(0, self.list_depth - 1)
389
+ elif tag == "ol":
390
+ self.list_depth = max(0, self.list_depth - 1)
391
+ if self.ordered_list_counters:
392
+ self.ordered_list_counters.pop()
393
+ elif tag == "li":
394
+ self.output.append("\n")
395
+ elif tag == "a":
396
+ if self.link_href:
397
+ self.output.append(f"]({self.link_href})")
398
+ else:
399
+ self.output.append("]")
400
+ self.link_href = None
401
+ elif tag == "p":
402
+ self.output.append("\n")
403
+ elif tag == "blockquote":
404
+ self.output.append("\n")
405
+ elif tag == "tr":
406
+ self.output.append("|\n")
407
+ elif tag == "thead":
408
+ # Add markdown table separator after header row
409
+ self.output.append("|---" * 10 + "|\n")
410
+
411
+ def handle_data(self, data: str) -> None:
412
+ if self.skip_content:
413
+ return
414
+ # Preserve whitespace in code blocks
415
+ if self.in_pre:
416
+ self.output.append(data)
417
+ else:
418
+ # Collapse whitespace outside code
419
+ text = re.sub(r"\s+", " ", data)
420
+ if text.strip():
421
+ self.output.append(text)
422
+
423
+ def get_markdown(self) -> str:
424
+ """Get the converted markdown, cleaned up."""
425
+ md = "".join(self.output)
426
+ # Clean up excessive newlines
427
+ md = re.sub(r"\n{3,}", "\n\n", md)
428
+ # Clean up empty table separators
429
+ md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
430
+ return md.strip()
431
+
432
+
433
+ def _html_to_markdown(html: str) -> str:
434
+ """Convert HTML to Markdown."""
435
+ parser = _HTMLToMarkdown()
436
+ parser.feed(html)
437
+ return parser.get_markdown()
438
+
439
+
172
440
  def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
173
- """Download NVIDIA docs using .md endpoint."""
441
+ """Download NVIDIA docs and convert HTML to Markdown.
442
+
443
+ NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
444
+ """
174
445
  assert config.urls is not None
175
446
  downloaded = 0
176
447
  with httpx.Client(timeout=30.0, follow_redirects=True) as client:
177
448
  for url in config.urls:
178
- md_url = f"{url}.md"
179
449
  filepath = _url_to_filepath(url, dest)
180
450
  filepath.parent.mkdir(parents=True, exist_ok=True)
181
451
  try:
182
- resp = client.get(md_url)
452
+ # Fetch HTML page directly
453
+ resp = client.get(url)
183
454
  resp.raise_for_status()
184
- filepath.write_text(resp.text)
455
+
456
+ # Convert HTML to Markdown
457
+ markdown = _html_to_markdown(resp.text)
458
+
459
+ # Add source URL as header
460
+ content = f"<!-- Source: {url} -->\n\n{markdown}"
461
+ filepath.write_text(content)
185
462
  downloaded += 1
186
463
  if verbose:
187
464
  print(f" ✓ {filepath.relative_to(dest)}")
@@ -275,6 +552,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
275
552
  return downloaded
276
553
 
277
554
 
555
+ def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
556
+ """Download from mixed sources (NVIDIA docs + GitHub repos)."""
557
+ total = 0
558
+
559
+ # Download NVIDIA markdown docs (urls)
560
+ if config.urls:
561
+ if verbose:
562
+ print(" [NVIDIA docs]")
563
+ total += _download_nvidia_md(config, dest, verbose)
564
+
565
+ # Download GitHub repos
566
+ if config.repos:
567
+ if verbose:
568
+ print(" [GitHub repos]")
569
+ total += _download_github_multi_repo(config, dest, verbose)
570
+
571
+ return total
572
+
573
+
278
574
  def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
279
575
  """Download a corpus to local cache.
280
576
 
@@ -311,6 +607,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
311
607
  count = _download_github_repo(config, dest, verbose)
312
608
  elif config.source_type == "github_multi_repo":
313
609
  count = _download_github_multi_repo(config, dest, verbose)
610
+ elif config.source_type == "mixed":
611
+ count = _download_mixed(config, dest, verbose)
314
612
  else:
315
613
  raise ValueError(f"Unknown source type: {config.source_type}")
316
614
  if verbose: