wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +45 -0
- wafer_core/lib/trace_compare/aligner.py +369 -0
- wafer_core/lib/trace_compare/analyzer.py +729 -0
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +486 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +356 -0
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +635 -0
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/METADATA +3 -1
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/RECORD +23 -9
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
"""Kernel classification logic for trace comparison.
|
|
2
|
+
|
|
3
|
+
Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
|
|
4
|
+
based on kernel name patterns and platform-specific conventions.
|
|
5
|
+
|
|
6
|
+
Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
|
|
7
|
+
for comprehensive coverage.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import fnmatch
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import yaml
|
|
18
|
+
except ImportError:
|
|
19
|
+
yaml = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Op(Enum):
|
|
23
|
+
"""Kernel operation categories."""
|
|
24
|
+
|
|
25
|
+
ATTN_PREFILL = "Attention (Prefill)"
|
|
26
|
+
ATTN_DECODE = "Attention (Decode)"
|
|
27
|
+
KV_CACHE = "KV Cache"
|
|
28
|
+
MOE_ROUTING = "MoE Routing"
|
|
29
|
+
MOE_GEMM = "MoE GEMM"
|
|
30
|
+
MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
|
|
31
|
+
MOE_FINALIZE = "MoE Finalize"
|
|
32
|
+
DENSE_GEMM = "Dense GEMM"
|
|
33
|
+
RMSNORM = "RMSNorm"
|
|
34
|
+
RMSNORM_GEMM = "RMSNorm+GEMM"
|
|
35
|
+
SWIGLU = "SwiGLU"
|
|
36
|
+
SWIGLU_GEMM = "SwiGLU+GEMM"
|
|
37
|
+
EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
|
|
38
|
+
SOFTMAX = "SoftMax"
|
|
39
|
+
TRITON_FUSED = "Triton Fused"
|
|
40
|
+
ELEMENTWISE = "Elementwise"
|
|
41
|
+
SORTING = "Sorting"
|
|
42
|
+
REDUCE = "Reduce"
|
|
43
|
+
INDEXING = "Indexing"
|
|
44
|
+
COPY_MEMORY = "Copy/Memory"
|
|
45
|
+
FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
|
|
46
|
+
OTHER = "Other"
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Keywords that indicate specific operations - used for heuristic fusion detection
|
|
50
|
+
FUSION_KEYWORDS: dict[str, str] = {
|
|
51
|
+
# Normalization
|
|
52
|
+
"rsqrt": "Norm",
|
|
53
|
+
"rmsnorm": "RMSNorm",
|
|
54
|
+
"layernorm": "LayerNorm",
|
|
55
|
+
# GEMM
|
|
56
|
+
"gemm": "GEMM",
|
|
57
|
+
"matmul": "GEMM",
|
|
58
|
+
"mm_": "GEMM",
|
|
59
|
+
# Activations
|
|
60
|
+
"silu": "SiLU",
|
|
61
|
+
"swiglu": "SwiGLU",
|
|
62
|
+
"gelu": "GELU",
|
|
63
|
+
"relu": "ReLU",
|
|
64
|
+
# Other ops
|
|
65
|
+
"softmax": "Softmax",
|
|
66
|
+
"attention": "Attention",
|
|
67
|
+
"embedding": "Embedding",
|
|
68
|
+
"reduce": "Reduce",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Load kernel registry YAML if available
|
|
73
|
+
_KERNEL_REGISTRY: dict[str, Any] | None = None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _load_kernel_registry() -> dict[str, Any] | None:
|
|
77
|
+
"""Load kernel pattern registry from YAML file."""
|
|
78
|
+
global _KERNEL_REGISTRY
|
|
79
|
+
|
|
80
|
+
if _KERNEL_REGISTRY is not None:
|
|
81
|
+
return _KERNEL_REGISTRY
|
|
82
|
+
|
|
83
|
+
if yaml is None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
registry_path = Path(__file__).parent / "kernel_registry.yaml"
|
|
87
|
+
if not registry_path.exists():
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
with open(registry_path) as f:
|
|
92
|
+
_KERNEL_REGISTRY = yaml.safe_load(f)
|
|
93
|
+
return _KERNEL_REGISTRY
|
|
94
|
+
except Exception:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
|
|
99
|
+
"""Try to match kernel name against registry patterns.
|
|
100
|
+
|
|
101
|
+
Returns (Op, pattern_name) if match found, None otherwise.
|
|
102
|
+
"""
|
|
103
|
+
registry = _load_kernel_registry()
|
|
104
|
+
if not registry:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
nl = name.lower()
|
|
108
|
+
category_data = registry.get(category, {})
|
|
109
|
+
|
|
110
|
+
platform_key = "amd" if platform.lower() == "amd" else "nvidia"
|
|
111
|
+
patterns = category_data.get(platform_key, [])
|
|
112
|
+
both_patterns = category_data.get("both", [])
|
|
113
|
+
patterns = patterns + both_patterns
|
|
114
|
+
|
|
115
|
+
for pattern_entry in patterns:
|
|
116
|
+
pattern = pattern_entry.get("pattern", "")
|
|
117
|
+
if not pattern:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
|
|
121
|
+
if category == "attention":
|
|
122
|
+
phase = pattern_entry.get("phase")
|
|
123
|
+
if phase == "prefill":
|
|
124
|
+
return Op.ATTN_PREFILL, pattern
|
|
125
|
+
elif phase == "decode":
|
|
126
|
+
return Op.ATTN_DECODE, pattern
|
|
127
|
+
return Op.ATTN_PREFILL, pattern
|
|
128
|
+
elif category == "gemm":
|
|
129
|
+
return Op.DENSE_GEMM, pattern
|
|
130
|
+
elif category == "rmsnorm":
|
|
131
|
+
# Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
|
|
132
|
+
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
133
|
+
has_embedding = "embedding" in nl
|
|
134
|
+
if has_gemm and has_embedding:
|
|
135
|
+
# Embedding + RMSNorm + GEMM all fused together
|
|
136
|
+
return Op.EMBEDDING_RMSNORM_GEMM, pattern
|
|
137
|
+
elif has_gemm:
|
|
138
|
+
# RMSNorm + GEMM fused
|
|
139
|
+
return Op.RMSNORM_GEMM, pattern
|
|
140
|
+
return Op.RMSNORM, pattern
|
|
141
|
+
elif category == "moe":
|
|
142
|
+
# Distinguish between MoE sub-operations
|
|
143
|
+
if "swiglu" in nl:
|
|
144
|
+
return Op.MOE_GEMM_SWIGLU, pattern
|
|
145
|
+
if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
|
|
146
|
+
return Op.MOE_ROUTING, pattern
|
|
147
|
+
if "finalize" in nl or "scatter" in nl:
|
|
148
|
+
return Op.MOE_FINALIZE, pattern
|
|
149
|
+
return Op.MOE_GEMM, pattern
|
|
150
|
+
elif category == "kv_cache":
|
|
151
|
+
return Op.KV_CACHE, pattern
|
|
152
|
+
elif category == "softmax":
|
|
153
|
+
return Op.SOFTMAX, pattern
|
|
154
|
+
elif category == "reduce":
|
|
155
|
+
return Op.REDUCE, pattern
|
|
156
|
+
elif category == "sorting":
|
|
157
|
+
return Op.SORTING, pattern
|
|
158
|
+
elif category == "memory":
|
|
159
|
+
return Op.COPY_MEMORY, pattern
|
|
160
|
+
elif category == "indexing":
|
|
161
|
+
return Op.INDEXING, pattern
|
|
162
|
+
elif category == "elementwise":
|
|
163
|
+
return Op.ELEMENTWISE, pattern
|
|
164
|
+
elif category == "triton":
|
|
165
|
+
return Op.TRITON_FUSED, pattern
|
|
166
|
+
elif category == "activation":
|
|
167
|
+
# Check for fused SwiGLU+GEMM (AMD Triton)
|
|
168
|
+
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
169
|
+
has_silu = "silu" in nl or "swiglu" in nl
|
|
170
|
+
if has_gemm and has_silu:
|
|
171
|
+
return Op.SWIGLU_GEMM, pattern
|
|
172
|
+
elif has_silu:
|
|
173
|
+
return Op.SWIGLU, pattern
|
|
174
|
+
# Other activations (GELU, etc.)
|
|
175
|
+
return Op.TRITON_FUSED, pattern
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
|
|
181
|
+
"""Heuristically detect potential fusions based on multiple operation keywords.
|
|
182
|
+
|
|
183
|
+
This is a fallback for kernels we haven't explicitly classified.
|
|
184
|
+
If a kernel name contains 2+ distinct operation keywords, it's likely fused.
|
|
185
|
+
|
|
186
|
+
Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
|
|
187
|
+
The pattern name contains the fused components for display.
|
|
188
|
+
"""
|
|
189
|
+
nl = name.lower()
|
|
190
|
+
|
|
191
|
+
# Only check Triton kernels - these are most likely to be fused
|
|
192
|
+
if "triton" not in nl:
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
# Find all operation keywords present in the name
|
|
196
|
+
# Use ordered list to maintain consistent ordering
|
|
197
|
+
found_ops: list[str] = []
|
|
198
|
+
keyword_priority = [
|
|
199
|
+
# Order matters - more specific first
|
|
200
|
+
("embedding", "Embedding"),
|
|
201
|
+
("rmsnorm", "RMSNorm"),
|
|
202
|
+
("layernorm", "LayerNorm"),
|
|
203
|
+
("rsqrt", "Norm"), # Generic norm indicator
|
|
204
|
+
("swiglu", "SwiGLU"),
|
|
205
|
+
("silu", "SiLU"),
|
|
206
|
+
("gelu", "GELU"),
|
|
207
|
+
("relu", "ReLU"),
|
|
208
|
+
("gemm", "GEMM"),
|
|
209
|
+
("matmul", "GEMM"),
|
|
210
|
+
("mm_", "GEMM"),
|
|
211
|
+
("softmax", "Softmax"),
|
|
212
|
+
("attention", "Attention"),
|
|
213
|
+
("reduce", "Reduce"),
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
for keyword, op_name in keyword_priority:
|
|
217
|
+
if keyword in nl and op_name not in found_ops:
|
|
218
|
+
# Avoid duplicates like "RMSNorm" and "Norm"
|
|
219
|
+
if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
|
|
220
|
+
continue
|
|
221
|
+
# Avoid duplicates like "SwiGLU" and "SiLU"
|
|
222
|
+
if op_name == "SiLU" and "SwiGLU" in found_ops:
|
|
223
|
+
continue
|
|
224
|
+
found_ops.append(op_name)
|
|
225
|
+
|
|
226
|
+
# If 2+ operations detected, it's likely a fusion
|
|
227
|
+
if len(found_ops) >= 2:
|
|
228
|
+
fused_name = "+".join(found_ops)
|
|
229
|
+
# The pattern name IS the fused operation name for display
|
|
230
|
+
return Op.FUSED_UNKNOWN, fused_name
|
|
231
|
+
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@lru_cache(maxsize=4096)
|
|
236
|
+
def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
237
|
+
"""Classify kernel by operation type.
|
|
238
|
+
|
|
239
|
+
Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
name: Kernel name from trace
|
|
243
|
+
platform: 'AMD' or 'NVIDIA'
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
Tuple of (operation type, pattern name)
|
|
247
|
+
"""
|
|
248
|
+
nl = name.lower()
|
|
249
|
+
|
|
250
|
+
# Check registry patterns first (order matters - more specific categories first)
|
|
251
|
+
registry_categories = [
|
|
252
|
+
"attention", # Attention ops (prefill/decode)
|
|
253
|
+
"gemm", # Dense GEMM
|
|
254
|
+
"rmsnorm", # RMSNorm
|
|
255
|
+
"moe", # MoE operations
|
|
256
|
+
"kv_cache", # KV Cache
|
|
257
|
+
"activation", # SwiGLU, SiLU, GELU
|
|
258
|
+
"softmax", # Softmax
|
|
259
|
+
"reduce", # Reduce/Scan
|
|
260
|
+
"sorting", # Sorting
|
|
261
|
+
"memory", # Memory/Copy
|
|
262
|
+
"indexing", # Index/Scatter-Gather
|
|
263
|
+
"triton", # Triton fused ops
|
|
264
|
+
"elementwise", # Elementwise (last - most generic)
|
|
265
|
+
]
|
|
266
|
+
for category in registry_categories:
|
|
267
|
+
result = _match_registry_pattern(name, category, platform)
|
|
268
|
+
if result:
|
|
269
|
+
return result
|
|
270
|
+
if "attention" in nl or "fmha" in nl:
|
|
271
|
+
if platform == "AMD":
|
|
272
|
+
if "2d" in nl:
|
|
273
|
+
return Op.ATTN_PREFILL, "kernel_unified_attention_2d"
|
|
274
|
+
if "3d" in nl:
|
|
275
|
+
return Op.ATTN_DECODE, "kernel_unified_attention_3d"
|
|
276
|
+
else:
|
|
277
|
+
# NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
|
|
278
|
+
if "fmhasm100a" in nl or "context" in nl:
|
|
279
|
+
return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
|
|
280
|
+
if "fmhasm100f" in nl or "forgen" in nl:
|
|
281
|
+
return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
|
|
282
|
+
return Op.ATTN_PREFILL, name[:40]
|
|
283
|
+
|
|
284
|
+
# Flash Attention variants (vLLM)
|
|
285
|
+
if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
|
|
286
|
+
# Could distinguish prefill/decode if needed, defaulting to prefill
|
|
287
|
+
return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
|
|
288
|
+
|
|
289
|
+
if "reshape_and_cache" in nl:
|
|
290
|
+
return Op.KV_CACHE, "reshape_and_cache_*"
|
|
291
|
+
|
|
292
|
+
# KV Cache variants (vLLM)
|
|
293
|
+
if "concat_and_cache" in nl or "cache_mla" in nl:
|
|
294
|
+
return Op.KV_CACHE, "vllm::concat_and_cache_*"
|
|
295
|
+
|
|
296
|
+
# MoE
|
|
297
|
+
# vLLM MoE kernels - these are very common in MoE models
|
|
298
|
+
if "fused_moe_kernel" in nl:
|
|
299
|
+
return Op.MOE_GEMM, "fused_moe_kernel"
|
|
300
|
+
|
|
301
|
+
if "vllm::moe::" in name:
|
|
302
|
+
if "moe_align_block_size" in nl:
|
|
303
|
+
if "small_batch" in nl:
|
|
304
|
+
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
|
|
305
|
+
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
|
|
306
|
+
if "moe_sum" in nl:
|
|
307
|
+
return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
|
|
308
|
+
|
|
309
|
+
# vLLM act_and_mul (can be mangled C++ name)
|
|
310
|
+
if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
|
|
311
|
+
return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
|
|
312
|
+
|
|
313
|
+
if "_matmul_ogs_" in nl:
|
|
314
|
+
if "swiglu" in nl:
|
|
315
|
+
return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
|
|
316
|
+
return Op.MOE_GEMM, "_matmul_ogs_*"
|
|
317
|
+
|
|
318
|
+
if name.startswith("bmm_") and "dynbatch" in nl:
|
|
319
|
+
if "swiglu" in nl:
|
|
320
|
+
return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
|
|
321
|
+
return Op.MOE_GEMM, "bmm_*_dynBatch"
|
|
322
|
+
|
|
323
|
+
# Generic MoE routing patterns (check before finalize)
|
|
324
|
+
if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
|
|
325
|
+
if "moe::dev::routing::" in name or "moe::" in name:
|
|
326
|
+
return Op.MOE_ROUTING, "moe::dev::routing::*"
|
|
327
|
+
return Op.MOE_ROUTING, "moe_routing_*"
|
|
328
|
+
|
|
329
|
+
# MoE finalize patterns
|
|
330
|
+
if "finalize" in nl or ("scatter" in nl and "moe" in nl):
|
|
331
|
+
return Op.MOE_FINALIZE, "moe_finalize_*"
|
|
332
|
+
|
|
333
|
+
# RMSNorm - match actual patterns from traces
|
|
334
|
+
if "triton" in nl and ("rsqrt" in nl or ("mean" in nl and "mul" in nl and "pow" in nl)):
|
|
335
|
+
if "gemm" in nl or "addmm" in nl:
|
|
336
|
+
return Op.RMSNORM_GEMM, "triton_*_rmsnorm_gemm"
|
|
337
|
+
return Op.RMSNORM, "triton_*_rsqrt"
|
|
338
|
+
|
|
339
|
+
# Dense GEMM - these are the most common kernels
|
|
340
|
+
if name.startswith("Cijk_") or name.startswith("Custom_Cijk_"):
|
|
341
|
+
return Op.DENSE_GEMM, "Cijk_* (Tensile)"
|
|
342
|
+
if name.startswith("nvjet_") or "cublaslt" in nl:
|
|
343
|
+
return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
|
|
344
|
+
if "wvsplitk" in nl or name.startswith("void wvSplitK"):
|
|
345
|
+
return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
|
|
346
|
+
|
|
347
|
+
# CUTLASS GEMM variants
|
|
348
|
+
if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
|
|
349
|
+
return Op.DENSE_GEMM, "cutlass*_gemm"
|
|
350
|
+
|
|
351
|
+
# GEMV (matrix-vector) operations - treat as GEMM
|
|
352
|
+
if "gemv" in nl or "gemvx" in nl:
|
|
353
|
+
return Op.DENSE_GEMM, "gemv*"
|
|
354
|
+
|
|
355
|
+
# Generic GEMM patterns
|
|
356
|
+
if "gemmsn" in nl or name.startswith("void gemmSN"):
|
|
357
|
+
return Op.DENSE_GEMM, "gemmSN_*"
|
|
358
|
+
|
|
359
|
+
# Triton fused operations - very common
|
|
360
|
+
if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
|
|
361
|
+
# Distinguish between different fusion patterns
|
|
362
|
+
if "silu" in nl or "swiglu" in nl:
|
|
363
|
+
return Op.TRITON_FUSED, "triton_*_silu"
|
|
364
|
+
return Op.TRITON_FUSED, "triton_*"
|
|
365
|
+
|
|
366
|
+
# SoftMax operations
|
|
367
|
+
if "softmax" in nl:
|
|
368
|
+
return Op.SOFTMAX, "softmax_*"
|
|
369
|
+
|
|
370
|
+
# Reduce operations - catch more patterns
|
|
371
|
+
if "reduce" in nl:
|
|
372
|
+
if "reduce_segments" in nl or "devicereduce" in nl:
|
|
373
|
+
return Op.REDUCE, "reduce_segments"
|
|
374
|
+
if "reduce_kernel" in nl:
|
|
375
|
+
return Op.REDUCE, "reduce_kernel"
|
|
376
|
+
return Op.REDUCE, "reduce_*"
|
|
377
|
+
|
|
378
|
+
# Scan operations (library internals - similar to reduce)
|
|
379
|
+
if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
|
|
380
|
+
return Op.REDUCE, "cub/rocprim::scan_*"
|
|
381
|
+
|
|
382
|
+
# Sorting operations (common in sampling/topk)
|
|
383
|
+
if "sort" in nl or "radixsort" in nl or "merge" in nl:
|
|
384
|
+
if platform == "AMD":
|
|
385
|
+
return Op.SORTING, "rocprim::sort/merge_*"
|
|
386
|
+
else:
|
|
387
|
+
return Op.SORTING, "cub::DeviceRadixSort*"
|
|
388
|
+
|
|
389
|
+
# Indexing/Scatter/Gather operations
|
|
390
|
+
if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
|
|
391
|
+
return Op.INDEXING, "index/scatter_*"
|
|
392
|
+
|
|
393
|
+
# Memory copy operations
|
|
394
|
+
if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
|
|
395
|
+
return Op.COPY_MEMORY, "copy_*"
|
|
396
|
+
|
|
397
|
+
# PyTorch native operations (catch-all for at::native)
|
|
398
|
+
if "at::native::" in name:
|
|
399
|
+
# Try to be more specific
|
|
400
|
+
if "fill" in nl:
|
|
401
|
+
return Op.ELEMENTWISE, "at::native::fill_*"
|
|
402
|
+
return Op.ELEMENTWISE, "at::native::*"
|
|
403
|
+
|
|
404
|
+
# ROCm/CUDA library kernels (other)
|
|
405
|
+
if "rocprim::" in name or "cub::" in name:
|
|
406
|
+
return Op.OTHER, "rocprim/cub_*"
|
|
407
|
+
|
|
408
|
+
# Fallback: Heuristic fusion detection for unknown Triton kernels
|
|
409
|
+
# If a kernel has multiple operation keywords, it's likely fused
|
|
410
|
+
heuristic_result = _detect_heuristic_fusion(name)
|
|
411
|
+
if heuristic_result:
|
|
412
|
+
return heuristic_result
|
|
413
|
+
|
|
414
|
+
return Op.OTHER, name[:40]
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def classify_kernel(name: str) -> str:
|
|
418
|
+
"""Simplified kernel classification for fusion analysis.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
name: Kernel name from trace
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
Simple category name consistent across platforms
|
|
425
|
+
"""
|
|
426
|
+
nl = name.lower()
|
|
427
|
+
|
|
428
|
+
# GEMM operations (matrix multiplication)
|
|
429
|
+
if any(x in nl for x in ["cijk_", "nvjet", "wvsplitk", "cublas", "hipblas", "tensile"]):
|
|
430
|
+
return "GEMM"
|
|
431
|
+
|
|
432
|
+
# Attention
|
|
433
|
+
if "attention" in nl or "fmha" in nl:
|
|
434
|
+
return "Attention"
|
|
435
|
+
|
|
436
|
+
# KV Cache
|
|
437
|
+
if "reshape_and_cache" in nl:
|
|
438
|
+
return "KV_Cache"
|
|
439
|
+
|
|
440
|
+
# RMSNorm / LayerNorm
|
|
441
|
+
if "triton" in nl and "rsqrt" in nl:
|
|
442
|
+
return "RMSNorm"
|
|
443
|
+
if "layernorm" in nl or "rmsnorm" in nl:
|
|
444
|
+
return "RMSNorm"
|
|
445
|
+
|
|
446
|
+
# SwiGLU / Activations
|
|
447
|
+
if "silu" in nl or "swiglu" in nl:
|
|
448
|
+
return "SwiGLU"
|
|
449
|
+
if "gelu" in nl:
|
|
450
|
+
return "GELU"
|
|
451
|
+
if "relu" in nl and "gelu" not in nl:
|
|
452
|
+
return "ReLU"
|
|
453
|
+
|
|
454
|
+
# Triton fused operations (generic)
|
|
455
|
+
if "triton_poi" in nl:
|
|
456
|
+
return "Triton_Pointwise"
|
|
457
|
+
if "triton_red" in nl:
|
|
458
|
+
return "Triton_Reduce"
|
|
459
|
+
if "triton_per" in nl:
|
|
460
|
+
return "Triton_Persistent"
|
|
461
|
+
|
|
462
|
+
# Reduce operations
|
|
463
|
+
if "reduce" in nl:
|
|
464
|
+
return "Reduce"
|
|
465
|
+
|
|
466
|
+
# Sort operations
|
|
467
|
+
if "sort" in nl or "radixsort" in nl or "merge" in nl:
|
|
468
|
+
return "Sort"
|
|
469
|
+
|
|
470
|
+
# Softmax
|
|
471
|
+
if "softmax" in nl:
|
|
472
|
+
return "Softmax"
|
|
473
|
+
|
|
474
|
+
# Indexing/Scatter/Gather
|
|
475
|
+
if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
|
|
476
|
+
return "Indexing"
|
|
477
|
+
|
|
478
|
+
# Elementwise operations
|
|
479
|
+
if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):
|
|
480
|
+
return "Elementwise"
|
|
481
|
+
|
|
482
|
+
# Copy/Memory operations
|
|
483
|
+
if "copy" in nl or "memcpy" in nl:
|
|
484
|
+
return "MemCopy"
|
|
485
|
+
|
|
486
|
+
return "Other"
|