wafer-core 0.1.22__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,6 +34,7 @@ from wafer_core.tools import (
34
34
  GLOB_TOOL,
35
35
  GREP_TOOL,
36
36
  READ_TOOL,
37
+ SEARCH_DOCS_TOOL,
37
38
  SKILL_TOOL,
38
39
  WRITE_TOOL,
39
40
  ApprovalCallback,
@@ -42,6 +43,7 @@ from wafer_core.tools import (
42
43
  exec_glob,
43
44
  exec_grep,
44
45
  exec_read,
46
+ exec_search_docs,
45
47
  exec_skill,
46
48
  exec_write,
47
49
  )
@@ -63,6 +65,7 @@ ALL_TOOLS = {
63
65
  "glob": GLOB_TOOL,
64
66
  "grep": GREP_TOOL,
65
67
  "bash": BASH_TOOL,
68
+ "search_docs": SEARCH_DOCS_TOOL,
66
69
  "skill": SKILL_TOOL,
67
70
  # TODO(wafer-tool): "wafer": WAFER_TOOL,
68
71
  }
@@ -211,6 +214,7 @@ class CodingEnvironment:
211
214
  self.bash_approval_callback,
212
215
  self._sandbox_policy,
213
216
  ),
217
+ "search_docs": lambda tc: exec_search_docs(tc),
214
218
  "skill": lambda tc: exec_skill(tc),
215
219
  # TODO(wafer-tool): "wafer": lambda tc: exec_wafer(
216
220
  # tc, self.working_dir, self.enabled_tools, self.allow_spawn, cancel_scope
@@ -14,7 +14,6 @@ from __future__ import annotations
14
14
 
15
15
  import json
16
16
  import logging
17
- import os
18
17
  import time
19
18
  from contextlib import asynccontextmanager
20
19
  from dataclasses import dataclass
@@ -250,10 +249,17 @@ async def provision_pod(target: RunPodTarget) -> tuple[str, str, int, str]:
250
249
  "ports": "22/tcp",
251
250
  "startSsh": True,
252
251
  "startJupyter": False,
253
- "imageName": target.image,
254
252
  "env": [],
255
253
  }
256
254
 
255
+ if target.template_id:
256
+ # Template defines image, dockerArgs (sshd setup), and ports.
257
+ # Required for non-RunPod images (e.g. rocm/pytorch) that don't
258
+ # have RunPod's built-in SSH handler.
259
+ pod_input["templateId"] = target.template_id
260
+ else:
261
+ pod_input["imageName"] = target.image
262
+
257
263
  variables = {"input": pod_input}
258
264
 
259
265
  logger.info(f"Provisioning RunPod pod: {pod_name}")
@@ -334,7 +340,8 @@ async def _wait_for_ssh(pod_id: str, timeout_seconds: int) -> tuple[str, int, st
334
340
  # Check for SSH port
335
341
  runtime = pod.get("runtime")
336
342
  if runtime and status == "running":
337
- for port in runtime.get("ports", []):
343
+ # ports can be null in JSON response, so use 'or []' instead of default
344
+ for port in runtime.get("ports") or []:
338
345
  if (
339
346
  port.get("privatePort") == 22
340
347
  and port.get("isIpPublic")
@@ -378,6 +385,55 @@ async def terminate_pod(pod_id: str) -> bool:
378
385
  return False
379
386
 
380
387
 
388
+ # =============================================================================
389
+ # Template Management (not yet implemented)
390
+ # =============================================================================
391
+ #
392
+ # The saveTemplate mutation allows creating reusable pod templates with custom
393
+ # configurations. Templates can specify docker images, environment setup,
394
+ # container disk size, and other pod settings.
395
+ #
396
+ # Example mutation:
397
+ #
398
+ # mutation saveTemplate($input: SaveTemplateInput) {
399
+ # saveTemplate(input: $input) {
400
+ # id
401
+ # name
402
+ # imageName
403
+ # containerDiskInGb
404
+ # ports
405
+ # dockerArgs
406
+ # startSsh
407
+ # startJupyter
408
+ # }
409
+ # }
410
+ #
411
+ # Example variables:
412
+ #
413
+ # {
414
+ # "input": {
415
+ # "containerDiskInGb": 50,
416
+ # "dockerArgs": "bash -c \"apt-get update && apt-get install -y openssh-server && ...\"",
417
+ # "env": [],
418
+ # "isPublic": false,
419
+ # "isServerless": false,
420
+ # "name": "template-name",
421
+ # "ports": "22/tcp",
422
+ # "portsConfig": [{"name": "SSH", "port": "22"}],
423
+ # "readme": "",
424
+ # "volumeInGb": 0,
425
+ # "volumeMountPath": "",
426
+ # "config": {},
427
+ # "category": "AMD",
428
+ # "imageName": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
429
+ # }
430
+ # }
431
+ #
432
+ # Note: Template creation is not currently implemented in this module.
433
+ # If needed, implement a save_template() function following the pattern of
434
+ # provision_pod() and terminate_pod() above.
435
+
436
+
381
437
  # =============================================================================
382
438
  # Context Manager
383
439
  # =============================================================================
@@ -482,20 +538,103 @@ async def cleanup_target(target_name: str) -> bool:
482
538
  return success
483
539
 
484
540
 
541
+ async def sync_pods_from_api() -> list[PodState]:
542
+ """Query RunPod API for all running pods and update local state.
543
+
544
+ This discovers pods that exist on the account but aren't in our state file
545
+ (e.g., created manually or by another machine). Updates the state file with
546
+ any wafer-created pods found.
547
+
548
+ Returns list of all running pods with SSH info.
549
+ """
550
+ query = """
551
+ query {
552
+ myself {
553
+ pods {
554
+ id
555
+ name
556
+ desiredStatus
557
+ runtime {
558
+ ports {
559
+ ip
560
+ isIpPublic
561
+ privatePort
562
+ publicPort
563
+ }
564
+ }
565
+ }
566
+ }
567
+ }
568
+ """
569
+
570
+ try:
571
+ data = await _graphql_request_async(query)
572
+ except Exception as e:
573
+ logger.warning(f"Failed to query pods from API: {e}")
574
+ return []
575
+
576
+ pods = data.get("myself", {}).get("pods", [])
577
+ running_pods = []
578
+
579
+ for pod in pods:
580
+ status = pod.get("desiredStatus", "").lower()
581
+ if status != "running":
582
+ continue
583
+
584
+ pod_id = pod["id"]
585
+ pod_name = pod.get("name", "")
586
+
587
+ # Extract SSH info
588
+ runtime = pod.get("runtime")
589
+ if not runtime:
590
+ continue
591
+
592
+ public_ip = None
593
+ ssh_port = None
594
+ for port in runtime.get("ports") or []:
595
+ if port.get("privatePort") == 22 and port.get("isIpPublic"):
596
+ public_ip = port.get("ip")
597
+ ssh_port = port.get("publicPort")
598
+ break
599
+
600
+ if not public_ip or not ssh_port:
601
+ continue
602
+
603
+ # Extract target name from pod name (wafer-{target_name}-{timestamp})
604
+ target_name = None
605
+ if pod_name.startswith("wafer-"):
606
+ parts = pod_name.split("-")
607
+ if len(parts) >= 3:
608
+ # Handle target names with hyphens: wafer-runpod-mi300x-1234567
609
+ target_name = "-".join(parts[1:-1])
610
+
611
+ pod_state = PodState(
612
+ pod_id=pod_id,
613
+ target_name=target_name or pod_name,
614
+ public_ip=public_ip,
615
+ ssh_port=ssh_port,
616
+ ssh_username="root",
617
+ created_at=datetime.now(timezone.utc).isoformat(),
618
+ )
619
+ running_pods.append(pod_state)
620
+
621
+ # Update state file if this is a wafer-created pod
622
+ if target_name:
623
+ existing = get_pod_state(target_name)
624
+ if not existing or existing.pod_id != pod_id:
625
+ logger.info(f"Syncing pod {pod_id} to state for target {target_name}")
626
+ _add_pod_to_state(target_name, pod_id, public_ip, ssh_port, "root")
627
+
628
+ return running_pods
629
+
630
+
485
631
  async def list_running_pods() -> list[PodState]:
486
- """List all pods in state file that are still running."""
487
- state = _load_state()
488
- running = []
632
+ """List all running pods by querying the RunPod API.
489
633
 
490
- for name, pod_state in state.items():
491
- if await check_pod_running(pod_state.pod_id):
492
- running.append(pod_state)
493
- else:
494
- # Clean up stale entry
495
- logger.info(f"Removing stale state for {name} (pod {pod_state.pod_id})")
496
- _remove_pod_from_state(name)
497
-
498
- return running
634
+ Syncs state file with API to discover pods not in local state.
635
+ Returns list of running pods with SSH info.
636
+ """
637
+ return await sync_pods_from_api()
499
638
 
500
639
 
501
640
  async def cleanup_all_pods() -> int:
@@ -72,6 +72,10 @@ from wafer_core.tools.write_kernel_tool import (
72
72
  KernelSubmission,
73
73
  exec_write_kernel,
74
74
  )
75
+ from wafer_core.tools.search_docs_tool import (
76
+ SEARCH_DOCS_TOOL,
77
+ exec_search_docs,
78
+ )
75
79
 
76
80
  __all__ = [
77
81
  # File tools
@@ -133,4 +137,7 @@ __all__ = [
133
137
  "exec_tracelens_report",
134
138
  "exec_tracelens_compare",
135
139
  "exec_tracelens_collective",
140
+ # Search docs tool
141
+ "SEARCH_DOCS_TOOL",
142
+ "exec_search_docs",
136
143
  ]
@@ -1,4 +1,4 @@
1
- """Grep tool using ripgrep for fast content search."""
1
+ """Grep tool using ripgrep (with fallback to standard grep)."""
2
2
 
3
3
  from pathlib import Path
4
4
 
@@ -15,7 +15,7 @@ GREP_TOOL = Tool(
15
15
  function=ToolFunction(
16
16
  name="grep",
17
17
  description=(
18
- "Search for a pattern in files using ripgrep. "
18
+ "Search for a pattern in files. "
19
19
  "Returns matching lines with file paths and line numbers. "
20
20
  "Supports regex patterns by default."
21
21
  ),
@@ -54,7 +54,7 @@ GREP_TOOL = Tool(
54
54
 
55
55
 
56
56
  async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
57
- """Execute grep using ripgrep."""
57
+ """Execute grep using ripgrep (preferred) or standard grep (fallback)."""
58
58
  import shutil
59
59
  import subprocess
60
60
 
@@ -74,35 +74,55 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
74
74
  error="'pattern' is required",
75
75
  )
76
76
 
77
- # Find ripgrep
77
+ # Try ripgrep first, fall back to standard grep
78
78
  rg_path = shutil.which("rg")
79
- if not rg_path:
79
+ grep_path = shutil.which("grep")
80
+
81
+ if rg_path:
82
+ # Use ripgrep (faster, better defaults)
83
+ cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
84
+
85
+ if case_insensitive:
86
+ cmd.append("--ignore-case")
87
+
88
+ if context_lines:
89
+ cmd.extend(["--context", str(context_lines)])
90
+
91
+ if glob_pattern:
92
+ cmd.extend(["--glob", glob_pattern])
93
+
94
+ # Limit results
95
+ cmd.extend(["--max-count", str(max_results)])
96
+
97
+ cmd.append(pattern)
98
+ cmd.append(search_path)
99
+ use_ripgrep = True
100
+ elif grep_path:
101
+ # Fallback to standard grep
102
+ cmd = [grep_path, "-r", "-n", "--color=never"]
103
+
104
+ if case_insensitive:
105
+ cmd.append("-i")
106
+
107
+ if context_lines:
108
+ cmd.extend(["-C", str(context_lines)])
109
+
110
+ if glob_pattern:
111
+ # Standard grep uses --include for glob patterns
112
+ cmd.extend(["--include", glob_pattern])
113
+
114
+ cmd.append(pattern)
115
+ cmd.append(search_path)
116
+ use_ripgrep = False
117
+ else:
80
118
  return ToolResult(
81
119
  tool_call_id=tool_call.id,
82
120
  is_error=True,
83
121
  content="",
84
- error="ripgrep (rg) not found. Please install it: brew install ripgrep",
122
+ error="Neither ripgrep (rg) nor grep found. Please install one.",
85
123
  )
86
124
 
87
- # Build command
88
- cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
89
-
90
- if case_insensitive:
91
- cmd.append("--ignore-case")
92
-
93
- if context_lines:
94
- cmd.extend(["--context", str(context_lines)])
95
-
96
- if glob_pattern:
97
- cmd.extend(["--glob", glob_pattern])
98
-
99
- # Limit results
100
- cmd.extend(["--max-count", str(max_results)])
101
-
102
- cmd.append(pattern)
103
- cmd.append(search_path)
104
-
105
- # Run ripgrep
125
+ # Run the search
106
126
  try:
107
127
  result = subprocess.run(
108
128
  cmd,
@@ -126,13 +146,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
126
146
  error=f"Search failed: {e}",
127
147
  )
128
148
 
129
- # ripgrep returns exit code 1 for no matches (not an error)
149
+ # Both ripgrep and grep return exit code 1 for no matches (not an error)
130
150
  if result.returncode not in (0, 1):
151
+ tool_name = "ripgrep" if use_ripgrep else "grep"
131
152
  return ToolResult(
132
153
  tool_call_id=tool_call.id,
133
154
  is_error=True,
134
155
  content="",
135
- error=result.stderr or f"ripgrep exited with code {result.returncode}",
156
+ error=result.stderr or f"{tool_name} exited with code {result.returncode}",
136
157
  )
137
158
 
138
159
  output = result.stdout.strip()
@@ -143,8 +164,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
143
164
  content=f"No matches found for pattern: {pattern}",
144
165
  )
145
166
 
146
- # Count matches
147
- match_count = len(output.split("\n"))
167
+ # Count matches and limit output for standard grep
168
+ lines = output.split("\n")
169
+ if not use_ripgrep and len(lines) > max_results:
170
+ lines = lines[:max_results]
171
+ output = "\n".join(lines)
172
+ output += f"\n... (truncated to {max_results} results)"
173
+
174
+ match_count = min(len(lines), max_results)
148
175
  header = f"Found {match_count} matches:\n\n"
149
176
 
150
177
  return ToolResult(
@@ -0,0 +1,196 @@
1
+ """Search documentation tool for GPU programming corpora.
2
+
3
+ Provides semantic and keyword search over documentation for CuTeDSL, CUDA, etc.
4
+
5
+ Corpora are downloaded via `wafer corpus download <name>` and stored in ~/.cache/wafer/corpora/.
6
+ """
7
+
8
+ import re
9
+ from pathlib import Path
10
+
11
+ from wafer_core.rollouts.dtypes import Tool, ToolCall, ToolFunction, ToolFunctionParameter, ToolResult
12
+
13
+ # Cache directory where wafer corpus download stores files
14
+ CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
15
+
16
+ # Available corpora (names match wafer corpus download)
17
+ AVAILABLE_CORPORA = ["cutlass", "cutedsl", "cuda", "hip", "amd"]
18
+
19
+ SEARCH_DOCS_TOOL = Tool(
20
+ type="function",
21
+ function=ToolFunction(
22
+ name="search_docs",
23
+ description="""Search GPU programming documentation for relevant information.
24
+
25
+ Use this tool to find documentation about:
26
+ - CUTLASS C++ (cute:: namespace, gemm tutorials, tensor cores, TMA, Blackwell)
27
+ - CuTeDSL Python API (@cute.kernel, @cute.jit, cute.arch functions)
28
+ - CUDA programming concepts
29
+ - GPU kernel optimization techniques
30
+ - Code examples and patterns
31
+
32
+ Available corpora:
33
+ - 'cutlass' - NVIDIA CUTLASS C++ docs + GitHub examples (gemm, hopper, blackwell)
34
+ - 'cutedsl' - CuTeDSL Python documentation
35
+ - 'cuda' - General CUDA programming docs
36
+ - 'hip' - AMD HIP programming docs
37
+ - 'amd' - AMD GPU kernel development (rocWMMA, CK, etc.)
38
+
39
+ Note: Corpora must be downloaded first with `wafer corpus download <name>`.
40
+ Returns relevant documentation snippets with file paths.""",
41
+ parameters=ToolFunctionParameter(
42
+ type="object",
43
+ properties={
44
+ "query": {
45
+ "type": "string",
46
+ "description": "Search query - describe what you're looking for",
47
+ },
48
+ "corpus": {
49
+ "type": "string",
50
+ "description": "Which docs to search: 'cutlass', 'cutedsl', 'cuda', 'hip', 'amd' (default: cutlass)",
51
+ },
52
+ "max_results": {
53
+ "type": "integer",
54
+ "description": "Maximum number of results to return (default: 5)",
55
+ },
56
+ },
57
+ ),
58
+ required=["query"],
59
+ )
60
+ )
61
+
62
+
63
+ def _get_corpus_path(corpus_name: str) -> Path | None:
64
+ """Get the path to a corpus in the cache directory.
65
+
66
+ Corpora are stored at ~/.cache/wafer/corpora/<corpus_name>/
67
+ """
68
+ if corpus_name not in AVAILABLE_CORPORA:
69
+ return None
70
+
71
+ corpus_path = CACHE_DIR / corpus_name
72
+ if corpus_path.exists():
73
+ return corpus_path
74
+
75
+ return None
76
+
77
+
78
+ def _search_files(corpus_path: Path, query: str, max_results: int = 5) -> list[dict]:
79
+ """Simple keyword search through documentation files."""
80
+ results = []
81
+ query_terms = query.lower().split()
82
+
83
+ # Search .md, .py, .cu, .hpp, and .h files (for CUTLASS examples)
84
+ for pattern in ["**/*.md", "**/*.py", "**/*.cu", "**/*.hpp", "**/*.h", "**/*.cuh"]:
85
+ for file_path in corpus_path.glob(pattern):
86
+ if file_path.is_file():
87
+ try:
88
+ content = file_path.read_text(encoding="utf-8", errors="ignore")
89
+ content_lower = content.lower()
90
+
91
+ # Score based on term matches
92
+ score = sum(content_lower.count(term) for term in query_terms)
93
+
94
+ if score > 0:
95
+ # Extract relevant snippets
96
+ snippets = _extract_snippets(content, query_terms)
97
+ results.append({
98
+ "file": str(file_path), # Return absolute path so read tool can access it
99
+ "score": score,
100
+ "snippets": snippets[:3], # Top 3 snippets
101
+ })
102
+ except Exception:
103
+ continue
104
+
105
+ # Sort by score and return top results
106
+ results.sort(key=lambda x: x["score"], reverse=True)
107
+ return results[:max_results]
108
+
109
+
110
+ def _extract_snippets(content: str, terms: list[str], context_lines: int = 5) -> list[str]:
111
+ """Extract snippets containing search terms."""
112
+ snippets = []
113
+ lines = content.split("\n")
114
+
115
+ for i, line in enumerate(lines):
116
+ line_lower = line.lower()
117
+ if any(term in line_lower for term in terms):
118
+ # Get context around the match
119
+ start = max(0, i - context_lines)
120
+ end = min(len(lines), i + context_lines + 1)
121
+ snippet = "\n".join(lines[start:end])
122
+
123
+ # Skip very short snippets
124
+ if len(snippet.strip()) > 50:
125
+ snippets.append(snippet)
126
+
127
+ return snippets
128
+
129
+
130
+ async def exec_search_docs(
131
+ tool_call: ToolCall,
132
+ corpus_override: str | None = None,
133
+ ) -> ToolResult:
134
+ """Execute search_docs tool.
135
+
136
+ Args:
137
+ tool_call: The tool call with query and optional corpus
138
+ corpus_override: Override corpus path (for testing)
139
+ """
140
+ query = tool_call.args.get("query", "")
141
+ corpus_name = tool_call.args.get("corpus", "cutlass")
142
+ max_results = tool_call.args.get("max_results", 5)
143
+
144
+ if not query:
145
+ return ToolResult(
146
+ tool_call_id=tool_call.id,
147
+ content="",
148
+ error="query parameter is required",
149
+ )
150
+
151
+ # Find corpus path
152
+ if corpus_override:
153
+ corpus_path = Path(corpus_override)
154
+ else:
155
+ corpus_path = _get_corpus_path(corpus_name)
156
+ if corpus_path is None:
157
+ return ToolResult(
158
+ tool_call_id=tool_call.id,
159
+ content="",
160
+ error=f"Unknown corpus: {corpus_name}. Available: {AVAILABLE_CORPORA}",
161
+ )
162
+
163
+ if not corpus_path.exists():
164
+ return ToolResult(
165
+ tool_call_id=tool_call.id,
166
+ content="",
167
+ error=f"Corpus '{corpus_name}' not downloaded. Run: wafer corpus download {corpus_name}",
168
+ )
169
+
170
+ # Search
171
+ results = _search_files(corpus_path, query, max_results)
172
+
173
+ if not results:
174
+ return ToolResult(
175
+ tool_call_id=tool_call.id,
176
+ content=f"No results found for query: {query}",
177
+ error=None,
178
+ )
179
+
180
+ # Format output
181
+ output_parts = [f"Found {len(results)} results for: {query}\n"]
182
+
183
+ for i, result in enumerate(results, 1):
184
+ output_parts.append(f"\n{'='*60}")
185
+ output_parts.append(f"[{i}] {result['file']} (score: {result['score']})")
186
+ output_parts.append("=" * 60)
187
+
188
+ for snippet in result["snippets"]:
189
+ output_parts.append(snippet)
190
+ output_parts.append("-" * 40)
191
+
192
+ return ToolResult(
193
+ tool_call_id=tool_call.id,
194
+ content="\n".join(output_parts),
195
+ error=None,
196
+ )
@@ -155,6 +155,161 @@ def check_torch_computation_ops(code: str) -> tuple[bool, str]:
155
155
  return (False, "")
156
156
 
157
157
 
158
+ # =============================================================================
159
+ # NN.MODULE FORWARD CALL CHECKS (Reward Hacking in forward())
160
+ # =============================================================================
161
+
162
+ # These patterns detect calling PyTorch nn.Module forward methods inside forward()
163
+ # e.g., self.conv(x), self.linear(x), self.bn(x) - these invoke cuBLAS/cuDNN
164
+ #
165
+ # This is different from:
166
+ # - nn.Linear(...) in __init__ = OK (just creates parameter container)
167
+ # - self.linear.weight in forward() = OK (accessing weights for custom kernel)
168
+ # - self.linear(x) in forward() = BAD (invokes PyTorch's matmul via cuBLAS)
169
+
170
+ NN_MODULE_FORWARD_PATTERNS = [
171
+ # Common layer types being called as functions
172
+ r"self\.(conv\d*d?|linear|bn|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
173
+ # More generic pattern: self.<name>(x) or self.<name>(input)
174
+ # But we need to be careful not to match custom module calls
175
+ ]
176
+
177
+ # =============================================================================
178
+ # TORCH.NN.FUNCTIONAL CHECKS (Reward Hacking)
179
+ # =============================================================================
180
+
181
+ # Patterns for torch.nn.functional / F.* calls that bypass custom kernel requirement
182
+ # These call into cuBLAS/cuDNN under the hood
183
+ TORCH_FUNCTIONAL_PATTERNS = [
184
+ # F.linear, F.conv*, F.batch_norm etc. (common alias)
185
+ r"\bF\.(linear|conv[123]d|conv_transpose[123]d|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
186
+ # Full path torch.nn.functional.*
187
+ r"\btorch\.nn\.functional\.(linear|conv[123]d|conv_transpose[123]d|batch_norm|layer_norm|group_norm|instance_norm)\s*\(",
188
+ ]
189
+
190
+
191
+ def check_torch_functional_calls(code: str) -> tuple[bool, str]:
192
+ """Check for torch.nn.functional / F.* calls in forward() method (reward hacking).
193
+
194
+ Detects patterns like F.linear(x, weight), F.conv2d(x, weight) which invoke
195
+ PyTorch's built-in operations (backed by cuBLAS/cuDNN) instead of custom kernels.
196
+
197
+ Note: We only check inside forward() to avoid false positives from imports or
198
+ training-mode fallbacks that aren't used during inference.
199
+ """
200
+ # Only check inside forward() method
201
+ forward_code = _extract_forward_method(code)
202
+ if not forward_code:
203
+ return (False, "")
204
+
205
+ forward_code = _strip_comments(forward_code)
206
+
207
+ for pattern in TORCH_FUNCTIONAL_PATTERNS:
208
+ match = re.search(pattern, forward_code)
209
+ if match:
210
+ return (True, f"Uses torch.nn.functional in forward(): {match.group(0)} (reward hacking - must use custom kernel)")
211
+
212
+ return (False, "")
213
+
214
+
215
+ def _extract_forward_method(code: str) -> str:
216
+ """Extract only the forward() method content from code."""
217
+ lines = code.split('\n')
218
+ result = []
219
+ in_forward = False
220
+ base_indent = 0
221
+
222
+ for i, line in enumerate(lines):
223
+ # Look for forward method definition
224
+ if re.search(r'\bdef\s+forward\s*\(\s*self', line):
225
+ in_forward = True
226
+ # Get the indentation level of the def line
227
+ base_indent = len(line) - len(line.lstrip())
228
+ result.append(line)
229
+ continue
230
+
231
+ if in_forward:
232
+ # Check if we've exited the forward method (new method/class at same or lower indent)
233
+ stripped = line.strip()
234
+ if stripped and not stripped.startswith('#') and not stripped.startswith('"""') and not stripped.startswith("'''"):
235
+ current_indent = len(line) - len(line.lstrip())
236
+ # If we hit a new def/class at the same or lower indentation, we're done
237
+ if current_indent <= base_indent and (stripped.startswith('def ') or stripped.startswith('class ')):
238
+ break
239
+ result.append(line)
240
+
241
+ return '\n'.join(result)
242
+
243
+
244
+ def check_nn_module_forward_call(code: str) -> tuple[bool, str]:
245
+ """Check for nn.Module forward calls inside forward() method (reward hacking).
246
+
247
+ Detects patterns like self.conv(x), self.linear(x) which invoke PyTorch's
248
+ built-in layers (backed by cuBLAS/cuDNN) instead of custom kernels.
249
+ """
250
+ # Only check inside forward() method
251
+ forward_code = _extract_forward_method(code)
252
+ if not forward_code:
253
+ return (False, "")
254
+
255
+ forward_code = _strip_comments(forward_code)
256
+
257
+ for pattern in NN_MODULE_FORWARD_PATTERNS:
258
+ match = re.search(pattern, forward_code)
259
+ if match:
260
+ return (True, f"Calls PyTorch nn.Module in forward(): {match.group(0)} (reward hacking - must use custom kernel)")
261
+
262
+ return (False, "")
263
+
264
+
265
+ # =============================================================================
266
+ # CUBLAS/CUDNN DIRECT USAGE CHECKS (Reward Hacking)
267
+ # =============================================================================
268
+
269
+ # Direct cuBLAS calls bypass custom kernel requirement
270
+ CUBLAS_PATTERNS = [
271
+ r"\bcublas[A-Z]\w+\s*\(", # cublasSgemm, cublasGemmEx, etc.
272
+ r"\bcublasCreate\b",
273
+ r"\bcublasDestroy\b",
274
+ r"\bcublasSetStream\b",
275
+ r"\bcublasSetMathMode\b",
276
+ r"#include\s*[<\"]cublas", # #include <cublas_v2.h>
277
+ r"CUBLAS_TENSOR_OP_MATH",
278
+ ]
279
+
280
+ # Direct cuDNN calls bypass custom kernel requirement
281
+ CUDNN_PATTERNS = [
282
+ r"\bcudnn[A-Z]\w+\s*\(", # cudnnConvolutionForward, etc.
283
+ r"\bcudnnCreate\b",
284
+ r"\bcudnnDestroy\b",
285
+ r"#include\s*[<\"]cudnn", # #include <cudnn.h>
286
+ ]
287
+
288
+
289
+ def check_cublas_usage(code: str) -> tuple[bool, str]:
290
+ """Check for direct cuBLAS API usage (reward hacking)."""
291
+ code = _strip_comments(code)
292
+
293
+ for pattern in CUBLAS_PATTERNS:
294
+ match = re.search(pattern, code)
295
+ if match:
296
+ return (True, f"Uses cuBLAS directly: {match.group(0)} (reward hacking - must write custom kernel)")
297
+
298
+ return (False, "")
299
+
300
+
301
+ def check_cudnn_usage(code: str) -> tuple[bool, str]:
302
+ """Check for direct cuDNN API usage (reward hacking)."""
303
+ code = _strip_comments(code)
304
+
305
+ for pattern in CUDNN_PATTERNS:
306
+ match = re.search(pattern, code)
307
+ if match:
308
+ return (True, f"Uses cuDNN directly: {match.group(0)} (reward hacking - must write custom kernel)")
309
+
310
+ return (False, "")
311
+
312
+
158
313
  # =============================================================================
159
314
  # TIMING MANIPULATION CHECKS
160
315
  # =============================================================================
@@ -310,8 +465,16 @@ def check_tk_impl(code: str) -> tuple[bool, str]:
310
465
  def check_cute_impl(code: str) -> tuple[bool, str]:
311
466
  """Check for valid CUTLASS/CuTe kernel implementation."""
312
467
  code = _strip_comments(code)
313
- if not any(p in code for p in ["cute::", "cutlass::", "from cutlass"]):
314
- return (True, "Missing cute:: or cutlass:: namespace")
468
+ # Accept explicit namespace qualifiers OR using namespace declarations
469
+ valid_patterns = [
470
+ "cute::",
471
+ "cutlass::",
472
+ "from cutlass",
473
+ "using namespace cute",
474
+ "using namespace cutlass",
475
+ ]
476
+ if not any(p in code for p in valid_patterns):
477
+ return (True, "Missing cute:: or cutlass:: namespace (or 'using namespace')")
315
478
  return (False, "")
316
479
 
317
480
 
@@ -334,6 +497,11 @@ CHECK_FUNCTIONS: dict[str, Callable[[str], tuple[bool, str]]] = {
334
497
  "timing_event_patch": check_timing_event_patch,
335
498
  # Torch ops
336
499
  "torch_computation_ops": check_torch_computation_ops,
500
+ # Reward hacking checks
501
+ "cublas_usage": check_cublas_usage,
502
+ "cudnn_usage": check_cudnn_usage,
503
+ "nn_module_forward_call": check_nn_module_forward_call,
504
+ "torch_functional_calls": check_torch_functional_calls,
337
505
  # Timing manipulation
338
506
  "stream_injection": check_stream_injection,
339
507
  "thread_injection": check_thread_injection,
@@ -363,12 +531,16 @@ STRICT_CHECKS = [
363
531
  "timing_event_patch",
364
532
  "thread_injection",
365
533
  "lazy_eval",
534
+ "cublas_usage", # Direct cuBLAS is reward hacking
535
+ "cudnn_usage", # Direct cuDNN is reward hacking
536
+ "nn_module_forward_call", # Calling self.conv(x), self.linear(x) in forward() is reward hacking
537
+ "torch_functional_calls", # Calling F.linear(), F.conv2d() in forward() is reward hacking
538
+ "torch_computation_ops", # torch.mm, torch.matmul, torch.conv* etc. are reward hacking
366
539
  ]
367
540
 
368
541
  # Checks that emit warnings but don't fail
369
542
  WARNING_CHECKS = [
370
543
  "pytorch_wrap",
371
- "torch_computation_ops",
372
544
  "stream_injection",
373
545
  ]
374
546
 
@@ -290,7 +290,7 @@ class RunPodTarget:
290
290
  ssh_key="~/.ssh/id_ed25519",
291
291
  gpu_type_id="AMD Instinct MI300X OAM",
292
292
  gpu_count=1,
293
- image="runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04",
293
+ image="rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1",
294
294
  keep_alive=True, # Don't terminate after eval
295
295
  )
296
296
 
@@ -304,7 +304,21 @@ class RunPodTarget:
304
304
  gpu_type_id: str = AMD_MI300X_GPU_ID # RunPod GPU type identifier
305
305
  gpu_count: int = 1
306
306
  container_disk_gb: int = 50
307
- image: str = "runpod/pytorch:2.4.0-py3.10-rocm6.1.0-ubuntu22.04"
307
+ # TODO: Consider creating a custom Docker image with HipKittens pre-installed
308
+ # to avoid needing `wafer config targets install <target> hipkittens`.
309
+ # HipKittens repo: https://github.com/HazyResearch/hipkittens
310
+ # CK (Composable Kernel) is already included in ROCm 7.0.
311
+ #
312
+ # WARNING: PyTorch's hipify can corrupt /opt/rocm/include/thrust/ headers.
313
+ # If you see "cuda/__cccl_config not found" errors, run:
314
+ # apt-get install --reinstall -y rocthrust
315
+ # See docker/rocm7-runpod/README.md for details.
316
+ image: str = "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
317
+
318
+ # RunPod template ID — required for non-RunPod images that need custom
319
+ # dockerArgs (e.g. to install and start sshd). When set, takes priority
320
+ # over `image` in the deploy mutation.
321
+ template_id: str | None = None
308
322
 
309
323
  # Timeouts
310
324
  provision_timeout: int = 900 # 15 min for SSH to be ready
@@ -426,7 +440,7 @@ class DigitalOceanTarget:
426
440
  # DigitalOcean instance configuration
427
441
  region: str = "atl1" # Atlanta (AMD GPUs available here)
428
442
  size_slug: str = "gpu-mi300x1-192gb-devcloud" # Single MI300X GPU
429
- image: str = "gpu-amd-base" # AMD AI/ML Ready image with ROCm
443
+ image: str = "amd-pytorchrocm7" # PyTorch (ROCm7) marketplace image
430
444
 
431
445
  # Timeouts
432
446
  provision_timeout: int = 600 # 10 min for droplet to be ready
@@ -20,35 +20,17 @@ import modal
20
20
 
21
21
  # Build Modal image with all dependencies
22
22
  # This image is cached and reused across function invocations
23
- def build_modal_image(
24
- gpu_type: str = "B200",
25
- compute_capability: str = "10.0",
26
- ) -> modal.Image:
23
+ def build_modal_image() -> modal.Image:
27
24
  """Build Modal image with PyTorch, CUTLASS, and evaluation dependencies.
28
25
 
29
26
  Uses explicit local code inclusion to avoid pulling in SSH deployment code.
30
27
 
31
- Phase 2 solution from MODAL_HANDOFF.md:
32
- - Use add_local_dir with ignore parameter to exclude deployment files
33
- - Only include files needed for kernel evaluation
34
-
35
- Args:
36
- gpu_type: GPU type (determines PyTorch index URL)
37
- compute_capability: CUDA compute capability
38
-
39
28
  Returns:
40
29
  Modal Image ready for kernel evaluation
41
30
  """
42
- # Determine PyTorch index based on GPU type
43
- # Match logic from configs/base_config.py
44
- if gpu_type in ["B200", "GB200"] or compute_capability.startswith("10."):
45
- # Blackwell requires PyTorch 2.8+ with CUDA 12.8
46
- torch_index = "https://download.pytorch.org/whl/nightly/cu128"
47
- torch_version = "torch>=2.8.0"
48
- else:
49
- # Older GPUs (H100, A100) use stable PyTorch
50
- torch_index = "https://download.pytorch.org/whl/cu124"
51
- torch_version = "torch>=2.4.0"
31
+ # Use CUDA 13.0 for all GPUs (H100, A100, B200, GB200)
32
+ torch_index = "https://download.pytorch.org/whl/cu130"
33
+ torch_version = "torch>=2.6.0"
52
34
 
53
35
  # Build image with dependencies
54
36
  image = (
@@ -74,6 +56,15 @@ def build_modal_image(
74
56
  "scipy",
75
57
  "pytest",
76
58
  )
59
+ # Install CUTLASS headers for C++ kernel compilation (v4.3.5)
60
+ .run_commands(
61
+ "git clone --depth 1 --branch v4.3.5 https://github.com/NVIDIA/cutlass.git /usr/local/cutlass",
62
+ # Verify CUTLASS was installed correctly
63
+ "ls -la /usr/local/cutlass/include/cutlass/util/ | head -20",
64
+ "test -f /usr/local/cutlass/include/cutlass/util/packed_stride.hpp && echo 'CUTLASS headers OK' || echo 'CUTLASS headers MISSING'",
65
+ )
66
+ # Set CUTLASS_PATH environment variable
67
+ .env({"CUTLASS_PATH": "/usr/local/cutlass/include"})
77
68
  # Create empty __init__.py files for proper Python package structure
78
69
  # MUST run before add_local_* commands (Modal restriction)
79
70
  .run_commands(
@@ -111,20 +102,16 @@ def build_modal_image(
111
102
  # Create app (can be customized per target)
112
103
  def create_modal_app(
113
104
  app_name: str = "test-kernel-eval", # Match test script default
114
- gpu_type: str = "B200",
115
- compute_capability: str = "10.0",
116
105
  ) -> modal.App:
117
106
  """Create Modal app for kernel evaluation.
118
107
 
119
108
  Args:
120
109
  app_name: Modal app name
121
- gpu_type: GPU type for image building
122
- compute_capability: CUDA compute capability
123
110
 
124
111
  Returns:
125
112
  Modal App instance
126
113
  """
127
- image = build_modal_image(gpu_type=gpu_type, compute_capability=compute_capability)
114
+ image = build_modal_image()
128
115
  return modal.App(name=app_name, image=image)
129
116
 
130
117
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.22
3
+ Version: 0.1.23
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0
@@ -12,7 +12,7 @@ wafer_core/config/__init__.py,sha256=hKywfjA4YXd4lBeBFEcBoMwFoflPHJTiBnkTq7_JYOQ
12
12
  wafer_core/config/loader.py,sha256=k7JnILmO13TWUzIv9Lm8fvmj3UfYHZDgaFurjQ-GXpY,6623
13
13
  wafer_core/config/schema.py,sha256=2WhFlnG0VYYX4T-70BLeJK8Janvi4KEa8KKGZA7331w,3898
14
14
  wafer_core/environments/__init__.py,sha256=SIsResVtm22tr_d-oHPeeSxrkhFdmPOFico3DqtRqK8,238
15
- wafer_core/environments/coding.py,sha256=T-_JFU-n5OxPR8xAWp8qar4Y5xyC-TWTIBjRy4PDel8,8418
15
+ wafer_core/environments/coding.py,sha256=N-ELZwJu5vKLCVtwO25c6JSty6fmqf85VR2d3WJ4RXw,8559
16
16
  wafer_core/environments/gpumode.py,sha256=8Da08nltvN_YloNyYI6-omN2D4n5C7aptKDCtUgT2bQ,17191
17
17
  wafer_core/lib/__init__.py,sha256=4-4p3mhwlquejWGglYXU8_nHdA0LoPaa_jGzcm13USA,1325
18
18
  wafer_core/lib/kernel_scope/__init__.py,sha256=WW2vu8jUlqOu-MCpgO40lIYacCA9N2u-uuECIs_JO2w,2817
@@ -585,11 +585,12 @@ wafer_core/sessions/dtypes.py,sha256=K6nOjvL6sjCGY7GTtdEygf1IZY_18R9YkHGqFyMd8wY
585
585
  wafer_core/sessions/hooks.py,sha256=A-txm6ufnRGQCdtP3vwh7oEOdlLN9Tv0XsjORMihuAI,4295
586
586
  wafer_core/targets/__init__.py,sha256=sHndC7AAOaHXlrmDXFLB53a5Y8DBjuyqS6nwsO2nj-Y,1728
587
587
  wafer_core/targets/digitalocean.py,sha256=cvoYpYjtSyy5t2lQAPi7ERruuuibronah_ivOiduAHQ,16550
588
- wafer_core/targets/runpod.py,sha256=bYTLVRaASrewJLIcZRtPMfMU2en1McE6W5w-Edo_OPQ,15785
589
- wafer_core/tools/__init__.py,sha256=wBQD45GdSfkxcT6NHzIv0IMeXCc0enwwkpm3T_9j1X8,3341
588
+ wafer_core/targets/runpod.py,sha256=LrVmNvA6qjzL5nbGSWvtw7CHrK6bDu7_o3vKIek00Tc,20286
589
+ wafer_core/tools/__init__.py,sha256=deGQQlcdSD6zQx8JHizfSXgF5-EntdBOF_ngtob1-VU,3506
590
590
  wafer_core/tools/bash_tool.py,sha256=daoKOVGSgL0x9X_3l8Apd6-wFH4VMXMGJwVemw2FIfc,16828
591
591
  wafer_core/tools/glob_tool.py,sha256=9X5PdOjQJj7kiVNqqCZC0-1LmnE6wHx3Zc9zfMjtXdc,3533
592
592
  wafer_core/tools/grep_tool.py,sha256=cStyDz-J47oDLLZCL83yOvYo8Ijv4qu3D372JKT_ptM,4580
593
+ wafer_core/tools/search_docs_tool.py,sha256=WY4hY83sseX8Fpxvw6DZxiG-F95F2t3-4PyfMD1Lpkg,6809
593
594
  wafer_core/tools/skill_tool.py,sha256=JXsT5hBTUH5U4tmzHEywU7eHHt5xCEF79tL2tsuk4-c,2067
594
595
  wafer_core/tools/wafer_tool.py,sha256=-dgPTHbWXq3I3wFj0mP7-lj5iZqGRoFvFf9IEEo3plQ,6345
595
596
  wafer_core/tools/write_kernel_tool.py,sha256=dJjhr-WBhVNe06hcJQVmBZTbS8mid64KF1MwlE2s2R4,21547
@@ -615,7 +616,7 @@ wafer_core/tools/capture_tool/metrics.py,sha256=BFZNmdE-kh3LneYdWXTNZmlLuo-DCrP5
615
616
  wafer_core/tools/file_tools/__init__.py,sha256=2H7Rq5bijNQHGO4W6jjQAShkrcmdcHC0EQ8mBpgrApI,632
616
617
  wafer_core/tools/file_tools/edit_tool.py,sha256=Efx83pM1Ljb07cJmAGVhPX4YiPJICK70sZM6uCjRWB0,4109
617
618
  wafer_core/tools/file_tools/glob_tool.py,sha256=Av4LfC21fHXbnSsgh_9zDxlY9Qhb48aApaGos4j3B4g,3437
618
- wafer_core/tools/file_tools/grep_tool.py,sha256=FRIYeBfCcywHqkiT8OqVL8xOtMiZDOV6EhUU2LN5fqM,4444
619
+ wafer_core/tools/file_tools/grep_tool.py,sha256=42eFj2pxBBrs5eg_GhyYJ-j2fNWkmGPvrEqXFmi5E10,5539
619
620
  wafer_core/tools/file_tools/read_tool.py,sha256=K0Hd8zwyL4Yva5YO9spXDfTRfXvfjqh9ztVrA8s1bJE,3961
620
621
  wafer_core/tools/file_tools/utils.py,sha256=HgaqYan2Pky4hTLX2L9d2Gj9oS325H7rFbJj-jryNtc,2576
621
622
  wafer_core/tools/file_tools/write_tool.py,sha256=X4N8y8wB-k9d5PcMRmZMRKIXlG9jHJiRdlEFFRLdZzs,2083
@@ -661,18 +662,18 @@ wafer_core/utils/kernel_utils/evaluate.py,sha256=1kxFNMl9VCXfKfk_BIiuA_zFfvDB1sl
661
662
  wafer_core/utils/kernel_utils/gpu_validation.py,sha256=LRiDjW_xAK4fXf1Vw1aYHG54B1W0J6b5L0K6PXzM2tI,3759
662
663
  wafer_core/utils/kernel_utils/reference_cache.py,sha256=4IQ2gND1StHULRO7geyAElEStbjQxwOeP6X09E5wCB0,11283
663
664
  wafer_core/utils/kernel_utils/results.py,sha256=QJGeah_41LSzxyYwGl9VxHPxTVAN2bLtk5bWdWLIpL4,6705
664
- wafer_core/utils/kernel_utils/static_checker.py,sha256=GWC7RZdwL4WqMLK0nO-wnZd-VmlMahCVKP6l9zVAW30,12909
665
+ wafer_core/utils/kernel_utils/static_checker.py,sha256=XIQkzAOkGH5xtrOuZM4tNUqVJ0QRkYeJ7_8DosDOtkw,19886
665
666
  wafer_core/utils/kernel_utils/task.py,sha256=XcmKxKUWh5It6nX3zGqj77tWgA32uPfQMqNOqyD5T48,2682
666
667
  wafer_core/utils/kernel_utils/utils.py,sha256=uDZoJDxh07hJeLNlPdKN2vgB15pqIr1LbXf0YIBHU4E,43056
667
668
  wafer_core/utils/kernel_utils/targets/__init__.py,sha256=4NwRLsuJ__S4xKAfda4Ag82C5MQ3Qio-4xA5S-mQGlU,2067
668
- wafer_core/utils/kernel_utils/targets/config.py,sha256=3TvT2Hp3TV-cjsSiZ8NOmAz8epDoVii76p6DDAI2V64,19134
669
+ wafer_core/utils/kernel_utils/targets/config.py,sha256=sNXyYTZ9rL9OET4xqbHZ0d4b8ChzST1yUKvNOv8JSQs,19933
669
670
  wafer_core/utils/kernel_utils/targets/execution.py,sha256=bZuNXCo0sIdD6hFhetLPrtDC-zMSiIsAx_aml49VVL0,15033
670
671
  wafer_core/utils/kernel_utils/targets/selection.py,sha256=5I_RG_7cfhq7uaeR28meC2EeNNKssFsK-Tc3QFG6Ze0,3590
671
672
  wafer_core/utils/modal_execution/__init__.py,sha256=jkVqYOLzCT5K73N9Od0UIUsx-99A0m6bpDrxfyXxQZ8,945
672
- wafer_core/utils/modal_execution/modal_app.py,sha256=ibBllC59R9bP9w4QweFocVGeSX5kDmUuDl39PPYgCpE,11796
673
+ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdhiv2tly3CifOyh9f4,11455
673
674
  wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
674
675
  wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
675
676
  wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
676
- wafer_core-0.1.22.dist-info/METADATA,sha256=wV6MLEufRIKPRacW_ErMoqgymAFNI2XgP4wobQoKjnM,1420
677
- wafer_core-0.1.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
678
- wafer_core-0.1.22.dist-info/RECORD,,
677
+ wafer_core-0.1.23.dist-info/METADATA,sha256=HnIqBmqEQ6t_dc54Rnyg_Wyy-HKuAr3XTmsEoJkjJLo,1420
678
+ wafer_core-0.1.23.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
679
+ wafer_core-0.1.23.dist-info/RECORD,,