wafer-core 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,336 @@
1
+ """Trace loading and parsing logic.
2
+
3
+ Loads JSON trace files from AMD/NVIDIA profilers and extracts kernel execution data,
4
+ Python call stacks, CPU operator mappings, and layer correlations.
5
+ """
6
+
7
+ import bisect
8
+ import json
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import Any
12
+
13
+ import pandas as pd
14
+
15
+ from .classifier import classify
16
+
17
+
18
+ def extract_layer_mapping(events: list[dict[str, Any]], platform: str) -> dict[int, int]:
19
+ """Extract correlation ID to layer number mapping.
20
+
21
+ vLLM's execution graph creates large correlation groups for full transformer layers.
22
+ Each layer's forward pass (norm + attention + FFN) gets grouped under one correlation ID,
23
+ containing 200-400 kernels depending on batch size and sequence length.
24
+
25
+ We identify layers as correlation groups with many kernels (70+), which filters out
26
+ individual operations like sampling, logit processing, etc.
27
+
28
+ Args:
29
+ events: List of trace events
30
+ platform: 'AMD' or 'NVIDIA'
31
+
32
+ Returns:
33
+ Dict mapping correlation ID to layer number
34
+ """
35
+ # Group kernels by correlation ID
36
+ correlation_groups = defaultdict(
37
+ lambda: {"count": 0, "has_attention": False, "has_ffn": False}
38
+ )
39
+
40
+ for ev in events:
41
+ if ev.get("cat") != "kernel":
42
+ continue
43
+
44
+ corr_id = ev.get("args", {}).get("correlation")
45
+ if corr_id is None:
46
+ continue
47
+
48
+ kernel_name = ev.get("name", "").lower()
49
+
50
+ # Track what operations this correlation contains
51
+ correlation_groups[corr_id]["count"] += 1
52
+ if "attention" in kernel_name or "fmha" in kernel_name:
53
+ correlation_groups[corr_id]["has_attention"] = True
54
+ if any(x in kernel_name for x in ["cijk_", "nvjet", "wvsplitk", "gemm"]):
55
+ correlation_groups[corr_id]["has_ffn"] = True
56
+
57
+ # Map correlation IDs to layer numbers
58
+ # Transformer layers have many kernels AND contain both attention and FFN ops
59
+ correlation_to_layer = {}
60
+ layer_num = 0
61
+
62
+ for corr_id in sorted(correlation_groups.keys()):
63
+ group = correlation_groups[corr_id]
64
+
65
+ # Identify complete transformer layers by their characteristics:
66
+ # - Has attention operations (self-attention or cross-attention)
67
+ # - Has FFN operations (feed-forward network)
68
+ # - Has sufficient kernel count (70+): typical transformer block has ~80-100 kernels
69
+ # including attention QKV projections, softmax, output projection, FFN layers,
70
+ # normalization, and elementwise ops. This threshold filters out:
71
+ # - Individual operations (1-10 kernels)
72
+ # - Sampling/generation steps (20-40 kernels)
73
+ # - Partial layer executions
74
+ is_layer = (
75
+ group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
76
+ )
77
+
78
+ if is_layer:
79
+ correlation_to_layer[corr_id] = layer_num
80
+ layer_num += 1
81
+
82
+ return correlation_to_layer
83
+
84
+
85
+ def _build_python_stack_index(
86
+ events: list[dict[str, Any]],
87
+ ) -> tuple[list[tuple[int, int, int, int | None, str]], dict[int, dict[str, Any]]]:
88
+ """Build Python call stack index for kernels.
89
+
90
+ Args:
91
+ events: List of trace events
92
+
93
+ Returns:
94
+ Tuple of (python_intervals, python_by_id)
95
+ """
96
+ python_by_id: dict[int, dict[str, Any]] = {}
97
+ python_intervals: list[tuple[int, int, int, int | None, str]] = []
98
+
99
+ for ev in events:
100
+ if ev.get("cat") == "python_function":
101
+ py_id = ev.get("args", {}).get("Python id")
102
+ name = ev["name"]
103
+ ts_start = ev["ts"]
104
+ ts_end = ts_start + ev.get("dur", 0)
105
+ duration = ev.get("dur", 0)
106
+ parent_id = ev.get("args", {}).get("Python parent id")
107
+
108
+ python_intervals.append((ts_start, ts_end, duration, py_id, name))
109
+
110
+ if py_id is not None:
111
+ python_by_id[py_id] = {
112
+ "name": name,
113
+ "parent_id": parent_id,
114
+ "ts_start": ts_start,
115
+ "ts_end": ts_end,
116
+ "duration": duration,
117
+ }
118
+
119
+ # Sort by start time for efficient binary search
120
+ python_intervals.sort()
121
+
122
+ return python_intervals, python_by_id
123
+
124
+
125
+ def _get_python_stack_full(
126
+ timestamp: int,
127
+ python_intervals: list[tuple[int, int, int, int | None, str]],
128
+ python_by_id: dict[int, dict[str, Any]],
129
+ ) -> tuple[str | None, list[str]]:
130
+ """Get full Python call stack for a kernel launch.
131
+
132
+ Args:
133
+ timestamp: Kernel launch timestamp
134
+ python_intervals: Sorted list of Python function intervals
135
+ python_by_id: Mapping of Python ID to function info
136
+
137
+ Returns:
138
+ Tuple of (summary_string, full_stack_list)
139
+ """
140
+ # Binary search for Python functions active at this timestamp
141
+ idx = bisect.bisect_right(
142
+ python_intervals, (timestamp, float("inf"), float("inf"), None, "")
143
+ )
144
+
145
+ # Find active functions
146
+ active_funcs = []
147
+ for i in range(idx - 1, max(0, idx - 1000), -1):
148
+ ts_start, ts_end, duration, py_id, name = python_intervals[i]
149
+ if ts_start <= timestamp <= ts_end:
150
+ active_funcs.append((duration, py_id, name))
151
+ if ts_end < timestamp - 1000000: # 1 second before
152
+ break
153
+
154
+ if not active_funcs:
155
+ return None, []
156
+
157
+ # Get the innermost (most specific) function
158
+ active_funcs.sort()
159
+ leaf_duration, leaf_id, leaf_name = active_funcs[0]
160
+
161
+ # Walk up parent chain to get FULL stack
162
+ full_stack = []
163
+ current_id = leaf_id
164
+ visited = set()
165
+
166
+ while (
167
+ current_id is not None
168
+ and current_id not in visited
169
+ and current_id in python_by_id
170
+ ):
171
+ func = python_by_id[current_id]
172
+ name = func["name"]
173
+ full_stack.append(name)
174
+
175
+ visited.add(current_id)
176
+ current_id = func["parent_id"]
177
+
178
+ # Safety limit: prevent infinite loops from circular parent references
179
+ # and bound memory usage. 50 frames is deeper than typical Python stacks.
180
+ if len(full_stack) >= 50:
181
+ break
182
+
183
+ # Reverse so it's outermost -> innermost
184
+ full_stack.reverse()
185
+
186
+ # Create summary for text output: show the most informative vLLM/model function
187
+ summary = None
188
+ vllm_funcs = [
189
+ f
190
+ for f in full_stack
191
+ if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])
192
+ ]
193
+
194
+ if vllm_funcs:
195
+ # Get innermost vLLM function (most specific)
196
+ summary = vllm_funcs[-1]
197
+
198
+ # Check if it's a CUDA graph - add annotation
199
+ if any("torch/cuda/graphs" in f for f in full_stack):
200
+ # Shorten if too long
201
+ if len(summary) > 45:
202
+ parts = summary.split("/")[-1]
203
+ summary = "vllm/..." + parts
204
+ summary = f"{summary} [CUDA graph]"
205
+ elif len(summary) > 53:
206
+ parts = summary.split("/")[-1]
207
+ summary = "vllm/..." + parts
208
+ else:
209
+ # Fallback to innermost function
210
+ summary = leaf_name
211
+
212
+ return summary, full_stack
213
+
214
+
215
+ def load_trace(
216
+ file_path: str | Path,
217
+ ) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
218
+ """Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
219
+
220
+ Args:
221
+ file_path: Path to JSON trace file
222
+
223
+ Returns:
224
+ Tuple of (platform, gpu_name, device_props, kernel_df, kernel_patterns, layer_mapping)
225
+ """
226
+ with open(file_path, "rb") as f:
227
+ trace = json.load(f)
228
+
229
+ props = trace.get("deviceProperties", [{}])[0]
230
+ is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
231
+ platform = "AMD" if is_amd else "NVIDIA"
232
+ gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
233
+
234
+ # Extract relevant device properties
235
+ device_props = {
236
+ "name": gpu_name,
237
+ "compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
238
+ "total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
239
+ "sm_count": props.get("numSms", 0),
240
+ "warp_size": props.get("warpSize", 32),
241
+ "max_threads_per_block": props.get("maxThreadsPerBlock", 0),
242
+ "shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
243
+ }
244
+
245
+ events = trace.get("traceEvents", [])
246
+
247
+ # Build mapping: external_id -> CPU operator name
248
+ external_to_cpu = {}
249
+ for ev in events:
250
+ if ev.get("cat") == "cpu_op":
251
+ ext_id = ev.get("args", {}).get("External id")
252
+ cpu_op_name = ev.get("name", "")
253
+ if ext_id is not None:
254
+ external_to_cpu[ext_id] = cpu_op_name
255
+
256
+ # Build Python call stack index for kernels without External IDs
257
+ python_intervals, python_by_id = _build_python_stack_index(events)
258
+
259
+ # Extract phases
260
+ phases = []
261
+ for ev in events:
262
+ if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith(
263
+ "execute_context"
264
+ ):
265
+ name = ev["name"]
266
+ # Parse execute_context_X(TOKENS)_generation_Y(Y)
267
+ # We want the TOKENS from execute_context, not the generation number
268
+ tokens = 0
269
+ parts = name.split("_")
270
+ for i, p in enumerate(parts):
271
+ # Look for execute_context_X(TOKENS) specifically
272
+ if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
273
+ try:
274
+ tokens = int(p.split("(")[1].split(")")[0])
275
+ break # Stop after finding context tokens
276
+ except Exception:
277
+ pass
278
+ is_prefill = tokens >= 1024 and "generation_0" in name
279
+ phases.append(
280
+ {
281
+ "type": "prefill" if is_prefill else "decode",
282
+ "ts_start": ev["ts"],
283
+ "ts_end": ev["ts"] + ev["dur"],
284
+ }
285
+ )
286
+
287
+ # Extract layer mapping from correlation IDs
288
+ layer_mapping = extract_layer_mapping(events, platform)
289
+
290
+ kernel_data = []
291
+ kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
292
+
293
+ for ev in events:
294
+ if ev.get("cat") != "kernel":
295
+ continue
296
+ name, dur, ts = ev["name"], ev.get("dur", 0), ev["ts"]
297
+ corr_id = ev.get("args", {}).get("correlation")
298
+ ext_id = ev.get("args", {}).get("External id")
299
+
300
+ phase = "decode"
301
+ for p in phases:
302
+ if p["ts_start"] <= ts <= p["ts_end"]:
303
+ phase = p["type"]
304
+ break
305
+
306
+ op, pattern = classify(name, platform)
307
+ kernel_patterns[(op.value, phase)].add(pattern)
308
+
309
+ # Assign layer number from correlation ID
310
+ layer = layer_mapping.get(corr_id) if corr_id is not None else None
311
+
312
+ # Get CPU operator name from external ID, or fallback to Python stack
313
+ cpu_op = external_to_cpu.get(ext_id) if ext_id is not None else None
314
+ python_stack: list[str] = []
315
+
316
+ # If no CPU op via External ID, try Python stack trace
317
+ if cpu_op is None:
318
+ cpu_op, python_stack = _get_python_stack_full(
319
+ ts, python_intervals, python_by_id
320
+ )
321
+
322
+ kernel_data.append(
323
+ {
324
+ "name": name,
325
+ "dur_us": dur,
326
+ "phase": phase,
327
+ "op": op.value,
328
+ "pattern": pattern,
329
+ "layer": layer,
330
+ "correlation": corr_id,
331
+ "cpu_op": cpu_op,
332
+ "python_stack": python_stack, # Full stack for JSON output
333
+ }
334
+ )
335
+
336
+ return platform, gpu_name, device_props, pd.DataFrame(kernel_data), dict(kernel_patterns), layer_mapping
@@ -84,7 +84,7 @@ class ProblemConfig:
84
84
  benchmarks: list[dict[str, Any]]
85
85
 
86
86
  # Optional with defaults
87
- model: str = "claude-sonnet-4-5-20250929"
87
+ model: str = "claude-opus-4-5-20251101"
88
88
  temperature: float = 0.2
89
89
  max_tokens: int = 8192
90
90
  max_turns: int = 10
@@ -219,7 +219,7 @@ def _parse_config(data: dict[str, Any], base_dir: Path) -> tuple[ProblemConfig |
219
219
  reference_code=reference_code,
220
220
  tests=tests,
221
221
  benchmarks=benchmarks,
222
- model=data.get("model", "claude-sonnet-4-5-20250929"),
222
+ model=data.get("model", "claude-opus-4-5-20251101"),
223
223
  temperature=data.get("temperature", 0.2),
224
224
  max_tokens=data.get("max_tokens", 8192),
225
225
  max_turns=data.get("max_turns", 10),
@@ -269,7 +269,7 @@ def create_problem_config_from_cli(
269
269
  reference_code=reference_code,
270
270
  tests=tests,
271
271
  benchmarks=benchmarks or tests, # Use tests as benchmarks if not specified
272
- model=kwargs.get("model", "claude-sonnet-4-5-20250929"),
272
+ model=kwargs.get("model", "claude-opus-4-5-20251101"),
273
273
  temperature=kwargs.get("temperature", 0.2),
274
274
  max_tokens=kwargs.get("max_tokens", 8192),
275
275
  max_turns=kwargs.get("max_turns", 10),
@@ -119,7 +119,7 @@ FINAL(42)
119
119
 
120
120
  config = AgentPresetConfig(
121
121
  name="rlm",
122
- model="anthropic/claude-sonnet-4-5-20250929",
122
+ model="anthropic/claude-opus-4-5-20251101",
123
123
  env="repl", # Uses REPLEnvironment
124
124
  thinking=True,
125
125
  system_prompt=RLM_TOOL_SYSTEM_PROMPT,
@@ -128,7 +128,7 @@ config = AgentPresetConfig(
128
128
  # Variant for message-parsing mode
129
129
  config_block_mode = AgentPresetConfig(
130
130
  name="rlm_blocks",
131
- model="anthropic/claude-sonnet-4-5-20250929",
131
+ model="anthropic/claude-opus-4-5-20251101",
132
132
  env="repl_blocks", # Uses MessageParsingREPLEnvironment
133
133
  thinking=True,
134
134
  system_prompt=RLM_BLOCK_SYSTEM_PROMPT,
@@ -1238,6 +1238,12 @@ class Endpoint(JsonSerializable):
1238
1238
  api_base: str = ""
1239
1239
  api_key: str = ""
1240
1240
  oauth_token: str = "" # OAuth bearer token (takes precedence over api_key for Anthropic)
1241
+ # TODO: Callbacks on a frozen dataclass are a code smell. This exists because wafer-core
1242
+ # can't depend on wafer-cli (where the Supabase refresh logic lives). A cleaner approach
1243
+ # would be a TokenProvider protocol that Endpoint delegates to, keeping the dataclass pure.
1244
+ api_key_refresh: Callable[[], Awaitable[str | None]] | None = field(
1245
+ default=None, repr=False, compare=False
1246
+ )
1241
1247
  is_claude_code_api_key: bool = (
1242
1248
  False # API key created via Claude Code OAuth (requires special headers)
1243
1249
  )
@@ -1300,6 +1306,7 @@ class Endpoint(JsonSerializable):
1300
1306
  exclude_secrets: If True (default), omits api_key and oauth_token.
1301
1307
  """
1302
1308
  d = asdict(self)
1309
+ d.pop("api_key_refresh", None) # Callable, not serializable
1303
1310
  if exclude_secrets:
1304
1311
  d.pop("api_key", None)
1305
1312
  d.pop("oauth_token", None)
@@ -1307,7 +1314,11 @@ class Endpoint(JsonSerializable):
1307
1314
 
1308
1315
  @classmethod
1309
1316
  def from_dict(
1310
- cls, data: dict[str, Any], api_key: str = "", oauth_token: str = ""
1317
+ cls,
1318
+ data: dict[str, Any],
1319
+ api_key: str = "",
1320
+ oauth_token: str = "",
1321
+ api_key_refresh: "Callable[[], Awaitable[str | None]] | None" = None,
1311
1322
  ) -> "Endpoint":
1312
1323
  """Deserialize from dict, injecting secrets at runtime.
1313
1324
 
@@ -1315,12 +1326,16 @@ class Endpoint(JsonSerializable):
1315
1326
  data: Dict from to_dict()
1316
1327
  api_key: API key to inject (not stored in session)
1317
1328
  oauth_token: OAuth token to inject (not stored in session)
1329
+ api_key_refresh: Callback to refresh api_key mid-session (not stored)
1318
1330
  """
1319
- # Remove secrets if present (they shouldn't be, but be safe)
1331
+ # Remove secrets/callables if present (they shouldn't be, but be safe)
1320
1332
  data = data.copy()
1321
1333
  data.pop("api_key", None)
1322
1334
  data.pop("oauth_token", None)
1323
- return cls(**data, api_key=api_key, oauth_token=oauth_token)
1335
+ data.pop("api_key_refresh", None)
1336
+ return cls(
1337
+ **data, api_key=api_key, oauth_token=oauth_token, api_key_refresh=api_key_refresh
1338
+ )
1324
1339
 
1325
1340
 
1326
1341
  @dataclass(frozen=True)
@@ -725,9 +725,16 @@ async def rollout_anthropic(
725
725
  oauth_token = fresh_token
726
726
  # If refresh failed, continue with existing token - it might still work
727
727
 
728
+ # Get fresh wafer proxy token if refresh callback is available
729
+ api_key = actor.endpoint.api_key
730
+ if actor.endpoint.api_key_refresh:
731
+ fresh_key = await actor.endpoint.api_key_refresh()
732
+ if fresh_key:
733
+ api_key = fresh_key
734
+
728
735
  client = _create_anthropic_client(
729
736
  oauth_token=oauth_token,
730
- api_key=actor.endpoint.api_key,
737
+ api_key=api_key,
731
738
  api_base=actor.endpoint.api_base,
732
739
  max_retries=actor.endpoint.max_retries,
733
740
  timeout=actor.endpoint.timeout,
@@ -973,7 +980,7 @@ async def rollout_anthropic(
973
980
  f"Model not found: {e}\nCheck your model ID is correct."
974
981
  ) from e
975
982
 
976
- # For OAuth: try to refresh token and retry once on auth errors
983
+ # Try to refresh token and retry once on auth errors
977
984
  if isinstance(e, anthropic.AuthenticationError):
978
985
  if oauth_token and attempt == 0:
979
986
  # Emit retry event for OAuth refresh
@@ -993,12 +1000,37 @@ async def rollout_anthropic(
993
1000
  await client.close()
994
1001
  client = _create_anthropic_client(
995
1002
  oauth_token=oauth_token,
996
- api_key=actor.endpoint.api_key,
1003
+ api_key=api_key,
997
1004
  api_base=actor.endpoint.api_base,
998
1005
  max_retries=actor.endpoint.max_retries,
999
1006
  timeout=actor.endpoint.timeout,
1000
1007
  )
1001
1008
  continue
1009
+
1010
+ # Wafer proxy token refresh (Supabase JWTs expire after ~1hr)
1011
+ if actor.endpoint.api_key_refresh and attempt == 0:
1012
+ await on_chunk(
1013
+ RetryStart(
1014
+ attempt=1,
1015
+ max_attempts=2,
1016
+ delay_seconds=0,
1017
+ error_message="Wafer proxy token expired, refreshing",
1018
+ provider="anthropic",
1019
+ )
1020
+ )
1021
+ fresh_key = await actor.endpoint.api_key_refresh()
1022
+ if fresh_key and fresh_key != api_key:
1023
+ api_key = fresh_key
1024
+ await client.close()
1025
+ client = _create_anthropic_client(
1026
+ oauth_token=oauth_token,
1027
+ api_key=api_key,
1028
+ api_base=actor.endpoint.api_base,
1029
+ max_retries=actor.endpoint.max_retries,
1030
+ timeout=actor.endpoint.timeout,
1031
+ )
1032
+ continue
1033
+
1002
1034
  raise FatalEvalError(
1003
1035
  f"Authentication failed: {e}\nCheck your API key or OAuth token."
1004
1036
  ) from e
@@ -7,10 +7,13 @@ import logging
7
7
  import os
8
8
  from pathlib import Path
9
9
 
10
+ import httpx
11
+
10
12
  logger = logging.getLogger(__name__)
11
13
 
12
14
  SUPABASE_URL = "https://hvlpthcnxlywlquiciqe.supabase.co"
13
15
  BUCKET_NAME = "traces"
16
+ API_BASE = os.environ.get("WAFER_API_URL", "https://api.wafer.ai")
14
17
 
15
18
 
16
19
  def upload_results_to_supabase(output_dir: Path, log: logging.Logger | None = None) -> bool:
@@ -95,6 +98,12 @@ def upload_results_to_supabase(output_dir: Path, log: logging.Logger | None = No
95
98
  )
96
99
 
97
100
  log.info(f"Uploaded {len(uploaded)} files to Supabase: {run_name}")
101
+
102
+ # Auto-index in database for trace viewer
103
+ # Fail if indexing fails - user can re-run (everything is idempotent)
104
+ if not _index_run_in_database(run_name, report_path, log):
105
+ return False
106
+
98
107
  return True
99
108
 
100
109
  except ImportError:
@@ -103,3 +112,39 @@ def upload_results_to_supabase(output_dir: Path, log: logging.Logger | None = No
103
112
  except Exception as e:
104
113
  log.error(f"Failed to upload to Supabase: {e}")
105
114
  return False
115
+
116
+
117
+ def _index_run_in_database(run_name: str, report_path: Path, log: logging.Logger) -> bool:
118
+ """Index a run in the trace_runs database table for fast querying.
119
+
120
+ Calls POST /v1/eval-traces/runs to upsert the run metadata.
121
+ This enables the trace viewer to show the run immediately without manual sync.
122
+
123
+ Args:
124
+ run_name: Name of the run (folder name)
125
+ report_path: Path to the report.json file
126
+ log: Logger instance
127
+
128
+ Returns:
129
+ True if indexing succeeded, False otherwise
130
+ """
131
+ try:
132
+ with open(report_path) as f:
133
+ report = json.load(f)
134
+
135
+ response = httpx.post(
136
+ f"{API_BASE}/v1/eval-traces/runs",
137
+ json={"name": run_name, "report": report},
138
+ timeout=30.0,
139
+ )
140
+
141
+ if response.status_code == 200:
142
+ log.info(f"Indexed run in database: {run_name}")
143
+ return True
144
+ else:
145
+ log.error(f"Failed to index run {run_name}: {response.status_code} {response.text}")
146
+ return False
147
+
148
+ except Exception as e:
149
+ log.error(f"Failed to index run {run_name} in database: {e}")
150
+ return False
@@ -72,10 +72,6 @@ from wafer_core.tools.write_kernel_tool import (
72
72
  KernelSubmission,
73
73
  exec_write_kernel,
74
74
  )
75
- from wafer_core.tools.search_docs_tool import (
76
- SEARCH_DOCS_TOOL,
77
- exec_search_docs,
78
- )
79
75
 
80
76
  __all__ = [
81
77
  # File tools
@@ -137,7 +133,4 @@ __all__ = [
137
133
  "exec_tracelens_report",
138
134
  "exec_tracelens_compare",
139
135
  "exec_tracelens_collective",
140
- # Search docs tool
141
- "SEARCH_DOCS_TOOL",
142
- "exec_search_docs",
143
136
  ]
@@ -12,6 +12,16 @@ Attack types defended against:
12
12
  5. Monkey-patching - Replacing CUDA timing functions with fake implementations
13
13
 
14
14
  Reference: "Hacks and Defenses in Automatic GPU Kernel Generation" by Jiwei Li (Dec 2025)
15
+
16
+ TODO: Memory guard buffers (from CUDA-L2's zero_one_correctness_check.py) — wrap
17
+ input/output tensors with guard regions and check for out-of-bounds writes after
18
+ kernel execution. Catches shared memory overflow and buffer overrun at the memory
19
+ boundary, rather than inferring from output non-determinism.
20
+
21
+ TODO: Exact correctness for GEMM kernels (from CUDA-L2) — use {0,1} input matrices
22
+ where FP16 results ≤2048 are exactly representable, enabling zero-tolerance
23
+ validation (torch.equal instead of torch.allclose). Eliminates the "bounded garbage
24
+ passes tolerance check" failure mode for matmul kernels entirely.
15
25
  """
16
26
 
17
27
  import random
@@ -21,6 +21,12 @@ if TYPE_CHECKING:
21
21
  from wafer_core.utils.kernel_utils.deployment import DeploymentConfig
22
22
 
23
23
 
24
+ # TODO: Split BaremetalTarget into BaremetalTarget (persistent servers like Vultr,
25
+ # never auto-removed) and SSHTarget (ephemeral SSH endpoints from providers like
26
+ # RunPod/DO, safe to auto-clean when unreachable). Currently the pool bridge creates
27
+ # ephemeral pod endpoints as type="baremetal", losing provenance. SSHTarget should
28
+ # subclass BaremetalTarget so existing isinstance() checks still work. The `provider`
29
+ # field is a stopgap until this split happens.
24
30
  @dataclass(frozen=True)
25
31
  class BaremetalTarget:
26
32
  """Configuration for baremetal GPU server.
@@ -59,6 +65,9 @@ class BaremetalTarget:
59
65
  gpu_type: str = "B200"
60
66
  compute_capability: str = "10.0"
61
67
  ncu_available: bool = True # Baremetal typically has NCU
68
+ provider: str | None = (
69
+ None # Source provider ("runpod", "digitalocean") — enables auto-cleanup when instance is gone
70
+ )
62
71
 
63
72
  # Docker execution config (Modal-like). If docker_image is set, run in container.
64
73
  docker_image: str | None = (
@@ -314,6 +323,7 @@ class RunPodTarget:
314
323
  # apt-get install --reinstall -y rocthrust
315
324
  # See docker/rocm7-runpod/README.md for details.
316
325
  image: str = "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
326
+ template_id: str | None = None # RunPod template ID for custom pod configuration
317
327
 
318
328
  # RunPod template ID — required for non-RunPod images that need custom
319
329
  # dockerArgs (e.g. to install and start sshd). When set, takes priority
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wafer-core
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: Core utilities and environments for Wafer GPU kernel optimization
5
5
  Requires-Python: >=3.10
6
6
  Requires-Dist: aiohttp>=3.9.0