wafer-core 0.1.31__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,217 @@
1
+ """Data types for kernel trace tool."""
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Literal
5
+
6
+
7
+ @dataclass(frozen=True)
8
+ class TensorSpec:
9
+ """Specification for a tensor operand."""
10
+
11
+ name: str
12
+ shape: tuple[int, ...]
13
+ dtype: str = "float16" # e.g., "float16", "float32", "bfloat16", "int8"
14
+ device: str = "cuda"
15
+
16
+ def __str__(self) -> str:
17
+ shape_str = "x".join(str(d) for d in self.shape)
18
+ return f"{self.name}[{shape_str}] ({self.dtype})"
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class OpSpec:
23
+ """Specification for a PyTorch operation to trace.
24
+
25
+ Example:
26
+ OpSpec(
27
+ op="torch.matmul",
28
+ inputs=[
29
+ TensorSpec("A", (4096, 4096), "float16"),
30
+ TensorSpec("B", (4096, 4096), "float16"),
31
+ ],
32
+ )
33
+ """
34
+
35
+ op: str # e.g., "torch.matmul", "torch.nn.functional.softmax", "torch.conv2d"
36
+ inputs: list[TensorSpec]
37
+ kwargs: dict[str, str] = field(default_factory=dict) # e.g., {"dim": "-1"}
38
+
39
+ def __str__(self) -> str:
40
+ inputs_str = ", ".join(t.name for t in self.inputs)
41
+ kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items())
42
+ if kwargs_str:
43
+ return f"{self.op}({inputs_str}, {kwargs_str})"
44
+ return f"{self.op}({inputs_str})"
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class KernelInfo:
49
+ """Information about a dispatched kernel."""
50
+
51
+ name: str # e.g., "sm90_xmma_gemm_f16f16f16_f32..."
52
+ duration_us: float = 0.0 # Duration in microseconds
53
+ grid_size: tuple[int, int, int] | None = None # (X, Y, Z)
54
+ block_size: tuple[int, int, int] | None = None # (X, Y, Z)
55
+ registers_per_thread: int | None = None
56
+ shared_memory_bytes: int | None = None
57
+ # Performance metrics (from NCU, if available)
58
+ compute_throughput_tflops: float | None = None
59
+ memory_throughput_tbps: float | None = None # TB/s
60
+ achieved_occupancy_pct: float | None = None
61
+
62
+ def __str__(self) -> str:
63
+ return f"{self.name} {self.duration_us:.1f} µs"
64
+
65
+
66
+ @dataclass(frozen=True)
67
+ class HardwareSpec:
68
+ """Hardware specifications for roofline analysis."""
69
+
70
+ name: str # e.g., "H100", "MI300X"
71
+ peak_fp16_tflops: float
72
+ peak_fp32_tflops: float
73
+ peak_memory_bw_tbps: float # TB/s
74
+ # Optional extras
75
+ peak_fp8_tflops: float | None = None
76
+ peak_int8_tops: float | None = None
77
+ shared_memory_per_sm_kb: float | None = None
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class RooflineAnalysis:
82
+ """Roofline analysis result."""
83
+
84
+ # Raw metrics
85
+ achieved_tflops: float
86
+ achieved_memory_bw_tbps: float
87
+ # Percentages of peak
88
+ compute_pct_of_peak: float
89
+ memory_bw_pct_of_peak: float
90
+ # Bottleneck identification
91
+ bottleneck: Literal["compute", "memory", "balanced"]
92
+ arithmetic_intensity: float # FLOPS per byte
93
+
94
+
95
+ @dataclass(frozen=True)
96
+ class KernelTraceConfig:
97
+ """Configuration for kernel tracing."""
98
+
99
+ op_spec: OpSpec
100
+ hardware: str # e.g., "H100", "MI300X"
101
+ num_warmup: int = 10
102
+ num_runs: int = 100
103
+ use_ncu: bool = False # Use NCU for detailed metrics (slower)
104
+ timeout_seconds: int = 120
105
+
106
+
107
+ @dataclass(frozen=True)
108
+ class KernelTraceResult:
109
+ """Result of tracing a PyTorch operation."""
110
+
111
+ op_spec: OpSpec
112
+ hardware: str
113
+ # Kernels dispatched (may be multiple for fused ops)
114
+ kernels: list[KernelInfo]
115
+ # Primary kernel (typically the longest-running one)
116
+ primary_kernel: KernelInfo | None
117
+ # Roofline analysis (if metrics available)
118
+ roofline: RooflineAnalysis | None = None
119
+ # Raw profiler output (for debugging)
120
+ raw_output: str | None = None
121
+ # Error, if any
122
+ error: str | None = None
123
+
124
+ @property
125
+ def total_duration_us(self) -> float:
126
+ """Total duration of all kernels."""
127
+ return sum(k.duration_us for k in self.kernels)
128
+
129
+ def summary(self) -> str:
130
+ """Generate human-readable summary."""
131
+ if self.error:
132
+ return f"Error: {self.error}"
133
+
134
+ # Build operation description with shapes
135
+ op_desc = self._format_op_description()
136
+
137
+ lines = [
138
+ "═" * 65,
139
+ f" BASELINE: {op_desc}",
140
+ "═" * 65,
141
+ "",
142
+ ]
143
+
144
+ # Primary kernel section
145
+ if self.primary_kernel:
146
+ lines.append(" Primary Kernel:")
147
+ kernel_name = self._truncate_kernel_name(self.primary_kernel.name, 55)
148
+ lines.append(f" → {kernel_name}")
149
+ lines.append(f" Duration: {self._format_duration(self.primary_kernel.duration_us)}")
150
+ lines.append("")
151
+
152
+ # Roofline section
153
+ if self.roofline:
154
+ lines.append(" Roofline Analysis:")
155
+
156
+ # Format TFLOPS with commas for readability
157
+ tflops_str = f"{self.roofline.achieved_tflops:,.1f}"
158
+ bw_str = f"{self.roofline.achieved_memory_bw_tbps:.2f}"
159
+
160
+ lines.append(f" Achieved: {tflops_str} TFLOPS ({self.roofline.compute_pct_of_peak:.1f}% of {self.hardware} peak)")
161
+ lines.append(f" Memory BW: {bw_str} TB/s ({self.roofline.memory_bw_pct_of_peak:.1f}% of peak)")
162
+ lines.append(f" Bottleneck: {self.roofline.bottleneck.upper()}")
163
+ lines.append("")
164
+
165
+ # All kernels section (if more than one)
166
+ if len(self.kernels) > 1:
167
+ lines.append(" All Kernels:")
168
+ total_dur = self.total_duration_us
169
+ for i, k in enumerate(self.kernels[:5]): # Show top 5
170
+ pct = (k.duration_us / total_dur * 100) if total_dur > 0 else 0
171
+ name = self._truncate_kernel_name(k.name, 45)
172
+ dur_str = self._format_duration(k.duration_us)
173
+ marker = "→" if k == self.primary_kernel else " "
174
+ lines.append(f" {marker} {i+1}. {name:<45} {dur_str:>10} ({pct:>4.1f}%)")
175
+ if len(self.kernels) > 5:
176
+ lines.append(f" ... and {len(self.kernels) - 5} more kernels")
177
+ lines.append("")
178
+
179
+ lines.append("─" * 65)
180
+
181
+ return "\n".join(lines)
182
+
183
+ def _format_op_description(self) -> str:
184
+ """Format operation with shape info."""
185
+ op_name = self.op_spec.op.split(".")[-1] # e.g., "matmul" from "torch.matmul"
186
+
187
+ # Format input shapes
188
+ shape_strs = []
189
+ for t in self.op_spec.inputs:
190
+ shape_str = "×".join(str(d) for d in t.shape)
191
+ shape_strs.append(shape_str)
192
+
193
+ # Get dtype from first input
194
+ dtype = self.op_spec.inputs[0].dtype if self.op_spec.inputs else "float16"
195
+ dtype_upper = dtype.upper().replace("FLOAT", "FP").replace("BFLOAT", "BF")
196
+
197
+ if len(shape_strs) == 2:
198
+ return f"{op_name} ({shape_strs[0]} @ {shape_strs[1]}) {dtype_upper}"
199
+ elif len(shape_strs) == 1:
200
+ return f"{op_name} ({shape_strs[0]}) {dtype_upper}"
201
+ else:
202
+ return f"{op_name} {dtype_upper}"
203
+
204
+ def _truncate_kernel_name(self, name: str, max_len: int) -> str:
205
+ """Truncate long kernel names."""
206
+ if len(name) <= max_len:
207
+ return name
208
+ return name[:max_len - 3] + "..."
209
+
210
+ def _format_duration(self, us: float) -> str:
211
+ """Format duration nicely."""
212
+ if us >= 1000:
213
+ return f"{us / 1000:.2f} ms"
214
+ elif us >= 1:
215
+ return f"{us:.1f} µs"
216
+ else:
217
+ return f"{us * 1000:.1f} ns"
@@ -0,0 +1,360 @@
1
+ """Executor for kernel trace operations.
2
+
3
+ Runs profiling scripts on remote GPUs and returns structured results.
4
+ """
5
+
6
+ import logging
7
+ import tempfile
8
+ from dataclasses import dataclass, replace
9
+ from pathlib import Path
10
+
11
+ from wafer_core.ssh import ExecResult, SSHClient
12
+ from wafer_core.tools.dispatch_baseline.analyzer import ParsedTraceResult, parse_trace_output
13
+ from wafer_core.tools.dispatch_baseline.codegen import generate_trace_script
14
+ from wafer_core.tools.dispatch_baseline.dtypes import (
15
+ KernelTraceConfig,
16
+ KernelTraceResult,
17
+ OpSpec,
18
+ TensorSpec,
19
+ )
20
+ from wafer_core.tools.dispatch_baseline.roofline import (
21
+ compute_roofline,
22
+ estimate_attention_bytes,
23
+ estimate_attention_flops,
24
+ estimate_matmul_bytes,
25
+ estimate_matmul_flops,
26
+ estimate_softmax_bytes,
27
+ estimate_softmax_flops,
28
+ get_hardware_spec,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class TraceExecutionResult:
36
+ """Result of trace execution with environment info for caching."""
37
+
38
+ result: KernelTraceResult
39
+ pytorch_version: str
40
+ runtime_version: str
41
+ gpu_arch: str
42
+ from_cache: bool = False
43
+
44
+
45
+ def trace_kernel_local(config: KernelTraceConfig) -> TraceExecutionResult:
46
+ """Trace a kernel on the local GPU.
47
+
48
+ Args:
49
+ config: Kernel trace configuration
50
+
51
+ Returns:
52
+ TraceExecutionResult with kernel information and environment info
53
+ """
54
+ import subprocess
55
+ import sys
56
+
57
+ script = generate_trace_script(config)
58
+
59
+ # Write script to temp file
60
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
61
+ f.write(script)
62
+ script_path = f.name
63
+
64
+ try:
65
+ # Run script
66
+ result = subprocess.run(
67
+ [sys.executable, script_path],
68
+ capture_output=True,
69
+ text=True,
70
+ timeout=config.timeout_seconds,
71
+ )
72
+
73
+ output = result.stdout + result.stderr
74
+
75
+ if result.returncode != 0:
76
+ return TraceExecutionResult(
77
+ result=KernelTraceResult(
78
+ op_spec=config.op_spec,
79
+ hardware=config.hardware,
80
+ kernels=[],
81
+ primary_kernel=None,
82
+ raw_output=output,
83
+ error=f"Script failed with exit code {result.returncode}: {result.stderr}",
84
+ ),
85
+ pytorch_version="unknown",
86
+ runtime_version="unknown",
87
+ gpu_arch="unknown",
88
+ )
89
+
90
+ # Parse output (now includes environment info)
91
+ parsed = parse_trace_output(output, config.op_spec, config.hardware)
92
+
93
+ # Add roofline analysis if we can estimate FLOPS
94
+ trace_result = _add_roofline_analysis(parsed.result, config)
95
+
96
+ return TraceExecutionResult(
97
+ result=trace_result,
98
+ pytorch_version=parsed.pytorch_version,
99
+ runtime_version=parsed.runtime_version,
100
+ gpu_arch=parsed.gpu_arch,
101
+ )
102
+
103
+ except subprocess.TimeoutExpired:
104
+ return TraceExecutionResult(
105
+ result=KernelTraceResult(
106
+ op_spec=config.op_spec,
107
+ hardware=config.hardware,
108
+ kernels=[],
109
+ primary_kernel=None,
110
+ error=f"Script timed out after {config.timeout_seconds}s",
111
+ ),
112
+ pytorch_version="unknown",
113
+ runtime_version="unknown",
114
+ gpu_arch="unknown",
115
+ )
116
+
117
+ finally:
118
+ # Cleanup
119
+ Path(script_path).unlink(missing_ok=True)
120
+
121
+
122
+ def trace_kernel_remote(
123
+ config: KernelTraceConfig,
124
+ ssh_client: SSHClient,
125
+ workspace_path: str = "/tmp",
126
+ ) -> TraceExecutionResult:
127
+ """Trace a kernel on a remote GPU via SSH.
128
+
129
+ Args:
130
+ config: Kernel trace configuration
131
+ ssh_client: Connected SSH client
132
+ workspace_path: Remote directory to use for temporary files
133
+
134
+ Returns:
135
+ TraceExecutionResult with kernel information and environment info
136
+ """
137
+ script = generate_trace_script(config)
138
+ script_filename = "dispatch_baseline_script.py"
139
+ remote_script_path = f"{workspace_path}/{script_filename}"
140
+
141
+ try:
142
+ # Upload script
143
+ logger.debug(f"Uploading trace script to {remote_script_path}")
144
+ ssh_client.upload_content(script, remote_script_path)
145
+
146
+ # Run script
147
+ # TODO: Add timeout to SSHClient.exec() to prevent hanging on remote failures
148
+ logger.debug("Running trace script...")
149
+ run_cmd = f"cd {workspace_path} && python {script_filename}"
150
+ result: ExecResult = ssh_client.exec(run_cmd)
151
+
152
+ output = result.stdout + result.stderr
153
+
154
+ if result.exit_code != 0:
155
+ return TraceExecutionResult(
156
+ result=KernelTraceResult(
157
+ op_spec=config.op_spec,
158
+ hardware=config.hardware,
159
+ kernels=[],
160
+ primary_kernel=None,
161
+ raw_output=output,
162
+ error=f"Script failed with exit code {result.exit_code}: {result.stderr}",
163
+ ),
164
+ pytorch_version="unknown",
165
+ runtime_version="unknown",
166
+ gpu_arch="unknown",
167
+ )
168
+
169
+ # Parse output (now includes environment info)
170
+ parsed = parse_trace_output(output, config.op_spec, config.hardware)
171
+
172
+ # Add roofline analysis
173
+ trace_result = _add_roofline_analysis(parsed.result, config)
174
+
175
+ return TraceExecutionResult(
176
+ result=trace_result,
177
+ pytorch_version=parsed.pytorch_version,
178
+ runtime_version=parsed.runtime_version,
179
+ gpu_arch=parsed.gpu_arch,
180
+ )
181
+
182
+ except Exception as e:
183
+ logger.exception("Failed to trace kernel remotely")
184
+ return TraceExecutionResult(
185
+ result=KernelTraceResult(
186
+ op_spec=config.op_spec,
187
+ hardware=config.hardware,
188
+ kernels=[],
189
+ primary_kernel=None,
190
+ error=f"Remote execution failed: {e}",
191
+ ),
192
+ pytorch_version="unknown",
193
+ runtime_version="unknown",
194
+ gpu_arch="unknown",
195
+ )
196
+
197
+ finally:
198
+ # Cleanup remote script
199
+ try:
200
+ ssh_client.exec(f"rm -f {remote_script_path}")
201
+ except Exception:
202
+ pass
203
+
204
+
205
+ def _add_roofline_analysis(
206
+ result: KernelTraceResult, config: KernelTraceConfig
207
+ ) -> KernelTraceResult:
208
+ """Add roofline analysis to trace result if possible.
209
+
210
+ Supports: matmul, softmax, attention, and generic elementwise ops.
211
+ """
212
+ if result.primary_kernel is None:
213
+ return result
214
+
215
+ if config.hardware is None:
216
+ return result
217
+
218
+ # Try to estimate FLOPS based on operation type
219
+ op_lower = config.op_spec.op.lower()
220
+ shapes = [t.shape for t in config.op_spec.inputs]
221
+ dtype_bytes = _get_dtype_bytes(config.op_spec.inputs[0].dtype) if config.op_spec.inputs else 2
222
+
223
+ flops = 0.0
224
+ bytes_transferred = 0.0
225
+
226
+ if "matmul" in op_lower or "mm" in op_lower or "linear" in op_lower:
227
+ # Matrix multiplication: C[M,N] = A[M,K] @ B[K,N]
228
+ if len(shapes) >= 2 and len(shapes[0]) >= 2 and len(shapes[1]) >= 2:
229
+ # Handle batched matmul by taking last 2 dims
230
+ m, k1 = shapes[0][-2], shapes[0][-1]
231
+ k2, n = shapes[1][-2], shapes[1][-1]
232
+ # Validate batch dimensions match (if present)
233
+ batch_dims_0 = shapes[0][:-2]
234
+ batch_dims_1 = shapes[1][:-2]
235
+ if k1 != k2:
236
+ logger.warning(f"Matmul inner dims mismatch: {k1} vs {k2}, skipping roofline")
237
+ elif batch_dims_0 != batch_dims_1:
238
+ logger.warning(f"Matmul batch dims mismatch: {batch_dims_0} vs {batch_dims_1}, skipping roofline")
239
+ else:
240
+ # Account for batch dimensions
241
+ batch_size = 1
242
+ for dim in batch_dims_0:
243
+ batch_size *= dim
244
+ flops = batch_size * estimate_matmul_flops(m, n, k1)
245
+ bytes_transferred = batch_size * estimate_matmul_bytes(m, n, k1, dtype_bytes)
246
+
247
+ elif "softmax" in op_lower:
248
+ # Softmax: read input, write output
249
+ if shapes:
250
+ elements = 1
251
+ for dim in shapes[0]:
252
+ elements *= dim
253
+ flops = estimate_softmax_flops(elements)
254
+ bytes_transferred = estimate_softmax_bytes(elements, dtype_bytes)
255
+
256
+ elif "attention" in op_lower or "sdpa" in op_lower:
257
+ # Scaled dot-product attention
258
+ # Expect inputs like Q[B,H,S,D], K[B,H,S,D], V[B,H,S,D]
259
+ if len(shapes) >= 3 and len(shapes[0]) == 4:
260
+ batch, heads, seq_len, head_dim = shapes[0]
261
+ flops = estimate_attention_flops(batch, heads, seq_len, head_dim)
262
+ bytes_transferred = estimate_attention_bytes(batch, heads, seq_len, head_dim, dtype_bytes)
263
+
264
+ elif any(op in op_lower for op in ["relu", "gelu", "silu", "tanh", "sigmoid", "exp", "log"]):
265
+ # Elementwise activation: ~1-5 ops per element, read+write
266
+ if shapes:
267
+ elements = 1
268
+ for dim in shapes[0]:
269
+ elements *= dim
270
+ flops = 5.0 * elements # Conservative estimate for transcendentals
271
+ bytes_transferred = 2.0 * elements * dtype_bytes # Read + write
272
+
273
+ elif any(op in op_lower for op in ["add", "sub", "mul", "div"]):
274
+ # Binary elementwise: 1 op per element
275
+ if shapes:
276
+ elements = 1
277
+ for dim in shapes[0]:
278
+ elements *= dim
279
+ flops = float(elements)
280
+ bytes_transferred = 3.0 * elements * dtype_bytes # Read 2 inputs + write 1 output
281
+
282
+ elif "layernorm" in op_lower or "layer_norm" in op_lower or "rmsnorm" in op_lower:
283
+ # Normalization: mean, variance, normalize (~10 ops per element)
284
+ if shapes:
285
+ elements = 1
286
+ for dim in shapes[0]:
287
+ elements *= dim
288
+ flops = 10.0 * elements
289
+ bytes_transferred = 2.0 * elements * dtype_bytes
290
+
291
+ elif "conv" in op_lower:
292
+ # Convolution is complex, skip for now (would need kernel size, stride, etc.)
293
+ pass
294
+
295
+ if flops > 0 and bytes_transferred > 0:
296
+ roofline = compute_roofline(
297
+ result.primary_kernel,
298
+ config.hardware,
299
+ flops,
300
+ bytes_transferred,
301
+ )
302
+ if roofline:
303
+ return replace(result, roofline=roofline)
304
+
305
+ return result
306
+
307
+
308
+ def _get_dtype_bytes(dtype: str) -> int:
309
+ """Get bytes per element for a dtype."""
310
+ dtype_map = {
311
+ "float16": 2,
312
+ "float32": 4,
313
+ "float64": 8,
314
+ "bfloat16": 2,
315
+ "int8": 1,
316
+ "int16": 2,
317
+ "int32": 4,
318
+ "int64": 8,
319
+ }
320
+ return dtype_map.get(dtype, 2)
321
+
322
+
323
+ def quick_trace(
324
+ op: str,
325
+ shapes: dict[str, tuple[int, ...]],
326
+ hardware: str = "H100",
327
+ dtype: str = "float16",
328
+ ) -> KernelTraceResult:
329
+ """Quick helper to trace an operation locally.
330
+
331
+ Args:
332
+ op: Operation string like "torch.matmul(A, B)"
333
+ shapes: Dict mapping tensor names to shapes
334
+ hardware: Hardware name (for roofline analysis)
335
+ dtype: Data type for tensors
336
+
337
+ Returns:
338
+ KernelTraceResult
339
+
340
+ Example:
341
+ result = quick_trace(
342
+ "torch.matmul(A, B)",
343
+ {"A": (4096, 4096), "B": (4096, 4096)},
344
+ hardware="H100",
345
+ )
346
+ print(result.summary())
347
+ """
348
+ from wafer_core.tools.dispatch_baseline.codegen import parse_op_string, update_dtypes, update_shapes
349
+
350
+ op_spec = parse_op_string(op)
351
+ op_spec = update_shapes(op_spec, shapes)
352
+ op_spec = update_dtypes(op_spec, dtype)
353
+
354
+ config = KernelTraceConfig(
355
+ op_spec=op_spec,
356
+ hardware=hardware,
357
+ )
358
+
359
+ execution_result = trace_kernel_local(config)
360
+ return execution_result.result