wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +376 -0
- wafer_core/lib/trace_compare/analyzer.py +558 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +280 -706
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/RECORD +28 -10
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -2,9 +2,21 @@
|
|
|
2
2
|
|
|
3
3
|
Classifies GPU kernels into operation categories (attention, GEMM, normalization, etc.)
|
|
4
4
|
based on kernel name patterns and platform-specific conventions.
|
|
5
|
+
|
|
6
|
+
Can optionally load patterns from kernel_registry.yaml, but falls back to hardcoded patterns
|
|
7
|
+
for comprehensive coverage.
|
|
5
8
|
"""
|
|
6
9
|
|
|
10
|
+
import fnmatch
|
|
7
11
|
from enum import Enum
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import yaml
|
|
18
|
+
except ImportError:
|
|
19
|
+
yaml = None
|
|
8
20
|
|
|
9
21
|
|
|
10
22
|
class Op(Enum):
|
|
@@ -20,16 +32,211 @@ class Op(Enum):
|
|
|
20
32
|
DENSE_GEMM = "Dense GEMM"
|
|
21
33
|
RMSNORM = "RMSNorm"
|
|
22
34
|
RMSNORM_GEMM = "RMSNorm+GEMM"
|
|
35
|
+
SWIGLU = "SwiGLU"
|
|
36
|
+
SWIGLU_GEMM = "SwiGLU+GEMM"
|
|
37
|
+
EMBEDDING_RMSNORM_GEMM = "Embedding+RMSNorm+GEMM"
|
|
38
|
+
SOFTMAX = "SoftMax"
|
|
23
39
|
TRITON_FUSED = "Triton Fused"
|
|
24
40
|
ELEMENTWISE = "Elementwise"
|
|
25
41
|
SORTING = "Sorting"
|
|
26
42
|
REDUCE = "Reduce"
|
|
43
|
+
INDEXING = "Indexing"
|
|
27
44
|
COPY_MEMORY = "Copy/Memory"
|
|
45
|
+
FUSED_UNKNOWN = "Fused (Unknown)" # Heuristically detected fusion
|
|
28
46
|
OTHER = "Other"
|
|
29
47
|
|
|
30
48
|
|
|
49
|
+
# Keywords that indicate specific operations - used for heuristic fusion detection
|
|
50
|
+
FUSION_KEYWORDS: dict[str, str] = {
|
|
51
|
+
# Normalization
|
|
52
|
+
"rsqrt": "Norm",
|
|
53
|
+
"rmsnorm": "RMSNorm",
|
|
54
|
+
"layernorm": "LayerNorm",
|
|
55
|
+
# GEMM
|
|
56
|
+
"gemm": "GEMM",
|
|
57
|
+
"matmul": "GEMM",
|
|
58
|
+
"mm_": "GEMM",
|
|
59
|
+
# Activations
|
|
60
|
+
"silu": "SiLU",
|
|
61
|
+
"swiglu": "SwiGLU",
|
|
62
|
+
"gelu": "GELU",
|
|
63
|
+
"relu": "ReLU",
|
|
64
|
+
# Other ops
|
|
65
|
+
"softmax": "Softmax",
|
|
66
|
+
"attention": "Attention",
|
|
67
|
+
"embedding": "Embedding",
|
|
68
|
+
"reduce": "Reduce",
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# Load kernel registry YAML if available
|
|
73
|
+
_KERNEL_REGISTRY: dict[str, Any] | None = None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _load_kernel_registry() -> dict[str, Any] | None:
|
|
77
|
+
"""Load kernel pattern registry from YAML file."""
|
|
78
|
+
global _KERNEL_REGISTRY
|
|
79
|
+
|
|
80
|
+
if _KERNEL_REGISTRY is not None:
|
|
81
|
+
return _KERNEL_REGISTRY
|
|
82
|
+
|
|
83
|
+
if yaml is None:
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
registry_path = Path(__file__).parent / "kernel_registry.yaml"
|
|
87
|
+
if not registry_path.exists():
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
with open(registry_path) as f:
|
|
92
|
+
_KERNEL_REGISTRY = yaml.safe_load(f)
|
|
93
|
+
return _KERNEL_REGISTRY
|
|
94
|
+
except Exception:
|
|
95
|
+
return None
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _match_registry_pattern(name: str, category: str, platform: str) -> tuple[Op, str] | None:
|
|
99
|
+
"""Try to match kernel name against registry patterns.
|
|
100
|
+
|
|
101
|
+
Returns (Op, pattern_name) if match found, None otherwise.
|
|
102
|
+
"""
|
|
103
|
+
registry = _load_kernel_registry()
|
|
104
|
+
if not registry:
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
nl = name.lower()
|
|
108
|
+
category_data = registry.get(category, {})
|
|
109
|
+
|
|
110
|
+
platform_key = "amd" if platform.lower() == "amd" else "nvidia"
|
|
111
|
+
patterns = category_data.get(platform_key, [])
|
|
112
|
+
both_patterns = category_data.get("both", [])
|
|
113
|
+
patterns = patterns + both_patterns
|
|
114
|
+
|
|
115
|
+
for pattern_entry in patterns:
|
|
116
|
+
pattern = pattern_entry.get("pattern", "")
|
|
117
|
+
if not pattern:
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
if fnmatch.fnmatch(name, pattern) or fnmatch.fnmatch(nl, pattern.lower()):
|
|
121
|
+
if category == "attention":
|
|
122
|
+
phase = pattern_entry.get("phase")
|
|
123
|
+
if phase == "prefill":
|
|
124
|
+
return Op.ATTN_PREFILL, pattern
|
|
125
|
+
elif phase == "decode":
|
|
126
|
+
return Op.ATTN_DECODE, pattern
|
|
127
|
+
return Op.ATTN_PREFILL, pattern
|
|
128
|
+
elif category == "gemm":
|
|
129
|
+
return Op.DENSE_GEMM, pattern
|
|
130
|
+
elif category == "rmsnorm":
|
|
131
|
+
# Detect fused operations (AMD Triton often fuses RMSNorm with GEMM)
|
|
132
|
+
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
133
|
+
has_embedding = "embedding" in nl
|
|
134
|
+
if has_gemm and has_embedding:
|
|
135
|
+
# Embedding + RMSNorm + GEMM all fused together
|
|
136
|
+
return Op.EMBEDDING_RMSNORM_GEMM, pattern
|
|
137
|
+
elif has_gemm:
|
|
138
|
+
# RMSNorm + GEMM fused
|
|
139
|
+
return Op.RMSNORM_GEMM, pattern
|
|
140
|
+
return Op.RMSNORM, pattern
|
|
141
|
+
elif category == "moe":
|
|
142
|
+
# Distinguish between MoE sub-operations
|
|
143
|
+
if "swiglu" in nl:
|
|
144
|
+
return Op.MOE_GEMM_SWIGLU, pattern
|
|
145
|
+
if any(x in nl for x in ["routing", "topk", "align_block", "count_and_sort", "gating"]):
|
|
146
|
+
return Op.MOE_ROUTING, pattern
|
|
147
|
+
if "finalize" in nl or "scatter" in nl:
|
|
148
|
+
return Op.MOE_FINALIZE, pattern
|
|
149
|
+
return Op.MOE_GEMM, pattern
|
|
150
|
+
elif category == "kv_cache":
|
|
151
|
+
return Op.KV_CACHE, pattern
|
|
152
|
+
elif category == "softmax":
|
|
153
|
+
return Op.SOFTMAX, pattern
|
|
154
|
+
elif category == "reduce":
|
|
155
|
+
return Op.REDUCE, pattern
|
|
156
|
+
elif category == "sorting":
|
|
157
|
+
return Op.SORTING, pattern
|
|
158
|
+
elif category == "memory":
|
|
159
|
+
return Op.COPY_MEMORY, pattern
|
|
160
|
+
elif category == "indexing":
|
|
161
|
+
return Op.INDEXING, pattern
|
|
162
|
+
elif category == "elementwise":
|
|
163
|
+
return Op.ELEMENTWISE, pattern
|
|
164
|
+
elif category == "triton":
|
|
165
|
+
return Op.TRITON_FUSED, pattern
|
|
166
|
+
elif category == "activation":
|
|
167
|
+
# Check for fused SwiGLU+GEMM (AMD Triton)
|
|
168
|
+
has_gemm = "gemm" in nl or "unquantized_gemm" in nl
|
|
169
|
+
has_silu = "silu" in nl or "swiglu" in nl
|
|
170
|
+
if has_gemm and has_silu:
|
|
171
|
+
return Op.SWIGLU_GEMM, pattern
|
|
172
|
+
elif has_silu:
|
|
173
|
+
return Op.SWIGLU, pattern
|
|
174
|
+
# Other activations (GELU, etc.)
|
|
175
|
+
return Op.TRITON_FUSED, pattern
|
|
176
|
+
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _detect_heuristic_fusion(name: str) -> tuple[Op, str] | None:
|
|
181
|
+
"""Heuristically detect potential fusions based on multiple operation keywords.
|
|
182
|
+
|
|
183
|
+
This is a fallback for kernels we haven't explicitly classified.
|
|
184
|
+
If a kernel name contains 2+ distinct operation keywords, it's likely fused.
|
|
185
|
+
|
|
186
|
+
Returns (Op.FUSED_UNKNOWN, "Component1+Component2+...") if suspected fusion.
|
|
187
|
+
The pattern name contains the fused components for display.
|
|
188
|
+
"""
|
|
189
|
+
nl = name.lower()
|
|
190
|
+
|
|
191
|
+
# Only check Triton kernels - these are most likely to be fused
|
|
192
|
+
if "triton" not in nl:
|
|
193
|
+
return None
|
|
194
|
+
|
|
195
|
+
# Find all operation keywords present in the name
|
|
196
|
+
# Use ordered list to maintain consistent ordering
|
|
197
|
+
found_ops: list[str] = []
|
|
198
|
+
keyword_priority = [
|
|
199
|
+
# Order matters - more specific first
|
|
200
|
+
("embedding", "Embedding"),
|
|
201
|
+
("rmsnorm", "RMSNorm"),
|
|
202
|
+
("layernorm", "LayerNorm"),
|
|
203
|
+
("rsqrt", "Norm"), # Generic norm indicator
|
|
204
|
+
("swiglu", "SwiGLU"),
|
|
205
|
+
("silu", "SiLU"),
|
|
206
|
+
("gelu", "GELU"),
|
|
207
|
+
("relu", "ReLU"),
|
|
208
|
+
("gemm", "GEMM"),
|
|
209
|
+
("matmul", "GEMM"),
|
|
210
|
+
("mm_", "GEMM"),
|
|
211
|
+
("softmax", "Softmax"),
|
|
212
|
+
("attention", "Attention"),
|
|
213
|
+
("reduce", "Reduce"),
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
for keyword, op_name in keyword_priority:
|
|
217
|
+
if keyword in nl and op_name not in found_ops:
|
|
218
|
+
# Avoid duplicates like "RMSNorm" and "Norm"
|
|
219
|
+
if op_name == "Norm" and any(n in found_ops for n in ["RMSNorm", "LayerNorm"]):
|
|
220
|
+
continue
|
|
221
|
+
# Avoid duplicates like "SwiGLU" and "SiLU"
|
|
222
|
+
if op_name == "SiLU" and "SwiGLU" in found_ops:
|
|
223
|
+
continue
|
|
224
|
+
found_ops.append(op_name)
|
|
225
|
+
|
|
226
|
+
# If 2+ operations detected, it's likely a fusion
|
|
227
|
+
if len(found_ops) >= 2:
|
|
228
|
+
fused_name = "+".join(found_ops)
|
|
229
|
+
# The pattern name IS the fused operation name for display
|
|
230
|
+
return Op.FUSED_UNKNOWN, fused_name
|
|
231
|
+
|
|
232
|
+
return None
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@lru_cache(maxsize=4096)
|
|
31
236
|
def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
32
237
|
"""Classify kernel by operation type.
|
|
238
|
+
|
|
239
|
+
Cached because PyTorch traces have ~48 unique kernel names repeated 810k times.
|
|
33
240
|
|
|
34
241
|
Args:
|
|
35
242
|
name: Kernel name from trace
|
|
@@ -39,8 +246,27 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
39
246
|
Tuple of (operation type, pattern name)
|
|
40
247
|
"""
|
|
41
248
|
nl = name.lower()
|
|
42
|
-
|
|
43
|
-
#
|
|
249
|
+
|
|
250
|
+
# Check registry patterns first (order matters - more specific categories first)
|
|
251
|
+
registry_categories = [
|
|
252
|
+
"attention", # Attention ops (prefill/decode)
|
|
253
|
+
"gemm", # Dense GEMM
|
|
254
|
+
"rmsnorm", # RMSNorm
|
|
255
|
+
"moe", # MoE operations
|
|
256
|
+
"kv_cache", # KV Cache
|
|
257
|
+
"activation", # SwiGLU, SiLU, GELU
|
|
258
|
+
"softmax", # Softmax
|
|
259
|
+
"reduce", # Reduce/Scan
|
|
260
|
+
"sorting", # Sorting
|
|
261
|
+
"memory", # Memory/Copy
|
|
262
|
+
"indexing", # Index/Scatter-Gather
|
|
263
|
+
"triton", # Triton fused ops
|
|
264
|
+
"elementwise", # Elementwise (last - most generic)
|
|
265
|
+
]
|
|
266
|
+
for category in registry_categories:
|
|
267
|
+
result = _match_registry_pattern(name, category, platform)
|
|
268
|
+
if result:
|
|
269
|
+
return result
|
|
44
270
|
if "attention" in nl or "fmha" in nl:
|
|
45
271
|
if platform == "AMD":
|
|
46
272
|
if "2d" in nl:
|
|
@@ -55,10 +281,35 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
55
281
|
return Op.ATTN_DECODE, "fmhaSm100f*_ForGen"
|
|
56
282
|
return Op.ATTN_PREFILL, name[:40]
|
|
57
283
|
|
|
284
|
+
# Flash Attention variants (vLLM)
|
|
285
|
+
if "flash::flash_fwd_kernel" in name or "flash_fwd" in nl:
|
|
286
|
+
# Could distinguish prefill/decode if needed, defaulting to prefill
|
|
287
|
+
return Op.ATTN_PREFILL, "flash::flash_fwd_kernel"
|
|
288
|
+
|
|
58
289
|
if "reshape_and_cache" in nl:
|
|
59
290
|
return Op.KV_CACHE, "reshape_and_cache_*"
|
|
60
291
|
|
|
292
|
+
# KV Cache variants (vLLM)
|
|
293
|
+
if "concat_and_cache" in nl or "cache_mla" in nl:
|
|
294
|
+
return Op.KV_CACHE, "vllm::concat_and_cache_*"
|
|
295
|
+
|
|
61
296
|
# MoE
|
|
297
|
+
# vLLM MoE kernels - these are very common in MoE models
|
|
298
|
+
if "fused_moe_kernel" in nl:
|
|
299
|
+
return Op.MOE_GEMM, "fused_moe_kernel"
|
|
300
|
+
|
|
301
|
+
if "vllm::moe::" in name:
|
|
302
|
+
if "moe_align_block_size" in nl:
|
|
303
|
+
if "small_batch" in nl:
|
|
304
|
+
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_small_batch_*"
|
|
305
|
+
return Op.MOE_ROUTING, "vllm::moe::moe_align_block_size_*"
|
|
306
|
+
if "moe_sum" in nl:
|
|
307
|
+
return Op.MOE_FINALIZE, "vllm::moe::moe_sum_*"
|
|
308
|
+
|
|
309
|
+
# vLLM act_and_mul (can be mangled C++ name)
|
|
310
|
+
if "vllm::act_and_mul_kernel" in name or ("act_and_mul_kernel" in nl and "vllm" in nl):
|
|
311
|
+
return Op.MOE_GEMM_SWIGLU, "vllm::act_and_mul_kernel"
|
|
312
|
+
|
|
62
313
|
if "_matmul_ogs_" in nl:
|
|
63
314
|
if "swiglu" in nl:
|
|
64
315
|
return Op.MOE_GEMM_SWIGLU, "_matmul_ogs_*_swiglu"
|
|
@@ -69,8 +320,13 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
69
320
|
return Op.MOE_GEMM_SWIGLU, "bmm_*_swiGlu_dynBatch"
|
|
70
321
|
return Op.MOE_GEMM, "bmm_*_dynBatch"
|
|
71
322
|
|
|
323
|
+
# Generic MoE routing patterns (check before finalize)
|
|
72
324
|
if any(x in nl for x in ["topk", "routing", "bitmatrix", "moe_forward", "_combined_routing"]):
|
|
325
|
+
if "moe::dev::routing::" in name or "moe::" in name:
|
|
326
|
+
return Op.MOE_ROUTING, "moe::dev::routing::*"
|
|
73
327
|
return Op.MOE_ROUTING, "moe_routing_*"
|
|
328
|
+
|
|
329
|
+
# MoE finalize patterns
|
|
74
330
|
if "finalize" in nl or ("scatter" in nl and "moe" in nl):
|
|
75
331
|
return Op.MOE_FINALIZE, "moe_finalize_*"
|
|
76
332
|
|
|
@@ -87,6 +343,18 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
87
343
|
return Op.DENSE_GEMM, "nvjet_* (cuBLASLt)"
|
|
88
344
|
if "wvsplitk" in nl or name.startswith("void wvSplitK"):
|
|
89
345
|
return Op.DENSE_GEMM, "wvSplitK_* (hipBLASLt)"
|
|
346
|
+
|
|
347
|
+
# CUTLASS GEMM variants
|
|
348
|
+
if "cutlass" in nl and ("sgemm" in nl or "gemm" in nl or "cutlass3x" in name.lower()):
|
|
349
|
+
return Op.DENSE_GEMM, "cutlass*_gemm"
|
|
350
|
+
|
|
351
|
+
# GEMV (matrix-vector) operations - treat as GEMM
|
|
352
|
+
if "gemv" in nl or "gemvx" in nl:
|
|
353
|
+
return Op.DENSE_GEMM, "gemv*"
|
|
354
|
+
|
|
355
|
+
# Generic GEMM patterns
|
|
356
|
+
if "gemmsn" in nl or name.startswith("void gemmSN"):
|
|
357
|
+
return Op.DENSE_GEMM, "gemmSN_*"
|
|
90
358
|
|
|
91
359
|
# Triton fused operations - very common
|
|
92
360
|
if "triton_poi" in nl or "triton_red" in nl or "triton_per" in nl:
|
|
@@ -95,9 +363,21 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
95
363
|
return Op.TRITON_FUSED, "triton_*_silu"
|
|
96
364
|
return Op.TRITON_FUSED, "triton_*"
|
|
97
365
|
|
|
98
|
-
#
|
|
99
|
-
if "
|
|
100
|
-
return Op.
|
|
366
|
+
# SoftMax operations
|
|
367
|
+
if "softmax" in nl:
|
|
368
|
+
return Op.SOFTMAX, "softmax_*"
|
|
369
|
+
|
|
370
|
+
# Reduce operations - catch more patterns
|
|
371
|
+
if "reduce" in nl:
|
|
372
|
+
if "reduce_segments" in nl or "devicereduce" in nl:
|
|
373
|
+
return Op.REDUCE, "reduce_segments"
|
|
374
|
+
if "reduce_kernel" in nl:
|
|
375
|
+
return Op.REDUCE, "reduce_kernel"
|
|
376
|
+
return Op.REDUCE, "reduce_*"
|
|
377
|
+
|
|
378
|
+
# Scan operations (library internals - similar to reduce)
|
|
379
|
+
if "scan" in nl and ("cub::" in name or "rocprim::" in name or "at_cuda_detail::cub::" in name):
|
|
380
|
+
return Op.REDUCE, "cub/rocprim::scan_*"
|
|
101
381
|
|
|
102
382
|
# Sorting operations (common in sampling/topk)
|
|
103
383
|
if "sort" in nl or "radixsort" in nl or "merge" in nl:
|
|
@@ -106,21 +386,31 @@ def classify(name: str, platform: str) -> tuple[Op, str]:
|
|
|
106
386
|
else:
|
|
107
387
|
return Op.SORTING, "cub::DeviceRadixSort*"
|
|
108
388
|
|
|
109
|
-
#
|
|
110
|
-
if
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
else:
|
|
114
|
-
return Op.REDUCE, "cub::DeviceReduce*"
|
|
115
|
-
|
|
389
|
+
# Indexing/Scatter/Gather operations
|
|
390
|
+
if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
|
|
391
|
+
return Op.INDEXING, "index/scatter_*"
|
|
392
|
+
|
|
116
393
|
# Memory copy operations
|
|
117
394
|
if "copy" in nl or "memcpy" in nl or "_copy_page_indices" in nl:
|
|
118
395
|
return Op.COPY_MEMORY, "copy_*"
|
|
119
396
|
|
|
397
|
+
# PyTorch native operations (catch-all for at::native)
|
|
398
|
+
if "at::native::" in name:
|
|
399
|
+
# Try to be more specific
|
|
400
|
+
if "fill" in nl:
|
|
401
|
+
return Op.ELEMENTWISE, "at::native::fill_*"
|
|
402
|
+
return Op.ELEMENTWISE, "at::native::*"
|
|
403
|
+
|
|
120
404
|
# ROCm/CUDA library kernels (other)
|
|
121
405
|
if "rocprim::" in name or "cub::" in name:
|
|
122
406
|
return Op.OTHER, "rocprim/cub_*"
|
|
123
407
|
|
|
408
|
+
# Fallback: Heuristic fusion detection for unknown Triton kernels
|
|
409
|
+
# If a kernel has multiple operation keywords, it's likely fused
|
|
410
|
+
heuristic_result = _detect_heuristic_fusion(name)
|
|
411
|
+
if heuristic_result:
|
|
412
|
+
return heuristic_result
|
|
413
|
+
|
|
124
414
|
return Op.OTHER, name[:40]
|
|
125
415
|
|
|
126
416
|
|
|
@@ -170,7 +460,7 @@ def classify_kernel(name: str) -> str:
|
|
|
170
460
|
return "Triton_Persistent"
|
|
171
461
|
|
|
172
462
|
# Reduce operations
|
|
173
|
-
if "
|
|
463
|
+
if "reduce" in nl:
|
|
174
464
|
return "Reduce"
|
|
175
465
|
|
|
176
466
|
# Sort operations
|
|
@@ -180,6 +470,10 @@ def classify_kernel(name: str) -> str:
|
|
|
180
470
|
# Softmax
|
|
181
471
|
if "softmax" in nl:
|
|
182
472
|
return "Softmax"
|
|
473
|
+
|
|
474
|
+
# Indexing/Scatter/Gather
|
|
475
|
+
if any(x in nl for x in ["indices", "scatter", "gather", "index_select", "embedding"]):
|
|
476
|
+
return "Indexing"
|
|
183
477
|
|
|
184
478
|
# Elementwise operations
|
|
185
479
|
if any(x in nl for x in ["elementwise", "unrolled_elementwise"]):
|