wafer-core 0.1.21__py3-none-any.whl → 0.1.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,7 +14,6 @@ from __future__ import annotations
14
14
 
15
15
  import json
16
16
  import logging
17
- import os
18
17
  import time
19
18
  from contextlib import asynccontextmanager
20
19
  from dataclasses import dataclass
@@ -250,10 +249,17 @@ async def provision_pod(target: RunPodTarget) -> tuple[str, str, int, str]:
250
249
  "ports": "22/tcp",
251
250
  "startSsh": True,
252
251
  "startJupyter": False,
253
- "imageName": target.image,
254
252
  "env": [],
255
253
  }
256
254
 
255
+ if target.template_id:
256
+ # Template defines image, dockerArgs (sshd setup), and ports.
257
+ # Required for non-RunPod images (e.g. rocm/pytorch) that don't
258
+ # have RunPod's built-in SSH handler.
259
+ pod_input["templateId"] = target.template_id
260
+ else:
261
+ pod_input["imageName"] = target.image
262
+
257
263
  variables = {"input": pod_input}
258
264
 
259
265
  logger.info(f"Provisioning RunPod pod: {pod_name}")
@@ -334,7 +340,8 @@ async def _wait_for_ssh(pod_id: str, timeout_seconds: int) -> tuple[str, int, st
334
340
  # Check for SSH port
335
341
  runtime = pod.get("runtime")
336
342
  if runtime and status == "running":
337
- for port in runtime.get("ports", []):
343
+ # ports can be null in JSON response, so use 'or []' instead of default
344
+ for port in runtime.get("ports") or []:
338
345
  if (
339
346
  port.get("privatePort") == 22
340
347
  and port.get("isIpPublic")
@@ -378,6 +385,55 @@ async def terminate_pod(pod_id: str) -> bool:
378
385
  return False
379
386
 
380
387
 
388
+ # =============================================================================
389
+ # Template Management (not yet implemented)
390
+ # =============================================================================
391
+ #
392
+ # The saveTemplate mutation allows creating reusable pod templates with custom
393
+ # configurations. Templates can specify docker images, environment setup,
394
+ # container disk size, and other pod settings.
395
+ #
396
+ # Example mutation:
397
+ #
398
+ # mutation saveTemplate($input: SaveTemplateInput) {
399
+ # saveTemplate(input: $input) {
400
+ # id
401
+ # name
402
+ # imageName
403
+ # containerDiskInGb
404
+ # ports
405
+ # dockerArgs
406
+ # startSsh
407
+ # startJupyter
408
+ # }
409
+ # }
410
+ #
411
+ # Example variables:
412
+ #
413
+ # {
414
+ # "input": {
415
+ # "containerDiskInGb": 50,
416
+ # "dockerArgs": "bash -c \"apt-get update && apt-get install -y openssh-server && ...\"",
417
+ # "env": [],
418
+ # "isPublic": false,
419
+ # "isServerless": false,
420
+ # "name": "template-name",
421
+ # "ports": "22/tcp",
422
+ # "portsConfig": [{"name": "SSH", "port": "22"}],
423
+ # "readme": "",
424
+ # "volumeInGb": 0,
425
+ # "volumeMountPath": "",
426
+ # "config": {},
427
+ # "category": "AMD",
428
+ # "imageName": "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
429
+ # }
430
+ # }
431
+ #
432
+ # Note: Template creation is not currently implemented in this module.
433
+ # If needed, implement a save_template() function following the pattern of
434
+ # provision_pod() and terminate_pod() above.
435
+
436
+
381
437
  # =============================================================================
382
438
  # Context Manager
383
439
  # =============================================================================
@@ -482,20 +538,103 @@ async def cleanup_target(target_name: str) -> bool:
482
538
  return success
483
539
 
484
540
 
541
+ async def sync_pods_from_api() -> list[PodState]:
542
+ """Query RunPod API for all running pods and update local state.
543
+
544
+ This discovers pods that exist on the account but aren't in our state file
545
+ (e.g., created manually or by another machine). Updates the state file with
546
+ any wafer-created pods found.
547
+
548
+ Returns list of all running pods with SSH info.
549
+ """
550
+ query = """
551
+ query {
552
+ myself {
553
+ pods {
554
+ id
555
+ name
556
+ desiredStatus
557
+ runtime {
558
+ ports {
559
+ ip
560
+ isIpPublic
561
+ privatePort
562
+ publicPort
563
+ }
564
+ }
565
+ }
566
+ }
567
+ }
568
+ """
569
+
570
+ try:
571
+ data = await _graphql_request_async(query)
572
+ except Exception as e:
573
+ logger.warning(f"Failed to query pods from API: {e}")
574
+ return []
575
+
576
+ pods = data.get("myself", {}).get("pods", [])
577
+ running_pods = []
578
+
579
+ for pod in pods:
580
+ status = pod.get("desiredStatus", "").lower()
581
+ if status != "running":
582
+ continue
583
+
584
+ pod_id = pod["id"]
585
+ pod_name = pod.get("name", "")
586
+
587
+ # Extract SSH info
588
+ runtime = pod.get("runtime")
589
+ if not runtime:
590
+ continue
591
+
592
+ public_ip = None
593
+ ssh_port = None
594
+ for port in runtime.get("ports") or []:
595
+ if port.get("privatePort") == 22 and port.get("isIpPublic"):
596
+ public_ip = port.get("ip")
597
+ ssh_port = port.get("publicPort")
598
+ break
599
+
600
+ if not public_ip or not ssh_port:
601
+ continue
602
+
603
+ # Extract target name from pod name (wafer-{target_name}-{timestamp})
604
+ target_name = None
605
+ if pod_name.startswith("wafer-"):
606
+ parts = pod_name.split("-")
607
+ if len(parts) >= 3:
608
+ # Handle target names with hyphens: wafer-runpod-mi300x-1234567
609
+ target_name = "-".join(parts[1:-1])
610
+
611
+ pod_state = PodState(
612
+ pod_id=pod_id,
613
+ target_name=target_name or pod_name,
614
+ public_ip=public_ip,
615
+ ssh_port=ssh_port,
616
+ ssh_username="root",
617
+ created_at=datetime.now(timezone.utc).isoformat(),
618
+ )
619
+ running_pods.append(pod_state)
620
+
621
+ # Update state file if this is a wafer-created pod
622
+ if target_name:
623
+ existing = get_pod_state(target_name)
624
+ if not existing or existing.pod_id != pod_id:
625
+ logger.info(f"Syncing pod {pod_id} to state for target {target_name}")
626
+ _add_pod_to_state(target_name, pod_id, public_ip, ssh_port, "root")
627
+
628
+ return running_pods
629
+
630
+
485
631
  async def list_running_pods() -> list[PodState]:
486
- """List all pods in state file that are still running."""
487
- state = _load_state()
488
- running = []
632
+ """List all running pods by querying the RunPod API.
489
633
 
490
- for name, pod_state in state.items():
491
- if await check_pod_running(pod_state.pod_id):
492
- running.append(pod_state)
493
- else:
494
- # Clean up stale entry
495
- logger.info(f"Removing stale state for {name} (pod {pod_state.pod_id})")
496
- _remove_pod_from_state(name)
497
-
498
- return running
634
+ Syncs state file with API to discover pods not in local state.
635
+ Returns list of running pods with SSH info.
636
+ """
637
+ return await sync_pods_from_api()
499
638
 
500
639
 
501
640
  async def cleanup_all_pods() -> int:
@@ -49,6 +49,10 @@ from wafer_core.tools.rocprof_systems_tools import (
49
49
  exec_rocprof_systems_query,
50
50
  exec_rocprof_systems_sample,
51
51
  )
52
+ from wafer_core.tools.skill_tool import (
53
+ SKILL_TOOL,
54
+ exec_skill,
55
+ )
52
56
  from wafer_core.tools.tracelens_tools import (
53
57
  TRACELENS_COLLECTIVE_TOOL,
54
58
  TRACELENS_COMPARE_TOOL,
@@ -68,6 +72,10 @@ from wafer_core.tools.write_kernel_tool import (
68
72
  KernelSubmission,
69
73
  exec_write_kernel,
70
74
  )
75
+ from wafer_core.tools.search_docs_tool import (
76
+ SEARCH_DOCS_TOOL,
77
+ exec_search_docs,
78
+ )
71
79
 
72
80
  __all__ = [
73
81
  # File tools
@@ -88,6 +96,9 @@ __all__ = [
88
96
  "BashPermissionResult",
89
97
  "check_bash_permissions",
90
98
  "exec_bash",
99
+ # Skill tool
100
+ "SKILL_TOOL",
101
+ "exec_skill",
91
102
  # Wafer tool
92
103
  "WAFER_TOOL",
93
104
  "WAFER_SUBCOMMANDS",
@@ -126,4 +137,7 @@ __all__ = [
126
137
  "exec_tracelens_report",
127
138
  "exec_tracelens_compare",
128
139
  "exec_tracelens_collective",
140
+ # Search docs tool
141
+ "SEARCH_DOCS_TOOL",
142
+ "exec_search_docs",
129
143
  ]
@@ -1,4 +1,4 @@
1
- """Grep tool using ripgrep for fast content search."""
1
+ """Grep tool using ripgrep (with fallback to standard grep)."""
2
2
 
3
3
  from pathlib import Path
4
4
 
@@ -15,7 +15,7 @@ GREP_TOOL = Tool(
15
15
  function=ToolFunction(
16
16
  name="grep",
17
17
  description=(
18
- "Search for a pattern in files using ripgrep. "
18
+ "Search for a pattern in files. "
19
19
  "Returns matching lines with file paths and line numbers. "
20
20
  "Supports regex patterns by default."
21
21
  ),
@@ -54,7 +54,7 @@ GREP_TOOL = Tool(
54
54
 
55
55
 
56
56
  async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
57
- """Execute grep using ripgrep."""
57
+ """Execute grep using ripgrep (preferred) or standard grep (fallback)."""
58
58
  import shutil
59
59
  import subprocess
60
60
 
@@ -74,35 +74,55 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
74
74
  error="'pattern' is required",
75
75
  )
76
76
 
77
- # Find ripgrep
77
+ # Try ripgrep first, fall back to standard grep
78
78
  rg_path = shutil.which("rg")
79
- if not rg_path:
79
+ grep_path = shutil.which("grep")
80
+
81
+ if rg_path:
82
+ # Use ripgrep (faster, better defaults)
83
+ cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
84
+
85
+ if case_insensitive:
86
+ cmd.append("--ignore-case")
87
+
88
+ if context_lines:
89
+ cmd.extend(["--context", str(context_lines)])
90
+
91
+ if glob_pattern:
92
+ cmd.extend(["--glob", glob_pattern])
93
+
94
+ # Limit results
95
+ cmd.extend(["--max-count", str(max_results)])
96
+
97
+ cmd.append(pattern)
98
+ cmd.append(search_path)
99
+ use_ripgrep = True
100
+ elif grep_path:
101
+ # Fallback to standard grep
102
+ cmd = [grep_path, "-r", "-n", "--color=never"]
103
+
104
+ if case_insensitive:
105
+ cmd.append("-i")
106
+
107
+ if context_lines:
108
+ cmd.extend(["-C", str(context_lines)])
109
+
110
+ if glob_pattern:
111
+ # Standard grep uses --include for glob patterns
112
+ cmd.extend(["--include", glob_pattern])
113
+
114
+ cmd.append(pattern)
115
+ cmd.append(search_path)
116
+ use_ripgrep = False
117
+ else:
80
118
  return ToolResult(
81
119
  tool_call_id=tool_call.id,
82
120
  is_error=True,
83
121
  content="",
84
- error="ripgrep (rg) not found. Please install it: brew install ripgrep",
122
+ error="Neither ripgrep (rg) nor grep found. Please install one.",
85
123
  )
86
124
 
87
- # Build command
88
- cmd = [rg_path, "--line-number", "--no-heading", "--color=never"]
89
-
90
- if case_insensitive:
91
- cmd.append("--ignore-case")
92
-
93
- if context_lines:
94
- cmd.extend(["--context", str(context_lines)])
95
-
96
- if glob_pattern:
97
- cmd.extend(["--glob", glob_pattern])
98
-
99
- # Limit results
100
- cmd.extend(["--max-count", str(max_results)])
101
-
102
- cmd.append(pattern)
103
- cmd.append(search_path)
104
-
105
- # Run ripgrep
125
+ # Run the search
106
126
  try:
107
127
  result = subprocess.run(
108
128
  cmd,
@@ -126,13 +146,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
126
146
  error=f"Search failed: {e}",
127
147
  )
128
148
 
129
- # ripgrep returns exit code 1 for no matches (not an error)
149
+ # Both ripgrep and grep return exit code 1 for no matches (not an error)
130
150
  if result.returncode not in (0, 1):
151
+ tool_name = "ripgrep" if use_ripgrep else "grep"
131
152
  return ToolResult(
132
153
  tool_call_id=tool_call.id,
133
154
  is_error=True,
134
155
  content="",
135
- error=result.stderr or f"ripgrep exited with code {result.returncode}",
156
+ error=result.stderr or f"{tool_name} exited with code {result.returncode}",
136
157
  )
137
158
 
138
159
  output = result.stdout.strip()
@@ -143,8 +164,14 @@ async def exec_grep(tool_call: ToolCall, working_dir: Path) -> ToolResult:
143
164
  content=f"No matches found for pattern: {pattern}",
144
165
  )
145
166
 
146
- # Count matches
147
- match_count = len(output.split("\n"))
167
+ # Count matches and limit output for standard grep
168
+ lines = output.split("\n")
169
+ if not use_ripgrep and len(lines) > max_results:
170
+ lines = lines[:max_results]
171
+ output = "\n".join(lines)
172
+ output += f"\n... (truncated to {max_results} results)"
173
+
174
+ match_count = min(len(lines), max_results)
148
175
  header = f"Found {match_count} matches:\n\n"
149
176
 
150
177
  return ToolResult(
@@ -0,0 +1,196 @@
1
+ """Search documentation tool for GPU programming corpora.
2
+
3
+ Provides semantic and keyword search over documentation for CuTeDSL, CUDA, etc.
4
+
5
+ Corpora are downloaded via `wafer corpus download <name>` and stored in ~/.cache/wafer/corpora/.
6
+ """
7
+
8
+ import re
9
+ from pathlib import Path
10
+
11
+ from wafer_core.rollouts.dtypes import Tool, ToolCall, ToolFunction, ToolFunctionParameter, ToolResult
12
+
13
+ # Cache directory where wafer corpus download stores files
14
+ CACHE_DIR = Path.home() / ".cache" / "wafer" / "corpora"
15
+
16
+ # Available corpora (names match wafer corpus download)
17
+ AVAILABLE_CORPORA = ["cutlass", "cutedsl", "cuda", "hip", "amd"]
18
+
19
+ SEARCH_DOCS_TOOL = Tool(
20
+ type="function",
21
+ function=ToolFunction(
22
+ name="search_docs",
23
+ description="""Search GPU programming documentation for relevant information.
24
+
25
+ Use this tool to find documentation about:
26
+ - CUTLASS C++ (cute:: namespace, gemm tutorials, tensor cores, TMA, Blackwell)
27
+ - CuTeDSL Python API (@cute.kernel, @cute.jit, cute.arch functions)
28
+ - CUDA programming concepts
29
+ - GPU kernel optimization techniques
30
+ - Code examples and patterns
31
+
32
+ Available corpora:
33
+ - 'cutlass' - NVIDIA CUTLASS C++ docs + GitHub examples (gemm, hopper, blackwell)
34
+ - 'cutedsl' - CuTeDSL Python documentation
35
+ - 'cuda' - General CUDA programming docs
36
+ - 'hip' - AMD HIP programming docs
37
+ - 'amd' - AMD GPU kernel development (rocWMMA, CK, etc.)
38
+
39
+ Note: Corpora must be downloaded first with `wafer corpus download <name>`.
40
+ Returns relevant documentation snippets with file paths.""",
41
+ parameters=ToolFunctionParameter(
42
+ type="object",
43
+ properties={
44
+ "query": {
45
+ "type": "string",
46
+ "description": "Search query - describe what you're looking for",
47
+ },
48
+ "corpus": {
49
+ "type": "string",
50
+ "description": "Which docs to search: 'cutlass', 'cutedsl', 'cuda', 'hip', 'amd' (default: cutlass)",
51
+ },
52
+ "max_results": {
53
+ "type": "integer",
54
+ "description": "Maximum number of results to return (default: 5)",
55
+ },
56
+ },
57
+ ),
58
+ required=["query"],
59
+ )
60
+ )
61
+
62
+
63
+ def _get_corpus_path(corpus_name: str) -> Path | None:
64
+ """Get the path to a corpus in the cache directory.
65
+
66
+ Corpora are stored at ~/.cache/wafer/corpora/<corpus_name>/
67
+ """
68
+ if corpus_name not in AVAILABLE_CORPORA:
69
+ return None
70
+
71
+ corpus_path = CACHE_DIR / corpus_name
72
+ if corpus_path.exists():
73
+ return corpus_path
74
+
75
+ return None
76
+
77
+
78
+ def _search_files(corpus_path: Path, query: str, max_results: int = 5) -> list[dict]:
79
+ """Simple keyword search through documentation files."""
80
+ results = []
81
+ query_terms = query.lower().split()
82
+
83
+ # Search .md, .py, .cu, .hpp, and .h files (for CUTLASS examples)
84
+ for pattern in ["**/*.md", "**/*.py", "**/*.cu", "**/*.hpp", "**/*.h", "**/*.cuh"]:
85
+ for file_path in corpus_path.glob(pattern):
86
+ if file_path.is_file():
87
+ try:
88
+ content = file_path.read_text(encoding="utf-8", errors="ignore")
89
+ content_lower = content.lower()
90
+
91
+ # Score based on term matches
92
+ score = sum(content_lower.count(term) for term in query_terms)
93
+
94
+ if score > 0:
95
+ # Extract relevant snippets
96
+ snippets = _extract_snippets(content, query_terms)
97
+ results.append({
98
+ "file": str(file_path), # Return absolute path so read tool can access it
99
+ "score": score,
100
+ "snippets": snippets[:3], # Top 3 snippets
101
+ })
102
+ except Exception:
103
+ continue
104
+
105
+ # Sort by score and return top results
106
+ results.sort(key=lambda x: x["score"], reverse=True)
107
+ return results[:max_results]
108
+
109
+
110
+ def _extract_snippets(content: str, terms: list[str], context_lines: int = 5) -> list[str]:
111
+ """Extract snippets containing search terms."""
112
+ snippets = []
113
+ lines = content.split("\n")
114
+
115
+ for i, line in enumerate(lines):
116
+ line_lower = line.lower()
117
+ if any(term in line_lower for term in terms):
118
+ # Get context around the match
119
+ start = max(0, i - context_lines)
120
+ end = min(len(lines), i + context_lines + 1)
121
+ snippet = "\n".join(lines[start:end])
122
+
123
+ # Skip very short snippets
124
+ if len(snippet.strip()) > 50:
125
+ snippets.append(snippet)
126
+
127
+ return snippets
128
+
129
+
130
+ async def exec_search_docs(
131
+ tool_call: ToolCall,
132
+ corpus_override: str | None = None,
133
+ ) -> ToolResult:
134
+ """Execute search_docs tool.
135
+
136
+ Args:
137
+ tool_call: The tool call with query and optional corpus
138
+ corpus_override: Override corpus path (for testing)
139
+ """
140
+ query = tool_call.args.get("query", "")
141
+ corpus_name = tool_call.args.get("corpus", "cutlass")
142
+ max_results = tool_call.args.get("max_results", 5)
143
+
144
+ if not query:
145
+ return ToolResult(
146
+ tool_call_id=tool_call.id,
147
+ content="",
148
+ error="query parameter is required",
149
+ )
150
+
151
+ # Find corpus path
152
+ if corpus_override:
153
+ corpus_path = Path(corpus_override)
154
+ else:
155
+ corpus_path = _get_corpus_path(corpus_name)
156
+ if corpus_path is None:
157
+ return ToolResult(
158
+ tool_call_id=tool_call.id,
159
+ content="",
160
+ error=f"Unknown corpus: {corpus_name}. Available: {AVAILABLE_CORPORA}",
161
+ )
162
+
163
+ if not corpus_path.exists():
164
+ return ToolResult(
165
+ tool_call_id=tool_call.id,
166
+ content="",
167
+ error=f"Corpus '{corpus_name}' not downloaded. Run: wafer corpus download {corpus_name}",
168
+ )
169
+
170
+ # Search
171
+ results = _search_files(corpus_path, query, max_results)
172
+
173
+ if not results:
174
+ return ToolResult(
175
+ tool_call_id=tool_call.id,
176
+ content=f"No results found for query: {query}",
177
+ error=None,
178
+ )
179
+
180
+ # Format output
181
+ output_parts = [f"Found {len(results)} results for: {query}\n"]
182
+
183
+ for i, result in enumerate(results, 1):
184
+ output_parts.append(f"\n{'='*60}")
185
+ output_parts.append(f"[{i}] {result['file']} (score: {result['score']})")
186
+ output_parts.append("=" * 60)
187
+
188
+ for snippet in result["snippets"]:
189
+ output_parts.append(snippet)
190
+ output_parts.append("-" * 40)
191
+
192
+ return ToolResult(
193
+ tool_call_id=tool_call.id,
194
+ content="\n".join(output_parts),
195
+ error=None,
196
+ )
@@ -0,0 +1,64 @@
1
+ """Skill tool.
2
+
3
+ Loads skill content on demand from ~/.wafer/skills/ or bundled locations.
4
+ """
5
+
6
+ from wafer_core.rollouts.dtypes import (
7
+ Tool,
8
+ ToolCall,
9
+ ToolFunction,
10
+ ToolFunctionParameter,
11
+ ToolResult,
12
+ )
13
+
14
+ # ── Tool Definition ──────────────────────────────────────────────────────────
15
+
16
+ SKILL_TOOL = Tool(
17
+ type="function",
18
+ function=ToolFunction(
19
+ name="skill",
20
+ description="Load a skill's full instructions. Skills provide domain-specific knowledge and workflows. Use this when you need detailed guidance for a task mentioned in your available skills.",
21
+ parameters=ToolFunctionParameter(
22
+ type="object",
23
+ properties={
24
+ "name": {
25
+ "type": "string",
26
+ "description": "Name of the skill to load (e.g., 'wafer-guide')",
27
+ },
28
+ },
29
+ ),
30
+ required=["name"],
31
+ ),
32
+ )
33
+
34
+
35
+ # ── Pure Function Executor ───────────────────────────────────────────────────
36
+
37
+
38
+ async def exec_skill(tool_call: ToolCall) -> ToolResult:
39
+ """Load a skill's full instructions.
40
+
41
+ Args:
42
+ tool_call: The tool call with skill name.
43
+ """
44
+ from wafer_core.rollouts.skills import discover_skills, load_skill
45
+
46
+ skill_name = tool_call.args["name"]
47
+ skill = load_skill(skill_name)
48
+
49
+ if skill is None:
50
+ available = discover_skills()
51
+ available_names = [s.name for s in available]
52
+ return ToolResult(
53
+ tool_call_id=tool_call.id,
54
+ is_error=True,
55
+ content="",
56
+ error=f"Skill not found: {skill_name}. Available skills: {', '.join(available_names) or 'none'}",
57
+ )
58
+
59
+ header = f"# Skill: {skill.name}\n\n"
60
+ return ToolResult(
61
+ tool_call_id=tool_call.id,
62
+ is_error=False,
63
+ content=header + skill.content,
64
+ )
@@ -33,6 +33,9 @@ def get_auth_token() -> str | None:
33
33
  Note:
34
34
  In local dev mode (localhost), no token is required.
35
35
  The API will use LOCAL_DEV_MODE to bypass auth.
36
+
37
+ Callers (like wevin-extension) should pass WAFER_AUTH_TOKEN
38
+ as an environment variable when spawning Python processes.
36
39
  """
37
40
  return os.environ.get("WAFER_AUTH_TOKEN")
38
41