wafer-cli 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer/corpus.py CHANGED
@@ -15,7 +15,16 @@ import httpx
15
15
 
16
16
  CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
17
17
 
18
- CorpusName = Literal["cuda", "cutlass", "hip"]
18
+ CorpusName = Literal["cuda", "cutlass", "hip", "amd"]
19
+
20
+
21
+ @dataclass
22
+ class RepoSource:
23
+ """A single GitHub repo source within a corpus."""
24
+
25
+ repo: str
26
+ paths: list[str]
27
+ branch: str = "main"
19
28
 
20
29
 
21
30
  @dataclass
@@ -24,10 +33,11 @@ class CorpusConfig:
24
33
 
25
34
  name: CorpusName
26
35
  description: str
27
- source_type: Literal["nvidia_md", "github_repo"]
36
+ source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
28
37
  urls: list[str] | None = None
29
38
  repo: str | None = None
30
39
  repo_paths: list[str] | None = None
40
+ repos: list[RepoSource] | None = None # For multi-repo corpora
31
41
 
32
42
 
33
43
  CORPORA: dict[CorpusName, CorpusConfig] = {
@@ -69,6 +79,74 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
69
79
  repo="ROCm/HIP",
70
80
  repo_paths=["docs"],
71
81
  ),
82
+ "amd": CorpusConfig(
83
+ name="amd",
84
+ description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM)",
85
+ source_type="github_multi_repo",
86
+ repos=[
87
+ # rocWMMA - wave matrix multiply-accumulate (WMMA) intrinsics
88
+ RepoSource(
89
+ repo="ROCm/rocWMMA",
90
+ paths=["docs", "samples", "library/include"],
91
+ branch="develop",
92
+ ),
93
+ # Composable Kernel - tile-based GPU programming
94
+ RepoSource(
95
+ repo="ROCm/composable_kernel",
96
+ paths=["docs", "example", "tutorial", "include/ck_tile"],
97
+ branch="develop",
98
+ ),
99
+ # AITER - AMD inference tensor runtime
100
+ RepoSource(
101
+ repo="ROCm/aiter",
102
+ paths=["docs", "aiter/ops"],
103
+ ),
104
+ # MIOpen - deep learning primitives (deprecated, use rocm-libraries)
105
+ RepoSource(
106
+ repo="ROCm/MIOpen",
107
+ paths=["docs"],
108
+ branch="develop_deprecated",
109
+ ),
110
+ # rocBLAS - BLAS library (deprecated, use rocm-libraries)
111
+ RepoSource(
112
+ repo="ROCm/rocBLAS",
113
+ paths=["docs"],
114
+ branch="develop_deprecated",
115
+ ),
116
+ # hipBLASLt - lightweight BLAS (deprecated, use rocm-libraries)
117
+ RepoSource(
118
+ repo="ROCm/hipBLASLt",
119
+ paths=["docs"],
120
+ branch="develop_deprecated",
121
+ ),
122
+ # Tensile - GEMM code generator (deprecated, use rocm-libraries)
123
+ RepoSource(
124
+ repo="ROCm/Tensile",
125
+ paths=["docs"],
126
+ branch="develop_deprecated",
127
+ ),
128
+ # HipKittens - high-performance AMD kernels
129
+ RepoSource(
130
+ repo="HazyResearch/HipKittens",
131
+ paths=["docs", "kernels", "include"],
132
+ ),
133
+ # vLLM AMD kernels
134
+ RepoSource(
135
+ repo="vllm-project/vllm",
136
+ paths=["csrc/rocm"],
137
+ ),
138
+ # SGLang AMD kernels
139
+ RepoSource(
140
+ repo="sgl-project/sglang",
141
+ paths=["3rdparty/amd"],
142
+ ),
143
+ # HuggingFace ROCm kernels
144
+ RepoSource(
145
+ repo="huggingface/hf-rocm-kernels",
146
+ paths=["csrc", "hf_rocm_kernels", "docs"],
147
+ ),
148
+ ],
149
+ ),
72
150
  }
73
151
 
74
152
 
@@ -113,41 +191,87 @@ def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True)
113
191
  return downloaded
114
192
 
115
193
 
194
+ def _extract_matching_files(
195
+ tar: tarfile.TarFile,
196
+ repo_paths: list[str],
197
+ dest: Path,
198
+ verbose: bool,
199
+ ) -> int:
200
+ """Extract files matching repo_paths from tarball."""
201
+ downloaded = 0
202
+ for member in tar.getmembers():
203
+ if not member.isfile():
204
+ continue
205
+ rel_path = "/".join(member.name.split("/")[1:])
206
+ if not any(rel_path.startswith(rp) for rp in repo_paths):
207
+ continue
208
+ target = dest / rel_path
209
+ target.parent.mkdir(parents=True, exist_ok=True)
210
+ src = tar.extractfile(member)
211
+ if src:
212
+ target.write_bytes(src.read())
213
+ downloaded += 1
214
+ if verbose:
215
+ print(f" ✓ {rel_path}")
216
+ return downloaded
217
+
218
+
219
+ def _download_single_github_repo(
220
+ client: httpx.Client,
221
+ repo: str,
222
+ repo_paths: list[str],
223
+ dest: Path,
224
+ branch: str = "main",
225
+ verbose: bool = True,
226
+ ) -> int:
227
+ """Download specific paths from a single GitHub repo."""
228
+ tarball_url = f"https://api.github.com/repos/{repo}/tarball/{branch}"
229
+ if verbose:
230
+ print(f" Fetching {repo}...")
231
+ resp = client.get(tarball_url)
232
+ resp.raise_for_status()
233
+ with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
234
+ tmp.write(resp.content)
235
+ tmp_path = Path(tmp.name)
236
+ try:
237
+ with tarfile.open(tmp_path, "r:gz") as tar:
238
+ return _extract_matching_files(tar, repo_paths, dest, verbose)
239
+ finally:
240
+ tmp_path.unlink()
241
+
242
+
116
243
  def _download_github_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
117
244
  """Download specific paths from GitHub repo."""
118
245
  assert config.repo is not None
119
246
  assert config.repo_paths is not None
120
- downloaded = 0
121
247
  with httpx.Client(timeout=60.0, follow_redirects=True) as client:
122
- tarball_url = f"https://api.github.com/repos/{config.repo}/tarball/main"
123
- if verbose:
124
- print(f" Fetching {config.repo}...")
125
- resp = client.get(tarball_url)
126
- resp.raise_for_status()
127
- with tempfile.NamedTemporaryFile(suffix=".tar.gz", delete=False) as tmp:
128
- tmp.write(resp.content)
129
- tmp_path = Path(tmp.name)
130
- try:
131
- with tarfile.open(tmp_path, "r:gz") as tar:
132
- members = tar.getmembers()
133
- root_prefix = members[0].name.split("/")[0] if members else ""
134
- for member in members:
135
- if not member.isfile():
136
- continue
137
- rel_path = "/".join(member.name.split("/")[1:])
138
- for repo_path in config.repo_paths:
139
- if rel_path.startswith(repo_path):
140
- target = dest / rel_path
141
- target.parent.mkdir(parents=True, exist_ok=True)
142
- with tar.extractfile(member) as src:
143
- if src:
144
- target.write_bytes(src.read())
145
- downloaded += 1
146
- if verbose:
147
- print(f" ✓ {rel_path}")
148
- break
149
- finally:
150
- tmp_path.unlink()
248
+ return _download_single_github_repo(
249
+ client, config.repo, config.repo_paths, dest, verbose=verbose
250
+ )
251
+
252
+
253
+ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
254
+ """Download specific paths from multiple GitHub repos."""
255
+ assert config.repos is not None
256
+ downloaded = 0
257
+ with httpx.Client(timeout=120.0, follow_redirects=True) as client:
258
+ for repo_source in config.repos:
259
+ repo_name = repo_source.repo.split("/")[-1]
260
+ repo_dest = dest / repo_name
261
+ repo_dest.mkdir(parents=True, exist_ok=True)
262
+ try:
263
+ count = _download_single_github_repo(
264
+ client,
265
+ repo_source.repo,
266
+ repo_source.paths,
267
+ repo_dest,
268
+ branch=repo_source.branch,
269
+ verbose=verbose,
270
+ )
271
+ downloaded += count
272
+ except httpx.HTTPError as e:
273
+ if verbose:
274
+ print(f" ✗ {repo_source.repo}: {e}")
151
275
  return downloaded
152
276
 
153
277
 
@@ -185,6 +309,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
185
309
  count = _download_nvidia_md(config, dest, verbose)
186
310
  elif config.source_type == "github_repo":
187
311
  count = _download_github_repo(config, dest, verbose)
312
+ elif config.source_type == "github_multi_repo":
313
+ count = _download_github_multi_repo(config, dest, verbose)
188
314
  else:
189
315
  raise ValueError(f"Unknown source type: {config.source_type}")
190
316
  if verbose: