wafer-core 0.1.22__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/environments/coding.py +4 -0
- wafer_core/targets/runpod.py +154 -15
- wafer_core/tools/__init__.py +7 -0
- wafer_core/tools/file_tools/grep_tool.py +56 -29
- wafer_core/tools/search_docs_tool.py +196 -0
- wafer_core/utils/kernel_utils/static_checker.py +175 -3
- wafer_core/utils/kernel_utils/targets/config.py +17 -3
- wafer_core/utils/modal_execution/modal_app.py +14 -27
- {wafer_core-0.1.22.dist-info → wafer_core-0.1.24.dist-info}/METADATA +1 -1
- {wafer_core-0.1.22.dist-info → wafer_core-0.1.24.dist-info}/RECORD +11 -10
- {wafer_core-0.1.22.dist-info → wafer_core-0.1.24.dist-info}/WHEEL +0 -0
|
@@ -34,6 +34,7 @@ from wafer_core.tools import (
|
|
|
34
34
|
GLOB_TOOL,
|
|
35
35
|
GREP_TOOL,
|
|
36
36
|
READ_TOOL,
|
|
37
|
+
SEARCH_DOCS_TOOL,
|
|
37
38
|
SKILL_TOOL,
|
|
38
39
|
WRITE_TOOL,
|
|
39
40
|
ApprovalCallback,
|
|
@@ -42,6 +43,7 @@ from wafer_core.tools import (
|
|
|
42
43
|
exec_glob,
|
|
43
44
|
exec_grep,
|
|
44
45
|
exec_read,
|
|
46
|
+
exec_search_docs,
|
|
45
47
|
exec_skill,
|
|
46
48
|
exec_write,
|
|
47
49
|
)
|
|
@@ -63,6 +65,7 @@ ALL_TOOLS = {
|
|
|
63
65
|
"glob": GLOB_TOOL,
|
|
64
66
|
"grep": GREP_TOOL,
|
|
65
67
|
"bash": BASH_TOOL,
|
|
68
|
+
"search_docs": SEARCH_DOCS_TOOL,
|
|
66
69
|
"skill": SKILL_TOOL,
|
|
67
70
|
# TODO(wafer-tool): "wafer": WAFER_TOOL,
|
|
68
71
|
}
|
|
@@ -211,6 +214,7 @@ class CodingEnvironment:
|
|
|
211
214
|
self.bash_approval_callback,
|
|
212
215
|
self._sandbox_policy,
|
|
213
216
|
),
|
|
217
|
+
"search_docs": lambda tc: exec_search_docs(tc),
|
|
214
218
|
"skill": lambda tc: exec_skill(tc),
|
|
215
219
|
# TODO(wafer-tool): "wafer": lambda tc: exec_wafer(
|
|
216
220
|
# tc, self.working_dir, self.enabled_tools, self.allow_spawn, cancel_scope
|
wafer_core/targets/runpod.py
CHANGED
|
@@ -14,7 +14,6 @@ from __future__ import annotations
|
|
|
14
14
|
|
|
15
15
|
import json
|
|
16
16
|
import logging
|
|
17
|
-
import os
|
|
18
17
|
import time
|
|
19
18
|
from contextlib import asynccontextmanager
|
|
20
19
|
from dataclasses import dataclass
|
|
@@ -250,10 +249,17 @@ async def provision_pod(target: RunPodTarget) -> tuple[str, str, int, str]:
|
|
|
250
249
|
"ports": "22/tcp",
|
|
251
250
|
"startSsh": True,
|
|
252
251
|
"startJupyter": False,
|
|
253
|
-
"imageName": target.image,
|
|
254
252
|
"env": [],
|
|
255
253
|
}
|
|
256
254
|
|
|
255
|
+
if target.template_id:
|
|
256
|
+
# Template defines image, dockerArgs (sshd setup), and ports.
|
|
257
|
+
# Required for non-RunPod images (e.g. rocm/pytorch) that don't
|
|
258
|
+
# have RunPod's built-in SSH handler.
|
|
259
|
+
pod_input["templateId"] = target.template_id
|
|
260
|
+
else:
|
|
261
|
+
pod_input["imageName"] = target.image
|
|
262
|
+
|
|
257
263
|
variables = {"input": pod_input}
|
|
258
264
|
|
|
259
265
|
logger.info(f"Provisioning RunPod pod: {pod_name}")
|
|
@@ -334,7 +340,8 @@ async def _wait_for_ssh(pod_id: str, timeout_seconds: int) -> tuple[str, int, st
|
|
|
334
340
|
# Check for SSH port
|
|
335
341
|
runtime = pod.get("runtime")
|
|
336
342
|
if runtime and status == "running":
|
|
337
|
-
|
|
343
|
+
# ports can be null in JSON response, so use 'or []' instead of default
|
|
344
|
+
for port in runtime.get("ports") or []:
|
|
338
345
|
if (
|
|
339
346
|
port.get("privatePort") == 22
|
|
340
347
|
and port.get("isIpPublic")
|
|
@@ -378,6 +385,55 @@ async def terminate_pod(pod_id: str) -> bool:
|
|
|
378
385
|
return False
|
|
379
386
|
|
|
380
387
|
|
|
388
|
+
# =============================================================================
|
|
389
|
+
# Template Management (not yet implemented)
|
|
390
|
+
# =============================================================================
|
|
391
|
+
#
|
|
392
|
+
# The saveTemplate mutation allows creating reusable pod templates with custom
|
|
393
|
+
# configurations. Templates can specify docker images, environment setup,
|
|
394
|
+
# container disk size, and other pod settings.
|
|
395
|
+
#
|
|
396
|
+
# Example mutation:
|
|
397
|
+
#
|
|
398
|
+
# mutation saveTemplate($input: SaveTemplateInput) {
|
|
399
|
+
# saveTemplate(input: $input) {
|
|
400
|
+
# id
|
|
401
|
+
# name
|
|
402
|
+
# imageName
|
|
403
|
+
# containerDiskInGb
|
|
404
|
+
# ports
|
|
405
|
+
# dockerArgs
|
|
406
|
+
# startSsh
|
|
407
|
+
# startJupyter
|
|
408
|
+
# }
|
|
409
|
+
# }
|
|
410
|
+
#
|
|
411
|
+
# Example variables:
|
|
412
|
+
#
|
|
413
|
+
# {
|
|
414
|
+
# "input": {
|
|
415
|
+
# "containerDiskInGb": 50,
|
|
416
|
+
# "dockerArgs": "bash -c \"apt-get update && apt-get install -y openssh-server && ...\"",
|
|
417
|
+
# "env": [],
|
|
418
|
+
# "isPublic": false,
|
|
419
|
+
# "isServerless": false,
|
|
420
|
+
# "name": "template-name",
|
|
421
|
+
# "ports": "22/tcp",
|
|
422
|
+
# "portsConfig": [{"name": "SSH", "port": "22"}],
|
|
423
|
+
# "readme": "",
|
|
424
|
+
# "volumeInGb": 0,
|
|
425
|
+
# "volumeMountPath": "",
|
|
426
|
+
# "config": {},
|
|
427
|
+
# "category": "AMD",
|
|
428
|
+
# "imageName": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
|
429
|
+
# }
|
|
430
|
+
# }
|
|
431
|
+
#
|
|
432
|
+
# Note: Template creation is not currently implemented in this module.
|
|
433
|
+
# If needed, implement a save_template() function following the pattern of
|
|
434
|
+
# provision_pod() and terminate_pod() above.
|
|
435
|
+
|
|
436
|
+
|
|
381
437
|
# =============================================================================
|
|
382
438
|
# Context Manager
|
|
383
439
|
# =============================================================================
|
|
@@ -482,20 +538,103 @@ async def cleanup_target(target_name: str) -> bool:
|
|
|
482
538
|
return success
|
|
483
539
|
|
|
484
540
|
|
|
541
|
+
async def sync_pods_from_api() -> list[PodState]:
|
|
542
|
+
"""Query RunPod API for all running pods and update local state.
|
|
543
|
+
|
|
544
|
+
This discovers pods that exist on the account but aren't in our state file
|
|
545
|
+
(e.g., created manually or by another machine). Updates the state file with
|
|
546
|
+
any wafer-created pods found.
|
|
547
|
+
|
|
548
|
+
Returns list of all running pods with SSH info.
|
|
549
|
+
"""
|
|
550
|
+
query = """
|
|
551
|
+
query {
|
|
552
|
+
myself {
|
|
553
|
+
pods {
|
|
554
|
+
id
|
|
555
|
+
name
|
|
556
|
+
desiredStatus
|
|
557
|
+
runtime {
|
|
558
|
+
ports {
|
|
559
|
+
ip
|
|
560
|
+
isIpPublic
|
|
561
|
+
privatePort
|
|
562
|
+
publicPort
|
|
563
|
+
}
|
|
564
|
+
}
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
}
|
|
568
|
+
"""
|
|
569
|
+
|
|
570
|
+
try:
|
|
571
|
+
data = await _graphql_request_async(query)
|
|
572
|
+
except Exception as e:
|
|
573
|
+
logger.warning(f"Failed to query pods from API: {e}")
|
|
574
|
+
return []
|
|
575
|
+
|
|
576
|
+
pods = data.get("myself", {}).get("pods", [])
|
|
577
|
+
running_pods = []
|
|
578
|
+
|
|
579
|
+
for pod in pods:
|
|
580
|
+
status = pod.get("desiredStatus", "").lower()
|
|
581
|
+
if status != "running":
|
|
582
|
+
continue
|
|
583
|
+
|
|
584
|
+
pod_id = pod["id"]
|
|
585
|
+
pod_name = pod.get("name", "")
|
|
586
|
+
|
|
587
|
+
# Extract SSH info
|
|
588
|
+
runtime = pod.get("runtime")
|
|
589
|
+
if not runtime:
|
|
590
|
+
continue
|
|
591
|
+
|
|
592
|
+
public_ip = None
|
|
593
|
+
ssh_port = None
|
|
594
|
+
for port in runtime.get("ports") or []:
|
|
595
|
+
if port.get("privatePort") == 22 and port.get("isIpPublic"):
|
|
596
|
+
public_ip = port.get("ip")
|
|
597
|
+
ssh_port = port.get("publicPort")
|
|
598
|
+
break
|
|
599
|
+
|
|
600
|
+
if not public_ip or not ssh_port:
|
|
601
|
+
continue
|
|
602
|
+
|
|
603
|
+
# Extract target name from pod name (wafer-{target_name}-{timestamp})
|
|
604
|
+
target_name = None
|
|
605
|
+
if pod_name.startswith("wafer-"):
|
|
606
|
+
parts = pod_name.split("-")
|
|
607
|
+
if len(parts) >= 3:
|
|
608
|
+
# Handle target names with hyphens: wafer-runpod-mi300x-1234567
|
|
609
|
+
target_name = "-".join(parts[1:-1])
|
|
610
|
+
|
|
611
|
+
pod_state = PodState(
|
|
612
|
+
pod_id=pod_id,
|
|
613
|
+
target_name=target_name or pod_name,
|
|
614
|
+
public_ip=public_ip,
|
|
615
|
+
ssh_port=ssh_port,
|
|
616
|
+
ssh_username="root",
|
|
617
|
+
created_at=datetime.now(timezone.utc).isoformat(),
|
|
618
|
+
)
|
|
619
|
+
running_pods.append(pod_state)
|
|
620
|
+
|
|
621
|
+
# Update state file if this is a wafer-created pod
|
|
622
|
+
if target_name:
|
|
623
|
+
existing = get_pod_state(target_name)
|
|
624
|
+
if not existing or existing.pod_id != pod_id:
|
|
625
|
+
logger.info(f"Syncing pod {pod_id} to state for target {target_name}")
|
|
626
|
+
_add_pod_to_state(target_name, pod_id, public_ip, ssh_port, "root")
|
|
627
|
+
|
|
628
|
+
return running_pods
|
|
629
|
+
|
|
630
|
+
|
|
485
631
|
async def list_running_pods() -> list[PodState]:
|
|
486
|
-
"""List all pods
|
|
487
|
-
state = _load_state()
|
|
488
|
-
running = []
|
|
632
|
+
"""List all running pods by querying the RunPod API.
|
|
489
633
|
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
# Clean up stale entry
|
|
495
|
-
logger.info(f"Removing stale state for {name} (pod {pod_state.pod_id})")
|
|
496
|
-
_remove_pod_from_state(name)
|
|
497
|
-
|
|
498
|
-
return running
|
|
634
|
+
Syncs state file with API to discover pods not in local state.
|
|
635
|
+
Returns list of running pods with SSH info.
|
|
636
|
+
"""
|
|
637
|
+
return await sync_pods_from_api()
|
|
499
638
|
|
|
500
639
|
|
|
501
640
|
async def cleanup_all_pods() -> int:
|
wafer_core/tools/__init__.py
CHANGED
|
@@ -72,6 +72,10 @@ from wafer_core.tools.write_kernel_tool import (
|
|
|
72
72
|
KernelSubmission,
|
|
73
73
|
exec_write_kernel,
|
|
74
74
|
)
|
|
75
|
+
from wafer_core.tools.search_docs_tool import (
|
|
76
|
+
SEARCH_DOCS_TOOL,
|
|
77
|
+
exec_search_docs,
|
|
78
|
+
)
|
|
75
79
|
|
|
76
80
|
__all__ = [
|
|
77
81
|
# File tools
|
|
@@ -133,4 +137,7 @@ __all__ = [
|
|
|
133
137
|
"exec_tracelens_report",
|
|
134
138
|
"exec_tracelens_compare",
|
|
135
139
|
"exec_tracelens_collective",
|
|
140
|
+
# Search docs tool
|
|
141
|
+
"SEARCH_DOCS_TOOL",
|
|
142
|
+
"exec_search_docs",
|
|
136
143
|
]
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Grep tool using ripgrep
|
|
1
|
+
"""Grep tool using ripgrep (with fallback to standard grep)."""
|
|
2
2
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
@@ -15,7 +15,7 @@ GREP_TOOL = Tool(
|
|
|
15
15
|
function=ToolFunction(
|
|
16
16
|
name="grep",
|
|
17
17
|
description=(
|
|
18
|
-
"Search for a pattern in files
|
|
18
|
+
"Search for a pattern in files. "
|
|
19
19
|
"Returns matching lines with file paths and line numbers. "
|
|
20
20
|
"Supports regex patterns by default."
|
|
21
21
|
),
|
|
@@ -54,7 +54,7 @@ GREP_TOOL = Tool(
|
|
|
54
54
|
|
|
55
55
|
|
|
56
56
|
async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
57
|
-
"""Execute grep using ripgrep."""
|
|
57
|
+
"""Execute grep using ripgrep (preferred) or standard grep (fallback)."""
|
|
58
58
|
import shutil
|
|
59
59
|
import subprocess
|
|
60
60
|
|
|
@@ -74,35 +74,55 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
74
74
|
error="'pattern' is required",
|
|
75
75
|
)
|
|
76
76
|
|
|
77
|
-
#
|
|
77
|
+
# Try ripgrep first, fall back to standard grep
|
|
78
78
|
rg_path = shutil.which("rg")
|
|
79
|
-
|
|
79
|
+
grep_path = shutil.which("grep")
|
|
80
|
+
|
|
81
|
+
if rg_path:
|
|
82
|
+
# Use ripgrep (faster, better defaults)
|
|
83
|
+
cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
|
|
84
|
+
|
|
85
|
+
if case_insensitive:
|
|
86
|
+
cmd.append("--ignore-case")
|
|
87
|
+
|
|
88
|
+
if context_lines:
|
|
89
|
+
cmd.extend(["--context", str(context_lines)])
|
|
90
|
+
|
|
91
|
+
if glob_pattern:
|
|
92
|
+
cmd.extend(["--glob", glob_pattern])
|
|
93
|
+
|
|
94
|
+
# Limit results
|
|
95
|
+
cmd.extend(["--max-count", str(max_results)])
|
|
96
|
+
|
|
97
|
+
cmd.append(pattern)
|
|
98
|
+
cmd.append(search_path)
|
|
99
|
+
use_ripgrep = True
|
|
100
|
+
elif grep_path:
|
|
101
|
+
# Fallback to standard grep
|
|
102
|
+
cmd = [grep_path, "-r", "-n", "--color=never"]
|
|
103
|
+
|
|
104
|
+
if case_insensitive:
|
|
105
|
+
cmd.append("-i")
|
|
106
|
+
|
|
107
|
+
if context_lines:
|
|
108
|
+
cmd.extend(["-C", str(context_lines)])
|
|
109
|
+
|
|
110
|
+
if glob_pattern:
|
|
111
|
+
# Standard grep uses --include for glob patterns
|
|
112
|
+
cmd.extend(["--include", glob_pattern])
|
|
113
|
+
|
|
114
|
+
cmd.append(pattern)
|
|
115
|
+
cmd.append(search_path)
|
|
116
|
+
use_ripgrep = False
|
|
117
|
+
else:
|
|
80
118
|
return ToolResult(
|
|
81
119
|
tool_call_id=tool_call.id,
|
|
82
120
|
is_error=True,
|
|
83
121
|
content="",
|
|
84
|
-
error="ripgrep (rg)
|
|
122
|
+
error="Neither ripgrep (rg) nor grep found. Please install one.",
|
|
85
123
|
)
|
|
86
124
|
|
|
87
|
-
#
|
|
88
|
-
cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
|
|
89
|
-
|
|
90
|
-
if case_insensitive:
|
|
91
|
-
cmd.append("--ignore-case")
|
|
92
|
-
|
|
93
|
-
if context_lines:
|
|
94
|
-
cmd.extend(["--context", str(context_lines)])
|
|
95
|
-
|
|
96
|
-
if glob_pattern:
|
|
97
|
-
cmd.extend(["--glob", glob_pattern])
|
|
98
|
-
|
|
99
|
-
# Limit results
|
|
100
|
-
cmd.extend(["--max-count", str(max_results)])
|
|
101
|
-
|
|
102
|
-
cmd.append(pattern)
|
|
103
|
-
cmd.append(search_path)
|
|
104
|
-
|
|
105
|
-
# Run ripgrep
|
|
125
|
+
# Run the search
|
|
106
126
|
try:
|
|
107
127
|
result = subprocess.run(
|
|
108
128
|
cmd,
|
|
@@ -126,13 +146,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
126
146
|
error=f"Search failed: {e}",
|
|
127
147
|
)
|
|
128
148
|
|
|
129
|
-
# ripgrep
|
|
149
|
+
# Both ripgrep and grep return exit code 1 for no matches (not an error)
|
|
130
150
|
if result.returncode not in (0, 1):
|
|
151
|
+
tool_name = "ripgrep" if use_ripgrep else "grep"
|
|
131
152
|
return ToolResult(
|
|
132
153
|
tool_call_id=tool_call.id,
|
|
133
154
|
is_error=True,
|
|
134
155
|
content="",
|
|
135
|
-
error=result.stderr or f"
|
|
156
|
+
error=result.stderr or f"{tool_name} exited with code {result.returncode}",
|
|
136
157
|
)
|
|
137
158
|
|
|
138
159
|
output = result.stdout.strip()
|
|
@@ -143,8 +164,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
|
|
|
143
164
|
content=f"No matches found for pattern: {pattern}",
|
|
144
165
|
)
|
|
145
166
|
|
|
146
|
-
# Count matches
|
|
147
|
-
|
|
167
|
+
# Count matches and limit output for standard grep
|
|
168
|
+
lines = output.split("\n")
|
|
169
|
+
if not use_ripgrep and len(lines) > max_results:
|
|
170
|
+
lines = lines[:max_results]
|
|
171
|
+
output = "\n".join(lines)
|
|
172
|
+
output += f"\n... (truncated to {max_results} results)"
|
|
173
|
+
|
|
174
|
+
match_count = min(len(lines), max_results)
|
|
148
175
|
header = f"Found {match_count} matches:\n\n"
|
|
149
176
|
|
|
150
177
|
return ToolResult(
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""Search documentation tool for GPU programming corpora.
|
|
2
|
+
|
|
3
|
+
Provides semantic and keyword search over documentation for CuTeDSL, CUDA, etc.
|
|
4
|
+
|
|
5
|
+
Corpora are downloaded via `wafer corpus download <name>` and stored in ~/.cache/wafer/corpora/.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from wafer_core.rollouts.dtypes import Tool, ToolCall, ToolFunction, ToolFunctionParameter, ToolResult
|
|
12
|
+
|
|
13
|
+
# Cache directory where wafer corpus download stores files
|
|
14
|
+
CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
|
|
15
|
+
|
|
16
|
+
# Available corpora (names match wafer corpus download)
|
|
17
|
+
AVAILABLE_CORPORA = ["cutlass", "cutedsl", "cuda", "hip", "amd"]
|
|
18
|
+
|
|
19
|
+
SEARCH_DOCS_TOOL = Tool(
|
|
20
|
+
type="function",
|
|
21
|
+
function=ToolFunction(
|
|
22
|
+
name="search_docs",
|
|
23
|
+
description="""Search GPU programming documentation for relevant information.
|
|
24
|
+
|
|
25
|
+
Use this tool to find documentation about:
|
|
26
|
+
- CUTLASS C++ (cute:: namespace, gemm tutorials, tensor cores, TMA, Blackwell)
|
|
27
|
+
- CuTeDSL Python API (@cute.kernel, @cute.jit, cute.arch functions)
|
|
28
|
+
- CUDA programming concepts
|
|
29
|
+
- GPU kernel optimization techniques
|
|
30
|
+
- Code examples and patterns
|
|
31
|
+
|
|
32
|
+
Available corpora:
|
|
33
|
+
- 'cutlass' - NVIDIA CUTLASS C++ docs + GitHub examples (gemm, hopper, blackwell)
|
|
34
|
+
- 'cutedsl' - CuTeDSL Python documentation
|
|
35
|
+
- 'cuda' - General CUDA programming docs
|
|
36
|
+
- 'hip' - AMD HIP programming docs
|
|
37
|
+
- 'amd' - AMD GPU kernel development (rocWMMA, CK, etc.)
|
|
38
|
+
|
|
39
|
+
Note: Corpora must be downloaded first with `wafer corpus download <name>`.
|
|
40
|
+
Returns relevant documentation snippets with file paths.""",
|
|
41
|
+
parameters=ToolFunctionParameter(
|
|
42
|
+
type="object",
|
|
43
|
+
properties={
|
|
44
|
+
"query": {
|
|
45
|
+
"type": "string",
|
|
46
|
+
"description": "Search query - describe what you're looking for",
|
|
47
|
+
},
|
|
48
|
+
"corpus": {
|
|
49
|
+
"type": "string",
|
|
50
|
+
"description": "Which docs to search: 'cutlass', 'cutedsl', 'cuda', 'hip', 'amd' (default: cutlass)",
|
|
51
|
+
},
|
|
52
|
+
"max_results": {
|
|
53
|
+
"type": "integer",
|
|
54
|
+
"description": "Maximum number of results to return (default: 5)",
|
|
55
|
+
},
|
|
56
|
+
},
|
|
57
|
+
),
|
|
58
|
+
required=["query"],
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _get_corpus_path(corpus_name: str) -> Path | None:
|
|
64
|
+
"""Get the path to a corpus in the cache directory.
|
|
65
|
+
|
|
66
|
+
Corpora are stored at ~/.cache/wafer/corpora/<corpus_name>/
|
|
67
|
+
"""
|
|
68
|
+
if corpus_name not in AVAILABLE_CORPORA:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
corpus_path = CACHE_DIR / corpus_name
|
|
72
|
+
if corpus_path.exists():
|
|
73
|
+
return corpus_path
|
|
74
|
+
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _search_files(corpus_path: Path, query: str, max_results: int = 5) -> list[dict]:
|
|
79
|
+
"""Simple keyword search through documentation files."""
|
|
80
|
+
results = []
|
|
81
|
+
query_terms = query.lower().split()
|
|
82
|
+
|
|
83
|
+
# Search .md, .py, .cu, .hpp, and .h files (for CUTLASS examples)
|
|
84
|
+
for pattern in ["**/*.md", "**/*.py", "**/*.cu", "**/*.hpp", "**/*.h", "**/*.cuh"]:
|
|
85
|
+
for file_path in corpus_path.glob(pattern):
|
|
86
|
+
if file_path.is_file():
|
|
87
|
+
try:
|
|
88
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
89
|
+
content_lower = content.lower()
|
|
90
|
+
|
|
91
|
+
# Score based on term matches
|
|
92
|
+
score = sum(content_lower.count(term) for term in query_terms)
|
|
93
|
+
|
|
94
|
+
if score > 0:
|
|
95
|
+
# Extract relevant snippets
|
|
96
|
+
snippets = _extract_snippets(content, query_terms)
|
|
97
|
+
results.append({
|
|
98
|
+
"file": str(file_path), # Return absolute path so read tool can access it
|
|
99
|
+
"score": score,
|
|
100
|
+
"snippets": snippets[:3], # Top 3 snippets
|
|
101
|
+
})
|
|
102
|
+
except Exception:
|
|
103
|
+
continue
|
|
104
|
+
|
|
105
|
+
# Sort by score and return top results
|
|
106
|
+
results.sort(key=lambda x: x["score"], reverse=True)
|
|
107
|
+
return results[:max_results]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _extract_snippets(content: str, terms: list[str], context_lines: int = 5) -> list[str]:
|
|
111
|
+
"""Extract snippets containing search terms."""
|
|
112
|
+
snippets = []
|
|
113
|
+
lines = content.split("\n")
|
|
114
|
+
|
|
115
|
+
for i, line in enumerate(lines):
|
|
116
|
+
line_lower = line.lower()
|
|
117
|
+
if any(term in line_lower for term in terms):
|
|
118
|
+
# Get context around the match
|
|
119
|
+
start = max(0, i - context_lines)
|
|
120
|
+
end = min(len(lines), i + context_lines + 1)
|
|
121
|
+
snippet = "\n".join(lines[start:end])
|
|
122
|
+
|
|
123
|
+
# Skip very short snippets
|
|
124
|
+
if len(snippet.strip()) > 50:
|
|
125
|
+
snippets.append(snippet)
|
|
126
|
+
|
|
127
|
+
return snippets
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
async def exec_search_docs(
|
|
131
|
+
tool_call: ToolCall,
|
|
132
|
+
corpus_override: str | None = None,
|
|
133
|
+
) -> ToolResult:
|
|
134
|
+
"""Execute search_docs tool.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
tool_call: The tool call with query and optional corpus
|
|
138
|
+
corpus_override: Override corpus path (for testing)
|
|
139
|
+
"""
|
|
140
|
+
query = tool_call.args.get("query", "")
|
|
141
|
+
corpus_name = tool_call.args.get("corpus", "cutlass")
|
|
142
|
+
max_results = tool_call.args.get("max_results", 5)
|
|
143
|
+
|
|
144
|
+
if not query:
|
|
145
|
+
return ToolResult(
|
|
146
|
+
tool_call_id=tool_call.id,
|
|
147
|
+
content="",
|
|
148
|
+
error="query parameter is required",
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Find corpus path
|
|
152
|
+
if corpus_override:
|
|
153
|
+
corpus_path = Path(corpus_override)
|
|
154
|
+
else:
|
|
155
|
+
corpus_path = _get_corpus_path(corpus_name)
|
|
156
|
+
if corpus_path is None:
|
|
157
|
+
return ToolResult(
|
|
158
|
+
tool_call_id=tool_call.id,
|
|
159
|
+
content="",
|
|
160
|
+
error=f"Unknown corpus: {corpus_name}. Available: {AVAILABLE_CORPORA}",
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if not corpus_path.exists():
|
|
164
|
+
return ToolResult(
|
|
165
|
+
tool_call_id=tool_call.id,
|
|
166
|
+
content="",
|
|
167
|
+
error=f"Corpus '{corpus_name}' not downloaded. Run: wafer corpus download {corpus_name}",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Search
|
|
171
|
+
results = _search_files(corpus_path, query, max_results)
|
|
172
|
+
|
|
173
|
+
if not results:
|
|
174
|
+
return ToolResult(
|
|
175
|
+
tool_call_id=tool_call.id,
|
|
176
|
+
content=f"No results found for query: {query}",
|
|
177
|
+
error=None,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Format output
|
|
181
|
+
output_parts = [f"Found {len(results)} results for: {query}\n"]
|
|
182
|
+
|
|
183
|
+
for i, result in enumerate(results, 1):
|
|
184
|
+
output_parts.append(f"\n{'='*60}")
|
|
185
|
+
output_parts.append(f"[{i}] {result['file']} (score: {result['score']})")
|
|
186
|
+
output_parts.append("=" * 60)
|
|
187
|
+
|
|
188
|
+
for snippet in result["snippets"]:
|
|
189
|
+
output_parts.append(snippet)
|
|
190
|
+
output_parts.append("-" * 40)
|
|
191
|
+
|
|
192
|
+
return ToolResult(
|
|
193
|
+
tool_call_id=tool_call.id,
|
|
194
|
+
content="\n".join(output_parts),
|
|
195
|
+
error=None,
|
|
196
|
+
)
|
|
@@ -155,6 +155,161 @@ def check_torch_computation_ops(code: str) -> tuple[bool, str]:
|
|
|
155
155
|
return (False, "")
|
|
156
156
|
|
|
157
157
|
|
|
158
|
+
# =============================================================================
|
|
159
|
+
# NN.MODULE FORWARD CALL CHECKS (Reward Hacking in forward())
|
|
160
|
+
# =============================================================================
|
|
161
|
+
|
|
162
|
+
# These patterns detect calling PyTorch nn.Module forward methods inside forward()
|
|
163
|
+
# e.g., self.conv(x), self.linear(x), self.bn(x) - these invoke cuBLAS/cuDNN
|
|
164
|
+
#
|
|
165
|
+
# This is different from:
|
|
166
|
+
# - nn.Linear(...) in __init__ = OK (just creates parameter container)
|
|
167
|
+
# - self.linear.weight in forward() = OK (accessing weights for custom kernel)
|
|
168
|
+
# - self.linear(x) in forward() = BAD (invokes PyTorch's matmul via cuBLAS)
|
|
169
|
+
|
|
170
|
+
NN_MODULE_FORWARD_PATTERNS = [
|
|
171
|
+
# Common layer types being called as functions
|
|
172
|
+
r"self\.(conv\d*d?|linear|bn|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
|
|
173
|
+
# More generic pattern: self.<name>(x) or self.<name>(input)
|
|
174
|
+
# But we need to be careful not to match custom module calls
|
|
175
|
+
]
|
|
176
|
+
|
|
177
|
+
# =============================================================================
|
|
178
|
+
# TORCH.NN.FUNCTIONAL CHECKS (Reward Hacking)
|
|
179
|
+
# =============================================================================
|
|
180
|
+
|
|
181
|
+
# Patterns for torch.nn.functional / F.* calls that bypass custom kernel requirement
|
|
182
|
+
# These call into cuBLAS/cuDNN under the hood
|
|
183
|
+
TORCH_FUNCTIONAL_PATTERNS = [
|
|
184
|
+
# F.linear, F.conv*, F.batch_norm etc. (common alias)
|
|
185
|
+
r"\bF\.(linear|conv[123]d|conv_transpose[123]d|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
|
|
186
|
+
# Full path torch.nn.functional.*
|
|
187
|
+
r"\btorch\.nn\.functional\.(linear|conv[123]d|conv_transpose[123]d|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def check_torch_functional_calls(code: str) -> tuple[bool, str]:
|
|
192
|
+
"""Check for torch.nn.functional / F.* calls in forward() method (reward hacking).
|
|
193
|
+
|
|
194
|
+
Detects patterns like F.linear(x, weight), F.conv2d(x, weight) which invoke
|
|
195
|
+
PyTorch's built-in operations (backed by cuBLAS/cuDNN) instead of custom kernels.
|
|
196
|
+
|
|
197
|
+
Note: We only check inside forward() to avoid false positives from imports or
|
|
198
|
+
training-mode fallbacks that aren't used during inference.
|
|
199
|
+
"""
|
|
200
|
+
# Only check inside forward() method
|
|
201
|
+
forward_code = _extract_forward_method(code)
|
|
202
|
+
if not forward_code:
|
|
203
|
+
return (False, "")
|
|
204
|
+
|
|
205
|
+
forward_code = _strip_comments(forward_code)
|
|
206
|
+
|
|
207
|
+
for pattern in TORCH_FUNCTIONAL_PATTERNS:
|
|
208
|
+
match = re.search(pattern, forward_code)
|
|
209
|
+
if match:
|
|
210
|
+
return (True, f"Uses torch.nn.functional in forward(): {match.group(0)} (reward hacking - must use custom kernel)")
|
|
211
|
+
|
|
212
|
+
return (False, "")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _extract_forward_method(code: str) -> str:
|
|
216
|
+
"""Extract only the forward() method content from code."""
|
|
217
|
+
lines = code.split('\n')
|
|
218
|
+
result = []
|
|
219
|
+
in_forward = False
|
|
220
|
+
base_indent = 0
|
|
221
|
+
|
|
222
|
+
for i, line in enumerate(lines):
|
|
223
|
+
# Look for forward method definition
|
|
224
|
+
if re.search(r'\bdef\s+forward\s*\(\s*self', line):
|
|
225
|
+
in_forward = True
|
|
226
|
+
# Get the indentation level of the def line
|
|
227
|
+
base_indent = len(line) - len(line.lstrip())
|
|
228
|
+
result.append(line)
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
if in_forward:
|
|
232
|
+
# Check if we've exited the forward method (new method/class at same or lower indent)
|
|
233
|
+
stripped = line.strip()
|
|
234
|
+
if stripped and not stripped.startswith('#') and not stripped.startswith('"""') and not stripped.startswith("'''"):
|
|
235
|
+
current_indent = len(line) - len(line.lstrip())
|
|
236
|
+
# If we hit a new def/class at the same or lower indentation, we're done
|
|
237
|
+
if current_indent <= base_indent and (stripped.startswith('def ') or stripped.startswith('class ')):
|
|
238
|
+
break
|
|
239
|
+
result.append(line)
|
|
240
|
+
|
|
241
|
+
return '\n'.join(result)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def check_nn_module_forward_call(code: str) -> tuple[bool, str]:
|
|
245
|
+
"""Check for nn.Module forward calls inside forward() method (reward hacking).
|
|
246
|
+
|
|
247
|
+
Detects patterns like self.conv(x), self.linear(x) which invoke PyTorch's
|
|
248
|
+
built-in layers (backed by cuBLAS/cuDNN) instead of custom kernels.
|
|
249
|
+
"""
|
|
250
|
+
# Only check inside forward() method
|
|
251
|
+
forward_code = _extract_forward_method(code)
|
|
252
|
+
if not forward_code:
|
|
253
|
+
return (False, "")
|
|
254
|
+
|
|
255
|
+
forward_code = _strip_comments(forward_code)
|
|
256
|
+
|
|
257
|
+
for pattern in NN_MODULE_FORWARD_PATTERNS:
|
|
258
|
+
match = re.search(pattern, forward_code)
|
|
259
|
+
if match:
|
|
260
|
+
return (True, f"Calls PyTorch nn.Module in forward(): {match.group(0)} (reward hacking - must use custom kernel)")
|
|
261
|
+
|
|
262
|
+
return (False, "")
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
# =============================================================================
|
|
266
|
+
# CUBLAS/CUDNN DIRECT USAGE CHECKS (Reward Hacking)
|
|
267
|
+
# =============================================================================
|
|
268
|
+
|
|
269
|
+
# Direct cuBLAS calls bypass custom kernel requirement
|
|
270
|
+
CUBLAS_PATTERNS = [
|
|
271
|
+
r"\bcublas[A-Z]\w+\s*\(", # cublasSgemm, cublasGemmEx, etc.
|
|
272
|
+
r"\bcublasCreate\b",
|
|
273
|
+
r"\bcublasDestroy\b",
|
|
274
|
+
r"\bcublasSetStream\b",
|
|
275
|
+
r"\bcublasSetMathMode\b",
|
|
276
|
+
r"#include\s*[<\"]cublas", # #include <cublas_v2.h>
|
|
277
|
+
r"CUBLAS_TENSOR_OP_MATH",
|
|
278
|
+
]
|
|
279
|
+
|
|
280
|
+
# Direct cuDNN calls bypass custom kernel requirement
|
|
281
|
+
CUDNN_PATTERNS = [
|
|
282
|
+
r"\bcudnn[A-Z]\w+\s*\(", # cudnnConvolutionForward, etc.
|
|
283
|
+
r"\bcudnnCreate\b",
|
|
284
|
+
r"\bcudnnDestroy\b",
|
|
285
|
+
r"#include\s*[<\"]cudnn", # #include <cudnn.h>
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def check_cublas_usage(code: str) -> tuple[bool, str]:
|
|
290
|
+
"""Check for direct cuBLAS API usage (reward hacking)."""
|
|
291
|
+
code = _strip_comments(code)
|
|
292
|
+
|
|
293
|
+
for pattern in CUBLAS_PATTERNS:
|
|
294
|
+
match = re.search(pattern, code)
|
|
295
|
+
if match:
|
|
296
|
+
return (True, f"Uses cuBLAS directly: {match.group(0)} (reward hacking - must write custom kernel)")
|
|
297
|
+
|
|
298
|
+
return (False, "")
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def check_cudnn_usage(code: str) -> tuple[bool, str]:
|
|
302
|
+
"""Check for direct cuDNN API usage (reward hacking)."""
|
|
303
|
+
code = _strip_comments(code)
|
|
304
|
+
|
|
305
|
+
for pattern in CUDNN_PATTERNS:
|
|
306
|
+
match = re.search(pattern, code)
|
|
307
|
+
if match:
|
|
308
|
+
return (True, f"Uses cuDNN directly: {match.group(0)} (reward hacking - must write custom kernel)")
|
|
309
|
+
|
|
310
|
+
return (False, "")
|
|
311
|
+
|
|
312
|
+
|
|
158
313
|
# =============================================================================
|
|
159
314
|
# TIMING MANIPULATION CHECKS
|
|
160
315
|
# =============================================================================
|
|
@@ -310,8 +465,16 @@ def check_tk_impl(code: str) -> tuple[bool, str]:
|
|
|
310
465
|
def check_cute_impl(code: str) -> tuple[bool, str]:
|
|
311
466
|
"""Check for valid CUTLASS/CuTe kernel implementation."""
|
|
312
467
|
code = _strip_comments(code)
|
|
313
|
-
|
|
314
|
-
|
|
468
|
+
# Accept explicit namespace qualifiers OR using namespace declarations
|
|
469
|
+
valid_patterns = [
|
|
470
|
+
"cute::",
|
|
471
|
+
"cutlass::",
|
|
472
|
+
"from cutlass",
|
|
473
|
+
"using namespace cute",
|
|
474
|
+
"using namespace cutlass",
|
|
475
|
+
]
|
|
476
|
+
if not any(p in code for p in valid_patterns):
|
|
477
|
+
return (True, "Missing cute:: or cutlass:: namespace (or 'using namespace')")
|
|
315
478
|
return (False, "")
|
|
316
479
|
|
|
317
480
|
|
|
@@ -334,6 +497,11 @@ CHECK_FUNCTIONS: dict[str, Callable[[str], tuple[bool, str]]] = {
|
|
|
334
497
|
"timing_event_patch": check_timing_event_patch,
|
|
335
498
|
# Torch ops
|
|
336
499
|
"torch_computation_ops": check_torch_computation_ops,
|
|
500
|
+
# Reward hacking checks
|
|
501
|
+
"cublas_usage": check_cublas_usage,
|
|
502
|
+
"cudnn_usage": check_cudnn_usage,
|
|
503
|
+
"nn_module_forward_call": check_nn_module_forward_call,
|
|
504
|
+
"torch_functional_calls": check_torch_functional_calls,
|
|
337
505
|
# Timing manipulation
|
|
338
506
|
"stream_injection": check_stream_injection,
|
|
339
507
|
"thread_injection": check_thread_injection,
|
|
@@ -363,12 +531,16 @@ STRICT_CHECKS = [
|
|
|
363
531
|
"timing_event_patch",
|
|
364
532
|
"thread_injection",
|
|
365
533
|
"lazy_eval",
|
|
534
|
+
"cublas_usage", # Direct cuBLAS is reward hacking
|
|
535
|
+
"cudnn_usage", # Direct cuDNN is reward hacking
|
|
536
|
+
"nn_module_forward_call", # Calling self.conv(x), self.linear(x) in forward() is reward hacking
|
|
537
|
+
"torch_functional_calls", # Calling F.linear(), F.conv2d() in forward() is reward hacking
|
|
538
|
+
"torch_computation_ops", # torch.mm, torch.matmul, torch.conv* etc. are reward hacking
|
|
366
539
|
]
|
|
367
540
|
|
|
368
541
|
# Checks that emit warnings but don't fail
|
|
369
542
|
WARNING_CHECKS = [
|
|
370
543
|
"pytorch_wrap",
|
|
371
|
-
"torch_computation_ops",
|
|
372
544
|
"stream_injection",
|
|
373
545
|
]
|
|
374
546
|
|
|
@@ -290,7 +290,7 @@ class RunPodTarget:
|
|
|
290
290
|
ssh_key="~/.ssh/id_ed25519",
|
|
291
291
|
gpu_type_id="AMD Instinct MI300X OAM",
|
|
292
292
|
gpu_count=1,
|
|
293
|
-
image="
|
|
293
|
+
image="rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
|
|
294
294
|
keep_alive=True, # Don't terminate after eval
|
|
295
295
|
)
|
|
296
296
|
|
|
@@ -304,7 +304,21 @@ class RunPodTarget:
|
|
|
304
304
|
gpu_type_id: str = AMD_MI300X_GPU_ID # RunPod GPU type identifier
|
|
305
305
|
gpu_count: int = 1
|
|
306
306
|
container_disk_gb: int = 50
|
|
307
|
-
|
|
307
|
+
# TODO: Consider creating a custom Docker image with HipKittens pre-installed
|
|
308
|
+
# to avoid needing `wafer config targets install <target> hipkittens`.
|
|
309
|
+
# HipKittens repo: https://github.com/HazyResearch/hipkittens
|
|
310
|
+
# CK (Composable Kernel) is already included in ROCm 7.0.
|
|
311
|
+
#
|
|
312
|
+
# WARNING: PyTorch's hipify can corrupt /opt/rocm/include/thrust/ headers.
|
|
313
|
+
# If you see "cuda/__cccl_config not found" errors, run:
|
|
314
|
+
# apt-get install --reinstall -y rocthrust
|
|
315
|
+
# See docker/rocm7-runpod/README.md for details.
|
|
316
|
+
image: str = "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
|
317
|
+
|
|
318
|
+
# RunPod template ID — required for non-RunPod images that need custom
|
|
319
|
+
# dockerArgs (e.g. to install and start sshd). When set, takes priority
|
|
320
|
+
# over `image` in the deploy mutation.
|
|
321
|
+
template_id: str | None = None
|
|
308
322
|
|
|
309
323
|
# Timeouts
|
|
310
324
|
provision_timeout: int = 900 # 15 min for SSH to be ready
|
|
@@ -426,7 +440,7 @@ class DigitalOceanTarget:
|
|
|
426
440
|
# DigitalOcean instance configuration
|
|
427
441
|
region: str = "atl1" # Atlanta (AMD GPUs available here)
|
|
428
442
|
size_slug: str = "gpu-mi300x1-192gb-devcloud" # Single MI300X GPU
|
|
429
|
-
image: str = "
|
|
443
|
+
image: str = "amd-pytorchrocm7" # PyTorch (ROCm7) marketplace image
|
|
430
444
|
|
|
431
445
|
# Timeouts
|
|
432
446
|
provision_timeout: int = 600 # 10 min for droplet to be ready
|
|
@@ -20,35 +20,17 @@ import modal
|
|
|
20
20
|
|
|
21
21
|
# Build Modal image with all dependencies
|
|
22
22
|
# This image is cached and reused across function invocations
|
|
23
|
-
def build_modal_image(
|
|
24
|
-
gpu_type: str = "B200",
|
|
25
|
-
compute_capability: str = "10.0",
|
|
26
|
-
) -> modal.Image:
|
|
23
|
+
def build_modal_image() -> modal.Image:
|
|
27
24
|
"""Build Modal image with PyTorch, CUTLASS, and evaluation dependencies.
|
|
28
25
|
|
|
29
26
|
Uses explicit local code inclusion to avoid pulling in SSH deployment code.
|
|
30
27
|
|
|
31
|
-
Phase 2 solution from MODAL_HANDOFF.md:
|
|
32
|
-
- Use add_local_dir with ignore parameter to exclude deployment files
|
|
33
|
-
- Only include files needed for kernel evaluation
|
|
34
|
-
|
|
35
|
-
Args:
|
|
36
|
-
gpu_type: GPU type (determines PyTorch index URL)
|
|
37
|
-
compute_capability: CUDA compute capability
|
|
38
|
-
|
|
39
28
|
Returns:
|
|
40
29
|
Modal Image ready for kernel evaluation
|
|
41
30
|
"""
|
|
42
|
-
#
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
# Blackwell requires PyTorch 2.8+ with CUDA 12.8
|
|
46
|
-
torch_index = "https://download.pytorch.org/whl/nightly/cu128"
|
|
47
|
-
torch_version = "torch>=2.8.0"
|
|
48
|
-
else:
|
|
49
|
-
# Older GPUs (H100, A100) use stable PyTorch
|
|
50
|
-
torch_index = "https://download.pytorch.org/whl/cu124"
|
|
51
|
-
torch_version = "torch>=2.4.0"
|
|
31
|
+
# Use CUDA 13.0 for all GPUs (H100, A100, B200, GB200)
|
|
32
|
+
torch_index = "https://download.pytorch.org/whl/cu130"
|
|
33
|
+
torch_version = "torch>=2.6.0"
|
|
52
34
|
|
|
53
35
|
# Build image with dependencies
|
|
54
36
|
image = (
|
|
@@ -74,6 +56,15 @@ def build_modal_image(
|
|
|
74
56
|
"scipy",
|
|
75
57
|
"pytest",
|
|
76
58
|
)
|
|
59
|
+
# Install CUTLASS headers for C++ kernel compilation (v4.3.5)
|
|
60
|
+
.run_commands(
|
|
61
|
+
"git clone --depth 1 --branch v4.3.5 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass",
|
|
62
|
+
# Verify CUTLASS was installed correctly
|
|
63
|
+
"ls -la /usr/local/cutlass/include/cutlass/util/ | head -20",
|
|
64
|
+
"test -f /usr/local/cutlass/include/cutlass/util/packed_stride.hpp && echo 'CUTLASS headers OK' || echo 'CUTLASS headers MISSING'",
|
|
65
|
+
)
|
|
66
|
+
# Set CUTLASS_PATH environment variable
|
|
67
|
+
.env({"CUTLASS_PATH": "/usr/local/cutlass/include"})
|
|
77
68
|
# Create empty __init__.py files for proper Python package structure
|
|
78
69
|
# MUST run before add_local_* commands (Modal restriction)
|
|
79
70
|
.run_commands(
|
|
@@ -111,20 +102,16 @@ def build_modal_image(
|
|
|
111
102
|
# Create app (can be customized per target)
|
|
112
103
|
def create_modal_app(
|
|
113
104
|
app_name: str = "test-kernel-eval", # Match test script default
|
|
114
|
-
gpu_type: str = "B200",
|
|
115
|
-
compute_capability: str = "10.0",
|
|
116
105
|
) -> modal.App:
|
|
117
106
|
"""Create Modal app for kernel evaluation.
|
|
118
107
|
|
|
119
108
|
Args:
|
|
120
109
|
app_name: Modal app name
|
|
121
|
-
gpu_type: GPU type for image building
|
|
122
|
-
compute_capability: CUDA compute capability
|
|
123
110
|
|
|
124
111
|
Returns:
|
|
125
112
|
Modal App instance
|
|
126
113
|
"""
|
|
127
|
-
image = build_modal_image(
|
|
114
|
+
image = build_modal_image()
|
|
128
115
|
return modal.App(name=app_name, image=image)
|
|
129
116
|
|
|
130
117
|
|
|
@@ -12,7 +12,7 @@ wafer_core/config/__init__.py,sha256=hKywfjA4YXd4lBeBFEcBoMwFoflPHJTiBnkTq7_JYOQ
|
|
|
12
12
|
wafer_core/config/loader.py,sha256=k7JnILmO13TWUzIv9Lm8fvmj3UfYHZDgaFurjQ-GXpY,6623
|
|
13
13
|
wafer_core/config/schema.py,sha256=2WhFlnG0VYYX4T-70BLeJK8Janvi4KEa8KKGZA7331w,3898
|
|
14
14
|
wafer_core/environments/__init__.py,sha256=SIsResVtm22tr_d-oHPeeSxrkhFdmPOFico3DqtRqK8,238
|
|
15
|
-
wafer_core/environments/coding.py,sha256=
|
|
15
|
+
wafer_core/environments/coding.py,sha256=N-ELZwJu5vKLCVtwO25c6JSty6fmqf85VR2d3WJ4RXw,8559
|
|
16
16
|
wafer_core/environments/gpumode.py,sha256=8Da08nltvN_YloNyYI6-omN2D4n5C7aptKDCtUgT2bQ,17191
|
|
17
17
|
wafer_core/lib/__init__.py,sha256=4-4p3mhwlquejWGglYXU8_nHdA0LoPaa_jGzcm13USA,1325
|
|
18
18
|
wafer_core/lib/kernel_scope/__init__.py,sha256=WW2vu8jUlqOu-MCpgO40lIYacCA9N2u-uuECIs_JO2w,2817
|
|
@@ -585,11 +585,12 @@ wafer_core/sessions/dtypes.py,sha256=K6nOjvL6sjCGY7GTtdEygf1IZY_18R9YkHGqFyMd8wY
|
|
|
585
585
|
wafer_core/sessions/hooks.py,sha256=A-txm6ufnRGQCdtP3vwh7oEOdlLN9Tv0XsjORMihuAI,4295
|
|
586
586
|
wafer_core/targets/__init__.py,sha256=sHndC7AAOaHXlrmDXFLB53a5Y8DBjuyqS6nwsO2nj-Y,1728
|
|
587
587
|
wafer_core/targets/digitalocean.py,sha256=cvoYpYjtSyy5t2lQAPi7ERruuuibronah_ivOiduAHQ,16550
|
|
588
|
-
wafer_core/targets/runpod.py,sha256=
|
|
589
|
-
wafer_core/tools/__init__.py,sha256=
|
|
588
|
+
wafer_core/targets/runpod.py,sha256=LrVmNvA6qjzL5nbGSWvtw7CHrK6bDu7_o3vKIek00Tc,20286
|
|
589
|
+
wafer_core/tools/__init__.py,sha256=deGQQlcdSD6zQx8JHizfSXgF5-EntdBOF_ngtob1-VU,3506
|
|
590
590
|
wafer_core/tools/bash_tool.py,sha256=daoKOVGSgL0x9X_3l8Apd6-wFH4VMXMGJwVemw2FIfc,16828
|
|
591
591
|
wafer_core/tools/glob_tool.py,sha256=9X5PdOjQJj7kiVNqqCZC0-1LmnE6wHx3Zc9zfMjtXdc,3533
|
|
592
592
|
wafer_core/tools/grep_tool.py,sha256=cStyDz-J47oDLLZCL83yOvYo8Ijv4qu3D372JKT_ptM,4580
|
|
593
|
+
wafer_core/tools/search_docs_tool.py,sha256=WY4hY83sseX8Fpxvw6DZxiG-F95F2t3-4PyfMD1Lpkg,6809
|
|
593
594
|
wafer_core/tools/skill_tool.py,sha256=JXsT5hBTUH5U4tmzHEywU7eHHt5xCEF79tL2tsuk4-c,2067
|
|
594
595
|
wafer_core/tools/wafer_tool.py,sha256=-dgPTHbWXq3I3wFj0mP7-lj5iZqGRoFvFf9IEEo3plQ,6345
|
|
595
596
|
wafer_core/tools/write_kernel_tool.py,sha256=dJjhr-WBhVNe06hcJQVmBZTbS8mid64KF1MwlE2s2R4,21547
|
|
@@ -615,7 +616,7 @@ wafer_core/tools/capture_tool/metrics.py,sha256=BFZNmdE-kh3LneYdWXTNZmlLuo-DCrP5
|
|
|
615
616
|
wafer_core/tools/file_tools/__init__.py,sha256=2H7Rq5bijNQHGO4W6jjQAShkrcmdcHC0EQ8mBpgrApI,632
|
|
616
617
|
wafer_core/tools/file_tools/edit_tool.py,sha256=Efx83pM1Ljb07cJmAGVhPX4YiPJICK70sZM6uCjRWB0,4109
|
|
617
618
|
wafer_core/tools/file_tools/glob_tool.py,sha256=Av4LfC21fHXbnSsgh_9zDxlY9Qhb48aApaGos4j3B4g,3437
|
|
618
|
-
wafer_core/tools/file_tools/grep_tool.py,sha256=
|
|
619
|
+
wafer_core/tools/file_tools/grep_tool.py,sha256=42eFj2pxBBrs5eg_GhyYJ-j2fNWkmGPvrEqXFmi5E10,5539
|
|
619
620
|
wafer_core/tools/file_tools/read_tool.py,sha256=K0Hd8zwyL4Yva5YO9spXDfTRfXvfjqh9ztVrA8s1bJE,3961
|
|
620
621
|
wafer_core/tools/file_tools/utils.py,sha256=HgaqYan2Pky4hTLX2L9d2Gj9oS325H7rFbJj-jryNtc,2576
|
|
621
622
|
wafer_core/tools/file_tools/write_tool.py,sha256=X4N8y8wB-k9d5PcMRmZMRKIXlG9jHJiRdlEFFRLdZzs,2083
|
|
@@ -661,18 +662,18 @@ wafer_core/utils/kernel_utils/evaluate.py,sha256=1kxFNMl9VCXfKfk_BIiuA_zFfvDB1sl
|
|
|
661
662
|
wafer_core/utils/kernel_utils/gpu_validation.py,sha256=LRiDjW_xAK4fXf1Vw1aYHG54B1W0J6b5L0K6PXzM2tI,3759
|
|
662
663
|
wafer_core/utils/kernel_utils/reference_cache.py,sha256=4IQ2gND1StHULRO7geyAElEStbjQxwOeP6X09E5wCB0,11283
|
|
663
664
|
wafer_core/utils/kernel_utils/results.py,sha256=QJGeah_41LSzxyYwGl9VxHPxTVAN2bLtk5bWdWLIpL4,6705
|
|
664
|
-
wafer_core/utils/kernel_utils/static_checker.py,sha256=
|
|
665
|
+
wafer_core/utils/kernel_utils/static_checker.py,sha256=XIQkzAOkGH5xtrOuZM4tNUqVJ0QRkYeJ7_8DosDOtkw,19886
|
|
665
666
|
wafer_core/utils/kernel_utils/task.py,sha256=XcmKxKUWh5It6nX3zGqj77tWgA32uPfQMqNOqyD5T48,2682
|
|
666
667
|
wafer_core/utils/kernel_utils/utils.py,sha256=uDZoJDxh07hJeLNlPdKN2vgB15pqIr1LbXf0YIBHU4E,43056
|
|
667
668
|
wafer_core/utils/kernel_utils/targets/__init__.py,sha256=4NwRLsuJ__S4xKAfda4Ag82C5MQ3Qio-4xA5S-mQGlU,2067
|
|
668
|
-
wafer_core/utils/kernel_utils/targets/config.py,sha256=
|
|
669
|
+
wafer_core/utils/kernel_utils/targets/config.py,sha256=sNXyYTZ9rL9OET4xqbHZ0d4b8ChzST1yUKvNOv8JSQs,19933
|
|
669
670
|
wafer_core/utils/kernel_utils/targets/execution.py,sha256=bZuNXCo0sIdD6hFhetLPrtDC-zMSiIsAx_aml49VVL0,15033
|
|
670
671
|
wafer_core/utils/kernel_utils/targets/selection.py,sha256=5I_RG_7cfhq7uaeR28meC2EeNNKssFsK-Tc3QFG6Ze0,3590
|
|
671
672
|
wafer_core/utils/modal_execution/__init__.py,sha256=jkVqYOLzCT5K73N9Od0UIUsx-99A0m6bpDrxfyXxQZ8,945
|
|
672
|
-
wafer_core/utils/modal_execution/modal_app.py,sha256=
|
|
673
|
+
wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdhiv2tly3CifOyh9f4,11455
|
|
673
674
|
wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
|
|
674
675
|
wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
|
|
675
676
|
wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
|
|
676
|
-
wafer_core-0.1.
|
|
677
|
-
wafer_core-0.1.
|
|
678
|
-
wafer_core-0.1.
|
|
677
|
+
wafer_core-0.1.24.dist-info/METADATA,sha256=h2zO5zgoRFyd1aZbWSugm8JWl8RzYYd9w5h0CDQ2pa4,1420
|
|
678
|
+
wafer_core-0.1.24.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
679
|
+
wafer_core-0.1.24.dist-info/RECORD,,
|
|
File without changes
|