wafer-core 0.1.24__py3-none-any.whl → 0.1.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -34,7 +34,6 @@ from wafer_core.tools import (
34
34
  GLOB_TOOL,
35
35
  GREP_TOOL,
36
36
  READ_TOOL,
37
- SEARCH_DOCS_TOOL,
38
37
  SKILL_TOOL,
39
38
  WRITE_TOOL,
40
39
  ApprovalCallback,
@@ -43,7 +42,6 @@ from wafer_core.tools import (
43
42
  exec_glob,
44
43
  exec_grep,
45
44
  exec_read,
46
- exec_search_docs,
47
45
  exec_skill,
48
46
  exec_write,
49
47
  )
@@ -65,7 +63,6 @@ ALL_TOOLS = {
65
63
  "glob": GLOB_TOOL,
66
64
  "grep": GREP_TOOL,
67
65
  "bash": BASH_TOOL,
68
- "search_docs": SEARCH_DOCS_TOOL,
69
66
  "skill": SKILL_TOOL,
70
67
  # TODO(wafer-tool): "wafer": WAFER_TOOL,
71
68
  }
@@ -214,7 +211,6 @@ class CodingEnvironment:
214
211
  self.bash_approval_callback,
215
212
  self._sandbox_policy,
216
213
  ),
217
- "search_docs": lambda tc: exec_search_docs(tc),
218
214
  "skill": lambda tc: exec_skill(tc),
219
215
  # TODO(wafer-tool): "wafer": lambda tc: exec_wafer(
220
216
  # tc, self.working_dir, self.enabled_tools, self.allow_spawn, cancel_scope
@@ -0,0 +1,32 @@
1
+ """Trace comparison library for analyzing GPU traces across platforms.
2
+
3
+ This module provides functionality to compare performance traces from AMD and NVIDIA GPUs,
4
+ identifying kernel-level performance differences and fusion opportunities.
5
+ """
6
+
7
+ from .analyzer import analyze_traces
8
+ from .classifier import Op, classify
9
+ from .formatter import (
10
+ format_csv,
11
+ format_fusion_csv,
12
+ format_fusion_json,
13
+ format_fusion_text,
14
+ format_json,
15
+ format_text,
16
+ )
17
+ from .fusion_analyzer import analyze_fusion_differences
18
+ from .loader import load_trace
19
+
20
+ __all__ = [
21
+ "Op",
22
+ "classify",
23
+ "load_trace",
24
+ "analyze_traces",
25
+ "analyze_fusion_differences",
26
+ "format_text",
27
+ "format_csv",
28
+ "format_json",
29
+ "format_fusion_text",
30
+ "format_fusion_csv",
31
+ "format_fusion_json",
32
+ ]
@@ -0,0 +1,339 @@
1
+ """Main trace comparison analysis logic.
2
+
3
+ Compares GPU traces from AMD and NVIDIA platforms, identifying performance differences
4
+ at the operation level and layer level.
5
+ """
6
+
7
+ from collections import defaultdict
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ import pandas as pd
12
+
13
+ from .loader import load_trace
14
+
15
+
16
+ def analyze_traces(
17
+ trace1_path: str | Path,
18
+ trace2_path: str | Path,
19
+ phase_filter: str = "all",
20
+ max_stacks: int = 3,
21
+ ) -> dict[str, Any]:
22
+ """Analyze two traces and return comparison data.
23
+
24
+ Args:
25
+ trace1_path: Path to first trace file
26
+ trace2_path: Path to second trace file
27
+ phase_filter: Filter by phase ('all', 'prefill', or 'decode')
28
+ max_stacks: Maximum number of Python stack traces to collect per operation (0 for unlimited)
29
+
30
+ Returns:
31
+ Dictionary containing:
32
+ - metadata: trace info (GPUs, kernel counts, total times, etc.)
33
+ - operations: per-operation comparison data
34
+ - layers: per-layer comparison data (if layers detected)
35
+ """
36
+ # Load traces
37
+ p1, gpu1, dev1, df1, patterns1, layers1 = load_trace(trace1_path)
38
+ p2, gpu2, dev2, df2, patterns2, layers2 = load_trace(trace2_path)
39
+
40
+ # Apply phase filter
41
+ if phase_filter != "all":
42
+ df1_filtered = df1[df1["phase"] == phase_filter]
43
+ df2_filtered = df2[df2["phase"] == phase_filter]
44
+
45
+ if len(df1_filtered) == 0 and len(df2_filtered) == 0:
46
+ # No data in requested phase - return early with error info
47
+ trace1_phases = {k: int(v) for k, v in df1["phase"].value_counts().items()}
48
+ trace2_phases = {k: int(v) for k, v in df2["phase"].value_counts().items()}
49
+ raise ValueError(
50
+ f"No {phase_filter} phase found. "
51
+ f"Trace1 phases: {trace1_phases}, Trace2 phases: {trace2_phases}"
52
+ )
53
+
54
+ df1, df2 = df1_filtered, df2_filtered
55
+
56
+ # Pre-compute aggregations for both operations and layers in single pass
57
+ # This is much faster than iterating through filtered dataframes multiple times
58
+
59
+ # Group by operation for operation-level analysis
60
+ trace1_by_op = df1.groupby("op").agg({
61
+ "dur_us": ["sum", "mean", "count"],
62
+ "phase": lambda x: set(x.dropna().unique()),
63
+ "cpu_op": lambda x: x.dropna().mode()[0] if len(x.dropna()) > 0 else None,
64
+ })
65
+ trace1_by_op.columns = ["total_us", "avg_us", "count", "phases", "cpu_op"]
66
+
67
+ trace2_by_op = df2.groupby("op").agg({
68
+ "dur_us": ["sum", "mean", "count"],
69
+ "phase": lambda x: set(x.dropna().unique()),
70
+ "cpu_op": lambda x: x.dropna().mode()[0] if len(x.dropna()) > 0 else None,
71
+ })
72
+ trace2_by_op.columns = ["total_us", "avg_us", "count", "phases", "cpu_op"]
73
+
74
+ # Group by layer for layer-level analysis (only for kernels with layer info)
75
+ df1_layered = df1[df1["layer"].notna()]
76
+ df2_layered = df2[df2["layer"].notna()]
77
+
78
+ trace1_by_layer = df1_layered.groupby("layer").agg({
79
+ "dur_us": ["sum", "count"],
80
+ }) if len(df1_layered) > 0 else pd.DataFrame()
81
+ if len(trace1_by_layer) > 0:
82
+ trace1_by_layer.columns = ["total_us", "count"]
83
+
84
+ trace2_by_layer = df2_layered.groupby("layer").agg({
85
+ "dur_us": ["sum", "count"],
86
+ }) if len(df2_layered) > 0 else pd.DataFrame()
87
+ if len(trace2_by_layer) > 0:
88
+ trace2_by_layer.columns = ["total_us", "count"]
89
+
90
+ # Calculate per-operation statistics
91
+ results: dict[str, Any] = {
92
+ "metadata": {
93
+ "trace1_name": str(trace1_path),
94
+ "trace2_name": str(trace2_path),
95
+ "trace1_platform": p1,
96
+ "trace1_gpu": gpu1,
97
+ "trace1_device": dev1,
98
+ "trace2_platform": p2,
99
+ "trace2_gpu": gpu2,
100
+ "trace2_device": dev2,
101
+ "trace1_kernels": len(df1),
102
+ "trace2_kernels": len(df2),
103
+ "trace1_total_ms": df1["dur_us"].sum() / 1000,
104
+ "trace2_total_ms": df2["dur_us"].sum() / 1000,
105
+ "phase": phase_filter,
106
+ "trace1_layers": len(layers1),
107
+ "trace2_layers": len(layers2),
108
+ },
109
+ "operations": [],
110
+ "layers": [],
111
+ }
112
+
113
+ # Per-operation comparison using pre-computed aggregations
114
+ all_ops = set(trace1_by_op.index) | set(trace2_by_op.index)
115
+
116
+ # Track if we've already compared RMSNorm variants to avoid duplicate comparisons
117
+ rmsnorm_compared = False
118
+
119
+ for op in sorted(all_ops):
120
+ # Use pre-computed aggregations instead of filtering entire dataframes
121
+ has_trace1 = op in trace1_by_op.index
122
+ has_trace2 = op in trace2_by_op.index
123
+
124
+ # Handle RMSNorm fusion differences: AMD does RMSNorm+GEMM, NVIDIA does separate RMSNorm
125
+ trace1_op_for_pattern = op # Operation name to use for AMD pattern lookup
126
+ trace2_op_for_pattern = op # Operation name to use for NVIDIA pattern lookup
127
+ skip_comparison = False
128
+
129
+ if op == "RMSNorm+GEMM" and not has_trace2:
130
+ # Compare AMD's fused version to NVIDIA's separate RMSNorm
131
+ has_trace2 = "RMSNorm" in trace2_by_op.index
132
+ trace2_op_for_pattern = "RMSNorm" # NVIDIA kernels are stored under 'RMSNorm'
133
+ rmsnorm_compared = True # Mark that we've compared RMSNorm
134
+ elif op == "RMSNorm" and not has_trace1:
135
+ # Skip this comparison if we already handled it in RMSNorm+GEMM
136
+ if rmsnorm_compared:
137
+ skip_comparison = True
138
+ else:
139
+ # Compare NVIDIA's RMSNorm to AMD's fused version
140
+ has_trace1 = "RMSNorm+GEMM" in trace1_by_op.index
141
+ trace1_op_for_pattern = (
142
+ "RMSNorm+GEMM" # AMD kernels are stored under 'RMSNorm+GEMM'
143
+ )
144
+ rmsnorm_compared = True
145
+
146
+ if skip_comparison or not (has_trace1 and has_trace2):
147
+ continue
148
+
149
+ # Get pre-computed aggregations
150
+ trace1_agg = trace1_by_op.loc[trace1_op_for_pattern]
151
+ trace2_agg = trace2_by_op.loc[trace2_op_for_pattern]
152
+
153
+ trace1_avg = trace1_agg["avg_us"]
154
+ trace2_avg = trace2_agg["avg_us"]
155
+ trace1_total = trace1_agg["total_us"] / 1000
156
+ trace2_total = trace2_agg["total_us"] / 1000
157
+ trace1_count = int(trace1_agg["count"])
158
+ trace2_count = int(trace2_agg["count"])
159
+ ratio = trace1_avg / trace2_avg if trace2_avg > 0 else 1
160
+ gap_ms = trace1_total - trace2_total
161
+
162
+ # Get kernel patterns using the correct operation names for each platform
163
+ trace1_pattern = list(
164
+ patterns1.get(
165
+ (trace1_op_for_pattern, "decode"),
166
+ patterns1.get((trace1_op_for_pattern, "prefill"), {"unknown"}),
167
+ )
168
+ )[0]
169
+ trace2_pattern = list(
170
+ patterns2.get(
171
+ (trace2_op_for_pattern, "decode"),
172
+ patterns2.get((trace2_op_for_pattern, "prefill"), {"unknown"}),
173
+ )
174
+ )[0]
175
+
176
+ # Get CPU operators from pre-computed aggregations
177
+ trace1_cpu_op = trace1_agg["cpu_op"]
178
+ trace2_cpu_op = trace2_agg["cpu_op"]
179
+
180
+ # For detailed kernel data and python stacks, we still need to filter (but only when needed)
181
+ trace1_data = df1[df1["op"] == trace1_op_for_pattern]
182
+ trace2_data = df2[df2["op"] == trace2_op_for_pattern]
183
+
184
+ # Collect example Python stacks for this operation (for JSON output)
185
+ trace1_python_stacks = []
186
+ stack_limit = None if max_stacks == 0 else max_stacks
187
+ for stack_list in trace1_data["python_stack"].head(stack_limit):
188
+ if stack_list and len(stack_list) > 0:
189
+ trace1_python_stacks.append(stack_list)
190
+
191
+ trace2_python_stacks = []
192
+ for stack_list in trace2_data["python_stack"].head(stack_limit):
193
+ if stack_list and len(stack_list) > 0:
194
+ trace2_python_stacks.append(stack_list)
195
+
196
+ # Aggregate individual kernels by name for detailed view
197
+ # Group by kernel name and calculate sum/count/avg
198
+ trace1_kernels = trace1_data.groupby("name").agg({"dur_us": ["sum", "count", "mean"]}).reset_index()
199
+ trace1_kernels.columns = ["name", "total_us", "count", "avg_us"]
200
+ trace1_kernels = trace1_kernels.sort_values("total_us", ascending=False)
201
+ trace1_kernels_list = trace1_kernels.to_dict("records")
202
+
203
+ trace2_kernels = trace2_data.groupby("name").agg({"dur_us": ["sum", "count", "mean"]}).reset_index()
204
+ trace2_kernels.columns = ["name", "total_us", "count", "avg_us"]
205
+ trace2_kernels = trace2_kernels.sort_values("total_us", ascending=False)
206
+ trace2_kernels_list = trace2_kernels.to_dict("records")
207
+
208
+ # Determine status based on TOTAL TIME (gap), not per-call ratio
209
+ # This handles cases where AMD runs fewer operations via fusion.
210
+ # 5ms threshold chosen because:
211
+ # - Filters out measurement noise and minor variations
212
+ # - Represents meaningful performance impact (0.5% of typical 1s inference)
213
+ # - Aligns with human perception of "noticeable" difference
214
+ # - Too small (1ms) creates false positives from variance
215
+ # - Too large (20ms) misses real optimization opportunities
216
+ if gap_ms > 5.0: # AMD spends >5ms more total time
217
+ status = "slower"
218
+ elif gap_ms < -5.0: # AMD spends >5ms less total time
219
+ status = "faster"
220
+ else:
221
+ status = "similar"
222
+
223
+ # Get phases from pre-computed aggregations
224
+ phases = trace1_agg["phases"] | trace2_agg["phases"]
225
+
226
+ results["operations"].append(
227
+ {
228
+ "operation": op,
229
+ "trace1_count": trace1_count,
230
+ "trace2_count": trace2_count,
231
+ "trace1_avg_us": trace1_avg,
232
+ "trace2_avg_us": trace2_avg,
233
+ "trace1_total_ms": trace1_total,
234
+ "trace2_total_ms": trace2_total,
235
+ "ratio": ratio,
236
+ "gap_ms": gap_ms,
237
+ "status": status,
238
+ "trace1_kernel": trace1_pattern,
239
+ "trace2_kernel": trace2_pattern,
240
+ "trace1_cpu_op": trace1_cpu_op,
241
+ "trace2_cpu_op": trace2_cpu_op,
242
+ "trace1_python_stacks": trace1_python_stacks, # Full stacks for JSON
243
+ "trace2_python_stacks": trace2_python_stacks,
244
+ "trace1_kernels": trace1_kernels_list, # All individual kernels for JSON
245
+ "trace2_kernels": trace2_kernels_list, # All individual kernels for JSON
246
+ "phases": sorted(list(phases)) if phases else ["all"], # For client-side filtering
247
+ }
248
+ )
249
+
250
+ # Sort by absolute gap
251
+ results["operations"].sort(key=lambda x: abs(x["gap_ms"]), reverse=True)
252
+
253
+ # Layer-wise analysis using pre-computed aggregations
254
+ if len(trace1_by_layer) > 0 or len(trace2_by_layer) > 0:
255
+ # Get all unique layers present in either trace
256
+ all_layers = sorted(set(trace1_by_layer.index) | set(trace2_by_layer.index))
257
+
258
+ for layer_num in all_layers:
259
+ has_trace1 = layer_num in trace1_by_layer.index
260
+ has_trace2 = layer_num in trace2_by_layer.index
261
+
262
+ if has_trace1 and has_trace2:
263
+ # Layer present in both traces - compare them
264
+ trace1_agg = trace1_by_layer.loc[layer_num]
265
+ trace2_agg = trace2_by_layer.loc[layer_num]
266
+
267
+ trace1_total = trace1_agg["total_us"] / 1000
268
+ trace2_total = trace2_agg["total_us"] / 1000
269
+ trace1_count = int(trace1_agg["count"])
270
+ trace2_count = int(trace2_agg["count"])
271
+ ratio = trace1_total / trace2_total if trace2_total > 0 else 1
272
+ gap_ms = trace1_total - trace2_total
273
+
274
+ # Determine status (use smaller threshold for layers: 0.1ms or 20% difference)
275
+ threshold_ms = 0.1
276
+ threshold_ratio = 1.2
277
+ if gap_ms > threshold_ms and ratio > threshold_ratio:
278
+ status = "slower"
279
+ elif gap_ms < -threshold_ms and ratio < (1.0 / threshold_ratio):
280
+ status = "faster"
281
+ else:
282
+ status = "similar"
283
+
284
+ results["layers"].append(
285
+ {
286
+ "layer": int(layer_num),
287
+ "trace1_kernels": trace1_count,
288
+ "trace2_kernels": trace2_count,
289
+ "trace1_total_ms": trace1_total,
290
+ "trace2_total_ms": trace2_total,
291
+ "ratio": ratio,
292
+ "gap_ms": gap_ms,
293
+ "status": status,
294
+ "in_both": True,
295
+ }
296
+ )
297
+ elif has_trace1:
298
+ # Layer only in trace1
299
+ trace1_agg = trace1_by_layer.loc[layer_num]
300
+ trace1_total = trace1_agg["total_us"] / 1000
301
+ trace1_count = int(trace1_agg["count"])
302
+
303
+ results["layers"].append(
304
+ {
305
+ "layer": int(layer_num),
306
+ "trace1_kernels": trace1_count,
307
+ "trace2_kernels": 0,
308
+ "trace1_total_ms": trace1_total,
309
+ "trace2_total_ms": 0.0,
310
+ "ratio": 0.0,
311
+ "gap_ms": trace1_total,
312
+ "status": "trace1_only",
313
+ "in_both": False,
314
+ }
315
+ )
316
+ elif has_trace2:
317
+ # Layer only in trace2
318
+ trace2_agg = trace2_by_layer.loc[layer_num]
319
+ trace2_total = trace2_agg["total_us"] / 1000
320
+ trace2_count = int(trace2_agg["count"])
321
+
322
+ results["layers"].append(
323
+ {
324
+ "layer": int(layer_num),
325
+ "trace1_kernels": 0,
326
+ "trace2_kernels": trace2_count,
327
+ "trace1_total_ms": 0.0,
328
+ "trace2_total_ms": trace2_total,
329
+ "ratio": 0.0,
330
+ "gap_ms": -trace2_total,
331
+ "status": "trace2_only",
332
+ "in_both": False,
333
+ }
334
+ )
335
+
336
+ # Sort: comparable layers first (by absolute gap), then trace-unique layers
337
+ results["layers"].sort(key=lambda x: (not x["in_both"], abs(x["gap_ms"])), reverse=True)
338
+
339
+ return results
@@ -0,0 +1,192 @@
1
+ """Kernel classification logic for trace comparison.
2
+
3
+ Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
4
+ based on kernel name patterns and platform-specific conventions.
5
+ """
6
+
7
+ from enum import Enum
8
+
9
+
10
+ class Op(Enum):
11
+ """Kernel operation categories."""
12
+
13
+ ATTN_PREFILL = "Attention (Prefill)"
14
+ ATTN_DECODE = "Attention (Decode)"
15
+ KV_CACHE = "KV Cache"
16
+ MOE_ROUTING = "MoE Routing"
17
+ MOE_GEMM = "MoE GEMM"
18
+ MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
19
+ MOE_FINALIZE = "MoE Finalize"
20
+ DENSE_GEMM = "Dense GEMM"
21
+ RMSNORM = "RMSNorm"
22
+ RMSNORM_GEMM = "RMSNorm+GEMM"
23
+ TRITON_FUSED = "Triton Fused"
24
+ ELEMENTWISE = "Elementwise"
25
+ SORTING = "Sorting"
26
+ REDUCE = "Reduce"
27
+ COPY_MEMORY = "Copy/Memory"
28
+ OTHER = "Other"
29
+
30
+
31
+ def classify(name: str, platform: str) -> tuple[Op, str]:
32
+ """Classify kernel by operation type.
33
+
34
+ Args:
35
+ name: Kernel name from trace
36
+ platform: 'AMD' or 'NVIDIA'
37
+
38
+ Returns:
39
+ Tuple of (operation type, pattern name)
40
+ """
41
+ nl = name.lower()
42
+
43
+ # Attention
44
+ if "attention" in nl or "fmha" in nl:
45
+ if platform == "AMD":
46
+ if "2d" in nl:
47
+ return Op.ATTN_PREFILL, "kernel_unified_attention_2d"
48
+ if "3d" in nl:
49
+ return Op.ATTN_DECODE, "kernel_unified_attention_3d"
50
+ else:
51
+ # NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
52
+ if "fmhasm100a" in nl or "context" in nl:
53
+ return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
54
+ if "fmhasm100f" in nl or "forgen" in nl:
55
+ return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
56
+ return Op.ATTN_PREFILL, name[:40]
57
+
58
+ if "reshape_and_cache" in nl:
59
+ return Op.KV_CACHE, "reshape_and_cache_*"
60
+
61
+ # MoE
62
+ if "_matmul_ogs_" in nl:
63
+ if "swiglu" in nl:
64
+ return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
65
+ return Op.MOE_GEMM, "_matmul_ogs_*"
66
+
67
+ if name.startswith("bmm_") and "dynbatch" in nl:
68
+ if "swiglu" in nl:
69
+ return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
70
+ return Op.MOE_GEMM, "bmm_*_dynBatch"
71
+
72
+ if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
73
+ return Op.MOE_ROUTING, "moe_routing_*"
74
+ if "finalize" in nl or ("scatter" in nl and "moe" in nl):
75
+ return Op.MOE_FINALIZE, "moe_finalize_*"
76
+
77
+ # RMSNorm - match actual patterns from traces
78
+ if "triton" in nl and ("rsqrt" in nl or ("mean" in nl and "mul" in nl and "pow" in nl)):
79
+ if "gemm" in nl or "addmm" in nl:
80
+ return Op.RMSNORM_GEMM, "triton_*_rmsnorm_gemm"
81
+ return Op.RMSNORM, "triton_*_rsqrt"
82
+
83
+ # Dense GEMM - these are the most common kernels
84
+ if name.startswith("Cijk_") or name.startswith("Custom_Cijk_"):
85
+ return Op.DENSE_GEMM, "Cijk_* (Tensile)"
86
+ if name.startswith("nvjet_") or "cublaslt" in nl:
87
+ return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
88
+ if "wvsplitk" in nl or name.startswith("void wvSplitK"):
89
+ return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
90
+
91
+ # Triton fused operations - very common
92
+ if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
93
+ # Distinguish between different fusion patterns
94
+ if "silu" in nl or "swiglu" in nl:
95
+ return Op.TRITON_FUSED, "triton_*_silu"
96
+ return Op.TRITON_FUSED, "triton_*"
97
+
98
+ # PyTorch native operations
99
+ if "at::native::" in name:
100
+ return Op.ELEMENTWISE, "at::native::*"
101
+
102
+ # Sorting operations (common in sampling/topk)
103
+ if "sort" in nl or "radixsort" in nl or "merge" in nl:
104
+ if platform == "AMD":
105
+ return Op.SORTING, "rocprim::sort/merge_*"
106
+ else:
107
+ return Op.SORTING, "cub::DeviceRadixSort*"
108
+
109
+ # Reduce operations
110
+ if "reduce" in nl and ("reduce_segments" in nl or "devicereduce" in nl or "devicescan" in nl):
111
+ if platform == "AMD":
112
+ return Op.REDUCE, "reduce_segments"
113
+ else:
114
+ return Op.REDUCE, "cub::DeviceReduce*"
115
+
116
+ # Memory copy operations
117
+ if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
118
+ return Op.COPY_MEMORY, "copy_*"
119
+
120
+ # ROCm/CUDA library kernels (other)
121
+ if "rocprim::" in name or "cub::" in name:
122
+ return Op.OTHER, "rocprim/cub_*"
123
+
124
+ return Op.OTHER, name[:40]
125
+
126
+
127
+ def classify_kernel(name: str) -> str:
128
+ """Simplified kernel classification for fusion analysis.
129
+
130
+ Args:
131
+ name: Kernel name from trace
132
+
133
+ Returns:
134
+ Simple category name consistent across platforms
135
+ """
136
+ nl = name.lower()
137
+
138
+ # GEMM operations (matrix multiplication)
139
+ if any(x in nl for x in ["cijk_", "nvjet", "wvsplitk", "cublas", "hipblas", "tensile"]):
140
+ return "GEMM"
141
+
142
+ # Attention
143
+ if "attention" in nl or "fmha" in nl:
144
+ return "Attention"
145
+
146
+ # KV Cache
147
+ if "reshape_and_cache" in nl:
148
+ return "KV_Cache"
149
+
150
+ # RMSNorm / LayerNorm
151
+ if "triton" in nl and "rsqrt" in nl:
152
+ return "RMSNorm"
153
+ if "layernorm" in nl or "rmsnorm" in nl:
154
+ return "RMSNorm"
155
+
156
+ # SwiGLU / Activations
157
+ if "silu" in nl or "swiglu" in nl:
158
+ return "SwiGLU"
159
+ if "gelu" in nl:
160
+ return "GELU"
161
+ if "relu" in nl and "gelu" not in nl:
162
+ return "ReLU"
163
+
164
+ # Triton fused operations (generic)
165
+ if "triton_poi" in nl:
166
+ return "Triton_Pointwise"
167
+ if "triton_red" in nl:
168
+ return "Triton_Reduce"
169
+ if "triton_per" in nl:
170
+ return "Triton_Persistent"
171
+
172
+ # Reduce operations
173
+ if "reduce_segments" in nl or "devicereduce" in nl:
174
+ return "Reduce"
175
+
176
+ # Sort operations
177
+ if "sort" in nl or "radixsort" in nl or "merge" in nl:
178
+ return "Sort"
179
+
180
+ # Softmax
181
+ if "softmax" in nl:
182
+ return "Softmax"
183
+
184
+ # Elementwise operations
185
+ if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):
186
+ return "Elementwise"
187
+
188
+ # Copy/Memory operations
189
+ if "copy" in nl or "memcpy" in nl:
190
+ return "MemCopy"
191
+
192
+ return "Other"