wafer-cli 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/config.py ADDED
@@ -0,0 +1,105 @@
1
+ """Configuration management for Wafer CLI.
2
+
3
+ Immutable dataclasses for config with TOML parsing.
4
+ """
5
+
6
+ import tomllib
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class WaferEnvironment:
13
+ """Docker environment configuration. Immutable."""
14
+
15
+ name: str
16
+ docker: str
17
+ description: str = ""
18
+
19
+ def __post_init__(self) -> None:
20
+ """Validate environment configuration."""
21
+ assert self.name, "environment name cannot be empty"
22
+ assert self.docker, "docker image cannot be empty"
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class WaferConfig:
27
+ """Wafer CLI configuration. Immutable."""
28
+
29
+ target: str
30
+ ssh_key: str
31
+ environments: dict[str, WaferEnvironment]
32
+ default_environment: str | None = None
33
+
34
+ def __post_init__(self) -> None:
35
+ """Validate configuration."""
36
+ assert self.target, "target cannot be empty"
37
+ assert self.ssh_key, "ssh_key cannot be empty"
38
+ assert self.environments, "at least one environment must be defined"
39
+
40
+ # Validate default_environment exists if specified
41
+ if self.default_environment:
42
+ assert (
43
+ self.default_environment in self.environments
44
+ ), f"default_environment '{self.default_environment}' not found in environments"
45
+
46
+ @classmethod
47
+ def from_toml(cls, path: Path) -> "WaferConfig":
48
+ """Parse config from TOML file.
49
+
50
+ Args:
51
+ path: Path to config file
52
+
53
+ Returns:
54
+ WaferConfig instance
55
+
56
+ Raises:
57
+ AssertionError: If config is invalid or missing required fields
58
+ FileNotFoundError: If config file doesn't exist
59
+
60
+ Example config file (~/.wafer/config.toml):
61
+
62
+ [default]
63
+ target = "root@b200:22"
64
+ ssh_key = "~/.ssh/id_ed25519"
65
+ environment = "cutlass" # Optional default
66
+
67
+ [environments.cutlass]
68
+ docker = "nvcr.io/nvidia/cutlass:4.3-devel"
69
+ description = "CUDA 13 + Cutlass 4.3"
70
+
71
+ [environments.pytorch]
72
+ docker = "pytorch/pytorch:2.5-cuda12.4"
73
+ description = "PyTorch with CUDA 12.4"
74
+ """
75
+ assert path.exists(), f"Config file not found: {path}"
76
+
77
+ with open(path, "rb") as f:
78
+ data = tomllib.load(f)
79
+
80
+ # Validate required sections
81
+ assert "default" in data, "Config must have [default] section"
82
+ assert "target" in data["default"], "Config must have default.target"
83
+ assert "ssh_key" in data["default"], "Config must have default.ssh_key"
84
+
85
+ # Parse environments
86
+ environments = {}
87
+ env_data = data.get("environments", {})
88
+ assert env_data, "Config must have at least one environment defined"
89
+
90
+ for name, env_config in env_data.items():
91
+ assert isinstance(env_config, dict), f"Environment {name} must be a table/dict"
92
+ assert "docker" in env_config, f"Environment {name} must have docker image"
93
+
94
+ environments[name] = WaferEnvironment(
95
+ name=name,
96
+ docker=env_config["docker"],
97
+ description=env_config.get("description", ""),
98
+ )
99
+
100
+ return cls(
101
+ target=data["default"]["target"],
102
+ ssh_key=data["default"]["ssh_key"],
103
+ environments=environments,
104
+ default_environment=data["default"].get("environment"),
105
+ )
wafer/corpus.py ADDED
@@ -0,0 +1,366 @@
1
+ """Corpus management for Wafer CLI.
2
+
3
+ Download and manage documentation corpora for agent filesystem access.
4
+ """
5
+
6
+ import shutil
7
+ import tarfile
8
+ import tempfile
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Literal
12
+ from urllib.parse import urlparse
13
+
14
+ import httpx
15
+
16
+ CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
17
+
18
+ CorpusName = Literal["cuda", "cutlass", "hip", "amd"]
19
+
20
+
21
+ @dataclass
22
+ class RepoSource:
23
+ """A single GitHub repo source within a corpus."""
24
+
25
+ repo: str
26
+ paths: list[str]
27
+ branch: str = "main"
28
+
29
+
30
+ @dataclass
31
+ class CorpusConfig:
32
+ """Configuration for a downloadable corpus."""
33
+
34
+ name: CorpusName
35
+ description: str
36
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
37
+ urls: list[str] | None = None
38
+ repo: str | None = None
39
+ repo_paths: list[str] | None = None
40
+ repos: list[RepoSource] | None = None # For multi-repo corpora
41
+
42
+
43
+ CORPORA: dict[CorpusName, CorpusConfig] = {
44
+ "cuda": CorpusConfig(
45
+ name="cuda",
46
+ description="CUDA Programming Guide and Best Practices",
47
+ source_type="nvidia_md",
48
+ urls=[
49
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/index.html",
50
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/introduction.html",
51
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/programming-model.html",
52
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/cuda-platform.html",
53
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/intro-to-cuda-cpp.html",
54
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/understanding-memory.html",
55
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/nvcc.html",
56
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/03-advanced/advanced-kernel-programming.html",
57
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/03-advanced/advanced-host-programming.html",
58
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/cuda-graphs.html",
59
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/stream-ordered-memory-allocation.html",
60
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/dynamic-parallelism.html",
61
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/virtual-memory-management.html",
62
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/cluster-launch-control.html",
63
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/graphics-interop.html",
64
+ "https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html",
65
+ "https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html",
66
+ ],
67
+ ),
68
+ "cutlass": CorpusConfig(
69
+ name="cutlass",
70
+ description="CUTLASS and CuTe DSL documentation",
71
+ source_type="github_repo",
72
+ repo="NVIDIA/cutlass",
73
+ repo_paths=["media/docs", "python/cutlass/docs"],
74
+ ),
75
+ "hip": CorpusConfig(
76
+ name="hip",
77
+ description="HIP programming guide and API reference",
78
+ source_type="github_repo",
79
+ repo="ROCm/HIP",
80
+ repo_paths=["docs"],
81
+ ),
82
+ "amd": CorpusConfig(
83
+ name="amd",
84
+ description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM)",
85
+ source_type="github_multi_repo",
86
+ repos=[
87
+ # rocWMMA - wave matrix multiply-accumulate (WMMA) intrinsics
88
+ RepoSource(
89
+ repo="ROCm/rocWMMA",
90
+ paths=["docs", "samples", "library/include"],
91
+ branch="develop",
92
+ ),
93
+ # Composable Kernel - tile-based GPU programming
94
+ RepoSource(
95
+ repo="ROCm/composable_kernel",
96
+ paths=["docs", "example", "tutorial", "include/ck_tile"],
97
+ branch="develop",
98
+ ),
99
+ # AITER - AMD inference tensor runtime
100
+ RepoSource(
101
+ repo="ROCm/aiter",
102
+ paths=["docs", "aiter/ops"],
103
+ ),
104
+ # MIOpen - deep learning primitives (deprecated, use rocm-libraries)
105
+ RepoSource(
106
+ repo="ROCm/MIOpen",
107
+ paths=["docs"],
108
+ branch="develop_deprecated",
109
+ ),
110
+ # rocBLAS - BLAS library (deprecated, use rocm-libraries)
111
+ RepoSource(
112
+ repo="ROCm/rocBLAS",
113
+ paths=["docs"],
114
+ branch="develop_deprecated",
115
+ ),
116
+ # hipBLASLt - lightweight BLAS (deprecated, use rocm-libraries)
117
+ RepoSource(
118
+ repo="ROCm/hipBLASLt",
119
+ paths=["docs"],
120
+ branch="develop_deprecated",
121
+ ),
122
+ # Tensile - GEMM code generator (deprecated, use rocm-libraries)
123
+ RepoSource(
124
+ repo="ROCm/Tensile",
125
+ paths=["docs"],
126
+ branch="develop_deprecated",
127
+ ),
128
+ # HipKittens - high-performance AMD kernels
129
+ RepoSource(
130
+ repo="HazyResearch/HipKittens",
131
+ paths=["docs", "kernels", "include"],
132
+ ),
133
+ # vLLM AMD kernels
134
+ RepoSource(
135
+ repo="vllm-project/vllm",
136
+ paths=["csrc/rocm"],
137
+ ),
138
+ # SGLang AMD kernels
139
+ RepoSource(
140
+ repo="sgl-project/sglang",
141
+ paths=["3rdparty/amd"],
142
+ ),
143
+ # HuggingFace ROCm kernels
144
+ RepoSource(
145
+ repo="huggingface/hf-rocm-kernels",
146
+ paths=["csrc", "hf_rocm_kernels", "docs"],
147
+ ),
148
+ ],
149
+ ),
150
+ }
151
+
152
+
153
+ def _corpus_path(name: CorpusName) -> Path:
154
+ """Get local path for corpus."""
155
+ return CACHE_DIR / name
156
+
157
+
158
+ def _ensure_cache_dir() -> None:
159
+ """Ensure cache directory exists."""
160
+ CACHE_DIR.mkdir(parents=True, exist_ok=True)
161
+
162
+
163
+ def _url_to_filepath(url: str, base_dir: Path) -> Path:
164
+ """Convert URL to local filepath preserving structure."""
165
+ parsed = urlparse(url)
166
+ path_parts = parsed.path.strip("/").split("/")
167
+ if path_parts[-1].endswith(".html"):
168
+ path_parts[-1] = path_parts[-1].replace(".html", ".md")
169
+ return base_dir / "/".join(path_parts)
170
+
171
+
172
+ def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
173
+ """Download NVIDIA docs using .md endpoint."""
174
+ assert config.urls is not None
175
+ downloaded = 0
176
+ with httpx.Client(timeout=30.0, follow_redirects=True) as client:
177
+ for url in config.urls:
178
+ md_url = f"{url}.md"
179
+ filepath = _url_to_filepath(url, dest)
180
+ filepath.parent.mkdir(parents=True, exist_ok=True)
181
+ try:
182
+ resp = client.get(md_url)
183
+ resp.raise_for_status()
184
+ filepath.write_text(resp.text)
185
+ downloaded += 1
186
+ if verbose:
187
+ print(f" ✓ {filepath.relative_to(dest)}")
188
+ except httpx.HTTPError as e:
189
+ if verbose:
190
+ print(f" ✗ {url}: {e}")
191
+ return downloaded
192
+
193
+
194
+ def _extract_matching_files(
195
+ tar: tarfile.TarFile,
196
+ repo_paths: list[str],
197
+ dest: Path,
198
+ verbose: bool,
199
+ ) -> int:
200
+ """Extract files matching repo_paths from tarball."""
201
+ downloaded = 0
202
+ for member in tar.getmembers():
203
+ if not member.isfile():
204
+ continue
205
+ rel_path = "/".join(member.name.split("/")[1:])
206
+ if not any(rel_path.startswith(rp) for rp in repo_paths):
207
+ continue
208
+ target = dest / rel_path
209
+ target.parent.mkdir(parents=True, exist_ok=True)
210
+ src = tar.extractfile(member)
211
+ if src:
212
+ target.write_bytes(src.read())
213
+ downloaded += 1
214
+ if verbose:
215
+ print(f" ✓ {rel_path}")
216
+ return downloaded
217
+
218
+
219
+ def _download_single_github_repo(
220
+ client: httpx.Client,
221
+ repo: str,
222
+ repo_paths: list[str],
223
+ dest: Path,
224
+ branch: str = "main",
225
+ verbose: bool = True,
226
+ ) -> int:
227
+ """Download specific paths from a single GitHub repo."""
228
+ tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
229
+ if verbose:
230
+ print(f" Fetching {repo}...")
231
+ resp = client.get(tarball_url)
232
+ resp.raise_for_status()
233
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
234
+ tmp.write(resp.content)
235
+ tmp_path = Path(tmp.name)
236
+ try:
237
+ with tarfile.open(tmp_path, "r:gz") as tar:
238
+ return _extract_matching_files(tar, repo_paths, dest, verbose)
239
+ finally:
240
+ tmp_path.unlink()
241
+
242
+
243
+ def _download_github_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
244
+ """Download specific paths from GitHub repo."""
245
+ assert config.repo is not None
246
+ assert config.repo_paths is not None
247
+ with httpx.Client(timeout=60.0, follow_redirects=True) as client:
248
+ return _download_single_github_repo(
249
+ client, config.repo, config.repo_paths, dest, verbose=verbose
250
+ )
251
+
252
+
253
+ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
254
+ """Download specific paths from multiple GitHub repos."""
255
+ assert config.repos is not None
256
+ downloaded = 0
257
+ with httpx.Client(timeout=120.0, follow_redirects=True) as client:
258
+ for repo_source in config.repos:
259
+ repo_name = repo_source.repo.split("/")[-1]
260
+ repo_dest = dest / repo_name
261
+ repo_dest.mkdir(parents=True, exist_ok=True)
262
+ try:
263
+ count = _download_single_github_repo(
264
+ client,
265
+ repo_source.repo,
266
+ repo_source.paths,
267
+ repo_dest,
268
+ branch=repo_source.branch,
269
+ verbose=verbose,
270
+ )
271
+ downloaded += count
272
+ except httpx.HTTPError as e:
273
+ if verbose:
274
+ print(f" ✗ {repo_source.repo}: {e}")
275
+ return downloaded
276
+
277
+
278
+ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
279
+ """Download a corpus to local cache.
280
+
281
+ Args:
282
+ name: Corpus name
283
+ force: Re-download even if exists
284
+ verbose: Print progress
285
+
286
+ Returns:
287
+ Path to downloaded corpus
288
+
289
+ Raises:
290
+ ValueError: If corpus name is unknown
291
+ httpx.HTTPError: If download fails
292
+ """
293
+ if name not in CORPORA:
294
+ raise ValueError(f"Unknown corpus: {name}. Available: {list(CORPORA.keys())}")
295
+ config = CORPORA[name]
296
+ dest = _corpus_path(name)
297
+ if dest.exists() and not force:
298
+ if verbose:
299
+ print(f"Corpus '{name}' already exists at {dest}")
300
+ print("Use --force to re-download")
301
+ return dest
302
+ _ensure_cache_dir()
303
+ if dest.exists():
304
+ shutil.rmtree(dest)
305
+ dest.mkdir(parents=True)
306
+ if verbose:
307
+ print(f"Downloading {name}: {config.description}")
308
+ if config.source_type == "nvidia_md":
309
+ count = _download_nvidia_md(config, dest, verbose)
310
+ elif config.source_type == "github_repo":
311
+ count = _download_github_repo(config, dest, verbose)
312
+ elif config.source_type == "github_multi_repo":
313
+ count = _download_github_multi_repo(config, dest, verbose)
314
+ else:
315
+ raise ValueError(f"Unknown source type: {config.source_type}")
316
+ if verbose:
317
+ print(f"Downloaded {count} files to {dest}")
318
+ return dest
319
+
320
+
321
+ def sync_corpus(name: CorpusName, verbose: bool = True) -> Path:
322
+ """Sync (re-download) a corpus.
323
+
324
+ Args:
325
+ name: Corpus name
326
+ verbose: Print progress
327
+
328
+ Returns:
329
+ Path to synced corpus
330
+ """
331
+ return download_corpus(name, force=True, verbose=verbose)
332
+
333
+
334
+ def list_corpora(verbose: bool = True) -> dict[CorpusName, bool]:
335
+ """List available corpora and their download status.
336
+
337
+ Returns:
338
+ Dict of corpus name -> is_downloaded
339
+ """
340
+ result: dict[CorpusName, bool] = {}
341
+ for name, config in CORPORA.items():
342
+ path = _corpus_path(name)
343
+ exists = path.exists()
344
+ result[name] = exists
345
+ if verbose:
346
+ status = "✓" if exists else " "
347
+ print(f"[{status}] {name}: {config.description}")
348
+ if exists:
349
+ file_count = sum(1 for _ in path.rglob("*") if _.is_file())
350
+ print(f" {path} ({file_count} files)")
351
+ return result
352
+
353
+
354
+ def get_corpus_path(name: CorpusName) -> Path | None:
355
+ """Get path to downloaded corpus, or None if not downloaded.
356
+
357
+ Args:
358
+ name: Corpus name
359
+
360
+ Returns:
361
+ Path if downloaded, None otherwise
362
+ """
363
+ if name not in CORPORA:
364
+ return None
365
+ path = _corpus_path(name)
366
+ return path if path.exists() else None