wafer-cli 0.2.20__tar.gz → 0.2.22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/PKG-INFO +1 -1
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/pyproject.toml +1 -1
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/cli.py +205 -22
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/corpus.py +241 -9
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/evaluate.py +426 -8
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/optimize_kernel.py +2 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/wevin_cli.py +39 -16
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/PKG-INFO +1 -1
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/README.md +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/setup.cfg +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_analytics.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_auth.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_billing.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_cli_coverage.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_cli_parity_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_config_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_file_operations_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_kernel_scope_cli.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_nsys_analyze.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_nsys_profile.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_output.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_rocprof_compute_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_skill_commands.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_ssh_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_targets_ops.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_wevin_cli.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/tests/test_workflow_integration.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/GUIDE.md +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/__init__.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/analytics.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/api_client.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/auth.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/autotuner.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/billing.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/config.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/global_config.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/gpu_run.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/inference.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/kernel_scope.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/ncu_analyze.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/nsys_analyze.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/nsys_profile.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/output.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/problems.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_compute.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_sdk.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/rocprof_systems.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/skills/wafer-guide/SKILL.md +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/ssh_keys.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/target_lock.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/targets.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/targets_ops.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/__init__.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/ask_docs.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/optimize_kernelbench.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/templates/trace_analyze.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/tracelens.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer/workspaces.py +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/SOURCES.txt +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/dependency_links.txt +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/entry_points.txt +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/requires.txt +0 -0
- {wafer_cli-0.2.20 → wafer_cli-0.2.22}/wafer_cli.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
|
-
# ruff: noqa: PLR0913
|
|
1
|
+
# ruff: noqa: PLR0913, E402
|
|
2
2
|
# PLR0913 (too many arguments) is suppressed because Typer CLI commands
|
|
3
3
|
# naturally have many parameters - each --flag becomes a function argument.
|
|
4
|
+
# E402 (module level import not at top) is suppressed because we intentionally
|
|
5
|
+
# load .env files before importing other modules that may read env vars.
|
|
4
6
|
"""Wafer CLI - GPU development toolkit for LLM coding agents.
|
|
5
7
|
|
|
6
8
|
Core commands:
|
|
@@ -27,6 +29,12 @@ from pathlib import Path
|
|
|
27
29
|
|
|
28
30
|
import trio
|
|
29
31
|
import typer
|
|
32
|
+
from dotenv import load_dotenv
|
|
33
|
+
|
|
34
|
+
# Auto-load .env from current directory and ~/.wafer/.env
|
|
35
|
+
# This runs at import time so env vars are available before any config is accessed
|
|
36
|
+
load_dotenv() # cwd/.env
|
|
37
|
+
load_dotenv(Path.home() / ".wafer" / ".env") # ~/.wafer/.env
|
|
30
38
|
|
|
31
39
|
from .config import WaferConfig, WaferEnvironment
|
|
32
40
|
from .inference import infer_upload_files, resolve_environment
|
|
@@ -42,6 +50,7 @@ from .problems import (
|
|
|
42
50
|
app = typer.Typer(
|
|
43
51
|
help="GPU development toolkit for LLM coding agents",
|
|
44
52
|
no_args_is_help=True,
|
|
53
|
+
pretty_exceptions_show_locals=False, # Don't dump local vars (makes tracebacks huge)
|
|
45
54
|
)
|
|
46
55
|
|
|
47
56
|
# =============================================================================
|
|
@@ -58,11 +67,11 @@ def _show_version() -> None:
|
|
|
58
67
|
"""Show CLI version and environment, then exit."""
|
|
59
68
|
from .analytics import _get_cli_version
|
|
60
69
|
from .global_config import load_global_config
|
|
61
|
-
|
|
70
|
+
|
|
62
71
|
version = _get_cli_version()
|
|
63
72
|
config = load_global_config()
|
|
64
73
|
environment = config.environment
|
|
65
|
-
|
|
74
|
+
|
|
66
75
|
typer.echo(f"wafer-cli {version} ({environment})")
|
|
67
76
|
raise typer.Exit()
|
|
68
77
|
|
|
@@ -110,7 +119,7 @@ def main_callback(
|
|
|
110
119
|
if version:
|
|
111
120
|
_show_version()
|
|
112
121
|
return
|
|
113
|
-
|
|
122
|
+
|
|
114
123
|
global _command_start_time, _command_outcome
|
|
115
124
|
_command_start_time = time.time()
|
|
116
125
|
_command_outcome = "success" # Default to success, mark failure on exceptions
|
|
@@ -121,6 +130,7 @@ def main_callback(
|
|
|
121
130
|
analytics.init_analytics()
|
|
122
131
|
|
|
123
132
|
# Install exception hook to catch SystemExit and mark failures
|
|
133
|
+
# Also prints error message FIRST so it's visible even when traceback is truncated
|
|
124
134
|
original_excepthook = sys.excepthook
|
|
125
135
|
|
|
126
136
|
def custom_excepthook(
|
|
@@ -136,7 +146,11 @@ def main_callback(
|
|
|
136
146
|
_command_outcome = "failure"
|
|
137
147
|
else:
|
|
138
148
|
_command_outcome = "failure"
|
|
139
|
-
|
|
149
|
+
# Print error summary FIRST (before traceback) so it's visible even if truncated
|
|
150
|
+
print(
|
|
151
|
+
f"\n\033[1;31m>>> ERROR: {exc_type.__name__}: {exc_value}\033[0m\n", file=sys.stderr
|
|
152
|
+
)
|
|
153
|
+
# Call original excepthook (prints the full traceback)
|
|
140
154
|
original_excepthook(exc_type, exc_value, exc_traceback)
|
|
141
155
|
|
|
142
156
|
sys.excepthook = custom_excepthook
|
|
@@ -591,7 +605,7 @@ app.add_typer(provider_auth_app, name="auth")
|
|
|
591
605
|
def provider_auth_login(
|
|
592
606
|
provider: str = typer.Argument(
|
|
593
607
|
...,
|
|
594
|
-
help="Provider name: runpod, digitalocean, or
|
|
608
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
595
609
|
),
|
|
596
610
|
api_key: str | None = typer.Option(
|
|
597
611
|
None,
|
|
@@ -600,15 +614,16 @@ def provider_auth_login(
|
|
|
600
614
|
help="API key (if not provided, reads from stdin)",
|
|
601
615
|
),
|
|
602
616
|
) -> None:
|
|
603
|
-
"""Save API key for a
|
|
617
|
+
"""Save API key for a provider.
|
|
604
618
|
|
|
605
619
|
Stores the key in ~/.wafer/auth.json. Environment variables
|
|
606
|
-
(e.g.,
|
|
620
|
+
(e.g., ANTHROPIC_API_KEY) take precedence over stored keys.
|
|
607
621
|
|
|
608
622
|
Examples:
|
|
623
|
+
wafer auth login anthropic --api-key sk-ant-xxx
|
|
609
624
|
wafer auth login runpod --api-key rp_xxx
|
|
610
|
-
wafer auth login
|
|
611
|
-
echo $API_KEY | wafer auth login
|
|
625
|
+
wafer auth login openai --api-key sk-xxx
|
|
626
|
+
echo $API_KEY | wafer auth login anthropic
|
|
612
627
|
"""
|
|
613
628
|
import sys
|
|
614
629
|
|
|
@@ -642,7 +657,7 @@ def provider_auth_login(
|
|
|
642
657
|
def provider_auth_logout(
|
|
643
658
|
provider: str = typer.Argument(
|
|
644
659
|
...,
|
|
645
|
-
help="Provider name: runpod, digitalocean, or
|
|
660
|
+
help="Provider name: runpod, digitalocean, modal, anthropic, or openai",
|
|
646
661
|
),
|
|
647
662
|
) -> None:
|
|
648
663
|
"""Remove stored API key for a cloud GPU provider.
|
|
@@ -3473,7 +3488,7 @@ def init_runpod(
|
|
|
3473
3488
|
gpu_configs = {
|
|
3474
3489
|
"MI300X": {
|
|
3475
3490
|
"gpu_type_id": "AMD Instinct MI300X OAM",
|
|
3476
|
-
"image": "
|
|
3491
|
+
"image": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
|
|
3477
3492
|
"compute_capability": "9.4",
|
|
3478
3493
|
},
|
|
3479
3494
|
"H100": {
|
|
@@ -3569,7 +3584,7 @@ def init_digitalocean(
|
|
|
3569
3584
|
"ssh_key": ssh_key,
|
|
3570
3585
|
"region": region,
|
|
3571
3586
|
"size_slug": "gpu-mi300x1-192gb-devcloud",
|
|
3572
|
-
"image": "
|
|
3587
|
+
"image": "amd-pytorchrocm7", # PyTorch (ROCm7) marketplace image
|
|
3573
3588
|
"provision_timeout": 600,
|
|
3574
3589
|
"eval_timeout": 600,
|
|
3575
3590
|
"keep_alive": keep_alive,
|
|
@@ -4071,6 +4086,164 @@ def targets_cleanup(
|
|
|
4071
4086
|
raise typer.Exit(1) from None
|
|
4072
4087
|
|
|
4073
4088
|
|
|
4089
|
+
# Known libraries that can be installed on targets
|
|
4090
|
+
# TODO: Consider adding HipKittens to the default RunPod/DO Docker images
|
|
4091
|
+
# so this install step isn't needed. For now, this command handles it.
|
|
4092
|
+
INSTALLABLE_LIBRARIES: dict[str, dict[str, object]] = {
|
|
4093
|
+
"hipkittens": {
|
|
4094
|
+
"description": "HipKittens - AMD port of ThunderKittens for MI300X",
|
|
4095
|
+
"git_url": "https://github.com/HazyResearch/hipkittens.git",
|
|
4096
|
+
"install_path": "/opt/hipkittens",
|
|
4097
|
+
"requires_amd": True,
|
|
4098
|
+
},
|
|
4099
|
+
# CK is already installed with ROCm 7.0, no action needed
|
|
4100
|
+
"repair-headers": {
|
|
4101
|
+
"description": "Repair ROCm thrust headers (fixes hipify corruption)",
|
|
4102
|
+
"custom_script": "apt-get update -qq && apt-get install --reinstall -y rocthrust >/dev/null 2>&1 && echo REPAIRED",
|
|
4103
|
+
"requires_amd": True,
|
|
4104
|
+
},
|
|
4105
|
+
}
|
|
4106
|
+
|
|
4107
|
+
|
|
4108
|
+
@targets_app.command("install")
|
|
4109
|
+
def targets_install(
|
|
4110
|
+
name: str = typer.Argument(..., help="Target name"),
|
|
4111
|
+
library: str = typer.Argument(..., help="Library to install (hipkittens, repair-headers)"),
|
|
4112
|
+
) -> None:
|
|
4113
|
+
"""Install a library or run maintenance on a target (idempotent).
|
|
4114
|
+
|
|
4115
|
+
Installs header-only libraries like HipKittens on remote targets.
|
|
4116
|
+
Safe to run multiple times - will skip if already installed.
|
|
4117
|
+
|
|
4118
|
+
Available libraries:
|
|
4119
|
+
hipkittens - HipKittens (AMD ThunderKittens port)
|
|
4120
|
+
repair-headers - Fix ROCm thrust headers (after hipify corruption)
|
|
4121
|
+
|
|
4122
|
+
Examples:
|
|
4123
|
+
wafer config targets install runpod-mi300x hipkittens
|
|
4124
|
+
wafer config targets install runpod-mi300x repair-headers
|
|
4125
|
+
wafer config targets install do-mi300x hipkittens
|
|
4126
|
+
"""
|
|
4127
|
+
import subprocess
|
|
4128
|
+
|
|
4129
|
+
from .targets import load_target
|
|
4130
|
+
from .targets_ops import get_target_ssh_info
|
|
4131
|
+
|
|
4132
|
+
if library not in INSTALLABLE_LIBRARIES:
|
|
4133
|
+
available = ", ".join(INSTALLABLE_LIBRARIES.keys())
|
|
4134
|
+
typer.echo(f"Error: Unknown library '{library}'. Available: {available}", err=True)
|
|
4135
|
+
raise typer.Exit(1)
|
|
4136
|
+
|
|
4137
|
+
lib_info = INSTALLABLE_LIBRARIES[library]
|
|
4138
|
+
|
|
4139
|
+
try:
|
|
4140
|
+
target = load_target(name)
|
|
4141
|
+
except FileNotFoundError as e:
|
|
4142
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4143
|
+
raise typer.Exit(1) from None
|
|
4144
|
+
|
|
4145
|
+
# Check if target is AMD (for AMD-only libraries)
|
|
4146
|
+
if lib_info.get("requires_amd"):
|
|
4147
|
+
from wafer_core.utils.kernel_utils.targets.config import (
|
|
4148
|
+
DigitalOceanTarget,
|
|
4149
|
+
RunPodTarget,
|
|
4150
|
+
)
|
|
4151
|
+
|
|
4152
|
+
is_amd = isinstance(target, (RunPodTarget, DigitalOceanTarget))
|
|
4153
|
+
if not is_amd and hasattr(target, "compute_capability"):
|
|
4154
|
+
# Check compute capability for MI300X (gfx942 = 9.4)
|
|
4155
|
+
is_amd = target.compute_capability.startswith("9.")
|
|
4156
|
+
if not is_amd:
|
|
4157
|
+
typer.echo(f"Error: {library} requires an AMD GPU target", err=True)
|
|
4158
|
+
raise typer.Exit(1)
|
|
4159
|
+
|
|
4160
|
+
typer.echo(f"Installing {library} on {name}...")
|
|
4161
|
+
typer.echo(f" {lib_info['description']}")
|
|
4162
|
+
|
|
4163
|
+
async def _install() -> bool:
|
|
4164
|
+
# get_target_ssh_info uses pure trio async (no asyncio bridging needed)
|
|
4165
|
+
# and we use subprocess for SSH, not AsyncSSHClient
|
|
4166
|
+
ssh_info = await get_target_ssh_info(target)
|
|
4167
|
+
|
|
4168
|
+
ssh_cmd = [
|
|
4169
|
+
"ssh",
|
|
4170
|
+
"-o",
|
|
4171
|
+
"StrictHostKeyChecking=no",
|
|
4172
|
+
"-o",
|
|
4173
|
+
"UserKnownHostsFile=/dev/null",
|
|
4174
|
+
"-o",
|
|
4175
|
+
"ConnectTimeout=30",
|
|
4176
|
+
"-i",
|
|
4177
|
+
str(ssh_info.key_path),
|
|
4178
|
+
"-p",
|
|
4179
|
+
str(ssh_info.port),
|
|
4180
|
+
f"{ssh_info.user}@{ssh_info.host}",
|
|
4181
|
+
]
|
|
4182
|
+
|
|
4183
|
+
# Handle custom scripts (like repair-headers) vs git installs
|
|
4184
|
+
if "custom_script" in lib_info:
|
|
4185
|
+
install_script = str(lib_info["custom_script"])
|
|
4186
|
+
success_marker = "REPAIRED"
|
|
4187
|
+
else:
|
|
4188
|
+
install_path = lib_info["install_path"]
|
|
4189
|
+
git_url = lib_info["git_url"]
|
|
4190
|
+
|
|
4191
|
+
# Idempotent install script
|
|
4192
|
+
install_script = f"""
|
|
4193
|
+
if [ -d "{install_path}" ]; then
|
|
4194
|
+
echo "ALREADY_INSTALLED: {install_path} exists"
|
|
4195
|
+
cd {install_path} && git pull --quiet 2>/dev/null || true
|
|
4196
|
+
else
|
|
4197
|
+
echo "INSTALLING: cloning to {install_path}"
|
|
4198
|
+
git clone --quiet {git_url} {install_path}
|
|
4199
|
+
fi
|
|
4200
|
+
echo "DONE"
|
|
4201
|
+
"""
|
|
4202
|
+
success_marker = "DONE"
|
|
4203
|
+
|
|
4204
|
+
def run_ssh() -> subprocess.CompletedProcess[str]:
|
|
4205
|
+
return subprocess.run(
|
|
4206
|
+
ssh_cmd + [install_script],
|
|
4207
|
+
capture_output=True,
|
|
4208
|
+
text=True,
|
|
4209
|
+
timeout=120,
|
|
4210
|
+
)
|
|
4211
|
+
|
|
4212
|
+
result = await trio.to_thread.run_sync(run_ssh)
|
|
4213
|
+
|
|
4214
|
+
if result.returncode != 0:
|
|
4215
|
+
typer.echo(f"Error: {result.stderr}", err=True)
|
|
4216
|
+
return False
|
|
4217
|
+
|
|
4218
|
+
output = result.stdout.strip()
|
|
4219
|
+
if "ALREADY_INSTALLED" in output:
|
|
4220
|
+
typer.echo(f" Already installed at {lib_info.get('install_path', 'N/A')}")
|
|
4221
|
+
elif "INSTALLING" in output:
|
|
4222
|
+
typer.echo(f" Installed to {lib_info.get('install_path', 'N/A')}")
|
|
4223
|
+
elif "REPAIRED" in output:
|
|
4224
|
+
typer.echo(" ROCm headers repaired")
|
|
4225
|
+
|
|
4226
|
+
return success_marker in output
|
|
4227
|
+
|
|
4228
|
+
try:
|
|
4229
|
+
success = trio.run(_install)
|
|
4230
|
+
except Exception as e:
|
|
4231
|
+
typer.echo(f"Error: {e}", err=True)
|
|
4232
|
+
raise typer.Exit(1) from None
|
|
4233
|
+
|
|
4234
|
+
if success:
|
|
4235
|
+
typer.echo(f"✓ {library} ready on {name}")
|
|
4236
|
+
|
|
4237
|
+
# Print usage hint
|
|
4238
|
+
if library == "hipkittens":
|
|
4239
|
+
typer.echo("")
|
|
4240
|
+
typer.echo("Usage in load_inline:")
|
|
4241
|
+
typer.echo(' extra_include_paths=["/opt/hipkittens/include", "/opt/rocm/include/hip"]')
|
|
4242
|
+
else:
|
|
4243
|
+
typer.echo(f"Failed to install {library}", err=True)
|
|
4244
|
+
raise typer.Exit(1)
|
|
4245
|
+
|
|
4246
|
+
|
|
4074
4247
|
@targets_app.command("pods")
|
|
4075
4248
|
def targets_pods() -> None:
|
|
4076
4249
|
"""List all running RunPod pods.
|
|
@@ -4406,9 +4579,13 @@ def workspaces_list(
|
|
|
4406
4579
|
@workspaces_app.command("create")
|
|
4407
4580
|
def workspaces_create(
|
|
4408
4581
|
name: str = typer.Argument(..., help="Workspace name"),
|
|
4409
|
-
gpu_type: str = typer.Option(
|
|
4582
|
+
gpu_type: str = typer.Option(
|
|
4583
|
+
"B200", "--gpu", "-g", help="GPU type: MI300X (AMD) or B200 (NVIDIA, default)"
|
|
4584
|
+
),
|
|
4410
4585
|
image: str | None = typer.Option(None, "--image", "-i", help="Docker image (optional)"),
|
|
4411
|
-
wait: bool = typer.Option(
|
|
4586
|
+
wait: bool = typer.Option(
|
|
4587
|
+
False, "--wait", "-w", help="Wait for provisioning and show SSH credentials"
|
|
4588
|
+
),
|
|
4412
4589
|
json_output: bool = typer.Option(False, "--json", "-j", help="Output as JSON"),
|
|
4413
4590
|
) -> None:
|
|
4414
4591
|
"""Create a new workspace.
|
|
@@ -4717,19 +4894,25 @@ def workspaces_ssh(
|
|
|
4717
4894
|
ssh_host = ws.get("ssh_host")
|
|
4718
4895
|
ssh_port = ws.get("ssh_port")
|
|
4719
4896
|
ssh_user = ws.get("ssh_user")
|
|
4720
|
-
|
|
4897
|
+
|
|
4721
4898
|
if not ssh_host or not ssh_port or not ssh_user:
|
|
4722
4899
|
typer.echo("Error: Workspace not ready. Wait a few seconds and retry.", err=True)
|
|
4723
4900
|
raise typer.Exit(1)
|
|
4724
4901
|
|
|
4725
4902
|
# Connect via SSH
|
|
4726
|
-
os.execvp(
|
|
4903
|
+
os.execvp(
|
|
4727
4904
|
"ssh",
|
|
4728
|
-
|
|
4729
|
-
|
|
4730
|
-
|
|
4731
|
-
|
|
4732
|
-
|
|
4905
|
+
[
|
|
4906
|
+
"ssh",
|
|
4907
|
+
"-p",
|
|
4908
|
+
str(ssh_port),
|
|
4909
|
+
"-o",
|
|
4910
|
+
"StrictHostKeyChecking=no",
|
|
4911
|
+
"-o",
|
|
4912
|
+
"UserKnownHostsFile=/dev/null",
|
|
4913
|
+
f"{ssh_user}@{ssh_host}",
|
|
4914
|
+
],
|
|
4915
|
+
)
|
|
4733
4916
|
|
|
4734
4917
|
|
|
4735
4918
|
@workspaces_app.command("sync")
|
|
@@ -3,10 +3,12 @@
|
|
|
3
3
|
Download and manage documentation corpora for agent filesystem access.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import re
|
|
6
7
|
import shutil
|
|
7
8
|
import tarfile
|
|
8
9
|
import tempfile
|
|
9
10
|
from dataclasses import dataclass
|
|
11
|
+
from html.parser import HTMLParser
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Literal
|
|
12
14
|
from urllib.parse import urlparse
|
|
@@ -33,7 +35,7 @@ class CorpusConfig:
|
|
|
33
35
|
|
|
34
36
|
name: CorpusName
|
|
35
37
|
description: str
|
|
36
|
-
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
|
|
38
|
+
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
|
|
37
39
|
urls: list[str] | None = None
|
|
38
40
|
repo: str | None = None
|
|
39
41
|
repo_paths: list[str] | None = None
|
|
@@ -67,10 +69,43 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
67
69
|
),
|
|
68
70
|
"cutlass": CorpusConfig(
|
|
69
71
|
name="cutlass",
|
|
70
|
-
description="CUTLASS
|
|
71
|
-
source_type="
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
description="CUTLASS C++ documentation, examples, and tutorials",
|
|
73
|
+
source_type="mixed",
|
|
74
|
+
# Official NVIDIA CUTLASS documentation (scraped as markdown)
|
|
75
|
+
urls=[
|
|
76
|
+
"https://docs.nvidia.com/cutlass/latest/overview.html",
|
|
77
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
|
|
78
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
|
|
79
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
|
|
80
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
|
|
81
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
|
|
82
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
|
|
83
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
|
|
84
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
|
|
85
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
|
|
86
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
|
|
87
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
|
|
88
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
|
|
89
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
|
|
90
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
|
|
91
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
|
|
92
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
|
|
93
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
|
|
94
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
|
|
95
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
|
|
96
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
|
|
97
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
|
|
98
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
|
|
99
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
|
|
100
|
+
],
|
|
101
|
+
# NVIDIA/cutlass GitHub examples (excluding python/)
|
|
102
|
+
repos=[
|
|
103
|
+
RepoSource(
|
|
104
|
+
repo="NVIDIA/cutlass",
|
|
105
|
+
paths=["examples"],
|
|
106
|
+
branch="main",
|
|
107
|
+
),
|
|
108
|
+
],
|
|
74
109
|
),
|
|
75
110
|
"hip": CorpusConfig(
|
|
76
111
|
name="hip",
|
|
@@ -169,19 +204,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
|
|
|
169
204
|
return base_dir / "/".join(path_parts)
|
|
170
205
|
|
|
171
206
|
|
|
207
|
+
class _HTMLToMarkdown(HTMLParser):
|
|
208
|
+
"""HTML to Markdown converter for NVIDIA documentation pages.
|
|
209
|
+
|
|
210
|
+
Uses stdlib HTMLParser - requires subclassing due to callback-based API.
|
|
211
|
+
The public interface is the functional `_html_to_markdown()` below.
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
def __init__(self) -> None:
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.output: list[str] = []
|
|
217
|
+
self.current_tag: str = ""
|
|
218
|
+
self.in_code_block = False
|
|
219
|
+
self.in_pre = False
|
|
220
|
+
self.list_depth = 0
|
|
221
|
+
self.ordered_list_counters: list[int] = []
|
|
222
|
+
self.skip_content = False
|
|
223
|
+
self.link_href: str | None = None
|
|
224
|
+
|
|
225
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
226
|
+
self.current_tag = tag
|
|
227
|
+
attrs_dict = dict(attrs)
|
|
228
|
+
|
|
229
|
+
# Skip script, style, nav, footer, header
|
|
230
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
231
|
+
self.skip_content = True
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
if tag == "h1":
|
|
235
|
+
self.output.append("\n# ")
|
|
236
|
+
elif tag == "h2":
|
|
237
|
+
self.output.append("\n## ")
|
|
238
|
+
elif tag == "h3":
|
|
239
|
+
self.output.append("\n### ")
|
|
240
|
+
elif tag == "h4":
|
|
241
|
+
self.output.append("\n#### ")
|
|
242
|
+
elif tag == "h5":
|
|
243
|
+
self.output.append("\n##### ")
|
|
244
|
+
elif tag == "h6":
|
|
245
|
+
self.output.append("\n###### ")
|
|
246
|
+
elif tag == "p":
|
|
247
|
+
self.output.append("\n\n")
|
|
248
|
+
elif tag == "br":
|
|
249
|
+
self.output.append("\n")
|
|
250
|
+
elif tag == "strong" or tag == "b":
|
|
251
|
+
self.output.append("**")
|
|
252
|
+
elif tag == "em" or tag == "i":
|
|
253
|
+
self.output.append("*")
|
|
254
|
+
elif tag == "code" and not self.in_pre:
|
|
255
|
+
self.output.append("`")
|
|
256
|
+
self.in_code_block = True
|
|
257
|
+
elif tag == "pre":
|
|
258
|
+
self.in_pre = True
|
|
259
|
+
# Check for language hint in class
|
|
260
|
+
lang = ""
|
|
261
|
+
if class_attr := attrs_dict.get("class"):
|
|
262
|
+
if "python" in class_attr.lower():
|
|
263
|
+
lang = "python"
|
|
264
|
+
elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
|
|
265
|
+
lang = "cpp"
|
|
266
|
+
elif "cuda" in class_attr.lower():
|
|
267
|
+
lang = "cuda"
|
|
268
|
+
self.output.append(f"\n```{lang}\n")
|
|
269
|
+
elif tag == "ul":
|
|
270
|
+
self.list_depth += 1
|
|
271
|
+
self.output.append("\n")
|
|
272
|
+
elif tag == "ol":
|
|
273
|
+
self.list_depth += 1
|
|
274
|
+
self.ordered_list_counters.append(1)
|
|
275
|
+
self.output.append("\n")
|
|
276
|
+
elif tag == "li":
|
|
277
|
+
indent = " " * (self.list_depth - 1)
|
|
278
|
+
if self.ordered_list_counters:
|
|
279
|
+
num = self.ordered_list_counters[-1]
|
|
280
|
+
self.output.append(f"{indent}{num}. ")
|
|
281
|
+
self.ordered_list_counters[-1] += 1
|
|
282
|
+
else:
|
|
283
|
+
self.output.append(f"{indent}- ")
|
|
284
|
+
elif tag == "a":
|
|
285
|
+
self.link_href = attrs_dict.get("href")
|
|
286
|
+
self.output.append("[")
|
|
287
|
+
elif tag == "img":
|
|
288
|
+
alt = attrs_dict.get("alt", "image")
|
|
289
|
+
src = attrs_dict.get("src", "")
|
|
290
|
+
self.output.append(f"")
|
|
291
|
+
elif tag == "blockquote":
|
|
292
|
+
self.output.append("\n> ")
|
|
293
|
+
elif tag == "hr":
|
|
294
|
+
self.output.append("\n---\n")
|
|
295
|
+
elif tag == "table":
|
|
296
|
+
self.output.append("\n")
|
|
297
|
+
elif tag == "th":
|
|
298
|
+
self.output.append("| ")
|
|
299
|
+
elif tag == "td":
|
|
300
|
+
self.output.append("| ")
|
|
301
|
+
elif tag == "tr":
|
|
302
|
+
pass # Handled in endtag
|
|
303
|
+
|
|
304
|
+
def handle_endtag(self, tag: str) -> None:
|
|
305
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
306
|
+
self.skip_content = False
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
|
|
310
|
+
self.output.append("\n")
|
|
311
|
+
elif tag == "strong" or tag == "b":
|
|
312
|
+
self.output.append("**")
|
|
313
|
+
elif tag == "em" or tag == "i":
|
|
314
|
+
self.output.append("*")
|
|
315
|
+
elif tag == "code" and not self.in_pre:
|
|
316
|
+
self.output.append("`")
|
|
317
|
+
self.in_code_block = False
|
|
318
|
+
elif tag == "pre":
|
|
319
|
+
self.in_pre = False
|
|
320
|
+
self.output.append("\n```\n")
|
|
321
|
+
elif tag == "ul":
|
|
322
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
323
|
+
elif tag == "ol":
|
|
324
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
325
|
+
if self.ordered_list_counters:
|
|
326
|
+
self.ordered_list_counters.pop()
|
|
327
|
+
elif tag == "li":
|
|
328
|
+
self.output.append("\n")
|
|
329
|
+
elif tag == "a":
|
|
330
|
+
if self.link_href:
|
|
331
|
+
self.output.append(f"]({self.link_href})")
|
|
332
|
+
else:
|
|
333
|
+
self.output.append("]")
|
|
334
|
+
self.link_href = None
|
|
335
|
+
elif tag == "p":
|
|
336
|
+
self.output.append("\n")
|
|
337
|
+
elif tag == "blockquote":
|
|
338
|
+
self.output.append("\n")
|
|
339
|
+
elif tag == "tr":
|
|
340
|
+
self.output.append("|\n")
|
|
341
|
+
elif tag == "thead":
|
|
342
|
+
# Add markdown table separator after header row
|
|
343
|
+
self.output.append("|---" * 10 + "|\n")
|
|
344
|
+
|
|
345
|
+
def handle_data(self, data: str) -> None:
|
|
346
|
+
if self.skip_content:
|
|
347
|
+
return
|
|
348
|
+
# Preserve whitespace in code blocks
|
|
349
|
+
if self.in_pre:
|
|
350
|
+
self.output.append(data)
|
|
351
|
+
else:
|
|
352
|
+
# Collapse whitespace outside code
|
|
353
|
+
text = re.sub(r"\s+", " ", data)
|
|
354
|
+
if text.strip():
|
|
355
|
+
self.output.append(text)
|
|
356
|
+
|
|
357
|
+
def get_markdown(self) -> str:
|
|
358
|
+
"""Get the converted markdown, cleaned up."""
|
|
359
|
+
md = "".join(self.output)
|
|
360
|
+
# Clean up excessive newlines
|
|
361
|
+
md = re.sub(r"\n{3,}", "\n\n", md)
|
|
362
|
+
# Clean up empty table separators
|
|
363
|
+
md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
|
|
364
|
+
return md.strip()
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _html_to_markdown(html: str) -> str:
|
|
368
|
+
"""Convert HTML to Markdown."""
|
|
369
|
+
parser = _HTMLToMarkdown()
|
|
370
|
+
parser.feed(html)
|
|
371
|
+
return parser.get_markdown()
|
|
372
|
+
|
|
373
|
+
|
|
172
374
|
def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
173
|
-
"""Download NVIDIA docs
|
|
375
|
+
"""Download NVIDIA docs and convert HTML to Markdown.
|
|
376
|
+
|
|
377
|
+
NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
|
|
378
|
+
"""
|
|
174
379
|
assert config.urls is not None
|
|
175
380
|
downloaded = 0
|
|
176
381
|
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
|
177
382
|
for url in config.urls:
|
|
178
|
-
md_url = f"{url}.md"
|
|
179
383
|
filepath = _url_to_filepath(url, dest)
|
|
180
384
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
181
385
|
try:
|
|
182
|
-
|
|
386
|
+
# Fetch HTML page directly
|
|
387
|
+
resp = client.get(url)
|
|
183
388
|
resp.raise_for_status()
|
|
184
|
-
|
|
389
|
+
|
|
390
|
+
# Convert HTML to Markdown
|
|
391
|
+
markdown = _html_to_markdown(resp.text)
|
|
392
|
+
|
|
393
|
+
# Add source URL as header
|
|
394
|
+
content = f"<!-- Source: {url} -->\n\n{markdown}"
|
|
395
|
+
filepath.write_text(content)
|
|
185
396
|
downloaded += 1
|
|
186
397
|
if verbose:
|
|
187
398
|
print(f" ✓ {filepath.relative_to(dest)}")
|
|
@@ -275,6 +486,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
|
|
|
275
486
|
return downloaded
|
|
276
487
|
|
|
277
488
|
|
|
489
|
+
def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
490
|
+
"""Download from mixed sources (NVIDIA docs + GitHub repos)."""
|
|
491
|
+
total = 0
|
|
492
|
+
|
|
493
|
+
# Download NVIDIA markdown docs (urls)
|
|
494
|
+
if config.urls:
|
|
495
|
+
if verbose:
|
|
496
|
+
print(" [NVIDIA docs]")
|
|
497
|
+
total += _download_nvidia_md(config, dest, verbose)
|
|
498
|
+
|
|
499
|
+
# Download GitHub repos
|
|
500
|
+
if config.repos:
|
|
501
|
+
if verbose:
|
|
502
|
+
print(" [GitHub repos]")
|
|
503
|
+
total += _download_github_multi_repo(config, dest, verbose)
|
|
504
|
+
|
|
505
|
+
return total
|
|
506
|
+
|
|
507
|
+
|
|
278
508
|
def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
|
|
279
509
|
"""Download a corpus to local cache.
|
|
280
510
|
|
|
@@ -311,6 +541,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
|
|
|
311
541
|
count = _download_github_repo(config, dest, verbose)
|
|
312
542
|
elif config.source_type == "github_multi_repo":
|
|
313
543
|
count = _download_github_multi_repo(config, dest, verbose)
|
|
544
|
+
elif config.source_type == "mixed":
|
|
545
|
+
count = _download_mixed(config, dest, verbose)
|
|
314
546
|
else:
|
|
315
547
|
raise ValueError(f"Unknown source type: {config.source_type}")
|
|
316
548
|
if verbose:
|