wafer-cli 0.2.14__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer/GUIDE.md +1 -1
- wafer/agent_defaults.py +42 -0
- wafer/auth.py +7 -0
- wafer/billing.py +6 -6
- wafer/cli.py +905 -131
- wafer/cli_instructions.py +143 -0
- wafer/corpus.py +313 -15
- wafer/evaluate.py +480 -146
- wafer/global_config.py +13 -0
- wafer/kernel_scope.py +1 -1
- wafer/ncu_analyze.py +1 -1
- wafer/nsys_analyze.py +1 -1
- wafer/skills/wafer-guide/SKILL.md +22 -6
- wafer/specs_cli.py +157 -0
- wafer/ssh_keys.py +6 -6
- wafer/targets_cli.py +472 -0
- wafer/targets_ops.py +29 -2
- wafer/templates/ask_docs.py +1 -1
- wafer/templates/optimize_kernel.py +3 -1
- wafer/templates/optimize_kernelbench.py +17 -62
- wafer/templates/trace_analyze.py +1 -1
- wafer/tests/test_eval_cli_parity.py +199 -0
- wafer/trace_compare.py +274 -0
- wafer/wevin_cli.py +125 -26
- wafer/workspaces.py +163 -16
- wafer_cli-0.2.30.dist-info/METADATA +107 -0
- wafer_cli-0.2.30.dist-info/RECORD +47 -0
- wafer_cli-0.2.14.dist-info/METADATA +0 -16
- wafer_cli-0.2.14.dist-info/RECORD +0 -41
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/WHEEL +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/entry_points.txt +0 -0
- {wafer_cli-0.2.14.dist-info → wafer_cli-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
"""Generate agent system prompt instructions from the wafer CLI's own --help text.
|
|
2
|
+
|
|
3
|
+
Walks the typer/click command tree and extracts help text for commands
|
|
4
|
+
matching the bash_allowlist. This ensures agent instructions stay in sync
|
|
5
|
+
with the CLI — the --help text is the single source of truth for both
|
|
6
|
+
human users and AI agents.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from wafer.cli_instructions import build_cli_instructions
|
|
10
|
+
|
|
11
|
+
instructions = build_cli_instructions([
|
|
12
|
+
"wafer evaluate",
|
|
13
|
+
"wafer nvidia ncu",
|
|
14
|
+
"wafer rocprof profile",
|
|
15
|
+
"python", # non-wafer commands are skipped
|
|
16
|
+
])
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import click
|
|
22
|
+
import typer.main
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _resolve_command(root: click.BaseCommand, parts: list[str]) -> click.BaseCommand | None:
|
|
26
|
+
"""Walk the click command tree to find a (sub)command by name parts.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
root: The root click command (from typer.main.get_command)
|
|
30
|
+
parts: Command path segments, e.g. ["evaluate", "kernelbench"]
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The click command at that path, or None if not found.
|
|
34
|
+
"""
|
|
35
|
+
cmd = root
|
|
36
|
+
for part in parts:
|
|
37
|
+
if not isinstance(cmd, click.MultiCommand):
|
|
38
|
+
return None
|
|
39
|
+
ctx = click.Context(cmd, info_name=part)
|
|
40
|
+
child = cmd.get_command(ctx, part)
|
|
41
|
+
if child is None:
|
|
42
|
+
return None
|
|
43
|
+
cmd = child
|
|
44
|
+
return cmd
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _format_command_help(cmd_path: str, cmd: click.BaseCommand) -> str:
|
|
48
|
+
"""Format a single command's help text for inclusion in a system prompt.
|
|
49
|
+
|
|
50
|
+
Extracts the description and option help text (skipping --help itself).
|
|
51
|
+
"""
|
|
52
|
+
lines = [f"### `{cmd_path}`"]
|
|
53
|
+
|
|
54
|
+
if cmd.help:
|
|
55
|
+
lines.append(cmd.help.strip())
|
|
56
|
+
|
|
57
|
+
# Extract option help
|
|
58
|
+
option_lines = []
|
|
59
|
+
for param in getattr(cmd, "params", []):
|
|
60
|
+
if not isinstance(param, click.Option):
|
|
61
|
+
continue
|
|
62
|
+
# Skip --help
|
|
63
|
+
if param.name == "help":
|
|
64
|
+
continue
|
|
65
|
+
name = "/".join(param.opts)
|
|
66
|
+
type_name = param.type.name.upper() if hasattr(param.type, "name") else ""
|
|
67
|
+
help_text = param.help or ""
|
|
68
|
+
is_flag = type_name in ("BOOL", "BOOLEAN") or param.is_flag
|
|
69
|
+
if type_name and not is_flag:
|
|
70
|
+
option_lines.append(f" {name} {type_name} {help_text}")
|
|
71
|
+
else:
|
|
72
|
+
option_lines.append(f" {name} {help_text}")
|
|
73
|
+
|
|
74
|
+
if option_lines:
|
|
75
|
+
lines.append("")
|
|
76
|
+
lines.append("Options:")
|
|
77
|
+
lines.extend(option_lines)
|
|
78
|
+
|
|
79
|
+
# List subcommands if this is a group
|
|
80
|
+
if isinstance(cmd, click.MultiCommand):
|
|
81
|
+
ctx = click.Context(cmd, info_name=cmd_path.split()[-1])
|
|
82
|
+
subcmd_names = cmd.list_commands(ctx)
|
|
83
|
+
if subcmd_names:
|
|
84
|
+
subcmd_lines = []
|
|
85
|
+
for name in subcmd_names:
|
|
86
|
+
subcmd = cmd.get_command(ctx, name)
|
|
87
|
+
if subcmd:
|
|
88
|
+
desc = (subcmd.help or subcmd.short_help or "").strip().split("\n")[0]
|
|
89
|
+
subcmd_lines.append(f" {cmd_path} {name} {desc}")
|
|
90
|
+
if subcmd_lines:
|
|
91
|
+
lines.append("")
|
|
92
|
+
lines.append("Subcommands:")
|
|
93
|
+
lines.extend(subcmd_lines)
|
|
94
|
+
|
|
95
|
+
return "\n".join(lines)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def build_cli_instructions(bash_allowlist: list[str]) -> str:
|
|
99
|
+
"""Generate CLI instruction text from --help for allowed wafer commands.
|
|
100
|
+
|
|
101
|
+
Walks the typer/click command tree and extracts help text for each
|
|
102
|
+
wafer command in the bash_allowlist. Non-wafer commands (python, ls, etc.)
|
|
103
|
+
are skipped.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
bash_allowlist: List of allowed bash command prefixes.
|
|
107
|
+
Example: ["wafer evaluate", "wafer nvidia ncu", "python"]
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Markdown-formatted CLI instructions, or empty string if no wafer
|
|
111
|
+
commands are in the allowlist.
|
|
112
|
+
"""
|
|
113
|
+
if not bash_allowlist:
|
|
114
|
+
return ""
|
|
115
|
+
|
|
116
|
+
# Filter to wafer commands only
|
|
117
|
+
wafer_commands = [cmd for cmd in bash_allowlist if cmd.startswith("wafer ")]
|
|
118
|
+
if not wafer_commands:
|
|
119
|
+
return ""
|
|
120
|
+
|
|
121
|
+
# Lazy import to avoid circular deps at module level
|
|
122
|
+
from wafer.cli import app
|
|
123
|
+
|
|
124
|
+
root = typer.main.get_command(app)
|
|
125
|
+
|
|
126
|
+
sections = []
|
|
127
|
+
for cmd_str in wafer_commands:
|
|
128
|
+
# "wafer evaluate kernelbench" -> ["evaluate", "kernelbench"]
|
|
129
|
+
parts = cmd_str.split()[1:] # drop "wafer" prefix
|
|
130
|
+
cmd = _resolve_command(root, parts)
|
|
131
|
+
if cmd is None:
|
|
132
|
+
# Command not found in tree — skip silently
|
|
133
|
+
continue
|
|
134
|
+
sections.append(_format_command_help(cmd_str, cmd))
|
|
135
|
+
|
|
136
|
+
if not sections:
|
|
137
|
+
return ""
|
|
138
|
+
|
|
139
|
+
header = (
|
|
140
|
+
"## Wafer CLI Commands\n\n"
|
|
141
|
+
"You do not have a local GPU. Use the wafer CLI to run on remote GPU hardware.\n"
|
|
142
|
+
)
|
|
143
|
+
return header + "\n\n".join(sections)
|
wafer/corpus.py
CHANGED
|
@@ -3,10 +3,12 @@
|
|
|
3
3
|
Download and manage documentation corpora for agent filesystem access.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
+
import re
|
|
6
7
|
import shutil
|
|
7
8
|
import tarfile
|
|
8
9
|
import tempfile
|
|
9
10
|
from dataclasses import dataclass
|
|
11
|
+
from html.parser import HTMLParser
|
|
10
12
|
from pathlib import Path
|
|
11
13
|
from typing import Literal
|
|
12
14
|
from urllib.parse import urlparse
|
|
@@ -33,7 +35,7 @@ class CorpusConfig:
|
|
|
33
35
|
|
|
34
36
|
name: CorpusName
|
|
35
37
|
description: str
|
|
36
|
-
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo"]
|
|
38
|
+
source_type: Literal["nvidia_md", "github_repo", "github_multi_repo", "mixed"]
|
|
37
39
|
urls: list[str] | None = None
|
|
38
40
|
repo: str | None = None
|
|
39
41
|
repo_paths: list[str] | None = None
|
|
@@ -67,21 +69,74 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
67
69
|
),
|
|
68
70
|
"cutlass": CorpusConfig(
|
|
69
71
|
name="cutlass",
|
|
70
|
-
description="CUTLASS
|
|
71
|
-
source_type="
|
|
72
|
-
|
|
73
|
-
|
|
72
|
+
description="CUTLASS C++ documentation, examples, and tutorials",
|
|
73
|
+
source_type="mixed",
|
|
74
|
+
# Official NVIDIA CUTLASS documentation (scraped as markdown)
|
|
75
|
+
urls=[
|
|
76
|
+
"https://docs.nvidia.com/cutlass/latest/overview.html",
|
|
77
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/functionality.html",
|
|
78
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/terminology.html",
|
|
79
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/fundamental_types.html",
|
|
80
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/programming_guidelines.html",
|
|
81
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/heuristics.html",
|
|
82
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/efficient_gemm.html",
|
|
83
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/pipeline.html",
|
|
84
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/profiler.html",
|
|
85
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/dependent_kernel_launch.html",
|
|
86
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_functionality.html",
|
|
87
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/blackwell_cluster_launch_control.html",
|
|
88
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/00_quickstart.html",
|
|
89
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/01_layout.html",
|
|
90
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/02_layout_algebra.html",
|
|
91
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/03_tensor.html",
|
|
92
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/04_algorithms.html",
|
|
93
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0t_mma_atom.html",
|
|
94
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0x_gemm_tutorial.html",
|
|
95
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0y_predication.html",
|
|
96
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cute/0z_tma_tensors.html",
|
|
97
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_design.html",
|
|
98
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/cutlass_3x_backwards_compatibility.html",
|
|
99
|
+
"https://docs.nvidia.com/cutlass/latest/media/docs/cpp/gemm_api_3x.html",
|
|
100
|
+
],
|
|
101
|
+
# NVIDIA/cutlass GitHub examples (excluding python/)
|
|
102
|
+
repos=[
|
|
103
|
+
RepoSource(
|
|
104
|
+
repo="NVIDIA/cutlass",
|
|
105
|
+
paths=["examples"],
|
|
106
|
+
branch="main",
|
|
107
|
+
),
|
|
108
|
+
],
|
|
74
109
|
),
|
|
75
110
|
"hip": CorpusConfig(
|
|
76
111
|
name="hip",
|
|
77
|
-
description="HIP programming guide
|
|
78
|
-
source_type="
|
|
79
|
-
|
|
80
|
-
|
|
112
|
+
description="HIP programming guide, API reference, and examples",
|
|
113
|
+
source_type="github_multi_repo",
|
|
114
|
+
repos=[
|
|
115
|
+
# HIP - main documentation and API
|
|
116
|
+
RepoSource(
|
|
117
|
+
repo="ROCm/HIP",
|
|
118
|
+
paths=["docs"],
|
|
119
|
+
),
|
|
120
|
+
# HIP examples - code samples
|
|
121
|
+
RepoSource(
|
|
122
|
+
repo="ROCm/HIP-Examples",
|
|
123
|
+
paths=["HIP-Examples-Applications", "mini-nbody"],
|
|
124
|
+
),
|
|
125
|
+
# clr - HIP/OpenCL runtime (low-level)
|
|
126
|
+
RepoSource(
|
|
127
|
+
repo="ROCm/clr",
|
|
128
|
+
paths=["hipamd/include", "rocclr/device/gpu"],
|
|
129
|
+
),
|
|
130
|
+
# ROCm docs - official documentation
|
|
131
|
+
RepoSource(
|
|
132
|
+
repo="ROCm/ROCm",
|
|
133
|
+
paths=["docs"],
|
|
134
|
+
),
|
|
135
|
+
],
|
|
81
136
|
),
|
|
82
137
|
"amd": CorpusConfig(
|
|
83
138
|
name="amd",
|
|
84
|
-
description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM)",
|
|
139
|
+
description="AMD GPU kernel development (rocWMMA, CK, AITER, rocBLAS, HipKittens, vLLM, FlashAttention)",
|
|
85
140
|
source_type="github_multi_repo",
|
|
86
141
|
repos=[
|
|
87
142
|
# rocWMMA - wave matrix multiply-accumulate (WMMA) intrinsics
|
|
@@ -125,11 +180,17 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
125
180
|
paths=["docs"],
|
|
126
181
|
branch="develop_deprecated",
|
|
127
182
|
),
|
|
128
|
-
# HipKittens - high-performance AMD kernels
|
|
183
|
+
# HipKittens - high-performance AMD kernels (main branch: MI350X/CDNA4+)
|
|
129
184
|
RepoSource(
|
|
130
185
|
repo="HazyResearch/HipKittens",
|
|
131
186
|
paths=["docs", "kernels", "include"],
|
|
132
187
|
),
|
|
188
|
+
# HipKittens cdna3 branch - MI300X/MI325X (gfx942)
|
|
189
|
+
RepoSource(
|
|
190
|
+
repo="HazyResearch/HipKittens",
|
|
191
|
+
paths=["kernels", "include", "tests"],
|
|
192
|
+
branch="cdna3",
|
|
193
|
+
),
|
|
133
194
|
# vLLM AMD kernels
|
|
134
195
|
RepoSource(
|
|
135
196
|
repo="vllm-project/vllm",
|
|
@@ -145,6 +206,46 @@ CORPORA: dict[CorpusName, CorpusConfig] = {
|
|
|
145
206
|
repo="huggingface/hf-rocm-kernels",
|
|
146
207
|
paths=["csrc", "hf_rocm_kernels", "docs"],
|
|
147
208
|
),
|
|
209
|
+
# ROCm/flash-attention - FlashAttention for AMD GPUs
|
|
210
|
+
RepoSource(
|
|
211
|
+
repo="ROCm/flash-attention",
|
|
212
|
+
paths=["csrc", "docs"],
|
|
213
|
+
),
|
|
214
|
+
# ROCm/triton - Triton compiler for AMD GPUs
|
|
215
|
+
RepoSource(
|
|
216
|
+
repo="ROCm/triton",
|
|
217
|
+
paths=["python/tutorials", "third_party/amd"],
|
|
218
|
+
),
|
|
219
|
+
# ROCm/rccl - ROCm Communication Collectives Library (multi-GPU)
|
|
220
|
+
RepoSource(
|
|
221
|
+
repo="ROCm/rccl",
|
|
222
|
+
paths=["docs"],
|
|
223
|
+
),
|
|
224
|
+
# ROCm/rocprofiler-sdk - AMD GPU profiling SDK
|
|
225
|
+
RepoSource(
|
|
226
|
+
repo="ROCm/rocprofiler-sdk",
|
|
227
|
+
paths=["docs", "samples"],
|
|
228
|
+
),
|
|
229
|
+
# ROCm/omniperf - AMD GPU profiling tool
|
|
230
|
+
RepoSource(
|
|
231
|
+
repo="ROCm/omniperf",
|
|
232
|
+
paths=["docs", "src/omniperf_analyze"],
|
|
233
|
+
),
|
|
234
|
+
# ROCm/omnitrace - Application tracing for AMD
|
|
235
|
+
RepoSource(
|
|
236
|
+
repo="ROCm/omnitrace",
|
|
237
|
+
paths=["docs"],
|
|
238
|
+
),
|
|
239
|
+
# AMD GPUOpen Performance Guides
|
|
240
|
+
RepoSource(
|
|
241
|
+
repo="GPUOpen-Tools/gpu_performance_api",
|
|
242
|
+
paths=["docs"],
|
|
243
|
+
),
|
|
244
|
+
# AMD LLVM - AMD GPU compiler backend
|
|
245
|
+
RepoSource(
|
|
246
|
+
repo="ROCm/llvm-project",
|
|
247
|
+
paths=["amd/device-libs/README.md", "llvm/docs/AMDGPUUsage.rst"],
|
|
248
|
+
),
|
|
148
249
|
],
|
|
149
250
|
),
|
|
150
251
|
}
|
|
@@ -169,19 +270,195 @@ def _url_to_filepath(url: str, base_dir: Path) -> Path:
|
|
|
169
270
|
return base_dir / "/".join(path_parts)
|
|
170
271
|
|
|
171
272
|
|
|
273
|
+
class _HTMLToMarkdown(HTMLParser):
|
|
274
|
+
"""HTML to Markdown converter for NVIDIA documentation pages.
|
|
275
|
+
|
|
276
|
+
Uses stdlib HTMLParser - requires subclassing due to callback-based API.
|
|
277
|
+
The public interface is the functional `_html_to_markdown()` below.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(self) -> None:
|
|
281
|
+
super().__init__()
|
|
282
|
+
self.output: list[str] = []
|
|
283
|
+
self.current_tag: str = ""
|
|
284
|
+
self.in_code_block = False
|
|
285
|
+
self.in_pre = False
|
|
286
|
+
self.list_depth = 0
|
|
287
|
+
self.ordered_list_counters: list[int] = []
|
|
288
|
+
self.skip_content = False
|
|
289
|
+
self.link_href: str | None = None
|
|
290
|
+
|
|
291
|
+
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
|
292
|
+
self.current_tag = tag
|
|
293
|
+
attrs_dict = dict(attrs)
|
|
294
|
+
|
|
295
|
+
# Skip script, style, nav, footer, header
|
|
296
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
297
|
+
self.skip_content = True
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
if tag == "h1":
|
|
301
|
+
self.output.append("\n# ")
|
|
302
|
+
elif tag == "h2":
|
|
303
|
+
self.output.append("\n## ")
|
|
304
|
+
elif tag == "h3":
|
|
305
|
+
self.output.append("\n### ")
|
|
306
|
+
elif tag == "h4":
|
|
307
|
+
self.output.append("\n#### ")
|
|
308
|
+
elif tag == "h5":
|
|
309
|
+
self.output.append("\n##### ")
|
|
310
|
+
elif tag == "h6":
|
|
311
|
+
self.output.append("\n###### ")
|
|
312
|
+
elif tag == "p":
|
|
313
|
+
self.output.append("\n\n")
|
|
314
|
+
elif tag == "br":
|
|
315
|
+
self.output.append("\n")
|
|
316
|
+
elif tag == "strong" or tag == "b":
|
|
317
|
+
self.output.append("**")
|
|
318
|
+
elif tag == "em" or tag == "i":
|
|
319
|
+
self.output.append("*")
|
|
320
|
+
elif tag == "code" and not self.in_pre:
|
|
321
|
+
self.output.append("`")
|
|
322
|
+
self.in_code_block = True
|
|
323
|
+
elif tag == "pre":
|
|
324
|
+
self.in_pre = True
|
|
325
|
+
# Check for language hint in class
|
|
326
|
+
lang = ""
|
|
327
|
+
if class_attr := attrs_dict.get("class"):
|
|
328
|
+
if "python" in class_attr.lower():
|
|
329
|
+
lang = "python"
|
|
330
|
+
elif "cpp" in class_attr.lower() or "c++" in class_attr.lower():
|
|
331
|
+
lang = "cpp"
|
|
332
|
+
elif "cuda" in class_attr.lower():
|
|
333
|
+
lang = "cuda"
|
|
334
|
+
self.output.append(f"\n```{lang}\n")
|
|
335
|
+
elif tag == "ul":
|
|
336
|
+
self.list_depth += 1
|
|
337
|
+
self.output.append("\n")
|
|
338
|
+
elif tag == "ol":
|
|
339
|
+
self.list_depth += 1
|
|
340
|
+
self.ordered_list_counters.append(1)
|
|
341
|
+
self.output.append("\n")
|
|
342
|
+
elif tag == "li":
|
|
343
|
+
indent = " " * (self.list_depth - 1)
|
|
344
|
+
if self.ordered_list_counters:
|
|
345
|
+
num = self.ordered_list_counters[-1]
|
|
346
|
+
self.output.append(f"{indent}{num}. ")
|
|
347
|
+
self.ordered_list_counters[-1] += 1
|
|
348
|
+
else:
|
|
349
|
+
self.output.append(f"{indent}- ")
|
|
350
|
+
elif tag == "a":
|
|
351
|
+
self.link_href = attrs_dict.get("href")
|
|
352
|
+
self.output.append("[")
|
|
353
|
+
elif tag == "img":
|
|
354
|
+
alt = attrs_dict.get("alt", "image")
|
|
355
|
+
src = attrs_dict.get("src", "")
|
|
356
|
+
self.output.append(f"")
|
|
357
|
+
elif tag == "blockquote":
|
|
358
|
+
self.output.append("\n> ")
|
|
359
|
+
elif tag == "hr":
|
|
360
|
+
self.output.append("\n---\n")
|
|
361
|
+
elif tag == "table":
|
|
362
|
+
self.output.append("\n")
|
|
363
|
+
elif tag == "th":
|
|
364
|
+
self.output.append("| ")
|
|
365
|
+
elif tag == "td":
|
|
366
|
+
self.output.append("| ")
|
|
367
|
+
elif tag == "tr":
|
|
368
|
+
pass # Handled in endtag
|
|
369
|
+
|
|
370
|
+
def handle_endtag(self, tag: str) -> None:
|
|
371
|
+
if tag in ("script", "style", "nav", "footer", "header", "aside"):
|
|
372
|
+
self.skip_content = False
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
if tag in ("h1", "h2", "h3", "h4", "h5", "h6"):
|
|
376
|
+
self.output.append("\n")
|
|
377
|
+
elif tag == "strong" or tag == "b":
|
|
378
|
+
self.output.append("**")
|
|
379
|
+
elif tag == "em" or tag == "i":
|
|
380
|
+
self.output.append("*")
|
|
381
|
+
elif tag == "code" and not self.in_pre:
|
|
382
|
+
self.output.append("`")
|
|
383
|
+
self.in_code_block = False
|
|
384
|
+
elif tag == "pre":
|
|
385
|
+
self.in_pre = False
|
|
386
|
+
self.output.append("\n```\n")
|
|
387
|
+
elif tag == "ul":
|
|
388
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
389
|
+
elif tag == "ol":
|
|
390
|
+
self.list_depth = max(0, self.list_depth - 1)
|
|
391
|
+
if self.ordered_list_counters:
|
|
392
|
+
self.ordered_list_counters.pop()
|
|
393
|
+
elif tag == "li":
|
|
394
|
+
self.output.append("\n")
|
|
395
|
+
elif tag == "a":
|
|
396
|
+
if self.link_href:
|
|
397
|
+
self.output.append(f"]({self.link_href})")
|
|
398
|
+
else:
|
|
399
|
+
self.output.append("]")
|
|
400
|
+
self.link_href = None
|
|
401
|
+
elif tag == "p":
|
|
402
|
+
self.output.append("\n")
|
|
403
|
+
elif tag == "blockquote":
|
|
404
|
+
self.output.append("\n")
|
|
405
|
+
elif tag == "tr":
|
|
406
|
+
self.output.append("|\n")
|
|
407
|
+
elif tag == "thead":
|
|
408
|
+
# Add markdown table separator after header row
|
|
409
|
+
self.output.append("|---" * 10 + "|\n")
|
|
410
|
+
|
|
411
|
+
def handle_data(self, data: str) -> None:
|
|
412
|
+
if self.skip_content:
|
|
413
|
+
return
|
|
414
|
+
# Preserve whitespace in code blocks
|
|
415
|
+
if self.in_pre:
|
|
416
|
+
self.output.append(data)
|
|
417
|
+
else:
|
|
418
|
+
# Collapse whitespace outside code
|
|
419
|
+
text = re.sub(r"\s+", " ", data)
|
|
420
|
+
if text.strip():
|
|
421
|
+
self.output.append(text)
|
|
422
|
+
|
|
423
|
+
def get_markdown(self) -> str:
|
|
424
|
+
"""Get the converted markdown, cleaned up."""
|
|
425
|
+
md = "".join(self.output)
|
|
426
|
+
# Clean up excessive newlines
|
|
427
|
+
md = re.sub(r"\n{3,}", "\n\n", md)
|
|
428
|
+
# Clean up empty table separators
|
|
429
|
+
md = re.sub(r"\|---\|---.*\|\n(?!\|)", "", md)
|
|
430
|
+
return md.strip()
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _html_to_markdown(html: str) -> str:
|
|
434
|
+
"""Convert HTML to Markdown."""
|
|
435
|
+
parser = _HTMLToMarkdown()
|
|
436
|
+
parser.feed(html)
|
|
437
|
+
return parser.get_markdown()
|
|
438
|
+
|
|
439
|
+
|
|
172
440
|
def _download_nvidia_md(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
173
|
-
"""Download NVIDIA docs
|
|
441
|
+
"""Download NVIDIA docs and convert HTML to Markdown.
|
|
442
|
+
|
|
443
|
+
NVIDIA's .md endpoint no longer works, so we scrape HTML and convert to markdown.
|
|
444
|
+
"""
|
|
174
445
|
assert config.urls is not None
|
|
175
446
|
downloaded = 0
|
|
176
447
|
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
|
177
448
|
for url in config.urls:
|
|
178
|
-
md_url = f"{url}.md"
|
|
179
449
|
filepath = _url_to_filepath(url, dest)
|
|
180
450
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
181
451
|
try:
|
|
182
|
-
|
|
452
|
+
# Fetch HTML page directly
|
|
453
|
+
resp = client.get(url)
|
|
183
454
|
resp.raise_for_status()
|
|
184
|
-
|
|
455
|
+
|
|
456
|
+
# Convert HTML to Markdown
|
|
457
|
+
markdown = _html_to_markdown(resp.text)
|
|
458
|
+
|
|
459
|
+
# Add source URL as header
|
|
460
|
+
content = f"<!-- Source: {url} -->\n\n{markdown}"
|
|
461
|
+
filepath.write_text(content)
|
|
185
462
|
downloaded += 1
|
|
186
463
|
if verbose:
|
|
187
464
|
print(f" ✓ {filepath.relative_to(dest)}")
|
|
@@ -275,6 +552,25 @@ def _download_github_multi_repo(config: CorpusConfig, dest: Path, verbose: bool
|
|
|
275
552
|
return downloaded
|
|
276
553
|
|
|
277
554
|
|
|
555
|
+
def _download_mixed(config: CorpusConfig, dest: Path, verbose: bool = True) -> int:
|
|
556
|
+
"""Download from mixed sources (NVIDIA docs + GitHub repos)."""
|
|
557
|
+
total = 0
|
|
558
|
+
|
|
559
|
+
# Download NVIDIA markdown docs (urls)
|
|
560
|
+
if config.urls:
|
|
561
|
+
if verbose:
|
|
562
|
+
print(" [NVIDIA docs]")
|
|
563
|
+
total += _download_nvidia_md(config, dest, verbose)
|
|
564
|
+
|
|
565
|
+
# Download GitHub repos
|
|
566
|
+
if config.repos:
|
|
567
|
+
if verbose:
|
|
568
|
+
print(" [GitHub repos]")
|
|
569
|
+
total += _download_github_multi_repo(config, dest, verbose)
|
|
570
|
+
|
|
571
|
+
return total
|
|
572
|
+
|
|
573
|
+
|
|
278
574
|
def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True) -> Path:
|
|
279
575
|
"""Download a corpus to local cache.
|
|
280
576
|
|
|
@@ -311,6 +607,8 @@ def download_corpus(name: CorpusName, force: bool = False, verbose: bool = True)
|
|
|
311
607
|
count = _download_github_repo(config, dest, verbose)
|
|
312
608
|
elif config.source_type == "github_multi_repo":
|
|
313
609
|
count = _download_github_multi_repo(config, dest, verbose)
|
|
610
|
+
elif config.source_type == "mixed":
|
|
611
|
+
count = _download_mixed(config, dest, verbose)
|
|
314
612
|
else:
|
|
315
613
|
raise ValueError(f"Unknown source type: {config.source_type}")
|
|
316
614
|
if verbose:
|