wafer-cli 0.2.20__tar.gz → 0.2.22__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/PKG-INFO +1 -1
  2. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/pyproject.toml +1 -1
  3. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/cli.py +205 -22
  4. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/corpus.py +241 -9
  5. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/evaluate.py +426 -8
  6. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/optimize_kernel.py +2 -0
  7. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/wevin_cli.py +39 -16
  8. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/PKG-INFO +1 -1
  9. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/README.md +0 -0
  10. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/setup.cfg +0 -0
  11. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_analytics.py +0 -0
  12. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_auth.py +0 -0
  13. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_billing.py +0 -0
  14. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_cli_coverage.py +0 -0
  15. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_cli_parity_integration.py +0 -0
  16. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_config_integration.py +0 -0
  17. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_file_operations_integration.py +0 -0
  18. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_kernel_scope_cli.py +0 -0
  19. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_nsys_analyze.py +0 -0
  20. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_nsys_profile.py +0 -0
  21. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_output.py +0 -0
  22. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_rocprof_compute_integration.py +0 -0
  23. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_skill_commands.py +0 -0
  24. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_ssh_integration.py +0 -0
  25. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_targets_ops.py +0 -0
  26. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_wevin_cli.py +0 -0
  27. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_workflow_integration.py +0 -0
  28. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/GUIDE.md +0 -0
  29. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/__init__.py +0 -0
  30. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/analytics.py +0 -0
  31. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/api_client.py +0 -0
  32. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/auth.py +0 -0
  33. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/autotuner.py +0 -0
  34. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/billing.py +0 -0
  35. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/config.py +0 -0
  36. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/global_config.py +0 -0
  37. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/gpu_run.py +0 -0
  38. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/inference.py +0 -0
  39. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/kernel_scope.py +0 -0
  40. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/ncu_analyze.py +0 -0
  41. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/nsys_analyze.py +0 -0
  42. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/nsys_profile.py +0 -0
  43. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/output.py +0 -0
  44. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/problems.py +0 -0
  45. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_compute.py +0 -0
  46. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_sdk.py +0 -0
  47. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_systems.py +0 -0
  48. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/skills/wafer-guide/SKILL.md +0 -0
  49. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/ssh_keys.py +0 -0
  50. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/target_lock.py +0 -0
  51. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/targets.py +0 -0
  52. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/targets_ops.py +0 -0
  53. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/__init__.py +0 -0
  54. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/ask_docs.py +0 -0
  55. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/optimize_kernelbench.py +0 -0
  56. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/trace_analyze.py +0 -0
  57. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/tracelens.py +0 -0
  58. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/workspaces.py +0 -0
  59. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/SOURCES.txt +0 -0
  60. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/dependency_links.txt +0 -0
  61. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/entry_points.txt +0 -0
  62. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/requires.txt +0 -0
  63. {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-cli
3
- Version: 0.2.20
3
+ Version: 0.2.22
4
4
  Summary: CLI tool for running commands on remote GPUs and GPU kernel optimization agent
5
5
  Requires-Python: >=3.11
6
6
  Requires-Dist: typer>=0.12.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "wafer-cli"
3
- version = "0.2.20"
3
+ version = "0.2.22"
4
4
  description = "CLI tool for running commands on remote GPUs and GPU kernel optimization agent"
5
5
  requires-python = ">=3.11"
6
6
  dependencies = [
@@ -1,6 +1,8 @@
1
- # ruff: noqa: PLR0913
1
+ # ruff: noqa: PLR0913, E402
2
2
  # PLR0913 (too many arguments) is suppressed because Typer CLI commands
3
3
  # naturally have many parameters - each --flag becomes a function argument.
4
+ # E402 (module level import not at top) is suppressed because we intentionally
5
+ # load .env files before importing other modules that may read env vars.
4
6
  """Wafer CLI - GPU development toolkit for LLM coding agents.
5
7
 
6
8
  Core commands:
@@ -27,6 +29,12 @@ from pathlib import Path
27
29
 
28
30
  import trio
29
31
  import typer
32
+ from dotenv import load_dotenv
33
+
34
+ # Auto-load .env from current directory and ~/.wafer/.env
35
+ # This runs at import time so env vars are available before any config is accessed
36
+ load_dotenv() # cwd/.env
37
+ load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
30
38
 
31
39
  from .config import WaferConfig, WaferEnvironment
32
40
  from .inference import infer_upload_files, resolve_environment
@@ -42,6 +50,7 @@ from .problems import (
42
50
  app = typer.Typer(
43
51
  help="GPU development toolkit for LLM coding agents",
44
52
  no_args_is_help=True,
53
+ pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
45
54
  )
46
55
 
47
56
  # =============================================================================
@@ -58,11 +67,11 @@ def _show_version() -> None:
58
67
  """Show CLI version and environment, then exit."""
59
68
  from .analytics import _get_cli_version
60
69
  from .global_config import load_global_config
61
-
70
+
62
71
  version = _get_cli_version()
63
72
  config = load_global_config()
64
73
  environment = config.environment
65
-
74
+
66
75
  typer.echo(f"wafer-cli {version} ({environment})")
67
76
  raise typer.Exit()
68
77
 
@@ -110,7 +119,7 @@ def main_callback(
110
119
  if version:
111
120
  _show_version()
112
121
  return
113
-
122
+
114
123
  global _command_start_time, _command_outcome
115
124
  _command_start_time = time.time()
116
125
  _command_outcome = "success" # Default to success, mark failure on exceptions
@@ -121,6 +130,7 @@ def main_callback(
121
130
  analytics.init_analytics()
122
131
 
123
132
  # Install exception hook to catch SystemExit and mark failures
133
+ # Also prints error message FIRST so it's visible even when traceback is truncated
124
134
  original_excepthook = sys.excepthook
125
135
 
126
136
  def custom_excepthook(
@@ -136,7 +146,11 @@ def main_callback(
136
146
  _command_outcome = "failure"
137
147
  else:
138
148
  _command_outcome = "failure"
139
- # Call original excepthook
149
+ # Print error summary FIRST (before traceback) so it's visible even if truncated
150
+ print(
151
+ f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
152
+ )
153
+ # Call original excepthook (prints the full traceback)
140
154
  original_excepthook(exc_type, exc_value, exc_traceback)
141
155
 
142
156
  sys.excepthook = custom_excepthook
@@ -591,7 +605,7 @@ app.add_typer(provider_auth_app, name="auth")
591
605
  def provider_auth_login(
592
606
  provider: str = typer.Argument(
593
607
  ...,
594
- help="Provider name: runpod, digitalocean, or modal",
608
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
595
609
  ),
596
610
  api_key: str | None = typer.Option(
597
611
  None,
@@ -600,15 +614,16 @@ def provider_auth_login(
600
614
  help="API key (if not provided, reads from stdin)",
601
615
  ),
602
616
  ) -> None:
603
- """Save API key for a cloud GPU provider.
617
+ """Save API key for a provider.
604
618
 
605
619
  Stores the key in ~/.wafer/auth.json. Environment variables
606
- (e.g., WAFER_RUNPOD_API_KEY) take precedence over stored keys.
620
+ (e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
607
621
 
608
622
  Examples:
623
+ wafer auth login anthropic --api-key sk-ant-xxx
609
624
  wafer auth login runpod --api-key rp_xxx
610
- wafer auth login digitalocean --api-key dop_v1_xxx
611
- echo $API_KEY | wafer auth login runpod
625
+ wafer auth login openai --api-key sk-xxx
626
+ echo $API_KEY | wafer auth login anthropic
612
627
  """
613
628
  import sys
614
629
 
@@ -642,7 +657,7 @@ def provider_auth_login(
642
657
  def provider_auth_logout(
643
658
  provider: str = typer.Argument(
644
659
  ...,
645
- help="Provider name: runpod, digitalocean, or modal",
660
+ help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
646
661
  ),
647
662
  ) -> None:
648
663
  """Remove stored API key for a cloud GPU provider.
@@ -3473,7 +3488,7 @@ def init_runpod(
3473
3488
  gpu_configs = {
3474
3489
  "MI300X": {
3475
3490
  "gpu_type_id": "AMD Instinct MI300X OAM",
3476
- "image": "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
3491
+ "image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
3477
3492
  "compute_capability": "9.4",
3478
3493
  },
3479
3494
  "H100": {
@@ -3569,7 +3584,7 @@ def init_digitalocean(
3569
3584
  "ssh_key": ssh_key,
3570
3585
  "region": region,
3571
3586
  "size_slug": "gpu-mi300x1-192gb-devcloud",
3572
- "image": "gpu-amd-base",
3587
+ "image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
3573
3588
  "provision_timeout": 600,
3574
3589
  "eval_timeout": 600,
3575
3590
  "keep_alive": keep_alive,
@@ -4071,6 +4086,164 @@ def targets_cleanup(
4071
4086
  raise typer.Exit(1) from None
4072
4087
 
4073
4088
 
4089
+ # Known libraries that can be installed on targets
4090
+ # TODO: Consider adding HipKittens to the default RunPod/DO Docker images
4091
+ # so this install step isn't needed. For now, this command handles it.
4092
+ INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
4093
+ "hipkittens": {
4094
+ "description": "HipKittens - AMD port of ThunderKittens for MI300X",
4095
+ "git_url": "https://github.com/HazyResearch/hipkittens.git",
4096
+ "install_path": "/opt/hipkittens",
4097
+ "requires_amd": True,
4098
+ },
4099
+ # CK is already installed with ROCm 7.0, no action needed
4100
+ "repair-headers": {
4101
+ "description": "Repair ROCm thrust headers (fixes hipify corruption)",
4102
+ "custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
4103
+ "requires_amd": True,
4104
+ },
4105
+ }
4106
+
4107
+
4108
+ @targets_app.command("install")
4109
+ def targets_install(
4110
+ name: str = typer.Argument(..., help="Target name"),
4111
+ library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
4112
+ ) -> None:
4113
+ """Install a library or run maintenance on a target (idempotent).
4114
+
4115
+ Installs header-only libraries like HipKittens on remote targets.
4116
+ Safe to run multiple times - will skip if already installed.
4117
+
4118
+ Available libraries:
4119
+ hipkittens - HipKittens (AMD ThunderKittens port)
4120
+ repair-headers - Fix ROCm thrust headers (after hipify corruption)
4121
+
4122
+ Examples:
4123
+ wafer config targets install runpod-mi300x hipkittens
4124
+ wafer config targets install runpod-mi300x repair-headers
4125
+ wafer config targets install do-mi300x hipkittens
4126
+ """
4127
+ import subprocess
4128
+
4129
+ from .targets import load_target
4130
+ from .targets_ops import get_target_ssh_info
4131
+
4132
+ if library not in INSTALLABLE_LIBRARIES:
4133
+ available = ", ".join(INSTALLABLE_LIBRARIES.keys())
4134
+ typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
4135
+ raise typer.Exit(1)
4136
+
4137
+ lib_info = INSTALLABLE_LIBRARIES[library]
4138
+
4139
+ try:
4140
+ target = load_target(name)
4141
+ except FileNotFoundError as e:
4142
+ typer.echo(f"Error: {e}", err=True)
4143
+ raise typer.Exit(1) from None
4144
+
4145
+ # Check if target is AMD (for AMD-only libraries)
4146
+ if lib_info.get("requires_amd"):
4147
+ from wafer_core.utils.kernel_utils.targets.config import (
4148
+ DigitalOceanTarget,
4149
+ RunPodTarget,
4150
+ )
4151
+
4152
+ is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
4153
+ if not is_amd and hasattr(target, "compute_capability"):
4154
+ # Check compute capability for MI300X (gfx942 = 9.4)
4155
+ is_amd = target.compute_capability.startswith("9.")
4156
+ if not is_amd:
4157
+ typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
4158
+ raise typer.Exit(1)
4159
+
4160
+ typer.echo(f"Installing {library} on {name}...")
4161
+ typer.echo(f" {lib_info['description']}")
4162
+
4163
+ async def _install() -> bool:
4164
+ # get_target_ssh_info uses pure trio async (no asyncio bridging needed)
4165
+ # and we use subprocess for SSH, not AsyncSSHClient
4166
+ ssh_info = await get_target_ssh_info(target)
4167
+
4168
+ ssh_cmd = [
4169
+ "ssh",
4170
+ "-o",
4171
+ "StrictHostKeyChecking=no",
4172
+ "-o",
4173
+ "UserKnownHostsFile=/dev/null",
4174
+ "-o",
4175
+ "ConnectTimeout=30",
4176
+ "-i",
4177
+ str(ssh_info.key_path),
4178
+ "-p",
4179
+ str(ssh_info.port),
4180
+ f"{ssh_info.user}@{ssh_info.host}",
4181
+ ]
4182
+
4183
+ # Handle custom scripts (like repair-headers) vs git installs
4184
+ if "custom_script" in lib_info:
4185
+ install_script = str(lib_info["custom_script"])
4186
+ success_marker = "REPAIRED"
4187
+ else:
4188
+ install_path = lib_info["install_path"]
4189
+ git_url = lib_info["git_url"]
4190
+
4191
+ # Idempotent install script
4192
+ install_script = f"""
4193
+ if [ -d "{install_path}" ]; then
4194
+ echo "ALREADY_INSTALLED: {install_path} exists"
4195
+ cd {install_path} && git pull --quiet 2>/dev/null || true
4196
+ else
4197
+ echo "INSTALLING: cloning to {install_path}"
4198
+ git clone --quiet {git_url} {install_path}
4199
+ fi
4200
+ echo "DONE"
4201
+ """
4202
+ success_marker = "DONE"
4203
+
4204
+ def run_ssh() -> subprocess.CompletedProcess[str]:
4205
+ return subprocess.run(
4206
+ ssh_cmd + [install_script],
4207
+ capture_output=True,
4208
+ text=True,
4209
+ timeout=120,
4210
+ )
4211
+
4212
+ result = await trio.to_thread.run_sync(run_ssh)
4213
+
4214
+ if result.returncode != 0:
4215
+ typer.echo(f"Error: {result.stderr}", err=True)
4216
+ return False
4217
+
4218
+ output = result.stdout.strip()
4219
+ if "ALREADY_INSTALLED" in output:
4220
+ typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
4221
+ elif "INSTALLING" in output:
4222
+ typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
4223
+ elif "REPAIRED" in output:
4224
+ typer.echo(" ROCm headers repaired")
4225
+
4226
+ return success_marker in output
4227
+
4228
+ try:
4229
+ success = trio.run(_install)
4230
+ except Exception as e:
4231
+ typer.echo(f"Error: {e}", err=True)
4232
+ raise typer.Exit(1) from None
4233
+
4234
+ if success:
4235
+ typer.echo(f"✓ {library} ready on {name}")
4236
+
4237
+ # Print usage hint
4238
+ if library == "hipkittens":
4239
+ typer.echo("")
4240
+ typer.echo("Usage in load_inline:")
4241
+ typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
4242
+ else:
4243
+ typer.echo(f"Failed to install {library}", err=True)
4244
+ raise typer.Exit(1)
4245
+
4246
+
4074
4247
  @targets_app.command("pods")
4075
4248
  def targets_pods() -> None:
4076
4249
  """List all running RunPod pods.
@@ -4406,9 +4579,13 @@ def workspaces_list(
4406
4579
  @workspaces_app.command("create")
4407
4580
  def workspaces_create(
4408
4581
  name: str = typer.Argument(..., help="Workspace name"),
4409
- gpu_type: str = typer.Option("B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"),
4582
+ gpu_type: str = typer.Option(
4583
+ "B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
4584
+ ),
4410
4585
  image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
4411
- wait: bool = typer.Option(False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"),
4586
+ wait: bool = typer.Option(
4587
+ False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
4588
+ ),
4412
4589
  json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
4413
4590
  ) -> None:
4414
4591
  """Create a new workspace.
@@ -4717,19 +4894,25 @@ def workspaces_ssh(
4717
4894
  ssh_host = ws.get("ssh_host")
4718
4895
  ssh_port = ws.get("ssh_port")
4719
4896
  ssh_user = ws.get("ssh_user")
4720
-
4897
+
4721
4898
  if not ssh_host or not ssh_port or not ssh_user:
4722
4899
  typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
4723
4900
  raise typer.Exit(1)
4724
4901
 
4725
4902
  # Connect via SSH
4726
- os.execvp("ssh", [
4903
+ os.execvp(
4727
4904
  "ssh",
4728
- "-p", str(ssh_port),
4729
- "-o", "StrictHostKeyChecking=no",
4730
- "-o", "UserKnownHostsFile=/dev/null",
4731
- f"{ssh_user}@{ssh_host}",
4732
- ])
4905
+ [
4906
+ "ssh",
4907
+ "-p",
4908
+ str(ssh_port),
4909
+ "-o",
4910
+ "StrictHostKeyChecking=no",
4911
+ "-o",
4912
+ "UserKnownHostsFile=/dev/null",
4913
+ f"{ssh_user}@{ssh_host}",
4914
+ ],
4915
+ )
4733
4916
 
4734
4917
 
4735
4918
  @workspaces_app.command("sync")
@@ -3,10 +3,12 @@
3
3
  Download and manage documentation corpora for agent filesystem access.
4
4
  """
5
5
 
6
+ import re
6
7
  import shutil
7
8
  import tarfile
8
9
  import tempfile
9
10
  from dataclasses import dataclass
11
+ from html.parser import HTMLParser
10
12
  from pathlib import Path
11
13
  from typing import Literal
12
14
  from urllib.parse import urlparse
@@ -33,7 +35,7 @@ class CorpusConfig:
33
35
 
34
36
  name: CorpusName
35
37
  description: str
36
- source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
38
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
37
39
  urls: list[str] | None = None
38
40
  repo: str | None = None
39
41
  repo_paths: list[str] | None = None
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
67
69
  ),
68
70
  "cutlass": CorpusConfig(
69
71
  name="cutlass",
70
- description="CUTLASS and CuTe DSL documentation",
71
- source_type="github_repo",
72
- repo="NVIDIA/cutlass",
73
- repo_paths=["media/docs", "python/cutlass/docs"],
72
+ description="CUTLASS C++ documentation, examples, and tutorials",
73
+ source_type="mixed",
74
+ # Official NVIDIA CUTLASS documentation (scraped as markdown)
75
+ urls=[
76
+ "https://docs.nvidia.com/cutlass/latest/overview.html",
77
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
78
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
79
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
80
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
81
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
82
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
83
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
84
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
85
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
86
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
87
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
88
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
89
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
90
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
91
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
92
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
93
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
94
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
95
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
96
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
97
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
98
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
99
+ "https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
100
+ ],
101
+ # NVIDIA/cutlass GitHub examples (excluding python/)
102
+ repos=[
103
+ RepoSource(
104
+ repo="NVIDIA/cutlass",
105
+ paths=["examples"],
106
+ branch="main",
107
+ ),
108
+ ],
74
109
  ),
75
110
  "hip": CorpusConfig(
76
111
  name="hip",
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
169
204
  return base_dir / "/".join(path_parts)
170
205
 
171
206
 
207
+ class _HTMLToMarkdown(HTMLParser):
208
+ """HTML to Markdown converter for NVIDIA documentation pages.
209
+
210
+ Uses stdlib HTMLParser - requires subclassing due to callback-based API.
211
+ The public interface is the functional `_html_to_markdown()` below.
212
+ """
213
+
214
+ def __init__(self) -> None:
215
+ super().__init__()
216
+ self.output: list[str] = []
217
+ self.current_tag: str = ""
218
+ self.in_code_block = False
219
+ self.in_pre = False
220
+ self.list_depth = 0
221
+ self.ordered_list_counters: list[int] = []
222
+ self.skip_content = False
223
+ self.link_href: str | None = None
224
+
225
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
226
+ self.current_tag = tag
227
+ attrs_dict = dict(attrs)
228
+
229
+ # Skip script, style, nav, footer, header
230
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
231
+ self.skip_content = True
232
+ return
233
+
234
+ if tag == "h1":
235
+ self.output.append("\n# ")
236
+ elif tag == "h2":
237
+ self.output.append("\n## ")
238
+ elif tag == "h3":
239
+ self.output.append("\n### ")
240
+ elif tag == "h4":
241
+ self.output.append("\n#### ")
242
+ elif tag == "h5":
243
+ self.output.append("\n##### ")
244
+ elif tag == "h6":
245
+ self.output.append("\n###### ")
246
+ elif tag == "p":
247
+ self.output.append("\n\n")
248
+ elif tag == "br":
249
+ self.output.append("\n")
250
+ elif tag == "strong" or tag == "b":
251
+ self.output.append("**")
252
+ elif tag == "em" or tag == "i":
253
+ self.output.append("*")
254
+ elif tag == "code" and not self.in_pre:
255
+ self.output.append("`")
256
+ self.in_code_block = True
257
+ elif tag == "pre":
258
+ self.in_pre = True
259
+ # Check for language hint in class
260
+ lang = ""
261
+ if class_attr := attrs_dict.get("class"):
262
+ if "python" in class_attr.lower():
263
+ lang = "python"
264
+ elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
265
+ lang = "cpp"
266
+ elif "cuda" in class_attr.lower():
267
+ lang = "cuda"
268
+ self.output.append(f"\n```{lang}\n")
269
+ elif tag == "ul":
270
+ self.list_depth += 1
271
+ self.output.append("\n")
272
+ elif tag == "ol":
273
+ self.list_depth += 1
274
+ self.ordered_list_counters.append(1)
275
+ self.output.append("\n")
276
+ elif tag == "li":
277
+ indent = " " * (self.list_depth - 1)
278
+ if self.ordered_list_counters:
279
+ num = self.ordered_list_counters[-1]
280
+ self.output.append(f"{indent}{num}. ")
281
+ self.ordered_list_counters[-1] += 1
282
+ else:
283
+ self.output.append(f"{indent}- ")
284
+ elif tag == "a":
285
+ self.link_href = attrs_dict.get("href")
286
+ self.output.append("[")
287
+ elif tag == "img":
288
+ alt = attrs_dict.get("alt", "image")
289
+ src = attrs_dict.get("src", "")
290
+ self.output.append(f"![{alt}]({src})")
291
+ elif tag == "blockquote":
292
+ self.output.append("\n> ")
293
+ elif tag == "hr":
294
+ self.output.append("\n---\n")
295
+ elif tag == "table":
296
+ self.output.append("\n")
297
+ elif tag == "th":
298
+ self.output.append("| ")
299
+ elif tag == "td":
300
+ self.output.append("| ")
301
+ elif tag == "tr":
302
+ pass # Handled in endtag
303
+
304
+ def handle_endtag(self, tag: str) -> None:
305
+ if tag in ("script", "style", "nav", "footer", "header", "aside"):
306
+ self.skip_content = False
307
+ return
308
+
309
+ if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
310
+ self.output.append("\n")
311
+ elif tag == "strong" or tag == "b":
312
+ self.output.append("**")
313
+ elif tag == "em" or tag == "i":
314
+ self.output.append("*")
315
+ elif tag == "code" and not self.in_pre:
316
+ self.output.append("`")
317
+ self.in_code_block = False
318
+ elif tag == "pre":
319
+ self.in_pre = False
320
+ self.output.append("\n```\n")
321
+ elif tag == "ul":
322
+ self.list_depth = max(0, self.list_depth - 1)
323
+ elif tag == "ol":
324
+ self.list_depth = max(0, self.list_depth - 1)
325
+ if self.ordered_list_counters:
326
+ self.ordered_list_counters.pop()
327
+ elif tag == "li":
328
+ self.output.append("\n")
329
+ elif tag == "a":
330
+ if self.link_href:
331
+ self.output.append(f"]({self.link_href})")
332
+ else:
333
+ self.output.append("]")
334
+ self.link_href = None
335
+ elif tag == "p":
336
+ self.output.append("\n")
337
+ elif tag == "blockquote":
338
+ self.output.append("\n")
339
+ elif tag == "tr":
340
+ self.output.append("|\n")
341
+ elif tag == "thead":
342
+ # Add markdown table separator after header row
343
+ self.output.append("|---" * 10 + "|\n")
344
+
345
+ def handle_data(self, data: str) -> None:
346
+ if self.skip_content:
347
+ return
348
+ # Preserve whitespace in code blocks
349
+ if self.in_pre:
350
+ self.output.append(data)
351
+ else:
352
+ # Collapse whitespace outside code
353
+ text = re.sub(r"\s+", " ", data)
354
+ if text.strip():
355
+ self.output.append(text)
356
+
357
+ def get_markdown(self) -> str:
358
+ """Get the converted markdown, cleaned up."""
359
+ md = "".join(self.output)
360
+ # Clean up excessive newlines
361
+ md = re.sub(r"\n{3,}", "\n\n", md)
362
+ # Clean up empty table separators
363
+ md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
364
+ return md.strip()
365
+
366
+
367
+ def _html_to_markdown(html: str) -> str:
368
+ """Convert HTML to Markdown."""
369
+ parser = _HTMLToMarkdown()
370
+ parser.feed(html)
371
+ return parser.get_markdown()
372
+
373
+
172
374
  def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
173
- """Download NVIDIA docs using .md endpoint."""
375
+ """Download NVIDIA docs and convert HTML to Markdown.
376
+
377
+ NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
378
+ """
174
379
  assert config.urls is not None
175
380
  downloaded = 0
176
381
  with httpx.Client(timeout=30.0, follow_redirects=True) as client:
177
382
  for url in config.urls:
178
- md_url = f"{url}.md"
179
383
  filepath = _url_to_filepath(url, dest)
180
384
  filepath.parent.mkdir(parents=True, exist_ok=True)
181
385
  try:
182
- resp = client.get(md_url)
386
+ # Fetch HTML page directly
387
+ resp = client.get(url)
183
388
  resp.raise_for_status()
184
- filepath.write_text(resp.text)
389
+
390
+ # Convert HTML to Markdown
391
+ markdown = _html_to_markdown(resp.text)
392
+
393
+ # Add source URL as header
394
+ content = f"<!-- Source: {url} -->\n\n{markdown}"
395
+ filepath.write_text(content)
185
396
  downloaded += 1
186
397
  if verbose:
187
398
  print(f" ✓ {filepath.relative_to(dest)}")
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
275
486
  return downloaded
276
487
 
277
488
 
489
+ def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
490
+ """Download from mixed sources (NVIDIA docs + GitHub repos)."""
491
+ total = 0
492
+
493
+ # Download NVIDIA markdown docs (urls)
494
+ if config.urls:
495
+ if verbose:
496
+ print(" [NVIDIA docs]")
497
+ total += _download_nvidia_md(config, dest, verbose)
498
+
499
+ # Download GitHub repos
500
+ if config.repos:
501
+ if verbose:
502
+ print(" [GitHub repos]")
503
+ total += _download_github_multi_repo(config, dest, verbose)
504
+
505
+ return total
506
+
507
+
278
508
  def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
279
509
  """Download a corpus to local cache.
280
510
 
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
311
541
  count = _download_github_repo(config, dest, verbose)
312
542
  elif config.source_type == "github_multi_repo":
313
543
  count = _download_github_multi_repo(config, dest, verbose)
544
+ elif config.source_type == "mixed":
545
+ count = _download_mixed(config, dest, verbose)
314
546
  else:
315
547
  raise ValueError(f"Unknown source type: {config.source_type}")
316
548
  if verbose: