dsalt 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dsalt/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .modules.dsalt_attention import DSALTAttention
2
+ from .modules.dsalt_transformer import DSALTTransformer
3
+ from .kernels.sparse_attn import dsalt_attention
4
+
5
+ __version__ = "0.1.0"
@@ -0,0 +1,2 @@
1
+ from .sparse_attn import dsalt_attention
2
+ from .hybrid_energy import compute_hybrid_energy_scores, select_landmarks
@@ -0,0 +1,228 @@
1
+ """
2
+ dsalt/kernels/hybrid_energy.py
3
+ -------------------------------
4
+ Hybrid Energy scoring and landmark selection for DSALT.
5
+
6
+ The score for token j (candidate landmark) is:
7
+
8
+ s_j = α * z(‖x_j W_V‖₂) + (1-α) * z(‖x_j‖₂)
9
+
10
+ where z(·) is standard-normalization across candidates at the current layer:
11
+
12
+ z(x) = (x - mean(x)) / std(x)
13
+
14
+ Two Triton kernels:
15
+ 1. _compute_norms_kernel — compute ‖x_j‖₂ and ‖x_j W_V‖₂ for all j in parallel
16
+ 2. Top-k selection — via PyTorch topk (already O(n log k), GPU-native)
17
+
18
+ The landmark selection is **global** (shared across all query tokens in the
19
+ sequence), consistent with DSALT's design where each token attends to the
20
+ same top-k globally informative tokens.
21
+ """
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from typing import Tuple
26
+
27
+ try:
28
+ import triton
29
+ import triton.language as tl
30
+ _TRITON_AVAILABLE = True
31
+ except ImportError:
32
+ _TRITON_AVAILABLE = False
33
+
34
+
35
+ # ═════════════════════════════════════════════════════════════════════════════
36
+ # Triton kernel: compute norms in a single fused pass
37
+ # ═════════════════════════════════════════════════════════════════════════════
38
+
39
+ if _TRITON_AVAILABLE:
40
+
41
+ @triton.jit
42
+ def _hybrid_energy_kernel(
43
+ X_ptr, # [N, D] input hidden states (one head, one batch)
44
+ WV_ptr, # [D, D_head] value projection matrix
45
+ XNorm_ptr, # [N] output: ‖x_j‖₂
46
+ XVNorm_ptr, # [N] output: ‖x_j W_V‖₂
47
+ N: tl.constexpr,
48
+ D: tl.constexpr,
49
+ D_head: tl.constexpr,
50
+ BLOCK_N: tl.constexpr, # tokens per program
51
+ BLOCK_D: tl.constexpr, # must be >= D, power-of-2
52
+ BLOCK_Dh: tl.constexpr, # must be >= D_head, power-of-2
53
+ ):
54
+ """
55
+ Grid: (cdiv(N, BLOCK_N),)
56
+ Each program handles BLOCK_N tokens, computing both norms.
57
+ """
58
+ pid = tl.program_id(0)
59
+ start = pid * BLOCK_N
60
+ offs_n = start + tl.arange(0, BLOCK_N)
61
+ mask_n = offs_n < N
62
+ offs_d = tl.arange(0, BLOCK_D)
63
+ mask_d = offs_d < D
64
+
65
+ # Load X tile [BLOCK_N, BLOCK_D]
66
+ x = tl.load(
67
+ X_ptr + offs_n[:, None] * D + offs_d[None, :],
68
+ mask=mask_n[:, None] & mask_d[None, :],
69
+ other=0.0,
70
+ ) # [BLOCK_N, BLOCK_D]
71
+
72
+ # ‖x_j‖₂²
73
+ x_norm_sq = tl.sum(x * x, axis=1) # [BLOCK_N]
74
+
75
+ # x_j W_V : [BLOCK_N, BLOCK_D] @ [BLOCK_D, BLOCK_Dh]
76
+ # We load WV in column blocks to compute the product tile by tile
77
+ offs_dh = tl.arange(0, BLOCK_Dh)
78
+ mask_dh = offs_dh < D_head
79
+ xv = tl.zeros([BLOCK_N, BLOCK_Dh], dtype=tl.float32)
80
+ # Full product in one dot if BLOCK_D == D and BLOCK_Dh == D_head (usually true)
81
+ wv = tl.load(
82
+ WV_ptr + offs_d[:, None] * D_head + offs_dh[None, :],
83
+ mask=mask_d[:, None] & mask_dh[None, :],
84
+ other=0.0,
85
+ ) # [BLOCK_D, BLOCK_Dh]
86
+ xv = tl.dot(x, wv) # [BLOCK_N, BLOCK_Dh]
87
+
88
+ xv_norm_sq = tl.sum(xv * xv, axis=1) # [BLOCK_N]
89
+
90
+ # Store sqrt of norms
91
+ x_norm = tl.sqrt(x_norm_sq)
92
+ xv_norm = tl.sqrt(xv_norm_sq)
93
+
94
+ tl.store(XNorm_ptr + offs_n, x_norm, mask=mask_n)
95
+ tl.store(XVNorm_ptr + offs_n, xv_norm, mask=mask_n)
96
+
97
+
98
+ # ═════════════════════════════════════════════════════════════════════════════
99
+ # CPU fallback
100
+ # ═════════════════════════════════════════════════════════════════════════════
101
+
102
+ def _cpu_compute_norms(
103
+ X: torch.Tensor, # [N, D_model]
104
+ WV: torch.Tensor, # [D_model, D_head] (already oriented for right-multiply)
105
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ """Returns (x_norm [N], xv_norm [N])."""
107
+ x_norm = X.float().norm(dim=-1) # [N]
108
+ xv_norm = (X.float() @ WV.float()).norm(dim=-1) # [N, D_head] → norm → [N]
109
+ return x_norm, xv_norm
110
+
111
+
112
+ # ═════════════════════════════════════════════════════════════════════════════
113
+ # Main entry point
114
+ # ═════════════════════════════════════════════════════════════════════════════
115
+
116
+ def compute_hybrid_energy_scores(
117
+ X: torch.Tensor, # [B, H, N, D] hidden states
118
+ WV: torch.Tensor, # [H, D, D] per-head value projections
119
+ alpha: float = 0.6,
120
+ ) -> torch.Tensor:
121
+ """
122
+ Computes the Hybrid Energy score for every token in the sequence.
123
+
124
+ s_j = α * z(‖x_j W_V‖₂) + (1-α) * z(‖x_j‖₂)
125
+
126
+ where z(·) is standard-normalisation over j.
127
+
128
+ Returns scores of shape [B, H, N] (higher = more likely landmark).
129
+ """
130
+ B, H, N, D = X.shape
131
+ _, _, D_head = WV.shape
132
+ x_norms = torch.empty(B, H, N, dtype=torch.float32, device=X.device)
133
+ xv_norms = torch.empty(B, H, N, dtype=torch.float32, device=X.device)
134
+
135
+ if X.is_cuda and _TRITON_AVAILABLE:
136
+ BLOCK_N = 64
137
+ BLOCK_D = triton.next_power_of_2(D)
138
+ BLOCK_Dh = triton.next_power_of_2(D_head)
139
+ grid = (triton.cdiv(N, BLOCK_N),)
140
+
141
+ for b in range(B):
142
+ for h in range(H):
143
+ _hybrid_energy_kernel[grid](
144
+ X[b, h].contiguous(),
145
+ WV[h].contiguous(),
146
+ x_norms[b, h],
147
+ xv_norms[b, h],
148
+ N=N, D=D, D_head=D_head,
149
+ BLOCK_N=BLOCK_N,
150
+ BLOCK_D=BLOCK_D,
151
+ BLOCK_Dh=BLOCK_Dh,
152
+ )
153
+ else:
154
+ for b in range(B):
155
+ for h in range(H):
156
+ xn, xvn = _cpu_compute_norms(X[b, h], WV[h])
157
+ x_norms[b, h] = xn
158
+ xv_norms[b, h] = xvn
159
+
160
+ # Z-normalise each (b, h) independently over the N dimension
161
+ def _znorm(t: torch.Tensor) -> torch.Tensor:
162
+ mu = t.mean(dim=-1, keepdim=True)
163
+ std = t.std(dim=-1, keepdim=True).clamp(min=1e-6)
164
+ return (t - mu) / std
165
+
166
+ scores = alpha * _znorm(xv_norms) + (1.0 - alpha) * _znorm(x_norms)
167
+ return scores # [B, H, N]
168
+
169
+
170
+ def select_landmarks(
171
+ scores: torch.Tensor, # [B, H, N]
172
+ k: int, # number of landmarks
173
+ window_sizes: torch.Tensor, # [B, H, N] int — exclude in-window tokens
174
+ exclude_last: int = 0, # never select the last `exclude_last` tokens
175
+ # (avoids selecting current token as own landmark)
176
+ ) -> torch.Tensor:
177
+ """
178
+ Selects k landmark token indices per (batch, head) via top-k on Hybrid Energy.
179
+
180
+ Tokens inside the maximum window are excluded from landmark candidacy
181
+ (they are already covered by the local window pass).
182
+
183
+ Returns landmark_idx of shape [B, H, N, k] (int32).
184
+ Each query token i gets the SAME global top-k landmarks (the standard
185
+ DSALT design); the per-query mask is applied inside the attention kernel.
186
+ """
187
+ B, H, N = scores.shape
188
+ device = scores.device
189
+
190
+ # Build candidate mask: exclude tokens too close to the end of the sequence
191
+ # and tokens that are trivially in the window for most queries.
192
+ max_w = window_sizes.max() # conservative: use global max window
193
+
194
+ # Mask out the last max_w positions (they're in everyone's window)
195
+ cand_scores = scores.clone()
196
+ if max_w > 0 and N > max_w:
197
+ # Soft exclusion: set in-window region to -inf so they're never picked
198
+ # as "global" landmarks (they're covered locally).
199
+ cand_scores[..., N - max_w:] = float("-inf")
200
+ if exclude_last > 0:
201
+ cand_scores[..., N - exclude_last:] = float("-inf")
202
+
203
+ # Top-k selection [B, H, k]
204
+ k_safe = min(k, N)
205
+ _, top_idx = torch.topk(cand_scores, k=k_safe, dim=-1) # [B, H, k]
206
+
207
+ # Sort indices for more cache-friendly access in the attention kernel
208
+ top_idx, _ = top_idx.sort(dim=-1)
209
+
210
+ # Broadcast to [B, H, N, k] — same landmarks for every query token
211
+ landmark_idx = top_idx.unsqueeze(2).expand(B, H, N, k_safe)
212
+
213
+ return landmark_idx.to(torch.int32)
214
+
215
+
216
+ def compute_landmark_idx(
217
+ X: torch.Tensor, # [B, H, N, D]
218
+ WV: torch.Tensor, # [H, D, D]
219
+ window_sizes: torch.Tensor, # [B, H, N]
220
+ k: int,
221
+ alpha: float = 0.6,
222
+ ) -> torch.Tensor:
223
+ """
224
+ Convenience function: score + select in one call.
225
+ Returns landmark_idx [B, H, N, k] int32.
226
+ """
227
+ scores = compute_hybrid_energy_scores(X, WV, alpha)
228
+ return select_landmarks(scores, k, window_sizes)