wafer-core 0.1.26__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wafer_core/lib/trace_compare/PERFORMANCE.md +148 -0
- wafer_core/lib/trace_compare/__init__.py +22 -9
- wafer_core/lib/trace_compare/aligner.py +376 -0
- wafer_core/lib/trace_compare/analyzer.py +558 -159
- wafer_core/lib/trace_compare/api.py +225 -0
- wafer_core/lib/trace_compare/architecture.py +77 -0
- wafer_core/lib/trace_compare/classifier.py +307 -13
- wafer_core/lib/trace_compare/fusion_analyzer.py +280 -706
- wafer_core/lib/trace_compare/kernel_registry.yaml +349 -0
- wafer_core/lib/trace_compare/layer_segmentation.py +114 -0
- wafer_core/lib/trace_compare/loader.py +526 -227
- wafer_core/lib/trace_compare/same_kernel_analyzer.py +119 -0
- wafer_core/lib/trace_compare/warnings.py +99 -0
- wafer_core/targets/__init__.py +47 -21
- wafer_core/targets/pool.py +181 -0
- wafer_core/targets/probe.py +113 -0
- wafer_core/targets/providers/__init__.py +46 -0
- wafer_core/targets/providers/baremetal.py +72 -0
- wafer_core/targets/providers/digitalocean.py +164 -0
- wafer_core/targets/providers/runpod.py +250 -0
- wafer_core/targets/reconcile.py +90 -0
- wafer_core/targets/spec_store.py +200 -0
- wafer_core/targets/state_cache.py +150 -0
- wafer_core/targets/types.py +141 -0
- wafer_core/utils/kernel_utils/targets/config.py +8 -24
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/METADATA +3 -1
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/RECORD +28 -10
- {wafer_core-0.1.26.dist-info → wafer_core-0.1.28.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Trace Compare Performance
|
|
2
|
+
|
|
3
|
+
## Current Performance (v2 - 4.6x speedup)
|
|
4
|
+
|
|
5
|
+
| Trace Size | Load Time |
|
|
6
|
+
|------------|-----------|
|
|
7
|
+
| 919MB | ~17s |
|
|
8
|
+
| 1.1GB | ~22s |
|
|
9
|
+
| 2GB | ~38s |
|
|
10
|
+
|
|
11
|
+
## Implemented Optimizations
|
|
12
|
+
|
|
13
|
+
1. **orjson** - Rust-based JSON parser (1.4x faster than stdlib)
|
|
14
|
+
2. **Binary search for phases** - O(log n) instead of O(n) per kernel
|
|
15
|
+
3. **CPU op caching by kernel name** - 779k lookups → 48 lookups
|
|
16
|
+
4. **Pre-computed parent chains** - O(1) instead of O(50) walk per kernel
|
|
17
|
+
5. **String interning** - Reduced memory, faster dict lookups
|
|
18
|
+
|
|
19
|
+
## Next Optimizations (Target: <5s)
|
|
20
|
+
|
|
21
|
+
### High Impact
|
|
22
|
+
|
|
23
|
+
#### 1. Remove Pandas DataFrame (~3-5s savings)
|
|
24
|
+
|
|
25
|
+
The DataFrame is only used for groupby aggregations. Replace with dict-based aggregation during kernel processing:
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
# Current: Build DataFrame, then groupby
|
|
29
|
+
df = pd.DataFrame(kernel_data) # ~3s for 810k rows
|
|
30
|
+
by_op = df.groupby("op").agg(...) # ~2s
|
|
31
|
+
|
|
32
|
+
# Better: Aggregate during processing
|
|
33
|
+
by_op = defaultdict(lambda: {"total_us": 0, "count": 0, "kernels": []})
|
|
34
|
+
for kernel in kernels:
|
|
35
|
+
by_op[kernel["op"]]["total_us"] += kernel["dur_us"]
|
|
36
|
+
by_op[kernel["op"]]["count"] += 1
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
**Estimated savings**: 3-5s (DataFrame creation + groupby overhead)
|
|
40
|
+
|
|
41
|
+
#### 2. Lazy/Streaming JSON Parsing (~2-3s savings)
|
|
42
|
+
|
|
43
|
+
Don't parse the entire JSON upfront. Use ijson or orjson's streaming mode to process events as they're parsed:
|
|
44
|
+
|
|
45
|
+
```python
|
|
46
|
+
import ijson
|
|
47
|
+
|
|
48
|
+
def stream_events(file_path):
|
|
49
|
+
with open(file_path, "rb") as f:
|
|
50
|
+
for event in ijson.items(f, "traceEvents.item"):
|
|
51
|
+
yield event
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
**Estimated savings**: 2-3s (avoid holding full JSON in memory)
|
|
55
|
+
|
|
56
|
+
#### 3. Parallel Processing (~40% savings)
|
|
57
|
+
|
|
58
|
+
Kernel classification and stack resolution are embarrassingly parallel:
|
|
59
|
+
|
|
60
|
+
```python
|
|
61
|
+
from concurrent.futures import ProcessPoolExecutor
|
|
62
|
+
|
|
63
|
+
# Split kernel_events into chunks
|
|
64
|
+
chunks = [kernel_events[i::8] for i in range(8)]
|
|
65
|
+
|
|
66
|
+
with ProcessPoolExecutor(max_workers=8) as pool:
|
|
67
|
+
results = pool.map(process_chunk, chunks)
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
**Estimated savings**: 40% of processing time with 8 cores
|
|
71
|
+
|
|
72
|
+
### Medium Impact
|
|
73
|
+
|
|
74
|
+
#### 4. Skip Unused Data
|
|
75
|
+
|
|
76
|
+
Most use cases don't need full `python_stack` lists. Only resolve what's actually needed:
|
|
77
|
+
|
|
78
|
+
```python
|
|
79
|
+
def load_trace(file_path, include_full_stacks=False):
|
|
80
|
+
# Only compute full stacks when explicitly requested
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
#### 5. Memory-Mapped File Reading
|
|
84
|
+
|
|
85
|
+
Use mmap for faster file I/O on large traces:
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
import mmap
|
|
89
|
+
|
|
90
|
+
with open(file_path, "rb") as f:
|
|
91
|
+
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
|
92
|
+
trace = orjson.loads(mm)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
#### 6. Pre-compiled Regex for Classification
|
|
96
|
+
|
|
97
|
+
Compile regex patterns once at module load instead of per-call.
|
|
98
|
+
|
|
99
|
+
### Long-term (Target: <1s repeat loads)
|
|
100
|
+
|
|
101
|
+
#### 7. Binary Index Format
|
|
102
|
+
|
|
103
|
+
Create a pre-processed `.trace.idx` file on first load:
|
|
104
|
+
|
|
105
|
+
```python
|
|
106
|
+
# First load: Parse JSON (~17s), save index
|
|
107
|
+
# Subsequent loads: Load index (<1s)
|
|
108
|
+
|
|
109
|
+
def load_trace_with_index(trace_path):
|
|
110
|
+
index_path = trace_path.with_suffix(".trace.idx")
|
|
111
|
+
if index_path.exists() and is_valid(index_path, trace_path):
|
|
112
|
+
return load_index(index_path) # <1s
|
|
113
|
+
|
|
114
|
+
data = load_trace(trace_path) # ~17s
|
|
115
|
+
save_index(data, index_path)
|
|
116
|
+
return data
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
#### 8. Rust Extension (PyO3)
|
|
120
|
+
|
|
121
|
+
Port hot paths to Rust for 10-50x speedup:
|
|
122
|
+
- JSON event iteration
|
|
123
|
+
- Stack resolution
|
|
124
|
+
- Kernel classification
|
|
125
|
+
|
|
126
|
+
## Profiling Commands
|
|
127
|
+
|
|
128
|
+
```bash
|
|
129
|
+
# Run benchmark
|
|
130
|
+
cd packages/wafer-core
|
|
131
|
+
python -m tests.trace_compare.benchmark_trace_compare
|
|
132
|
+
|
|
133
|
+
# Profile with py-spy
|
|
134
|
+
py-spy record -o profile.svg -- python -c "
|
|
135
|
+
from wafer_core.lib.trace_compare.loader import load_trace
|
|
136
|
+
load_trace('/path/to/trace.json')
|
|
137
|
+
"
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
## Test Commands
|
|
141
|
+
|
|
142
|
+
```bash
|
|
143
|
+
# Run correctness tests
|
|
144
|
+
pytest tests/trace_compare/test_trace_compare_correctness.py -v
|
|
145
|
+
|
|
146
|
+
# Regenerate golden file (after expected changes)
|
|
147
|
+
python -m tests.trace_compare.generate_golden_file
|
|
148
|
+
```
|
|
@@ -5,28 +5,41 @@ identifying kernel-level performance differences and fusion opportunities.
|
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
7
|
from .analyzer import analyze_traces
|
|
8
|
+
from .api import TraceComparisonResult, analyze_trace_pair
|
|
9
|
+
from .architecture import ArchitectureType, detect_architecture
|
|
8
10
|
from .classifier import Op, classify
|
|
11
|
+
from .aligner import align_traces, TraceAlignment, LayerAlignment, KernelPair
|
|
9
12
|
from .formatter import (
|
|
10
13
|
format_csv,
|
|
11
|
-
format_fusion_csv,
|
|
12
|
-
format_fusion_json,
|
|
13
|
-
format_fusion_text,
|
|
14
14
|
format_json,
|
|
15
15
|
format_text,
|
|
16
16
|
)
|
|
17
|
-
from .fusion_analyzer import
|
|
18
|
-
from .
|
|
17
|
+
from .fusion_analyzer import analyze_fusion_from_alignment
|
|
18
|
+
from .same_kernel_analyzer import analyze_same_kernels_from_alignment
|
|
19
|
+
from .loader import load_trace, load_trace_streaming, StreamingMetadata
|
|
20
|
+
from .warnings import TraceWarning, detect_warnings
|
|
19
21
|
|
|
20
22
|
__all__ = [
|
|
21
23
|
"Op",
|
|
22
24
|
"classify",
|
|
23
25
|
"load_trace",
|
|
26
|
+
"load_trace_streaming",
|
|
27
|
+
"StreamingMetadata",
|
|
24
28
|
"analyze_traces",
|
|
25
|
-
"analyze_fusion_differences",
|
|
26
29
|
"format_text",
|
|
27
30
|
"format_csv",
|
|
28
31
|
"format_json",
|
|
29
|
-
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
+
# New alignment exports
|
|
33
|
+
"TraceComparisonResult",
|
|
34
|
+
"analyze_trace_pair",
|
|
35
|
+
"ArchitectureType",
|
|
36
|
+
"detect_architecture",
|
|
37
|
+
"TraceWarning",
|
|
38
|
+
"detect_warnings",
|
|
39
|
+
"align_traces",
|
|
40
|
+
"TraceAlignment",
|
|
41
|
+
"LayerAlignment",
|
|
42
|
+
"KernelPair",
|
|
43
|
+
"analyze_fusion_from_alignment",
|
|
44
|
+
"analyze_same_kernels_from_alignment",
|
|
32
45
|
]
|
|
@@ -0,0 +1,376 @@
|
|
|
1
|
+
"""Kernel alignment system for comparing AMD and NVIDIA traces.
|
|
2
|
+
|
|
3
|
+
Aligns kernels at the layer level using positional matching (same model = same layer structure).
|
|
4
|
+
Provides kernel-to-kernel mapping for exact performance comparison.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import bisect
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
from dataclasses import dataclass, field
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from .classifier import Op, classify
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class KernelPair:
|
|
17
|
+
"""A pair of aligned kernels from AMD and NVIDIA traces."""
|
|
18
|
+
|
|
19
|
+
position: int
|
|
20
|
+
operation: str
|
|
21
|
+
operation_detail: str | None
|
|
22
|
+
|
|
23
|
+
amd_kernel: str
|
|
24
|
+
amd_avg_us: float
|
|
25
|
+
amd_count: int
|
|
26
|
+
amd_total_us: float
|
|
27
|
+
|
|
28
|
+
nvidia_kernel: str | None
|
|
29
|
+
nvidia_avg_us: float
|
|
30
|
+
nvidia_count: int
|
|
31
|
+
nvidia_total_us: float
|
|
32
|
+
|
|
33
|
+
ratio: float
|
|
34
|
+
gap_us: float
|
|
35
|
+
fusion_note: str | None = None
|
|
36
|
+
is_same_kernel: bool = False
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class LayerAlignment:
|
|
41
|
+
"""Alignment data for a single layer."""
|
|
42
|
+
|
|
43
|
+
layer: int
|
|
44
|
+
amd_total_us: float
|
|
45
|
+
nvidia_total_us: float
|
|
46
|
+
ratio: float
|
|
47
|
+
gap_us: float
|
|
48
|
+
kernel_pairs: list[KernelPair] = field(default_factory=list)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class TraceAlignment:
|
|
53
|
+
"""Complete alignment result for two traces."""
|
|
54
|
+
|
|
55
|
+
layer_alignments: list[LayerAlignment]
|
|
56
|
+
num_layers: int
|
|
57
|
+
num_forward_passes: int
|
|
58
|
+
phase_breakdown: dict[str, int]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def split_by_forward_pass(
|
|
62
|
+
kernels: list[dict[str, Any]], phases: list[dict[str, Any]]
|
|
63
|
+
) -> list[list[dict[str, Any]]]:
|
|
64
|
+
"""Split kernels into forward passes using phase annotation timestamps.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
kernels: List of kernel events with 'ts' field
|
|
68
|
+
phases: List of phase annotations with 'ts_start' and 'ts_end'
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
List of forward passes, each containing kernels for that pass
|
|
72
|
+
"""
|
|
73
|
+
if not phases:
|
|
74
|
+
return [kernels]
|
|
75
|
+
|
|
76
|
+
sorted_phases = sorted(phases, key=lambda p: p["ts_start"])
|
|
77
|
+
sorted_kernels = sorted(kernels, key=lambda k: k.get("ts", 0))
|
|
78
|
+
kernel_timestamps = [k.get("ts", 0) for k in sorted_kernels]
|
|
79
|
+
|
|
80
|
+
forward_passes: list[list[dict[str, Any]]] = []
|
|
81
|
+
|
|
82
|
+
for phase in sorted_phases:
|
|
83
|
+
ts_start = phase["ts_start"]
|
|
84
|
+
ts_end = phase["ts_end"]
|
|
85
|
+
|
|
86
|
+
start_idx = bisect.bisect_left(kernel_timestamps, ts_start)
|
|
87
|
+
end_idx = bisect.bisect_right(kernel_timestamps, ts_end)
|
|
88
|
+
|
|
89
|
+
pass_kernels = sorted_kernels[start_idx:end_idx]
|
|
90
|
+
if pass_kernels:
|
|
91
|
+
forward_passes.append(pass_kernels)
|
|
92
|
+
|
|
93
|
+
return forward_passes
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def split_into_layers(
|
|
97
|
+
forward_pass: list[dict[str, Any]], platform: str
|
|
98
|
+
) -> list[list[dict[str, Any]]]:
|
|
99
|
+
"""Split a forward pass into layers using attention kernels as boundaries.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
forward_pass: Kernels from a single forward pass
|
|
103
|
+
platform: 'AMD' or 'NVIDIA'
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
List of layers, each containing kernels for that layer
|
|
107
|
+
"""
|
|
108
|
+
if not forward_pass:
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
sorted_kernels = sorted(forward_pass, key=lambda k: k.get("ts", 0))
|
|
112
|
+
|
|
113
|
+
layer_markers: list[int] = []
|
|
114
|
+
for i, kernel in enumerate(sorted_kernels):
|
|
115
|
+
name_lower = kernel.get("name", "").lower()
|
|
116
|
+
is_attention = False
|
|
117
|
+
|
|
118
|
+
if platform == "AMD":
|
|
119
|
+
is_attention = "attention" in name_lower and ("2d" in name_lower or "3d" in name_lower)
|
|
120
|
+
else:
|
|
121
|
+
is_attention = "fmha" in name_lower or "attention" in name_lower
|
|
122
|
+
|
|
123
|
+
if is_attention:
|
|
124
|
+
layer_markers.append(i)
|
|
125
|
+
|
|
126
|
+
if not layer_markers:
|
|
127
|
+
return [sorted_kernels]
|
|
128
|
+
|
|
129
|
+
layers: list[list[dict[str, Any]]] = []
|
|
130
|
+
for i, marker_idx in enumerate(layer_markers):
|
|
131
|
+
start_idx = marker_idx
|
|
132
|
+
end_idx = layer_markers[i + 1] if i + 1 < len(layer_markers) else len(sorted_kernels)
|
|
133
|
+
layer_kernels = sorted_kernels[start_idx:end_idx]
|
|
134
|
+
if layer_kernels:
|
|
135
|
+
layers.append(layer_kernels)
|
|
136
|
+
|
|
137
|
+
return layers
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def align_kernels_within_layer(
|
|
141
|
+
amd_layer_instances: list[list[dict[str, Any]]],
|
|
142
|
+
nvidia_layer_instances: list[list[dict[str, Any]]],
|
|
143
|
+
platform_amd: str = "AMD",
|
|
144
|
+
platform_nvidia: str = "NVIDIA",
|
|
145
|
+
) -> list[KernelPair]:
|
|
146
|
+
"""Align kernels within a layer across multiple forward pass instances.
|
|
147
|
+
|
|
148
|
+
Args:
|
|
149
|
+
amd_layer_instances: List of layer kernels from each AMD forward pass
|
|
150
|
+
nvidia_layer_instances: List of layer kernels from each NVIDIA forward pass
|
|
151
|
+
platform_amd: Platform name for AMD
|
|
152
|
+
platform_nvidia: Platform name for NVIDIA
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
List of aligned kernel pairs
|
|
156
|
+
"""
|
|
157
|
+
if not amd_layer_instances and not nvidia_layer_instances:
|
|
158
|
+
return []
|
|
159
|
+
|
|
160
|
+
amd_by_op_pos: dict[tuple[str, int], list[dict[str, Any]]] = defaultdict(list)
|
|
161
|
+
nvidia_by_op_pos: dict[tuple[str, int], list[dict[str, Any]]] = defaultdict(list)
|
|
162
|
+
|
|
163
|
+
for instance in amd_layer_instances:
|
|
164
|
+
sorted_kernels = sorted(instance, key=lambda k: k.get("ts", 0))
|
|
165
|
+
for pos, kernel in enumerate(sorted_kernels):
|
|
166
|
+
op, pattern = classify(kernel.get("name", ""), platform_amd)
|
|
167
|
+
# For FUSED_UNKNOWN, use the pattern (e.g., "RMSNorm+GEMM") as the operation
|
|
168
|
+
# so fusion detection works correctly
|
|
169
|
+
op_str = pattern if op.value == "Fused (Unknown)" else op.value
|
|
170
|
+
amd_by_op_pos[(op_str, pos)].append(kernel)
|
|
171
|
+
|
|
172
|
+
for instance in nvidia_layer_instances:
|
|
173
|
+
sorted_kernels = sorted(instance, key=lambda k: k.get("ts", 0))
|
|
174
|
+
for pos, kernel in enumerate(sorted_kernels):
|
|
175
|
+
op, pattern = classify(kernel.get("name", ""), platform_nvidia)
|
|
176
|
+
# For FUSED_UNKNOWN, use the pattern (e.g., "RMSNorm+GEMM") as the operation
|
|
177
|
+
op_str = pattern if op.value == "Fused (Unknown)" else op.value
|
|
178
|
+
nvidia_by_op_pos[(op_str, pos)].append(kernel)
|
|
179
|
+
|
|
180
|
+
all_keys = set(amd_by_op_pos.keys()) | set(nvidia_by_op_pos.keys())
|
|
181
|
+
kernel_pairs: list[KernelPair] = []
|
|
182
|
+
|
|
183
|
+
for op_str, pos in sorted(all_keys):
|
|
184
|
+
amd_kernels = amd_by_op_pos.get((op_str, pos), [])
|
|
185
|
+
nvidia_kernels = nvidia_by_op_pos.get((op_str, pos), [])
|
|
186
|
+
|
|
187
|
+
amd_total_us = sum(k.get("dur", 0) for k in amd_kernels)
|
|
188
|
+
amd_count = len(amd_kernels)
|
|
189
|
+
amd_avg_us = amd_total_us / amd_count if amd_count > 0 else 0.0
|
|
190
|
+
|
|
191
|
+
amd_kernel_name = ""
|
|
192
|
+
if amd_kernels:
|
|
193
|
+
name_counts = defaultdict(int)
|
|
194
|
+
for k in amd_kernels:
|
|
195
|
+
name_counts[k.get("name", "")] += 1
|
|
196
|
+
amd_kernel_name = max(name_counts.items(), key=lambda x: x[1])[0]
|
|
197
|
+
|
|
198
|
+
nvidia_total_us = sum(k.get("dur", 0) for k in nvidia_kernels)
|
|
199
|
+
nvidia_count = len(nvidia_kernels)
|
|
200
|
+
nvidia_avg_us = nvidia_total_us / nvidia_count if nvidia_count > 0 else 0.0
|
|
201
|
+
|
|
202
|
+
nvidia_kernel_name: str | None = None
|
|
203
|
+
if nvidia_kernels:
|
|
204
|
+
name_counts = defaultdict(int)
|
|
205
|
+
for k in nvidia_kernels:
|
|
206
|
+
name_counts[k.get("name", "")] += 1
|
|
207
|
+
nvidia_kernel_name = max(name_counts.items(), key=lambda x: x[1])[0]
|
|
208
|
+
|
|
209
|
+
ratio = amd_avg_us / nvidia_avg_us if nvidia_avg_us > 0 else float("inf")
|
|
210
|
+
gap_us = amd_avg_us - nvidia_avg_us
|
|
211
|
+
|
|
212
|
+
# Detect fusion notes
|
|
213
|
+
# Key insight: If operation has '+' (e.g., "RMSNorm+GEMM"), it's already a fused operation
|
|
214
|
+
# The platform that HAS the kernel IS fusing; the other runs components separately
|
|
215
|
+
is_fused_op = "+" in op_str
|
|
216
|
+
|
|
217
|
+
# Operations that can't be "fused away" - absence means alignment issue, not fusion
|
|
218
|
+
non_fusable_ops = {
|
|
219
|
+
"Attention (Prefill)", "Attention (Decode)", "Dense GEMM",
|
|
220
|
+
"KV Cache", "MoE GEMM", "MoE Routing"
|
|
221
|
+
}
|
|
222
|
+
is_non_fusable = op_str in non_fusable_ops
|
|
223
|
+
|
|
224
|
+
fusion_note = None
|
|
225
|
+
if amd_count > 0 and nvidia_count == 0:
|
|
226
|
+
if is_fused_op:
|
|
227
|
+
# AMD has a fused kernel like "RMSNorm+GEMM" → AMD IS fusing
|
|
228
|
+
fusion_note = f"AMD fuses {op_str} into {amd_kernel_name}"
|
|
229
|
+
elif not is_non_fusable:
|
|
230
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
231
|
+
fusion_note = f"AMD runs {amd_kernel_name}, NVIDIA may fuse into another kernel"
|
|
232
|
+
elif amd_count == 0 and nvidia_count > 0:
|
|
233
|
+
if is_fused_op:
|
|
234
|
+
# NVIDIA has a fused kernel → NVIDIA IS fusing
|
|
235
|
+
fusion_note = f"NVIDIA fuses {op_str} into {nvidia_kernel_name}"
|
|
236
|
+
elif not is_non_fusable:
|
|
237
|
+
# Only mark as fusion for ops that can legitimately be fused
|
|
238
|
+
fusion_note = f"NVIDIA runs {nvidia_kernel_name}, AMD may fuse into another kernel"
|
|
239
|
+
elif amd_count > nvidia_count * 1.5 and nvidia_count > 0:
|
|
240
|
+
# AMD runs more kernels = NVIDIA is fusing some
|
|
241
|
+
fusion_note = f"AMD runs {amd_kernel_name} {amd_count / nvidia_count:.1f}x more → NVIDIA fuses"
|
|
242
|
+
elif nvidia_count > amd_count * 1.5 and amd_count > 0:
|
|
243
|
+
# NVIDIA runs more kernels = AMD is fusing some
|
|
244
|
+
fusion_note = f"NVIDIA runs {nvidia_kernel_name} {nvidia_count / amd_count:.1f}x more → AMD fuses"
|
|
245
|
+
|
|
246
|
+
is_same = (
|
|
247
|
+
amd_kernel_name != ""
|
|
248
|
+
and nvidia_kernel_name is not None
|
|
249
|
+
and amd_kernel_name == nvidia_kernel_name
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
operation_detail = None
|
|
253
|
+
if op_str == "Dense GEMM":
|
|
254
|
+
if "qkv" in amd_kernel_name.lower() or "qkv" in (nvidia_kernel_name or "").lower():
|
|
255
|
+
operation_detail = "QKV"
|
|
256
|
+
elif "out" in amd_kernel_name.lower() or "out" in (nvidia_kernel_name or "").lower():
|
|
257
|
+
operation_detail = "O"
|
|
258
|
+
elif "up" in amd_kernel_name.lower() or "up" in (nvidia_kernel_name or "").lower():
|
|
259
|
+
operation_detail = "FFN_up"
|
|
260
|
+
elif "down" in amd_kernel_name.lower() or "down" in (nvidia_kernel_name or "").lower():
|
|
261
|
+
operation_detail = "FFN_down"
|
|
262
|
+
|
|
263
|
+
kernel_pairs.append(
|
|
264
|
+
KernelPair(
|
|
265
|
+
position=pos,
|
|
266
|
+
operation=op_str,
|
|
267
|
+
operation_detail=operation_detail,
|
|
268
|
+
amd_kernel=amd_kernel_name,
|
|
269
|
+
amd_avg_us=amd_avg_us,
|
|
270
|
+
amd_count=amd_count,
|
|
271
|
+
amd_total_us=amd_total_us,
|
|
272
|
+
nvidia_kernel=nvidia_kernel_name,
|
|
273
|
+
nvidia_avg_us=nvidia_avg_us,
|
|
274
|
+
nvidia_count=nvidia_count,
|
|
275
|
+
nvidia_total_us=nvidia_total_us,
|
|
276
|
+
ratio=ratio,
|
|
277
|
+
gap_us=gap_us,
|
|
278
|
+
fusion_note=fusion_note,
|
|
279
|
+
is_same_kernel=is_same,
|
|
280
|
+
)
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return kernel_pairs
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def align_traces(
|
|
287
|
+
amd_kernels: list[dict[str, Any]],
|
|
288
|
+
nvidia_kernels: list[dict[str, Any]],
|
|
289
|
+
amd_phases: list[dict[str, Any]],
|
|
290
|
+
nvidia_phases: list[dict[str, Any]],
|
|
291
|
+
platform_amd: str = "AMD",
|
|
292
|
+
platform_nvidia: str = "NVIDIA",
|
|
293
|
+
) -> TraceAlignment:
|
|
294
|
+
"""Align two traces at the layer level.
|
|
295
|
+
|
|
296
|
+
Args:
|
|
297
|
+
amd_kernels: Kernel events from AMD trace
|
|
298
|
+
nvidia_kernels: Kernel events from NVIDIA trace
|
|
299
|
+
amd_phases: Phase annotations from AMD trace
|
|
300
|
+
nvidia_phases: Phase annotations from NVIDIA trace
|
|
301
|
+
platform_amd: Platform name for AMD (default: "AMD")
|
|
302
|
+
platform_nvidia: Platform name for NVIDIA (default: "NVIDIA")
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
TraceAlignment with layer-by-layer kernel pairs
|
|
306
|
+
"""
|
|
307
|
+
amd_passes = split_by_forward_pass(amd_kernels, amd_phases)
|
|
308
|
+
nvidia_passes = split_by_forward_pass(nvidia_kernels, nvidia_phases)
|
|
309
|
+
|
|
310
|
+
if not amd_passes or not nvidia_passes:
|
|
311
|
+
return TraceAlignment(
|
|
312
|
+
layer_alignments=[],
|
|
313
|
+
num_layers=0,
|
|
314
|
+
num_forward_passes=0,
|
|
315
|
+
phase_breakdown={"prefill": 0, "decode": 0},
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
amd_layers_by_pass: list[list[list[dict[str, Any]]]] = [
|
|
319
|
+
split_into_layers(pass_kernels, platform_amd) for pass_kernels in amd_passes
|
|
320
|
+
]
|
|
321
|
+
nvidia_layers_by_pass: list[list[list[dict[str, Any]]]] = [
|
|
322
|
+
split_into_layers(pass_kernels, platform_nvidia) for pass_kernels in nvidia_passes
|
|
323
|
+
]
|
|
324
|
+
|
|
325
|
+
num_layers = max(
|
|
326
|
+
max((len(layers) for layers in amd_layers_by_pass), default=0),
|
|
327
|
+
max((len(layers) for layers in nvidia_layers_by_pass), default=0),
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if num_layers == 0:
|
|
331
|
+
return TraceAlignment(
|
|
332
|
+
layer_alignments=[],
|
|
333
|
+
num_layers=0,
|
|
334
|
+
num_forward_passes=len(amd_passes),
|
|
335
|
+
phase_breakdown={"prefill": 0, "decode": 0},
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
layer_alignments: list[LayerAlignment] = []
|
|
339
|
+
|
|
340
|
+
for layer_idx in range(num_layers):
|
|
341
|
+
amd_layer_instances = [
|
|
342
|
+
layers[layer_idx] for layers in amd_layers_by_pass if layer_idx < len(layers)
|
|
343
|
+
]
|
|
344
|
+
nvidia_layer_instances = [
|
|
345
|
+
layers[layer_idx] for layers in nvidia_layers_by_pass if layer_idx < len(layers)
|
|
346
|
+
]
|
|
347
|
+
|
|
348
|
+
kernel_pairs = align_kernels_within_layer(
|
|
349
|
+
amd_layer_instances, nvidia_layer_instances, platform_amd, platform_nvidia
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
amd_total_us = sum(pair.amd_total_us for pair in kernel_pairs)
|
|
353
|
+
nvidia_total_us = sum(pair.nvidia_total_us for pair in kernel_pairs)
|
|
354
|
+
ratio = amd_total_us / nvidia_total_us if nvidia_total_us > 0 else float("inf")
|
|
355
|
+
gap_us = amd_total_us - nvidia_total_us
|
|
356
|
+
|
|
357
|
+
layer_alignments.append(
|
|
358
|
+
LayerAlignment(
|
|
359
|
+
layer=layer_idx,
|
|
360
|
+
amd_total_us=amd_total_us,
|
|
361
|
+
nvidia_total_us=nvidia_total_us,
|
|
362
|
+
ratio=ratio,
|
|
363
|
+
gap_us=gap_us,
|
|
364
|
+
kernel_pairs=kernel_pairs,
|
|
365
|
+
)
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
prefill_count = sum(1 for p in amd_phases if p.get("type") == "prefill")
|
|
369
|
+
decode_count = sum(1 for p in amd_phases if p.get("type") == "decode")
|
|
370
|
+
|
|
371
|
+
return TraceAlignment(
|
|
372
|
+
layer_alignments=layer_alignments,
|
|
373
|
+
num_layers=num_layers,
|
|
374
|
+
num_forward_passes=len(amd_passes),
|
|
375
|
+
phase_breakdown={"prefill": prefill_count, "decode": decode_count},
|
|
376
|
+
)
|