dsalt 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dsalt/__init__.py +5 -0
- dsalt/kernels/__init__.py +2 -0
- dsalt/kernels/hybrid_energy.py +228 -0
- dsalt/kernels/sparse_attn.py +670 -0
- dsalt/kernels/window_utils.py +100 -0
- dsalt/model/__init__.py +3 -0
- dsalt/model/dsalt_lm.py +89 -0
- dsalt/modules/__init__.py +0 -0
- dsalt/modules/dsalt_attention.py +197 -0
- dsalt/modules/dsalt_transformer.py +253 -0
- dsalt/py.typed +1 -0
- dsalt/training/__init__.py +0 -0
- dsalt/training/trainer.py +361 -0
- dsalt-0.1.0.dist-info/METADATA +243 -0
- dsalt-0.1.0.dist-info/RECORD +18 -0
- dsalt-0.1.0.dist-info/WHEEL +5 -0
- dsalt-0.1.0.dist-info/licenses/LICENSE +200 -0
- dsalt-0.1.0.dist-info/top_level.txt +1 -0
dsalt/__init__.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
"""
|
|
2
|
+
dsalt/kernels/hybrid_energy.py
|
|
3
|
+
-------------------------------
|
|
4
|
+
Hybrid Energy scoring and landmark selection for DSALT.
|
|
5
|
+
|
|
6
|
+
The score for token j (candidate landmark) is:
|
|
7
|
+
|
|
8
|
+
s_j = α * z(‖x_j W_V‖₂) + (1-α) * z(‖x_j‖₂)
|
|
9
|
+
|
|
10
|
+
where z(·) is standard-normalization across candidates at the current layer:
|
|
11
|
+
|
|
12
|
+
z(x) = (x - mean(x)) / std(x)
|
|
13
|
+
|
|
14
|
+
Two Triton kernels:
|
|
15
|
+
1. _compute_norms_kernel — compute ‖x_j‖₂ and ‖x_j W_V‖₂ for all j in parallel
|
|
16
|
+
2. Top-k selection — via PyTorch topk (already O(n log k), GPU-native)
|
|
17
|
+
|
|
18
|
+
The landmark selection is **global** (shared across all query tokens in the
|
|
19
|
+
sequence), consistent with DSALT's design where each token attends to the
|
|
20
|
+
same top-k globally informative tokens.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import torch
|
|
24
|
+
import torch.nn.functional as F
|
|
25
|
+
from typing import Tuple
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
import triton
|
|
29
|
+
import triton.language as tl
|
|
30
|
+
_TRITON_AVAILABLE = True
|
|
31
|
+
except ImportError:
|
|
32
|
+
_TRITON_AVAILABLE = False
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
36
|
+
# Triton kernel: compute norms in a single fused pass
|
|
37
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
38
|
+
|
|
39
|
+
if _TRITON_AVAILABLE:
|
|
40
|
+
|
|
41
|
+
@triton.jit
|
|
42
|
+
def _hybrid_energy_kernel(
|
|
43
|
+
X_ptr, # [N, D] input hidden states (one head, one batch)
|
|
44
|
+
WV_ptr, # [D, D_head] value projection matrix
|
|
45
|
+
XNorm_ptr, # [N] output: ‖x_j‖₂
|
|
46
|
+
XVNorm_ptr, # [N] output: ‖x_j W_V‖₂
|
|
47
|
+
N: tl.constexpr,
|
|
48
|
+
D: tl.constexpr,
|
|
49
|
+
D_head: tl.constexpr,
|
|
50
|
+
BLOCK_N: tl.constexpr, # tokens per program
|
|
51
|
+
BLOCK_D: tl.constexpr, # must be >= D, power-of-2
|
|
52
|
+
BLOCK_Dh: tl.constexpr, # must be >= D_head, power-of-2
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Grid: (cdiv(N, BLOCK_N),)
|
|
56
|
+
Each program handles BLOCK_N tokens, computing both norms.
|
|
57
|
+
"""
|
|
58
|
+
pid = tl.program_id(0)
|
|
59
|
+
start = pid * BLOCK_N
|
|
60
|
+
offs_n = start + tl.arange(0, BLOCK_N)
|
|
61
|
+
mask_n = offs_n < N
|
|
62
|
+
offs_d = tl.arange(0, BLOCK_D)
|
|
63
|
+
mask_d = offs_d < D
|
|
64
|
+
|
|
65
|
+
# Load X tile [BLOCK_N, BLOCK_D]
|
|
66
|
+
x = tl.load(
|
|
67
|
+
X_ptr + offs_n[:, None] * D + offs_d[None, :],
|
|
68
|
+
mask=mask_n[:, None] & mask_d[None, :],
|
|
69
|
+
other=0.0,
|
|
70
|
+
) # [BLOCK_N, BLOCK_D]
|
|
71
|
+
|
|
72
|
+
# ‖x_j‖₂²
|
|
73
|
+
x_norm_sq = tl.sum(x * x, axis=1) # [BLOCK_N]
|
|
74
|
+
|
|
75
|
+
# x_j W_V : [BLOCK_N, BLOCK_D] @ [BLOCK_D, BLOCK_Dh]
|
|
76
|
+
# We load WV in column blocks to compute the product tile by tile
|
|
77
|
+
offs_dh = tl.arange(0, BLOCK_Dh)
|
|
78
|
+
mask_dh = offs_dh < D_head
|
|
79
|
+
xv = tl.zeros([BLOCK_N, BLOCK_Dh], dtype=tl.float32)
|
|
80
|
+
# Full product in one dot if BLOCK_D == D and BLOCK_Dh == D_head (usually true)
|
|
81
|
+
wv = tl.load(
|
|
82
|
+
WV_ptr + offs_d[:, None] * D_head + offs_dh[None, :],
|
|
83
|
+
mask=mask_d[:, None] & mask_dh[None, :],
|
|
84
|
+
other=0.0,
|
|
85
|
+
) # [BLOCK_D, BLOCK_Dh]
|
|
86
|
+
xv = tl.dot(x, wv) # [BLOCK_N, BLOCK_Dh]
|
|
87
|
+
|
|
88
|
+
xv_norm_sq = tl.sum(xv * xv, axis=1) # [BLOCK_N]
|
|
89
|
+
|
|
90
|
+
# Store sqrt of norms
|
|
91
|
+
x_norm = tl.sqrt(x_norm_sq)
|
|
92
|
+
xv_norm = tl.sqrt(xv_norm_sq)
|
|
93
|
+
|
|
94
|
+
tl.store(XNorm_ptr + offs_n, x_norm, mask=mask_n)
|
|
95
|
+
tl.store(XVNorm_ptr + offs_n, xv_norm, mask=mask_n)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
99
|
+
# CPU fallback
|
|
100
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
101
|
+
|
|
102
|
+
def _cpu_compute_norms(
|
|
103
|
+
X: torch.Tensor, # [N, D_model]
|
|
104
|
+
WV: torch.Tensor, # [D_model, D_head] (already oriented for right-multiply)
|
|
105
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
106
|
+
"""Returns (x_norm [N], xv_norm [N])."""
|
|
107
|
+
x_norm = X.float().norm(dim=-1) # [N]
|
|
108
|
+
xv_norm = (X.float() @ WV.float()).norm(dim=-1) # [N, D_head] → norm → [N]
|
|
109
|
+
return x_norm, xv_norm
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
113
|
+
# Main entry point
|
|
114
|
+
# ═════════════════════════════════════════════════════════════════════════════
|
|
115
|
+
|
|
116
|
+
def compute_hybrid_energy_scores(
|
|
117
|
+
X: torch.Tensor, # [B, H, N, D] hidden states
|
|
118
|
+
WV: torch.Tensor, # [H, D, D] per-head value projections
|
|
119
|
+
alpha: float = 0.6,
|
|
120
|
+
) -> torch.Tensor:
|
|
121
|
+
"""
|
|
122
|
+
Computes the Hybrid Energy score for every token in the sequence.
|
|
123
|
+
|
|
124
|
+
s_j = α * z(‖x_j W_V‖₂) + (1-α) * z(‖x_j‖₂)
|
|
125
|
+
|
|
126
|
+
where z(·) is standard-normalisation over j.
|
|
127
|
+
|
|
128
|
+
Returns scores of shape [B, H, N] (higher = more likely landmark).
|
|
129
|
+
"""
|
|
130
|
+
B, H, N, D = X.shape
|
|
131
|
+
_, _, D_head = WV.shape
|
|
132
|
+
x_norms = torch.empty(B, H, N, dtype=torch.float32, device=X.device)
|
|
133
|
+
xv_norms = torch.empty(B, H, N, dtype=torch.float32, device=X.device)
|
|
134
|
+
|
|
135
|
+
if X.is_cuda and _TRITON_AVAILABLE:
|
|
136
|
+
BLOCK_N = 64
|
|
137
|
+
BLOCK_D = triton.next_power_of_2(D)
|
|
138
|
+
BLOCK_Dh = triton.next_power_of_2(D_head)
|
|
139
|
+
grid = (triton.cdiv(N, BLOCK_N),)
|
|
140
|
+
|
|
141
|
+
for b in range(B):
|
|
142
|
+
for h in range(H):
|
|
143
|
+
_hybrid_energy_kernel[grid](
|
|
144
|
+
X[b, h].contiguous(),
|
|
145
|
+
WV[h].contiguous(),
|
|
146
|
+
x_norms[b, h],
|
|
147
|
+
xv_norms[b, h],
|
|
148
|
+
N=N, D=D, D_head=D_head,
|
|
149
|
+
BLOCK_N=BLOCK_N,
|
|
150
|
+
BLOCK_D=BLOCK_D,
|
|
151
|
+
BLOCK_Dh=BLOCK_Dh,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
for b in range(B):
|
|
155
|
+
for h in range(H):
|
|
156
|
+
xn, xvn = _cpu_compute_norms(X[b, h], WV[h])
|
|
157
|
+
x_norms[b, h] = xn
|
|
158
|
+
xv_norms[b, h] = xvn
|
|
159
|
+
|
|
160
|
+
# Z-normalise each (b, h) independently over the N dimension
|
|
161
|
+
def _znorm(t: torch.Tensor) -> torch.Tensor:
|
|
162
|
+
mu = t.mean(dim=-1, keepdim=True)
|
|
163
|
+
std = t.std(dim=-1, keepdim=True).clamp(min=1e-6)
|
|
164
|
+
return (t - mu) / std
|
|
165
|
+
|
|
166
|
+
scores = alpha * _znorm(xv_norms) + (1.0 - alpha) * _znorm(x_norms)
|
|
167
|
+
return scores # [B, H, N]
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def select_landmarks(
|
|
171
|
+
scores: torch.Tensor, # [B, H, N]
|
|
172
|
+
k: int, # number of landmarks
|
|
173
|
+
window_sizes: torch.Tensor, # [B, H, N] int — exclude in-window tokens
|
|
174
|
+
exclude_last: int = 0, # never select the last `exclude_last` tokens
|
|
175
|
+
# (avoids selecting current token as own landmark)
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
"""
|
|
178
|
+
Selects k landmark token indices per (batch, head) via top-k on Hybrid Energy.
|
|
179
|
+
|
|
180
|
+
Tokens inside the maximum window are excluded from landmark candidacy
|
|
181
|
+
(they are already covered by the local window pass).
|
|
182
|
+
|
|
183
|
+
Returns landmark_idx of shape [B, H, N, k] (int32).
|
|
184
|
+
Each query token i gets the SAME global top-k landmarks (the standard
|
|
185
|
+
DSALT design); the per-query mask is applied inside the attention kernel.
|
|
186
|
+
"""
|
|
187
|
+
B, H, N = scores.shape
|
|
188
|
+
device = scores.device
|
|
189
|
+
|
|
190
|
+
# Build candidate mask: exclude tokens too close to the end of the sequence
|
|
191
|
+
# and tokens that are trivially in the window for most queries.
|
|
192
|
+
max_w = window_sizes.max() # conservative: use global max window
|
|
193
|
+
|
|
194
|
+
# Mask out the last max_w positions (they're in everyone's window)
|
|
195
|
+
cand_scores = scores.clone()
|
|
196
|
+
if max_w > 0 and N > max_w:
|
|
197
|
+
# Soft exclusion: set in-window region to -inf so they're never picked
|
|
198
|
+
# as "global" landmarks (they're covered locally).
|
|
199
|
+
cand_scores[..., N - max_w:] = float("-inf")
|
|
200
|
+
if exclude_last > 0:
|
|
201
|
+
cand_scores[..., N - exclude_last:] = float("-inf")
|
|
202
|
+
|
|
203
|
+
# Top-k selection [B, H, k]
|
|
204
|
+
k_safe = min(k, N)
|
|
205
|
+
_, top_idx = torch.topk(cand_scores, k=k_safe, dim=-1) # [B, H, k]
|
|
206
|
+
|
|
207
|
+
# Sort indices for more cache-friendly access in the attention kernel
|
|
208
|
+
top_idx, _ = top_idx.sort(dim=-1)
|
|
209
|
+
|
|
210
|
+
# Broadcast to [B, H, N, k] — same landmarks for every query token
|
|
211
|
+
landmark_idx = top_idx.unsqueeze(2).expand(B, H, N, k_safe)
|
|
212
|
+
|
|
213
|
+
return landmark_idx.to(torch.int32)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def compute_landmark_idx(
|
|
217
|
+
X: torch.Tensor, # [B, H, N, D]
|
|
218
|
+
WV: torch.Tensor, # [H, D, D]
|
|
219
|
+
window_sizes: torch.Tensor, # [B, H, N]
|
|
220
|
+
k: int,
|
|
221
|
+
alpha: float = 0.6,
|
|
222
|
+
) -> torch.Tensor:
|
|
223
|
+
"""
|
|
224
|
+
Convenience function: score + select in one call.
|
|
225
|
+
Returns landmark_idx [B, H, N, k] int32.
|
|
226
|
+
"""
|
|
227
|
+
scores = compute_hybrid_energy_scores(X, WV, alpha)
|
|
228
|
+
return select_landmarks(scores, k, window_sizes)
|