wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +45 -0
- wafer_core/lib/trace_compare/aligner.py +369 -0
- wafer_core/lib/trace_compare/analyzer.py +729 -0
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +486 -0
- wafer_core/lib/trace_compare/formatter.py +951 -0
- wafer_core/lib/trace_compare/fusion_analyzer.py +356 -0
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +635 -0
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/problem_config.py +3 -3
- wafer_core/rollouts/agent_presets/rlm_01_01.py +2 -2
- wafer_core/rollouts/dtypes.py +18 -3
- wafer_core/rollouts/providers/anthropic.py +35 -3
- wafer_core/utils/kernel_utils/defense.py +10 -0
- wafer_core/utils/kernel_utils/targets/config.py +10 -0
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/METADATA +3 -1
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/RECORD +23 -9
- {wafer_core-0.1.25.dist-info → wafer_core-0.1.27.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""Fusion analysis using aligned kernel pairs.
|
|
2
|
+
|
|
3
|
+
Detects kernel fusion patterns at the kernel-pair level within aligned layers.
|
|
4
|
+
|
|
5
|
+
Key insight: When one platform has an operation type like "RMSNorm+GEMM" (containing '+'),
|
|
6
|
+
that platform IS fusing. The OTHER platform runs the components separately.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from collections import Counter, defaultdict
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from .aligner import KernelPair, LayerAlignment
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Maps fused operation names to their component operations
|
|
17
|
+
FUSED_OPERATION_COMPONENTS: dict[str, list[str]] = {
|
|
18
|
+
"RMSNorm+GEMM": ["RMSNorm", "Dense GEMM"],
|
|
19
|
+
"SwiGLU+GEMM": ["SwiGLU", "Dense GEMM"],
|
|
20
|
+
"MoE GEMM+SwiGLU": ["MoE GEMM", "SwiGLU"],
|
|
21
|
+
"Embedding+RMSNorm+GEMM": ["Elementwise", "RMSNorm", "Dense GEMM"], # Embedding is often Elementwise
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class FusionPattern:
|
|
27
|
+
"""A detected fusion pattern."""
|
|
28
|
+
|
|
29
|
+
layer: int
|
|
30
|
+
operation: str # The fused operation name (e.g., "RMSNorm+GEMM")
|
|
31
|
+
fused_platform: str # Platform that fuses (has fewer kernels)
|
|
32
|
+
fused_kernel: str # The actual fused kernel name
|
|
33
|
+
unfused_kernels: list[str] # List of separate kernels on the other platform
|
|
34
|
+
count: int
|
|
35
|
+
evidence: str
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class FusionAnalysis:
|
|
40
|
+
"""Complete fusion analysis result."""
|
|
41
|
+
|
|
42
|
+
patterns: list[FusionPattern] = field(default_factory=list)
|
|
43
|
+
summary: dict[str, Any] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _is_fused_operation(operation: str) -> bool:
|
|
47
|
+
"""Check if an operation type represents a fused operation.
|
|
48
|
+
|
|
49
|
+
Operations like 'RMSNorm+GEMM', 'MoE GEMM+SwiGLU' indicate fusion.
|
|
50
|
+
Also handles 'Fused (Unknown)' from heuristic detection.
|
|
51
|
+
"""
|
|
52
|
+
return "+" in operation or operation == "Fused (Unknown)"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _get_component_operations(fused_op: str) -> list[str]:
|
|
56
|
+
"""Get the component operations for a fused operation type."""
|
|
57
|
+
if fused_op in FUSED_OPERATION_COMPONENTS:
|
|
58
|
+
return FUSED_OPERATION_COMPONENTS[fused_op]
|
|
59
|
+
|
|
60
|
+
# Parse from the operation name itself (e.g., "RMSNorm+GEMM" -> ["RMSNorm", "GEMM"])
|
|
61
|
+
if "+" in fused_op:
|
|
62
|
+
parts = [p.strip() for p in fused_op.split("+")]
|
|
63
|
+
# Map shorthand to full names
|
|
64
|
+
mapping = {
|
|
65
|
+
"GEMM": "Dense GEMM",
|
|
66
|
+
"SwiGLU": "Triton Fused",
|
|
67
|
+
}
|
|
68
|
+
return [mapping.get(p, p) for p in parts]
|
|
69
|
+
|
|
70
|
+
return []
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _find_component_kernels(
|
|
74
|
+
layer_alignment: LayerAlignment,
|
|
75
|
+
component_ops: list[str],
|
|
76
|
+
platform: str,
|
|
77
|
+
) -> list[str]:
|
|
78
|
+
"""Find kernels for component operations on the specified platform.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
layer_alignment: Layer alignment data
|
|
82
|
+
component_ops: List of operation types to look for (e.g., ["RMSNorm", "Dense GEMM"])
|
|
83
|
+
platform: "AMD" or "NVIDIA"
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
List of kernel names found for these operations
|
|
87
|
+
"""
|
|
88
|
+
found_kernels: list[str] = []
|
|
89
|
+
|
|
90
|
+
for pair in layer_alignment.kernel_pairs:
|
|
91
|
+
if pair.operation in component_ops:
|
|
92
|
+
if platform == "AMD" and pair.amd_kernel and pair.amd_count > 0:
|
|
93
|
+
found_kernels.append(pair.amd_kernel)
|
|
94
|
+
elif platform == "NVIDIA" and pair.nvidia_kernel and pair.nvidia_count > 0:
|
|
95
|
+
found_kernels.append(pair.nvidia_kernel)
|
|
96
|
+
|
|
97
|
+
return found_kernels
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def detect_fusion_in_aligned_layers(
|
|
101
|
+
layer_alignments: list[LayerAlignment],
|
|
102
|
+
) -> FusionAnalysis:
|
|
103
|
+
"""Detect fusion patterns from aligned layer data.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
layer_alignments: List of aligned layers
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
FusionAnalysis with detected patterns
|
|
110
|
+
"""
|
|
111
|
+
fusion_patterns: list[FusionPattern] = []
|
|
112
|
+
|
|
113
|
+
# Track patterns we've already seen to avoid duplicates
|
|
114
|
+
seen_patterns: set[tuple[int, str, str]] = set() # (layer, operation, fused_platform)
|
|
115
|
+
|
|
116
|
+
for layer_alignment in layer_alignments:
|
|
117
|
+
for pair in layer_alignment.kernel_pairs:
|
|
118
|
+
is_fused_op = _is_fused_operation(pair.operation)
|
|
119
|
+
|
|
120
|
+
# Case 1: AMD has a kernel for this operation, NVIDIA doesn't
|
|
121
|
+
if pair.amd_kernel and pair.amd_count > 0 and (pair.nvidia_kernel is None or pair.nvidia_count == 0):
|
|
122
|
+
if is_fused_op:
|
|
123
|
+
# AMD HAS the fused kernel → AMD is fusing
|
|
124
|
+
# Find what NVIDIA runs separately
|
|
125
|
+
component_ops = _get_component_operations(pair.operation)
|
|
126
|
+
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "NVIDIA")
|
|
127
|
+
|
|
128
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
129
|
+
if pattern_key not in seen_patterns:
|
|
130
|
+
seen_patterns.add(pattern_key)
|
|
131
|
+
|
|
132
|
+
if unfused_kernels:
|
|
133
|
+
evidence = f"AMD fuses {' + '.join(component_ops)} into {pair.amd_kernel}, NVIDIA runs {len(unfused_kernels)} separate kernels"
|
|
134
|
+
else:
|
|
135
|
+
evidence = f"AMD fuses into {pair.amd_kernel}, NVIDIA runs components separately"
|
|
136
|
+
|
|
137
|
+
fusion_patterns.append(
|
|
138
|
+
FusionPattern(
|
|
139
|
+
layer=layer_alignment.layer,
|
|
140
|
+
operation=pair.operation,
|
|
141
|
+
fused_platform="AMD",
|
|
142
|
+
fused_kernel=pair.amd_kernel,
|
|
143
|
+
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
144
|
+
count=pair.amd_count,
|
|
145
|
+
evidence=evidence,
|
|
146
|
+
)
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
# Regular case: AMD has a kernel that NVIDIA fuses into something else
|
|
150
|
+
# This means NVIDIA is fusing (it doesn't need this separate kernel)
|
|
151
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
152
|
+
if pattern_key not in seen_patterns:
|
|
153
|
+
seen_patterns.add(pattern_key)
|
|
154
|
+
|
|
155
|
+
evidence = f"AMD runs {pair.amd_kernel} separately ({pair.amd_count}x), NVIDIA fuses this operation"
|
|
156
|
+
|
|
157
|
+
fusion_patterns.append(
|
|
158
|
+
FusionPattern(
|
|
159
|
+
layer=layer_alignment.layer,
|
|
160
|
+
operation=pair.operation,
|
|
161
|
+
fused_platform="NVIDIA",
|
|
162
|
+
fused_kernel="(fused into nearby kernel)",
|
|
163
|
+
unfused_kernels=[pair.amd_kernel],
|
|
164
|
+
count=pair.amd_count,
|
|
165
|
+
evidence=evidence,
|
|
166
|
+
)
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Case 2: NVIDIA has a kernel for this operation, AMD doesn't
|
|
170
|
+
elif pair.nvidia_kernel and pair.nvidia_count > 0 and (pair.amd_kernel is None or pair.amd_count == 0):
|
|
171
|
+
if is_fused_op:
|
|
172
|
+
# NVIDIA HAS the fused kernel → NVIDIA is fusing
|
|
173
|
+
# Find what AMD runs separately
|
|
174
|
+
component_ops = _get_component_operations(pair.operation)
|
|
175
|
+
unfused_kernels = _find_component_kernels(layer_alignment, component_ops, "AMD")
|
|
176
|
+
|
|
177
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
178
|
+
if pattern_key not in seen_patterns:
|
|
179
|
+
seen_patterns.add(pattern_key)
|
|
180
|
+
|
|
181
|
+
if unfused_kernels:
|
|
182
|
+
evidence = f"NVIDIA fuses {' + '.join(component_ops)} into {pair.nvidia_kernel}, AMD runs {len(unfused_kernels)} separate kernels"
|
|
183
|
+
else:
|
|
184
|
+
evidence = f"NVIDIA fuses into {pair.nvidia_kernel}, AMD runs components separately"
|
|
185
|
+
|
|
186
|
+
fusion_patterns.append(
|
|
187
|
+
FusionPattern(
|
|
188
|
+
layer=layer_alignment.layer,
|
|
189
|
+
operation=pair.operation,
|
|
190
|
+
fused_platform="NVIDIA",
|
|
191
|
+
fused_kernel=pair.nvidia_kernel,
|
|
192
|
+
unfused_kernels=unfused_kernels if unfused_kernels else component_ops,
|
|
193
|
+
count=pair.nvidia_count,
|
|
194
|
+
evidence=evidence,
|
|
195
|
+
)
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
# Regular case: NVIDIA has a kernel that AMD fuses into something else
|
|
199
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
200
|
+
if pattern_key not in seen_patterns:
|
|
201
|
+
seen_patterns.add(pattern_key)
|
|
202
|
+
|
|
203
|
+
evidence = f"NVIDIA runs {pair.nvidia_kernel} separately ({pair.nvidia_count}x), AMD fuses this operation"
|
|
204
|
+
|
|
205
|
+
fusion_patterns.append(
|
|
206
|
+
FusionPattern(
|
|
207
|
+
layer=layer_alignment.layer,
|
|
208
|
+
operation=pair.operation,
|
|
209
|
+
fused_platform="AMD",
|
|
210
|
+
fused_kernel="(fused into nearby kernel)",
|
|
211
|
+
unfused_kernels=[pair.nvidia_kernel],
|
|
212
|
+
count=pair.nvidia_count,
|
|
213
|
+
evidence=evidence,
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Case 3: Both have kernels but with very different counts (partial fusion)
|
|
218
|
+
elif (
|
|
219
|
+
pair.amd_kernel
|
|
220
|
+
and pair.nvidia_kernel
|
|
221
|
+
and pair.amd_count > 0
|
|
222
|
+
and pair.nvidia_count > 0
|
|
223
|
+
):
|
|
224
|
+
count_ratio = pair.amd_count / pair.nvidia_count if pair.nvidia_count > 0 else float('inf')
|
|
225
|
+
|
|
226
|
+
if count_ratio > 1.5:
|
|
227
|
+
# AMD runs more → NVIDIA fuses some instances
|
|
228
|
+
pattern_key = (layer_alignment.layer, pair.operation, "NVIDIA")
|
|
229
|
+
if pattern_key not in seen_patterns:
|
|
230
|
+
seen_patterns.add(pattern_key)
|
|
231
|
+
|
|
232
|
+
evidence = (
|
|
233
|
+
f"AMD runs {pair.amd_kernel} {count_ratio:.1f}x more "
|
|
234
|
+
f"({pair.amd_count} vs {pair.nvidia_count}), NVIDIA partially fuses"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
fusion_patterns.append(
|
|
238
|
+
FusionPattern(
|
|
239
|
+
layer=layer_alignment.layer,
|
|
240
|
+
operation=pair.operation,
|
|
241
|
+
fused_platform="NVIDIA",
|
|
242
|
+
fused_kernel=pair.nvidia_kernel,
|
|
243
|
+
unfused_kernels=[pair.amd_kernel],
|
|
244
|
+
count=pair.amd_count - pair.nvidia_count,
|
|
245
|
+
evidence=evidence,
|
|
246
|
+
)
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
elif count_ratio < 0.67:
|
|
250
|
+
# NVIDIA runs more → AMD fuses some instances
|
|
251
|
+
inverse_ratio = pair.nvidia_count / pair.amd_count if pair.amd_count > 0 else float('inf')
|
|
252
|
+
pattern_key = (layer_alignment.layer, pair.operation, "AMD")
|
|
253
|
+
if pattern_key not in seen_patterns:
|
|
254
|
+
seen_patterns.add(pattern_key)
|
|
255
|
+
|
|
256
|
+
evidence = (
|
|
257
|
+
f"NVIDIA runs {pair.nvidia_kernel} {inverse_ratio:.1f}x more "
|
|
258
|
+
f"({pair.nvidia_count} vs {pair.amd_count}), AMD partially fuses"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
fusion_patterns.append(
|
|
262
|
+
FusionPattern(
|
|
263
|
+
layer=layer_alignment.layer,
|
|
264
|
+
operation=pair.operation,
|
|
265
|
+
fused_platform="AMD",
|
|
266
|
+
fused_kernel=pair.amd_kernel,
|
|
267
|
+
unfused_kernels=[pair.nvidia_kernel],
|
|
268
|
+
count=pair.nvidia_count - pair.amd_count,
|
|
269
|
+
evidence=evidence,
|
|
270
|
+
)
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Aggregate patterns by operation type and fused platform
|
|
274
|
+
aggregated: dict[tuple[str, str], FusionPattern] = {}
|
|
275
|
+
for pattern in fusion_patterns:
|
|
276
|
+
key = (pattern.operation, pattern.fused_platform)
|
|
277
|
+
if key in aggregated:
|
|
278
|
+
existing = aggregated[key]
|
|
279
|
+
existing.count += pattern.count
|
|
280
|
+
# Merge unfused kernels (deduplicate)
|
|
281
|
+
for k in pattern.unfused_kernels:
|
|
282
|
+
if k not in existing.unfused_kernels:
|
|
283
|
+
existing.unfused_kernels.append(k)
|
|
284
|
+
else:
|
|
285
|
+
aggregated[key] = FusionPattern(
|
|
286
|
+
layer=pattern.layer,
|
|
287
|
+
operation=pattern.operation,
|
|
288
|
+
fused_platform=pattern.fused_platform,
|
|
289
|
+
fused_kernel=pattern.fused_kernel,
|
|
290
|
+
unfused_kernels=list(pattern.unfused_kernels),
|
|
291
|
+
count=pattern.count,
|
|
292
|
+
evidence=pattern.evidence,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
amd_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "AMD")
|
|
296
|
+
nvidia_fuses = sum(1 for p in aggregated.values() if p.fused_platform == "NVIDIA")
|
|
297
|
+
|
|
298
|
+
return FusionAnalysis(
|
|
299
|
+
patterns=list(aggregated.values()),
|
|
300
|
+
summary={
|
|
301
|
+
"amd_fuses": amd_fuses,
|
|
302
|
+
"nvidia_fuses": nvidia_fuses,
|
|
303
|
+
"total_fusion_opportunities": len(aggregated),
|
|
304
|
+
},
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def analyze_fusion_from_alignment(
|
|
309
|
+
layer_alignments: list[LayerAlignment],
|
|
310
|
+
) -> dict[str, Any]:
|
|
311
|
+
"""Analyze fusion from alignment data (for API compatibility).
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
layer_alignments: List of aligned layers
|
|
315
|
+
|
|
316
|
+
Returns:
|
|
317
|
+
Dictionary with fusion analysis results (compatible with old API)
|
|
318
|
+
"""
|
|
319
|
+
fusion_analysis = detect_fusion_in_aligned_layers(layer_alignments)
|
|
320
|
+
|
|
321
|
+
fusion_opportunities = []
|
|
322
|
+
fusion_mappings = []
|
|
323
|
+
|
|
324
|
+
for pattern in fusion_analysis.patterns:
|
|
325
|
+
unfused_platform = "NVIDIA" if pattern.fused_platform == "AMD" else "AMD"
|
|
326
|
+
|
|
327
|
+
fusion_opportunities.append(
|
|
328
|
+
{
|
|
329
|
+
"kernel_type": pattern.operation,
|
|
330
|
+
"layer": pattern.layer,
|
|
331
|
+
"fused_by": pattern.fused_platform,
|
|
332
|
+
"fused_kernel": pattern.fused_kernel,
|
|
333
|
+
"unfused_kernels": pattern.unfused_kernels,
|
|
334
|
+
"count": pattern.count,
|
|
335
|
+
"evidence": pattern.evidence,
|
|
336
|
+
}
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
fusion_mappings.append(
|
|
340
|
+
{
|
|
341
|
+
"fused_platform": pattern.fused_platform,
|
|
342
|
+
"fused_kernel_type": pattern.operation,
|
|
343
|
+
"fused_kernel_name": pattern.fused_kernel,
|
|
344
|
+
"unfused_platform": unfused_platform,
|
|
345
|
+
"unfused_sequence": pattern.unfused_kernels,
|
|
346
|
+
"pattern_count": pattern.count,
|
|
347
|
+
"evidence": pattern.evidence,
|
|
348
|
+
"layer": pattern.layer,
|
|
349
|
+
}
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return {
|
|
353
|
+
"fusion_opportunities": fusion_opportunities,
|
|
354
|
+
"fusion_mappings": fusion_mappings,
|
|
355
|
+
"summary": fusion_analysis.summary,
|
|
356
|
+
}
|
|
@@ -0,0 +1,349 @@
|
|
|
1
|
+
# Kernel Pattern Registry
|
|
2
|
+
# Version: 2025-01
|
|
3
|
+
# Last updated: 2025-01-28
|
|
4
|
+
# Update when: New GPU architecture, new library version, new model architecture
|
|
5
|
+
|
|
6
|
+
version: "2025-01"
|
|
7
|
+
|
|
8
|
+
# ============================================================================
|
|
9
|
+
# SUPPORTED HARDWARE
|
|
10
|
+
# ============================================================================
|
|
11
|
+
# NVIDIA:
|
|
12
|
+
# - SM100 (Blackwell): B200, B100
|
|
13
|
+
# - SM90 (Hopper): H100, H200
|
|
14
|
+
# - SM89 (Ada Lovelace): L40, RTX 4090
|
|
15
|
+
# - SM80 (Ampere): A100, A10, A30
|
|
16
|
+
#
|
|
17
|
+
# AMD:
|
|
18
|
+
# - CDNA 4 (gfx950): MI355X
|
|
19
|
+
# - CDNA 3 (gfx942): MI300X, MI300A, MI325X
|
|
20
|
+
# - CDNA 2 (gfx90a): MI250X, MI210
|
|
21
|
+
#
|
|
22
|
+
# Note: MI325X uses same gfx942 ISA as MI300X but with 256GB HBM3e memory
|
|
23
|
+
# ============================================================================
|
|
24
|
+
|
|
25
|
+
attention:
|
|
26
|
+
nvidia:
|
|
27
|
+
# SM100 (Blackwell B200/B100) - 'a' suffix = prefill/context, 'f' suffix = decode/forgen
|
|
28
|
+
- pattern: "fmhaSm100a*"
|
|
29
|
+
hardware: "SM100 (Blackwell)"
|
|
30
|
+
library: "Flash Attention 3"
|
|
31
|
+
phase: prefill
|
|
32
|
+
- pattern: "fmhaSm100f*"
|
|
33
|
+
hardware: "SM100 (Blackwell)"
|
|
34
|
+
library: "Flash Attention 3"
|
|
35
|
+
phase: decode
|
|
36
|
+
# SM90 (Hopper H100/H200) - Flash Attention 2/3
|
|
37
|
+
- pattern: "fmhaSm90*"
|
|
38
|
+
hardware: "SM90 (Hopper)"
|
|
39
|
+
library: "Flash Attention 3"
|
|
40
|
+
- pattern: "flash::flash_fwd_kernel*"
|
|
41
|
+
hardware: "SM90 (Hopper)"
|
|
42
|
+
library: "Flash Attention 2"
|
|
43
|
+
phase: prefill
|
|
44
|
+
- pattern: "flash_fwd_*"
|
|
45
|
+
hardware: "SM90 (Hopper)"
|
|
46
|
+
library: "Flash Attention 2"
|
|
47
|
+
- pattern: "fmha_v2_*flash_attention_forward*"
|
|
48
|
+
hardware: "SM90 (Hopper)"
|
|
49
|
+
library: "Flash Attention 2"
|
|
50
|
+
phase: prefill
|
|
51
|
+
- pattern: "fmha_v2_*"
|
|
52
|
+
hardware: "SM90 (Hopper)"
|
|
53
|
+
library: "Flash Attention 2"
|
|
54
|
+
# SM89 (Ada Lovelace L40/RTX 4090)
|
|
55
|
+
- pattern: "fmhaSm89*"
|
|
56
|
+
hardware: "SM89 (Ada Lovelace)"
|
|
57
|
+
library: "Flash Attention"
|
|
58
|
+
# SM80 (Ampere A100/A10)
|
|
59
|
+
- pattern: "fmhaSm80*"
|
|
60
|
+
hardware: "SM80 (Ampere)"
|
|
61
|
+
library: "Flash Attention"
|
|
62
|
+
- pattern: "fmha_*"
|
|
63
|
+
hardware: "SM80 (Ampere)"
|
|
64
|
+
library: "Flash Attention"
|
|
65
|
+
# Generic phase patterns (fallback)
|
|
66
|
+
- pattern: "*Context*"
|
|
67
|
+
phase: prefill
|
|
68
|
+
- pattern: "*context*"
|
|
69
|
+
phase: prefill
|
|
70
|
+
- pattern: "*ForGen*"
|
|
71
|
+
phase: decode
|
|
72
|
+
- pattern: "*forgen*"
|
|
73
|
+
phase: decode
|
|
74
|
+
amd:
|
|
75
|
+
# CDNA 4 (MI355X - gfx950) - Composable Kernel v2
|
|
76
|
+
- pattern: "*ck_fmha_*"
|
|
77
|
+
hardware: "CDNA 4 (MI355X)"
|
|
78
|
+
library: "Composable Kernel"
|
|
79
|
+
- pattern: "*flash_attn_ck*"
|
|
80
|
+
hardware: "CDNA 4 (MI355X)"
|
|
81
|
+
library: "Composable Kernel"
|
|
82
|
+
# CDNA 3 (MI300X/MI325X - gfx942) - Composable Kernel unified attention
|
|
83
|
+
- pattern: "*unified_attention_2d*"
|
|
84
|
+
hardware: "CDNA 3 (MI300X/MI325X)"
|
|
85
|
+
phase: prefill
|
|
86
|
+
library: "Composable Kernel"
|
|
87
|
+
- pattern: "*unified_attention_3d*"
|
|
88
|
+
hardware: "CDNA 3 (MI300X/MI325X)"
|
|
89
|
+
phase: decode
|
|
90
|
+
library: "Composable Kernel"
|
|
91
|
+
- pattern: "kernel_unified_attention_2d*"
|
|
92
|
+
hardware: "CDNA 3 (MI300X/MI325X)"
|
|
93
|
+
phase: prefill
|
|
94
|
+
library: "Composable Kernel"
|
|
95
|
+
- pattern: "kernel_unified_attention_3d*"
|
|
96
|
+
hardware: "CDNA 3 (MI300X/MI325X)"
|
|
97
|
+
phase: decode
|
|
98
|
+
library: "Composable Kernel"
|
|
99
|
+
- pattern: "attention_2d*"
|
|
100
|
+
phase: prefill
|
|
101
|
+
library: "Composable Kernel"
|
|
102
|
+
- pattern: "attention_3d*"
|
|
103
|
+
phase: decode
|
|
104
|
+
library: "Composable Kernel"
|
|
105
|
+
# Triton Flash Attention (works on all AMD GPUs)
|
|
106
|
+
- pattern: "triton_*flash*"
|
|
107
|
+
library: "Triton Flash Attention"
|
|
108
|
+
- pattern: "triton_*attention*"
|
|
109
|
+
library: "Triton"
|
|
110
|
+
|
|
111
|
+
gemm:
|
|
112
|
+
nvidia:
|
|
113
|
+
# cuBLASLt (H100/H200 optimized)
|
|
114
|
+
- pattern: "nvjet_*"
|
|
115
|
+
library: "cuBLASLt"
|
|
116
|
+
hardware: "SM90+ (Hopper/Blackwell)"
|
|
117
|
+
- pattern: "void cublasLt*"
|
|
118
|
+
library: "cuBLASLt"
|
|
119
|
+
# CUTLASS (all architectures)
|
|
120
|
+
- pattern: "cutlass*gemm*"
|
|
121
|
+
library: "CUTLASS 3.x"
|
|
122
|
+
- pattern: "cutlass_*"
|
|
123
|
+
library: "CUTLASS"
|
|
124
|
+
# cuBLAS legacy
|
|
125
|
+
- pattern: "cublas*"
|
|
126
|
+
library: "cuBLAS"
|
|
127
|
+
# FP8 GEMM (H100+ specific)
|
|
128
|
+
- pattern: "*fp8*gemm*"
|
|
129
|
+
library: "cuBLASLt FP8"
|
|
130
|
+
hardware: "SM90+ (Hopper)"
|
|
131
|
+
- pattern: "*e4m3*"
|
|
132
|
+
library: "cuBLASLt FP8"
|
|
133
|
+
hardware: "SM90+ (Hopper)"
|
|
134
|
+
amd:
|
|
135
|
+
# Tensile (all CDNA architectures)
|
|
136
|
+
- pattern: "Cijk_*"
|
|
137
|
+
library: "Tensile"
|
|
138
|
+
- pattern: "Custom_Cijk_*"
|
|
139
|
+
library: "Tensile"
|
|
140
|
+
# hipBLASLt (MI300X/MI325X/MI355X optimized)
|
|
141
|
+
- pattern: "wvSplitK*"
|
|
142
|
+
library: "hipBLASLt"
|
|
143
|
+
hardware: "CDNA 3/4 (MI300X/MI325X/MI355X)"
|
|
144
|
+
- pattern: "hipblaslt*"
|
|
145
|
+
library: "hipBLASLt"
|
|
146
|
+
- pattern: "hipblas*"
|
|
147
|
+
library: "hipBLAS"
|
|
148
|
+
# FP8 GEMM (MI300X+ specific)
|
|
149
|
+
- pattern: "*fp8*"
|
|
150
|
+
library: "hipBLASLt FP8"
|
|
151
|
+
hardware: "CDNA 3+ (MI300X/MI325X/MI355X)"
|
|
152
|
+
# CDNA 4 specific (MI355X - gfx950)
|
|
153
|
+
- pattern: "*gfx950*"
|
|
154
|
+
library: "Tensile"
|
|
155
|
+
hardware: "CDNA 4 (MI355X)"
|
|
156
|
+
# ISA-specific patterns (gfx942 = MI300X/MI325X, gfx950 = MI355X)
|
|
157
|
+
- pattern: "*ISA942*"
|
|
158
|
+
library: "Tensile"
|
|
159
|
+
hardware: "CDNA 3 (MI300X/MI325X)"
|
|
160
|
+
- pattern: "*ISA950*"
|
|
161
|
+
library: "Tensile"
|
|
162
|
+
hardware: "CDNA 4 (MI355X)"
|
|
163
|
+
|
|
164
|
+
ssm:
|
|
165
|
+
both:
|
|
166
|
+
- pattern: "selective_scan*"
|
|
167
|
+
model: "Mamba"
|
|
168
|
+
- pattern: "ssd_*"
|
|
169
|
+
model: "Mamba-2"
|
|
170
|
+
- pattern: "causal_conv1d*"
|
|
171
|
+
model: "Mamba"
|
|
172
|
+
- pattern: "mamba_*"
|
|
173
|
+
model: "Mamba"
|
|
174
|
+
|
|
175
|
+
rmsnorm:
|
|
176
|
+
both:
|
|
177
|
+
# Fused RMSNorm+GEMM patterns (AMD Triton fuses these)
|
|
178
|
+
# Key indicator: *rocm_unquantized_gemm* in kernel name
|
|
179
|
+
- pattern: "triton_*rocm_unquantized_gemm*rsqrt*"
|
|
180
|
+
library: "Triton"
|
|
181
|
+
fused_with: "GEMM"
|
|
182
|
+
- pattern: "triton_*rsqrt*rocm_unquantized_gemm*"
|
|
183
|
+
library: "Triton"
|
|
184
|
+
fused_with: "GEMM"
|
|
185
|
+
- pattern: "triton_*rsqrt*gemm*"
|
|
186
|
+
library: "Triton"
|
|
187
|
+
fused_with: "GEMM"
|
|
188
|
+
- pattern: "triton_*gemm*rsqrt*"
|
|
189
|
+
library: "Triton"
|
|
190
|
+
fused_with: "GEMM"
|
|
191
|
+
# Non-fused RMSNorm (no gemm in name)
|
|
192
|
+
- pattern: "triton_*rsqrt*"
|
|
193
|
+
library: "Triton"
|
|
194
|
+
- pattern: "*rmsnorm*"
|
|
195
|
+
library: "Various"
|
|
196
|
+
|
|
197
|
+
moe:
|
|
198
|
+
both:
|
|
199
|
+
- pattern: "_matmul_ogs_*"
|
|
200
|
+
library: "Triton"
|
|
201
|
+
- pattern: "bmm_*dynbatch*"
|
|
202
|
+
library: "Triton"
|
|
203
|
+
- pattern: "*routing*"
|
|
204
|
+
library: "Various"
|
|
205
|
+
- pattern: "*topk*"
|
|
206
|
+
library: "Various"
|
|
207
|
+
- pattern: "fused_moe_kernel*"
|
|
208
|
+
library: "vLLM"
|
|
209
|
+
- pattern: "*vllm::moe::*"
|
|
210
|
+
library: "vLLM"
|
|
211
|
+
- pattern: "*moe_align_block_size*"
|
|
212
|
+
library: "vLLM"
|
|
213
|
+
- pattern: "*count_and_sort_expert*"
|
|
214
|
+
library: "vLLM"
|
|
215
|
+
- pattern: "*topkGatingSoftmax*"
|
|
216
|
+
library: "vLLM"
|
|
217
|
+
|
|
218
|
+
# Activation functions (SwiGLU, SiLU, etc.)
|
|
219
|
+
activation:
|
|
220
|
+
both:
|
|
221
|
+
# Fused SwiGLU+GEMM (AMD Triton fuses these)
|
|
222
|
+
- pattern: "triton_*rocm_unquantized_gemm*silu*"
|
|
223
|
+
operation: "SwiGLU+GEMM"
|
|
224
|
+
library: "Triton"
|
|
225
|
+
fused_with: "GEMM"
|
|
226
|
+
- pattern: "triton_*silu*rocm_unquantized_gemm*"
|
|
227
|
+
operation: "SwiGLU+GEMM"
|
|
228
|
+
library: "Triton"
|
|
229
|
+
fused_with: "GEMM"
|
|
230
|
+
- pattern: "triton_*gemm*silu*"
|
|
231
|
+
operation: "SwiGLU+GEMM"
|
|
232
|
+
library: "Triton"
|
|
233
|
+
fused_with: "GEMM"
|
|
234
|
+
- pattern: "triton_*silu*gemm*"
|
|
235
|
+
operation: "SwiGLU+GEMM"
|
|
236
|
+
library: "Triton"
|
|
237
|
+
fused_with: "GEMM"
|
|
238
|
+
# Non-fused activation
|
|
239
|
+
- pattern: "*act_and_mul_kernel*"
|
|
240
|
+
operation: "SwiGLU"
|
|
241
|
+
library: "vLLM"
|
|
242
|
+
- pattern: "triton_*silu*"
|
|
243
|
+
operation: "SiLU"
|
|
244
|
+
library: "Triton"
|
|
245
|
+
- pattern: "*silu_kernel*"
|
|
246
|
+
operation: "SiLU"
|
|
247
|
+
library: "vLLM"
|
|
248
|
+
- pattern: "*gelu*"
|
|
249
|
+
operation: "GELU"
|
|
250
|
+
library: "Various"
|
|
251
|
+
|
|
252
|
+
# KV Cache operations
|
|
253
|
+
kv_cache:
|
|
254
|
+
both:
|
|
255
|
+
- pattern: "*reshape_and_cache*"
|
|
256
|
+
library: "vLLM"
|
|
257
|
+
- pattern: "*concat_and_cache*"
|
|
258
|
+
library: "vLLM"
|
|
259
|
+
- pattern: "*cache_mla*"
|
|
260
|
+
library: "vLLM"
|
|
261
|
+
|
|
262
|
+
# Softmax operations
|
|
263
|
+
softmax:
|
|
264
|
+
both:
|
|
265
|
+
- pattern: "*SoftMax*"
|
|
266
|
+
library: "PyTorch"
|
|
267
|
+
- pattern: "*softmax*"
|
|
268
|
+
library: "PyTorch"
|
|
269
|
+
|
|
270
|
+
# Triton fused operations (more specific patterns)
|
|
271
|
+
triton:
|
|
272
|
+
both:
|
|
273
|
+
- pattern: "triton_poi_fused_mul*silu*"
|
|
274
|
+
operation: "SwiGLU"
|
|
275
|
+
library: "Triton"
|
|
276
|
+
- pattern: "triton_poi_fused*"
|
|
277
|
+
operation: "Pointwise"
|
|
278
|
+
library: "Triton"
|
|
279
|
+
- pattern: "triton_red_fused*"
|
|
280
|
+
operation: "Reduction"
|
|
281
|
+
library: "Triton"
|
|
282
|
+
- pattern: "triton_per_fused*"
|
|
283
|
+
operation: "Persistent"
|
|
284
|
+
library: "Triton"
|
|
285
|
+
|
|
286
|
+
# Reduce/Scan operations
|
|
287
|
+
reduce:
|
|
288
|
+
nvidia:
|
|
289
|
+
- pattern: "*cub::*Reduce*"
|
|
290
|
+
library: "CUB"
|
|
291
|
+
- pattern: "*cub::*Scan*"
|
|
292
|
+
library: "CUB"
|
|
293
|
+
- pattern: "*splitKreduce*"
|
|
294
|
+
library: "cuBLASLt"
|
|
295
|
+
note: "GEMM epilogue reduction"
|
|
296
|
+
amd:
|
|
297
|
+
- pattern: "*rocprim::*reduce*"
|
|
298
|
+
library: "rocPRIM"
|
|
299
|
+
- pattern: "*rocprim::*scan*"
|
|
300
|
+
library: "rocPRIM"
|
|
301
|
+
- pattern: "reduce_segments*"
|
|
302
|
+
library: "vLLM"
|
|
303
|
+
|
|
304
|
+
# Sorting operations
|
|
305
|
+
sorting:
|
|
306
|
+
nvidia:
|
|
307
|
+
- pattern: "*RadixSort*"
|
|
308
|
+
library: "CUB"
|
|
309
|
+
- pattern: "*DeviceSort*"
|
|
310
|
+
library: "CUB"
|
|
311
|
+
amd:
|
|
312
|
+
- pattern: "*rocprim::*sort*"
|
|
313
|
+
library: "rocPRIM"
|
|
314
|
+
- pattern: "*rocprim::*merge*"
|
|
315
|
+
library: "rocPRIM"
|
|
316
|
+
|
|
317
|
+
# Memory/Copy operations
|
|
318
|
+
memory:
|
|
319
|
+
both:
|
|
320
|
+
- pattern: "*memcpy*"
|
|
321
|
+
library: "CUDA/HIP Runtime"
|
|
322
|
+
- pattern: "*direct_copy*"
|
|
323
|
+
library: "PyTorch"
|
|
324
|
+
- pattern: "*copy_page_indices*"
|
|
325
|
+
library: "vLLM"
|
|
326
|
+
- pattern: "*rocclr_copyBuffer*"
|
|
327
|
+
library: "AMD ROCclr"
|
|
328
|
+
- pattern: "*rocprim::*transform*"
|
|
329
|
+
library: "rocPRIM"
|
|
330
|
+
|
|
331
|
+
# Indexing/Scatter-Gather operations
|
|
332
|
+
indexing:
|
|
333
|
+
both:
|
|
334
|
+
- pattern: "*scatter_gather*"
|
|
335
|
+
library: "PyTorch"
|
|
336
|
+
- pattern: "*index_elementwise*"
|
|
337
|
+
library: "PyTorch"
|
|
338
|
+
- pattern: "*fill_reverse_indices*"
|
|
339
|
+
library: "PyTorch"
|
|
340
|
+
|
|
341
|
+
# Elementwise operations (fallback patterns)
|
|
342
|
+
elementwise:
|
|
343
|
+
both:
|
|
344
|
+
- pattern: "at::native::*elementwise*"
|
|
345
|
+
library: "PyTorch"
|
|
346
|
+
- pattern: "at::native::*vectorized*"
|
|
347
|
+
library: "PyTorch"
|
|
348
|
+
- pattern: "*distribution_elementwise*"
|
|
349
|
+
library: "PyTorch"
|