wafer-cli 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +118 -0
- wafer/__init__.py +3 -0
- wafer/analytics.py +306 -0
- wafer/api_client.py +195 -0
- wafer/auth.py +432 -0
- wafer/autotuner.py +1080 -0
- wafer/billing.py +233 -0
- wafer/cli.py +7289 -0
- wafer/config.py +105 -0
- wafer/corpus.py +366 -0
- wafer/evaluate.py +4593 -0
- wafer/global_config.py +350 -0
- wafer/gpu_run.py +307 -0
- wafer/inference.py +148 -0
- wafer/kernel_scope.py +552 -0
- wafer/ncu_analyze.py +651 -0
- wafer/nsys_analyze.py +1042 -0
- wafer/nsys_profile.py +510 -0
- wafer/output.py +248 -0
- wafer/problems.py +357 -0
- wafer/rocprof_compute.py +490 -0
- wafer/rocprof_sdk.py +274 -0
- wafer/rocprof_systems.py +520 -0
- wafer/skills/wafer-guide/SKILL.md +129 -0
- wafer/ssh_keys.py +261 -0
- wafer/target_lock.py +270 -0
- wafer/targets.py +842 -0
- wafer/targets_ops.py +717 -0
- wafer/templates/__init__.py +0 -0
- wafer/templates/ask_docs.py +61 -0
- wafer/templates/optimize_kernel.py +71 -0
- wafer/templates/optimize_kernelbench.py +137 -0
- wafer/templates/trace_analyze.py +74 -0
- wafer/tracelens.py +218 -0
- wafer/wevin_cli.py +577 -0
- wafer/workspaces.py +852 -0
- wafer_cli-0.2.14.dist-info/METADATA +16 -0
- wafer_cli-0.2.14.dist-info/RECORD +41 -0
- wafer_cli-0.2.14.dist-info/WHEEL +5 -0
- wafer_cli-0.2.14.dist-info/entry_points.txt +2 -0
- wafer_cli-0.2.14.dist-info/top_level.txt +1 -0
wafer/config.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Configuration management for Wafer CLI.
|
|
2
|
+
|
|
3
|
+
Immutable dataclasses for config with TOML parsing.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import tomllib
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class WaferEnvironment:
|
|
13
|
+
"""Docker environment configuration. Immutable."""
|
|
14
|
+
|
|
15
|
+
name: str
|
|
16
|
+
docker: str
|
|
17
|
+
description: str = ""
|
|
18
|
+
|
|
19
|
+
def __post_init__(self) -> None:
|
|
20
|
+
"""Validate environment configuration."""
|
|
21
|
+
assert self.name, "environment name cannot be empty"
|
|
22
|
+
assert self.docker, "docker image cannot be empty"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class WaferConfig:
|
|
27
|
+
"""Wafer CLI configuration. Immutable."""
|
|
28
|
+
|
|
29
|
+
target: str
|
|
30
|
+
ssh_key: str
|
|
31
|
+
environments: dict[str, WaferEnvironment]
|
|
32
|
+
default_environment: str | None = None
|
|
33
|
+
|
|
34
|
+
def __post_init__(self) -> None:
|
|
35
|
+
"""Validate configuration."""
|
|
36
|
+
assert self.target, "target cannot be empty"
|
|
37
|
+
assert self.ssh_key, "ssh_key cannot be empty"
|
|
38
|
+
assert self.environments, "at least one environment must be defined"
|
|
39
|
+
|
|
40
|
+
# Validate default_environment exists if specified
|
|
41
|
+
if self.default_environment:
|
|
42
|
+
assert (
|
|
43
|
+
self.default_environment in self.environments
|
|
44
|
+
), f"default_environment '{self.default_environment}' not found in environments"
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_toml(cls, path: Path) -> "WaferConfig":
|
|
48
|
+
"""Parse config from TOML file.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
path: Path to config file
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
WaferConfig instance
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
AssertionError: If config is invalid or missing required fields
|
|
58
|
+
FileNotFoundError: If config file doesn't exist
|
|
59
|
+
|
|
60
|
+
Example config file (~/.wafer/config.toml):
|
|
61
|
+
|
|
62
|
+
[default]
|
|
63
|
+
target = "root@b200:22"
|
|
64
|
+
ssh_key = "~/.ssh/id_ed25519"
|
|
65
|
+
environment = "cutlass" # Optional default
|
|
66
|
+
|
|
67
|
+
[environments.cutlass]
|
|
68
|
+
docker = "nvcr.io/nvidia/cutlass:4.3-devel"
|
|
69
|
+
description = "CUDA 13 + Cutlass 4.3"
|
|
70
|
+
|
|
71
|
+
[environments.pytorch]
|
|
72
|
+
docker = "pytorch/pytorch:2.5-cuda12.4"
|
|
73
|
+
description = "PyTorch with CUDA 12.4"
|
|
74
|
+
"""
|
|
75
|
+
assert path.exists(), f"Config file not found: {path}"
|
|
76
|
+
|
|
77
|
+
with open(path, "rb") as f:
|
|
78
|
+
data = tomllib.load(f)
|
|
79
|
+
|
|
80
|
+
# Validate required sections
|
|
81
|
+
assert "default" in data, "Config must have [default] section"
|
|
82
|
+
assert "target" in data["default"], "Config must have default.target"
|
|
83
|
+
assert "ssh_key" in data["default"], "Config must have default.ssh_key"
|
|
84
|
+
|
|
85
|
+
# Parse environments
|
|
86
|
+
environments = {}
|
|
87
|
+
env_data = data.get("environments", {})
|
|
88
|
+
assert env_data, "Config must have at least one environment defined"
|
|
89
|
+
|
|
90
|
+
for name, env_config in env_data.items():
|
|
91
|
+
assert isinstance(env_config, dict), f"Environment {name} must be a table/dict"
|
|
92
|
+
assert "docker" in env_config, f"Environment {name} must have docker image"
|
|
93
|
+
|
|
94
|
+
environments[name] = WaferEnvironment(
|
|
95
|
+
name=name,
|
|
96
|
+
docker=env_config["docker"],
|
|
97
|
+
description=env_config.get("description", ""),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
return cls(
|
|
101
|
+
target=data["default"]["target"],
|
|
102
|
+
ssh_key=data["default"]["ssh_key"],
|
|
103
|
+
environments=environments,
|
|
104
|
+
default_environment=data["default"].get("environment"),
|
|
105
|
+
)
|
wafer/corpus.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""Corpus management for Wafer CLI.
|
|
2
|
+
|
|
3
|
+
Download and manage documentation corpora for agent filesystem access.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import shutil
|
|
7
|
+
import tarfile
|
|
8
|
+
import tempfile
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Literal
|
|
12
|
+
from urllib.parse import urlparse
|
|
13
|
+
|
|
14
|
+
import httpx
|
|
15
|
+
|
|
16
|
+
CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
|
|
17
|
+
|
|
18
|
+
CorpusName = Literal["cuda", "cutlass", "hip", "amd"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RepoSource:
|
|
23
|
+
"""A single GitHub repo source within a corpus."""
|
|
24
|
+
|
|
25
|
+
repo: str
|
|
26
|
+
paths: list[str]
|
|
27
|
+
branch: str = "main"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class CorpusConfig:
|
|
32
|
+
"""Configuration for a downloadable corpus."""
|
|
33
|
+
|
|
34
|
+
name: CorpusName
|
|
35
|
+
description: str
|
|
36
|
+
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
|
|
37
|
+
urls: list[str] | None = None
|
|
38
|
+
repo: str | None = None
|
|
39
|
+
repo_paths: list[str] | None = None
|
|
40
|
+
repos: list[RepoSource] | None = None # For multi-repo corpora
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
44
|
+
"cuda": CorpusConfig(
|
|
45
|
+
name="cuda",
|
|
46
|
+
description="CUDA Programming Guide and Best Practices",
|
|
47
|
+
source_type="nvidia_md",
|
|
48
|
+
urls=[
|
|
49
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/index.html",
|
|
50
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/introduction.html",
|
|
51
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/programming-model.html",
|
|
52
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/01-introduction/cuda-platform.html",
|
|
53
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/intro-to-cuda-cpp.html",
|
|
54
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/understanding-memory.html",
|
|
55
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/02-basics/nvcc.html",
|
|
56
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/03-advanced/advanced-kernel-programming.html",
|
|
57
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/03-advanced/advanced-host-programming.html",
|
|
58
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/cuda-graphs.html",
|
|
59
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/stream-ordered-memory-allocation.html",
|
|
60
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/dynamic-parallelism.html",
|
|
61
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/virtual-memory-management.html",
|
|
62
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/cluster-launch-control.html",
|
|
63
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/graphics-interop.html",
|
|
64
|
+
"https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html",
|
|
65
|
+
"https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html",
|
|
66
|
+
],
|
|
67
|
+
),
|
|
68
|
+
"cutlass": CorpusConfig(
|
|
69
|
+
name="cutlass",
|
|
70
|
+
description="CUTLASS and CuTe DSL documentation",
|
|
71
|
+
source_type="github_repo",
|
|
72
|
+
repo="NVIDIA/cutlass",
|
|
73
|
+
repo_paths=["media/docs", "python/cutlass/docs"],
|
|
74
|
+
),
|
|
75
|
+
"hip": CorpusConfig(
|
|
76
|
+
name="hip",
|
|
77
|
+
description="HIP programming guide and API reference",
|
|
78
|
+
source_type="github_repo",
|
|
79
|
+
repo="ROCm/HIP",
|
|
80
|
+
repo_paths=["docs"],
|
|
81
|
+
),
|
|
82
|
+
"amd": CorpusConfig(
|
|
83
|
+
name="amd",
|
|
84
|
+
description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM)",
|
|
85
|
+
source_type="github_multi_repo",
|
|
86
|
+
repos=[
|
|
87
|
+
# rocWMMA - wave matrix multiply-accumulate (WMMA) intrinsics
|
|
88
|
+
RepoSource(
|
|
89
|
+
repo="ROCm/rocWMMA",
|
|
90
|
+
paths=["docs", "samples", "library/include"],
|
|
91
|
+
branch="develop",
|
|
92
|
+
),
|
|
93
|
+
# Composable Kernel - tile-based GPU programming
|
|
94
|
+
RepoSource(
|
|
95
|
+
repo="ROCm/composable_kernel",
|
|
96
|
+
paths=["docs", "example", "tutorial", "include/ck_tile"],
|
|
97
|
+
branch="develop",
|
|
98
|
+
),
|
|
99
|
+
# AITER - AMD inference tensor runtime
|
|
100
|
+
RepoSource(
|
|
101
|
+
repo="ROCm/aiter",
|
|
102
|
+
paths=["docs", "aiter/ops"],
|
|
103
|
+
),
|
|
104
|
+
# MIOpen - deep learning primitives (deprecated, use rocm-libraries)
|
|
105
|
+
RepoSource(
|
|
106
|
+
repo="ROCm/MIOpen",
|
|
107
|
+
paths=["docs"],
|
|
108
|
+
branch="develop_deprecated",
|
|
109
|
+
),
|
|
110
|
+
# rocBLAS - BLAS library (deprecated, use rocm-libraries)
|
|
111
|
+
RepoSource(
|
|
112
|
+
repo="ROCm/rocBLAS",
|
|
113
|
+
paths=["docs"],
|
|
114
|
+
branch="develop_deprecated",
|
|
115
|
+
),
|
|
116
|
+
# hipBLASLt - lightweight BLAS (deprecated, use rocm-libraries)
|
|
117
|
+
RepoSource(
|
|
118
|
+
repo="ROCm/hipBLASLt",
|
|
119
|
+
paths=["docs"],
|
|
120
|
+
branch="develop_deprecated",
|
|
121
|
+
),
|
|
122
|
+
# Tensile - GEMM code generator (deprecated, use rocm-libraries)
|
|
123
|
+
RepoSource(
|
|
124
|
+
repo="ROCm/Tensile",
|
|
125
|
+
paths=["docs"],
|
|
126
|
+
branch="develop_deprecated",
|
|
127
|
+
),
|
|
128
|
+
# HipKittens - high-performance AMD kernels
|
|
129
|
+
RepoSource(
|
|
130
|
+
repo="HazyResearch/HipKittens",
|
|
131
|
+
paths=["docs", "kernels", "include"],
|
|
132
|
+
),
|
|
133
|
+
# vLLM AMD kernels
|
|
134
|
+
RepoSource(
|
|
135
|
+
repo="vllm-project/vllm",
|
|
136
|
+
paths=["csrc/rocm"],
|
|
137
|
+
),
|
|
138
|
+
# SGLang AMD kernels
|
|
139
|
+
RepoSource(
|
|
140
|
+
repo="sgl-project/sglang",
|
|
141
|
+
paths=["3rdparty/amd"],
|
|
142
|
+
),
|
|
143
|
+
# HuggingFace ROCm kernels
|
|
144
|
+
RepoSource(
|
|
145
|
+
repo="huggingface/hf-rocm-kernels",
|
|
146
|
+
paths=["csrc", "hf_rocm_kernels", "docs"],
|
|
147
|
+
),
|
|
148
|
+
],
|
|
149
|
+
),
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _corpus_path(name: CorpusName) -> Path:
|
|
154
|
+
"""Get local path for corpus."""
|
|
155
|
+
return CACHE_DIR / name
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _ensure_cache_dir() -> None:
|
|
159
|
+
"""Ensure cache directory exists."""
|
|
160
|
+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _url_to_filepath(url: str, base_dir: Path) -> Path:
|
|
164
|
+
"""Convert URL to local filepath preserving structure."""
|
|
165
|
+
parsed = urlparse(url)
|
|
166
|
+
path_parts = parsed.path.strip("/").split("/")
|
|
167
|
+
if path_parts[-1].endswith(".html"):
|
|
168
|
+
path_parts[-1] = path_parts[-1].replace(".html", ".md")
|
|
169
|
+
return base_dir / "/".join(path_parts)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
173
|
+
"""Download NVIDIA docs using .md endpoint."""
|
|
174
|
+
assert config.urls is not None
|
|
175
|
+
downloaded = 0
|
|
176
|
+
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
|
177
|
+
for url in config.urls:
|
|
178
|
+
md_url = f"{url}.md"
|
|
179
|
+
filepath = _url_to_filepath(url, dest)
|
|
180
|
+
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
181
|
+
try:
|
|
182
|
+
resp = client.get(md_url)
|
|
183
|
+
resp.raise_for_status()
|
|
184
|
+
filepath.write_text(resp.text)
|
|
185
|
+
downloaded += 1
|
|
186
|
+
if verbose:
|
|
187
|
+
print(f" ✓ {filepath.relative_to(dest)}")
|
|
188
|
+
except httpx.HTTPError as e:
|
|
189
|
+
if verbose:
|
|
190
|
+
print(f" ✗ {url}: {e}")
|
|
191
|
+
return downloaded
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _extract_matching_files(
|
|
195
|
+
tar: tarfile.TarFile,
|
|
196
|
+
repo_paths: list[str],
|
|
197
|
+
dest: Path,
|
|
198
|
+
verbose: bool,
|
|
199
|
+
) -> int:
|
|
200
|
+
"""Extract files matching repo_paths from tarball."""
|
|
201
|
+
downloaded = 0
|
|
202
|
+
for member in tar.getmembers():
|
|
203
|
+
if not member.isfile():
|
|
204
|
+
continue
|
|
205
|
+
rel_path = "/".join(member.name.split("/")[1:])
|
|
206
|
+
if not any(rel_path.startswith(rp) for rp in repo_paths):
|
|
207
|
+
continue
|
|
208
|
+
target = dest / rel_path
|
|
209
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
210
|
+
src = tar.extractfile(member)
|
|
211
|
+
if src:
|
|
212
|
+
target.write_bytes(src.read())
|
|
213
|
+
downloaded += 1
|
|
214
|
+
if verbose:
|
|
215
|
+
print(f" ✓ {rel_path}")
|
|
216
|
+
return downloaded
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _download_single_github_repo(
|
|
220
|
+
client: httpx.Client,
|
|
221
|
+
repo: str,
|
|
222
|
+
repo_paths: list[str],
|
|
223
|
+
dest: Path,
|
|
224
|
+
branch: str = "main",
|
|
225
|
+
verbose: bool = True,
|
|
226
|
+
) -> int:
|
|
227
|
+
"""Download specific paths from a single GitHub repo."""
|
|
228
|
+
tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
|
|
229
|
+
if verbose:
|
|
230
|
+
print(f" Fetching {repo}...")
|
|
231
|
+
resp = client.get(tarball_url)
|
|
232
|
+
resp.raise_for_status()
|
|
233
|
+
with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
|
|
234
|
+
tmp.write(resp.content)
|
|
235
|
+
tmp_path = Path(tmp.name)
|
|
236
|
+
try:
|
|
237
|
+
with tarfile.open(tmp_path, "r:gz") as tar:
|
|
238
|
+
return _extract_matching_files(tar, repo_paths, dest, verbose)
|
|
239
|
+
finally:
|
|
240
|
+
tmp_path.unlink()
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _download_github_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
244
|
+
"""Download specific paths from GitHub repo."""
|
|
245
|
+
assert config.repo is not None
|
|
246
|
+
assert config.repo_paths is not None
|
|
247
|
+
with httpx.Client(timeout=60.0, follow_redirects=True) as client:
|
|
248
|
+
return _download_single_github_repo(
|
|
249
|
+
client, config.repo, config.repo_paths, dest, verbose=verbose
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
254
|
+
"""Download specific paths from multiple GitHub repos."""
|
|
255
|
+
assert config.repos is not None
|
|
256
|
+
downloaded = 0
|
|
257
|
+
with httpx.Client(timeout=120.0, follow_redirects=True) as client:
|
|
258
|
+
for repo_source in config.repos:
|
|
259
|
+
repo_name = repo_source.repo.split("/")[-1]
|
|
260
|
+
repo_dest = dest / repo_name
|
|
261
|
+
repo_dest.mkdir(parents=True, exist_ok=True)
|
|
262
|
+
try:
|
|
263
|
+
count = _download_single_github_repo(
|
|
264
|
+
client,
|
|
265
|
+
repo_source.repo,
|
|
266
|
+
repo_source.paths,
|
|
267
|
+
repo_dest,
|
|
268
|
+
branch=repo_source.branch,
|
|
269
|
+
verbose=verbose,
|
|
270
|
+
)
|
|
271
|
+
downloaded += count
|
|
272
|
+
except httpx.HTTPError as e:
|
|
273
|
+
if verbose:
|
|
274
|
+
print(f" ✗ {repo_source.repo}: {e}")
|
|
275
|
+
return downloaded
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
|
|
279
|
+
"""Download a corpus to local cache.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
name: Corpus name
|
|
283
|
+
force: Re-download even if exists
|
|
284
|
+
verbose: Print progress
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
Path to downloaded corpus
|
|
288
|
+
|
|
289
|
+
Raises:
|
|
290
|
+
ValueError: If corpus name is unknown
|
|
291
|
+
httpx.HTTPError: If download fails
|
|
292
|
+
"""
|
|
293
|
+
if name not in CORPORA:
|
|
294
|
+
raise ValueError(f"Unknown corpus: {name}. Available: {list(CORPORA.keys())}")
|
|
295
|
+
config = CORPORA[name]
|
|
296
|
+
dest = _corpus_path(name)
|
|
297
|
+
if dest.exists() and not force:
|
|
298
|
+
if verbose:
|
|
299
|
+
print(f"Corpus '{name}' already exists at {dest}")
|
|
300
|
+
print("Use --force to re-download")
|
|
301
|
+
return dest
|
|
302
|
+
_ensure_cache_dir()
|
|
303
|
+
if dest.exists():
|
|
304
|
+
shutil.rmtree(dest)
|
|
305
|
+
dest.mkdir(parents=True)
|
|
306
|
+
if verbose:
|
|
307
|
+
print(f"Downloading {name}: {config.description}")
|
|
308
|
+
if config.source_type == "nvidia_md":
|
|
309
|
+
count = _download_nvidia_md(config, dest, verbose)
|
|
310
|
+
elif config.source_type == "github_repo":
|
|
311
|
+
count = _download_github_repo(config, dest, verbose)
|
|
312
|
+
elif config.source_type == "github_multi_repo":
|
|
313
|
+
count = _download_github_multi_repo(config, dest, verbose)
|
|
314
|
+
else:
|
|
315
|
+
raise ValueError(f"Unknown source type: {config.source_type}")
|
|
316
|
+
if verbose:
|
|
317
|
+
print(f"Downloaded {count} files to {dest}")
|
|
318
|
+
return dest
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def sync_corpus(name: CorpusName, verbose: bool = True) -> Path:
|
|
322
|
+
"""Sync (re-download) a corpus.
|
|
323
|
+
|
|
324
|
+
Args:
|
|
325
|
+
name: Corpus name
|
|
326
|
+
verbose: Print progress
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
Path to synced corpus
|
|
330
|
+
"""
|
|
331
|
+
return download_corpus(name, force=True, verbose=verbose)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def list_corpora(verbose: bool = True) -> dict[CorpusName, bool]:
|
|
335
|
+
"""List available corpora and their download status.
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
Dict of corpus name -> is_downloaded
|
|
339
|
+
"""
|
|
340
|
+
result: dict[CorpusName, bool] = {}
|
|
341
|
+
for name, config in CORPORA.items():
|
|
342
|
+
path = _corpus_path(name)
|
|
343
|
+
exists = path.exists()
|
|
344
|
+
result[name] = exists
|
|
345
|
+
if verbose:
|
|
346
|
+
status = "✓" if exists else " "
|
|
347
|
+
print(f"[{status}] {name}: {config.description}")
|
|
348
|
+
if exists:
|
|
349
|
+
file_count = sum(1 for _ in path.rglob("*") if _.is_file())
|
|
350
|
+
print(f" {path} ({file_count} files)")
|
|
351
|
+
return result
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def get_corpus_path(name: CorpusName) -> Path | None:
|
|
355
|
+
"""Get path to downloaded corpus, or None if not downloaded.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
name: Corpus name
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Path if downloaded, None otherwise
|
|
362
|
+
"""
|
|
363
|
+
if name not in CORPORA:
|
|
364
|
+
return None
|
|
365
|
+
path = _corpus_path(name)
|
|
366
|
+
return path if path.exists() else None
|