wafer-core 0.1.25__py3-none-any.whl → 0.1.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/__init__.py +32 -0
- wafer_core/lib/trace_compare/analyzer.py +339 -0
- wafer_core/lib/trace_compare/classifier.py +192 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +890 -0
- wafer_core/lib/trace_compare/loader.py +336 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/METADATA +1 -1
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/RECORD +15 -9
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.26.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
"""Trace loading and parsing logic.
|
|
2
|
+
|
|
3
|
+
Loads JSON trace files from AMD/NVIDIA profilers and extracts kernel execution data,
|
|
4
|
+
Python call stacks, CPU operator mappings, and layer correlations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import bisect
|
|
8
|
+
import json
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pandas as pd
|
|
14
|
+
|
|
15
|
+
from .classifier import classify
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def extract_layer_mapping(events: list[dict[str, Any]], platform: str) -> dict[int, int]:
|
|
19
|
+
"""Extract correlation ID to layer number mapping.
|
|
20
|
+
|
|
21
|
+
vLLM's execution graph creates large correlation groups for full transformer layers.
|
|
22
|
+
Each layer's forward pass (norm + attention + FFN) gets grouped under one correlation ID,
|
|
23
|
+
containing 200-400 kernels depending on batch size and sequence length.
|
|
24
|
+
|
|
25
|
+
We identify layers as correlation groups with many kernels (70+), which filters out
|
|
26
|
+
individual operations like sampling, logit processing, etc.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
events: List of trace events
|
|
30
|
+
platform: 'AMD' or 'NVIDIA'
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Dict mapping correlation ID to layer number
|
|
34
|
+
"""
|
|
35
|
+
# Group kernels by correlation ID
|
|
36
|
+
correlation_groups = defaultdict(
|
|
37
|
+
lambda: {"count": 0, "has_attention": False, "has_ffn": False}
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
for ev in events:
|
|
41
|
+
if ev.get("cat") != "kernel":
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
45
|
+
if corr_id is None:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
kernel_name = ev.get("name", "").lower()
|
|
49
|
+
|
|
50
|
+
# Track what operations this correlation contains
|
|
51
|
+
correlation_groups[corr_id]["count"] += 1
|
|
52
|
+
if "attention" in kernel_name or "fmha" in kernel_name:
|
|
53
|
+
correlation_groups[corr_id]["has_attention"] = True
|
|
54
|
+
if any(x in kernel_name for x in ["cijk_", "nvjet", "wvsplitk", "gemm"]):
|
|
55
|
+
correlation_groups[corr_id]["has_ffn"] = True
|
|
56
|
+
|
|
57
|
+
# Map correlation IDs to layer numbers
|
|
58
|
+
# Transformer layers have many kernels AND contain both attention and FFN ops
|
|
59
|
+
correlation_to_layer = {}
|
|
60
|
+
layer_num = 0
|
|
61
|
+
|
|
62
|
+
for corr_id in sorted(correlation_groups.keys()):
|
|
63
|
+
group = correlation_groups[corr_id]
|
|
64
|
+
|
|
65
|
+
# Identify complete transformer layers by their characteristics:
|
|
66
|
+
# - Has attention operations (self-attention or cross-attention)
|
|
67
|
+
# - Has FFN operations (feed-forward network)
|
|
68
|
+
# - Has sufficient kernel count (70+): typical transformer block has ~80-100 kernels
|
|
69
|
+
# including attention QKV projections, softmax, output projection, FFN layers,
|
|
70
|
+
# normalization, and elementwise ops. This threshold filters out:
|
|
71
|
+
# - Individual operations (1-10 kernels)
|
|
72
|
+
# - Sampling/generation steps (20-40 kernels)
|
|
73
|
+
# - Partial layer executions
|
|
74
|
+
is_layer = (
|
|
75
|
+
group["count"] >= 70 and group["has_attention"] and group["has_ffn"]
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if is_layer:
|
|
79
|
+
correlation_to_layer[corr_id] = layer_num
|
|
80
|
+
layer_num += 1
|
|
81
|
+
|
|
82
|
+
return correlation_to_layer
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _build_python_stack_index(
|
|
86
|
+
events: list[dict[str, Any]],
|
|
87
|
+
) -> tuple[list[tuple[int, int, int, int | None, str]], dict[int, dict[str, Any]]]:
|
|
88
|
+
"""Build Python call stack index for kernels.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
events: List of trace events
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Tuple of (python_intervals, python_by_id)
|
|
95
|
+
"""
|
|
96
|
+
python_by_id: dict[int, dict[str, Any]] = {}
|
|
97
|
+
python_intervals: list[tuple[int, int, int, int | None, str]] = []
|
|
98
|
+
|
|
99
|
+
for ev in events:
|
|
100
|
+
if ev.get("cat") == "python_function":
|
|
101
|
+
py_id = ev.get("args", {}).get("Python id")
|
|
102
|
+
name = ev["name"]
|
|
103
|
+
ts_start = ev["ts"]
|
|
104
|
+
ts_end = ts_start + ev.get("dur", 0)
|
|
105
|
+
duration = ev.get("dur", 0)
|
|
106
|
+
parent_id = ev.get("args", {}).get("Python parent id")
|
|
107
|
+
|
|
108
|
+
python_intervals.append((ts_start, ts_end, duration, py_id, name))
|
|
109
|
+
|
|
110
|
+
if py_id is not None:
|
|
111
|
+
python_by_id[py_id] = {
|
|
112
|
+
"name": name,
|
|
113
|
+
"parent_id": parent_id,
|
|
114
|
+
"ts_start": ts_start,
|
|
115
|
+
"ts_end": ts_end,
|
|
116
|
+
"duration": duration,
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
# Sort by start time for efficient binary search
|
|
120
|
+
python_intervals.sort()
|
|
121
|
+
|
|
122
|
+
return python_intervals, python_by_id
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _get_python_stack_full(
|
|
126
|
+
timestamp: int,
|
|
127
|
+
python_intervals: list[tuple[int, int, int, int | None, str]],
|
|
128
|
+
python_by_id: dict[int, dict[str, Any]],
|
|
129
|
+
) -> tuple[str | None, list[str]]:
|
|
130
|
+
"""Get full Python call stack for a kernel launch.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
timestamp: Kernel launch timestamp
|
|
134
|
+
python_intervals: Sorted list of Python function intervals
|
|
135
|
+
python_by_id: Mapping of Python ID to function info
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Tuple of (summary_string, full_stack_list)
|
|
139
|
+
"""
|
|
140
|
+
# Binary search for Python functions active at this timestamp
|
|
141
|
+
idx = bisect.bisect_right(
|
|
142
|
+
python_intervals, (timestamp, float("inf"), float("inf"), None, "")
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Find active functions
|
|
146
|
+
active_funcs = []
|
|
147
|
+
for i in range(idx - 1, max(0, idx - 1000), -1):
|
|
148
|
+
ts_start, ts_end, duration, py_id, name = python_intervals[i]
|
|
149
|
+
if ts_start <= timestamp <= ts_end:
|
|
150
|
+
active_funcs.append((duration, py_id, name))
|
|
151
|
+
if ts_end < timestamp - 1000000: # 1 second before
|
|
152
|
+
break
|
|
153
|
+
|
|
154
|
+
if not active_funcs:
|
|
155
|
+
return None, []
|
|
156
|
+
|
|
157
|
+
# Get the innermost (most specific) function
|
|
158
|
+
active_funcs.sort()
|
|
159
|
+
leaf_duration, leaf_id, leaf_name = active_funcs[0]
|
|
160
|
+
|
|
161
|
+
# Walk up parent chain to get FULL stack
|
|
162
|
+
full_stack = []
|
|
163
|
+
current_id = leaf_id
|
|
164
|
+
visited = set()
|
|
165
|
+
|
|
166
|
+
while (
|
|
167
|
+
current_id is not None
|
|
168
|
+
and current_id not in visited
|
|
169
|
+
and current_id in python_by_id
|
|
170
|
+
):
|
|
171
|
+
func = python_by_id[current_id]
|
|
172
|
+
name = func["name"]
|
|
173
|
+
full_stack.append(name)
|
|
174
|
+
|
|
175
|
+
visited.add(current_id)
|
|
176
|
+
current_id = func["parent_id"]
|
|
177
|
+
|
|
178
|
+
# Safety limit: prevent infinite loops from circular parent references
|
|
179
|
+
# and bound memory usage. 50 frames is deeper than typical Python stacks.
|
|
180
|
+
if len(full_stack) >= 50:
|
|
181
|
+
break
|
|
182
|
+
|
|
183
|
+
# Reverse so it's outermost -> innermost
|
|
184
|
+
full_stack.reverse()
|
|
185
|
+
|
|
186
|
+
# Create summary for text output: show the most informative vLLM/model function
|
|
187
|
+
summary = None
|
|
188
|
+
vllm_funcs = [
|
|
189
|
+
f
|
|
190
|
+
for f in full_stack
|
|
191
|
+
if any(x in f.lower() for x in ["vllm/", "model", "<eval_with_key>"])
|
|
192
|
+
]
|
|
193
|
+
|
|
194
|
+
if vllm_funcs:
|
|
195
|
+
# Get innermost vLLM function (most specific)
|
|
196
|
+
summary = vllm_funcs[-1]
|
|
197
|
+
|
|
198
|
+
# Check if it's a CUDA graph - add annotation
|
|
199
|
+
if any("torch/cuda/graphs" in f for f in full_stack):
|
|
200
|
+
# Shorten if too long
|
|
201
|
+
if len(summary) > 45:
|
|
202
|
+
parts = summary.split("/")[-1]
|
|
203
|
+
summary = "vllm/..." + parts
|
|
204
|
+
summary = f"{summary} [CUDA graph]"
|
|
205
|
+
elif len(summary) > 53:
|
|
206
|
+
parts = summary.split("/")[-1]
|
|
207
|
+
summary = "vllm/..." + parts
|
|
208
|
+
else:
|
|
209
|
+
# Fallback to innermost function
|
|
210
|
+
summary = leaf_name
|
|
211
|
+
|
|
212
|
+
return summary, full_stack
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def load_trace(
|
|
216
|
+
file_path: str | Path,
|
|
217
|
+
) -> tuple[str, str, dict[str, Any], pd.DataFrame, dict[tuple[str, str], set[str]], dict[int, int]]:
|
|
218
|
+
"""Load trace and return platform info, device properties, kernels, patterns, and layer mapping.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
file_path: Path to JSON trace file
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Tuple of (platform, gpu_name, device_props, kernel_df, kernel_patterns, layer_mapping)
|
|
225
|
+
"""
|
|
226
|
+
with open(file_path, "rb") as f:
|
|
227
|
+
trace = json.load(f)
|
|
228
|
+
|
|
229
|
+
props = trace.get("deviceProperties", [{}])[0]
|
|
230
|
+
is_amd = trace.get("roctracer_version") or props.get("warpSize") == 64
|
|
231
|
+
platform = "AMD" if is_amd else "NVIDIA"
|
|
232
|
+
gpu_name = props.get("name", "MI300X" if is_amd else "Unknown GPU")
|
|
233
|
+
|
|
234
|
+
# Extract relevant device properties
|
|
235
|
+
device_props = {
|
|
236
|
+
"name": gpu_name,
|
|
237
|
+
"compute_capability": f"{props.get('computeMajor', 0)}.{props.get('computeMinor', 0)}",
|
|
238
|
+
"total_memory_gb": props.get("totalGlobalMem", 0) / (1024**3),
|
|
239
|
+
"sm_count": props.get("numSms", 0),
|
|
240
|
+
"warp_size": props.get("warpSize", 32),
|
|
241
|
+
"max_threads_per_block": props.get("maxThreadsPerBlock", 0),
|
|
242
|
+
"shared_mem_per_block_kb": props.get("sharedMemPerBlock", 0) / 1024,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
events = trace.get("traceEvents", [])
|
|
246
|
+
|
|
247
|
+
# Build mapping: external_id -> CPU operator name
|
|
248
|
+
external_to_cpu = {}
|
|
249
|
+
for ev in events:
|
|
250
|
+
if ev.get("cat") == "cpu_op":
|
|
251
|
+
ext_id = ev.get("args", {}).get("External id")
|
|
252
|
+
cpu_op_name = ev.get("name", "")
|
|
253
|
+
if ext_id is not None:
|
|
254
|
+
external_to_cpu[ext_id] = cpu_op_name
|
|
255
|
+
|
|
256
|
+
# Build Python call stack index for kernels without External IDs
|
|
257
|
+
python_intervals, python_by_id = _build_python_stack_index(events)
|
|
258
|
+
|
|
259
|
+
# Extract phases
|
|
260
|
+
phases = []
|
|
261
|
+
for ev in events:
|
|
262
|
+
if ev.get("cat") == "user_annotation" and ev.get("name", "").startswith(
|
|
263
|
+
"execute_context"
|
|
264
|
+
):
|
|
265
|
+
name = ev["name"]
|
|
266
|
+
# Parse execute_context_X(TOKENS)_generation_Y(Y)
|
|
267
|
+
# We want the TOKENS from execute_context, not the generation number
|
|
268
|
+
tokens = 0
|
|
269
|
+
parts = name.split("_")
|
|
270
|
+
for i, p in enumerate(parts):
|
|
271
|
+
# Look for execute_context_X(TOKENS) specifically
|
|
272
|
+
if i > 0 and parts[i-1] == "context" and "(" in p and ")" in p:
|
|
273
|
+
try:
|
|
274
|
+
tokens = int(p.split("(")[1].split(")")[0])
|
|
275
|
+
break # Stop after finding context tokens
|
|
276
|
+
except Exception:
|
|
277
|
+
pass
|
|
278
|
+
is_prefill = tokens >= 1024 and "generation_0" in name
|
|
279
|
+
phases.append(
|
|
280
|
+
{
|
|
281
|
+
"type": "prefill" if is_prefill else "decode",
|
|
282
|
+
"ts_start": ev["ts"],
|
|
283
|
+
"ts_end": ev["ts"] + ev["dur"],
|
|
284
|
+
}
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Extract layer mapping from correlation IDs
|
|
288
|
+
layer_mapping = extract_layer_mapping(events, platform)
|
|
289
|
+
|
|
290
|
+
kernel_data = []
|
|
291
|
+
kernel_patterns: dict[tuple[str, str], set[str]] = defaultdict(set)
|
|
292
|
+
|
|
293
|
+
for ev in events:
|
|
294
|
+
if ev.get("cat") != "kernel":
|
|
295
|
+
continue
|
|
296
|
+
name, dur, ts = ev["name"], ev.get("dur", 0), ev["ts"]
|
|
297
|
+
corr_id = ev.get("args", {}).get("correlation")
|
|
298
|
+
ext_id = ev.get("args", {}).get("External id")
|
|
299
|
+
|
|
300
|
+
phase = "decode"
|
|
301
|
+
for p in phases:
|
|
302
|
+
if p["ts_start"] <= ts <= p["ts_end"]:
|
|
303
|
+
phase = p["type"]
|
|
304
|
+
break
|
|
305
|
+
|
|
306
|
+
op, pattern = classify(name, platform)
|
|
307
|
+
kernel_patterns[(op.value, phase)].add(pattern)
|
|
308
|
+
|
|
309
|
+
# Assign layer number from correlation ID
|
|
310
|
+
layer = layer_mapping.get(corr_id) if corr_id is not None else None
|
|
311
|
+
|
|
312
|
+
# Get CPU operator name from external ID, or fallback to Python stack
|
|
313
|
+
cpu_op = external_to_cpu.get(ext_id) if ext_id is not None else None
|
|
314
|
+
python_stack: list[str] = []
|
|
315
|
+
|
|
316
|
+
# If no CPU op via External ID, try Python stack trace
|
|
317
|
+
if cpu_op is None:
|
|
318
|
+
cpu_op, python_stack = _get_python_stack_full(
|
|
319
|
+
ts, python_intervals, python_by_id
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
kernel_data.append(
|
|
323
|
+
{
|
|
324
|
+
"name": name,
|
|
325
|
+
"dur_us": dur,
|
|
326
|
+
"phase": phase,
|
|
327
|
+
"op": op.value,
|
|
328
|
+
"pattern": pattern,
|
|
329
|
+
"layer": layer,
|
|
330
|
+
"correlation": corr_id,
|
|
331
|
+
"cpu_op": cpu_op,
|
|
332
|
+
"python_stack": python_stack, # Full stack for JSON output
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
return platform, gpu_name, device_props, pd.DataFrame(kernel_data), dict(kernel_patterns), layer_mapping
|
wafer_core/problem_config.py
CHANGED
|
@@ -84,7 +84,7 @@ class ProblemConfig:
|
|
|
84
84
|
benchmarks: list[dict[str, Any]]
|
|
85
85
|
|
|
86
86
|
# Optional with defaults
|
|
87
|
-
model: str = "claude-
|
|
87
|
+
model: str = "claude-opus-4-5-20251101"
|
|
88
88
|
temperature: float = 0.2
|
|
89
89
|
max_tokens: int = 8192
|
|
90
90
|
max_turns: int = 10
|
|
@@ -219,7 +219,7 @@ def _parse_config(data: dict[str, Any], base_dir: Path) -> tuple[ProblemConfig |
|
|
|
219
219
|
reference_code=reference_code,
|
|
220
220
|
tests=tests,
|
|
221
221
|
benchmarks=benchmarks,
|
|
222
|
-
model=data.get("model", "claude-
|
|
222
|
+
model=data.get("model", "claude-opus-4-5-20251101"),
|
|
223
223
|
temperature=data.get("temperature", 0.2),
|
|
224
224
|
max_tokens=data.get("max_tokens", 8192),
|
|
225
225
|
max_turns=data.get("max_turns", 10),
|
|
@@ -269,7 +269,7 @@ def create_problem_config_from_cli(
|
|
|
269
269
|
reference_code=reference_code,
|
|
270
270
|
tests=tests,
|
|
271
271
|
benchmarks=benchmarks or tests, # Use tests as benchmarks if not specified
|
|
272
|
-
model=kwargs.get("model", "claude-
|
|
272
|
+
model=kwargs.get("model", "claude-opus-4-5-20251101"),
|
|
273
273
|
temperature=kwargs.get("temperature", 0.2),
|
|
274
274
|
max_tokens=kwargs.get("max_tokens", 8192),
|
|
275
275
|
max_turns=kwargs.get("max_turns", 10),
|
|
@@ -119,7 +119,7 @@ FINAL(42)
|
|
|
119
119
|
|
|
120
120
|
config = AgentPresetConfig(
|
|
121
121
|
name="rlm",
|
|
122
|
-
model="anthropic/claude-
|
|
122
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
123
123
|
env="repl", # Uses REPLEnvironment
|
|
124
124
|
thinking=True,
|
|
125
125
|
system_prompt=RLM_TOOL_SYSTEM_PROMPT,
|
|
@@ -128,7 +128,7 @@ config = AgentPresetConfig(
|
|
|
128
128
|
# Variant for message-parsing mode
|
|
129
129
|
config_block_mode = AgentPresetConfig(
|
|
130
130
|
name="rlm_blocks",
|
|
131
|
-
model="anthropic/claude-
|
|
131
|
+
model="anthropic/claude-opus-4-5-20251101",
|
|
132
132
|
env="repl_blocks", # Uses MessageParsingREPLEnvironment
|
|
133
133
|
thinking=True,
|
|
134
134
|
system_prompt=RLM_BLOCK_SYSTEM_PROMPT,
|
wafer_core/rollouts/dtypes.py
CHANGED
|
@@ -1238,6 +1238,12 @@ class Endpoint(JsonSerializable):
|
|
|
1238
1238
|
api_base: str = ""
|
|
1239
1239
|
api_key: str = ""
|
|
1240
1240
|
oauth_token: str = "" # OAuth bearer token (takes precedence over api_key for Anthropic)
|
|
1241
|
+
# TODO: Callbacks on a frozen dataclass are a code smell. This exists because wafer-core
|
|
1242
|
+
# can't depend on wafer-cli (where the Supabase refresh logic lives). A cleaner approach
|
|
1243
|
+
# would be a TokenProvider protocol that Endpoint delegates to, keeping the dataclass pure.
|
|
1244
|
+
api_key_refresh: Callable[[], Awaitable[str | None]] | None = field(
|
|
1245
|
+
default=None, repr=False, compare=False
|
|
1246
|
+
)
|
|
1241
1247
|
is_claude_code_api_key: bool = (
|
|
1242
1248
|
False # API key created via Claude Code OAuth (requires special headers)
|
|
1243
1249
|
)
|
|
@@ -1300,6 +1306,7 @@ class Endpoint(JsonSerializable):
|
|
|
1300
1306
|
exclude_secrets: If True (default), omits api_key and oauth_token.
|
|
1301
1307
|
"""
|
|
1302
1308
|
d = asdict(self)
|
|
1309
|
+
d.pop("api_key_refresh", None) # Callable, not serializable
|
|
1303
1310
|
if exclude_secrets:
|
|
1304
1311
|
d.pop("api_key", None)
|
|
1305
1312
|
d.pop("oauth_token", None)
|
|
@@ -1307,7 +1314,11 @@ class Endpoint(JsonSerializable):
|
|
|
1307
1314
|
|
|
1308
1315
|
@classmethod
|
|
1309
1316
|
def from_dict(
|
|
1310
|
-
cls,
|
|
1317
|
+
cls,
|
|
1318
|
+
data: dict[str, Any],
|
|
1319
|
+
api_key: str = "",
|
|
1320
|
+
oauth_token: str = "",
|
|
1321
|
+
api_key_refresh: "Callable[[], Awaitable[str | None]] | None" = None,
|
|
1311
1322
|
) -> "Endpoint":
|
|
1312
1323
|
"""Deserialize from dict, injecting secrets at runtime.
|
|
1313
1324
|
|
|
@@ -1315,12 +1326,16 @@ class Endpoint(JsonSerializable):
|
|
|
1315
1326
|
data: Dict from to_dict()
|
|
1316
1327
|
api_key: API key to inject (not stored in session)
|
|
1317
1328
|
oauth_token: OAuth token to inject (not stored in session)
|
|
1329
|
+
api_key_refresh: Callback to refresh api_key mid-session (not stored)
|
|
1318
1330
|
"""
|
|
1319
|
-
# Remove secrets if present (they shouldn't be, but be safe)
|
|
1331
|
+
# Remove secrets/callables if present (they shouldn't be, but be safe)
|
|
1320
1332
|
data = data.copy()
|
|
1321
1333
|
data.pop("api_key", None)
|
|
1322
1334
|
data.pop("oauth_token", None)
|
|
1323
|
-
|
|
1335
|
+
data.pop("api_key_refresh", None)
|
|
1336
|
+
return cls(
|
|
1337
|
+
**data, api_key=api_key, oauth_token=oauth_token, api_key_refresh=api_key_refresh
|
|
1338
|
+
)
|
|
1324
1339
|
|
|
1325
1340
|
|
|
1326
1341
|
@dataclass(frozen=True)
|
|
@@ -725,9 +725,16 @@ async def rollout_anthropic(
|
|
|
725
725
|
oauth_token = fresh_token
|
|
726
726
|
# If refresh failed, continue with existing token - it might still work
|
|
727
727
|
|
|
728
|
+
# Get fresh wafer proxy token if refresh callback is available
|
|
729
|
+
api_key = actor.endpoint.api_key
|
|
730
|
+
if actor.endpoint.api_key_refresh:
|
|
731
|
+
fresh_key = await actor.endpoint.api_key_refresh()
|
|
732
|
+
if fresh_key:
|
|
733
|
+
api_key = fresh_key
|
|
734
|
+
|
|
728
735
|
client = _create_anthropic_client(
|
|
729
736
|
oauth_token=oauth_token,
|
|
730
|
-
api_key=
|
|
737
|
+
api_key=api_key,
|
|
731
738
|
api_base=actor.endpoint.api_base,
|
|
732
739
|
max_retries=actor.endpoint.max_retries,
|
|
733
740
|
timeout=actor.endpoint.timeout,
|
|
@@ -973,7 +980,7 @@ async def rollout_anthropic(
|
|
|
973
980
|
f"Model not found: {e}\nCheck your model ID is correct."
|
|
974
981
|
) from e
|
|
975
982
|
|
|
976
|
-
#
|
|
983
|
+
# Try to refresh token and retry once on auth errors
|
|
977
984
|
if isinstance(e, anthropic.AuthenticationError):
|
|
978
985
|
if oauth_token and attempt == 0:
|
|
979
986
|
# Emit retry event for OAuth refresh
|
|
@@ -993,12 +1000,37 @@ async def rollout_anthropic(
|
|
|
993
1000
|
await client.close()
|
|
994
1001
|
client = _create_anthropic_client(
|
|
995
1002
|
oauth_token=oauth_token,
|
|
996
|
-
api_key=
|
|
1003
|
+
api_key=api_key,
|
|
997
1004
|
api_base=actor.endpoint.api_base,
|
|
998
1005
|
max_retries=actor.endpoint.max_retries,
|
|
999
1006
|
timeout=actor.endpoint.timeout,
|
|
1000
1007
|
)
|
|
1001
1008
|
continue
|
|
1009
|
+
|
|
1010
|
+
# Wafer proxy token refresh (Supabase JWTs expire after ~1hr)
|
|
1011
|
+
if actor.endpoint.api_key_refresh and attempt == 0:
|
|
1012
|
+
await on_chunk(
|
|
1013
|
+
RetryStart(
|
|
1014
|
+
attempt=1,
|
|
1015
|
+
max_attempts=2,
|
|
1016
|
+
delay_seconds=0,
|
|
1017
|
+
error_message="Wafer proxy token expired, refreshing",
|
|
1018
|
+
provider="anthropic",
|
|
1019
|
+
)
|
|
1020
|
+
)
|
|
1021
|
+
fresh_key = await actor.endpoint.api_key_refresh()
|
|
1022
|
+
if fresh_key and fresh_key != api_key:
|
|
1023
|
+
api_key = fresh_key
|
|
1024
|
+
await client.close()
|
|
1025
|
+
client = _create_anthropic_client(
|
|
1026
|
+
oauth_token=oauth_token,
|
|
1027
|
+
api_key=api_key,
|
|
1028
|
+
api_base=actor.endpoint.api_base,
|
|
1029
|
+
max_retries=actor.endpoint.max_retries,
|
|
1030
|
+
timeout=actor.endpoint.timeout,
|
|
1031
|
+
)
|
|
1032
|
+
continue
|
|
1033
|
+
|
|
1002
1034
|
raise FatalEvalError(
|
|
1003
1035
|
f"Authentication failed: {e}\nCheck your API key or OAuth token."
|
|
1004
1036
|
) from e
|
|
@@ -12,6 +12,16 @@ Attack types defended against:
|
|
|
12
12
|
5. Monkey-patching - Replacing CUDA timing functions with fake implementations
|
|
13
13
|
|
|
14
14
|
Reference: "Hacks and Defenses in Automatic GPU Kernel Generation" by Jiwei Li (Dec 2025)
|
|
15
|
+
|
|
16
|
+
TODO: Memory guard buffers (from CUDA-L2's zero_one_correctness_check.py) — wrap
|
|
17
|
+
input/output tensors with guard regions and check for out-of-bounds writes after
|
|
18
|
+
kernel execution. Catches shared memory overflow and buffer overrun at the memory
|
|
19
|
+
boundary, rather than inferring from output non-determinism.
|
|
20
|
+
|
|
21
|
+
TODO: Exact correctness for GEMM kernels (from CUDA-L2) — use {0,1} input matrices
|
|
22
|
+
where FP16 results ≤2048 are exactly representable, enabling zero-tolerance
|
|
23
|
+
validation (torch.equal instead of torch.allclose). Eliminates the "bounded garbage
|
|
24
|
+
passes tolerance check" failure mode for matmul kernels entirely.
|
|
15
25
|
"""
|
|
16
26
|
|
|
17
27
|
import random
|
|
@@ -21,6 +21,12 @@ if TYPE_CHECKING:
|
|
|
21
21
|
from wafer_core.utils.kernel_utils.deployment import DeploymentConfig
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
# TODO: Split BaremetalTarget into BaremetalTarget (persistent servers like Vultr,
|
|
25
|
+
# never auto-removed) and SSHTarget (ephemeral SSH endpoints from providers like
|
|
26
|
+
# RunPod/DO, safe to auto-clean when unreachable). Currently the pool bridge creates
|
|
27
|
+
# ephemeral pod endpoints as type="baremetal", losing provenance. SSHTarget should
|
|
28
|
+
# subclass BaremetalTarget so existing isinstance() checks still work. The `provider`
|
|
29
|
+
# field is a stopgap until this split happens.
|
|
24
30
|
@dataclass(frozen=True)
|
|
25
31
|
class BaremetalTarget:
|
|
26
32
|
"""Configuration for baremetal GPU server.
|
|
@@ -59,6 +65,9 @@ class BaremetalTarget:
|
|
|
59
65
|
gpu_type: str = "B200"
|
|
60
66
|
compute_capability: str = "10.0"
|
|
61
67
|
ncu_available: bool = True # Baremetal typically has NCU
|
|
68
|
+
provider: str | None = (
|
|
69
|
+
None # Source provider ("runpod", "digitalocean") — enables auto-cleanup when instance is gone
|
|
70
|
+
)
|
|
62
71
|
|
|
63
72
|
# Docker execution config (Modal-like). If docker_image is set, run in container.
|
|
64
73
|
docker_image: str | None = (
|
|
@@ -314,6 +323,7 @@ class RunPodTarget:
|
|
|
314
323
|
# apt-get install --reinstall -y rocthrust
|
|
315
324
|
# See docker/rocm7-runpod/README.md for details.
|
|
316
325
|
image: str = "rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.7.1"
|
|
326
|
+
template_id: str | None = None # RunPod template ID for custom pod configuration
|
|
317
327
|
|
|
318
328
|
# RunPod template ID — required for non-RunPod images that need custom
|
|
319
329
|
# dockerArgs (e.g. to install and start sshd). When set, takes priority
|
|
@@ -3,7 +3,7 @@ wafer_core/async_ssh.py,sha256=ocw2Gh5p8ltKeoqG_q32DXOBfu5q-IE7jCnzMbQN9WI,28713
|
|
|
3
3
|
wafer_core/auth.py,sha256=JpUkZ3bROIsgexayak5TLiGqUAR5kqGjekwqQRvIXH0,7235
|
|
4
4
|
wafer_core/gpu.py,sha256=ENa92btjXsx6ldpoyKfRrAmfy-LHG2KpA5k7SWd6Q_s,28627
|
|
5
5
|
wafer_core/gpu_detect.py,sha256=kpD8Q_G6GA9j-WnnnTNA3BBPulkGcWnZiogOmjKDao0,13650
|
|
6
|
-
wafer_core/problem_config.py,sha256=
|
|
6
|
+
wafer_core/problem_config.py,sha256=IM4ZRul4306dF7yo8wwyxXYORUZ7nz5wnphG59HN6fo,10907
|
|
7
7
|
wafer_core/remote_env.py,sha256=0ACTL-A_qn2B43qgQakqGaern-pspvwBGB9iebz199k,15354
|
|
8
8
|
wafer_core/remote_jobs.py,sha256=7HdBDCigSxfp32BreWoljzG5xjK6fp25rwC_6D7D04s,8306
|
|
9
9
|
wafer_core/retry.py,sha256=OIvSElJZbSm4-SFBpOFuYtoX2DWGiANomCmb3qxsirM,14821
|
|
@@ -318,6 +318,12 @@ wafer_core/lib/rocprofiler/systems/run/analyzer.py,sha256=Qg3M8-kCKdV82ehn6Ta20N
|
|
|
318
318
|
wafer_core/lib/rocprofiler/systems/run/profiler.py,sha256=aiQLsDnfQHSeCM5zLnO4VlbTmREYnAtiuT50Eq6uWfg,8387
|
|
319
319
|
wafer_core/lib/rocprofiler/systems/sample/__init__.py,sha256=31rNmLPQ7OVhvlOEEOwPKgk8_qrCidj6AmzDXexQJ_o,288
|
|
320
320
|
wafer_core/lib/rocprofiler/systems/sample/profiler.py,sha256=CYZPTzNXd48LoCfmY6h_5RSYEdWYccuv3-t4YncHJLE,7384
|
|
321
|
+
wafer_core/lib/trace_compare/__init__.py,sha256=G5vmiQnuweiF9vjK1FC4ZIy-tzuHiaLMs7QBnir8OJw,800
|
|
322
|
+
wafer_core/lib/trace_compare/analyzer.py,sha256=o0SI1PsehpgxeUPQEB9708W_Q_ILiO5apgqVLe2xE8A,14541
|
|
323
|
+
wafer_core/lib/trace_compare/classifier.py,sha256=sE1K007GVk_Up2g59SVUIZ7BThf0yHNjGsZ9AyM_Ah8,6028
|
|
324
|
+
wafer_core/lib/trace_compare/formatter.py,sha256=GNrCZ45ueBN05CEXjOtTuKvTI8z-g-ZZFil-ni3sWVY,37962
|
|
325
|
+
wafer_core/lib/trace_compare/fusion_analyzer.py,sha256=LwYTBjL_gHCvydfgFp-L9f_qfXq3GenJHRemygly4H8,36482
|
|
326
|
+
wafer_core/lib/trace_compare/loader.py,sha256=E7-OS4uMqvJhGLyxKQNnAgK33YECrSjuCssUT_X0LQA,11728
|
|
321
327
|
wafer_core/lib/tracelens/__init__.py,sha256=AkHdmOnKlBO4RpsAqVVGe7MOfv6E6uhEaC_iKrYeMPI,2002
|
|
322
328
|
wafer_core/lib/tracelens/comparator.py,sha256=71YEPfjBi7_24u1oQuPerNtFsN0sDQ5CT_uBi0XLllw,3460
|
|
323
329
|
wafer_core/lib/tracelens/finder.py,sha256=HpbN8TuRNbbBytPYOmkBkfsFVBReQqVgsvFX-mBrln4,2459
|
|
@@ -336,7 +342,7 @@ wafer_core/rollouts/agents.py,sha256=Uv1kjYogUfdPl18YfkVxVqFTbmWfuJQrxem_iHTUgdw
|
|
|
336
342
|
wafer_core/rollouts/cli.py,sha256=2NqgegKdlmxD0eJzGOMB5o_1Hb5t7O5JpP_32uvF2BE,80117
|
|
337
343
|
wafer_core/rollouts/cli_agents.py,sha256=e4qqqYBzWLsbw8FsNnddGApWp_on9Cvzrfd1amiAyvI,20641
|
|
338
344
|
wafer_core/rollouts/deploy.py,sha256=3t88fM_BMyAPkxIl8pS4r5ogHJvrlqWQDuIaltDZBRc,40924
|
|
339
|
-
wafer_core/rollouts/dtypes.py,sha256=
|
|
345
|
+
wafer_core/rollouts/dtypes.py,sha256=oRWjpbUOTf4uyXvnO9QThcSzD1fBrDQnAfRhGbxdgrg,61916
|
|
340
346
|
wafer_core/rollouts/eval_helpers.py,sha256=OE7uQZRcbqQhpFqb4zOj8zafc9Gr6xZJpSrMvxXKVUw,1699
|
|
341
347
|
wafer_core/rollouts/evaluation.py,sha256=fk-pGZ5vpocVmw1iBbHtxMK0K6l8pYTLHCpDNvRY1Xo,69142
|
|
342
348
|
wafer_core/rollouts/events.py,sha256=z85J8kq0LXPj5CiUk4RkiTQg--r9xiO7QeeJwkyUOto,7505
|
|
@@ -371,7 +377,7 @@ wafer_core/rollouts/agent_presets/gpt_5_1_codex_04_04.py,sha256=42NIBBYAnVoy5mbu
|
|
|
371
377
|
wafer_core/rollouts/agent_presets/gpt_5_2_03_03.py,sha256=lEsHRUhhr8UbP5wSVKMOVDVOOtH_bQMRRgZ0dRGZMVc,1166
|
|
372
378
|
wafer_core/rollouts/agent_presets/loader.py,sha256=WSkTbL7QhgMamZR5sXxep1n4cuy8LC3a4MN2phYTm-4,3666
|
|
373
379
|
wafer_core/rollouts/agent_presets/opus_4_01_01.py,sha256=rurZMI-Df7O-Q-uVJj2zfY_DSjdNbMKBDZlRg9MLADc,3568
|
|
374
|
-
wafer_core/rollouts/agent_presets/rlm_01_01.py,sha256=
|
|
380
|
+
wafer_core/rollouts/agent_presets/rlm_01_01.py,sha256=jsjwDgACQxxJj4GYOUCcQvYjcICAaKV3eccQu9oyEcw,4781
|
|
375
381
|
wafer_core/rollouts/agent_presets/sonnet_4_02_02.py,sha256=ZdHNxioki3wsfD6ficgB2r7HkgQDH_trCR-baGFgoHk,1269
|
|
376
382
|
wafer_core/rollouts/agent_presets/sonnet_4_subagent_03_02.py,sha256=nxyjs4HWAPOAYLmPknSQr3viBXhboKC7wQ76LWB-jA0,2165
|
|
377
383
|
wafer_core/rollouts/config/README.md,sha256=i0r0a3sKLkc1Eq3EqqR2Gahsgo-c8O3OZ0cCh7rp8Uw,9899
|
|
@@ -495,7 +501,7 @@ wafer_core/rollouts/prompt_optimization/adapters/system_prompt.py,sha256=CWFox1N
|
|
|
495
501
|
wafer_core/rollouts/prompt_optimization/adapters/system_user_prompt.py,sha256=8JsSirihgZ5gacyYhn31GagyIxG0xQ7f7i4PnEupWz8,12090
|
|
496
502
|
wafer_core/rollouts/prompt_optimization/adapters/terminal_bench.py,sha256=Etswuqf5dBIZQ2x2p26AXz4LT33YxT2qEeHqKXTJy18,12273
|
|
497
503
|
wafer_core/rollouts/providers/__init__.py,sha256=Xu8PPDHOmF97ylMJXfE9JX2FJCanNVh7LXkHMmg0vWs,3121
|
|
498
|
-
wafer_core/rollouts/providers/anthropic.py,sha256=
|
|
504
|
+
wafer_core/rollouts/providers/anthropic.py,sha256=9x1GIL6JE8gutxVrLNiyAkymknIEKtl-98TnIUpFxoI,45223
|
|
499
505
|
wafer_core/rollouts/providers/base.py,sha256=2ADu6pDz6yEcazo4j6-O12rs19bPewAfycjK_N03ZkY,14544
|
|
500
506
|
wafer_core/rollouts/providers/google.py,sha256=IbqdXOpzSuMdI7eOZqRtzni85ysKby13PGe482Fq13w,22073
|
|
501
507
|
wafer_core/rollouts/providers/openai_completions.py,sha256=3vUA74qjrxG-aOjyngtnZp0MzIhnzW5kudwxmOGxXfs,28820
|
|
@@ -655,7 +661,7 @@ wafer_core/utils/remote_execution.py,sha256=z7nLiOgmDiM_VmElLnT2LF-aKNeeKFYjXigT
|
|
|
655
661
|
wafer_core/utils/submission_selection.py,sha256=LucdMTAbkqZA-GitSb3ZJ2pAeJ36wKqt5cTeS8xuAQ4,5655
|
|
656
662
|
wafer_core/utils/kernel_utils/__init__.py,sha256=NsfKpbfpIsfupWIpIjWLGCjGAVqaONiwiWil5zXbrRc,2015
|
|
657
663
|
wafer_core/utils/kernel_utils/backends.py,sha256=t3wY73Y-pVc_wALNu_bPsaFkqJ2dp2pf38KQ5ofP_go,1143
|
|
658
|
-
wafer_core/utils/kernel_utils/defense.py,sha256=
|
|
664
|
+
wafer_core/utils/kernel_utils/defense.py,sha256=8tHVTZlJfFcB_FWjNZfeGHwReSjG191OmFXtWXa07OM,20124
|
|
659
665
|
wafer_core/utils/kernel_utils/deployment.py,sha256=-tMb3qWmAoXHWCmmT7SQBH7KBKyyLP0e5Dk6lOrTPW8,55957
|
|
660
666
|
wafer_core/utils/kernel_utils/evaluate.py,sha256=1kxFNMl9VCXfKfk_BIiuA_zFfvDB1sl_feS2OEIJA1k,72346
|
|
661
667
|
wafer_core/utils/kernel_utils/gpu_validation.py,sha256=LRiDjW_xAK4fXf1Vw1aYHG54B1W0J6b5L0K6PXzM2tI,3759
|
|
@@ -665,7 +671,7 @@ wafer_core/utils/kernel_utils/static_checker.py,sha256=XIQkzAOkGH5xtrOuZM4tNUqVJ
|
|
|
665
671
|
wafer_core/utils/kernel_utils/task.py,sha256=XcmKxKUWh5It6nX3zGqj77tWgA32uPfQMqNOqyD5T48,2682
|
|
666
672
|
wafer_core/utils/kernel_utils/utils.py,sha256=uDZoJDxh07hJeLNlPdKN2vgB15pqIr1LbXf0YIBHU4E,43056
|
|
667
673
|
wafer_core/utils/kernel_utils/targets/__init__.py,sha256=4NwRLsuJ__S4xKAfda4Ag82C5MQ3Qio-4xA5S-mQGlU,2067
|
|
668
|
-
wafer_core/utils/kernel_utils/targets/config.py,sha256=
|
|
674
|
+
wafer_core/utils/kernel_utils/targets/config.py,sha256=V587DYkisEFoWwkmLQUW6I0mXkMEwA52sM7ZINslkK8,20625
|
|
669
675
|
wafer_core/utils/kernel_utils/targets/execution.py,sha256=bZuNXCo0sIdD6hFhetLPrtDC-zMSiIsAx_aml49VVL0,15033
|
|
670
676
|
wafer_core/utils/kernel_utils/targets/selection.py,sha256=5I_RG_7cfhq7uaeR28meC2EeNNKssFsK-Tc3QFG6Ze0,3590
|
|
671
677
|
wafer_core/utils/modal_execution/__init__.py,sha256=jkVqYOLzCT5K73N9Od0UIUsx-99A0m6bpDrxfyXxQZ8,945
|
|
@@ -673,6 +679,6 @@ wafer_core/utils/modal_execution/modal_app.py,sha256=VfS2cX8gHtnlPXemmMcEwDPeQdh
|
|
|
673
679
|
wafer_core/utils/modal_execution/modal_config.py,sha256=7cGX9TGqilQ3qxI3OFGXV5orjtyRU-PEDOJ4vP2oxno,4421
|
|
674
680
|
wafer_core/utils/modal_execution/modal_execution.py,sha256=gChjnV6jqA3A7IRP3DfvV5cSfm_MN0X4f7JZufXgdZE,24594
|
|
675
681
|
wafer_core/utils/modal_execution/test_modal.py,sha256=_jqou_hrLs1Daf1590Pnb0a_lXMMa2rczAPpW9HpoNQ,8153
|
|
676
|
-
wafer_core-0.1.
|
|
677
|
-
wafer_core-0.1.
|
|
678
|
-
wafer_core-0.1.
|
|
682
|
+
wafer_core-0.1.26.dist-info/METADATA,sha256=xzTIIcsmbJkA06hTdoRb4uXZj2ud1-wnV7EXdLOSOe4,1420
|
|
683
|
+
wafer_core-0.1.26.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
684
|
+
wafer_core-0.1.26.dist-info/RECORD,,
|
|
File without changes
|