wafer-core 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/tools/dispatch_baseline/__init__.py +73 -0
- wafer_core/tools/dispatch_baseline/analyzer.py +174 -0
- wafer_core/tools/dispatch_baseline/client.py +196 -0
- wafer_core/tools/dispatch_baseline/codegen.py +246 -0
- wafer_core/tools/dispatch_baseline/dtypes.py +217 -0
- wafer_core/tools/dispatch_baseline/executor.py +360 -0
- wafer_core/tools/dispatch_baseline/roofline.py +165 -0
- wafer_core/utils/kernel_utils/defense.py +812 -10
- wafer_core/utils/kernel_utils/test_reward_hacks.py +140 -0
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/METADATA +1 -1
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/RECORD +12 -4
- {wafer_core-0.1.31.dist-info → wafer_core-0.1.32.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""Data types for kernel trace tool."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Literal
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class TensorSpec:
|
|
9
|
+
"""Specification for a tensor operand."""
|
|
10
|
+
|
|
11
|
+
name: str
|
|
12
|
+
shape: tuple[int, ...]
|
|
13
|
+
dtype: str = "float16" # e.g., "float16", "float32", "bfloat16", "int8"
|
|
14
|
+
device: str = "cuda"
|
|
15
|
+
|
|
16
|
+
def __str__(self) -> str:
|
|
17
|
+
shape_str = "x".join(str(d) for d in self.shape)
|
|
18
|
+
return f"{self.name}[{shape_str}] ({self.dtype})"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class OpSpec:
|
|
23
|
+
"""Specification for a PyTorch operation to trace.
|
|
24
|
+
|
|
25
|
+
Example:
|
|
26
|
+
OpSpec(
|
|
27
|
+
op="torch.matmul",
|
|
28
|
+
inputs=[
|
|
29
|
+
TensorSpec("A", (4096, 4096), "float16"),
|
|
30
|
+
TensorSpec("B", (4096, 4096), "float16"),
|
|
31
|
+
],
|
|
32
|
+
)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
op: str # e.g., "torch.matmul", "torch.nn.functional.softmax", "torch.conv2d"
|
|
36
|
+
inputs: list[TensorSpec]
|
|
37
|
+
kwargs: dict[str, str] = field(default_factory=dict) # e.g., {"dim": "-1"}
|
|
38
|
+
|
|
39
|
+
def __str__(self) -> str:
|
|
40
|
+
inputs_str = ", ".join(t.name for t in self.inputs)
|
|
41
|
+
kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
|
|
42
|
+
if kwargs_str:
|
|
43
|
+
return f"{self.op}({inputs_str}, {kwargs_str})"
|
|
44
|
+
return f"{self.op}({inputs_str})"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class KernelInfo:
|
|
49
|
+
"""Information about a dispatched kernel."""
|
|
50
|
+
|
|
51
|
+
name: str # e.g., "sm90_xmma_gemm_f16f16f16_f32..."
|
|
52
|
+
duration_us: float = 0.0 # Duration in microseconds
|
|
53
|
+
grid_size: tuple[int, int, int] | None = None # (X, Y, Z)
|
|
54
|
+
block_size: tuple[int, int, int] | None = None # (X, Y, Z)
|
|
55
|
+
registers_per_thread: int | None = None
|
|
56
|
+
shared_memory_bytes: int | None = None
|
|
57
|
+
# Performance metrics (from NCU, if available)
|
|
58
|
+
compute_throughput_tflops: float | None = None
|
|
59
|
+
memory_throughput_tbps: float | None = None # TB/s
|
|
60
|
+
achieved_occupancy_pct: float | None = None
|
|
61
|
+
|
|
62
|
+
def __str__(self) -> str:
|
|
63
|
+
return f"{self.name} {self.duration_us:.1f} µs"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class HardwareSpec:
|
|
68
|
+
"""Hardware specifications for roofline analysis."""
|
|
69
|
+
|
|
70
|
+
name: str # e.g., "H100", "MI300X"
|
|
71
|
+
peak_fp16_tflops: float
|
|
72
|
+
peak_fp32_tflops: float
|
|
73
|
+
peak_memory_bw_tbps: float # TB/s
|
|
74
|
+
# Optional extras
|
|
75
|
+
peak_fp8_tflops: float | None = None
|
|
76
|
+
peak_int8_tops: float | None = None
|
|
77
|
+
shared_memory_per_sm_kb: float | None = None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass(frozen=True)
|
|
81
|
+
class RooflineAnalysis:
|
|
82
|
+
"""Roofline analysis result."""
|
|
83
|
+
|
|
84
|
+
# Raw metrics
|
|
85
|
+
achieved_tflops: float
|
|
86
|
+
achieved_memory_bw_tbps: float
|
|
87
|
+
# Percentages of peak
|
|
88
|
+
compute_pct_of_peak: float
|
|
89
|
+
memory_bw_pct_of_peak: float
|
|
90
|
+
# Bottleneck identification
|
|
91
|
+
bottleneck: Literal["compute", "memory", "balanced"]
|
|
92
|
+
arithmetic_intensity: float # FLOPS per byte
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@dataclass(frozen=True)
|
|
96
|
+
class KernelTraceConfig:
|
|
97
|
+
"""Configuration for kernel tracing."""
|
|
98
|
+
|
|
99
|
+
op_spec: OpSpec
|
|
100
|
+
hardware: str # e.g., "H100", "MI300X"
|
|
101
|
+
num_warmup: int = 10
|
|
102
|
+
num_runs: int = 100
|
|
103
|
+
use_ncu: bool = False # Use NCU for detailed metrics (slower)
|
|
104
|
+
timeout_seconds: int = 120
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass(frozen=True)
|
|
108
|
+
class KernelTraceResult:
|
|
109
|
+
"""Result of tracing a PyTorch operation."""
|
|
110
|
+
|
|
111
|
+
op_spec: OpSpec
|
|
112
|
+
hardware: str
|
|
113
|
+
# Kernels dispatched (may be multiple for fused ops)
|
|
114
|
+
kernels: list[KernelInfo]
|
|
115
|
+
# Primary kernel (typically the longest-running one)
|
|
116
|
+
primary_kernel: KernelInfo | None
|
|
117
|
+
# Roofline analysis (if metrics available)
|
|
118
|
+
roofline: RooflineAnalysis | None = None
|
|
119
|
+
# Raw profiler output (for debugging)
|
|
120
|
+
raw_output: str | None = None
|
|
121
|
+
# Error, if any
|
|
122
|
+
error: str | None = None
|
|
123
|
+
|
|
124
|
+
@property
|
|
125
|
+
def total_duration_us(self) -> float:
|
|
126
|
+
"""Total duration of all kernels."""
|
|
127
|
+
return sum(k.duration_us for k in self.kernels)
|
|
128
|
+
|
|
129
|
+
def summary(self) -> str:
|
|
130
|
+
"""Generate human-readable summary."""
|
|
131
|
+
if self.error:
|
|
132
|
+
return f"Error: {self.error}"
|
|
133
|
+
|
|
134
|
+
# Build operation description with shapes
|
|
135
|
+
op_desc = self._format_op_description()
|
|
136
|
+
|
|
137
|
+
lines = [
|
|
138
|
+
"═" * 65,
|
|
139
|
+
f" BASELINE: {op_desc}",
|
|
140
|
+
"═" * 65,
|
|
141
|
+
"",
|
|
142
|
+
]
|
|
143
|
+
|
|
144
|
+
# Primary kernel section
|
|
145
|
+
if self.primary_kernel:
|
|
146
|
+
lines.append(" Primary Kernel:")
|
|
147
|
+
kernel_name = self._truncate_kernel_name(self.primary_kernel.name, 55)
|
|
148
|
+
lines.append(f" → {kernel_name}")
|
|
149
|
+
lines.append(f" Duration: {self._format_duration(self.primary_kernel.duration_us)}")
|
|
150
|
+
lines.append("")
|
|
151
|
+
|
|
152
|
+
# Roofline section
|
|
153
|
+
if self.roofline:
|
|
154
|
+
lines.append(" Roofline Analysis:")
|
|
155
|
+
|
|
156
|
+
# Format TFLOPS with commas for readability
|
|
157
|
+
tflops_str = f"{self.roofline.achieved_tflops:,.1f}"
|
|
158
|
+
bw_str = f"{self.roofline.achieved_memory_bw_tbps:.2f}"
|
|
159
|
+
|
|
160
|
+
lines.append(f" Achieved: {tflops_str} TFLOPS ({self.roofline.compute_pct_of_peak:.1f}% of {self.hardware} peak)")
|
|
161
|
+
lines.append(f" Memory BW: {bw_str} TB/s ({self.roofline.memory_bw_pct_of_peak:.1f}% of peak)")
|
|
162
|
+
lines.append(f" Bottleneck: {self.roofline.bottleneck.upper()}")
|
|
163
|
+
lines.append("")
|
|
164
|
+
|
|
165
|
+
# All kernels section (if more than one)
|
|
166
|
+
if len(self.kernels) > 1:
|
|
167
|
+
lines.append(" All Kernels:")
|
|
168
|
+
total_dur = self.total_duration_us
|
|
169
|
+
for i, k in enumerate(self.kernels[:5]): # Show top 5
|
|
170
|
+
pct = (k.duration_us / total_dur * 100) if total_dur > 0 else 0
|
|
171
|
+
name = self._truncate_kernel_name(k.name, 45)
|
|
172
|
+
dur_str = self._format_duration(k.duration_us)
|
|
173
|
+
marker = "→" if k == self.primary_kernel else " "
|
|
174
|
+
lines.append(f" {marker} {i+1}. {name:<45} {dur_str:>10} ({pct:>4.1f}%)")
|
|
175
|
+
if len(self.kernels) > 5:
|
|
176
|
+
lines.append(f" ... and {len(self.kernels) - 5} more kernels")
|
|
177
|
+
lines.append("")
|
|
178
|
+
|
|
179
|
+
lines.append("─" * 65)
|
|
180
|
+
|
|
181
|
+
return "\n".join(lines)
|
|
182
|
+
|
|
183
|
+
def _format_op_description(self) -> str:
|
|
184
|
+
"""Format operation with shape info."""
|
|
185
|
+
op_name = self.op_spec.op.split(".")[-1] # e.g., "matmul" from "torch.matmul"
|
|
186
|
+
|
|
187
|
+
# Format input shapes
|
|
188
|
+
shape_strs = []
|
|
189
|
+
for t in self.op_spec.inputs:
|
|
190
|
+
shape_str = "×".join(str(d) for d in t.shape)
|
|
191
|
+
shape_strs.append(shape_str)
|
|
192
|
+
|
|
193
|
+
# Get dtype from first input
|
|
194
|
+
dtype = self.op_spec.inputs[0].dtype if self.op_spec.inputs else "float16"
|
|
195
|
+
dtype_upper = dtype.upper().replace("FLOAT", "FP").replace("BFLOAT", "BF")
|
|
196
|
+
|
|
197
|
+
if len(shape_strs) == 2:
|
|
198
|
+
return f"{op_name} ({shape_strs[0]} @ {shape_strs[1]}) {dtype_upper}"
|
|
199
|
+
elif len(shape_strs) == 1:
|
|
200
|
+
return f"{op_name} ({shape_strs[0]}) {dtype_upper}"
|
|
201
|
+
else:
|
|
202
|
+
return f"{op_name} {dtype_upper}"
|
|
203
|
+
|
|
204
|
+
def _truncate_kernel_name(self, name: str, max_len: int) -> str:
|
|
205
|
+
"""Truncate long kernel names."""
|
|
206
|
+
if len(name) <= max_len:
|
|
207
|
+
return name
|
|
208
|
+
return name[:max_len - 3] + "..."
|
|
209
|
+
|
|
210
|
+
def _format_duration(self, us: float) -> str:
|
|
211
|
+
"""Format duration nicely."""
|
|
212
|
+
if us >= 1000:
|
|
213
|
+
return f"{us / 1000:.2f} ms"
|
|
214
|
+
elif us >= 1:
|
|
215
|
+
return f"{us:.1f} µs"
|
|
216
|
+
else:
|
|
217
|
+
return f"{us * 1000:.1f} ns"
|
|
@@ -0,0 +1,360 @@
|
|
|
1
|
+
"""Executor for kernel trace operations.
|
|
2
|
+
|
|
3
|
+
Runs profiling scripts on remote GPUs and returns structured results.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import tempfile
|
|
8
|
+
from dataclasses import dataclass, replace
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
from wafer_core.ssh import ExecResult, SSHClient
|
|
12
|
+
from wafer_core.tools.dispatch_baseline.analyzer import ParsedTraceResult, parse_trace_output
|
|
13
|
+
from wafer_core.tools.dispatch_baseline.codegen import generate_trace_script
|
|
14
|
+
from wafer_core.tools.dispatch_baseline.dtypes import (
|
|
15
|
+
KernelTraceConfig,
|
|
16
|
+
KernelTraceResult,
|
|
17
|
+
OpSpec,
|
|
18
|
+
TensorSpec,
|
|
19
|
+
)
|
|
20
|
+
from wafer_core.tools.dispatch_baseline.roofline import (
|
|
21
|
+
compute_roofline,
|
|
22
|
+
estimate_attention_bytes,
|
|
23
|
+
estimate_attention_flops,
|
|
24
|
+
estimate_matmul_bytes,
|
|
25
|
+
estimate_matmul_flops,
|
|
26
|
+
estimate_softmax_bytes,
|
|
27
|
+
estimate_softmax_flops,
|
|
28
|
+
get_hardware_spec,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class TraceExecutionResult:
|
|
36
|
+
"""Result of trace execution with environment info for caching."""
|
|
37
|
+
|
|
38
|
+
result: KernelTraceResult
|
|
39
|
+
pytorch_version: str
|
|
40
|
+
runtime_version: str
|
|
41
|
+
gpu_arch: str
|
|
42
|
+
from_cache: bool = False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def trace_kernel_local(config: KernelTraceConfig) -> TraceExecutionResult:
|
|
46
|
+
"""Trace a kernel on the local GPU.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
config: Kernel trace configuration
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
TraceExecutionResult with kernel information and environment info
|
|
53
|
+
"""
|
|
54
|
+
import subprocess
|
|
55
|
+
import sys
|
|
56
|
+
|
|
57
|
+
script = generate_trace_script(config)
|
|
58
|
+
|
|
59
|
+
# Write script to temp file
|
|
60
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
61
|
+
f.write(script)
|
|
62
|
+
script_path = f.name
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
# Run script
|
|
66
|
+
result = subprocess.run(
|
|
67
|
+
[sys.executable, script_path],
|
|
68
|
+
capture_output=True,
|
|
69
|
+
text=True,
|
|
70
|
+
timeout=config.timeout_seconds,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
output = result.stdout + result.stderr
|
|
74
|
+
|
|
75
|
+
if result.returncode != 0:
|
|
76
|
+
return TraceExecutionResult(
|
|
77
|
+
result=KernelTraceResult(
|
|
78
|
+
op_spec=config.op_spec,
|
|
79
|
+
hardware=config.hardware,
|
|
80
|
+
kernels=[],
|
|
81
|
+
primary_kernel=None,
|
|
82
|
+
raw_output=output,
|
|
83
|
+
error=f"Script failed with exit code {result.returncode}: {result.stderr}",
|
|
84
|
+
),
|
|
85
|
+
pytorch_version="unknown",
|
|
86
|
+
runtime_version="unknown",
|
|
87
|
+
gpu_arch="unknown",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# Parse output (now includes environment info)
|
|
91
|
+
parsed = parse_trace_output(output, config.op_spec, config.hardware)
|
|
92
|
+
|
|
93
|
+
# Add roofline analysis if we can estimate FLOPS
|
|
94
|
+
trace_result = _add_roofline_analysis(parsed.result, config)
|
|
95
|
+
|
|
96
|
+
return TraceExecutionResult(
|
|
97
|
+
result=trace_result,
|
|
98
|
+
pytorch_version=parsed.pytorch_version,
|
|
99
|
+
runtime_version=parsed.runtime_version,
|
|
100
|
+
gpu_arch=parsed.gpu_arch,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
except subprocess.TimeoutExpired:
|
|
104
|
+
return TraceExecutionResult(
|
|
105
|
+
result=KernelTraceResult(
|
|
106
|
+
op_spec=config.op_spec,
|
|
107
|
+
hardware=config.hardware,
|
|
108
|
+
kernels=[],
|
|
109
|
+
primary_kernel=None,
|
|
110
|
+
error=f"Script timed out after {config.timeout_seconds}s",
|
|
111
|
+
),
|
|
112
|
+
pytorch_version="unknown",
|
|
113
|
+
runtime_version="unknown",
|
|
114
|
+
gpu_arch="unknown",
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
finally:
|
|
118
|
+
# Cleanup
|
|
119
|
+
Path(script_path).unlink(missing_ok=True)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def trace_kernel_remote(
|
|
123
|
+
config: KernelTraceConfig,
|
|
124
|
+
ssh_client: SSHClient,
|
|
125
|
+
workspace_path: str = "/tmp",
|
|
126
|
+
) -> TraceExecutionResult:
|
|
127
|
+
"""Trace a kernel on a remote GPU via SSH.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
config: Kernel trace configuration
|
|
131
|
+
ssh_client: Connected SSH client
|
|
132
|
+
workspace_path: Remote directory to use for temporary files
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
TraceExecutionResult with kernel information and environment info
|
|
136
|
+
"""
|
|
137
|
+
script = generate_trace_script(config)
|
|
138
|
+
script_filename = "dispatch_baseline_script.py"
|
|
139
|
+
remote_script_path = f"{workspace_path}/{script_filename}"
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
# Upload script
|
|
143
|
+
logger.debug(f"Uploading trace script to {remote_script_path}")
|
|
144
|
+
ssh_client.upload_content(script, remote_script_path)
|
|
145
|
+
|
|
146
|
+
# Run script
|
|
147
|
+
# TODO: Add timeout to SSHClient.exec() to prevent hanging on remote failures
|
|
148
|
+
logger.debug("Running trace script...")
|
|
149
|
+
run_cmd = f"cd {workspace_path} && python {script_filename}"
|
|
150
|
+
result: ExecResult = ssh_client.exec(run_cmd)
|
|
151
|
+
|
|
152
|
+
output = result.stdout + result.stderr
|
|
153
|
+
|
|
154
|
+
if result.exit_code != 0:
|
|
155
|
+
return TraceExecutionResult(
|
|
156
|
+
result=KernelTraceResult(
|
|
157
|
+
op_spec=config.op_spec,
|
|
158
|
+
hardware=config.hardware,
|
|
159
|
+
kernels=[],
|
|
160
|
+
primary_kernel=None,
|
|
161
|
+
raw_output=output,
|
|
162
|
+
error=f"Script failed with exit code {result.exit_code}: {result.stderr}",
|
|
163
|
+
),
|
|
164
|
+
pytorch_version="unknown",
|
|
165
|
+
runtime_version="unknown",
|
|
166
|
+
gpu_arch="unknown",
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Parse output (now includes environment info)
|
|
170
|
+
parsed = parse_trace_output(output, config.op_spec, config.hardware)
|
|
171
|
+
|
|
172
|
+
# Add roofline analysis
|
|
173
|
+
trace_result = _add_roofline_analysis(parsed.result, config)
|
|
174
|
+
|
|
175
|
+
return TraceExecutionResult(
|
|
176
|
+
result=trace_result,
|
|
177
|
+
pytorch_version=parsed.pytorch_version,
|
|
178
|
+
runtime_version=parsed.runtime_version,
|
|
179
|
+
gpu_arch=parsed.gpu_arch,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.exception("Failed to trace kernel remotely")
|
|
184
|
+
return TraceExecutionResult(
|
|
185
|
+
result=KernelTraceResult(
|
|
186
|
+
op_spec=config.op_spec,
|
|
187
|
+
hardware=config.hardware,
|
|
188
|
+
kernels=[],
|
|
189
|
+
primary_kernel=None,
|
|
190
|
+
error=f"Remote execution failed: {e}",
|
|
191
|
+
),
|
|
192
|
+
pytorch_version="unknown",
|
|
193
|
+
runtime_version="unknown",
|
|
194
|
+
gpu_arch="unknown",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
finally:
|
|
198
|
+
# Cleanup remote script
|
|
199
|
+
try:
|
|
200
|
+
ssh_client.exec(f"rm -f {remote_script_path}")
|
|
201
|
+
except Exception:
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _add_roofline_analysis(
|
|
206
|
+
result: KernelTraceResult, config: KernelTraceConfig
|
|
207
|
+
) -> KernelTraceResult:
|
|
208
|
+
"""Add roofline analysis to trace result if possible.
|
|
209
|
+
|
|
210
|
+
Supports: matmul, softmax, attention, and generic elementwise ops.
|
|
211
|
+
"""
|
|
212
|
+
if result.primary_kernel is None:
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
if config.hardware is None:
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
# Try to estimate FLOPS based on operation type
|
|
219
|
+
op_lower = config.op_spec.op.lower()
|
|
220
|
+
shapes = [t.shape for t in config.op_spec.inputs]
|
|
221
|
+
dtype_bytes = _get_dtype_bytes(config.op_spec.inputs[0].dtype) if config.op_spec.inputs else 2
|
|
222
|
+
|
|
223
|
+
flops = 0.0
|
|
224
|
+
bytes_transferred = 0.0
|
|
225
|
+
|
|
226
|
+
if "matmul" in op_lower or "mm" in op_lower or "linear" in op_lower:
|
|
227
|
+
# Matrix multiplication: C[M,N] = A[M,K] @ B[K,N]
|
|
228
|
+
if len(shapes) >= 2 and len(shapes[0]) >= 2 and len(shapes[1]) >= 2:
|
|
229
|
+
# Handle batched matmul by taking last 2 dims
|
|
230
|
+
m, k1 = shapes[0][-2], shapes[0][-1]
|
|
231
|
+
k2, n = shapes[1][-2], shapes[1][-1]
|
|
232
|
+
# Validate batch dimensions match (if present)
|
|
233
|
+
batch_dims_0 = shapes[0][:-2]
|
|
234
|
+
batch_dims_1 = shapes[1][:-2]
|
|
235
|
+
if k1 != k2:
|
|
236
|
+
logger.warning(f"Matmul inner dims mismatch: {k1} vs {k2}, skipping roofline")
|
|
237
|
+
elif batch_dims_0 != batch_dims_1:
|
|
238
|
+
logger.warning(f"Matmul batch dims mismatch: {batch_dims_0} vs {batch_dims_1}, skipping roofline")
|
|
239
|
+
else:
|
|
240
|
+
# Account for batch dimensions
|
|
241
|
+
batch_size = 1
|
|
242
|
+
for dim in batch_dims_0:
|
|
243
|
+
batch_size *= dim
|
|
244
|
+
flops = batch_size * estimate_matmul_flops(m, n, k1)
|
|
245
|
+
bytes_transferred = batch_size * estimate_matmul_bytes(m, n, k1, dtype_bytes)
|
|
246
|
+
|
|
247
|
+
elif "softmax" in op_lower:
|
|
248
|
+
# Softmax: read input, write output
|
|
249
|
+
if shapes:
|
|
250
|
+
elements = 1
|
|
251
|
+
for dim in shapes[0]:
|
|
252
|
+
elements *= dim
|
|
253
|
+
flops = estimate_softmax_flops(elements)
|
|
254
|
+
bytes_transferred = estimate_softmax_bytes(elements, dtype_bytes)
|
|
255
|
+
|
|
256
|
+
elif "attention" in op_lower or "sdpa" in op_lower:
|
|
257
|
+
# Scaled dot-product attention
|
|
258
|
+
# Expect inputs like Q[B,H,S,D], K[B,H,S,D], V[B,H,S,D]
|
|
259
|
+
if len(shapes) >= 3 and len(shapes[0]) == 4:
|
|
260
|
+
batch, heads, seq_len, head_dim = shapes[0]
|
|
261
|
+
flops = estimate_attention_flops(batch, heads, seq_len, head_dim)
|
|
262
|
+
bytes_transferred = estimate_attention_bytes(batch, heads, seq_len, head_dim, dtype_bytes)
|
|
263
|
+
|
|
264
|
+
elif any(op in op_lower for op in ["relu", "gelu", "silu", "tanh", "sigmoid", "exp", "log"]):
|
|
265
|
+
# Elementwise activation: ~1-5 ops per element, read+write
|
|
266
|
+
if shapes:
|
|
267
|
+
elements = 1
|
|
268
|
+
for dim in shapes[0]:
|
|
269
|
+
elements *= dim
|
|
270
|
+
flops = 5.0 * elements # Conservative estimate for transcendentals
|
|
271
|
+
bytes_transferred = 2.0 * elements * dtype_bytes # Read + write
|
|
272
|
+
|
|
273
|
+
elif any(op in op_lower for op in ["add", "sub", "mul", "div"]):
|
|
274
|
+
# Binary elementwise: 1 op per element
|
|
275
|
+
if shapes:
|
|
276
|
+
elements = 1
|
|
277
|
+
for dim in shapes[0]:
|
|
278
|
+
elements *= dim
|
|
279
|
+
flops = float(elements)
|
|
280
|
+
bytes_transferred = 3.0 * elements * dtype_bytes # Read 2 inputs + write 1 output
|
|
281
|
+
|
|
282
|
+
elif "layernorm" in op_lower or "layer_norm" in op_lower or "rmsnorm" in op_lower:
|
|
283
|
+
# Normalization: mean, variance, normalize (~10 ops per element)
|
|
284
|
+
if shapes:
|
|
285
|
+
elements = 1
|
|
286
|
+
for dim in shapes[0]:
|
|
287
|
+
elements *= dim
|
|
288
|
+
flops = 10.0 * elements
|
|
289
|
+
bytes_transferred = 2.0 * elements * dtype_bytes
|
|
290
|
+
|
|
291
|
+
elif "conv" in op_lower:
|
|
292
|
+
# Convolution is complex, skip for now (would need kernel size, stride, etc.)
|
|
293
|
+
pass
|
|
294
|
+
|
|
295
|
+
if flops > 0 and bytes_transferred > 0:
|
|
296
|
+
roofline = compute_roofline(
|
|
297
|
+
result.primary_kernel,
|
|
298
|
+
config.hardware,
|
|
299
|
+
flops,
|
|
300
|
+
bytes_transferred,
|
|
301
|
+
)
|
|
302
|
+
if roofline:
|
|
303
|
+
return replace(result, roofline=roofline)
|
|
304
|
+
|
|
305
|
+
return result
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _get_dtype_bytes(dtype: str) -> int:
|
|
309
|
+
"""Get bytes per element for a dtype."""
|
|
310
|
+
dtype_map = {
|
|
311
|
+
"float16": 2,
|
|
312
|
+
"float32": 4,
|
|
313
|
+
"float64": 8,
|
|
314
|
+
"bfloat16": 2,
|
|
315
|
+
"int8": 1,
|
|
316
|
+
"int16": 2,
|
|
317
|
+
"int32": 4,
|
|
318
|
+
"int64": 8,
|
|
319
|
+
}
|
|
320
|
+
return dtype_map.get(dtype, 2)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def quick_trace(
|
|
324
|
+
op: str,
|
|
325
|
+
shapes: dict[str, tuple[int, ...]],
|
|
326
|
+
hardware: str = "H100",
|
|
327
|
+
dtype: str = "float16",
|
|
328
|
+
) -> KernelTraceResult:
|
|
329
|
+
"""Quick helper to trace an operation locally.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
op: Operation string like "torch.matmul(A, B)"
|
|
333
|
+
shapes: Dict mapping tensor names to shapes
|
|
334
|
+
hardware: Hardware name (for roofline analysis)
|
|
335
|
+
dtype: Data type for tensors
|
|
336
|
+
|
|
337
|
+
Returns:
|
|
338
|
+
KernelTraceResult
|
|
339
|
+
|
|
340
|
+
Example:
|
|
341
|
+
result = quick_trace(
|
|
342
|
+
"torch.matmul(A, B)",
|
|
343
|
+
{"A": (4096, 4096), "B": (4096, 4096)},
|
|
344
|
+
hardware="H100",
|
|
345
|
+
)
|
|
346
|
+
print(result.summary())
|
|
347
|
+
"""
|
|
348
|
+
from wafer_core.tools.dispatch_baseline.codegen import parse_op_string, update_dtypes, update_shapes
|
|
349
|
+
|
|
350
|
+
op_spec = parse_op_string(op)
|
|
351
|
+
op_spec = update_shapes(op_spec, shapes)
|
|
352
|
+
op_spec = update_dtypes(op_spec, dtype)
|
|
353
|
+
|
|
354
|
+
config = KernelTraceConfig(
|
|
355
|
+
op_spec=op_spec,
|
|
356
|
+
hardware=hardware,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
execution_result = trace_kernel_local(config)
|
|
360
|
+
return execution_result.result
|