superlinear 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- apps/__init__.py +4 -0
- apps/cli/__init__.py +8 -0
- apps/cli/bm25_rag.py +471 -0
- apps/cli/chat_repl.py +1497 -0
- apps/cli/client.py +195 -0
- apps/cli/docs_repl.py +2275 -0
- apps/cli/light_rag.py +729 -0
- apps/cli/local_snapshots.py +139 -0
- apps/cli/locks.py +214 -0
- apps/cli/main.py +457 -0
- apps/cli/output.py +32 -0
- apps/cli/server_cmds.py +516 -0
- apps/cli/session_cmds.py +491 -0
- apps/cli/snapshot_cmds.py +303 -0
- apps/cli/state.py +265 -0
- apps/server/__init__.py +4 -0
- apps/server/app.py +1363 -0
- apps/server/main.py +313 -0
- superlinear/__init__.py +114 -0
- superlinear/_version.py +3 -0
- superlinear/engine/__init__.py +10 -0
- superlinear/engine/adapters/__init__.py +12 -0
- superlinear/engine/adapters/base.py +91 -0
- superlinear/engine/adapters/superlinear.py +1233 -0
- superlinear/engine/chat_engine.py +1173 -0
- superlinear/engine/chat_types.py +130 -0
- superlinear/engine/registry.py +51 -0
- superlinear/engine/repetition.py +203 -0
- superlinear/engine/session_snapshots.py +451 -0
- superlinear/engine/tool_parser.py +83 -0
- superlinear/engine/types.py +42 -0
- superlinear/kernels/__init__.py +2 -0
- superlinear/kernels/common/__init__.py +21 -0
- superlinear/kernels/common/adjustment.py +106 -0
- superlinear/kernels/common/power.py +154 -0
- superlinear/kernels/superlinear/__init__.py +10 -0
- superlinear/kernels/superlinear/attention/__init__.py +78 -0
- superlinear/kernels/superlinear/attention/_prefill.py +940 -0
- superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
- superlinear/kernels/superlinear/attention/api.py +433 -0
- superlinear/kernels/superlinear/search/__init__.py +33 -0
- superlinear/kernels/superlinear/search/_reference.py +204 -0
- superlinear/kernels/superlinear/search/_triton.py +488 -0
- superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
- superlinear/kernels/superlinear/search/api.py +200 -0
- superlinear/kernels/superlinear/span/__init__.py +41 -0
- superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
- superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
- superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
- superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
- superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
- superlinear/kernels/superlinear/span/api.py +296 -0
- superlinear/kernels/superlinear/span/masks.py +187 -0
- superlinear/py.typed +0 -0
- superlinear/runtime.py +71 -0
- superlinear-0.1.0.dist-info/METADATA +469 -0
- superlinear-0.1.0.dist-info/RECORD +62 -0
- superlinear-0.1.0.dist-info/WHEEL +5 -0
- superlinear-0.1.0.dist-info/entry_points.txt +2 -0
- superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
- superlinear-0.1.0.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1226 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import triton
|
|
6
|
+
import triton.language as tl
|
|
7
|
+
|
|
8
|
+
from superlinear.kernels.common.power import derive_stripe_power_params, max_stripe_index_for_token_pos, window_len_from_sw_index
|
|
9
|
+
from superlinear.kernels.common.adjustment import compute_qend_from_qanchor
|
|
10
|
+
from ._triton_impl import (
|
|
11
|
+
_assert_no_span_sw_overlap,
|
|
12
|
+
_next_power_of_two,
|
|
13
|
+
decode_span_attention_staged,
|
|
14
|
+
fused_span_attention,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@triton.jit
|
|
19
|
+
def fused_span_forward_kernel_gqa(
|
|
20
|
+
Q_ptr, K_ptr, V_ptr,
|
|
21
|
+
qstart_ptr, qend_ptr, cache_position_ptr,
|
|
22
|
+
attn_mask_ptr,
|
|
23
|
+
Out_ptr,
|
|
24
|
+
B, H_Q, H_KV, L_Q, L_KV,
|
|
25
|
+
kv_repeat,
|
|
26
|
+
window_size,
|
|
27
|
+
sm_scale,
|
|
28
|
+
K_VAL: tl.constexpr,
|
|
29
|
+
BLOCK_K: tl.constexpr,
|
|
30
|
+
BLOCK_D: tl.constexpr,
|
|
31
|
+
D_HEAD: tl.constexpr,
|
|
32
|
+
SPAN_MAX_BLOCKS: tl.constexpr,
|
|
33
|
+
SW_MAX_BLOCKS: tl.constexpr,
|
|
34
|
+
HAS_ATTN_MASK: tl.constexpr,
|
|
35
|
+
):
|
|
36
|
+
pid = tl.program_id(0)
|
|
37
|
+
span_index = pid % K_VAL
|
|
38
|
+
q_idx = (pid // K_VAL) % L_Q
|
|
39
|
+
q_head_idx = (pid // (K_VAL * L_Q)) % H_Q
|
|
40
|
+
batch_idx = pid // (K_VAL * L_Q * H_Q)
|
|
41
|
+
|
|
42
|
+
kv_head_idx = q_head_idx // kv_repeat
|
|
43
|
+
kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
|
|
44
|
+
|
|
45
|
+
d_range = tl.arange(0, BLOCK_D)
|
|
46
|
+
d_mask = d_range < D_HEAD
|
|
47
|
+
|
|
48
|
+
q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
|
|
49
|
+
q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
|
|
50
|
+
|
|
51
|
+
cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
|
|
52
|
+
cache_pos = tl.minimum(cache_pos, L_KV - 1)
|
|
53
|
+
|
|
54
|
+
span_offset = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
|
|
55
|
+
span_start = tl.load(qstart_ptr + span_offset, mask=True, other=-1).to(tl.int32)
|
|
56
|
+
span_end = tl.load(qend_ptr + span_offset, mask=True, other=-1).to(tl.int32)
|
|
57
|
+
span_start = tl.maximum(span_start, 0)
|
|
58
|
+
span_end = tl.minimum(span_end, L_KV - 1)
|
|
59
|
+
span_valid = (span_end >= span_start) & (span_end >= 0)
|
|
60
|
+
|
|
61
|
+
window = window_size
|
|
62
|
+
sw_end = tl.minimum(cache_pos, L_KV - 1)
|
|
63
|
+
sw_start = sw_end - (window - 1)
|
|
64
|
+
sw_start = tl.maximum(sw_start, 0)
|
|
65
|
+
sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
|
|
66
|
+
|
|
67
|
+
seg1_start, seg1_end, seg1_valid = span_start, span_end, span_valid
|
|
68
|
+
seg2_start, seg2_end, seg2_valid = sw_start, sw_end, sw_valid
|
|
69
|
+
|
|
70
|
+
seg1_start = tl.maximum(seg1_start, 0)
|
|
71
|
+
seg1_end = tl.minimum(seg1_end, L_KV - 1)
|
|
72
|
+
seg1_valid = seg1_valid & (seg1_start <= seg1_end)
|
|
73
|
+
|
|
74
|
+
seg2_start = tl.maximum(seg2_start, 0)
|
|
75
|
+
seg2_end = tl.minimum(seg2_end, L_KV - 1)
|
|
76
|
+
seg2_valid = seg2_valid & (seg2_start <= seg2_end)
|
|
77
|
+
|
|
78
|
+
k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
|
|
79
|
+
attn_base = batch_idx * L_KV
|
|
80
|
+
|
|
81
|
+
m_i = -float('inf')
|
|
82
|
+
l_i = 0.0
|
|
83
|
+
acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
|
|
84
|
+
scale = tl.full((1,), sm_scale, tl.float32)
|
|
85
|
+
|
|
86
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
87
|
+
block_start = seg1_start + block_idx * BLOCK_K
|
|
88
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
89
|
+
in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
|
|
90
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
91
|
+
|
|
92
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
93
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
94
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
95
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
96
|
+
|
|
97
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
98
|
+
|
|
99
|
+
if HAS_ATTN_MASK:
|
|
100
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
101
|
+
in_range = in_range & attn_mask_vals
|
|
102
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
103
|
+
|
|
104
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
105
|
+
block_max = tl.max(logits, axis=0)
|
|
106
|
+
m_new = tl.maximum(m_i, block_max)
|
|
107
|
+
alpha = tl.exp(m_i - m_new)
|
|
108
|
+
p = tl.exp(logits - m_new)
|
|
109
|
+
l_i = l_i * alpha + tl.sum(p, axis=0)
|
|
110
|
+
acc = acc * alpha + tl.sum(p[:, None] * v_block, axis=0)
|
|
111
|
+
m_i = m_new
|
|
112
|
+
|
|
113
|
+
if SW_MAX_BLOCKS > 0:
|
|
114
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
115
|
+
block_start = seg2_start + block_idx * BLOCK_K
|
|
116
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
117
|
+
in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
|
|
118
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
119
|
+
|
|
120
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
121
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
122
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
123
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
124
|
+
|
|
125
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
126
|
+
|
|
127
|
+
if HAS_ATTN_MASK:
|
|
128
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
129
|
+
in_range = in_range & attn_mask_vals
|
|
130
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
131
|
+
|
|
132
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
133
|
+
block_max = tl.max(logits, axis=0)
|
|
134
|
+
m_new = tl.maximum(m_i, block_max)
|
|
135
|
+
alpha = tl.exp(m_i - m_new)
|
|
136
|
+
p = tl.exp(logits - m_new)
|
|
137
|
+
l_i = l_i * alpha + tl.sum(p, axis=0)
|
|
138
|
+
acc = acc * alpha + tl.sum(p[:, None] * v_block, axis=0)
|
|
139
|
+
m_i = m_new
|
|
140
|
+
|
|
141
|
+
acc = tl.where(l_i > 0, acc / l_i, 0.0)
|
|
142
|
+
out_index = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
|
|
143
|
+
out_base = Out_ptr + out_index.to(tl.int64) * D_HEAD
|
|
144
|
+
tl.store(out_base + d_range, acc, mask=d_mask)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@triton.jit
|
|
148
|
+
def fused_span_backward_kernel_gqa(
|
|
149
|
+
Q_ptr, K_ptr, V_ptr, dOut_ptr,
|
|
150
|
+
qstart_ptr, qend_ptr, cache_position_ptr,
|
|
151
|
+
attn_mask_ptr,
|
|
152
|
+
dQ_ptr, dK_ptr, dV_ptr,
|
|
153
|
+
B, H_Q, H_KV, L_Q, L_KV,
|
|
154
|
+
kv_repeat,
|
|
155
|
+
window_size,
|
|
156
|
+
sm_scale,
|
|
157
|
+
K_VAL: tl.constexpr,
|
|
158
|
+
BLOCK_K: tl.constexpr,
|
|
159
|
+
BLOCK_D: tl.constexpr,
|
|
160
|
+
D_HEAD: tl.constexpr,
|
|
161
|
+
SPAN_MAX_BLOCKS: tl.constexpr,
|
|
162
|
+
SW_MAX_BLOCKS: tl.constexpr,
|
|
163
|
+
HAS_ATTN_MASK: tl.constexpr,
|
|
164
|
+
WRITE_DKV_PER_QHEAD: tl.constexpr,
|
|
165
|
+
):
|
|
166
|
+
pid = tl.program_id(0)
|
|
167
|
+
span_index = pid % K_VAL
|
|
168
|
+
q_idx = (pid // K_VAL) % L_Q
|
|
169
|
+
q_head_idx = (pid // (K_VAL * L_Q)) % H_Q
|
|
170
|
+
batch_idx = pid // (K_VAL * L_Q * H_Q)
|
|
171
|
+
|
|
172
|
+
kv_head_idx = q_head_idx // kv_repeat
|
|
173
|
+
kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
|
|
174
|
+
|
|
175
|
+
d_range = tl.arange(0, BLOCK_D)
|
|
176
|
+
d_mask = d_range < D_HEAD
|
|
177
|
+
|
|
178
|
+
q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
|
|
179
|
+
q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
|
|
180
|
+
|
|
181
|
+
cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
|
|
182
|
+
cache_pos = tl.minimum(cache_pos, L_KV - 1)
|
|
183
|
+
|
|
184
|
+
span_offset = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
|
|
185
|
+
span_start = tl.load(qstart_ptr + span_offset, mask=True, other=-1).to(tl.int32)
|
|
186
|
+
span_end = tl.load(qend_ptr + span_offset, mask=True, other=-1).to(tl.int32)
|
|
187
|
+
span_start = tl.maximum(span_start, 0)
|
|
188
|
+
span_end = tl.minimum(span_end, L_KV - 1)
|
|
189
|
+
span_valid = (span_end >= span_start) & (span_end >= 0)
|
|
190
|
+
|
|
191
|
+
window = window_size
|
|
192
|
+
sw_end = tl.minimum(cache_pos, L_KV - 1)
|
|
193
|
+
sw_start = sw_end - (window - 1)
|
|
194
|
+
sw_start = tl.maximum(sw_start, 0)
|
|
195
|
+
sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
|
|
196
|
+
|
|
197
|
+
seg1_start, seg1_end, seg1_valid = span_start, span_end, span_valid
|
|
198
|
+
seg2_start, seg2_end, seg2_valid = sw_start, sw_end, sw_valid
|
|
199
|
+
|
|
200
|
+
seg1_start = tl.maximum(seg1_start, 0)
|
|
201
|
+
seg1_end = tl.minimum(seg1_end, L_KV - 1)
|
|
202
|
+
seg1_valid = seg1_valid & (seg1_start <= seg1_end)
|
|
203
|
+
|
|
204
|
+
seg2_start = tl.maximum(seg2_start, 0)
|
|
205
|
+
seg2_end = tl.minimum(seg2_end, L_KV - 1)
|
|
206
|
+
seg2_valid = seg2_valid & (seg2_start <= seg2_end)
|
|
207
|
+
|
|
208
|
+
k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
|
|
209
|
+
if WRITE_DKV_PER_QHEAD:
|
|
210
|
+
dkv_head_offset = ((batch_idx * H_Q + q_head_idx) * L_KV).to(tl.int64) * D_HEAD
|
|
211
|
+
else:
|
|
212
|
+
dkv_head_offset = k_head_offset
|
|
213
|
+
attn_base = batch_idx * L_KV
|
|
214
|
+
|
|
215
|
+
m_i = -float('inf')
|
|
216
|
+
l_i = 0.0
|
|
217
|
+
scale = tl.full((1,), sm_scale, tl.float32)
|
|
218
|
+
|
|
219
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
220
|
+
block_start = seg1_start + block_idx * BLOCK_K
|
|
221
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
222
|
+
in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
|
|
223
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
224
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
225
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
226
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
227
|
+
if HAS_ATTN_MASK:
|
|
228
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
229
|
+
in_range = in_range & attn_mask_vals
|
|
230
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
231
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
232
|
+
block_max = tl.max(logits, axis=0)
|
|
233
|
+
m_new = tl.maximum(m_i, block_max)
|
|
234
|
+
alpha = tl.exp(m_i - m_new)
|
|
235
|
+
p = tl.exp(logits - m_new)
|
|
236
|
+
l_i = l_i * alpha + tl.sum(p, axis=0)
|
|
237
|
+
m_i = m_new
|
|
238
|
+
|
|
239
|
+
if SW_MAX_BLOCKS > 0:
|
|
240
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
241
|
+
block_start = seg2_start + block_idx * BLOCK_K
|
|
242
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
243
|
+
in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
|
|
244
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
245
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
246
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
247
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
248
|
+
if HAS_ATTN_MASK:
|
|
249
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
250
|
+
in_range = in_range & attn_mask_vals
|
|
251
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
252
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
253
|
+
block_max = tl.max(logits, axis=0)
|
|
254
|
+
m_new = tl.maximum(m_i, block_max)
|
|
255
|
+
alpha = tl.exp(m_i - m_new)
|
|
256
|
+
p = tl.exp(logits - m_new)
|
|
257
|
+
l_i = l_i * alpha + tl.sum(p, axis=0)
|
|
258
|
+
m_i = m_new
|
|
259
|
+
|
|
260
|
+
dO_index = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
|
|
261
|
+
dO_base = dOut_ptr + dO_index.to(tl.int64) * D_HEAD
|
|
262
|
+
dO = tl.load(dO_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
|
|
263
|
+
|
|
264
|
+
grad_q = tl.zeros((BLOCK_D,), dtype=tl.float32)
|
|
265
|
+
|
|
266
|
+
dot_total = 0.0
|
|
267
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
268
|
+
block_start = seg1_start + block_idx * BLOCK_K
|
|
269
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
270
|
+
in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
|
|
271
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
272
|
+
|
|
273
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
274
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
275
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
276
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
277
|
+
|
|
278
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
279
|
+
if HAS_ATTN_MASK:
|
|
280
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
281
|
+
in_range = in_range & attn_mask_vals
|
|
282
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
283
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
284
|
+
weights = tl.exp(logits - m_i) / l_i
|
|
285
|
+
weights = tl.where(in_range, weights, 0.0)
|
|
286
|
+
grad_w = tl.sum(v_block * dO[None, :], axis=1)
|
|
287
|
+
dot_total += tl.sum(grad_w * weights, axis=0)
|
|
288
|
+
|
|
289
|
+
if SW_MAX_BLOCKS > 0:
|
|
290
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
291
|
+
block_start = seg2_start + block_idx * BLOCK_K
|
|
292
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
293
|
+
in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
|
|
294
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
295
|
+
|
|
296
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
297
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
298
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
299
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
300
|
+
|
|
301
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
302
|
+
if HAS_ATTN_MASK:
|
|
303
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
304
|
+
in_range = in_range & attn_mask_vals
|
|
305
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
306
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
307
|
+
weights = tl.exp(logits - m_i) / l_i
|
|
308
|
+
weights = tl.where(in_range, weights, 0.0)
|
|
309
|
+
grad_w = tl.sum(v_block * dO[None, :], axis=1)
|
|
310
|
+
dot_total += tl.sum(grad_w * weights, axis=0)
|
|
311
|
+
|
|
312
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
313
|
+
block_start = seg1_start + block_idx * BLOCK_K
|
|
314
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
315
|
+
in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
|
|
316
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
317
|
+
|
|
318
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
319
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
320
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
321
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
322
|
+
|
|
323
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
324
|
+
if HAS_ATTN_MASK:
|
|
325
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
326
|
+
in_range = in_range & attn_mask_vals
|
|
327
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
328
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
329
|
+
weights = tl.exp(logits - m_i) / l_i
|
|
330
|
+
weights = tl.where(in_range, weights, 0.0)
|
|
331
|
+
grad_w = tl.sum(v_block * dO[None, :], axis=1)
|
|
332
|
+
grad_s = (grad_w - dot_total) * weights * sm_scale
|
|
333
|
+
grad_s = tl.where(in_range, grad_s, 0.0)
|
|
334
|
+
|
|
335
|
+
grad_q = grad_q + tl.sum(grad_s[:, None] * k_block, axis=0)
|
|
336
|
+
|
|
337
|
+
dk = grad_s[:, None] * q[None, :]
|
|
338
|
+
dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
339
|
+
tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, None] & d_mask[None, :])
|
|
340
|
+
|
|
341
|
+
dv = weights[:, None] * dO[None, :]
|
|
342
|
+
tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, None] & d_mask[None, :])
|
|
343
|
+
|
|
344
|
+
if SW_MAX_BLOCKS > 0:
|
|
345
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
346
|
+
block_start = seg2_start + block_idx * BLOCK_K
|
|
347
|
+
k_pos = block_start + tl.arange(0, BLOCK_K)
|
|
348
|
+
in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
|
|
349
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
350
|
+
|
|
351
|
+
k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
352
|
+
v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
353
|
+
k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
354
|
+
v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
355
|
+
|
|
356
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
357
|
+
if HAS_ATTN_MASK:
|
|
358
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
359
|
+
in_range = in_range & attn_mask_vals
|
|
360
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
361
|
+
if tl.sum(in_range, axis=0) > 0:
|
|
362
|
+
weights = tl.exp(logits - m_i) / l_i
|
|
363
|
+
weights = tl.where(in_range, weights, 0.0)
|
|
364
|
+
grad_w = tl.sum(v_block * dO[None, :], axis=1)
|
|
365
|
+
grad_s = (grad_w - dot_total) * weights * sm_scale
|
|
366
|
+
grad_s = tl.where(in_range, grad_s, 0.0)
|
|
367
|
+
|
|
368
|
+
grad_q = grad_q + tl.sum(grad_s[:, None] * k_block, axis=0)
|
|
369
|
+
|
|
370
|
+
dk = grad_s[:, None] * q[None, :]
|
|
371
|
+
dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
372
|
+
tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, None] & d_mask[None, :])
|
|
373
|
+
|
|
374
|
+
dv = weights[:, None] * dO[None, :]
|
|
375
|
+
tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, None] & d_mask[None, :])
|
|
376
|
+
|
|
377
|
+
dq_base = dQ_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
|
|
378
|
+
tl.atomic_add(dq_base + d_range, grad_q, mask=d_mask)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@triton.jit
|
|
382
|
+
def fused_span_backward_kernel_gqa_fused_spans(
|
|
383
|
+
Q_ptr, K_ptr, V_ptr, dOut_ptr,
|
|
384
|
+
qstart_ptr, qend_ptr, cache_position_ptr,
|
|
385
|
+
attn_mask_ptr,
|
|
386
|
+
dQ_ptr, dK_ptr, dV_ptr,
|
|
387
|
+
B, H_Q, H_KV, L_Q, L_KV,
|
|
388
|
+
kv_repeat,
|
|
389
|
+
window_size,
|
|
390
|
+
sm_scale,
|
|
391
|
+
K_VAL: tl.constexpr, # Power of 2 for tl.arange
|
|
392
|
+
K_ACTUAL: tl.constexpr, # Actual number of spans
|
|
393
|
+
BLOCK_K: tl.constexpr,
|
|
394
|
+
BLOCK_D: tl.constexpr,
|
|
395
|
+
D_HEAD: tl.constexpr,
|
|
396
|
+
SPAN_MAX_BLOCKS: tl.constexpr,
|
|
397
|
+
SW_MAX_BLOCKS: tl.constexpr,
|
|
398
|
+
HAS_ATTN_MASK: tl.constexpr,
|
|
399
|
+
WRITE_DKV_PER_QHEAD: tl.constexpr,
|
|
400
|
+
):
|
|
401
|
+
pid = tl.program_id(0)
|
|
402
|
+
q_idx = pid % L_Q
|
|
403
|
+
q_head_idx = (pid // L_Q) % H_Q
|
|
404
|
+
batch_idx = pid // (L_Q * H_Q)
|
|
405
|
+
|
|
406
|
+
kv_head_idx = q_head_idx // kv_repeat
|
|
407
|
+
kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
|
|
408
|
+
|
|
409
|
+
span_range = tl.arange(0, K_VAL)
|
|
410
|
+
span_mask = span_range < K_ACTUAL # Mask for valid spans
|
|
411
|
+
k_range = tl.arange(0, BLOCK_K)
|
|
412
|
+
d_range = tl.arange(0, BLOCK_D)
|
|
413
|
+
d_mask = d_range < D_HEAD
|
|
414
|
+
|
|
415
|
+
q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
|
|
416
|
+
q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
|
|
417
|
+
|
|
418
|
+
cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
|
|
419
|
+
cache_pos = tl.minimum(cache_pos, L_KV - 1)
|
|
420
|
+
|
|
421
|
+
# Use K_ACTUAL for actual data layout
|
|
422
|
+
span_offset_base = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_ACTUAL
|
|
423
|
+
span_offsets = span_offset_base + span_range
|
|
424
|
+
span_start = tl.load(qstart_ptr + span_offsets, mask=span_mask, other=-1).to(tl.int32)
|
|
425
|
+
span_end = tl.load(qend_ptr + span_offsets, mask=span_mask, other=-1).to(tl.int32)
|
|
426
|
+
span_start = tl.maximum(span_start, 0)
|
|
427
|
+
span_end = tl.minimum(span_end, L_KV - 1)
|
|
428
|
+
span_valid = (span_end >= span_start) & (span_end >= 0) & span_mask
|
|
429
|
+
|
|
430
|
+
window = window_size
|
|
431
|
+
sw_end = tl.minimum(cache_pos, L_KV - 1)
|
|
432
|
+
sw_start = sw_end - (window - 1)
|
|
433
|
+
sw_start = tl.maximum(sw_start, 0)
|
|
434
|
+
sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
|
|
435
|
+
|
|
436
|
+
k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
|
|
437
|
+
if WRITE_DKV_PER_QHEAD:
|
|
438
|
+
dkv_head_offset = ((batch_idx * H_Q + q_head_idx) * L_KV).to(tl.int64) * D_HEAD
|
|
439
|
+
else:
|
|
440
|
+
dkv_head_offset = k_head_offset
|
|
441
|
+
attn_base = batch_idx * L_KV
|
|
442
|
+
|
|
443
|
+
# Use K_ACTUAL for actual data layout
|
|
444
|
+
dO_index_base = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_ACTUAL
|
|
445
|
+
dO_offsets = (dO_index_base + span_range)[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
446
|
+
dO = tl.load(dOut_ptr + dO_offsets, mask=span_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
447
|
+
|
|
448
|
+
m_i = tl.full((K_VAL,), -float('inf'), tl.float32)
|
|
449
|
+
l_i = tl.zeros((K_VAL,), tl.float32)
|
|
450
|
+
dot_num = tl.zeros((K_VAL,), tl.float32)
|
|
451
|
+
scale = tl.full((1,), sm_scale, tl.float32)
|
|
452
|
+
|
|
453
|
+
# Pass 1: compute m_i, l_i, dot_num (online) for each span.
|
|
454
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
455
|
+
block_start = span_start + block_idx * BLOCK_K
|
|
456
|
+
k_pos = block_start[:, None] + k_range[None, :]
|
|
457
|
+
in_range = span_valid[:, None] & (k_pos >= span_start[:, None]) & (k_pos <= span_end[:, None])
|
|
458
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
459
|
+
|
|
460
|
+
kv_offsets = k_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
|
|
461
|
+
k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
|
|
462
|
+
v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
|
|
463
|
+
|
|
464
|
+
logits = tl.sum(k_block * q[None, None, :], axis=2) * scale
|
|
465
|
+
|
|
466
|
+
if HAS_ATTN_MASK:
|
|
467
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
468
|
+
in_range = in_range & attn_mask_vals
|
|
469
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
470
|
+
|
|
471
|
+
grad_w = tl.sum(v_block * dO[:, None, :], axis=2)
|
|
472
|
+
|
|
473
|
+
block_max = tl.max(logits, axis=1)
|
|
474
|
+
m_new = tl.maximum(m_i, block_max)
|
|
475
|
+
alpha = tl.exp(m_i - m_new)
|
|
476
|
+
alpha = tl.where(m_new == -float('inf'), 0.0, alpha)
|
|
477
|
+
|
|
478
|
+
logits_shift = logits - m_new[:, None]
|
|
479
|
+
logits_shift = tl.where(m_new[:, None] == -float('inf'), float('-inf'), logits_shift)
|
|
480
|
+
p = tl.exp(logits_shift)
|
|
481
|
+
|
|
482
|
+
l_i = l_i * alpha + tl.sum(p, axis=1)
|
|
483
|
+
dot_num = dot_num * alpha + tl.sum(p * grad_w, axis=1)
|
|
484
|
+
m_i = m_new
|
|
485
|
+
|
|
486
|
+
if SW_MAX_BLOCKS > 0:
|
|
487
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
488
|
+
block_start = sw_start + block_idx * BLOCK_K
|
|
489
|
+
k_pos = block_start + k_range
|
|
490
|
+
in_range = sw_valid & (k_pos >= sw_start) & (k_pos <= sw_end)
|
|
491
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
492
|
+
|
|
493
|
+
kv_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
494
|
+
k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
495
|
+
v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
496
|
+
|
|
497
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
498
|
+
|
|
499
|
+
if HAS_ATTN_MASK:
|
|
500
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
501
|
+
in_range = in_range & attn_mask_vals
|
|
502
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
503
|
+
|
|
504
|
+
grad_w = tl.sum(v_block[None, :, :] * dO[:, None, :], axis=2)
|
|
505
|
+
|
|
506
|
+
block_max = tl.max(logits, axis=0)
|
|
507
|
+
m_new = tl.maximum(m_i, block_max)
|
|
508
|
+
alpha = tl.exp(m_i - m_new)
|
|
509
|
+
alpha = tl.where(m_new == -float('inf'), 0.0, alpha)
|
|
510
|
+
|
|
511
|
+
logits_shift = logits[None, :] - m_new[:, None]
|
|
512
|
+
logits_shift = tl.where(m_new[:, None] == -float('inf'), float('-inf'), logits_shift)
|
|
513
|
+
p = tl.exp(logits_shift)
|
|
514
|
+
|
|
515
|
+
l_i = l_i * alpha + tl.sum(p, axis=1)
|
|
516
|
+
dot_num = dot_num * alpha + tl.sum(p * grad_w, axis=1)
|
|
517
|
+
m_i = m_new
|
|
518
|
+
|
|
519
|
+
inv_l = tl.where(l_i > 0, 1.0 / l_i, 0.0)
|
|
520
|
+
dot_total = dot_num * inv_l
|
|
521
|
+
|
|
522
|
+
grad_q = tl.zeros((BLOCK_D,), dtype=tl.float32)
|
|
523
|
+
|
|
524
|
+
# Pass 2: accumulate gradients for all spans.
|
|
525
|
+
for block_idx in range(SPAN_MAX_BLOCKS):
|
|
526
|
+
block_start = span_start + block_idx * BLOCK_K
|
|
527
|
+
k_pos = block_start[:, None] + k_range[None, :]
|
|
528
|
+
in_range = span_valid[:, None] & (k_pos >= span_start[:, None]) & (k_pos <= span_end[:, None])
|
|
529
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
530
|
+
|
|
531
|
+
kv_offsets = k_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
|
|
532
|
+
k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
|
|
533
|
+
v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
|
|
534
|
+
|
|
535
|
+
logits = tl.sum(k_block * q[None, None, :], axis=2) * scale
|
|
536
|
+
|
|
537
|
+
if HAS_ATTN_MASK:
|
|
538
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
539
|
+
in_range = in_range & attn_mask_vals
|
|
540
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
541
|
+
|
|
542
|
+
weights = tl.exp(logits - m_i[:, None]) * inv_l[:, None]
|
|
543
|
+
weights = tl.where(inv_l[:, None] > 0, weights, 0.0)
|
|
544
|
+
weights = tl.where(in_range, weights, 0.0)
|
|
545
|
+
|
|
546
|
+
grad_w = tl.sum(v_block * dO[:, None, :], axis=2)
|
|
547
|
+
grad_s = (grad_w - dot_total[:, None]) * weights * sm_scale
|
|
548
|
+
grad_s = tl.where(in_range, grad_s, 0.0)
|
|
549
|
+
|
|
550
|
+
grad_q_span = tl.sum(grad_s[:, :, None] * k_block, axis=1)
|
|
551
|
+
grad_q += tl.sum(grad_q_span, axis=0)
|
|
552
|
+
|
|
553
|
+
dk = grad_s[:, :, None] * q[None, None, :]
|
|
554
|
+
dv = weights[:, :, None] * dO[:, None, :]
|
|
555
|
+
|
|
556
|
+
dkv_offsets = dkv_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
|
|
557
|
+
tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, :, None] & d_mask[None, None, :])
|
|
558
|
+
tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, :, None] & d_mask[None, None, :])
|
|
559
|
+
|
|
560
|
+
if SW_MAX_BLOCKS > 0:
|
|
561
|
+
for block_idx in range(SW_MAX_BLOCKS):
|
|
562
|
+
block_start = sw_start + block_idx * BLOCK_K
|
|
563
|
+
k_pos = block_start + k_range
|
|
564
|
+
in_range = sw_valid & (k_pos >= sw_start) & (k_pos <= sw_end)
|
|
565
|
+
k_pos_safe = tl.where(in_range, k_pos, 0)
|
|
566
|
+
|
|
567
|
+
kv_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
568
|
+
k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
569
|
+
v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
|
|
570
|
+
|
|
571
|
+
logits = tl.sum(k_block * q[None, :], axis=1) * scale
|
|
572
|
+
|
|
573
|
+
if HAS_ATTN_MASK:
|
|
574
|
+
attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
|
|
575
|
+
in_range = in_range & attn_mask_vals
|
|
576
|
+
logits = tl.where(in_range, logits, float('-inf'))
|
|
577
|
+
|
|
578
|
+
weights = tl.exp(logits[None, :] - m_i[:, None]) * inv_l[:, None]
|
|
579
|
+
weights = tl.where(inv_l[:, None] > 0, weights, 0.0)
|
|
580
|
+
weights = tl.where(in_range[None, :], weights, 0.0)
|
|
581
|
+
|
|
582
|
+
grad_w = tl.sum(v_block[None, :, :] * dO[:, None, :], axis=2)
|
|
583
|
+
grad_s = (grad_w - dot_total[:, None]) * weights * sm_scale
|
|
584
|
+
grad_s = tl.where(in_range[None, :], grad_s, 0.0)
|
|
585
|
+
|
|
586
|
+
grad_q_span = tl.sum(grad_s[:, :, None] * k_block[None, :, :], axis=1)
|
|
587
|
+
grad_q += tl.sum(grad_q_span, axis=0)
|
|
588
|
+
|
|
589
|
+
grad_s_sum = tl.sum(grad_s, axis=0)
|
|
590
|
+
dk_total = grad_s_sum[:, None] * q[None, :]
|
|
591
|
+
dv_total = tl.sum(weights[:, :, None] * dO[:, None, :], axis=0)
|
|
592
|
+
|
|
593
|
+
dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
|
|
594
|
+
tl.atomic_add(dK_ptr + dkv_offsets, dk_total, mask=in_range[:, None] & d_mask[None, :])
|
|
595
|
+
tl.atomic_add(dV_ptr + dkv_offsets, dv_total, mask=in_range[:, None] & d_mask[None, :])
|
|
596
|
+
|
|
597
|
+
dq_base = dQ_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
|
|
598
|
+
tl.store(dq_base + d_range, grad_q, mask=d_mask)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def fused_span_triton_gqa(
|
|
603
|
+
Q2,
|
|
604
|
+
K,
|
|
605
|
+
V,
|
|
606
|
+
qstart,
|
|
607
|
+
qend,
|
|
608
|
+
cache_position,
|
|
609
|
+
attention_mask=None,
|
|
610
|
+
sw_index=0,
|
|
611
|
+
block_k: int = 64,
|
|
612
|
+
span_len_factor: float = 2.0,
|
|
613
|
+
span_power: float = 0.5,
|
|
614
|
+
search_power: float | None = None,
|
|
615
|
+
inv_search_power_int: int | None = 2,
|
|
616
|
+
):
|
|
617
|
+
assert Q2.is_cuda, "CUDA required for fused Triton path"
|
|
618
|
+
B, H_q, L_Q, D = Q2.shape
|
|
619
|
+
_, H_kv, L_KV, _ = K.shape
|
|
620
|
+
num_spans = qstart.shape[-1]
|
|
621
|
+
|
|
622
|
+
span_len_factor = float(span_len_factor)
|
|
623
|
+
if not math.isfinite(span_len_factor) or span_len_factor <= 0.0:
|
|
624
|
+
raise ValueError(f"span_len_factor must be finite and > 0 (got {span_len_factor})")
|
|
625
|
+
|
|
626
|
+
span_power_f = float(span_power)
|
|
627
|
+
if not math.isfinite(span_power_f) or not (0.0 < span_power_f < 1.0):
|
|
628
|
+
raise ValueError(f"span_power must be finite and in (0, 1) (got {span_power})")
|
|
629
|
+
|
|
630
|
+
kv_repeat = H_q // H_kv
|
|
631
|
+
|
|
632
|
+
# Ensure all tensors are on the same device as Q2 for multi-GPU compatibility
|
|
633
|
+
device = Q2.device
|
|
634
|
+
K = K.to(device)
|
|
635
|
+
V = V.to(device)
|
|
636
|
+
qstart = qstart.to(device)
|
|
637
|
+
qend = qend.to(device)
|
|
638
|
+
cache_position = cache_position.to(device)
|
|
639
|
+
if attention_mask is not None:
|
|
640
|
+
attention_mask = attention_mask.to(device)
|
|
641
|
+
|
|
642
|
+
Q2c = Q2.contiguous()
|
|
643
|
+
Kc = K.contiguous()
|
|
644
|
+
Vc = V.contiguous()
|
|
645
|
+
qstartc = qstart.contiguous()
|
|
646
|
+
qendc = qend.contiguous()
|
|
647
|
+
cachec = cache_position.to(torch.int32).contiguous()
|
|
648
|
+
|
|
649
|
+
if attention_mask is not None:
|
|
650
|
+
attn_mask = attention_mask[:, :L_KV].contiguous().to(torch.int8)
|
|
651
|
+
has_mask = True
|
|
652
|
+
else:
|
|
653
|
+
attn_mask = torch.empty((1,), device=device, dtype=torch.int8)
|
|
654
|
+
has_mask = False
|
|
655
|
+
|
|
656
|
+
window = window_len_from_sw_index(
|
|
657
|
+
int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
658
|
+
)
|
|
659
|
+
_assert_no_span_sw_overlap(qendc, cachec.view(1, 1, -1, 1), sw_index, L_KV, window_len=window)
|
|
660
|
+
|
|
661
|
+
out = torch.empty((B, H_q, L_Q, num_spans, D), device=device, dtype=Q2.dtype)
|
|
662
|
+
|
|
663
|
+
max_span_len = int(span_len_factor * math.ceil(float(L_KV) ** span_power_f) + 2)
|
|
664
|
+
span_max_blocks = triton.cdiv(max_span_len, block_k)
|
|
665
|
+
span_max_blocks = max(1, span_max_blocks)
|
|
666
|
+
sw_max_len = min(window, L_KV) if window > 0 else 0
|
|
667
|
+
sw_max_blocks = triton.cdiv(sw_max_len, block_k) if sw_max_len > 0 else 0
|
|
668
|
+
block_d = min(256, _next_power_of_two(D))
|
|
669
|
+
|
|
670
|
+
grid = (B * H_q * L_Q * num_spans,)
|
|
671
|
+
# Use torch.cuda.device context to ensure kernel launches on correct GPU
|
|
672
|
+
with torch.cuda.device(device):
|
|
673
|
+
fused_span_forward_kernel_gqa[grid](
|
|
674
|
+
Q2c, Kc, Vc,
|
|
675
|
+
qstartc, qendc, cachec,
|
|
676
|
+
attn_mask,
|
|
677
|
+
out,
|
|
678
|
+
B, H_q, H_kv, L_Q, L_KV,
|
|
679
|
+
kv_repeat,
|
|
680
|
+
window,
|
|
681
|
+
1.0 / math.sqrt(D),
|
|
682
|
+
K_VAL=num_spans,
|
|
683
|
+
BLOCK_K=block_k,
|
|
684
|
+
BLOCK_D=block_d,
|
|
685
|
+
D_HEAD=D,
|
|
686
|
+
SPAN_MAX_BLOCKS=span_max_blocks,
|
|
687
|
+
SW_MAX_BLOCKS=sw_max_blocks,
|
|
688
|
+
HAS_ATTN_MASK=has_mask,
|
|
689
|
+
)
|
|
690
|
+
# Synchronize to ensure kernel completes before output is used elsewhere.
|
|
691
|
+
# Skip during CUDA graph capture (synchronize is not allowed during capture).
|
|
692
|
+
# NOTE: This sync was added in commit 9dcb65f to fix multi-GPU race conditions
|
|
693
|
+
# with device_map='auto'. However, notebook 39.4 shows that the
|
|
694
|
+
# `with torch.cuda.device(device):` context is sufficient for correctness,
|
|
695
|
+
# and HuggingFace patterns don't use per-kernel sync. Consider removing
|
|
696
|
+
# this sync entirely if multi-GPU inference remains stable without it.
|
|
697
|
+
if not torch.cuda.is_current_stream_capturing():
|
|
698
|
+
torch.cuda.current_stream().synchronize()
|
|
699
|
+
return out
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
class FusedSpanGQATriton(torch.autograd.Function):
|
|
703
|
+
@staticmethod
|
|
704
|
+
def forward(
|
|
705
|
+
ctx,
|
|
706
|
+
Q2,
|
|
707
|
+
K,
|
|
708
|
+
V,
|
|
709
|
+
qstart,
|
|
710
|
+
qend,
|
|
711
|
+
cache_position,
|
|
712
|
+
attention_mask=None,
|
|
713
|
+
sw_index: int = 0,
|
|
714
|
+
block_k: int = 64,
|
|
715
|
+
span_len_factor: float = 2.0,
|
|
716
|
+
span_power: float = 0.5,
|
|
717
|
+
search_power: float | None = None,
|
|
718
|
+
inv_search_power_int: int | None = 2,
|
|
719
|
+
):
|
|
720
|
+
B, H_q, L_Q, D = Q2.shape
|
|
721
|
+
_, H_kv, _, _ = K.shape
|
|
722
|
+
assert H_q % H_kv == 0, "Query heads must be divisible by KV heads"
|
|
723
|
+
|
|
724
|
+
kv_repeat = H_q // H_kv
|
|
725
|
+
out = fused_span_triton_gqa(
|
|
726
|
+
Q2,
|
|
727
|
+
K,
|
|
728
|
+
V,
|
|
729
|
+
qstart,
|
|
730
|
+
qend,
|
|
731
|
+
cache_position,
|
|
732
|
+
attention_mask,
|
|
733
|
+
sw_index,
|
|
734
|
+
block_k,
|
|
735
|
+
span_len_factor=span_len_factor,
|
|
736
|
+
span_power=span_power,
|
|
737
|
+
search_power=search_power,
|
|
738
|
+
inv_search_power_int=inv_search_power_int,
|
|
739
|
+
)
|
|
740
|
+
saved_mask = attention_mask if attention_mask is not None else torch.tensor([], device=Q2.device)
|
|
741
|
+
ctx.save_for_backward(Q2, K, V, qstart, qend, cache_position.to(torch.int32), saved_mask)
|
|
742
|
+
ctx.sw_index = sw_index
|
|
743
|
+
ctx.block_k = block_k
|
|
744
|
+
ctx.kv_repeat = kv_repeat
|
|
745
|
+
ctx.span_len_factor = float(span_len_factor)
|
|
746
|
+
ctx.span_power = float(span_power)
|
|
747
|
+
ctx.window_len = window_len_from_sw_index(
|
|
748
|
+
int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
749
|
+
)
|
|
750
|
+
return out
|
|
751
|
+
|
|
752
|
+
@staticmethod
|
|
753
|
+
def backward(ctx, grad_out):
|
|
754
|
+
Q2, K, V, qstart, qend, cache_position, attention_mask_saved = ctx.saved_tensors
|
|
755
|
+
attention_mask = None if attention_mask_saved.numel() == 0 else attention_mask_saved
|
|
756
|
+
sw_index = ctx.sw_index
|
|
757
|
+
block_k = ctx.block_k
|
|
758
|
+
kv_repeat = ctx.kv_repeat
|
|
759
|
+
span_len_factor = float(getattr(ctx, "span_len_factor", 2.0))
|
|
760
|
+
span_power = float(getattr(ctx, "span_power", 0.5))
|
|
761
|
+
|
|
762
|
+
B, H_q, L_Q, D = Q2.shape
|
|
763
|
+
_, H_kv, L_KV, _ = K.shape
|
|
764
|
+
num_spans = qstart.shape[-1]
|
|
765
|
+
|
|
766
|
+
window = int(getattr(ctx, "window_len", (sw_index + 1) ** 2 - 1))
|
|
767
|
+
max_span_len = int(span_len_factor * math.ceil(float(L_KV) ** span_power) + 2)
|
|
768
|
+
span_max_blocks = triton.cdiv(max_span_len, block_k)
|
|
769
|
+
span_max_blocks = max(1, span_max_blocks)
|
|
770
|
+
sw_max_len = min(window, L_KV) if window > 0 else 0
|
|
771
|
+
sw_max_blocks = triton.cdiv(sw_max_len, block_k) if sw_max_len > 0 else 0
|
|
772
|
+
block_d = min(256, _next_power_of_two(D))
|
|
773
|
+
|
|
774
|
+
Q2c = Q2.contiguous()
|
|
775
|
+
Kc = K.contiguous()
|
|
776
|
+
Vc = V.contiguous()
|
|
777
|
+
qstartc = qstart.contiguous()
|
|
778
|
+
qendc = qend.contiguous()
|
|
779
|
+
cachec = cache_position.to(torch.int32).contiguous()
|
|
780
|
+
grad_out_c = grad_out.contiguous()
|
|
781
|
+
|
|
782
|
+
if attention_mask is not None:
|
|
783
|
+
attn_mask = attention_mask[:, :L_KV].contiguous().to(torch.int8)
|
|
784
|
+
has_mask = True
|
|
785
|
+
else:
|
|
786
|
+
attn_mask = torch.empty((1,), device=Q2.device, dtype=torch.int8)
|
|
787
|
+
has_mask = False
|
|
788
|
+
|
|
789
|
+
dQ = torch.zeros_like(Q2c, dtype=torch.float32)
|
|
790
|
+
split_dkv = os.getenv("SPAN_ATTN_GQA_BACKWARD_SPLIT_DKV", "1") != "0"
|
|
791
|
+
if split_dkv and kv_repeat > 1:
|
|
792
|
+
dK_rep = torch.zeros((B, H_q, L_KV, D), device=Kc.device, dtype=torch.float32)
|
|
793
|
+
dV_rep = torch.zeros((B, H_q, L_KV, D), device=Vc.device, dtype=torch.float32)
|
|
794
|
+
dK_out = dK_rep
|
|
795
|
+
dV_out = dV_rep
|
|
796
|
+
else:
|
|
797
|
+
dK = torch.zeros_like(Kc, dtype=torch.float32)
|
|
798
|
+
dV = torch.zeros_like(Vc, dtype=torch.float32)
|
|
799
|
+
dK_out = dK
|
|
800
|
+
dV_out = dV
|
|
801
|
+
|
|
802
|
+
write_dkv_per_qhead = split_dkv and kv_repeat > 1
|
|
803
|
+
fuse_spans = os.getenv("SPAN_ATTN_GQA_BACKWARD_FUSE_SPANS", "1") != "0"
|
|
804
|
+
# K_VAL must be power of 2 for tl.arange
|
|
805
|
+
k_val_padded = 1 << (num_spans - 1).bit_length() if num_spans > 0 else 1
|
|
806
|
+
if fuse_spans and num_spans > 1:
|
|
807
|
+
grid = (B * H_q * L_Q,)
|
|
808
|
+
fused_span_backward_kernel_gqa_fused_spans[grid](
|
|
809
|
+
Q2c, Kc, Vc, grad_out_c,
|
|
810
|
+
qstartc, qendc, cachec,
|
|
811
|
+
attn_mask,
|
|
812
|
+
dQ, dK_out, dV_out,
|
|
813
|
+
B, H_q, H_kv, L_Q, L_KV,
|
|
814
|
+
kv_repeat,
|
|
815
|
+
window,
|
|
816
|
+
1.0 / math.sqrt(D),
|
|
817
|
+
K_VAL=k_val_padded,
|
|
818
|
+
K_ACTUAL=num_spans,
|
|
819
|
+
BLOCK_K=block_k,
|
|
820
|
+
BLOCK_D=block_d,
|
|
821
|
+
D_HEAD=D,
|
|
822
|
+
SPAN_MAX_BLOCKS=span_max_blocks,
|
|
823
|
+
SW_MAX_BLOCKS=sw_max_blocks,
|
|
824
|
+
HAS_ATTN_MASK=has_mask,
|
|
825
|
+
WRITE_DKV_PER_QHEAD=write_dkv_per_qhead,
|
|
826
|
+
)
|
|
827
|
+
else:
|
|
828
|
+
grid = (B * H_q * L_Q * num_spans,)
|
|
829
|
+
fused_span_backward_kernel_gqa[grid](
|
|
830
|
+
Q2c, Kc, Vc, grad_out_c,
|
|
831
|
+
qstartc, qendc, cachec,
|
|
832
|
+
attn_mask,
|
|
833
|
+
dQ, dK_out, dV_out,
|
|
834
|
+
B, H_q, H_kv, L_Q, L_KV,
|
|
835
|
+
kv_repeat,
|
|
836
|
+
window,
|
|
837
|
+
1.0 / math.sqrt(D),
|
|
838
|
+
K_VAL=num_spans,
|
|
839
|
+
BLOCK_K=block_k,
|
|
840
|
+
BLOCK_D=block_d,
|
|
841
|
+
D_HEAD=D,
|
|
842
|
+
SPAN_MAX_BLOCKS=span_max_blocks,
|
|
843
|
+
SW_MAX_BLOCKS=sw_max_blocks,
|
|
844
|
+
HAS_ATTN_MASK=has_mask,
|
|
845
|
+
WRITE_DKV_PER_QHEAD=write_dkv_per_qhead,
|
|
846
|
+
)
|
|
847
|
+
|
|
848
|
+
if split_dkv and kv_repeat > 1:
|
|
849
|
+
dK = dK_rep.view(B, H_kv, kv_repeat, L_KV, D).sum(dim=2)
|
|
850
|
+
dV = dV_rep.view(B, H_kv, kv_repeat, L_KV, D).sum(dim=2)
|
|
851
|
+
|
|
852
|
+
return (
|
|
853
|
+
dQ.to(Q2.dtype),
|
|
854
|
+
dK.to(K.dtype),
|
|
855
|
+
dV.to(V.dtype),
|
|
856
|
+
None,
|
|
857
|
+
None,
|
|
858
|
+
None,
|
|
859
|
+
None,
|
|
860
|
+
None,
|
|
861
|
+
None,
|
|
862
|
+
None,
|
|
863
|
+
None,
|
|
864
|
+
None,
|
|
865
|
+
None,
|
|
866
|
+
)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
def fused_span_attention_gqa(
|
|
870
|
+
Q2,
|
|
871
|
+
K,
|
|
872
|
+
V,
|
|
873
|
+
qstart,
|
|
874
|
+
qend,
|
|
875
|
+
cache_position,
|
|
876
|
+
attention_mask=None,
|
|
877
|
+
sw_index=0,
|
|
878
|
+
block_k=64,
|
|
879
|
+
span_len_factor: float = 2.0,
|
|
880
|
+
span_power: float = 0.5,
|
|
881
|
+
search_power: float | None = None,
|
|
882
|
+
inv_search_power_int: int | None = 2,
|
|
883
|
+
):
|
|
884
|
+
B, H_q, L_Q, _ = Q2.shape
|
|
885
|
+
_, H_kv, _, _ = K.shape
|
|
886
|
+
assert H_q % H_kv == 0, "Query heads must be divisible by KV heads when using GQA"
|
|
887
|
+
|
|
888
|
+
if not Q2.is_cuda:
|
|
889
|
+
params = derive_stripe_power_params(
|
|
890
|
+
search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
891
|
+
)
|
|
892
|
+
if params.triton_inv_n != 2:
|
|
893
|
+
raise NotImplementedError(
|
|
894
|
+
"Non-CUDA fused_span_attention_gqa only supports p=0.5 for now "
|
|
895
|
+
"(use inv_search_power_int=2 or search_power=0.5)."
|
|
896
|
+
)
|
|
897
|
+
kv_repeat = H_q // H_kv
|
|
898
|
+
K_rep = K.repeat_interleave(kv_repeat, dim=1)
|
|
899
|
+
V_rep = V.repeat_interleave(kv_repeat, dim=1)
|
|
900
|
+
return fused_span_attention(
|
|
901
|
+
Q2,
|
|
902
|
+
K_rep,
|
|
903
|
+
V_rep,
|
|
904
|
+
qstart,
|
|
905
|
+
qend,
|
|
906
|
+
cache_position,
|
|
907
|
+
attention_mask=attention_mask,
|
|
908
|
+
sw_index=sw_index,
|
|
909
|
+
block_k=block_k,
|
|
910
|
+
span_len_factor=span_len_factor,
|
|
911
|
+
span_power=span_power,
|
|
912
|
+
search_power=search_power,
|
|
913
|
+
inv_search_power_int=inv_search_power_int,
|
|
914
|
+
)
|
|
915
|
+
|
|
916
|
+
return FusedSpanGQATriton.apply(
|
|
917
|
+
Q2,
|
|
918
|
+
K,
|
|
919
|
+
V,
|
|
920
|
+
qstart,
|
|
921
|
+
qend,
|
|
922
|
+
cache_position,
|
|
923
|
+
attention_mask,
|
|
924
|
+
sw_index,
|
|
925
|
+
block_k,
|
|
926
|
+
span_len_factor,
|
|
927
|
+
float(span_power),
|
|
928
|
+
search_power,
|
|
929
|
+
inv_search_power_int,
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
|
|
933
|
+
|
|
934
|
+
def decode_span_attention_staged_gqa_kernel(
|
|
935
|
+
Q1,
|
|
936
|
+
Q2,
|
|
937
|
+
K,
|
|
938
|
+
V,
|
|
939
|
+
cache_position,
|
|
940
|
+
attention_mask=None,
|
|
941
|
+
sw_index=0,
|
|
942
|
+
topk=3,
|
|
943
|
+
enable_gqa=False,
|
|
944
|
+
block_k=64,
|
|
945
|
+
backward_factor: float = 2.0,
|
|
946
|
+
forward_factor: float = 0.0,
|
|
947
|
+
span_power: float = 0.5,
|
|
948
|
+
search_power: float | None = None,
|
|
949
|
+
inv_search_power_int: int | None = 2,
|
|
950
|
+
force_mode=None,
|
|
951
|
+
):
|
|
952
|
+
return decode_span_attention_staged_gqa_kernel_v2(
|
|
953
|
+
Q1,
|
|
954
|
+
Q2,
|
|
955
|
+
K,
|
|
956
|
+
V,
|
|
957
|
+
cache_position,
|
|
958
|
+
attention_mask=attention_mask,
|
|
959
|
+
sw_index=sw_index,
|
|
960
|
+
topk=topk,
|
|
961
|
+
enable_gqa=enable_gqa,
|
|
962
|
+
block_k=block_k,
|
|
963
|
+
backward_factor=backward_factor,
|
|
964
|
+
forward_factor=forward_factor,
|
|
965
|
+
span_power=span_power,
|
|
966
|
+
search_power=search_power,
|
|
967
|
+
inv_search_power_int=inv_search_power_int,
|
|
968
|
+
force_mode=force_mode,
|
|
969
|
+
)
|
|
970
|
+
|
|
971
|
+
|
|
972
|
+
def decode_span_attention_staged_gqa_kernel_v2(
|
|
973
|
+
Q1,
|
|
974
|
+
Q2,
|
|
975
|
+
K,
|
|
976
|
+
V,
|
|
977
|
+
cache_position,
|
|
978
|
+
attention_mask=None,
|
|
979
|
+
sw_index=0,
|
|
980
|
+
topk=3,
|
|
981
|
+
enable_gqa=False,
|
|
982
|
+
block_k=64,
|
|
983
|
+
backward_factor: float = 2.0,
|
|
984
|
+
forward_factor: float = 0.0,
|
|
985
|
+
span_power: float = 0.5,
|
|
986
|
+
search_power: float | None = None,
|
|
987
|
+
inv_search_power_int: int | None = 2,
|
|
988
|
+
force_mode=None,
|
|
989
|
+
):
|
|
990
|
+
"""Staged decode variant that repeats K/V for small-window SDPA to avoid GQA overhead."""
|
|
991
|
+
if (not enable_gqa) or (K.shape[1] == Q1.shape[1]):
|
|
992
|
+
params = derive_stripe_power_params(
|
|
993
|
+
search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
994
|
+
)
|
|
995
|
+
if params.triton_inv_n != 2:
|
|
996
|
+
raise NotImplementedError(
|
|
997
|
+
"decode_span_attention_staged_gqa_kernel_v2 only supports p!=0.5 when enable_gqa=True. "
|
|
998
|
+
"Use enable_gqa=True or keep p=0.5 (inv_search_power_int=2 or search_power=0.5)."
|
|
999
|
+
)
|
|
1000
|
+
return decode_span_attention_staged(
|
|
1001
|
+
Q1,
|
|
1002
|
+
Q2,
|
|
1003
|
+
K,
|
|
1004
|
+
V,
|
|
1005
|
+
cache_position,
|
|
1006
|
+
attention_mask=attention_mask,
|
|
1007
|
+
sw_index=sw_index,
|
|
1008
|
+
topk=topk,
|
|
1009
|
+
backward_factor=backward_factor,
|
|
1010
|
+
forward_factor=forward_factor,
|
|
1011
|
+
span_power=span_power,
|
|
1012
|
+
search_power=search_power,
|
|
1013
|
+
inv_search_power_int=inv_search_power_int,
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
B, H_q, L_Q, D = Q1.shape
|
|
1017
|
+
_, H_kv, L_KV, D_kv = K.shape
|
|
1018
|
+
assert L_Q == 1, "GQA decode only supports decoding (L_Q=1)"
|
|
1019
|
+
assert D_kv == D, "Q/K/V head dimensions must match"
|
|
1020
|
+
assert H_q % H_kv == 0, "Query heads must be divisible by KV heads when GQA is enabled"
|
|
1021
|
+
|
|
1022
|
+
device = Q1.device
|
|
1023
|
+
kv_repeat = H_q // H_kv
|
|
1024
|
+
|
|
1025
|
+
window_len = window_len_from_sw_index(
|
|
1026
|
+
int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
1027
|
+
)
|
|
1028
|
+
# ------------------------------------------------------------------------
|
|
1029
|
+
# IMPORTANT (StaticCache correctness):
|
|
1030
|
+
# StaticCache returns full fixed-size K/V buffers during decode (L_KV = max_seq_len) to
|
|
1031
|
+
# keep shapes stable for CUDA graphs. For short prefixes, we must still take the same
|
|
1032
|
+
# "full-attention" SDPA fallback that DynamicCache would take (based on the *effective*
|
|
1033
|
+
# prefix length, i.e., cache_position[-1] + 1), otherwise small numeric diffs can
|
|
1034
|
+
# accumulate and change argmax tokens.
|
|
1035
|
+
#
|
|
1036
|
+
# To keep CUDA-graph safety (no data-dependent Python branching), we compute both:
|
|
1037
|
+
# - SDPA fallback over a small fixed prefix window (<= window_len)
|
|
1038
|
+
# - the staged span-attention output
|
|
1039
|
+
# and select with a tensor mask.
|
|
1040
|
+
# ------------------------------------------------------------------------
|
|
1041
|
+
token_pos = cache_position[-1] + 1 # effective prefix length (1-indexed)
|
|
1042
|
+
use_sdpa = token_pos <= window_len
|
|
1043
|
+
|
|
1044
|
+
if force_mode not in (None, "sdpa", "span"):
|
|
1045
|
+
force_mode = None
|
|
1046
|
+
|
|
1047
|
+
out_sdpa = None
|
|
1048
|
+
if force_mode != "span":
|
|
1049
|
+
# SDPA fallback: only needs the first `min(L_KV, window_len)` keys. For StaticCache,
|
|
1050
|
+
# attention_mask is expected to mask out positions beyond the current prefix.
|
|
1051
|
+
kv_slice_len = min(L_KV, max(window_len, 1))
|
|
1052
|
+
K_sdpa = K[:, :, :kv_slice_len, :]
|
|
1053
|
+
V_sdpa = V[:, :, :kv_slice_len, :]
|
|
1054
|
+
if attention_mask is not None:
|
|
1055
|
+
sdpa_mask = attention_mask[:, :kv_slice_len]
|
|
1056
|
+
else:
|
|
1057
|
+
# Synthesize a prefix mask so StaticCache doesn't attend to the unused tail.
|
|
1058
|
+
pos = torch.arange(kv_slice_len, device=device, dtype=cache_position.dtype)
|
|
1059
|
+
sdpa_mask = (pos <= cache_position[-1]).unsqueeze(0).expand(B, -1)
|
|
1060
|
+
|
|
1061
|
+
attn_mask = sdpa_mask.unsqueeze(1).unsqueeze(2)
|
|
1062
|
+
# CUDA graph safe: use arithmetic instead of torch.tensor()
|
|
1063
|
+
# True (1.0) -> 0.0, False (0.0) -> -1e9
|
|
1064
|
+
attn_mask_float = attn_mask.to(Q2.dtype)
|
|
1065
|
+
attn_mask = (attn_mask_float - 1.0) * 1e9
|
|
1066
|
+
|
|
1067
|
+
K_rep = K_sdpa.repeat_interleave(kv_repeat, dim=1)
|
|
1068
|
+
V_rep = V_sdpa.repeat_interleave(kv_repeat, dim=1)
|
|
1069
|
+
out_sdpa = torch.nn.functional.scaled_dot_product_attention(
|
|
1070
|
+
Q2,
|
|
1071
|
+
K_rep,
|
|
1072
|
+
V_rep,
|
|
1073
|
+
attn_mask=attn_mask,
|
|
1074
|
+
dropout_p=0.0,
|
|
1075
|
+
is_causal=False,
|
|
1076
|
+
enable_gqa=False,
|
|
1077
|
+
)
|
|
1078
|
+
if force_mode == "sdpa":
|
|
1079
|
+
return out_sdpa
|
|
1080
|
+
|
|
1081
|
+
# ========================================================================
|
|
1082
|
+
# CUDA Graph Compatible Version: Use fixed-size tensors with masking
|
|
1083
|
+
# instead of dynamic filtering operations (no stripe_loc[mask] indexing)
|
|
1084
|
+
# ========================================================================
|
|
1085
|
+
|
|
1086
|
+
num_stripes = max(
|
|
1087
|
+
int(
|
|
1088
|
+
max_stripe_index_for_token_pos(
|
|
1089
|
+
int(L_KV),
|
|
1090
|
+
search_power=search_power,
|
|
1091
|
+
inv_search_power_int=inv_search_power_int,
|
|
1092
|
+
)
|
|
1093
|
+
),
|
|
1094
|
+
sw_index + 1,
|
|
1095
|
+
)
|
|
1096
|
+
stripe_idx = torch.arange(sw_index + 1, num_stripes + 1, device=device)
|
|
1097
|
+
# IMPORTANT: For StaticCache decode, K/V are full fixed-size buffers (L_KV=max_seq_len),
|
|
1098
|
+
# so we must anchor stripes relative to the *effective* prefix length (token_pos),
|
|
1099
|
+
# not the allocation length. For DynamicCache, token_pos == L_KV so this is unchanged.
|
|
1100
|
+
power_params = derive_stripe_power_params(
|
|
1101
|
+
search_power=search_power, inv_search_power_int=inv_search_power_int
|
|
1102
|
+
)
|
|
1103
|
+
if power_params.triton_inv_n != 0:
|
|
1104
|
+
stripe_floor_power = stripe_idx.to(torch.int64) ** int(power_params.triton_inv_n)
|
|
1105
|
+
else:
|
|
1106
|
+
stripe_floor_power = torch.floor(
|
|
1107
|
+
stripe_idx.to(torch.float32) ** float(power_params.inv_p)
|
|
1108
|
+
).to(torch.int64)
|
|
1109
|
+
stripe_loc_all = token_pos - stripe_floor_power
|
|
1110
|
+
|
|
1111
|
+
# Create validity mask instead of filtering
|
|
1112
|
+
# Mask 1: stripe_loc >= 0
|
|
1113
|
+
valid_mask = stripe_loc_all >= 0
|
|
1114
|
+
|
|
1115
|
+
# Mask 2: attention_mask check (if provided)
|
|
1116
|
+
if attention_mask is not None:
|
|
1117
|
+
# Use safe indexing with clamping
|
|
1118
|
+
safe_indices = stripe_loc_all.clamp(min=0, max=L_KV-1)
|
|
1119
|
+
attn_valid = attention_mask[0, safe_indices].to(torch.bool)
|
|
1120
|
+
valid_mask = valid_mask & attn_valid
|
|
1121
|
+
|
|
1122
|
+
# Mask 3: no overlap with sliding window
|
|
1123
|
+
# Use tensor operations instead of .item() for CUDA graph compatibility
|
|
1124
|
+
decode_pos_tensor = cache_position[-1]
|
|
1125
|
+
sw_start_tensor = torch.clamp(decode_pos_tensor - (window_len - 1), min=0)
|
|
1126
|
+
sw_valid = stripe_loc_all < sw_start_tensor
|
|
1127
|
+
valid_mask = valid_mask & sw_valid
|
|
1128
|
+
|
|
1129
|
+
# Convert stripe_loc to int64 and mark invalid entries as -1
|
|
1130
|
+
stripe_loc = torch.where(valid_mask, stripe_loc_all, -1).to(torch.int64)
|
|
1131
|
+
|
|
1132
|
+
# Initialize topk_vals and topk_idx with proper dimensions
|
|
1133
|
+
topk_vals = torch.full((B, H_q, topk), float('-inf'), device=device, dtype=torch.float32)
|
|
1134
|
+
topk_idx = torch.full((B, H_q, topk), -1, device=device, dtype=torch.int64)
|
|
1135
|
+
|
|
1136
|
+
# CUDA Graph Compatible: Always compute logits, use masking to handle empty cases
|
|
1137
|
+
# No conditional branches based on .item() calls
|
|
1138
|
+
max_stripes = stripe_loc.shape[0]
|
|
1139
|
+
|
|
1140
|
+
# Use safe indexing with clamping to avoid out-of-bounds for invalid stripes
|
|
1141
|
+
safe_stripe_locs = stripe_loc.clamp(min=0, max=L_KV-1)
|
|
1142
|
+
K_stripe = K.index_select(2, safe_stripe_locs)
|
|
1143
|
+
|
|
1144
|
+
logits = torch.einsum(
|
|
1145
|
+
'bhrld,bhsd->bhrls',
|
|
1146
|
+
Q1.detach().float().reshape(B, H_kv, kv_repeat, L_Q, D),
|
|
1147
|
+
K_stripe.detach().float()
|
|
1148
|
+
).squeeze(3)
|
|
1149
|
+
logits = logits.reshape(B, H_q, -1)
|
|
1150
|
+
|
|
1151
|
+
# Set invalid stripe logits to -inf so they don't get selected in topk
|
|
1152
|
+
invalid_stripe_mask = ~valid_mask.unsqueeze(0).unsqueeze(0).expand(B, H_q, -1)
|
|
1153
|
+
logits = torch.where(invalid_stripe_mask, float('-inf'), logits)
|
|
1154
|
+
|
|
1155
|
+
# Top-k over all stripes (invalid ones have -inf logits)
|
|
1156
|
+
actual_k = min(topk, max_stripes)
|
|
1157
|
+
actual_topk_vals, actual_topk_idx = torch.topk(logits, k=actual_k, dim=-1)
|
|
1158
|
+
|
|
1159
|
+
topk_vals[:, :, :actual_k] = actual_topk_vals
|
|
1160
|
+
topk_idx[:, :, :actual_k] = actual_topk_idx
|
|
1161
|
+
|
|
1162
|
+
# Map indices to stripe locations using the full stripe_loc tensor
|
|
1163
|
+
# Invalid topk_idx (-1) will produce qanchor=-1
|
|
1164
|
+
qanchor = torch.full_like(topk_idx, -1, dtype=torch.int64)
|
|
1165
|
+
safe_idx = topk_idx.clamp(min=0, max=stripe_loc.shape[0]-1)
|
|
1166
|
+
valid_topk_mask = topk_idx >= 0
|
|
1167
|
+
selected_locs = stripe_loc[safe_idx]
|
|
1168
|
+
qanchor = torch.where(valid_topk_mask, selected_locs, qanchor)
|
|
1169
|
+
|
|
1170
|
+
# Use tensor operations instead of .item() for CUDA graph compatibility
|
|
1171
|
+
# Compute span_len using tensor operations
|
|
1172
|
+
span_len = torch.ceil(float(backward_factor) * (token_pos.float() ** float(span_power))).to(torch.int64)
|
|
1173
|
+
qstart = qanchor - span_len
|
|
1174
|
+
qstart = torch.clamp(qstart, min=0)
|
|
1175
|
+
# For invalid spans (qanchor=-1), set qstart to -1 as well
|
|
1176
|
+
qstart = torch.where(qanchor < 0, -1, qstart)
|
|
1177
|
+
|
|
1178
|
+
qstart = qstart.unsqueeze(2)
|
|
1179
|
+
qanchor = qanchor.unsqueeze(2)
|
|
1180
|
+
|
|
1181
|
+
# For decode, pass only the last cache position (shape [1]) to avoid broadcast mismatch
|
|
1182
|
+
# in _assert_no_span_sw_overlap which computes sw_start per query position
|
|
1183
|
+
decode_cache_pos = cache_position[-1:]
|
|
1184
|
+
qend = compute_qend_from_qanchor(
|
|
1185
|
+
qanchor,
|
|
1186
|
+
cache_position=decode_cache_pos,
|
|
1187
|
+
key_length=L_KV,
|
|
1188
|
+
sw_index=sw_index,
|
|
1189
|
+
attention_mask=attention_mask,
|
|
1190
|
+
forward_factor=forward_factor,
|
|
1191
|
+
span_power=span_power,
|
|
1192
|
+
search_power=search_power,
|
|
1193
|
+
inv_search_power_int=inv_search_power_int,
|
|
1194
|
+
)
|
|
1195
|
+
span_len_factor = backward_factor + forward_factor
|
|
1196
|
+
O_span = fused_span_attention_gqa(
|
|
1197
|
+
Q2,
|
|
1198
|
+
K,
|
|
1199
|
+
V,
|
|
1200
|
+
qstart,
|
|
1201
|
+
qend,
|
|
1202
|
+
decode_cache_pos,
|
|
1203
|
+
attention_mask=attention_mask,
|
|
1204
|
+
sw_index=sw_index,
|
|
1205
|
+
block_k=block_k,
|
|
1206
|
+
span_len_factor=span_len_factor,
|
|
1207
|
+
span_power=span_power,
|
|
1208
|
+
search_power=search_power,
|
|
1209
|
+
inv_search_power_int=inv_search_power_int,
|
|
1210
|
+
)
|
|
1211
|
+
|
|
1212
|
+
# Reuse topk_vals directly for gating (no need to recompute K[qend] @ Q1)
|
|
1213
|
+
span_values = topk_vals.unsqueeze(2) # [B, H_q, 1, topk]
|
|
1214
|
+
# Invalid spans already have -inf in topk_vals, but ensure consistency
|
|
1215
|
+
span_values = torch.where(qanchor < 0, float("-inf"), span_values)
|
|
1216
|
+
|
|
1217
|
+
span_scores = torch.nan_to_num(torch.softmax(span_values.float(), dim=-1), 1 / topk)
|
|
1218
|
+
span_scores = span_scores.to(O_span.dtype)
|
|
1219
|
+
O = (span_scores.unsqueeze(-1) * O_span).sum(dim=3)
|
|
1220
|
+
|
|
1221
|
+
if force_mode == "span":
|
|
1222
|
+
return O
|
|
1223
|
+
|
|
1224
|
+
# Select SDPA for short prefixes, span attention otherwise.
|
|
1225
|
+
use_sdpa_broadcast = use_sdpa.view(1, 1, 1, 1)
|
|
1226
|
+
return torch.where(use_sdpa_broadcast, out_sdpa, O)
|