wafer-core 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/auth.py +38 -6
- wafer_core/environments/coding.py +8 -0
- wafer_core/rollouts/dtypes.py +4 -0
- wafer_core/rollouts/environments/localfs.py +50 -2
- wafer_core/rollouts/evaluation.py +17 -1
- wafer_core/rollouts/prompt.py +14 -4
- wafer_core/rollouts/skills.py +176 -0
- wafer_core/rollouts/templates/base.py +3 -0
- wafer_core/targets/runpod.py +154 -15
- wafer_core/tools/__init__.py +14 -0
- wafer_core/tools/file_tools/grep_tool.py +56 -29
- wafer_core/tools/search_docs_tool.py +196 -0
- wafer_core/tools/skill_tool.py +64 -0
- wafer_core/utils/backend.py +3 -0
- wafer_core/utils/kernel_utils/static_checker.py +175 -3
- wafer_core/utils/kernel_utils/targets/config.py +58 -24
- wafer_core/utils/modal_execution/modal_app.py +14 -27
- {wafer_core-0.1.21.dist-info → wafer_core-0.1.23.dist-info}/METADATA +1 -1
- {wafer_core-0.1.21.dist-info → wafer_core-0.1.23.dist-info}/RECORD +20 -17
- {wafer_core-0.1.21.dist-info → wafer_core-0.1.23.dist-info}/WHEEL +0 -0
wafer_core/targets/runpod.py
CHANGED
|
@@ -14,7 +14,6 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
-
import os
|
|
18
17
|
import time
|
|
19
18
|
from contextlib import asynccontextmanager
|
|
20
19
|
from dataclasses import dataclass
|
|
@@ -250,10 +249,17 @@ async def provision_pod(target: RunPodTarget) -> tuple[str, str, int, str]:
|
|
|
250
249
|
"ports": "22/tcp",
|
|
251
250
|
"startSsh": True,
|
|
252
251
|
"startJupyter": False,
|
|
253
|
-
"imageName": target.image,
|
|
254
252
|
"env": [],
|
|
255
253
|
}
|
|
256
254
|
|
|
255
|
+
if target.template_id:
|
|
256
|
+
# Template defines image, dockerArgs (sshd setup), and ports.
|
|
257
|
+
# Required for non-RunPod images (e.g. rocm/pytorch) that don't
|
|
258
|
+
# have RunPod's built-in SSH handler.
|
|
259
|
+
pod_input["templateId"] = target.template_id
|
|
260
|
+
else:
|
|
261
|
+
pod_input["imageName"] = target.image
|
|
262
|
+
|
|
257
263
|
variables = {"input": pod_input}
|
|
258
264
|
|
|
259
265
|
logger.info(f"Provisioning RunPod pod: {pod_name}")
|
|
@@ -334,7 +340,8 @@ async def _wait_for_ssh(pod_id: str, timeout_seconds: int) -> tuple[str, int, st
|
|
|
334
340
|
# Check for SSH port
|
|
335
341
|
runtime = pod.get("runtime")
|
|
336
342
|
if runtime and status == "running":
|
|
337
|
-
|
|
343
|
+
# ports can be null in JSON response, so use 'or []' instead of default
|
|
344
|
+
for port in runtime.get("ports") or []:
|
|
338
345
|
if (
|
|
339
346
|
port.get("privatePort") == 22
|
|
340
347
|
and port.get("isIpPublic")
|
|
@@ -378,6 +385,55 @@ async def terminate_pod(pod_id: str) -> bool:
|
|
|
378
385
|
return False
|
|
379
386
|
|
|
380
387
|
|
|
388
|
+
# =============================================================================
|
|
389
|
+
# Template Management (not yet implemented)
|
|
390
|
+
# =============================================================================
|
|
391
|
+
#
|
|
392
|
+
# The saveTemplate mutation allows creating reusable pod templates with custom
|
|
393
|
+
# configurations. Templates can specify docker images, environment setup,
|
|
394
|
+
# container disk size, and other pod settings.
|
|
395
|
+
#
|
|
396
|
+
# Example mutation:
|
|
397
|
+
#
|
|
398
|
+
# mutation saveTemplate($input: SaveTemplateInput) {
|
|
399
|
+
# saveTemplate(input: $input) {
|
|
400
|
+
# id
|
|
401
|
+
# name
|
|
402
|
+
# imageName
|
|
403
|
+
# containerDiskInGb
|
|
404
|
+
# ports
|
|
405
|
+
# dockerArgs
|
|
406
|
+
# startSsh
|
|
407
|
+
# startJupyter
|
|
408
|
+
# }
|
|
409
|
+
# }
|
|
410
|
+
#
|
|
411
|
+
# Example variables:
|
|
412
|
+
#
|
|
413
|
+
# {
|
|
414
|
+
# "input": {
|
|
415
|
+
# "containerDiskInGb": 50,
|
|
416
|
+
# "dockerArgs": "bash -c \"apt-get update && apt-get install -y openssh-server && ...\"",
|
|
417
|
+
# "env": [],
|
|
418
|
+
# "isPublic": false,
|
|
419
|
+
# "isServerless": false,
|
|
420
|
+
# "name": "template-name",
|
|
421
|
+
# "ports": "22/tcp",
|
|
422
|
+
# "portsConfig": [{"name": "SSH", "port": "22"}],
|
|
423
|
+
# "readme": "",
|
|
424
|
+
# "volumeInGb": 0,
|
|
425
|
+
# "volumeMountPath": "",
|
|
426
|
+
# "config": {},
|
|
427
|
+
# "category": "AMD",
|
|
428
|
+
# "imageName": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
|
429
|
+
# }
|
|
430
|
+
# }
|
|
431
|
+
#
|
|
432
|
+
# Note: Template creation is not currently implemented in this module.
|
|
433
|
+
# If needed, implement a save_template() function following the pattern of
|
|
434
|
+
# provision_pod() and terminate_pod() above.
|
|
435
|
+
|
|
436
|
+
|
|
381
437
|
# =============================================================================
|
|
382
438
|
# Context Manager
|
|
383
439
|
# =============================================================================
|
|
@@ -482,20 +538,103 @@ async def cleanup_target(target_name: str) -> bool:
|
|
|
482
538
|
return success
|
|
483
539
|
|
|
484
540
|
|
|
541
|
+
async def sync_pods_from_api() -> list[PodState]:
|
|
542
|
+
"""Query RunPod API for all running pods and update local state.
|
|
543
|
+
|
|
544
|
+
This discovers pods that exist on the account but aren't in our state file
|
|
545
|
+
(e.g., created manually or by another machine). Updates the state file with
|
|
546
|
+
any wafer-created pods found.
|
|
547
|
+
|
|
548
|
+
Returns list of all running pods with SSH info.
|
|
549
|
+
"""
|
|
550
|
+
query = """
|
|
551
|
+
query {
|
|
552
|
+
myself {
|
|
553
|
+
pods {
|
|
554
|
+
id
|
|
555
|
+
name
|
|
556
|
+
desiredStatus
|
|
557
|
+
runtime {
|
|
558
|
+
ports {
|
|
559
|
+
ip
|
|
560
|
+
isIpPublic
|
|
561
|
+
privatePort
|
|
562
|
+
publicPort
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
}
|
|
568
|
+
"""
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
data = await _graphql_request_async(query)
|
|
572
|
+
except Exception as e:
|
|
573
|
+
logger.warning(f"Failed to query pods from API: {e}")
|
|
574
|
+
return []
|
|
575
|
+
|
|
576
|
+
pods = data.get("myself", {}).get("pods", [])
|
|
577
|
+
running_pods = []
|
|
578
|
+
|
|
579
|
+
for pod in pods:
|
|
580
|
+
status = pod.get("desiredStatus", "").lower()
|
|
581
|
+
if status != "running":
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
pod_id = pod["id"]
|
|
585
|
+
pod_name = pod.get("name", "")
|
|
586
|
+
|
|
587
|
+
# Extract SSH info
|
|
588
|
+
runtime = pod.get("runtime")
|
|
589
|
+
if not runtime:
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
public_ip = None
|
|
593
|
+
ssh_port = None
|
|
594
|
+
for port in runtime.get("ports") or []:
|
|
595
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
596
|
+
public_ip = port.get("ip")
|
|
597
|
+
ssh_port = port.get("publicPort")
|
|
598
|
+
break
|
|
599
|
+
|
|
600
|
+
if not public_ip or not ssh_port:
|
|
601
|
+
continue
|
|
602
|
+
|
|
603
|
+
# Extract target name from pod name (wafer-{target_name}-{timestamp})
|
|
604
|
+
target_name = None
|
|
605
|
+
if pod_name.startswith("wafer-"):
|
|
606
|
+
parts = pod_name.split("-")
|
|
607
|
+
if len(parts) >= 3:
|
|
608
|
+
# Handle target names with hyphens: wafer-runpod-mi300x-1234567
|
|
609
|
+
target_name = "-".join(parts[1:-1])
|
|
610
|
+
|
|
611
|
+
pod_state = PodState(
|
|
612
|
+
pod_id=pod_id,
|
|
613
|
+
target_name=target_name or pod_name,
|
|
614
|
+
public_ip=public_ip,
|
|
615
|
+
ssh_port=ssh_port,
|
|
616
|
+
ssh_username="root",
|
|
617
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
618
|
+
)
|
|
619
|
+
running_pods.append(pod_state)
|
|
620
|
+
|
|
621
|
+
# Update state file if this is a wafer-created pod
|
|
622
|
+
if target_name:
|
|
623
|
+
existing = get_pod_state(target_name)
|
|
624
|
+
if not existing or existing.pod_id != pod_id:
|
|
625
|
+
logger.info(f"Syncing pod {pod_id} to state for target {target_name}")
|
|
626
|
+
_add_pod_to_state(target_name, pod_id, public_ip, ssh_port, "root")
|
|
627
|
+
|
|
628
|
+
return running_pods
|
|
629
|
+
|
|
630
|
+
|
|
485
631
|
async def list_running_pods() -> list[PodState]:
|
|
486
|
-
"""List all pods
|
|
487
|
-
state = _load_state()
|
|
488
|
-
running = []
|
|
632
|
+
"""List all running pods by querying the RunPod API.
|
|
489
633
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
# Clean up stale entry
|
|
495
|
-
logger.info(f"Removing stale state for {name} (pod {pod_state.pod_id})")
|
|
496
|
-
_remove_pod_from_state(name)
|
|
497
|
-
|
|
498
|
-
return running
|
|
634
|
+
Syncs state file with API to discover pods not in local state.
|
|
635
|
+
Returns list of running pods with SSH info.
|
|
636
|
+
"""
|
|
637
|
+
return await sync_pods_from_api()
|
|
499
638
|
|
|
500
639
|
|
|
501
640
|
async def cleanup_all_pods() -> int:
|
wafer_core/tools/__init__.py
CHANGED
|
@@ -49,6 +49,10 @@ from wafer_core.tools.rocprof_systems_tools import (
|
|
|
49
49
|
exec_rocprof_systems_query,
|
|
50
50
|
exec_rocprof_systems_sample,
|
|
51
51
|
)
|
|
52
|
+
from wafer_core.tools.skill_tool import (
|
|
53
|
+
SKILL_TOOL,
|
|
54
|
+
exec_skill,
|
|
55
|
+
)
|
|
52
56
|
from wafer_core.tools.tracelens_tools import (
|
|
53
57
|
TRACELENS_COLLECTIVE_TOOL,
|
|
54
58
|
TRACELENS_COMPARE_TOOL,
|
|
@@ -68,6 +72,10 @@ from wafer_core.tools.write_kernel_tool import (
|
|
|
68
72
|
KernelSubmission,
|
|
69
73
|
exec_write_kernel,
|
|
70
74
|
)
|
|
75
|
+
from wafer_core.tools.search_docs_tool import (
|
|
76
|
+
SEARCH_DOCS_TOOL,
|
|
77
|
+
exec_search_docs,
|
|
78
|
+
)
|
|
71
79
|
|
|
72
80
|
__all__ = [
|
|
73
81
|
# File tools
|
|
@@ -88,6 +96,9 @@ __all__ = [
|
|
|
88
96
|
"BashPermissionResult",
|
|
89
97
|
"check_bash_permissions",
|
|
90
98
|
"exec_bash",
|
|
99
|
+
# Skill tool
|
|
100
|
+
"SKILL_TOOL",
|
|
101
|
+
"exec_skill",
|
|
91
102
|
# Wafer tool
|
|
92
103
|
"WAFER_TOOL",
|
|
93
104
|
"WAFER_SUBCOMMANDS",
|
|
@@ -126,4 +137,7 @@ __all__ = [
|
|
|
126
137
|
"exec_tracelens_report",
|
|
127
138
|
"exec_tracelens_compare",
|
|
128
139
|
"exec_tracelens_collective",
|
|
140
|
+
# Search docs tool
|
|
141
|
+
"SEARCH_DOCS_TOOL",
|
|
142
|
+
"exec_search_docs",
|
|
129
143
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Grep tool using ripgrep
|
|
1
|
+
"""Grep tool using ripgrep (with fallback to standard grep)."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
@@ -15,7 +15,7 @@ GREP_TOOL = Tool(
|
|
|
15
15
|
function=ToolFunction(
|
|
16
16
|
name="grep",
|
|
17
17
|
description=(
|
|
18
|
-
"Search for a pattern in files
|
|
18
|
+
"Search for a pattern in files. "
|
|
19
19
|
"Returns matching lines with file paths and line numbers. "
|
|
20
20
|
"Supports regex patterns by default."
|
|
21
21
|
),
|
|
@@ -54,7 +54,7 @@ GREP_TOOL = Tool(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
57
|
-
"""Execute grep using ripgrep."""
|
|
57
|
+
"""Execute grep using ripgrep (preferred) or standard grep (fallback)."""
|
|
58
58
|
import shutil
|
|
59
59
|
import subprocess
|
|
60
60
|
|
|
@@ -74,35 +74,55 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
74
74
|
error="'pattern' is required",
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
#
|
|
77
|
+
# Try ripgrep first, fall back to standard grep
|
|
78
78
|
rg_path = shutil.which("rg")
|
|
79
|
-
|
|
79
|
+
grep_path = shutil.which("grep")
|
|
80
|
+
|
|
81
|
+
if rg_path:
|
|
82
|
+
# Use ripgrep (faster, better defaults)
|
|
83
|
+
cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
|
|
84
|
+
|
|
85
|
+
if case_insensitive:
|
|
86
|
+
cmd.append("--ignore-case")
|
|
87
|
+
|
|
88
|
+
if context_lines:
|
|
89
|
+
cmd.extend(["--context", str(context_lines)])
|
|
90
|
+
|
|
91
|
+
if glob_pattern:
|
|
92
|
+
cmd.extend(["--glob", glob_pattern])
|
|
93
|
+
|
|
94
|
+
# Limit results
|
|
95
|
+
cmd.extend(["--max-count", str(max_results)])
|
|
96
|
+
|
|
97
|
+
cmd.append(pattern)
|
|
98
|
+
cmd.append(search_path)
|
|
99
|
+
use_ripgrep = True
|
|
100
|
+
elif grep_path:
|
|
101
|
+
# Fallback to standard grep
|
|
102
|
+
cmd = [grep_path, "-r", "-n", "--color=never"]
|
|
103
|
+
|
|
104
|
+
if case_insensitive:
|
|
105
|
+
cmd.append("-i")
|
|
106
|
+
|
|
107
|
+
if context_lines:
|
|
108
|
+
cmd.extend(["-C", str(context_lines)])
|
|
109
|
+
|
|
110
|
+
if glob_pattern:
|
|
111
|
+
# Standard grep uses --include for glob patterns
|
|
112
|
+
cmd.extend(["--include", glob_pattern])
|
|
113
|
+
|
|
114
|
+
cmd.append(pattern)
|
|
115
|
+
cmd.append(search_path)
|
|
116
|
+
use_ripgrep = False
|
|
117
|
+
else:
|
|
80
118
|
return ToolResult(
|
|
81
119
|
tool_call_id=tool_call.id,
|
|
82
120
|
is_error=True,
|
|
83
121
|
content="",
|
|
84
|
-
error="ripgrep (rg)
|
|
122
|
+
error="Neither ripgrep (rg) nor grep found. Please install one.",
|
|
85
123
|
)
|
|
86
124
|
|
|
87
|
-
#
|
|
88
|
-
cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
|
|
89
|
-
|
|
90
|
-
if case_insensitive:
|
|
91
|
-
cmd.append("--ignore-case")
|
|
92
|
-
|
|
93
|
-
if context_lines:
|
|
94
|
-
cmd.extend(["--context", str(context_lines)])
|
|
95
|
-
|
|
96
|
-
if glob_pattern:
|
|
97
|
-
cmd.extend(["--glob", glob_pattern])
|
|
98
|
-
|
|
99
|
-
# Limit results
|
|
100
|
-
cmd.extend(["--max-count", str(max_results)])
|
|
101
|
-
|
|
102
|
-
cmd.append(pattern)
|
|
103
|
-
cmd.append(search_path)
|
|
104
|
-
|
|
105
|
-
# Run ripgrep
|
|
125
|
+
# Run the search
|
|
106
126
|
try:
|
|
107
127
|
result = subprocess.run(
|
|
108
128
|
cmd,
|
|
@@ -126,13 +146,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
126
146
|
error=f"Search failed: {e}",
|
|
127
147
|
)
|
|
128
148
|
|
|
129
|
-
# ripgrep
|
|
149
|
+
# Both ripgrep and grep return exit code 1 for no matches (not an error)
|
|
130
150
|
if result.returncode not in (0, 1):
|
|
151
|
+
tool_name = "ripgrep" if use_ripgrep else "grep"
|
|
131
152
|
return ToolResult(
|
|
132
153
|
tool_call_id=tool_call.id,
|
|
133
154
|
is_error=True,
|
|
134
155
|
content="",
|
|
135
|
-
error=result.stderr or f"
|
|
156
|
+
error=result.stderr or f"{tool_name} exited with code {result.returncode}",
|
|
136
157
|
)
|
|
137
158
|
|
|
138
159
|
output = result.stdout.strip()
|
|
@@ -143,8 +164,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
143
164
|
content=f"No matches found for pattern: {pattern}",
|
|
144
165
|
)
|
|
145
166
|
|
|
146
|
-
# Count matches
|
|
147
|
-
|
|
167
|
+
# Count matches and limit output for standard grep
|
|
168
|
+
lines = output.split("\n")
|
|
169
|
+
if not use_ripgrep and len(lines) > max_results:
|
|
170
|
+
lines = lines[:max_results]
|
|
171
|
+
output = "\n".join(lines)
|
|
172
|
+
output += f"\n... (truncated to {max_results} results)"
|
|
173
|
+
|
|
174
|
+
match_count = min(len(lines), max_results)
|
|
148
175
|
header = f"Found {match_count} matches:\n\n"
|
|
149
176
|
|
|
150
177
|
return ToolResult(
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Search documentation tool for GPU programming corpora.
|
|
2
|
+
|
|
3
|
+
Provides semantic and keyword search over documentation for CuTeDSL, CUDA, etc.
|
|
4
|
+
|
|
5
|
+
Corpora are downloaded via `wafer corpus download <name>` and stored in ~/.cache/wafer/corpora/.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from wafer_core.rollouts.dtypes import Tool, ToolCall, ToolFunction, ToolFunctionParameter, ToolResult
|
|
12
|
+
|
|
13
|
+
# Cache directory where wafer corpus download stores files
|
|
14
|
+
CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
|
|
15
|
+
|
|
16
|
+
# Available corpora (names match wafer corpus download)
|
|
17
|
+
AVAILABLE_CORPORA = ["cutlass", "cutedsl", "cuda", "hip", "amd"]
|
|
18
|
+
|
|
19
|
+
SEARCH_DOCS_TOOL = Tool(
|
|
20
|
+
type="function",
|
|
21
|
+
function=ToolFunction(
|
|
22
|
+
name="search_docs",
|
|
23
|
+
description="""Search GPU programming documentation for relevant information.
|
|
24
|
+
|
|
25
|
+
Use this tool to find documentation about:
|
|
26
|
+
- CUTLASS C++ (cute:: namespace, gemm tutorials, tensor cores, TMA, Blackwell)
|
|
27
|
+
- CuTeDSL Python API (@cute.kernel, @cute.jit, cute.arch functions)
|
|
28
|
+
- CUDA programming concepts
|
|
29
|
+
- GPU kernel optimization techniques
|
|
30
|
+
- Code examples and patterns
|
|
31
|
+
|
|
32
|
+
Available corpora:
|
|
33
|
+
- 'cutlass' - NVIDIA CUTLASS C++ docs + GitHub examples (gemm, hopper, blackwell)
|
|
34
|
+
- 'cutedsl' - CuTeDSL Python documentation
|
|
35
|
+
- 'cuda' - General CUDA programming docs
|
|
36
|
+
- 'hip' - AMD HIP programming docs
|
|
37
|
+
- 'amd' - AMD GPU kernel development (rocWMMA, CK, etc.)
|
|
38
|
+
|
|
39
|
+
Note: Corpora must be downloaded first with `wafer corpus download <name>`.
|
|
40
|
+
Returns relevant documentation snippets with file paths.""",
|
|
41
|
+
parameters=ToolFunctionParameter(
|
|
42
|
+
type="object",
|
|
43
|
+
properties={
|
|
44
|
+
"query": {
|
|
45
|
+
"type": "string",
|
|
46
|
+
"description": "Search query - describe what you're looking for",
|
|
47
|
+
},
|
|
48
|
+
"corpus": {
|
|
49
|
+
"type": "string",
|
|
50
|
+
"description": "Which docs to search: 'cutlass', 'cutedsl', 'cuda', 'hip', 'amd' (default: cutlass)",
|
|
51
|
+
},
|
|
52
|
+
"max_results": {
|
|
53
|
+
"type": "integer",
|
|
54
|
+
"description": "Maximum number of results to return (default: 5)",
|
|
55
|
+
},
|
|
56
|
+
},
|
|
57
|
+
),
|
|
58
|
+
required=["query"],
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_corpus_path(corpus_name: str) -> Path | None:
|
|
64
|
+
"""Get the path to a corpus in the cache directory.
|
|
65
|
+
|
|
66
|
+
Corpora are stored at ~/.cache/wafer/corpora/<corpus_name>/
|
|
67
|
+
"""
|
|
68
|
+
if corpus_name not in AVAILABLE_CORPORA:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
corpus_path = CACHE_DIR / corpus_name
|
|
72
|
+
if corpus_path.exists():
|
|
73
|
+
return corpus_path
|
|
74
|
+
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _search_files(corpus_path: Path, query: str, max_results: int = 5) -> list[dict]:
|
|
79
|
+
"""Simple keyword search through documentation files."""
|
|
80
|
+
results = []
|
|
81
|
+
query_terms = query.lower().split()
|
|
82
|
+
|
|
83
|
+
# Search .md, .py, .cu, .hpp, and .h files (for CUTLASS examples)
|
|
84
|
+
for pattern in ["**/*.md", "**/*.py", "**/*.cu", "**/*.hpp", "**/*.h", "**/*.cuh"]:
|
|
85
|
+
for file_path in corpus_path.glob(pattern):
|
|
86
|
+
if file_path.is_file():
|
|
87
|
+
try:
|
|
88
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
89
|
+
content_lower = content.lower()
|
|
90
|
+
|
|
91
|
+
# Score based on term matches
|
|
92
|
+
score = sum(content_lower.count(term) for term in query_terms)
|
|
93
|
+
|
|
94
|
+
if score > 0:
|
|
95
|
+
# Extract relevant snippets
|
|
96
|
+
snippets = _extract_snippets(content, query_terms)
|
|
97
|
+
results.append({
|
|
98
|
+
"file": str(file_path), # Return absolute path so read tool can access it
|
|
99
|
+
"score": score,
|
|
100
|
+
"snippets": snippets[:3], # Top 3 snippets
|
|
101
|
+
})
|
|
102
|
+
except Exception:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Sort by score and return top results
|
|
106
|
+
results.sort(key=lambda x: x["score"], reverse=True)
|
|
107
|
+
return results[:max_results]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _extract_snippets(content: str, terms: list[str], context_lines: int = 5) -> list[str]:
|
|
111
|
+
"""Extract snippets containing search terms."""
|
|
112
|
+
snippets = []
|
|
113
|
+
lines = content.split("\n")
|
|
114
|
+
|
|
115
|
+
for i, line in enumerate(lines):
|
|
116
|
+
line_lower = line.lower()
|
|
117
|
+
if any(term in line_lower for term in terms):
|
|
118
|
+
# Get context around the match
|
|
119
|
+
start = max(0, i - context_lines)
|
|
120
|
+
end = min(len(lines), i + context_lines + 1)
|
|
121
|
+
snippet = "\n".join(lines[start:end])
|
|
122
|
+
|
|
123
|
+
# Skip very short snippets
|
|
124
|
+
if len(snippet.strip()) > 50:
|
|
125
|
+
snippets.append(snippet)
|
|
126
|
+
|
|
127
|
+
return snippets
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
async def exec_search_docs(
|
|
131
|
+
tool_call: ToolCall,
|
|
132
|
+
corpus_override: str | None = None,
|
|
133
|
+
) -> ToolResult:
|
|
134
|
+
"""Execute search_docs tool.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool_call: The tool call with query and optional corpus
|
|
138
|
+
corpus_override: Override corpus path (for testing)
|
|
139
|
+
"""
|
|
140
|
+
query = tool_call.args.get("query", "")
|
|
141
|
+
corpus_name = tool_call.args.get("corpus", "cutlass")
|
|
142
|
+
max_results = tool_call.args.get("max_results", 5)
|
|
143
|
+
|
|
144
|
+
if not query:
|
|
145
|
+
return ToolResult(
|
|
146
|
+
tool_call_id=tool_call.id,
|
|
147
|
+
content="",
|
|
148
|
+
error="query parameter is required",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Find corpus path
|
|
152
|
+
if corpus_override:
|
|
153
|
+
corpus_path = Path(corpus_override)
|
|
154
|
+
else:
|
|
155
|
+
corpus_path = _get_corpus_path(corpus_name)
|
|
156
|
+
if corpus_path is None:
|
|
157
|
+
return ToolResult(
|
|
158
|
+
tool_call_id=tool_call.id,
|
|
159
|
+
content="",
|
|
160
|
+
error=f"Unknown corpus: {corpus_name}. Available: {AVAILABLE_CORPORA}",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if not corpus_path.exists():
|
|
164
|
+
return ToolResult(
|
|
165
|
+
tool_call_id=tool_call.id,
|
|
166
|
+
content="",
|
|
167
|
+
error=f"Corpus '{corpus_name}' not downloaded. Run: wafer corpus download {corpus_name}",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Search
|
|
171
|
+
results = _search_files(corpus_path, query, max_results)
|
|
172
|
+
|
|
173
|
+
if not results:
|
|
174
|
+
return ToolResult(
|
|
175
|
+
tool_call_id=tool_call.id,
|
|
176
|
+
content=f"No results found for query: {query}",
|
|
177
|
+
error=None,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Format output
|
|
181
|
+
output_parts = [f"Found {len(results)} results for: {query}\n"]
|
|
182
|
+
|
|
183
|
+
for i, result in enumerate(results, 1):
|
|
184
|
+
output_parts.append(f"\n{'='*60}")
|
|
185
|
+
output_parts.append(f"[{i}] {result['file']} (score: {result['score']})")
|
|
186
|
+
output_parts.append("=" * 60)
|
|
187
|
+
|
|
188
|
+
for snippet in result["snippets"]:
|
|
189
|
+
output_parts.append(snippet)
|
|
190
|
+
output_parts.append("-" * 40)
|
|
191
|
+
|
|
192
|
+
return ToolResult(
|
|
193
|
+
tool_call_id=tool_call.id,
|
|
194
|
+
content="\n".join(output_parts),
|
|
195
|
+
error=None,
|
|
196
|
+
)
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""Skill tool.
|
|
2
|
+
|
|
3
|
+
Loads skill content on demand from ~/.wafer/skills/ or bundled locations.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from wafer_core.rollouts.dtypes import (
|
|
7
|
+
Tool,
|
|
8
|
+
ToolCall,
|
|
9
|
+
ToolFunction,
|
|
10
|
+
ToolFunctionParameter,
|
|
11
|
+
ToolResult,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# ── Tool Definition ──────────────────────────────────────────────────────────
|
|
15
|
+
|
|
16
|
+
SKILL_TOOL = Tool(
|
|
17
|
+
type="function",
|
|
18
|
+
function=ToolFunction(
|
|
19
|
+
name="skill",
|
|
20
|
+
description="Load a skill's full instructions. Skills provide domain-specific knowledge and workflows. Use this when you need detailed guidance for a task mentioned in your available skills.",
|
|
21
|
+
parameters=ToolFunctionParameter(
|
|
22
|
+
type="object",
|
|
23
|
+
properties={
|
|
24
|
+
"name": {
|
|
25
|
+
"type": "string",
|
|
26
|
+
"description": "Name of the skill to load (e.g., 'wafer-guide')",
|
|
27
|
+
},
|
|
28
|
+
},
|
|
29
|
+
),
|
|
30
|
+
required=["name"],
|
|
31
|
+
),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ── Pure Function Executor ───────────────────────────────────────────────────
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
async def exec_skill(tool_call: ToolCall) -> ToolResult:
|
|
39
|
+
"""Load a skill's full instructions.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
tool_call: The tool call with skill name.
|
|
43
|
+
"""
|
|
44
|
+
from wafer_core.rollouts.skills import discover_skills, load_skill
|
|
45
|
+
|
|
46
|
+
skill_name = tool_call.args["name"]
|
|
47
|
+
skill = load_skill(skill_name)
|
|
48
|
+
|
|
49
|
+
if skill is None:
|
|
50
|
+
available = discover_skills()
|
|
51
|
+
available_names = [s.name for s in available]
|
|
52
|
+
return ToolResult(
|
|
53
|
+
tool_call_id=tool_call.id,
|
|
54
|
+
is_error=True,
|
|
55
|
+
content="",
|
|
56
|
+
error=f"Skill not found: {skill_name}. Available skills: {', '.join(available_names) or 'none'}",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
header = f"# Skill: {skill.name}\n\n"
|
|
60
|
+
return ToolResult(
|
|
61
|
+
tool_call_id=tool_call.id,
|
|
62
|
+
is_error=False,
|
|
63
|
+
content=header + skill.content,
|
|
64
|
+
)
|
wafer_core/utils/backend.py
CHANGED
|
@@ -33,6 +33,9 @@ def get_auth_token() -> str | None:
|
|
|
33
33
|
Note:
|
|
34
34
|
In local dev mode (localhost), no token is required.
|
|
35
35
|
The API will use LOCAL_DEV_MODE to bypass auth.
|
|
36
|
+
|
|
37
|
+
Callers (like wevin-extension) should pass WAFER_AUTH_TOKEN
|
|
38
|
+
as an environment variable when spawning Python processes.
|
|
36
39
|
"""
|
|
37
40
|
return os.environ.get("WAFER_AUTH_TOKEN")
|
|
38
41
|
|