wafer-core 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,114 @@
1
+ """Layer segmentation based on architecture type.
2
+
3
+ Segments kernels into transformer layers based on architecture-specific markers
4
+ (e.g., attention kernels for transformers, SSM scan kernels for Mamba).
5
+ """
6
+
7
+ import bisect
8
+ from typing import Any
9
+
10
+ from .architecture import ArchitectureType
11
+ from .warnings import TraceWarning
12
+
13
+
14
+ def segment_layers_by_architecture(
15
+ kernels: list[dict[str, Any]],
16
+ architecture: ArchitectureType,
17
+ ) -> tuple[dict[int, list[dict[str, Any]]], list[TraceWarning]]:
18
+ """Segment kernels into layers based on architecture.
19
+
20
+ Args:
21
+ kernels: List of kernel events with 'name', 'ts', and other fields
22
+ architecture: Detected architecture type
23
+
24
+ Returns:
25
+ Tuple of (layer_mapping, warnings)
26
+ layer_mapping: Dict mapping layer_num -> list of kernel events
27
+ warnings: List of warnings if segmentation fails
28
+ """
29
+ warnings: list[TraceWarning] = []
30
+
31
+ if architecture == ArchitectureType.HYBRID:
32
+ warnings.append(
33
+ TraceWarning(
34
+ code="HYBRID_ARCHITECTURE",
35
+ severity="info",
36
+ message="Hybrid architecture detected (both attention and SSM kernels). Layer segmentation unavailable.",
37
+ suggestion="Hybrid models require custom segmentation logic. Layer analysis will be skipped.",
38
+ )
39
+ )
40
+ return {}, warnings
41
+
42
+ if architecture == ArchitectureType.UNKNOWN:
43
+ warnings.append(
44
+ TraceWarning(
45
+ code="UNKNOWN_ARCHITECTURE",
46
+ severity="warning",
47
+ message="Cannot determine model architecture. Layer segmentation unavailable.",
48
+ suggestion="Ensure trace contains recognizable kernel patterns (attention, SSM, etc.).",
49
+ )
50
+ )
51
+ return {}, warnings
52
+
53
+ layer_markers: list[tuple[int, str]] = []
54
+
55
+ for kernel in kernels:
56
+ name_lower = kernel.get("name", "").lower()
57
+
58
+ if architecture == ArchitectureType.TRANSFORMER:
59
+ if any(pattern in name_lower for pattern in ["fmha", "attention", "flash"]):
60
+ if "context" in name_lower or "2d" in name_lower or "fmhasm100a" in name_lower:
61
+ layer_markers.append((kernel.get("ts", 0), kernel.get("name", "")))
62
+ elif architecture == ArchitectureType.SSM:
63
+ if any(pattern in name_lower for pattern in ["selective_scan", "mamba", "ssd"]):
64
+ layer_markers.append((kernel.get("ts", 0), kernel.get("name", "")))
65
+
66
+ if not layer_markers:
67
+ warnings.append(
68
+ TraceWarning(
69
+ code="NO_LAYER_MARKERS",
70
+ severity="warning",
71
+ message=f"No layer marker kernels found for {architecture.value} architecture.",
72
+ suggestion="Ensure trace contains expected kernel patterns for this architecture type.",
73
+ )
74
+ )
75
+ return {}, warnings
76
+
77
+ layer_markers.sort(key=lambda x: x[0])
78
+
79
+ # Sort kernels by timestamp for binary search
80
+ sorted_kernels = sorted(kernels, key=lambda k: k.get("ts", 0))
81
+ kernel_timestamps = [k.get("ts", 0) for k in sorted_kernels]
82
+
83
+ layer_mapping: dict[int, list[dict[str, Any]]] = {}
84
+
85
+ for i, (marker_ts, _) in enumerate(layer_markers):
86
+ layer_num = i
87
+ ts_start = marker_ts
88
+ ts_end = layer_markers[i + 1][0] if i + 1 < len(layer_markers) else float("inf")
89
+
90
+ # Binary search for start and end indices
91
+ start_idx = bisect.bisect_left(kernel_timestamps, ts_start)
92
+ end_idx = bisect.bisect_left(kernel_timestamps, ts_end) if ts_end != float("inf") else len(sorted_kernels)
93
+
94
+ layer_kernels = sorted_kernels[start_idx:end_idx]
95
+
96
+ if layer_kernels:
97
+ layer_mapping[layer_num] = layer_kernels
98
+
99
+ if layer_mapping:
100
+ kernel_counts = [len(kernels) for kernels in layer_mapping.values()]
101
+ if kernel_counts:
102
+ mean_count = sum(kernel_counts) / len(kernel_counts)
103
+ variances = [abs(count - mean_count) / mean_count for count in kernel_counts]
104
+ if any(v > 0.3 for v in variances):
105
+ warnings.append(
106
+ TraceWarning(
107
+ code="LAYER_SIZE_VARIANCE",
108
+ severity="info",
109
+ message="Layer kernel counts vary significantly. Segmentation may be inaccurate.",
110
+ suggestion="This is normal for models with varying layer sizes or non-uniform workloads.",
111
+ )
112
+ )
113
+
114
+ return layer_mapping, warnings