superlinear 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. apps/__init__.py +4 -0
  2. apps/cli/__init__.py +8 -0
  3. apps/cli/bm25_rag.py +471 -0
  4. apps/cli/chat_repl.py +1497 -0
  5. apps/cli/client.py +195 -0
  6. apps/cli/docs_repl.py +2275 -0
  7. apps/cli/light_rag.py +729 -0
  8. apps/cli/local_snapshots.py +139 -0
  9. apps/cli/locks.py +214 -0
  10. apps/cli/main.py +457 -0
  11. apps/cli/output.py +32 -0
  12. apps/cli/server_cmds.py +516 -0
  13. apps/cli/session_cmds.py +491 -0
  14. apps/cli/snapshot_cmds.py +303 -0
  15. apps/cli/state.py +265 -0
  16. apps/server/__init__.py +4 -0
  17. apps/server/app.py +1363 -0
  18. apps/server/main.py +313 -0
  19. superlinear/__init__.py +114 -0
  20. superlinear/_version.py +3 -0
  21. superlinear/engine/__init__.py +10 -0
  22. superlinear/engine/adapters/__init__.py +12 -0
  23. superlinear/engine/adapters/base.py +91 -0
  24. superlinear/engine/adapters/superlinear.py +1233 -0
  25. superlinear/engine/chat_engine.py +1173 -0
  26. superlinear/engine/chat_types.py +130 -0
  27. superlinear/engine/registry.py +51 -0
  28. superlinear/engine/repetition.py +203 -0
  29. superlinear/engine/session_snapshots.py +451 -0
  30. superlinear/engine/tool_parser.py +83 -0
  31. superlinear/engine/types.py +42 -0
  32. superlinear/kernels/__init__.py +2 -0
  33. superlinear/kernels/common/__init__.py +21 -0
  34. superlinear/kernels/common/adjustment.py +106 -0
  35. superlinear/kernels/common/power.py +154 -0
  36. superlinear/kernels/superlinear/__init__.py +10 -0
  37. superlinear/kernels/superlinear/attention/__init__.py +78 -0
  38. superlinear/kernels/superlinear/attention/_prefill.py +940 -0
  39. superlinear/kernels/superlinear/attention/_sliding_window.py +1167 -0
  40. superlinear/kernels/superlinear/attention/api.py +433 -0
  41. superlinear/kernels/superlinear/search/__init__.py +33 -0
  42. superlinear/kernels/superlinear/search/_reference.py +204 -0
  43. superlinear/kernels/superlinear/search/_triton.py +488 -0
  44. superlinear/kernels/superlinear/search/_triton_gqa.py +534 -0
  45. superlinear/kernels/superlinear/search/api.py +200 -0
  46. superlinear/kernels/superlinear/span/__init__.py +41 -0
  47. superlinear/kernels/superlinear/span/_triton_bucketed_gqa.py +1461 -0
  48. superlinear/kernels/superlinear/span/_triton_forward.py +22 -0
  49. superlinear/kernels/superlinear/span/_triton_gqa.py +1226 -0
  50. superlinear/kernels/superlinear/span/_triton_impl.py +928 -0
  51. superlinear/kernels/superlinear/span/_triton_precomputed_sw.py +460 -0
  52. superlinear/kernels/superlinear/span/_triton_precomputed_sw_gqa.py +598 -0
  53. superlinear/kernels/superlinear/span/api.py +296 -0
  54. superlinear/kernels/superlinear/span/masks.py +187 -0
  55. superlinear/py.typed +0 -0
  56. superlinear/runtime.py +71 -0
  57. superlinear-0.1.0.dist-info/METADATA +469 -0
  58. superlinear-0.1.0.dist-info/RECORD +62 -0
  59. superlinear-0.1.0.dist-info/WHEEL +5 -0
  60. superlinear-0.1.0.dist-info/entry_points.txt +2 -0
  61. superlinear-0.1.0.dist-info/licenses/LICENSE +202 -0
  62. superlinear-0.1.0.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1226 @@
1
+ import math
2
+ import os
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+ from superlinear.kernels.common.power import derive_stripe_power_params, max_stripe_index_for_token_pos, window_len_from_sw_index
9
+ from superlinear.kernels.common.adjustment import compute_qend_from_qanchor
10
+ from ._triton_impl import (
11
+ _assert_no_span_sw_overlap,
12
+ _next_power_of_two,
13
+ decode_span_attention_staged,
14
+ fused_span_attention,
15
+ )
16
+
17
+
18
+ @triton.jit
19
+ def fused_span_forward_kernel_gqa(
20
+ Q_ptr, K_ptr, V_ptr,
21
+ qstart_ptr, qend_ptr, cache_position_ptr,
22
+ attn_mask_ptr,
23
+ Out_ptr,
24
+ B, H_Q, H_KV, L_Q, L_KV,
25
+ kv_repeat,
26
+ window_size,
27
+ sm_scale,
28
+ K_VAL: tl.constexpr,
29
+ BLOCK_K: tl.constexpr,
30
+ BLOCK_D: tl.constexpr,
31
+ D_HEAD: tl.constexpr,
32
+ SPAN_MAX_BLOCKS: tl.constexpr,
33
+ SW_MAX_BLOCKS: tl.constexpr,
34
+ HAS_ATTN_MASK: tl.constexpr,
35
+ ):
36
+ pid = tl.program_id(0)
37
+ span_index = pid % K_VAL
38
+ q_idx = (pid // K_VAL) % L_Q
39
+ q_head_idx = (pid // (K_VAL * L_Q)) % H_Q
40
+ batch_idx = pid // (K_VAL * L_Q * H_Q)
41
+
42
+ kv_head_idx = q_head_idx // kv_repeat
43
+ kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
44
+
45
+ d_range = tl.arange(0, BLOCK_D)
46
+ d_mask = d_range < D_HEAD
47
+
48
+ q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
49
+ q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
50
+
51
+ cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
52
+ cache_pos = tl.minimum(cache_pos, L_KV - 1)
53
+
54
+ span_offset = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
55
+ span_start = tl.load(qstart_ptr + span_offset, mask=True, other=-1).to(tl.int32)
56
+ span_end = tl.load(qend_ptr + span_offset, mask=True, other=-1).to(tl.int32)
57
+ span_start = tl.maximum(span_start, 0)
58
+ span_end = tl.minimum(span_end, L_KV - 1)
59
+ span_valid = (span_end >= span_start) & (span_end >= 0)
60
+
61
+ window = window_size
62
+ sw_end = tl.minimum(cache_pos, L_KV - 1)
63
+ sw_start = sw_end - (window - 1)
64
+ sw_start = tl.maximum(sw_start, 0)
65
+ sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
66
+
67
+ seg1_start, seg1_end, seg1_valid = span_start, span_end, span_valid
68
+ seg2_start, seg2_end, seg2_valid = sw_start, sw_end, sw_valid
69
+
70
+ seg1_start = tl.maximum(seg1_start, 0)
71
+ seg1_end = tl.minimum(seg1_end, L_KV - 1)
72
+ seg1_valid = seg1_valid & (seg1_start <= seg1_end)
73
+
74
+ seg2_start = tl.maximum(seg2_start, 0)
75
+ seg2_end = tl.minimum(seg2_end, L_KV - 1)
76
+ seg2_valid = seg2_valid & (seg2_start <= seg2_end)
77
+
78
+ k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
79
+ attn_base = batch_idx * L_KV
80
+
81
+ m_i = -float('inf')
82
+ l_i = 0.0
83
+ acc = tl.zeros((BLOCK_D,), dtype=tl.float32)
84
+ scale = tl.full((1,), sm_scale, tl.float32)
85
+
86
+ for block_idx in range(SPAN_MAX_BLOCKS):
87
+ block_start = seg1_start + block_idx * BLOCK_K
88
+ k_pos = block_start + tl.arange(0, BLOCK_K)
89
+ in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
90
+ k_pos_safe = tl.where(in_range, k_pos, 0)
91
+
92
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
93
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
94
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
95
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
96
+
97
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
98
+
99
+ if HAS_ATTN_MASK:
100
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
101
+ in_range = in_range & attn_mask_vals
102
+ logits = tl.where(in_range, logits, float('-inf'))
103
+
104
+ if tl.sum(in_range, axis=0) > 0:
105
+ block_max = tl.max(logits, axis=0)
106
+ m_new = tl.maximum(m_i, block_max)
107
+ alpha = tl.exp(m_i - m_new)
108
+ p = tl.exp(logits - m_new)
109
+ l_i = l_i * alpha + tl.sum(p, axis=0)
110
+ acc = acc * alpha + tl.sum(p[:, None] * v_block, axis=0)
111
+ m_i = m_new
112
+
113
+ if SW_MAX_BLOCKS > 0:
114
+ for block_idx in range(SW_MAX_BLOCKS):
115
+ block_start = seg2_start + block_idx * BLOCK_K
116
+ k_pos = block_start + tl.arange(0, BLOCK_K)
117
+ in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
118
+ k_pos_safe = tl.where(in_range, k_pos, 0)
119
+
120
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
121
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
122
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
123
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
124
+
125
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
126
+
127
+ if HAS_ATTN_MASK:
128
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
129
+ in_range = in_range & attn_mask_vals
130
+ logits = tl.where(in_range, logits, float('-inf'))
131
+
132
+ if tl.sum(in_range, axis=0) > 0:
133
+ block_max = tl.max(logits, axis=0)
134
+ m_new = tl.maximum(m_i, block_max)
135
+ alpha = tl.exp(m_i - m_new)
136
+ p = tl.exp(logits - m_new)
137
+ l_i = l_i * alpha + tl.sum(p, axis=0)
138
+ acc = acc * alpha + tl.sum(p[:, None] * v_block, axis=0)
139
+ m_i = m_new
140
+
141
+ acc = tl.where(l_i > 0, acc / l_i, 0.0)
142
+ out_index = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
143
+ out_base = Out_ptr + out_index.to(tl.int64) * D_HEAD
144
+ tl.store(out_base + d_range, acc, mask=d_mask)
145
+
146
+
147
+ @triton.jit
148
+ def fused_span_backward_kernel_gqa(
149
+ Q_ptr, K_ptr, V_ptr, dOut_ptr,
150
+ qstart_ptr, qend_ptr, cache_position_ptr,
151
+ attn_mask_ptr,
152
+ dQ_ptr, dK_ptr, dV_ptr,
153
+ B, H_Q, H_KV, L_Q, L_KV,
154
+ kv_repeat,
155
+ window_size,
156
+ sm_scale,
157
+ K_VAL: tl.constexpr,
158
+ BLOCK_K: tl.constexpr,
159
+ BLOCK_D: tl.constexpr,
160
+ D_HEAD: tl.constexpr,
161
+ SPAN_MAX_BLOCKS: tl.constexpr,
162
+ SW_MAX_BLOCKS: tl.constexpr,
163
+ HAS_ATTN_MASK: tl.constexpr,
164
+ WRITE_DKV_PER_QHEAD: tl.constexpr,
165
+ ):
166
+ pid = tl.program_id(0)
167
+ span_index = pid % K_VAL
168
+ q_idx = (pid // K_VAL) % L_Q
169
+ q_head_idx = (pid // (K_VAL * L_Q)) % H_Q
170
+ batch_idx = pid // (K_VAL * L_Q * H_Q)
171
+
172
+ kv_head_idx = q_head_idx // kv_repeat
173
+ kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
174
+
175
+ d_range = tl.arange(0, BLOCK_D)
176
+ d_mask = d_range < D_HEAD
177
+
178
+ q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
179
+ q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
180
+
181
+ cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
182
+ cache_pos = tl.minimum(cache_pos, L_KV - 1)
183
+
184
+ span_offset = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
185
+ span_start = tl.load(qstart_ptr + span_offset, mask=True, other=-1).to(tl.int32)
186
+ span_end = tl.load(qend_ptr + span_offset, mask=True, other=-1).to(tl.int32)
187
+ span_start = tl.maximum(span_start, 0)
188
+ span_end = tl.minimum(span_end, L_KV - 1)
189
+ span_valid = (span_end >= span_start) & (span_end >= 0)
190
+
191
+ window = window_size
192
+ sw_end = tl.minimum(cache_pos, L_KV - 1)
193
+ sw_start = sw_end - (window - 1)
194
+ sw_start = tl.maximum(sw_start, 0)
195
+ sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
196
+
197
+ seg1_start, seg1_end, seg1_valid = span_start, span_end, span_valid
198
+ seg2_start, seg2_end, seg2_valid = sw_start, sw_end, sw_valid
199
+
200
+ seg1_start = tl.maximum(seg1_start, 0)
201
+ seg1_end = tl.minimum(seg1_end, L_KV - 1)
202
+ seg1_valid = seg1_valid & (seg1_start <= seg1_end)
203
+
204
+ seg2_start = tl.maximum(seg2_start, 0)
205
+ seg2_end = tl.minimum(seg2_end, L_KV - 1)
206
+ seg2_valid = seg2_valid & (seg2_start <= seg2_end)
207
+
208
+ k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
209
+ if WRITE_DKV_PER_QHEAD:
210
+ dkv_head_offset = ((batch_idx * H_Q + q_head_idx) * L_KV).to(tl.int64) * D_HEAD
211
+ else:
212
+ dkv_head_offset = k_head_offset
213
+ attn_base = batch_idx * L_KV
214
+
215
+ m_i = -float('inf')
216
+ l_i = 0.0
217
+ scale = tl.full((1,), sm_scale, tl.float32)
218
+
219
+ for block_idx in range(SPAN_MAX_BLOCKS):
220
+ block_start = seg1_start + block_idx * BLOCK_K
221
+ k_pos = block_start + tl.arange(0, BLOCK_K)
222
+ in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
223
+ k_pos_safe = tl.where(in_range, k_pos, 0)
224
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
225
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
226
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
227
+ if HAS_ATTN_MASK:
228
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
229
+ in_range = in_range & attn_mask_vals
230
+ logits = tl.where(in_range, logits, float('-inf'))
231
+ if tl.sum(in_range, axis=0) > 0:
232
+ block_max = tl.max(logits, axis=0)
233
+ m_new = tl.maximum(m_i, block_max)
234
+ alpha = tl.exp(m_i - m_new)
235
+ p = tl.exp(logits - m_new)
236
+ l_i = l_i * alpha + tl.sum(p, axis=0)
237
+ m_i = m_new
238
+
239
+ if SW_MAX_BLOCKS > 0:
240
+ for block_idx in range(SW_MAX_BLOCKS):
241
+ block_start = seg2_start + block_idx * BLOCK_K
242
+ k_pos = block_start + tl.arange(0, BLOCK_K)
243
+ in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
244
+ k_pos_safe = tl.where(in_range, k_pos, 0)
245
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
246
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
247
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
248
+ if HAS_ATTN_MASK:
249
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
250
+ in_range = in_range & attn_mask_vals
251
+ logits = tl.where(in_range, logits, float('-inf'))
252
+ if tl.sum(in_range, axis=0) > 0:
253
+ block_max = tl.max(logits, axis=0)
254
+ m_new = tl.maximum(m_i, block_max)
255
+ alpha = tl.exp(m_i - m_new)
256
+ p = tl.exp(logits - m_new)
257
+ l_i = l_i * alpha + tl.sum(p, axis=0)
258
+ m_i = m_new
259
+
260
+ dO_index = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_VAL + span_index
261
+ dO_base = dOut_ptr + dO_index.to(tl.int64) * D_HEAD
262
+ dO = tl.load(dO_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
263
+
264
+ grad_q = tl.zeros((BLOCK_D,), dtype=tl.float32)
265
+
266
+ dot_total = 0.0
267
+ for block_idx in range(SPAN_MAX_BLOCKS):
268
+ block_start = seg1_start + block_idx * BLOCK_K
269
+ k_pos = block_start + tl.arange(0, BLOCK_K)
270
+ in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
271
+ k_pos_safe = tl.where(in_range, k_pos, 0)
272
+
273
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
274
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
275
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
276
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
277
+
278
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
279
+ if HAS_ATTN_MASK:
280
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
281
+ in_range = in_range & attn_mask_vals
282
+ logits = tl.where(in_range, logits, float('-inf'))
283
+ if tl.sum(in_range, axis=0) > 0:
284
+ weights = tl.exp(logits - m_i) / l_i
285
+ weights = tl.where(in_range, weights, 0.0)
286
+ grad_w = tl.sum(v_block * dO[None, :], axis=1)
287
+ dot_total += tl.sum(grad_w * weights, axis=0)
288
+
289
+ if SW_MAX_BLOCKS > 0:
290
+ for block_idx in range(SW_MAX_BLOCKS):
291
+ block_start = seg2_start + block_idx * BLOCK_K
292
+ k_pos = block_start + tl.arange(0, BLOCK_K)
293
+ in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
294
+ k_pos_safe = tl.where(in_range, k_pos, 0)
295
+
296
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
297
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
298
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
299
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
300
+
301
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
302
+ if HAS_ATTN_MASK:
303
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
304
+ in_range = in_range & attn_mask_vals
305
+ logits = tl.where(in_range, logits, float('-inf'))
306
+ if tl.sum(in_range, axis=0) > 0:
307
+ weights = tl.exp(logits - m_i) / l_i
308
+ weights = tl.where(in_range, weights, 0.0)
309
+ grad_w = tl.sum(v_block * dO[None, :], axis=1)
310
+ dot_total += tl.sum(grad_w * weights, axis=0)
311
+
312
+ for block_idx in range(SPAN_MAX_BLOCKS):
313
+ block_start = seg1_start + block_idx * BLOCK_K
314
+ k_pos = block_start + tl.arange(0, BLOCK_K)
315
+ in_range = seg1_valid & (k_pos >= seg1_start) & (k_pos <= seg1_end)
316
+ k_pos_safe = tl.where(in_range, k_pos, 0)
317
+
318
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
319
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
320
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
321
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
322
+
323
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
324
+ if HAS_ATTN_MASK:
325
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
326
+ in_range = in_range & attn_mask_vals
327
+ logits = tl.where(in_range, logits, float('-inf'))
328
+ if tl.sum(in_range, axis=0) > 0:
329
+ weights = tl.exp(logits - m_i) / l_i
330
+ weights = tl.where(in_range, weights, 0.0)
331
+ grad_w = tl.sum(v_block * dO[None, :], axis=1)
332
+ grad_s = (grad_w - dot_total) * weights * sm_scale
333
+ grad_s = tl.where(in_range, grad_s, 0.0)
334
+
335
+ grad_q = grad_q + tl.sum(grad_s[:, None] * k_block, axis=0)
336
+
337
+ dk = grad_s[:, None] * q[None, :]
338
+ dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
339
+ tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, None] & d_mask[None, :])
340
+
341
+ dv = weights[:, None] * dO[None, :]
342
+ tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, None] & d_mask[None, :])
343
+
344
+ if SW_MAX_BLOCKS > 0:
345
+ for block_idx in range(SW_MAX_BLOCKS):
346
+ block_start = seg2_start + block_idx * BLOCK_K
347
+ k_pos = block_start + tl.arange(0, BLOCK_K)
348
+ in_range = seg2_valid & (k_pos >= seg2_start) & (k_pos <= seg2_end)
349
+ k_pos_safe = tl.where(in_range, k_pos, 0)
350
+
351
+ k_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
352
+ v_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
353
+ k_block = tl.load(K_ptr + k_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
354
+ v_block = tl.load(V_ptr + v_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
355
+
356
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
357
+ if HAS_ATTN_MASK:
358
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
359
+ in_range = in_range & attn_mask_vals
360
+ logits = tl.where(in_range, logits, float('-inf'))
361
+ if tl.sum(in_range, axis=0) > 0:
362
+ weights = tl.exp(logits - m_i) / l_i
363
+ weights = tl.where(in_range, weights, 0.0)
364
+ grad_w = tl.sum(v_block * dO[None, :], axis=1)
365
+ grad_s = (grad_w - dot_total) * weights * sm_scale
366
+ grad_s = tl.where(in_range, grad_s, 0.0)
367
+
368
+ grad_q = grad_q + tl.sum(grad_s[:, None] * k_block, axis=0)
369
+
370
+ dk = grad_s[:, None] * q[None, :]
371
+ dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
372
+ tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, None] & d_mask[None, :])
373
+
374
+ dv = weights[:, None] * dO[None, :]
375
+ tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, None] & d_mask[None, :])
376
+
377
+ dq_base = dQ_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
378
+ tl.atomic_add(dq_base + d_range, grad_q, mask=d_mask)
379
+
380
+
381
+ @triton.jit
382
+ def fused_span_backward_kernel_gqa_fused_spans(
383
+ Q_ptr, K_ptr, V_ptr, dOut_ptr,
384
+ qstart_ptr, qend_ptr, cache_position_ptr,
385
+ attn_mask_ptr,
386
+ dQ_ptr, dK_ptr, dV_ptr,
387
+ B, H_Q, H_KV, L_Q, L_KV,
388
+ kv_repeat,
389
+ window_size,
390
+ sm_scale,
391
+ K_VAL: tl.constexpr, # Power of 2 for tl.arange
392
+ K_ACTUAL: tl.constexpr, # Actual number of spans
393
+ BLOCK_K: tl.constexpr,
394
+ BLOCK_D: tl.constexpr,
395
+ D_HEAD: tl.constexpr,
396
+ SPAN_MAX_BLOCKS: tl.constexpr,
397
+ SW_MAX_BLOCKS: tl.constexpr,
398
+ HAS_ATTN_MASK: tl.constexpr,
399
+ WRITE_DKV_PER_QHEAD: tl.constexpr,
400
+ ):
401
+ pid = tl.program_id(0)
402
+ q_idx = pid % L_Q
403
+ q_head_idx = (pid // L_Q) % H_Q
404
+ batch_idx = pid // (L_Q * H_Q)
405
+
406
+ kv_head_idx = q_head_idx // kv_repeat
407
+ kv_head_idx = tl.minimum(kv_head_idx, H_KV - 1)
408
+
409
+ span_range = tl.arange(0, K_VAL)
410
+ span_mask = span_range < K_ACTUAL # Mask for valid spans
411
+ k_range = tl.arange(0, BLOCK_K)
412
+ d_range = tl.arange(0, BLOCK_D)
413
+ d_mask = d_range < D_HEAD
414
+
415
+ q_base = Q_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
416
+ q = tl.load(q_base + d_range, mask=d_mask, other=0.0).to(tl.float32)
417
+
418
+ cache_pos = tl.load(cache_position_ptr + q_idx, mask=True, other=0).to(tl.int32)
419
+ cache_pos = tl.minimum(cache_pos, L_KV - 1)
420
+
421
+ # Use K_ACTUAL for actual data layout
422
+ span_offset_base = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_ACTUAL
423
+ span_offsets = span_offset_base + span_range
424
+ span_start = tl.load(qstart_ptr + span_offsets, mask=span_mask, other=-1).to(tl.int32)
425
+ span_end = tl.load(qend_ptr + span_offsets, mask=span_mask, other=-1).to(tl.int32)
426
+ span_start = tl.maximum(span_start, 0)
427
+ span_end = tl.minimum(span_end, L_KV - 1)
428
+ span_valid = (span_end >= span_start) & (span_end >= 0) & span_mask
429
+
430
+ window = window_size
431
+ sw_end = tl.minimum(cache_pos, L_KV - 1)
432
+ sw_start = sw_end - (window - 1)
433
+ sw_start = tl.maximum(sw_start, 0)
434
+ sw_valid = (window > 0) & (sw_start <= sw_end) & (sw_end >= 0)
435
+
436
+ k_head_offset = ((batch_idx * H_KV + kv_head_idx) * L_KV).to(tl.int64) * D_HEAD
437
+ if WRITE_DKV_PER_QHEAD:
438
+ dkv_head_offset = ((batch_idx * H_Q + q_head_idx) * L_KV).to(tl.int64) * D_HEAD
439
+ else:
440
+ dkv_head_offset = k_head_offset
441
+ attn_base = batch_idx * L_KV
442
+
443
+ # Use K_ACTUAL for actual data layout
444
+ dO_index_base = ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * K_ACTUAL
445
+ dO_offsets = (dO_index_base + span_range)[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
446
+ dO = tl.load(dOut_ptr + dO_offsets, mask=span_mask[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
447
+
448
+ m_i = tl.full((K_VAL,), -float('inf'), tl.float32)
449
+ l_i = tl.zeros((K_VAL,), tl.float32)
450
+ dot_num = tl.zeros((K_VAL,), tl.float32)
451
+ scale = tl.full((1,), sm_scale, tl.float32)
452
+
453
+ # Pass 1: compute m_i, l_i, dot_num (online) for each span.
454
+ for block_idx in range(SPAN_MAX_BLOCKS):
455
+ block_start = span_start + block_idx * BLOCK_K
456
+ k_pos = block_start[:, None] + k_range[None, :]
457
+ in_range = span_valid[:, None] & (k_pos >= span_start[:, None]) & (k_pos <= span_end[:, None])
458
+ k_pos_safe = tl.where(in_range, k_pos, 0)
459
+
460
+ kv_offsets = k_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
461
+ k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
462
+ v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
463
+
464
+ logits = tl.sum(k_block * q[None, None, :], axis=2) * scale
465
+
466
+ if HAS_ATTN_MASK:
467
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
468
+ in_range = in_range & attn_mask_vals
469
+ logits = tl.where(in_range, logits, float('-inf'))
470
+
471
+ grad_w = tl.sum(v_block * dO[:, None, :], axis=2)
472
+
473
+ block_max = tl.max(logits, axis=1)
474
+ m_new = tl.maximum(m_i, block_max)
475
+ alpha = tl.exp(m_i - m_new)
476
+ alpha = tl.where(m_new == -float('inf'), 0.0, alpha)
477
+
478
+ logits_shift = logits - m_new[:, None]
479
+ logits_shift = tl.where(m_new[:, None] == -float('inf'), float('-inf'), logits_shift)
480
+ p = tl.exp(logits_shift)
481
+
482
+ l_i = l_i * alpha + tl.sum(p, axis=1)
483
+ dot_num = dot_num * alpha + tl.sum(p * grad_w, axis=1)
484
+ m_i = m_new
485
+
486
+ if SW_MAX_BLOCKS > 0:
487
+ for block_idx in range(SW_MAX_BLOCKS):
488
+ block_start = sw_start + block_idx * BLOCK_K
489
+ k_pos = block_start + k_range
490
+ in_range = sw_valid & (k_pos >= sw_start) & (k_pos <= sw_end)
491
+ k_pos_safe = tl.where(in_range, k_pos, 0)
492
+
493
+ kv_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
494
+ k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
495
+ v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
496
+
497
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
498
+
499
+ if HAS_ATTN_MASK:
500
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
501
+ in_range = in_range & attn_mask_vals
502
+ logits = tl.where(in_range, logits, float('-inf'))
503
+
504
+ grad_w = tl.sum(v_block[None, :, :] * dO[:, None, :], axis=2)
505
+
506
+ block_max = tl.max(logits, axis=0)
507
+ m_new = tl.maximum(m_i, block_max)
508
+ alpha = tl.exp(m_i - m_new)
509
+ alpha = tl.where(m_new == -float('inf'), 0.0, alpha)
510
+
511
+ logits_shift = logits[None, :] - m_new[:, None]
512
+ logits_shift = tl.where(m_new[:, None] == -float('inf'), float('-inf'), logits_shift)
513
+ p = tl.exp(logits_shift)
514
+
515
+ l_i = l_i * alpha + tl.sum(p, axis=1)
516
+ dot_num = dot_num * alpha + tl.sum(p * grad_w, axis=1)
517
+ m_i = m_new
518
+
519
+ inv_l = tl.where(l_i > 0, 1.0 / l_i, 0.0)
520
+ dot_total = dot_num * inv_l
521
+
522
+ grad_q = tl.zeros((BLOCK_D,), dtype=tl.float32)
523
+
524
+ # Pass 2: accumulate gradients for all spans.
525
+ for block_idx in range(SPAN_MAX_BLOCKS):
526
+ block_start = span_start + block_idx * BLOCK_K
527
+ k_pos = block_start[:, None] + k_range[None, :]
528
+ in_range = span_valid[:, None] & (k_pos >= span_start[:, None]) & (k_pos <= span_end[:, None])
529
+ k_pos_safe = tl.where(in_range, k_pos, 0)
530
+
531
+ kv_offsets = k_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
532
+ k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
533
+ v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, :, None] & d_mask[None, None, :], other=0.0).to(tl.float32)
534
+
535
+ logits = tl.sum(k_block * q[None, None, :], axis=2) * scale
536
+
537
+ if HAS_ATTN_MASK:
538
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
539
+ in_range = in_range & attn_mask_vals
540
+ logits = tl.where(in_range, logits, float('-inf'))
541
+
542
+ weights = tl.exp(logits - m_i[:, None]) * inv_l[:, None]
543
+ weights = tl.where(inv_l[:, None] > 0, weights, 0.0)
544
+ weights = tl.where(in_range, weights, 0.0)
545
+
546
+ grad_w = tl.sum(v_block * dO[:, None, :], axis=2)
547
+ grad_s = (grad_w - dot_total[:, None]) * weights * sm_scale
548
+ grad_s = tl.where(in_range, grad_s, 0.0)
549
+
550
+ grad_q_span = tl.sum(grad_s[:, :, None] * k_block, axis=1)
551
+ grad_q += tl.sum(grad_q_span, axis=0)
552
+
553
+ dk = grad_s[:, :, None] * q[None, None, :]
554
+ dv = weights[:, :, None] * dO[:, None, :]
555
+
556
+ dkv_offsets = dkv_head_offset + k_pos_safe[:, :, None].to(tl.int64) * D_HEAD + d_range[None, None, :]
557
+ tl.atomic_add(dK_ptr + dkv_offsets, dk, mask=in_range[:, :, None] & d_mask[None, None, :])
558
+ tl.atomic_add(dV_ptr + dkv_offsets, dv, mask=in_range[:, :, None] & d_mask[None, None, :])
559
+
560
+ if SW_MAX_BLOCKS > 0:
561
+ for block_idx in range(SW_MAX_BLOCKS):
562
+ block_start = sw_start + block_idx * BLOCK_K
563
+ k_pos = block_start + k_range
564
+ in_range = sw_valid & (k_pos >= sw_start) & (k_pos <= sw_end)
565
+ k_pos_safe = tl.where(in_range, k_pos, 0)
566
+
567
+ kv_offsets = k_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
568
+ k_block = tl.load(K_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
569
+ v_block = tl.load(V_ptr + kv_offsets, mask=in_range[:, None] & d_mask[None, :], other=0.0).to(tl.float32)
570
+
571
+ logits = tl.sum(k_block * q[None, :], axis=1) * scale
572
+
573
+ if HAS_ATTN_MASK:
574
+ attn_mask_vals = tl.load(attn_mask_ptr + attn_base + k_pos_safe, mask=in_range, other=0).to(tl.int1)
575
+ in_range = in_range & attn_mask_vals
576
+ logits = tl.where(in_range, logits, float('-inf'))
577
+
578
+ weights = tl.exp(logits[None, :] - m_i[:, None]) * inv_l[:, None]
579
+ weights = tl.where(inv_l[:, None] > 0, weights, 0.0)
580
+ weights = tl.where(in_range[None, :], weights, 0.0)
581
+
582
+ grad_w = tl.sum(v_block[None, :, :] * dO[:, None, :], axis=2)
583
+ grad_s = (grad_w - dot_total[:, None]) * weights * sm_scale
584
+ grad_s = tl.where(in_range[None, :], grad_s, 0.0)
585
+
586
+ grad_q_span = tl.sum(grad_s[:, :, None] * k_block[None, :, :], axis=1)
587
+ grad_q += tl.sum(grad_q_span, axis=0)
588
+
589
+ grad_s_sum = tl.sum(grad_s, axis=0)
590
+ dk_total = grad_s_sum[:, None] * q[None, :]
591
+ dv_total = tl.sum(weights[:, :, None] * dO[:, None, :], axis=0)
592
+
593
+ dkv_offsets = dkv_head_offset + k_pos_safe[:, None].to(tl.int64) * D_HEAD + d_range[None, :]
594
+ tl.atomic_add(dK_ptr + dkv_offsets, dk_total, mask=in_range[:, None] & d_mask[None, :])
595
+ tl.atomic_add(dV_ptr + dkv_offsets, dv_total, mask=in_range[:, None] & d_mask[None, :])
596
+
597
+ dq_base = dQ_ptr + ((batch_idx * H_Q + q_head_idx) * L_Q + q_idx) * D_HEAD
598
+ tl.store(dq_base + d_range, grad_q, mask=d_mask)
599
+
600
+
601
+
602
+ def fused_span_triton_gqa(
603
+ Q2,
604
+ K,
605
+ V,
606
+ qstart,
607
+ qend,
608
+ cache_position,
609
+ attention_mask=None,
610
+ sw_index=0,
611
+ block_k: int = 64,
612
+ span_len_factor: float = 2.0,
613
+ span_power: float = 0.5,
614
+ search_power: float | None = None,
615
+ inv_search_power_int: int | None = 2,
616
+ ):
617
+ assert Q2.is_cuda, "CUDA required for fused Triton path"
618
+ B, H_q, L_Q, D = Q2.shape
619
+ _, H_kv, L_KV, _ = K.shape
620
+ num_spans = qstart.shape[-1]
621
+
622
+ span_len_factor = float(span_len_factor)
623
+ if not math.isfinite(span_len_factor) or span_len_factor <= 0.0:
624
+ raise ValueError(f"span_len_factor must be finite and > 0 (got {span_len_factor})")
625
+
626
+ span_power_f = float(span_power)
627
+ if not math.isfinite(span_power_f) or not (0.0 < span_power_f < 1.0):
628
+ raise ValueError(f"span_power must be finite and in (0, 1) (got {span_power})")
629
+
630
+ kv_repeat = H_q // H_kv
631
+
632
+ # Ensure all tensors are on the same device as Q2 for multi-GPU compatibility
633
+ device = Q2.device
634
+ K = K.to(device)
635
+ V = V.to(device)
636
+ qstart = qstart.to(device)
637
+ qend = qend.to(device)
638
+ cache_position = cache_position.to(device)
639
+ if attention_mask is not None:
640
+ attention_mask = attention_mask.to(device)
641
+
642
+ Q2c = Q2.contiguous()
643
+ Kc = K.contiguous()
644
+ Vc = V.contiguous()
645
+ qstartc = qstart.contiguous()
646
+ qendc = qend.contiguous()
647
+ cachec = cache_position.to(torch.int32).contiguous()
648
+
649
+ if attention_mask is not None:
650
+ attn_mask = attention_mask[:, :L_KV].contiguous().to(torch.int8)
651
+ has_mask = True
652
+ else:
653
+ attn_mask = torch.empty((1,), device=device, dtype=torch.int8)
654
+ has_mask = False
655
+
656
+ window = window_len_from_sw_index(
657
+ int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
658
+ )
659
+ _assert_no_span_sw_overlap(qendc, cachec.view(1, 1, -1, 1), sw_index, L_KV, window_len=window)
660
+
661
+ out = torch.empty((B, H_q, L_Q, num_spans, D), device=device, dtype=Q2.dtype)
662
+
663
+ max_span_len = int(span_len_factor * math.ceil(float(L_KV) ** span_power_f) + 2)
664
+ span_max_blocks = triton.cdiv(max_span_len, block_k)
665
+ span_max_blocks = max(1, span_max_blocks)
666
+ sw_max_len = min(window, L_KV) if window > 0 else 0
667
+ sw_max_blocks = triton.cdiv(sw_max_len, block_k) if sw_max_len > 0 else 0
668
+ block_d = min(256, _next_power_of_two(D))
669
+
670
+ grid = (B * H_q * L_Q * num_spans,)
671
+ # Use torch.cuda.device context to ensure kernel launches on correct GPU
672
+ with torch.cuda.device(device):
673
+ fused_span_forward_kernel_gqa[grid](
674
+ Q2c, Kc, Vc,
675
+ qstartc, qendc, cachec,
676
+ attn_mask,
677
+ out,
678
+ B, H_q, H_kv, L_Q, L_KV,
679
+ kv_repeat,
680
+ window,
681
+ 1.0 / math.sqrt(D),
682
+ K_VAL=num_spans,
683
+ BLOCK_K=block_k,
684
+ BLOCK_D=block_d,
685
+ D_HEAD=D,
686
+ SPAN_MAX_BLOCKS=span_max_blocks,
687
+ SW_MAX_BLOCKS=sw_max_blocks,
688
+ HAS_ATTN_MASK=has_mask,
689
+ )
690
+ # Synchronize to ensure kernel completes before output is used elsewhere.
691
+ # Skip during CUDA graph capture (synchronize is not allowed during capture).
692
+ # NOTE: This sync was added in commit 9dcb65f to fix multi-GPU race conditions
693
+ # with device_map='auto'. However, notebook 39.4 shows that the
694
+ # `with torch.cuda.device(device):` context is sufficient for correctness,
695
+ # and HuggingFace patterns don't use per-kernel sync. Consider removing
696
+ # this sync entirely if multi-GPU inference remains stable without it.
697
+ if not torch.cuda.is_current_stream_capturing():
698
+ torch.cuda.current_stream().synchronize()
699
+ return out
700
+
701
+
702
+ class FusedSpanGQATriton(torch.autograd.Function):
703
+ @staticmethod
704
+ def forward(
705
+ ctx,
706
+ Q2,
707
+ K,
708
+ V,
709
+ qstart,
710
+ qend,
711
+ cache_position,
712
+ attention_mask=None,
713
+ sw_index: int = 0,
714
+ block_k: int = 64,
715
+ span_len_factor: float = 2.0,
716
+ span_power: float = 0.5,
717
+ search_power: float | None = None,
718
+ inv_search_power_int: int | None = 2,
719
+ ):
720
+ B, H_q, L_Q, D = Q2.shape
721
+ _, H_kv, _, _ = K.shape
722
+ assert H_q % H_kv == 0, "Query heads must be divisible by KV heads"
723
+
724
+ kv_repeat = H_q // H_kv
725
+ out = fused_span_triton_gqa(
726
+ Q2,
727
+ K,
728
+ V,
729
+ qstart,
730
+ qend,
731
+ cache_position,
732
+ attention_mask,
733
+ sw_index,
734
+ block_k,
735
+ span_len_factor=span_len_factor,
736
+ span_power=span_power,
737
+ search_power=search_power,
738
+ inv_search_power_int=inv_search_power_int,
739
+ )
740
+ saved_mask = attention_mask if attention_mask is not None else torch.tensor([], device=Q2.device)
741
+ ctx.save_for_backward(Q2, K, V, qstart, qend, cache_position.to(torch.int32), saved_mask)
742
+ ctx.sw_index = sw_index
743
+ ctx.block_k = block_k
744
+ ctx.kv_repeat = kv_repeat
745
+ ctx.span_len_factor = float(span_len_factor)
746
+ ctx.span_power = float(span_power)
747
+ ctx.window_len = window_len_from_sw_index(
748
+ int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
749
+ )
750
+ return out
751
+
752
+ @staticmethod
753
+ def backward(ctx, grad_out):
754
+ Q2, K, V, qstart, qend, cache_position, attention_mask_saved = ctx.saved_tensors
755
+ attention_mask = None if attention_mask_saved.numel() == 0 else attention_mask_saved
756
+ sw_index = ctx.sw_index
757
+ block_k = ctx.block_k
758
+ kv_repeat = ctx.kv_repeat
759
+ span_len_factor = float(getattr(ctx, "span_len_factor", 2.0))
760
+ span_power = float(getattr(ctx, "span_power", 0.5))
761
+
762
+ B, H_q, L_Q, D = Q2.shape
763
+ _, H_kv, L_KV, _ = K.shape
764
+ num_spans = qstart.shape[-1]
765
+
766
+ window = int(getattr(ctx, "window_len", (sw_index + 1) ** 2 - 1))
767
+ max_span_len = int(span_len_factor * math.ceil(float(L_KV) ** span_power) + 2)
768
+ span_max_blocks = triton.cdiv(max_span_len, block_k)
769
+ span_max_blocks = max(1, span_max_blocks)
770
+ sw_max_len = min(window, L_KV) if window > 0 else 0
771
+ sw_max_blocks = triton.cdiv(sw_max_len, block_k) if sw_max_len > 0 else 0
772
+ block_d = min(256, _next_power_of_two(D))
773
+
774
+ Q2c = Q2.contiguous()
775
+ Kc = K.contiguous()
776
+ Vc = V.contiguous()
777
+ qstartc = qstart.contiguous()
778
+ qendc = qend.contiguous()
779
+ cachec = cache_position.to(torch.int32).contiguous()
780
+ grad_out_c = grad_out.contiguous()
781
+
782
+ if attention_mask is not None:
783
+ attn_mask = attention_mask[:, :L_KV].contiguous().to(torch.int8)
784
+ has_mask = True
785
+ else:
786
+ attn_mask = torch.empty((1,), device=Q2.device, dtype=torch.int8)
787
+ has_mask = False
788
+
789
+ dQ = torch.zeros_like(Q2c, dtype=torch.float32)
790
+ split_dkv = os.getenv("SPAN_ATTN_GQA_BACKWARD_SPLIT_DKV", "1") != "0"
791
+ if split_dkv and kv_repeat > 1:
792
+ dK_rep = torch.zeros((B, H_q, L_KV, D), device=Kc.device, dtype=torch.float32)
793
+ dV_rep = torch.zeros((B, H_q, L_KV, D), device=Vc.device, dtype=torch.float32)
794
+ dK_out = dK_rep
795
+ dV_out = dV_rep
796
+ else:
797
+ dK = torch.zeros_like(Kc, dtype=torch.float32)
798
+ dV = torch.zeros_like(Vc, dtype=torch.float32)
799
+ dK_out = dK
800
+ dV_out = dV
801
+
802
+ write_dkv_per_qhead = split_dkv and kv_repeat > 1
803
+ fuse_spans = os.getenv("SPAN_ATTN_GQA_BACKWARD_FUSE_SPANS", "1") != "0"
804
+ # K_VAL must be power of 2 for tl.arange
805
+ k_val_padded = 1 << (num_spans - 1).bit_length() if num_spans > 0 else 1
806
+ if fuse_spans and num_spans > 1:
807
+ grid = (B * H_q * L_Q,)
808
+ fused_span_backward_kernel_gqa_fused_spans[grid](
809
+ Q2c, Kc, Vc, grad_out_c,
810
+ qstartc, qendc, cachec,
811
+ attn_mask,
812
+ dQ, dK_out, dV_out,
813
+ B, H_q, H_kv, L_Q, L_KV,
814
+ kv_repeat,
815
+ window,
816
+ 1.0 / math.sqrt(D),
817
+ K_VAL=k_val_padded,
818
+ K_ACTUAL=num_spans,
819
+ BLOCK_K=block_k,
820
+ BLOCK_D=block_d,
821
+ D_HEAD=D,
822
+ SPAN_MAX_BLOCKS=span_max_blocks,
823
+ SW_MAX_BLOCKS=sw_max_blocks,
824
+ HAS_ATTN_MASK=has_mask,
825
+ WRITE_DKV_PER_QHEAD=write_dkv_per_qhead,
826
+ )
827
+ else:
828
+ grid = (B * H_q * L_Q * num_spans,)
829
+ fused_span_backward_kernel_gqa[grid](
830
+ Q2c, Kc, Vc, grad_out_c,
831
+ qstartc, qendc, cachec,
832
+ attn_mask,
833
+ dQ, dK_out, dV_out,
834
+ B, H_q, H_kv, L_Q, L_KV,
835
+ kv_repeat,
836
+ window,
837
+ 1.0 / math.sqrt(D),
838
+ K_VAL=num_spans,
839
+ BLOCK_K=block_k,
840
+ BLOCK_D=block_d,
841
+ D_HEAD=D,
842
+ SPAN_MAX_BLOCKS=span_max_blocks,
843
+ SW_MAX_BLOCKS=sw_max_blocks,
844
+ HAS_ATTN_MASK=has_mask,
845
+ WRITE_DKV_PER_QHEAD=write_dkv_per_qhead,
846
+ )
847
+
848
+ if split_dkv and kv_repeat > 1:
849
+ dK = dK_rep.view(B, H_kv, kv_repeat, L_KV, D).sum(dim=2)
850
+ dV = dV_rep.view(B, H_kv, kv_repeat, L_KV, D).sum(dim=2)
851
+
852
+ return (
853
+ dQ.to(Q2.dtype),
854
+ dK.to(K.dtype),
855
+ dV.to(V.dtype),
856
+ None,
857
+ None,
858
+ None,
859
+ None,
860
+ None,
861
+ None,
862
+ None,
863
+ None,
864
+ None,
865
+ None,
866
+ )
867
+
868
+
869
+ def fused_span_attention_gqa(
870
+ Q2,
871
+ K,
872
+ V,
873
+ qstart,
874
+ qend,
875
+ cache_position,
876
+ attention_mask=None,
877
+ sw_index=0,
878
+ block_k=64,
879
+ span_len_factor: float = 2.0,
880
+ span_power: float = 0.5,
881
+ search_power: float | None = None,
882
+ inv_search_power_int: int | None = 2,
883
+ ):
884
+ B, H_q, L_Q, _ = Q2.shape
885
+ _, H_kv, _, _ = K.shape
886
+ assert H_q % H_kv == 0, "Query heads must be divisible by KV heads when using GQA"
887
+
888
+ if not Q2.is_cuda:
889
+ params = derive_stripe_power_params(
890
+ search_power=search_power, inv_search_power_int=inv_search_power_int
891
+ )
892
+ if params.triton_inv_n != 2:
893
+ raise NotImplementedError(
894
+ "Non-CUDA fused_span_attention_gqa only supports p=0.5 for now "
895
+ "(use inv_search_power_int=2 or search_power=0.5)."
896
+ )
897
+ kv_repeat = H_q // H_kv
898
+ K_rep = K.repeat_interleave(kv_repeat, dim=1)
899
+ V_rep = V.repeat_interleave(kv_repeat, dim=1)
900
+ return fused_span_attention(
901
+ Q2,
902
+ K_rep,
903
+ V_rep,
904
+ qstart,
905
+ qend,
906
+ cache_position,
907
+ attention_mask=attention_mask,
908
+ sw_index=sw_index,
909
+ block_k=block_k,
910
+ span_len_factor=span_len_factor,
911
+ span_power=span_power,
912
+ search_power=search_power,
913
+ inv_search_power_int=inv_search_power_int,
914
+ )
915
+
916
+ return FusedSpanGQATriton.apply(
917
+ Q2,
918
+ K,
919
+ V,
920
+ qstart,
921
+ qend,
922
+ cache_position,
923
+ attention_mask,
924
+ sw_index,
925
+ block_k,
926
+ span_len_factor,
927
+ float(span_power),
928
+ search_power,
929
+ inv_search_power_int,
930
+ )
931
+
932
+
933
+
934
+ def decode_span_attention_staged_gqa_kernel(
935
+ Q1,
936
+ Q2,
937
+ K,
938
+ V,
939
+ cache_position,
940
+ attention_mask=None,
941
+ sw_index=0,
942
+ topk=3,
943
+ enable_gqa=False,
944
+ block_k=64,
945
+ backward_factor: float = 2.0,
946
+ forward_factor: float = 0.0,
947
+ span_power: float = 0.5,
948
+ search_power: float | None = None,
949
+ inv_search_power_int: int | None = 2,
950
+ force_mode=None,
951
+ ):
952
+ return decode_span_attention_staged_gqa_kernel_v2(
953
+ Q1,
954
+ Q2,
955
+ K,
956
+ V,
957
+ cache_position,
958
+ attention_mask=attention_mask,
959
+ sw_index=sw_index,
960
+ topk=topk,
961
+ enable_gqa=enable_gqa,
962
+ block_k=block_k,
963
+ backward_factor=backward_factor,
964
+ forward_factor=forward_factor,
965
+ span_power=span_power,
966
+ search_power=search_power,
967
+ inv_search_power_int=inv_search_power_int,
968
+ force_mode=force_mode,
969
+ )
970
+
971
+
972
+ def decode_span_attention_staged_gqa_kernel_v2(
973
+ Q1,
974
+ Q2,
975
+ K,
976
+ V,
977
+ cache_position,
978
+ attention_mask=None,
979
+ sw_index=0,
980
+ topk=3,
981
+ enable_gqa=False,
982
+ block_k=64,
983
+ backward_factor: float = 2.0,
984
+ forward_factor: float = 0.0,
985
+ span_power: float = 0.5,
986
+ search_power: float | None = None,
987
+ inv_search_power_int: int | None = 2,
988
+ force_mode=None,
989
+ ):
990
+ """Staged decode variant that repeats K/V for small-window SDPA to avoid GQA overhead."""
991
+ if (not enable_gqa) or (K.shape[1] == Q1.shape[1]):
992
+ params = derive_stripe_power_params(
993
+ search_power=search_power, inv_search_power_int=inv_search_power_int
994
+ )
995
+ if params.triton_inv_n != 2:
996
+ raise NotImplementedError(
997
+ "decode_span_attention_staged_gqa_kernel_v2 only supports p!=0.5 when enable_gqa=True. "
998
+ "Use enable_gqa=True or keep p=0.5 (inv_search_power_int=2 or search_power=0.5)."
999
+ )
1000
+ return decode_span_attention_staged(
1001
+ Q1,
1002
+ Q2,
1003
+ K,
1004
+ V,
1005
+ cache_position,
1006
+ attention_mask=attention_mask,
1007
+ sw_index=sw_index,
1008
+ topk=topk,
1009
+ backward_factor=backward_factor,
1010
+ forward_factor=forward_factor,
1011
+ span_power=span_power,
1012
+ search_power=search_power,
1013
+ inv_search_power_int=inv_search_power_int,
1014
+ )
1015
+
1016
+ B, H_q, L_Q, D = Q1.shape
1017
+ _, H_kv, L_KV, D_kv = K.shape
1018
+ assert L_Q == 1, "GQA decode only supports decoding (L_Q=1)"
1019
+ assert D_kv == D, "Q/K/V head dimensions must match"
1020
+ assert H_q % H_kv == 0, "Query heads must be divisible by KV heads when GQA is enabled"
1021
+
1022
+ device = Q1.device
1023
+ kv_repeat = H_q // H_kv
1024
+
1025
+ window_len = window_len_from_sw_index(
1026
+ int(sw_index), search_power=search_power, inv_search_power_int=inv_search_power_int
1027
+ )
1028
+ # ------------------------------------------------------------------------
1029
+ # IMPORTANT (StaticCache correctness):
1030
+ # StaticCache returns full fixed-size K/V buffers during decode (L_KV = max_seq_len) to
1031
+ # keep shapes stable for CUDA graphs. For short prefixes, we must still take the same
1032
+ # "full-attention" SDPA fallback that DynamicCache would take (based on the *effective*
1033
+ # prefix length, i.e., cache_position[-1] + 1), otherwise small numeric diffs can
1034
+ # accumulate and change argmax tokens.
1035
+ #
1036
+ # To keep CUDA-graph safety (no data-dependent Python branching), we compute both:
1037
+ # - SDPA fallback over a small fixed prefix window (<= window_len)
1038
+ # - the staged span-attention output
1039
+ # and select with a tensor mask.
1040
+ # ------------------------------------------------------------------------
1041
+ token_pos = cache_position[-1] + 1 # effective prefix length (1-indexed)
1042
+ use_sdpa = token_pos <= window_len
1043
+
1044
+ if force_mode not in (None, "sdpa", "span"):
1045
+ force_mode = None
1046
+
1047
+ out_sdpa = None
1048
+ if force_mode != "span":
1049
+ # SDPA fallback: only needs the first `min(L_KV, window_len)` keys. For StaticCache,
1050
+ # attention_mask is expected to mask out positions beyond the current prefix.
1051
+ kv_slice_len = min(L_KV, max(window_len, 1))
1052
+ K_sdpa = K[:, :, :kv_slice_len, :]
1053
+ V_sdpa = V[:, :, :kv_slice_len, :]
1054
+ if attention_mask is not None:
1055
+ sdpa_mask = attention_mask[:, :kv_slice_len]
1056
+ else:
1057
+ # Synthesize a prefix mask so StaticCache doesn't attend to the unused tail.
1058
+ pos = torch.arange(kv_slice_len, device=device, dtype=cache_position.dtype)
1059
+ sdpa_mask = (pos <= cache_position[-1]).unsqueeze(0).expand(B, -1)
1060
+
1061
+ attn_mask = sdpa_mask.unsqueeze(1).unsqueeze(2)
1062
+ # CUDA graph safe: use arithmetic instead of torch.tensor()
1063
+ # True (1.0) -> 0.0, False (0.0) -> -1e9
1064
+ attn_mask_float = attn_mask.to(Q2.dtype)
1065
+ attn_mask = (attn_mask_float - 1.0) * 1e9
1066
+
1067
+ K_rep = K_sdpa.repeat_interleave(kv_repeat, dim=1)
1068
+ V_rep = V_sdpa.repeat_interleave(kv_repeat, dim=1)
1069
+ out_sdpa = torch.nn.functional.scaled_dot_product_attention(
1070
+ Q2,
1071
+ K_rep,
1072
+ V_rep,
1073
+ attn_mask=attn_mask,
1074
+ dropout_p=0.0,
1075
+ is_causal=False,
1076
+ enable_gqa=False,
1077
+ )
1078
+ if force_mode == "sdpa":
1079
+ return out_sdpa
1080
+
1081
+ # ========================================================================
1082
+ # CUDA Graph Compatible Version: Use fixed-size tensors with masking
1083
+ # instead of dynamic filtering operations (no stripe_loc[mask] indexing)
1084
+ # ========================================================================
1085
+
1086
+ num_stripes = max(
1087
+ int(
1088
+ max_stripe_index_for_token_pos(
1089
+ int(L_KV),
1090
+ search_power=search_power,
1091
+ inv_search_power_int=inv_search_power_int,
1092
+ )
1093
+ ),
1094
+ sw_index + 1,
1095
+ )
1096
+ stripe_idx = torch.arange(sw_index + 1, num_stripes + 1, device=device)
1097
+ # IMPORTANT: For StaticCache decode, K/V are full fixed-size buffers (L_KV=max_seq_len),
1098
+ # so we must anchor stripes relative to the *effective* prefix length (token_pos),
1099
+ # not the allocation length. For DynamicCache, token_pos == L_KV so this is unchanged.
1100
+ power_params = derive_stripe_power_params(
1101
+ search_power=search_power, inv_search_power_int=inv_search_power_int
1102
+ )
1103
+ if power_params.triton_inv_n != 0:
1104
+ stripe_floor_power = stripe_idx.to(torch.int64) ** int(power_params.triton_inv_n)
1105
+ else:
1106
+ stripe_floor_power = torch.floor(
1107
+ stripe_idx.to(torch.float32) ** float(power_params.inv_p)
1108
+ ).to(torch.int64)
1109
+ stripe_loc_all = token_pos - stripe_floor_power
1110
+
1111
+ # Create validity mask instead of filtering
1112
+ # Mask 1: stripe_loc >= 0
1113
+ valid_mask = stripe_loc_all >= 0
1114
+
1115
+ # Mask 2: attention_mask check (if provided)
1116
+ if attention_mask is not None:
1117
+ # Use safe indexing with clamping
1118
+ safe_indices = stripe_loc_all.clamp(min=0, max=L_KV-1)
1119
+ attn_valid = attention_mask[0, safe_indices].to(torch.bool)
1120
+ valid_mask = valid_mask & attn_valid
1121
+
1122
+ # Mask 3: no overlap with sliding window
1123
+ # Use tensor operations instead of .item() for CUDA graph compatibility
1124
+ decode_pos_tensor = cache_position[-1]
1125
+ sw_start_tensor = torch.clamp(decode_pos_tensor - (window_len - 1), min=0)
1126
+ sw_valid = stripe_loc_all < sw_start_tensor
1127
+ valid_mask = valid_mask & sw_valid
1128
+
1129
+ # Convert stripe_loc to int64 and mark invalid entries as -1
1130
+ stripe_loc = torch.where(valid_mask, stripe_loc_all, -1).to(torch.int64)
1131
+
1132
+ # Initialize topk_vals and topk_idx with proper dimensions
1133
+ topk_vals = torch.full((B, H_q, topk), float('-inf'), device=device, dtype=torch.float32)
1134
+ topk_idx = torch.full((B, H_q, topk), -1, device=device, dtype=torch.int64)
1135
+
1136
+ # CUDA Graph Compatible: Always compute logits, use masking to handle empty cases
1137
+ # No conditional branches based on .item() calls
1138
+ max_stripes = stripe_loc.shape[0]
1139
+
1140
+ # Use safe indexing with clamping to avoid out-of-bounds for invalid stripes
1141
+ safe_stripe_locs = stripe_loc.clamp(min=0, max=L_KV-1)
1142
+ K_stripe = K.index_select(2, safe_stripe_locs)
1143
+
1144
+ logits = torch.einsum(
1145
+ 'bhrld,bhsd->bhrls',
1146
+ Q1.detach().float().reshape(B, H_kv, kv_repeat, L_Q, D),
1147
+ K_stripe.detach().float()
1148
+ ).squeeze(3)
1149
+ logits = logits.reshape(B, H_q, -1)
1150
+
1151
+ # Set invalid stripe logits to -inf so they don't get selected in topk
1152
+ invalid_stripe_mask = ~valid_mask.unsqueeze(0).unsqueeze(0).expand(B, H_q, -1)
1153
+ logits = torch.where(invalid_stripe_mask, float('-inf'), logits)
1154
+
1155
+ # Top-k over all stripes (invalid ones have -inf logits)
1156
+ actual_k = min(topk, max_stripes)
1157
+ actual_topk_vals, actual_topk_idx = torch.topk(logits, k=actual_k, dim=-1)
1158
+
1159
+ topk_vals[:, :, :actual_k] = actual_topk_vals
1160
+ topk_idx[:, :, :actual_k] = actual_topk_idx
1161
+
1162
+ # Map indices to stripe locations using the full stripe_loc tensor
1163
+ # Invalid topk_idx (-1) will produce qanchor=-1
1164
+ qanchor = torch.full_like(topk_idx, -1, dtype=torch.int64)
1165
+ safe_idx = topk_idx.clamp(min=0, max=stripe_loc.shape[0]-1)
1166
+ valid_topk_mask = topk_idx >= 0
1167
+ selected_locs = stripe_loc[safe_idx]
1168
+ qanchor = torch.where(valid_topk_mask, selected_locs, qanchor)
1169
+
1170
+ # Use tensor operations instead of .item() for CUDA graph compatibility
1171
+ # Compute span_len using tensor operations
1172
+ span_len = torch.ceil(float(backward_factor) * (token_pos.float() ** float(span_power))).to(torch.int64)
1173
+ qstart = qanchor - span_len
1174
+ qstart = torch.clamp(qstart, min=0)
1175
+ # For invalid spans (qanchor=-1), set qstart to -1 as well
1176
+ qstart = torch.where(qanchor < 0, -1, qstart)
1177
+
1178
+ qstart = qstart.unsqueeze(2)
1179
+ qanchor = qanchor.unsqueeze(2)
1180
+
1181
+ # For decode, pass only the last cache position (shape [1]) to avoid broadcast mismatch
1182
+ # in _assert_no_span_sw_overlap which computes sw_start per query position
1183
+ decode_cache_pos = cache_position[-1:]
1184
+ qend = compute_qend_from_qanchor(
1185
+ qanchor,
1186
+ cache_position=decode_cache_pos,
1187
+ key_length=L_KV,
1188
+ sw_index=sw_index,
1189
+ attention_mask=attention_mask,
1190
+ forward_factor=forward_factor,
1191
+ span_power=span_power,
1192
+ search_power=search_power,
1193
+ inv_search_power_int=inv_search_power_int,
1194
+ )
1195
+ span_len_factor = backward_factor + forward_factor
1196
+ O_span = fused_span_attention_gqa(
1197
+ Q2,
1198
+ K,
1199
+ V,
1200
+ qstart,
1201
+ qend,
1202
+ decode_cache_pos,
1203
+ attention_mask=attention_mask,
1204
+ sw_index=sw_index,
1205
+ block_k=block_k,
1206
+ span_len_factor=span_len_factor,
1207
+ span_power=span_power,
1208
+ search_power=search_power,
1209
+ inv_search_power_int=inv_search_power_int,
1210
+ )
1211
+
1212
+ # Reuse topk_vals directly for gating (no need to recompute K[qend] @ Q1)
1213
+ span_values = topk_vals.unsqueeze(2) # [B, H_q, 1, topk]
1214
+ # Invalid spans already have -inf in topk_vals, but ensure consistency
1215
+ span_values = torch.where(qanchor < 0, float("-inf"), span_values)
1216
+
1217
+ span_scores = torch.nan_to_num(torch.softmax(span_values.float(), dim=-1), 1 / topk)
1218
+ span_scores = span_scores.to(O_span.dtype)
1219
+ O = (span_scores.unsqueeze(-1) * O_span).sum(dim=3)
1220
+
1221
+ if force_mode == "span":
1222
+ return O
1223
+
1224
+ # Select SDPA for short prefixes, span attention otherwise.
1225
+ use_sdpa_broadcast = use_sdpa.view(1, 1, 1, 1)
1226
+ return torch.where(use_sdpa_broadcast, out_sdpa, O)