wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,356 @@
1
+ """Fusion analysis using aligned kernel pairs.
2
+
3
+ Detects kernel fusion patterns at the kernel-pair level within aligned layers.
4
+
5
+ Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
6
+ that platform IS fusing. The OTHER platform runs the components separately.
7
+ """
8
+
9
+ from collections import Counter, defaultdict
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from .aligner import KernelPair, LayerAlignment
14
+
15
+
16
+ # Maps fused operation names to their component operations
17
+ FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
18
+ "RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
19
+ "SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
20
+ "MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
21
+ "Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
22
+ }
23
+
24
+
25
+ @dataclass
26
+ class FusionPattern:
27
+ """A detected fusion pattern."""
28
+
29
+ layer: int
30
+ operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
31
+ fused_platform: str # Platform that fuses (has fewer kernels)
32
+ fused_kernel: str # The actual fused kernel name
33
+ unfused_kernels: list[str] # List of separate kernels on the other platform
34
+ count: int
35
+ evidence: str
36
+
37
+
38
+ @dataclass
39
+ class FusionAnalysis:
40
+ """Complete fusion analysis result."""
41
+
42
+ patterns: list[FusionPattern] = field(default_factory=list)
43
+ summary: dict[str, Any] = field(default_factory=dict)
44
+
45
+
46
+ def _is_fused_operation(operation: str) -> bool:
47
+ """Check if an operation type represents a fused operation.
48
+
49
+ Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
50
+ Also handles 'Fused (Unknown)' from heuristic detection.
51
+ """
52
+ return "+" in operation or operation == "Fused (Unknown)"
53
+
54
+
55
+ def _get_component_operations(fused_op: str) -> list[str]:
56
+ """Get the component operations for a fused operation type."""
57
+ if fused_op in FUSED_OPERATION_COMPONENTS:
58
+ return FUSED_OPERATION_COMPONENTS[fused_op]
59
+
60
+ # Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
61
+ if "+" in fused_op:
62
+ parts = [p.strip() for p in fused_op.split("+")]
63
+ # Map shorthand to full names
64
+ mapping = {
65
+ "GEMM": "Dense GEMM",
66
+ "SwiGLU": "Triton Fused",
67
+ }
68
+ return [mapping.get(p, p) for p in parts]
69
+
70
+ return []
71
+
72
+
73
+ def _find_component_kernels(
74
+ layer_alignment: LayerAlignment,
75
+ component_ops: list[str],
76
+ platform: str,
77
+ ) -> list[str]:
78
+ """Find kernels for component operations on the specified platform.
79
+
80
+ Args:
81
+ layer_alignment: Layer alignment data
82
+ component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
83
+ platform: "AMD" or "NVIDIA"
84
+
85
+ Returns:
86
+ List of kernel names found for these operations
87
+ """
88
+ found_kernels: list[str] = []
89
+
90
+ for pair in layer_alignment.kernel_pairs:
91
+ if pair.operation in component_ops:
92
+ if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
93
+ found_kernels.append(pair.amd_kernel)
94
+ elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
95
+ found_kernels.append(pair.nvidia_kernel)
96
+
97
+ return found_kernels
98
+
99
+
100
+ def detect_fusion_in_aligned_layers(
101
+ layer_alignments: list[LayerAlignment],
102
+ ) -> FusionAnalysis:
103
+ """Detect fusion patterns from aligned layer data.
104
+
105
+ Args:
106
+ layer_alignments: List of aligned layers
107
+
108
+ Returns:
109
+ FusionAnalysis with detected patterns
110
+ """
111
+ fusion_patterns: list[FusionPattern] = []
112
+
113
+ # Track patterns we've already seen to avoid duplicates
114
+ seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
115
+
116
+ for layer_alignment in layer_alignments:
117
+ for pair in layer_alignment.kernel_pairs:
118
+ is_fused_op = _is_fused_operation(pair.operation)
119
+
120
+ # Case 1: AMD has a kernel for this operation, NVIDIA doesn't
121
+ if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
122
+ if is_fused_op:
123
+ # AMD HAS the fused kernel → AMD is fusing
124
+ # Find what NVIDIA runs separately
125
+ component_ops = _get_component_operations(pair.operation)
126
+ unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
127
+
128
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
129
+ if pattern_key not in seen_patterns:
130
+ seen_patterns.add(pattern_key)
131
+
132
+ if unfused_kernels:
133
+ evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
134
+ else:
135
+ evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
136
+
137
+ fusion_patterns.append(
138
+ FusionPattern(
139
+ layer=layer_alignment.layer,
140
+ operation=pair.operation,
141
+ fused_platform="AMD",
142
+ fused_kernel=pair.amd_kernel,
143
+ unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
144
+ count=pair.amd_count,
145
+ evidence=evidence,
146
+ )
147
+ )
148
+ else:
149
+ # Regular case: AMD has a kernel that NVIDIA fuses into something else
150
+ # This means NVIDIA is fusing (it doesn't need this separate kernel)
151
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
152
+ if pattern_key not in seen_patterns:
153
+ seen_patterns.add(pattern_key)
154
+
155
+ evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
156
+
157
+ fusion_patterns.append(
158
+ FusionPattern(
159
+ layer=layer_alignment.layer,
160
+ operation=pair.operation,
161
+ fused_platform="NVIDIA",
162
+ fused_kernel="(fused into nearby kernel)",
163
+ unfused_kernels=[pair.amd_kernel],
164
+ count=pair.amd_count,
165
+ evidence=evidence,
166
+ )
167
+ )
168
+
169
+ # Case 2: NVIDIA has a kernel for this operation, AMD doesn't
170
+ elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
171
+ if is_fused_op:
172
+ # NVIDIA HAS the fused kernel → NVIDIA is fusing
173
+ # Find what AMD runs separately
174
+ component_ops = _get_component_operations(pair.operation)
175
+ unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
176
+
177
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
178
+ if pattern_key not in seen_patterns:
179
+ seen_patterns.add(pattern_key)
180
+
181
+ if unfused_kernels:
182
+ evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
183
+ else:
184
+ evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
185
+
186
+ fusion_patterns.append(
187
+ FusionPattern(
188
+ layer=layer_alignment.layer,
189
+ operation=pair.operation,
190
+ fused_platform="NVIDIA",
191
+ fused_kernel=pair.nvidia_kernel,
192
+ unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
193
+ count=pair.nvidia_count,
194
+ evidence=evidence,
195
+ )
196
+ )
197
+ else:
198
+ # Regular case: NVIDIA has a kernel that AMD fuses into something else
199
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
200
+ if pattern_key not in seen_patterns:
201
+ seen_patterns.add(pattern_key)
202
+
203
+ evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
204
+
205
+ fusion_patterns.append(
206
+ FusionPattern(
207
+ layer=layer_alignment.layer,
208
+ operation=pair.operation,
209
+ fused_platform="AMD",
210
+ fused_kernel="(fused into nearby kernel)",
211
+ unfused_kernels=[pair.nvidia_kernel],
212
+ count=pair.nvidia_count,
213
+ evidence=evidence,
214
+ )
215
+ )
216
+
217
+ # Case 3: Both have kernels but with very different counts (partial fusion)
218
+ elif (
219
+ pair.amd_kernel
220
+ and pair.nvidia_kernel
221
+ and pair.amd_count > 0
222
+ and pair.nvidia_count > 0
223
+ ):
224
+ count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
225
+
226
+ if count_ratio > 1.5:
227
+ # AMD runs more → NVIDIA fuses some instances
228
+ pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
229
+ if pattern_key not in seen_patterns:
230
+ seen_patterns.add(pattern_key)
231
+
232
+ evidence = (
233
+ f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
234
+ f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
235
+ )
236
+
237
+ fusion_patterns.append(
238
+ FusionPattern(
239
+ layer=layer_alignment.layer,
240
+ operation=pair.operation,
241
+ fused_platform="NVIDIA",
242
+ fused_kernel=pair.nvidia_kernel,
243
+ unfused_kernels=[pair.amd_kernel],
244
+ count=pair.amd_count - pair.nvidia_count,
245
+ evidence=evidence,
246
+ )
247
+ )
248
+
249
+ elif count_ratio < 0.67:
250
+ # NVIDIA runs more → AMD fuses some instances
251
+ inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
252
+ pattern_key = (layer_alignment.layer, pair.operation, "AMD")
253
+ if pattern_key not in seen_patterns:
254
+ seen_patterns.add(pattern_key)
255
+
256
+ evidence = (
257
+ f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
258
+ f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
259
+ )
260
+
261
+ fusion_patterns.append(
262
+ FusionPattern(
263
+ layer=layer_alignment.layer,
264
+ operation=pair.operation,
265
+ fused_platform="AMD",
266
+ fused_kernel=pair.amd_kernel,
267
+ unfused_kernels=[pair.nvidia_kernel],
268
+ count=pair.nvidia_count - pair.amd_count,
269
+ evidence=evidence,
270
+ )
271
+ )
272
+
273
+ # Aggregate patterns by operation type and fused platform
274
+ aggregated: dict[tuple[str, str], FusionPattern] = {}
275
+ for pattern in fusion_patterns:
276
+ key = (pattern.operation, pattern.fused_platform)
277
+ if key in aggregated:
278
+ existing = aggregated[key]
279
+ existing.count += pattern.count
280
+ # Merge unfused kernels (deduplicate)
281
+ for k in pattern.unfused_kernels:
282
+ if k not in existing.unfused_kernels:
283
+ existing.unfused_kernels.append(k)
284
+ else:
285
+ aggregated[key] = FusionPattern(
286
+ layer=pattern.layer,
287
+ operation=pattern.operation,
288
+ fused_platform=pattern.fused_platform,
289
+ fused_kernel=pattern.fused_kernel,
290
+ unfused_kernels=list(pattern.unfused_kernels),
291
+ count=pattern.count,
292
+ evidence=pattern.evidence,
293
+ )
294
+
295
+ amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
296
+ nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
297
+
298
+ return FusionAnalysis(
299
+ patterns=list(aggregated.values()),
300
+ summary={
301
+ "amd_fuses": amd_fuses,
302
+ "nvidia_fuses": nvidia_fuses,
303
+ "total_fusion_opportunities": len(aggregated),
304
+ },
305
+ )
306
+
307
+
308
+ def analyze_fusion_from_alignment(
309
+ layer_alignments: list[LayerAlignment],
310
+ ) -> dict[str, Any]:
311
+ """Analyze fusion from alignment data (for API compatibility).
312
+
313
+ Args:
314
+ layer_alignments: List of aligned layers
315
+
316
+ Returns:
317
+ Dictionary with fusion analysis results (compatible with old API)
318
+ """
319
+ fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
320
+
321
+ fusion_opportunities = []
322
+ fusion_mappings = []
323
+
324
+ for pattern in fusion_analysis.patterns:
325
+ unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
326
+
327
+ fusion_opportunities.append(
328
+ {
329
+ "kernel_type": pattern.operation,
330
+ "layer": pattern.layer,
331
+ "fused_by": pattern.fused_platform,
332
+ "fused_kernel": pattern.fused_kernel,
333
+ "unfused_kernels": pattern.unfused_kernels,
334
+ "count": pattern.count,
335
+ "evidence": pattern.evidence,
336
+ }
337
+ )
338
+
339
+ fusion_mappings.append(
340
+ {
341
+ "fused_platform": pattern.fused_platform,
342
+ "fused_kernel_type": pattern.operation,
343
+ "fused_kernel_name": pattern.fused_kernel,
344
+ "unfused_platform": unfused_platform,
345
+ "unfused_sequence": pattern.unfused_kernels,
346
+ "pattern_count": pattern.count,
347
+ "evidence": pattern.evidence,
348
+ "layer": pattern.layer,
349
+ }
350
+ )
351
+
352
+ return {
353
+ "fusion_opportunities": fusion_opportunities,
354
+ "fusion_mappings": fusion_mappings,
355
+ "summary": fusion_analysis.summary,
356
+ }
@@ -0,0 +1,349 @@
1
+ # Kernel Pattern Registry
2
+ # Version: 2025-01
3
+ # Last updated: 2025-01-28
4
+ # Update when: New GPU architecture, new library version, new model architecture
5
+
6
+ version: "2025-01"
7
+
8
+ # ============================================================================
9
+ # SUPPORTED HARDWARE
10
+ # ============================================================================
11
+ # NVIDIA:
12
+ # - SM100 (Blackwell): B200, B100
13
+ # - SM90 (Hopper): H100, H200
14
+ # - SM89 (Ada Lovelace): L40, RTX 4090
15
+ # - SM80 (Ampere): A100, A10, A30
16
+ #
17
+ # AMD:
18
+ # - CDNA 4 (gfx950): MI355X
19
+ # - CDNA 3 (gfx942): MI300X, MI300A, MI325X
20
+ # - CDNA 2 (gfx90a): MI250X, MI210
21
+ #
22
+ # Note: MI325X uses same gfx942 ISA as MI300X but with 256GB HBM3e memory
23
+ # ============================================================================
24
+
25
+ attention:
26
+ nvidia:
27
+ # SM100 (Blackwell B200/B100) - 'a' suffix = prefill/context, 'f' suffix = decode/forgen
28
+ - pattern: "fmhaSm100a*"
29
+ hardware: "SM100 (Blackwell)"
30
+ library: "Flash Attention 3"
31
+ phase: prefill
32
+ - pattern: "fmhaSm100f*"
33
+ hardware: "SM100 (Blackwell)"
34
+ library: "Flash Attention 3"
35
+ phase: decode
36
+ # SM90 (Hopper H100/H200) - Flash Attention 2/3
37
+ - pattern: "fmhaSm90*"
38
+ hardware: "SM90 (Hopper)"
39
+ library: "Flash Attention 3"
40
+ - pattern: "flash::flash_fwd_kernel*"
41
+ hardware: "SM90 (Hopper)"
42
+ library: "Flash Attention 2"
43
+ phase: prefill
44
+ - pattern: "flash_fwd_*"
45
+ hardware: "SM90 (Hopper)"
46
+ library: "Flash Attention 2"
47
+ - pattern: "fmha_v2_*flash_attention_forward*"
48
+ hardware: "SM90 (Hopper)"
49
+ library: "Flash Attention 2"
50
+ phase: prefill
51
+ - pattern: "fmha_v2_*"
52
+ hardware: "SM90 (Hopper)"
53
+ library: "Flash Attention 2"
54
+ # SM89 (Ada Lovelace L40/RTX 4090)
55
+ - pattern: "fmhaSm89*"
56
+ hardware: "SM89 (Ada Lovelace)"
57
+ library: "Flash Attention"
58
+ # SM80 (Ampere A100/A10)
59
+ - pattern: "fmhaSm80*"
60
+ hardware: "SM80 (Ampere)"
61
+ library: "Flash Attention"
62
+ - pattern: "fmha_*"
63
+ hardware: "SM80 (Ampere)"
64
+ library: "Flash Attention"
65
+ # Generic phase patterns (fallback)
66
+ - pattern: "*Context*"
67
+ phase: prefill
68
+ - pattern: "*context*"
69
+ phase: prefill
70
+ - pattern: "*ForGen*"
71
+ phase: decode
72
+ - pattern: "*forgen*"
73
+ phase: decode
74
+ amd:
75
+ # CDNA 4 (MI355X - gfx950) - Composable Kernel v2
76
+ - pattern: "*ck_fmha_*"
77
+ hardware: "CDNA 4 (MI355X)"
78
+ library: "Composable Kernel"
79
+ - pattern: "*flash_attn_ck*"
80
+ hardware: "CDNA 4 (MI355X)"
81
+ library: "Composable Kernel"
82
+ # CDNA 3 (MI300X/MI325X - gfx942) - Composable Kernel unified attention
83
+ - pattern: "*unified_attention_2d*"
84
+ hardware: "CDNA 3 (MI300X/MI325X)"
85
+ phase: prefill
86
+ library: "Composable Kernel"
87
+ - pattern: "*unified_attention_3d*"
88
+ hardware: "CDNA 3 (MI300X/MI325X)"
89
+ phase: decode
90
+ library: "Composable Kernel"
91
+ - pattern: "kernel_unified_attention_2d*"
92
+ hardware: "CDNA 3 (MI300X/MI325X)"
93
+ phase: prefill
94
+ library: "Composable Kernel"
95
+ - pattern: "kernel_unified_attention_3d*"
96
+ hardware: "CDNA 3 (MI300X/MI325X)"
97
+ phase: decode
98
+ library: "Composable Kernel"
99
+ - pattern: "attention_2d*"
100
+ phase: prefill
101
+ library: "Composable Kernel"
102
+ - pattern: "attention_3d*"
103
+ phase: decode
104
+ library: "Composable Kernel"
105
+ # Triton Flash Attention (works on all AMD GPUs)
106
+ - pattern: "triton_*flash*"
107
+ library: "Triton Flash Attention"
108
+ - pattern: "triton_*attention*"
109
+ library: "Triton"
110
+
111
+ gemm:
112
+ nvidia:
113
+ # cuBLASLt (H100/H200 optimized)
114
+ - pattern: "nvjet_*"
115
+ library: "cuBLASLt"
116
+ hardware: "SM90+ (Hopper/Blackwell)"
117
+ - pattern: "void cublasLt*"
118
+ library: "cuBLASLt"
119
+ # CUTLASS (all architectures)
120
+ - pattern: "cutlass*gemm*"
121
+ library: "CUTLASS 3.x"
122
+ - pattern: "cutlass_*"
123
+ library: "CUTLASS"
124
+ # cuBLAS legacy
125
+ - pattern: "cublas*"
126
+ library: "cuBLAS"
127
+ # FP8 GEMM (H100+ specific)
128
+ - pattern: "*fp8*gemm*"
129
+ library: "cuBLASLt FP8"
130
+ hardware: "SM90+ (Hopper)"
131
+ - pattern: "*e4m3*"
132
+ library: "cuBLASLt FP8"
133
+ hardware: "SM90+ (Hopper)"
134
+ amd:
135
+ # Tensile (all CDNA architectures)
136
+ - pattern: "Cijk_*"
137
+ library: "Tensile"
138
+ - pattern: "Custom_Cijk_*"
139
+ library: "Tensile"
140
+ # hipBLASLt (MI300X/MI325X/MI355X optimized)
141
+ - pattern: "wvSplitK*"
142
+ library: "hipBLASLt"
143
+ hardware: "CDNA 3/4 (MI300X/MI325X/MI355X)"
144
+ - pattern: "hipblaslt*"
145
+ library: "hipBLASLt"
146
+ - pattern: "hipblas*"
147
+ library: "hipBLAS"
148
+ # FP8 GEMM (MI300X+ specific)
149
+ - pattern: "*fp8*"
150
+ library: "hipBLASLt FP8"
151
+ hardware: "CDNA 3+ (MI300X/MI325X/MI355X)"
152
+ # CDNA 4 specific (MI355X - gfx950)
153
+ - pattern: "*gfx950*"
154
+ library: "Tensile"
155
+ hardware: "CDNA 4 (MI355X)"
156
+ # ISA-specific patterns (gfx942 = MI300X/MI325X, gfx950 = MI355X)
157
+ - pattern: "*ISA942*"
158
+ library: "Tensile"
159
+ hardware: "CDNA 3 (MI300X/MI325X)"
160
+ - pattern: "*ISA950*"
161
+ library: "Tensile"
162
+ hardware: "CDNA 4 (MI355X)"
163
+
164
+ ssm:
165
+ both:
166
+ - pattern: "selective_scan*"
167
+ model: "Mamba"
168
+ - pattern: "ssd_*"
169
+ model: "Mamba-2"
170
+ - pattern: "causal_conv1d*"
171
+ model: "Mamba"
172
+ - pattern: "mamba_*"
173
+ model: "Mamba"
174
+
175
+ rmsnorm:
176
+ both:
177
+ # Fused RMSNorm+GEMM patterns (AMD Triton fuses these)
178
+ # Key indicator: *rocm_unquantized_gemm* in kernel name
179
+ - pattern: "triton_*rocm_unquantized_gemm*rsqrt*"
180
+ library: "Triton"
181
+ fused_with: "GEMM"
182
+ - pattern: "triton_*rsqrt*rocm_unquantized_gemm*"
183
+ library: "Triton"
184
+ fused_with: "GEMM"
185
+ - pattern: "triton_*rsqrt*gemm*"
186
+ library: "Triton"
187
+ fused_with: "GEMM"
188
+ - pattern: "triton_*gemm*rsqrt*"
189
+ library: "Triton"
190
+ fused_with: "GEMM"
191
+ # Non-fused RMSNorm (no gemm in name)
192
+ - pattern: "triton_*rsqrt*"
193
+ library: "Triton"
194
+ - pattern: "*rmsnorm*"
195
+ library: "Various"
196
+
197
+ moe:
198
+ both:
199
+ - pattern: "_matmul_ogs_*"
200
+ library: "Triton"
201
+ - pattern: "bmm_*dynbatch*"
202
+ library: "Triton"
203
+ - pattern: "*routing*"
204
+ library: "Various"
205
+ - pattern: "*topk*"
206
+ library: "Various"
207
+ - pattern: "fused_moe_kernel*"
208
+ library: "vLLM"
209
+ - pattern: "*vllm::moe::*"
210
+ library: "vLLM"
211
+ - pattern: "*moe_align_block_size*"
212
+ library: "vLLM"
213
+ - pattern: "*count_and_sort_expert*"
214
+ library: "vLLM"
215
+ - pattern: "*topkGatingSoftmax*"
216
+ library: "vLLM"
217
+
218
+ # Activation functions (SwiGLU, SiLU, etc.)
219
+ activation:
220
+ both:
221
+ # Fused SwiGLU+GEMM (AMD Triton fuses these)
222
+ - pattern: "triton_*rocm_unquantized_gemm*silu*"
223
+ operation: "SwiGLU+GEMM"
224
+ library: "Triton"
225
+ fused_with: "GEMM"
226
+ - pattern: "triton_*silu*rocm_unquantized_gemm*"
227
+ operation: "SwiGLU+GEMM"
228
+ library: "Triton"
229
+ fused_with: "GEMM"
230
+ - pattern: "triton_*gemm*silu*"
231
+ operation: "SwiGLU+GEMM"
232
+ library: "Triton"
233
+ fused_with: "GEMM"
234
+ - pattern: "triton_*silu*gemm*"
235
+ operation: "SwiGLU+GEMM"
236
+ library: "Triton"
237
+ fused_with: "GEMM"
238
+ # Non-fused activation
239
+ - pattern: "*act_and_mul_kernel*"
240
+ operation: "SwiGLU"
241
+ library: "vLLM"
242
+ - pattern: "triton_*silu*"
243
+ operation: "SiLU"
244
+ library: "Triton"
245
+ - pattern: "*silu_kernel*"
246
+ operation: "SiLU"
247
+ library: "vLLM"
248
+ - pattern: "*gelu*"
249
+ operation: "GELU"
250
+ library: "Various"
251
+
252
+ # KV Cache operations
253
+ kv_cache:
254
+ both:
255
+ - pattern: "*reshape_and_cache*"
256
+ library: "vLLM"
257
+ - pattern: "*concat_and_cache*"
258
+ library: "vLLM"
259
+ - pattern: "*cache_mla*"
260
+ library: "vLLM"
261
+
262
+ # Softmax operations
263
+ softmax:
264
+ both:
265
+ - pattern: "*SoftMax*"
266
+ library: "PyTorch"
267
+ - pattern: "*softmax*"
268
+ library: "PyTorch"
269
+
270
+ # Triton fused operations (more specific patterns)
271
+ triton:
272
+ both:
273
+ - pattern: "triton_poi_fused_mul*silu*"
274
+ operation: "SwiGLU"
275
+ library: "Triton"
276
+ - pattern: "triton_poi_fused*"
277
+ operation: "Pointwise"
278
+ library: "Triton"
279
+ - pattern: "triton_red_fused*"
280
+ operation: "Reduction"
281
+ library: "Triton"
282
+ - pattern: "triton_per_fused*"
283
+ operation: "Persistent"
284
+ library: "Triton"
285
+
286
+ # Reduce/Scan operations
287
+ reduce:
288
+ nvidia:
289
+ - pattern: "*cub::*Reduce*"
290
+ library: "CUB"
291
+ - pattern: "*cub::*Scan*"
292
+ library: "CUB"
293
+ - pattern: "*splitKreduce*"
294
+ library: "cuBLASLt"
295
+ note: "GEMM epilogue reduction"
296
+ amd:
297
+ - pattern: "*rocprim::*reduce*"
298
+ library: "rocPRIM"
299
+ - pattern: "*rocprim::*scan*"
300
+ library: "rocPRIM"
301
+ - pattern: "reduce_segments*"
302
+ library: "vLLM"
303
+
304
+ # Sorting operations
305
+ sorting:
306
+ nvidia:
307
+ - pattern: "*RadixSort*"
308
+ library: "CUB"
309
+ - pattern: "*DeviceSort*"
310
+ library: "CUB"
311
+ amd:
312
+ - pattern: "*rocprim::*sort*"
313
+ library: "rocPRIM"
314
+ - pattern: "*rocprim::*merge*"
315
+ library: "rocPRIM"
316
+
317
+ # Memory/Copy operations
318
+ memory:
319
+ both:
320
+ - pattern: "*memcpy*"
321
+ library: "CUDA/HIP Runtime"
322
+ - pattern: "*direct_copy*"
323
+ library: "PyTorch"
324
+ - pattern: "*copy_page_indices*"
325
+ library: "vLLM"
326
+ - pattern: "*rocclr_copyBuffer*"
327
+ library: "AMD ROCclr"
328
+ - pattern: "*rocprim::*transform*"
329
+ library: "rocPRIM"
330
+
331
+ # Indexing/Scatter-Gather operations
332
+ indexing:
333
+ both:
334
+ - pattern: "*scatter_gather*"
335
+ library: "PyTorch"
336
+ - pattern: "*index_elementwise*"
337
+ library: "PyTorch"
338
+ - pattern: "*fill_reverse_indices*"
339
+ library: "PyTorch"
340
+
341
+ # Elementwise operations (fallback patterns)
342
+ elementwise:
343
+ both:
344
+ - pattern: "at::native::*elementwise*"
345
+ library: "PyTorch"
346
+ - pattern: "at::native::*vectorized*"
347
+ library: "PyTorch"
348
+ - pattern: "*distribution_elementwise*"
349
+ library: "PyTorch"