wafer-core 0.1.33__py3-none-any.whl → 0.1.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/__init__.py +9 -22
- wafer_core/lib/trace_compare/analyzer.py +160 -584
- wafer_core/lib/trace_compare/classifier.py +18 -321
- wafer_core/lib/trace_compare/fusion_analyzer.py +753 -329
- wafer_core/lib/trace_compare/loader.py +220 -413
- wafer_core/targets/__init__.py +21 -47
- wafer_core/utils/kernel_utils/defense.py +1 -813
- wafer_core/utils/kernel_utils/targets/config.py +24 -8
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/METADATA +1 -1
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/RECORD +11 -11
- {wafer_core-0.1.33.dist-info → wafer_core-0.1.35.dist-info}/WHEEL +0 -0
|
@@ -2,21 +2,9 @@
|
|
|
2
2
|
|
|
3
3
|
Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
|
|
4
4
|
based on kernel name patterns and platform-specific conventions.
|
|
5
|
-
|
|
6
|
-
Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
|
|
7
|
-
for comprehensive coverage.
|
|
8
5
|
"""
|
|
9
6
|
|
|
10
|
-
import fnmatch
|
|
11
7
|
from enum import Enum
|
|
12
|
-
from functools import lru_cache
|
|
13
|
-
from pathlib import Path
|
|
14
|
-
from typing import Any
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
import yaml
|
|
18
|
-
except ImportError:
|
|
19
|
-
yaml = None
|
|
20
8
|
|
|
21
9
|
|
|
22
10
|
class Op(Enum):
|
|
@@ -24,223 +12,24 @@ class Op(Enum):
|
|
|
24
12
|
|
|
25
13
|
ATTN_PREFILL = "Attention (Prefill)"
|
|
26
14
|
ATTN_DECODE = "Attention (Decode)"
|
|
27
|
-
# NVIDIA Flash Attention fuses QKV projection + Softmax + Attention
|
|
28
|
-
FLASH_ATTN_FUSED = "FlashAttention (QKV+Softmax+Attn)"
|
|
29
15
|
KV_CACHE = "KV Cache"
|
|
30
16
|
MOE_ROUTING = "MoE Routing"
|
|
31
17
|
MOE_GEMM = "MoE GEMM"
|
|
32
18
|
MOE_GEMM_SWIGLU = "MoE GEMM+SwiGLU"
|
|
33
19
|
MOE_FINALIZE = "MoE Finalize"
|
|
34
20
|
DENSE_GEMM = "Dense GEMM"
|
|
35
|
-
# NVIDIA cuBLASLt/CUTLASS can fuse GEMM with epilogue (bias + activation)
|
|
36
|
-
GEMM_BIAS_ACT = "GEMM+Bias+Activation"
|
|
37
21
|
RMSNORM = "RMSNorm"
|
|
38
22
|
RMSNORM_GEMM = "RMSNorm+GEMM"
|
|
39
|
-
SWIGLU = "SwiGLU"
|
|
40
|
-
SWIGLU_GEMM = "SwiGLU+GEMM"
|
|
41
|
-
EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
|
|
42
|
-
SOFTMAX = "SoftMax"
|
|
43
23
|
TRITON_FUSED = "Triton Fused"
|
|
44
24
|
ELEMENTWISE = "Elementwise"
|
|
45
25
|
SORTING = "Sorting"
|
|
46
26
|
REDUCE = "Reduce"
|
|
47
|
-
INDEXING = "Indexing"
|
|
48
27
|
COPY_MEMORY = "Copy/Memory"
|
|
49
|
-
FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
|
|
50
28
|
OTHER = "Other"
|
|
51
29
|
|
|
52
30
|
|
|
53
|
-
# Keywords that indicate specific operations - used for heuristic fusion detection
|
|
54
|
-
FUSION_KEYWORDS: dict[str, str] = {
|
|
55
|
-
# Normalization
|
|
56
|
-
"rsqrt": "Norm",
|
|
57
|
-
"rmsnorm": "RMSNorm",
|
|
58
|
-
"layernorm": "LayerNorm",
|
|
59
|
-
# GEMM
|
|
60
|
-
"gemm": "GEMM",
|
|
61
|
-
"matmul": "GEMM",
|
|
62
|
-
"mm_": "GEMM",
|
|
63
|
-
# Activations
|
|
64
|
-
"silu": "SiLU",
|
|
65
|
-
"swiglu": "SwiGLU",
|
|
66
|
-
"gelu": "GELU",
|
|
67
|
-
"relu": "ReLU",
|
|
68
|
-
# Other ops
|
|
69
|
-
"softmax": "Softmax",
|
|
70
|
-
"attention": "Attention",
|
|
71
|
-
"embedding": "Embedding",
|
|
72
|
-
"reduce": "Reduce",
|
|
73
|
-
}
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
# Load kernel registry YAML if available
|
|
77
|
-
_KERNEL_REGISTRY: dict[str, Any] | None = None
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _load_kernel_registry() -> dict[str, Any] | None:
|
|
81
|
-
"""Load kernel pattern registry from YAML file."""
|
|
82
|
-
global _KERNEL_REGISTRY
|
|
83
|
-
|
|
84
|
-
if _KERNEL_REGISTRY is not None:
|
|
85
|
-
return _KERNEL_REGISTRY
|
|
86
|
-
|
|
87
|
-
if yaml is None:
|
|
88
|
-
return None
|
|
89
|
-
|
|
90
|
-
registry_path = Path(__file__).parent / "kernel_registry.yaml"
|
|
91
|
-
if not registry_path.exists():
|
|
92
|
-
return None
|
|
93
|
-
|
|
94
|
-
try:
|
|
95
|
-
with open(registry_path) as f:
|
|
96
|
-
_KERNEL_REGISTRY = yaml.safe_load(f)
|
|
97
|
-
return _KERNEL_REGISTRY
|
|
98
|
-
except Exception:
|
|
99
|
-
return None
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
|
|
103
|
-
"""Try to match kernel name against registry patterns.
|
|
104
|
-
|
|
105
|
-
Returns (Op, pattern_name) if match found, None otherwise.
|
|
106
|
-
"""
|
|
107
|
-
registry = _load_kernel_registry()
|
|
108
|
-
if not registry:
|
|
109
|
-
return None
|
|
110
|
-
|
|
111
|
-
nl = name.lower()
|
|
112
|
-
category_data = registry.get(category, {})
|
|
113
|
-
|
|
114
|
-
platform_key = "amd" if platform.lower() == "amd" else "nvidia"
|
|
115
|
-
patterns = category_data.get(platform_key, [])
|
|
116
|
-
both_patterns = category_data.get("both", [])
|
|
117
|
-
patterns = patterns + both_patterns
|
|
118
|
-
|
|
119
|
-
for pattern_entry in patterns:
|
|
120
|
-
pattern = pattern_entry.get("pattern", "")
|
|
121
|
-
if not pattern:
|
|
122
|
-
continue
|
|
123
|
-
|
|
124
|
-
if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
|
|
125
|
-
if category == "attention":
|
|
126
|
-
phase = pattern_entry.get("phase")
|
|
127
|
-
if phase == "prefill":
|
|
128
|
-
return Op.ATTN_PREFILL, pattern
|
|
129
|
-
elif phase == "decode":
|
|
130
|
-
return Op.ATTN_DECODE, pattern
|
|
131
|
-
return Op.ATTN_PREFILL, pattern
|
|
132
|
-
elif category == "gemm":
|
|
133
|
-
return Op.DENSE_GEMM, pattern
|
|
134
|
-
elif category == "rmsnorm":
|
|
135
|
-
# Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
|
|
136
|
-
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
137
|
-
has_embedding = "embedding" in nl
|
|
138
|
-
if has_gemm and has_embedding:
|
|
139
|
-
# Embedding + RMSNorm + GEMM all fused together
|
|
140
|
-
return Op.EMBEDDING_RMSNORM_GEMM, pattern
|
|
141
|
-
elif has_gemm:
|
|
142
|
-
# RMSNorm + GEMM fused
|
|
143
|
-
return Op.RMSNORM_GEMM, pattern
|
|
144
|
-
return Op.RMSNORM, pattern
|
|
145
|
-
elif category == "moe":
|
|
146
|
-
# Distinguish between MoE sub-operations
|
|
147
|
-
if "swiglu" in nl:
|
|
148
|
-
return Op.MOE_GEMM_SWIGLU, pattern
|
|
149
|
-
if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
|
|
150
|
-
return Op.MOE_ROUTING, pattern
|
|
151
|
-
if "finalize" in nl or "scatter" in nl:
|
|
152
|
-
return Op.MOE_FINALIZE, pattern
|
|
153
|
-
return Op.MOE_GEMM, pattern
|
|
154
|
-
elif category == "kv_cache":
|
|
155
|
-
return Op.KV_CACHE, pattern
|
|
156
|
-
elif category == "softmax":
|
|
157
|
-
return Op.SOFTMAX, pattern
|
|
158
|
-
elif category == "reduce":
|
|
159
|
-
return Op.REDUCE, pattern
|
|
160
|
-
elif category == "sorting":
|
|
161
|
-
return Op.SORTING, pattern
|
|
162
|
-
elif category == "memory":
|
|
163
|
-
return Op.COPY_MEMORY, pattern
|
|
164
|
-
elif category == "indexing":
|
|
165
|
-
return Op.INDEXING, pattern
|
|
166
|
-
elif category == "elementwise":
|
|
167
|
-
return Op.ELEMENTWISE, pattern
|
|
168
|
-
elif category == "triton":
|
|
169
|
-
return Op.TRITON_FUSED, pattern
|
|
170
|
-
elif category == "activation":
|
|
171
|
-
# Check for fused SwiGLU+GEMM (AMD Triton)
|
|
172
|
-
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
173
|
-
has_silu = "silu" in nl or "swiglu" in nl
|
|
174
|
-
if has_gemm and has_silu:
|
|
175
|
-
return Op.SWIGLU_GEMM, pattern
|
|
176
|
-
elif has_silu:
|
|
177
|
-
return Op.SWIGLU, pattern
|
|
178
|
-
# Other activations (GELU, etc.)
|
|
179
|
-
return Op.TRITON_FUSED, pattern
|
|
180
|
-
|
|
181
|
-
return None
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
|
|
185
|
-
"""Heuristically detect potential fusions based on multiple operation keywords.
|
|
186
|
-
|
|
187
|
-
This is a fallback for kernels we haven't explicitly classified.
|
|
188
|
-
If a kernel name contains 2+ distinct operation keywords, it's likely fused.
|
|
189
|
-
|
|
190
|
-
Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
|
|
191
|
-
The pattern name contains the fused components for display.
|
|
192
|
-
"""
|
|
193
|
-
nl = name.lower()
|
|
194
|
-
|
|
195
|
-
# Only check Triton kernels - these are most likely to be fused
|
|
196
|
-
if "triton" not in nl:
|
|
197
|
-
return None
|
|
198
|
-
|
|
199
|
-
# Find all operation keywords present in the name
|
|
200
|
-
# Use ordered list to maintain consistent ordering
|
|
201
|
-
found_ops: list[str] = []
|
|
202
|
-
keyword_priority = [
|
|
203
|
-
# Order matters - more specific first
|
|
204
|
-
("embedding", "Embedding"),
|
|
205
|
-
("rmsnorm", "RMSNorm"),
|
|
206
|
-
("layernorm", "LayerNorm"),
|
|
207
|
-
("rsqrt", "Norm"), # Generic norm indicator
|
|
208
|
-
("swiglu", "SwiGLU"),
|
|
209
|
-
("silu", "SiLU"),
|
|
210
|
-
("gelu", "GELU"),
|
|
211
|
-
("relu", "ReLU"),
|
|
212
|
-
("gemm", "GEMM"),
|
|
213
|
-
("matmul", "GEMM"),
|
|
214
|
-
("mm_", "GEMM"),
|
|
215
|
-
("softmax", "Softmax"),
|
|
216
|
-
("attention", "Attention"),
|
|
217
|
-
("reduce", "Reduce"),
|
|
218
|
-
]
|
|
219
|
-
|
|
220
|
-
for keyword, op_name in keyword_priority:
|
|
221
|
-
if keyword in nl and op_name not in found_ops:
|
|
222
|
-
# Avoid duplicates like "RMSNorm" and "Norm"
|
|
223
|
-
if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
|
|
224
|
-
continue
|
|
225
|
-
# Avoid duplicates like "SwiGLU" and "SiLU"
|
|
226
|
-
if op_name == "SiLU" and "SwiGLU" in found_ops:
|
|
227
|
-
continue
|
|
228
|
-
found_ops.append(op_name)
|
|
229
|
-
|
|
230
|
-
# If 2+ operations detected, it's likely a fusion
|
|
231
|
-
if len(found_ops) >= 2:
|
|
232
|
-
fused_name = "+".join(found_ops)
|
|
233
|
-
# The pattern name IS the fused operation name for display
|
|
234
|
-
return Op.FUSED_UNKNOWN, fused_name
|
|
235
|
-
|
|
236
|
-
return None
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
@lru_cache(maxsize=4096)
|
|
240
31
|
def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
241
32
|
"""Classify kernel by operation type.
|
|
242
|
-
|
|
243
|
-
Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
|
|
244
33
|
|
|
245
34
|
Args:
|
|
246
35
|
name: Kernel name from trace
|
|
@@ -250,27 +39,8 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
250
39
|
Tuple of (operation type, pattern name)
|
|
251
40
|
"""
|
|
252
41
|
nl = name.lower()
|
|
253
|
-
|
|
254
|
-
#
|
|
255
|
-
registry_categories = [
|
|
256
|
-
"attention", # Attention ops (prefill/decode)
|
|
257
|
-
"gemm", # Dense GEMM
|
|
258
|
-
"rmsnorm", # RMSNorm
|
|
259
|
-
"moe", # MoE operations
|
|
260
|
-
"kv_cache", # KV Cache
|
|
261
|
-
"activation", # SwiGLU, SiLU, GELU
|
|
262
|
-
"softmax", # Softmax
|
|
263
|
-
"reduce", # Reduce/Scan
|
|
264
|
-
"sorting", # Sorting
|
|
265
|
-
"memory", # Memory/Copy
|
|
266
|
-
"indexing", # Index/Scatter-Gather
|
|
267
|
-
"triton", # Triton fused ops
|
|
268
|
-
"elementwise", # Elementwise (last - most generic)
|
|
269
|
-
]
|
|
270
|
-
for category in registry_categories:
|
|
271
|
-
result = _match_registry_pattern(name, category, platform)
|
|
272
|
-
if result:
|
|
273
|
-
return result
|
|
42
|
+
|
|
43
|
+
# Attention
|
|
274
44
|
if "attention" in nl or "fmha" in nl:
|
|
275
45
|
if platform == "AMD":
|
|
276
46
|
if "2d" in nl:
|
|
@@ -278,47 +48,17 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
278
48
|
if "3d" in nl:
|
|
279
49
|
return Op.ATTN_DECODE, "kernel_unified_attention_3d"
|
|
280
50
|
else:
|
|
281
|
-
# NVIDIA
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
if "fmhasm100f" in nl or "forgen" in nl:
|
|
287
|
-
return Op.FLASH_ATTN_FUSED, "fmhaSm100f*_ForGen (QKV+Softmax+Attn)"
|
|
288
|
-
return Op.FLASH_ATTN_FUSED, "fmhaSm100* (QKV+Softmax+Attn)"
|
|
51
|
+
# NVIDIA uses fmhaSm100 with 'a' (prefill/context) and 'f' (decode/forgen)
|
|
52
|
+
if "fmhasm100a" in nl or "context" in nl:
|
|
53
|
+
return Op.ATTN_PREFILL, "fmhaSm100a*_Context"
|
|
54
|
+
if "fmhasm100f" in nl or "forgen" in nl:
|
|
55
|
+
return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
|
|
289
56
|
return Op.ATTN_PREFILL, name[:40]
|
|
290
57
|
|
|
291
|
-
# Flash Attention variants (vLLM) - these are fused on NVIDIA
|
|
292
|
-
if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
|
|
293
|
-
if platform != "AMD":
|
|
294
|
-
return Op.FLASH_ATTN_FUSED, "flash::flash_fwd_kernel (QKV+Softmax+Attn)"
|
|
295
|
-
return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
|
|
296
|
-
|
|
297
58
|
if "reshape_and_cache" in nl:
|
|
298
59
|
return Op.KV_CACHE, "reshape_and_cache_*"
|
|
299
60
|
|
|
300
|
-
# KV Cache variants (vLLM)
|
|
301
|
-
if "concat_and_cache" in nl or "cache_mla" in nl:
|
|
302
|
-
return Op.KV_CACHE, "vllm::concat_and_cache_*"
|
|
303
|
-
|
|
304
61
|
# MoE
|
|
305
|
-
# vLLM MoE kernels - these are very common in MoE models
|
|
306
|
-
if "fused_moe_kernel" in nl:
|
|
307
|
-
return Op.MOE_GEMM, "fused_moe_kernel"
|
|
308
|
-
|
|
309
|
-
if "vllm::moe::" in name:
|
|
310
|
-
if "moe_align_block_size" in nl:
|
|
311
|
-
if "small_batch" in nl:
|
|
312
|
-
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
|
|
313
|
-
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
|
|
314
|
-
if "moe_sum" in nl:
|
|
315
|
-
return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
|
|
316
|
-
|
|
317
|
-
# vLLM act_and_mul - fuses activation with element-wise multiply (SiLU * x)
|
|
318
|
-
# This is a fused operation used in SwiGLU/MoE
|
|
319
|
-
if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
|
|
320
|
-
return Op.SWIGLU_GEMM, "vllm::act_and_mul_kernel (SwiGLU+Mul)"
|
|
321
|
-
|
|
322
62
|
if "_matmul_ogs_" in nl:
|
|
323
63
|
if "swiglu" in nl:
|
|
324
64
|
return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
|
|
@@ -329,13 +69,8 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
329
69
|
return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
|
|
330
70
|
return Op.MOE_GEMM, "bmm_*_dynBatch"
|
|
331
71
|
|
|
332
|
-
# Generic MoE routing patterns (check before finalize)
|
|
333
72
|
if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
|
|
334
|
-
if "moe::dev::routing::" in name or "moe::" in name:
|
|
335
|
-
return Op.MOE_ROUTING, "moe::dev::routing::*"
|
|
336
73
|
return Op.MOE_ROUTING, "moe_routing_*"
|
|
337
|
-
|
|
338
|
-
# MoE finalize patterns
|
|
339
74
|
if "finalize" in nl or ("scatter" in nl and "moe" in nl):
|
|
340
75
|
return Op.MOE_FINALIZE, "moe_finalize_*"
|
|
341
76
|
|
|
@@ -352,18 +87,6 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
352
87
|
return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
|
|
353
88
|
if "wvsplitk" in nl or name.startswith("void wvSplitK"):
|
|
354
89
|
return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
|
|
355
|
-
|
|
356
|
-
# CUTLASS GEMM variants
|
|
357
|
-
if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
|
|
358
|
-
return Op.DENSE_GEMM, "cutlass*_gemm"
|
|
359
|
-
|
|
360
|
-
# GEMV (matrix-vector) operations - treat as GEMM
|
|
361
|
-
if "gemv" in nl or "gemvx" in nl:
|
|
362
|
-
return Op.DENSE_GEMM, "gemv*"
|
|
363
|
-
|
|
364
|
-
# Generic GEMM patterns
|
|
365
|
-
if "gemmsn" in nl or name.startswith("void gemmSN"):
|
|
366
|
-
return Op.DENSE_GEMM, "gemmSN_*"
|
|
367
90
|
|
|
368
91
|
# Triton fused operations - very common
|
|
369
92
|
if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
|
|
@@ -372,21 +95,9 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
372
95
|
return Op.TRITON_FUSED, "triton_*_silu"
|
|
373
96
|
return Op.TRITON_FUSED, "triton_*"
|
|
374
97
|
|
|
375
|
-
#
|
|
376
|
-
if "
|
|
377
|
-
return Op.
|
|
378
|
-
|
|
379
|
-
# Reduce operations - catch more patterns
|
|
380
|
-
if "reduce" in nl:
|
|
381
|
-
if "reduce_segments" in nl or "devicereduce" in nl:
|
|
382
|
-
return Op.REDUCE, "reduce_segments"
|
|
383
|
-
if "reduce_kernel" in nl:
|
|
384
|
-
return Op.REDUCE, "reduce_kernel"
|
|
385
|
-
return Op.REDUCE, "reduce_*"
|
|
386
|
-
|
|
387
|
-
# Scan operations (library internals - similar to reduce)
|
|
388
|
-
if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
|
|
389
|
-
return Op.REDUCE, "cub/rocprim::scan_*"
|
|
98
|
+
# PyTorch native operations
|
|
99
|
+
if "at::native::" in name:
|
|
100
|
+
return Op.ELEMENTWISE, "at::native::*"
|
|
390
101
|
|
|
391
102
|
# Sorting operations (common in sampling/topk)
|
|
392
103
|
if "sort" in nl or "radixsort" in nl or "merge" in nl:
|
|
@@ -395,31 +106,21 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
395
106
|
else:
|
|
396
107
|
return Op.SORTING, "cub::DeviceRadixSort*"
|
|
397
108
|
|
|
398
|
-
#
|
|
399
|
-
if
|
|
400
|
-
|
|
401
|
-
|
|
109
|
+
# Reduce operations
|
|
110
|
+
if "reduce" in nl and ("reduce_segments" in nl or "devicereduce" in nl or "devicescan" in nl):
|
|
111
|
+
if platform == "AMD":
|
|
112
|
+
return Op.REDUCE, "reduce_segments"
|
|
113
|
+
else:
|
|
114
|
+
return Op.REDUCE, "cub::DeviceReduce*"
|
|
115
|
+
|
|
402
116
|
# Memory copy operations
|
|
403
117
|
if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
|
|
404
118
|
return Op.COPY_MEMORY, "copy_*"
|
|
405
119
|
|
|
406
|
-
# PyTorch native operations (catch-all for at::native)
|
|
407
|
-
if "at::native::" in name:
|
|
408
|
-
# Try to be more specific
|
|
409
|
-
if "fill" in nl:
|
|
410
|
-
return Op.ELEMENTWISE, "at::native::fill_*"
|
|
411
|
-
return Op.ELEMENTWISE, "at::native::*"
|
|
412
|
-
|
|
413
120
|
# ROCm/CUDA library kernels (other)
|
|
414
121
|
if "rocprim::" in name or "cub::" in name:
|
|
415
122
|
return Op.OTHER, "rocprim/cub_*"
|
|
416
123
|
|
|
417
|
-
# Fallback: Heuristic fusion detection for unknown Triton kernels
|
|
418
|
-
# If a kernel has multiple operation keywords, it's likely fused
|
|
419
|
-
heuristic_result = _detect_heuristic_fusion(name)
|
|
420
|
-
if heuristic_result:
|
|
421
|
-
return heuristic_result
|
|
422
|
-
|
|
423
124
|
return Op.OTHER, name[:40]
|
|
424
125
|
|
|
425
126
|
|
|
@@ -469,7 +170,7 @@ def classify_kernel(name: str) -> str:
|
|
|
469
170
|
return "Triton_Persistent"
|
|
470
171
|
|
|
471
172
|
# Reduce operations
|
|
472
|
-
if "
|
|
173
|
+
if "reduce_segments" in nl or "devicereduce" in nl:
|
|
473
174
|
return "Reduce"
|
|
474
175
|
|
|
475
176
|
# Sort operations
|
|
@@ -479,10 +180,6 @@ def classify_kernel(name: str) -> str:
|
|
|
479
180
|
# Softmax
|
|
480
181
|
if "softmax" in nl:
|
|
481
182
|
return "Softmax"
|
|
482
|
-
|
|
483
|
-
# Indexing/Scatter/Gather
|
|
484
|
-
if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
|
|
485
|
-
return "Indexing"
|
|
486
183
|
|
|
487
184
|
# Elementwise operations
|
|
488
185
|
if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):
|