wafer-core 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wafer_core/auth.py CHANGED
@@ -41,6 +41,18 @@ PROVIDERS = {
41
41
  "display_name": "Modal",
42
42
  "key_url": "https://modal.com/settings",
43
43
  },
44
+ "anthropic": {
45
+ "env_var": "ANTHROPIC_API_KEY",
46
+ "alt_env_var": "WAFER_ANTHROPIC_API_KEY", # Check this first
47
+ "display_name": "Anthropic",
48
+ "key_url": "https://console.anthropic.com/settings/keys",
49
+ },
50
+ "openai": {
51
+ "env_var": "OPENAI_API_KEY",
52
+ "alt_env_var": "WAFER_OPENAI_KEY", # Check this first
53
+ "display_name": "OpenAI",
54
+ "key_url": "https://platform.openai.com/api-keys",
55
+ },
44
56
  }
45
57
 
46
58
 
@@ -78,11 +90,12 @@ def get_api_key(provider: str) -> str | None:
78
90
  """Get API key for a provider.
79
91
 
80
92
  Checks in order:
81
- 1. Environment variable (e.g., WAFER_RUNPOD_API_KEY)
82
- 2. ~/.wafer/auth.json
93
+ 1. Alt environment variable if defined (e.g., WAFER_ANTHROPIC_API_KEY)
94
+ 2. Primary environment variable (e.g., ANTHROPIC_API_KEY)
95
+ 3. ~/.wafer/auth.json
83
96
 
84
97
  Args:
85
- provider: Provider name (runpod, digitalocean, modal)
98
+ provider: Provider name (runpod, digitalocean, modal, anthropic, openai)
86
99
 
87
100
  Returns:
88
101
  API key string or None if not found
@@ -92,7 +105,13 @@ def get_api_key(provider: str) -> str | None:
92
105
 
93
106
  config = PROVIDERS[provider]
94
107
 
95
- # Check environment variable first
108
+ # Check alt environment variable first (e.g., WAFER_ANTHROPIC_API_KEY)
109
+ if "alt_env_var" in config:
110
+ alt_key = os.environ.get(config["alt_env_var"], "").strip()
111
+ if alt_key:
112
+ return alt_key
113
+
114
+ # Check primary environment variable
96
115
  env_key = os.environ.get(config["env_var"], "").strip()
97
116
  if env_key:
98
117
  return env_key
@@ -154,7 +173,7 @@ def get_auth_status(provider: str) -> AuthStatus:
154
173
  """Get authentication status for a provider.
155
174
 
156
175
  Args:
157
- provider: Provider name (runpod, digitalocean, modal)
176
+ provider: Provider name (runpod, digitalocean, modal, anthropic, openai)
158
177
 
159
178
  Returns:
160
179
  AuthStatus with details about the auth state
@@ -164,7 +183,20 @@ def get_auth_status(provider: str) -> AuthStatus:
164
183
 
165
184
  config = PROVIDERS[provider]
166
185
 
167
- # Check environment variable first
186
+ # Check alt environment variable first (e.g., WAFER_ANTHROPIC_API_KEY)
187
+ if "alt_env_var" in config:
188
+ alt_key = os.environ.get(config["alt_env_var"], "").strip()
189
+ if alt_key:
190
+ return AuthStatus(
191
+ provider=provider,
192
+ display_name=config["display_name"],
193
+ is_authenticated=True,
194
+ source="env",
195
+ key_preview=_format_key_preview(alt_key),
196
+ key_url=config["key_url"],
197
+ )
198
+
199
+ # Check primary environment variable
168
200
  env_key = os.environ.get(config["env_var"], "").strip()
169
201
  if env_key:
170
202
  return AuthStatus(
@@ -34,6 +34,7 @@ from wafer_core.tools import (
34
34
  GLOB_TOOL,
35
35
  GREP_TOOL,
36
36
  READ_TOOL,
37
+ SKILL_TOOL,
37
38
  WRITE_TOOL,
38
39
  ApprovalCallback,
39
40
  exec_bash,
@@ -41,6 +42,7 @@ from wafer_core.tools import (
41
42
  exec_glob,
42
43
  exec_grep,
43
44
  exec_read,
45
+ exec_skill,
44
46
  exec_write,
45
47
  )
46
48
 
@@ -61,6 +63,7 @@ ALL_TOOLS = {
61
63
  "glob": GLOB_TOOL,
62
64
  "grep": GREP_TOOL,
63
65
  "bash": BASH_TOOL,
66
+ "skill": SKILL_TOOL,
64
67
  # TODO(wafer-tool): "wafer": WAFER_TOOL,
65
68
  }
66
69
 
@@ -208,6 +211,7 @@ class CodingEnvironment:
208
211
  self.bash_approval_callback,
209
212
  self._sandbox_policy,
210
213
  ),
214
+ "skill": lambda tc: exec_skill(tc),
211
215
  # TODO(wafer-tool): "wafer": lambda tc: exec_wafer(
212
216
  # tc, self.working_dir, self.enabled_tools, self.allow_spawn, cancel_scope
213
217
  # ),
@@ -1562,6 +1562,10 @@ class EvalConfig:
1562
1562
  resume_dir: Path | None = None
1563
1563
  report_batch_size: int = 1 # Write report after each sample for best recovery
1564
1564
 
1565
+ # Custom metadata (flows to report.json for dashboard filtering)
1566
+ # e.g., {"waferbench_category": "gemm", "github_runner": "elliot"}
1567
+ metadata: dict[str, Any] | None = None
1568
+
1565
1569
 
1566
1570
  # ── Session Types ──────────────────────────────────────────────────────────────
1567
1571
  # Types for persisting agent sessions (trajectories, config, environment state).
@@ -331,9 +331,9 @@ def generate_diff(old_content: str, new_content: str, context_lines: int = 3) ->
331
331
 
332
332
  # Tool preset configurations
333
333
  TOOL_PRESETS = {
334
- "full": ["read", "write", "edit", "bash", "web_fetch"],
334
+ "full": ["read", "write", "edit", "bash", "web_fetch", "skill"],
335
335
  "readonly": ["read"],
336
- "no-write": ["read", "edit", "bash", "web_fetch"],
336
+ "no-write": ["read", "edit", "bash", "web_fetch", "skill"],
337
337
  }
338
338
 
339
339
 
@@ -630,6 +630,24 @@ class LocalFilesystemEnvironment:
630
630
  required=["url", "prompt"],
631
631
  ),
632
632
  ),
633
+ # skill tool
634
+ Tool(
635
+ type="function",
636
+ function=ToolFunction(
637
+ name="skill",
638
+ description="Load a skill's full instructions. Skills provide domain-specific knowledge and workflows. Use this when you need detailed guidance for a task mentioned in your available skills.",
639
+ parameters=ToolFunctionParameter(
640
+ type="object",
641
+ properties={
642
+ "name": {
643
+ "type": "string",
644
+ "description": "Name of the skill to load (e.g., 'wafer-guide')",
645
+ },
646
+ },
647
+ ),
648
+ required=["name"],
649
+ ),
650
+ ),
633
651
  ]
634
652
 
635
653
  async def on_assistant_message(self, message: Message, state: AgentState) -> AgentState:
@@ -655,6 +673,8 @@ class LocalFilesystemEnvironment:
655
673
  return await self._exec_bash(tool_call, current_state.session_id, cancel_scope)
656
674
  elif tool_call.name == "web_fetch":
657
675
  return await self._exec_web_fetch(tool_call, current_state.session_id)
676
+ elif tool_call.name == "skill":
677
+ return await self._exec_skill(tool_call)
658
678
  else:
659
679
  return ToolResult(
660
680
  tool_call_id=tool_call.id,
@@ -1155,3 +1175,31 @@ class LocalFilesystemEnvironment:
1155
1175
  content=header + final_content,
1156
1176
  details={"output_file": output_file_path} if output_file_path else None,
1157
1177
  )
1178
+
1179
+ async def _exec_skill(self, tool_call: ToolCall) -> ToolResult:
1180
+ """Load a skill's full instructions."""
1181
+ from ..skills import load_skill
1182
+
1183
+ skill_name = tool_call.args["name"]
1184
+ skill = load_skill(skill_name)
1185
+
1186
+ if skill is None:
1187
+ # List available skills in error message
1188
+ from ..skills import discover_skills
1189
+
1190
+ available = discover_skills()
1191
+ available_names = [s.name for s in available]
1192
+ return ToolResult(
1193
+ tool_call_id=tool_call.id,
1194
+ is_error=True,
1195
+ content="",
1196
+ error=f"Skill not found: {skill_name}. Available skills: {', '.join(available_names) or 'none'}",
1197
+ )
1198
+
1199
+ # Return the full skill content
1200
+ header = f"# Skill: {skill.name}\n\n"
1201
+ return ToolResult(
1202
+ tool_call_id=tool_call.id,
1203
+ is_error=False,
1204
+ content=header + skill.content,
1205
+ )
@@ -642,6 +642,7 @@ class EvalReport:
642
642
  timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
643
643
  git_info: dict[str, Any] = field(default_factory=_get_git_info)
644
644
  config_path: str | None = None # Path to config file relative to repo root
645
+ metadata: dict[str, Any] | None = None # Custom metadata (waferbench_category, github_runner, etc.)
645
646
 
646
647
  async def save(self, output_dir: Path) -> None:
647
648
  """Save evaluation results to directory."""
@@ -675,6 +676,8 @@ class EvalReport:
675
676
  "config_path": self.config_path,
676
677
  "sample_ids": [s.id for s in self.sample_results],
677
678
  }
679
+ if self.metadata:
680
+ summary["metadata"] = self.metadata
678
681
  # Sanitize API keys in the summary before saving
679
682
  summary = sanitize_api_keys(summary)
680
683
  report_file = output_dir / "report.json"
@@ -761,6 +764,9 @@ def _write_partial_report(
761
764
  "config_path": config.config_path,
762
765
  }
763
766
 
767
+ if config.metadata:
768
+ report["metadata"] = config.metadata
769
+
764
770
  if resume_from:
765
771
  report["resume_from"] = resume_from
766
772
 
@@ -1279,6 +1285,7 @@ async def evaluate(
1279
1285
  summary_metrics={},
1280
1286
  sample_results=[],
1281
1287
  config={"resumed_from": str(config.resume_dir)},
1288
+ metadata=config.metadata,
1282
1289
  )
1283
1290
 
1284
1291
  if config.verbose:
@@ -1489,6 +1496,7 @@ async def evaluate(
1489
1496
  "evaluation_timestamp": datetime.now().isoformat(),
1490
1497
  },
1491
1498
  config_path=config.config_path,
1499
+ metadata=config.metadata,
1492
1500
  )
1493
1501
 
1494
1502
  # Save if output directory specified
@@ -1546,7 +1554,7 @@ def compute_summary_metrics(results: list[Sample]) -> dict[str, float]:
1546
1554
  for m in r.score.metrics:
1547
1555
  all_metric_names.add(m.name)
1548
1556
 
1549
- # Compute mean, min, max, std for each metric
1557
+ # Compute mean, median, min, max, std for each metric
1550
1558
  for metric_name in all_metric_names:
1551
1559
  values = []
1552
1560
  for r in results:
@@ -1557,7 +1565,15 @@ def compute_summary_metrics(results: list[Sample]) -> dict[str, float]:
1557
1565
  break
1558
1566
  if values:
1559
1567
  mean_val = sum(values) / len(values)
1568
+ sorted_values = sorted(values)
1569
+ n = len(sorted_values)
1570
+ if n % 2 == 0:
1571
+ median_val = (sorted_values[n // 2 - 1] + sorted_values[n // 2]) / 2
1572
+ else:
1573
+ median_val = sorted_values[n // 2]
1574
+
1560
1575
  summary[f"mean_{metric_name}"] = mean_val
1576
+ summary[f"median_{metric_name}"] = median_val
1561
1577
  summary[f"min_{metric_name}"] = min(values)
1562
1578
  summary[f"max_{metric_name}"] = max(values)
1563
1579
  summary[f"std_{metric_name}"] = (
@@ -232,7 +232,7 @@ Detailed docs: {docs}
232
232
  If asked about your capabilities, read these files."""
233
233
 
234
234
 
235
- def build_system_prompt(
235
+ def build_system_prompt( # noqa: PLR0913
236
236
  env_name: str,
237
237
  tools: list[Tool],
238
238
  cwd: Path | None = None,
@@ -240,6 +240,7 @@ def build_system_prompt(
240
240
  env_system_prompt: str | None = None,
241
241
  include_self_docs: bool = True,
242
242
  include_project_context: bool = True,
243
+ include_skills: bool = False,
243
244
  ) -> str:
244
245
  """Build complete system prompt with dynamic tool info.
245
246
 
@@ -251,6 +252,7 @@ def build_system_prompt(
251
252
  env_system_prompt: Environment-provided system prompt (from env.get_system_prompt())
252
253
  include_self_docs: Whether to include rollouts documentation paths
253
254
  include_project_context: Whether to load AGENTS.md/ROLLOUTS.md files
255
+ include_skills: Whether to discover and list available skills
254
256
  """
255
257
  # Assertions (Tiger Style: 2+ per function, split compound)
256
258
  assert env_name, "env_name required"
@@ -277,11 +279,19 @@ def build_system_prompt(
277
279
  if guidelines:
278
280
  sections.append("Guidelines:\n" + "\n".join(f"- {g}" for g in guidelines))
279
281
 
280
- # 5. Self-documentation
282
+ # 5. Available skills (metadata only - agent loads full content via skill tool)
283
+ if include_skills:
284
+ from .skills import discover_skills, format_skill_metadata_for_prompt
285
+
286
+ skill_metadata = discover_skills()
287
+ if skill_metadata:
288
+ sections.append(format_skill_metadata_for_prompt(skill_metadata))
289
+
290
+ # 6. Self-documentation
281
291
  if include_self_docs:
282
292
  sections.append(build_self_doc_section())
283
293
 
284
- # 6. Project context files (AGENTS.md, ROLLOUTS.md, etc.)
294
+ # 7. Project context files (AGENTS.md, ROLLOUTS.md, etc.)
285
295
  if include_project_context:
286
296
  context_files = load_project_context(working_dir)
287
297
  if context_files:
@@ -290,7 +300,7 @@ def build_system_prompt(
290
300
  ctx_section += f"\n## {path}\n\n{content}\n"
291
301
  sections.append(ctx_section)
292
302
 
293
- # 7. Runtime context
303
+ # 8. Runtime context
294
304
  now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
295
305
  sections.append(f"Current time: {now}\nWorking directory: {working_dir}")
296
306
 
@@ -0,0 +1,176 @@
1
+ """Skill discovery and loading.
2
+
3
+ Skills are documentation files that agents can load on demand.
4
+ Format follows agentskills.io spec: SKILL.md with YAML frontmatter.
5
+
6
+ Discovery order:
7
+ 1. ~/.wafer/skills/{name}/SKILL.md (user-installed)
8
+ 2. Bundled skills (wafer-cli package)
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+
16
+ from .paths import get_config_dir
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class SkillMetadata:
21
+ """Lightweight skill metadata for system prompt injection."""
22
+
23
+ name: str
24
+ description: str
25
+ path: Path # Path to SKILL.md file
26
+
27
+
28
+ @dataclass(frozen=True)
29
+ class Skill:
30
+ """Full skill with content."""
31
+
32
+ name: str
33
+ description: str
34
+ content: str # Full markdown content (without frontmatter)
35
+ path: Path
36
+
37
+
38
+ def _parse_skill_file(path: Path) -> tuple[dict[str, str], str] | None:
39
+ """Parse SKILL.md file into (frontmatter, content).
40
+
41
+ Returns None if file doesn't exist or is malformed.
42
+ """
43
+ if not path.exists():
44
+ return None
45
+
46
+ try:
47
+ text = path.read_text()
48
+ except (OSError, PermissionError):
49
+ return None
50
+
51
+ # Parse YAML frontmatter (between --- markers)
52
+ if not text.startswith("---"):
53
+ return None
54
+
55
+ # Find closing ---
56
+ end_idx = text.find("---", 3)
57
+ if end_idx == -1:
58
+ return None
59
+
60
+ frontmatter_text = text[3:end_idx].strip()
61
+ content = text[end_idx + 3 :].strip()
62
+
63
+ # Parse YAML (simple key: value format, no dependencies)
64
+ frontmatter: dict[str, str] = {}
65
+ for raw_line in frontmatter_text.split("\n"):
66
+ stripped = raw_line.strip()
67
+ if not stripped or ":" not in stripped:
68
+ continue
69
+ key, _, value = stripped.partition(":")
70
+ frontmatter[key.strip()] = value.strip()
71
+
72
+ # Validate required fields
73
+ if "name" not in frontmatter or "description" not in frontmatter:
74
+ return None
75
+
76
+ return frontmatter, content
77
+
78
+
79
+ def _get_bundled_skills_dir() -> Path | None:
80
+ """Get path to bundled skills in wafer-cli package."""
81
+ # Try to find wafer-cli's skills directory
82
+ try:
83
+ import wafer
84
+
85
+ wafer_cli_path = Path(wafer.__file__).parent
86
+ skills_dir = wafer_cli_path / "skills"
87
+ if skills_dir.exists():
88
+ return skills_dir
89
+ except ImportError:
90
+ pass
91
+
92
+ return None
93
+
94
+
95
+ def discover_skills() -> list[SkillMetadata]:
96
+ """Discover all available skills.
97
+
98
+ Returns list of SkillMetadata (name + description only).
99
+ """
100
+ skills: dict[str, SkillMetadata] = {}
101
+
102
+ # 1. User-installed skills (~/.wafer/skills/)
103
+ user_skills_dir = get_config_dir() / "skills"
104
+ if user_skills_dir.exists():
105
+ for skill_dir in user_skills_dir.iterdir():
106
+ if not skill_dir.is_dir():
107
+ continue
108
+ skill_file = skill_dir / "SKILL.md"
109
+ parsed = _parse_skill_file(skill_file)
110
+ if parsed:
111
+ frontmatter, _ = parsed
112
+ skills[frontmatter["name"]] = SkillMetadata(
113
+ name=frontmatter["name"],
114
+ description=frontmatter["description"],
115
+ path=skill_file,
116
+ )
117
+
118
+ # 2. Bundled skills (wafer-cli package)
119
+ bundled_dir = _get_bundled_skills_dir()
120
+ if bundled_dir:
121
+ for skill_dir in bundled_dir.iterdir():
122
+ if not skill_dir.is_dir():
123
+ continue
124
+ skill_file = skill_dir / "SKILL.md"
125
+ parsed = _parse_skill_file(skill_file)
126
+ if parsed:
127
+ frontmatter, _ = parsed
128
+ # User skills take precedence
129
+ if frontmatter["name"] not in skills:
130
+ skills[frontmatter["name"]] = SkillMetadata(
131
+ name=frontmatter["name"],
132
+ description=frontmatter["description"],
133
+ path=skill_file,
134
+ )
135
+
136
+ return list(skills.values())
137
+
138
+
139
+ def load_skill(name: str) -> Skill | None:
140
+ """Load a skill by name.
141
+
142
+ Returns full Skill with content, or None if not found.
143
+ """
144
+ # Find the skill
145
+ for metadata in discover_skills():
146
+ if metadata.name == name:
147
+ parsed = _parse_skill_file(metadata.path)
148
+ if parsed:
149
+ frontmatter, content = parsed
150
+ return Skill(
151
+ name=frontmatter["name"],
152
+ description=frontmatter["description"],
153
+ content=content,
154
+ path=metadata.path,
155
+ )
156
+ return None
157
+
158
+
159
+ def format_skill_metadata_for_prompt(skills: list[SkillMetadata]) -> str:
160
+ """Format skill metadata for system prompt injection.
161
+
162
+ Returns a compact section listing available skills.
163
+ """
164
+ if not skills:
165
+ return ""
166
+
167
+ lines = ["## Available Skills", ""]
168
+ lines.append(
169
+ "You have access to the following skills. Use the `skill` tool to load full instructions when needed."
170
+ )
171
+ lines.append("")
172
+
173
+ for skill in skills:
174
+ lines.append(f"- **{skill.name}**: {skill.description}")
175
+
176
+ return "\n".join(lines)
@@ -222,6 +222,9 @@ class TemplateConfig:
222
222
  # Example: {"corpus": "./docs/", "format": "markdown"}
223
223
  defaults: dict[str, str] = field(default_factory=dict)
224
224
 
225
+ # Skill discovery - if True, discovers skills and adds skill tool
226
+ include_skills: bool = False
227
+
225
228
  def interpolate_prompt(self, args: dict[str, str] | None = None) -> str:
226
229
  """Interpolate template variables into the system prompt.
227
230
 
@@ -49,6 +49,10 @@ from wafer_core.tools.rocprof_systems_tools import (
49
49
  exec_rocprof_systems_query,
50
50
  exec_rocprof_systems_sample,
51
51
  )
52
+ from wafer_core.tools.skill_tool import (
53
+ SKILL_TOOL,
54
+ exec_skill,
55
+ )
52
56
  from wafer_core.tools.tracelens_tools import (
53
57
  TRACELENS_COLLECTIVE_TOOL,
54
58
  TRACELENS_COMPARE_TOOL,
@@ -88,6 +92,9 @@ __all__ = [
88
92
  "BashPermissionResult",
89
93
  "check_bash_permissions",
90
94
  "exec_bash",
95
+ # Skill tool
96
+ "SKILL_TOOL",
97
+ "exec_skill",
91
98
  # Wafer tool
92
99
  "WAFER_TOOL",
93
100
  "WAFER_SUBCOMMANDS",
@@ -0,0 +1,64 @@
1
+ """Skill tool.
2
+
3
+ Loads skill content on demand from ~/.wafer/skills/ or bundled locations.
4
+ """
5
+
6
+ from wafer_core.rollouts.dtypes import (
7
+ Tool,
8
+ ToolCall,
9
+ ToolFunction,
10
+ ToolFunctionParameter,
11
+ ToolResult,
12
+ )
13
+
14
+ # ── Tool Definition ──────────────────────────────────────────────────────────
15
+
16
+ SKILL_TOOL = Tool(
17
+ type="function",
18
+ function=ToolFunction(
19
+ name="skill",
20
+ description="Load a skill's full instructions. Skills provide domain-specific knowledge and workflows. Use this when you need detailed guidance for a task mentioned in your available skills.",
21
+ parameters=ToolFunctionParameter(
22
+ type="object",
23
+ properties={
24
+ "name": {
25
+ "type": "string",
26
+ "description": "Name of the skill to load (e.g., 'wafer-guide')",
27
+ },
28
+ },
29
+ ),
30
+ required=["name"],
31
+ ),
32
+ )
33
+
34
+
35
+ # ── Pure Function Executor ───────────────────────────────────────────────────
36
+
37
+
38
+ async def exec_skill(tool_call: ToolCall) -> ToolResult:
39
+ """Load a skill's full instructions.
40
+
41
+ Args:
42
+ tool_call: The tool call with skill name.
43
+ """
44
+ from wafer_core.rollouts.skills import discover_skills, load_skill
45
+
46
+ skill_name = tool_call.args["name"]
47
+ skill = load_skill(skill_name)
48
+
49
+ if skill is None:
50
+ available = discover_skills()
51
+ available_names = [s.name for s in available]
52
+ return ToolResult(
53
+ tool_call_id=tool_call.id,
54
+ is_error=True,
55
+ content="",
56
+ error=f"Skill not found: {skill_name}. Available skills: {', '.join(available_names) or 'none'}",
57
+ )
58
+
59
+ header = f"# Skill: {skill.name}\n\n"
60
+ return ToolResult(
61
+ tool_call_id=tool_call.id,
62
+ is_error=False,
63
+ content=header + skill.content,
64
+ )
@@ -33,6 +33,9 @@ def get_auth_token() -> str | None:
33
33
  Note:
34
34
  In local dev mode (localhost), no token is required.
35
35
  The API will use LOCAL_DEV_MODE to bypass auth.
36
+
37
+ Callers (like wevin-extension) should pass WAFER_AUTH_TOKEN
38
+ as an environment variable when spawning Python processes.
36
39
  """
37
40
  return os.environ.get("WAFER_AUTH_TOKEN")
38
41
 
@@ -61,7 +61,9 @@ class BaremetalTarget:
61
61
  ncu_available: bool = True # Baremetal typically has NCU
62
62
 
63
63
  # Docker execution config (Modal-like). If docker_image is set, run in container.
64
- docker_image: str | None = None # Docker image to use (e.g., "nvcr.io/nvidia/cutlass:4.3-devel")
64
+ docker_image: str | None = (
65
+ None # Docker image to use (e.g., "nvcr.io/nvidia/cutlass:4.3-devel")
66
+ )
65
67
  pip_packages: tuple[str, ...] = () # Packages to install via uv pip install
66
68
  torch_package: str | None = None # Torch package spec (e.g., "torch>=2.8.0")
67
69
  torch_index_url: str | None = None # Custom index for torch (e.g., PyTorch nightly)
@@ -69,7 +71,9 @@ class BaremetalTarget:
69
71
  def __post_init__(self) -> None:
70
72
  """Validate configuration."""
71
73
  assert len(self.gpu_ids) > 0, "Must specify at least one GPU ID"
72
- assert ":" in self.ssh_target, f"ssh_target must include port (user@host:port), got: {self.ssh_target}"
74
+ assert ":" in self.ssh_target, (
75
+ f"ssh_target must include port (user@host:port), got: {self.ssh_target}"
76
+ )
73
77
  # If torch_index_url is set, torch_package must also be set
74
78
  if self.torch_index_url:
75
79
  assert self.torch_package, "torch_package must be set when torch_index_url is provided"
@@ -114,7 +118,9 @@ class VMTarget:
114
118
  ncu_available: bool = False # VMs typically don't have NCU
115
119
 
116
120
  # Docker execution config (Modal-like). If docker_image is set, run in container.
117
- docker_image: str | None = None # Docker image to use (e.g., "nvcr.io/nvidia/pytorch:24.01-py3")
121
+ docker_image: str | None = (
122
+ None # Docker image to use (e.g., "nvcr.io/nvidia/pytorch:24.01-py3")
123
+ )
118
124
  pip_packages: tuple[str, ...] = () # Packages to install via uv pip install
119
125
  torch_package: str | None = None # Torch package spec (e.g., "torch>=2.8.0")
120
126
  torch_index_url: str | None = None # Custom index for torch (e.g., PyTorch nightly)
@@ -122,7 +128,9 @@ class VMTarget:
122
128
  def __post_init__(self) -> None:
123
129
  """Validate configuration."""
124
130
  assert len(self.gpu_ids) > 0, "Must specify at least one GPU ID"
125
- assert ":" in self.ssh_target, f"ssh_target must include port (user@host:port), got: {self.ssh_target}"
131
+ assert ":" in self.ssh_target, (
132
+ f"ssh_target must include port (user@host:port), got: {self.ssh_target}"
133
+ )
126
134
  # If torch_index_url is set, torch_package must also be set
127
135
  if self.torch_index_url:
128
136
  assert self.torch_package, "torch_package must be set when torch_index_url is provided"
@@ -325,12 +333,13 @@ class RunPodTarget:
325
333
 
326
334
  # Check for API key (env var or ~/.wafer/auth.json)
327
335
  api_key = get_api_key("runpod")
328
- assert api_key, (
329
- "RunPod API key not found.\n"
330
- "Set WAFER_RUNPOD_API_KEY environment variable, or run:\n"
331
- " wafer auth login runpod\n"
332
- "Get your API key from: https://runpod.io/console/user/settings"
333
- )
336
+ if not api_key:
337
+ raise ValueError(
338
+ "RunPod API key not found.\n"
339
+ "Set WAFER_RUNPOD_API_KEY environment variable, or run:\n"
340
+ " wafer auth login runpod\n"
341
+ "Get your API key from: https://runpod.io/console/user/settings"
342
+ )
334
343
 
335
344
 
336
345
  @dataclass(frozen=True)
@@ -370,7 +379,9 @@ class LocalTarget:
370
379
  """Validate configuration."""
371
380
  assert self.name, "name cannot be empty"
372
381
  assert len(self.gpu_ids) > 0, "Must specify at least one GPU ID"
373
- assert self.vendor in ("nvidia", "amd"), f"vendor must be 'nvidia' or 'amd', got: {self.vendor}"
382
+ assert self.vendor in ("nvidia", "amd"), (
383
+ f"vendor must be 'nvidia' or 'amd', got: {self.vendor}"
384
+ )
374
385
 
375
386
 
376
387
  @dataclass(frozen=True)
@@ -443,16 +454,25 @@ class DigitalOceanTarget:
443
454
 
444
455
  # Check for API key (env var or ~/.wafer/auth.json)
445
456
  api_key = get_api_key("digitalocean")
446
- assert api_key, (
447
- "DigitalOcean API key not found.\n"
448
- "Set WAFER_AMD_DIGITALOCEAN_API_KEY environment variable, or run:\n"
449
- " wafer auth login digitalocean\n"
450
- "Get your API key from: https://cloud.digitalocean.com/account/api/tokens"
451
- )
457
+ if not api_key:
458
+ raise ValueError(
459
+ "DigitalOcean API key not found.\n"
460
+ "Set WAFER_AMD_DIGITALOCEAN_API_KEY environment variable, or run:\n"
461
+ " wafer auth login digitalocean\n"
462
+ "Get your API key from: https://cloud.digitalocean.com/account/api/tokens"
463
+ )
452
464
 
453
465
 
454
466
  # Union type for target configs
455
- TargetConfig = BaremetalTarget | VMTarget | ModalTarget | WorkspaceTarget | RunPodTarget | DigitalOceanTarget | LocalTarget
467
+ TargetConfig = (
468
+ BaremetalTarget
469
+ | VMTarget
470
+ | ModalTarget
471
+ | WorkspaceTarget
472
+ | RunPodTarget
473
+ | DigitalOceanTarget
474
+ | LocalTarget
475
+ )
456
476
 
457
477
 
458
478
  # Type guard functions for pattern matching
@@ -517,9 +537,9 @@ def target_to_deployment_config(target: TargetConfig, gpu_id: int) -> Deployment
517
537
  from wafer_core.utils.kernel_utils.deployment import DeploymentConfig
518
538
 
519
539
  # Type narrowing: Only SSH-based targets supported (not Modal)
520
- assert not isinstance(
521
- target, ModalTarget
522
- ), f"target_to_deployment_config only supports SSH targets, got {type(target).__name__}"
540
+ assert not isinstance(target, ModalTarget), (
541
+ f"target_to_deployment_config only supports SSH targets, got {type(target).__name__}"
542
+ )
523
543
 
524
544
  return DeploymentConfig(
525
545
  ssh_target=target.ssh_target,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.20
3
+ Version: 0.1.22
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0
@@ -1,6 +1,6 @@
1
1
  wafer_core/__init__.py,sha256=syB2JjzvL91otODCKsGCgCPE6gVCINhTUv9j98k6tL0,3209
2
2
  wafer_core/async_ssh.py,sha256=ocw2Gh5p8ltKeoqG_q32DXOBfu5q-IE7jCnzMbQN9WI,28713
3
- wafer_core/auth.py,sha256=U1AeQE_Z5ev0ltXAEIh0STFIv46mie9G8Vy34oOyowE,5936
3
+ wafer_core/auth.py,sha256=JpUkZ3bROIsgexayak5TLiGqUAR5kqGjekwqQRvIXH0,7235
4
4
  wafer_core/gpu.py,sha256=ENa92btjXsx6ldpoyKfRrAmfy-LHG2KpA5k7SWd6Q_s,28627
5
5
  wafer_core/gpu_detect.py,sha256=kpD8Q_G6GA9j-WnnnTNA3BBPulkGcWnZiogOmjKDao0,13650
6
6
  wafer_core/problem_config.py,sha256=8oqxL9-pvgzi8BtFxgDcqZW4e6DV2OCZOYkcPoyXrc8,10913
@@ -12,7 +12,7 @@ wafer_core/config/__init__.py,sha256=hKywfjA4YXd4lBeBFEcBoMwFoflPHJTiBnkTq7_JYOQ
12
12
  wafer_core/config/loader.py,sha256=k7JnILmO13TWUzIv9Lm8fvmj3UfYHZDgaFurjQ-GXpY,6623
13
13
  wafer_core/config/schema.py,sha256=2WhFlnG0VYYX4T-70BLeJK8Janvi4KEa8KKGZA7331w,3898
14
14
  wafer_core/environments/__init__.py,sha256=SIsResVtm22tr_d-oHPeeSxrkhFdmPOFico3DqtRqK8,238
15
- wafer_core/environments/coding.py,sha256=0TU1kggKH4fLaoeDOUKpWwBklDf5Q9dHTAV2b96GVnU,8313
15
+ wafer_core/environments/coding.py,sha256=T-_JFU-n5OxPR8xAWp8qar4Y5xyC-TWTIBjRy4PDel8,8418
16
16
  wafer_core/environments/gpumode.py,sha256=8Da08nltvN_YloNyYI6-omN2D4n5C7aptKDCtUgT2bQ,17191
17
17
  wafer_core/lib/__init__.py,sha256=4-4p3mhwlquejWGglYXU8_nHdA0LoPaa_jGzcm13USA,1325
18
18
  wafer_core/lib/kernel_scope/__init__.py,sha256=WW2vu8jUlqOu-MCpgO40lIYacCA9N2u-uuECIs_JO2w,2817
@@ -336,9 +336,9 @@ wafer_core/rollouts/agents.py,sha256=Uv1kjYogUfdPl18YfkVxVqFTbmWfuJQrxem_iHTUgdw
336
336
  wafer_core/rollouts/cli.py,sha256=2NqgegKdlmxD0eJzGOMB5o_1Hb5t7O5JpP_32uvF2BE,80117
337
337
  wafer_core/rollouts/cli_agents.py,sha256=e4qqqYBzWLsbw8FsNnddGApWp_on9Cvzrfd1amiAyvI,20641
338
338
  wafer_core/rollouts/deploy.py,sha256=3t88fM_BMyAPkxIl8pS4r5ogHJvrlqWQDuIaltDZBRc,40924
339
- wafer_core/rollouts/dtypes.py,sha256=GXB3SaLzlSXs3Gxe6GLdU3_t4E1Pl7EmJamexRuaFDE,60962
339
+ wafer_core/rollouts/dtypes.py,sha256=GUezPTzkd8E-nDlqdGE7idUthyZC-7jTrbpa4ye-v8k,61146
340
340
  wafer_core/rollouts/eval_helpers.py,sha256=OE7uQZRcbqQhpFqb4zOj8zafc9Gr6xZJpSrMvxXKVUw,1699
341
- wafer_core/rollouts/evaluation.py,sha256=Zu2hqkGWDjTRj9UxRVgNQ_FgEnFo6zKCw4ck2htvnpg,68483
341
+ wafer_core/rollouts/evaluation.py,sha256=fk-pGZ5vpocVmw1iBbHtxMK0K6l8pYTLHCpDNvRY1Xo,69142
342
342
  wafer_core/rollouts/events.py,sha256=z85J8kq0LXPj5CiUk4RkiTQg--r9xiO7QeeJwkyUOto,7505
343
343
  wafer_core/rollouts/export.py,sha256=0CfdBB7Du4E3VekKEUcTwTEFS1bOMGZ9GbD5KU3CecQ,11583
344
344
  wafer_core/rollouts/feedback.py,sha256=mu17eQbAinXZWI3hMYLq-LyF4JAdH9SfNRWY-0S8jvQ,6769
@@ -350,11 +350,12 @@ wafer_core/rollouts/paths.py,sha256=9XtrA9ylhb5LttMFe2DE7X0IHeUMjuGUerII9OscYec,
350
350
  wafer_core/rollouts/pipeline.py,sha256=vlJTYE3ZX2XScpF9pmtv91K8Q0g8uLmcbI5jn6b5Hzg,15319
351
351
  wafer_core/rollouts/progress.py,sha256=szA9cvWT2xUxGVhF9BaAqJMmKDqMAUlxImxcOpcnqbY,29228
352
352
  wafer_core/rollouts/progress_display.py,sha256=it-IiI37k9whAuB6T_66GYgsZyidCq5x00URiOcxe2c,15769
353
- wafer_core/rollouts/prompt.py,sha256=KAFSS4HCtYtaPGGCFtC5VU7NrVmhEoQFNGWcA1HVW-Q,10566
353
+ wafer_core/rollouts/prompt.py,sha256=EDmGb0rhWwke7tokIcO8dukc3q5c8x0n5Omi5CpAQmA,11022
354
354
  wafer_core/rollouts/providers.py,sha256=dcGJh1p30hstVbCDDtJ902lyafkg81DKjcOzb0uuKS0,1400
355
355
  wafer_core/rollouts/remote.py,sha256=cAYpRCONlsTeRxzLiegAUfjZWGtqBNwZTHehMhk5ldA,8816
356
356
  wafer_core/rollouts/scoring.py,sha256=qeIT8Z7pK51XRDmN2sGdg_hIPRabWqoQIYKsuytlvRo,8838
357
357
  wafer_core/rollouts/search.py,sha256=5BEDuw9FVbQhei3nvUXEVwBU5ouwgJE6ONhEqvU5Ldc,14696
358
+ wafer_core/rollouts/skills.py,sha256=ATYoG02Cc6_VrtE415TnseBFJrKOMq27z-5YgBgPpZQ,5081
358
359
  wafer_core/rollouts/slice.py,sha256=darOZO53BuSPfvv_KjOSzulGVSWbL4OuoE3k6xXpBFg,20195
359
360
  wafer_core/rollouts/store.py,sha256=UDP9idDOEVs_0Pslx0K_Y8E1i-BeoqVSaxdQiaqtz1E,18051
360
361
  wafer_core/rollouts/transform_messages.py,sha256=yldzdLgugNYb5Zxju7myFBel1tmrHXx9M399ImqPLGI,20891
@@ -394,7 +395,7 @@ wafer_core/rollouts/environments/compose.py,sha256=DlJA_GdzByWjVvGeR4MrcQIB4ucV6
394
395
  wafer_core/rollouts/environments/cuda_grep.py,sha256=o4GPJcnKuB96KwE4UWkxXq5DdaKYMz-QizcZXleOKLs,22007
395
396
  wafer_core/rollouts/environments/git_worktree.py,sha256=f4S-OI-m6OEMyrEl3TfD4oYLXkkNgsiHuX6NOHCVfSQ,22397
396
397
  wafer_core/rollouts/environments/handoff.py,sha256=pvhcSDdltZ1zJ3Y_SAJlBmX6FhXjZLNYneXwjFqb9lE,16941
397
- wafer_core/rollouts/environments/localfs.py,sha256=fpk91XB9GNrtsHiFhSzvkfzRdEyqf8me_pR5PPxDmMg,43848
398
+ wafer_core/rollouts/environments/localfs.py,sha256=xceepFc-eh9RqwTSJd_WiODVCs4Jm-d8rQm_MvutG0A,45766
398
399
  wafer_core/rollouts/environments/no_tools.py,sha256=FRsiMDzi8ma15xePTANteUxAOiOw0s90459XMvSq1_k,2970
399
400
  wafer_core/rollouts/environments/oracle.py,sha256=OtviDKed6DEIy67TdohqSGmvQT8XuuQLGnxJF9B7yrE,7563
400
401
  wafer_core/rollouts/environments/repl.py,sha256=DyvyqEbBBytliqZ2-uAJqm-F7gdROG5-9LDFVyCo6lo,36042
@@ -501,7 +502,7 @@ wafer_core/rollouts/providers/openai_completions.py,sha256=3vUA74qjrxG-aOjyngtnZ
501
502
  wafer_core/rollouts/providers/openai_responses.py,sha256=xlyeI9h7aZEbpFY_8_zQ6IIYbMeNgcsEVRi92PBEAWc,32470
502
503
  wafer_core/rollouts/providers/sglang.py,sha256=kahYlrFG008D4jRA-c6mylBsTe-qTryKMhsjUbidduU,35525
503
504
  wafer_core/rollouts/templates/__init__.py,sha256=8qANHtoWZe9zpAWufkXlo8tQ07_Lw-RhX-lm2i-0ORQ,976
504
- wafer_core/rollouts/templates/base.py,sha256=FP6zx43ba8rx2nTx8dK1qZu-8sbmude2lBlM0xjSyGY,6906
505
+ wafer_core/rollouts/templates/base.py,sha256=aBojLjssj2YaPjgCZMvHYWxtGa83Y4uqt1Cxoy1J3v0,7010
505
506
  wafer_core/rollouts/templates/loader.py,sha256=eNRWP2zMTsygpewBhAO3Co0vLb4o4SwJo74Jp1DWS-0,4726
506
507
  wafer_core/rollouts/tests/test_slice.py,sha256=jUJFkTUzY6TNrimFH5kkfOavn3J5Gg-hsFWUOzs30lk,9148
507
508
  wafer_core/rollouts/tools/__init__.py,sha256=nLEJrlhT1anqInh6u4CeX8CCrjUJpAEMp0jx-G0V6gs,298
@@ -585,10 +586,11 @@ wafer_core/sessions/hooks.py,sha256=A-txm6ufnRGQCdtP3vwh7oEOdlLN9Tv0XsjORMihuAI,
585
586
  wafer_core/targets/__init__.py,sha256=sHndC7AAOaHXlrmDXFLB53a5Y8DBjuyqS6nwsO2nj-Y,1728
586
587
  wafer_core/targets/digitalocean.py,sha256=cvoYpYjtSyy5t2lQAPi7ERruuuibronah_ivOiduAHQ,16550
587
588
  wafer_core/targets/runpod.py,sha256=bYTLVRaASrewJLIcZRtPMfMU2en1McE6W5w-Edo_OPQ,15785
588
- wafer_core/tools/__init__.py,sha256=zdWVdQU8aP66HVoI1AhoUHJ9UKG95rT5MFufD6zV2Q0,3212
589
+ wafer_core/tools/__init__.py,sha256=wBQD45GdSfkxcT6NHzIv0IMeXCc0enwwkpm3T_9j1X8,3341
589
590
  wafer_core/tools/bash_tool.py,sha256=daoKOVGSgL0x9X_3l8Apd6-wFH4VMXMGJwVemw2FIfc,16828
590
591
  wafer_core/tools/glob_tool.py,sha256=9X5PdOjQJj7kiVNqqCZC0-1LmnE6wHx3Zc9zfMjtXdc,3533
591
592
  wafer_core/tools/grep_tool.py,sha256=cStyDz-J47oDLLZCL83yOvYo8Ijv4qu3D372JKT_ptM,4580
593
+ wafer_core/tools/skill_tool.py,sha256=JXsT5hBTUH5U4tmzHEywU7eHHt5xCEF79tL2tsuk4-c,2067
592
594
  wafer_core/tools/wafer_tool.py,sha256=-dgPTHbWXq3I3wFj0mP7-lj5iZqGRoFvFf9IEEo3plQ,6345
593
595
  wafer_core/tools/write_kernel_tool.py,sha256=dJjhr-WBhVNe06hcJQVmBZTbS8mid64KF1MwlE2s2R4,21547
594
596
  wafer_core/tools/autotuner/BENCHMARKING.md,sha256=RkJ2wFhbDFXuMbw0mOW4pRqntT0UirptXwIxyrA1_KM,3825
@@ -640,7 +642,7 @@ wafer_core/tools/tracelens_tools/tracelens_collective_tool.py,sha256=0E3FhfaA1N0
640
642
  wafer_core/tools/tracelens_tools/tracelens_compare_tool.py,sha256=99dUsB4wuYjxbh4X6Nsf2AtDMs94Uzy04tSemDOmKhg,4458
641
643
  wafer_core/tools/tracelens_tools/tracelens_report_tool.py,sha256=unuEx2zXaK42lA3qojS-WzFlBmIFrS75GHSgXUnDXGE,4720
642
644
  wafer_core/utils/__init__.py,sha256=oPHgkMkE7wS2lYKLlXrw4Ia5EHnpVcGHFfpWebIlVKs,354
643
- wafer_core/utils/backend.py,sha256=gNxv4jeOfZwrefTDhyp-UHzkWRyWRhmot76PRZkAwyk,8842
645
+ wafer_core/utils/backend.py,sha256=zt5AX00OXSIstprvQ1_WNf_PQYphD2y53kOXWvg20RY,8986
644
646
  wafer_core/utils/code_validation.py,sha256=UqS4UVDxO-atdbn6i7JygX6IFPITvT56zZn1t-ZNuM8,4692
645
647
  wafer_core/utils/environment_serialization.py,sha256=cVDkapx0JC60CekazgirPEMAeGZhbLdX1WMIkFvId60,5047
646
648
  wafer_core/utils/event_streaming.py,sha256=Sg3-hI043Ofc2b29Z3DWrKgu4HkfJoIqhhbfGRJv70Q,2260
@@ -663,7 +665,7 @@ wafer_core/utils/kernel_utils/static_checker.py,sha256=GWC7RZdwL4WqMLK0nO-wnZd-V
663
665
  wafer_core/utils/kernel_utils/task.py,sha256=XcmKxKUWh5It6nX3zGqj77tWgA32uPfQMqNOqyD5T48,2682
664
666
  wafer_core/utils/kernel_utils/utils.py,sha256=uDZoJDxh07hJeLNlPdKN2vgB15pqIr1LbXf0YIBHU4E,43056
665
667
  wafer_core/utils/kernel_utils/targets/__init__.py,sha256=4NwRLsuJ__S4xKAfda4Ag82C5MQ3Qio-4xA5S-mQGlU,2067
666
- wafer_core/utils/kernel_utils/targets/config.py,sha256=eGpPOKlTxcS16Fs1-L9r9Cf_CDTG2-52bsG7LJllFw4,18900
668
+ wafer_core/utils/kernel_utils/targets/config.py,sha256=3TvT2Hp3TV-cjsSiZ8NOmAz8epDoVii76p6DDAI2V64,19134
667
669
  wafer_core/utils/kernel_utils/targets/execution.py,sha256=bZuNXCo0sIdD6hFhetLPrtDC-zMSiIsAx_aml49VVL0,15033
668
670
  wafer_core/utils/kernel_utils/targets/selection.py,sha256=5I_RG_7cfhq7uaeR28meC2EeNNKssFsK-Tc3QFG6Ze0,3590
669
671
  wafer_core/utils/modal_execution/__init__.py,sha256=jkVqYOLzCT5K73N9Od0UIUsx-99A0m6bpDrxfyXxQZ8,945
@@ -671,6 +673,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=ibBllC59R9bP9w4QweFocVGeSX5
671
673
  wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
672
674
  wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
673
675
  wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
674
- wafer_core-0.1.20.dist-info/METADATA,sha256=Dndu3eLqq_Gxvoa_tKupaqOmQ17FjVujI_KLTzk0MXQ,1420
675
- wafer_core-0.1.20.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
676
- wafer_core-0.1.20.dist-info/RECORD,,
676
+ wafer_core-0.1.22.dist-info/METADATA,sha256=wV6MLEufRIKPRacW_ErMoqgymAFNI2XgP4wobQoKjnM,1420
677
+ wafer_core-0.1.22.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
678
+ wafer_core-0.1.22.dist-info/RECORD,,