sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,772 @@
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
6
+
7
+ if global_server_args_dict.get("attention_reduce_in_fp32", False):
8
+ REDUCE_TRITON_TYPE = tl.float32
9
+ REDUCE_TORCH_TYPE = torch.float32
10
+ else:
11
+ REDUCE_TRITON_TYPE = tl.float16
12
+ REDUCE_TORCH_TYPE = torch.float16
13
+
14
+
15
+ @triton.jit
16
+ def tanh(x):
17
+ # Tanh is just a scaled sigmoid
18
+ return 2 * tl.sigmoid(2 * x) - 1
19
+
20
+
21
+ @triton.jit
22
+ def _fwd_kernel_flash_decode_stage1(
23
+ Q,
24
+ K,
25
+ V,
26
+ sm_scale,
27
+ Req_to_tokens,
28
+ B_req_idx,
29
+ B_Seqlen,
30
+ Mid_O, # [batch, head, seq_block_num, head_dim]
31
+ Mid_O_LogExpSum, # [batch, head, seq_block_num]
32
+ stride_req_to_tokens_b,
33
+ stride_req_to_tokens_s,
34
+ stride_qbs,
35
+ stride_qh,
36
+ stride_qd,
37
+ stride_kbs,
38
+ stride_kh,
39
+ stride_kd,
40
+ stride_vbs,
41
+ stride_vh,
42
+ stride_vd,
43
+ stride_mid_ob,
44
+ stride_mid_oh,
45
+ stride_mid_os,
46
+ stride_mid_od,
47
+ stride_mid_o_eb,
48
+ stride_mid_o_eh,
49
+ stride_mid_o_es,
50
+ gqa_group_size,
51
+ BLOCK_SEQ: tl.constexpr,
52
+ BLOCK_DMODEL: tl.constexpr,
53
+ BLOCK_N: tl.constexpr,
54
+ ):
55
+ cur_batch = tl.program_id(0)
56
+ cur_head = tl.program_id(1)
57
+ seq_start_block = tl.program_id(2)
58
+ cur_kv_head = cur_head // gqa_group_size
59
+
60
+ offs_d = tl.arange(0, BLOCK_DMODEL)
61
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
62
+ cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
63
+ cur_batch_start_index = seq_start_block * BLOCK_SEQ
64
+ cur_batch_end_index = tl.minimum(
65
+ cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ
66
+ )
67
+
68
+ off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
69
+
70
+ block_n_size = (
71
+ tl.where(
72
+ cur_batch_end_index - cur_batch_start_index <= 0,
73
+ 0,
74
+ cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,
75
+ )
76
+ // BLOCK_N
77
+ )
78
+
79
+ offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
80
+
81
+ q = tl.load(Q + off_q)
82
+
83
+ sum_exp = 0.0
84
+ max_logic = -float("inf")
85
+ acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
86
+
87
+ for start_n in range(0, block_n_size, 1):
88
+ offs_n_new = start_n * BLOCK_N + offs_n
89
+ k_loc = tl.load(
90
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
91
+ mask=offs_n_new < cur_batch_end_index,
92
+ other=0,
93
+ )
94
+ off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]
95
+ k = tl.load(
96
+ K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
97
+ )
98
+ att_value = tl.sum(q[None, :] * k, 1)
99
+ att_value *= sm_scale
100
+ att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf"))
101
+ v = tl.load(
102
+ V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
103
+ )
104
+
105
+ cur_max_logic = tl.max(att_value, axis=0)
106
+ new_max_logic = tl.maximum(cur_max_logic, max_logic)
107
+
108
+ exp_logic = tl.exp(att_value - new_max_logic)
109
+ logic_scale = tl.exp(max_logic - new_max_logic)
110
+ acc *= logic_scale
111
+ acc += tl.sum(exp_logic[:, None] * v, axis=0)
112
+
113
+ sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)
114
+ max_logic = new_max_logic
115
+
116
+ need_store = tl.where(block_n_size == 0, 0, 1)
117
+ for _ in range(0, need_store, 1):
118
+ off_mid_o = (
119
+ cur_batch * stride_mid_ob
120
+ + cur_head * stride_mid_oh
121
+ + seq_start_block * stride_mid_os
122
+ + offs_d
123
+ )
124
+ off_mid_o_logexpsum = (
125
+ cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block
126
+ )
127
+ tl.store(Mid_O + off_mid_o, acc / sum_exp)
128
+ tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))
129
+ return
130
+
131
+
132
+ @triton.jit
133
+ def _fwd_kernel_flash_decode_stage2(
134
+ B_Seqlen,
135
+ Mid_O, # [batch, head, seq_block_num, head_dim]
136
+ Mid_O_LogExpSum, # [batch, head, seq_block_num]
137
+ O, # [batch, head, head_dim]
138
+ stride_mid_ob,
139
+ stride_mid_oh,
140
+ stride_mid_os,
141
+ stride_mid_od,
142
+ stride_mid_o_eb,
143
+ stride_mid_o_eh,
144
+ stride_mid_o_es,
145
+ stride_obs,
146
+ stride_oh,
147
+ stride_od,
148
+ BLOCK_SEQ: tl.constexpr,
149
+ BLOCK_DMODEL: tl.constexpr,
150
+ ):
151
+ cur_batch = tl.program_id(0)
152
+ cur_head = tl.program_id(1)
153
+
154
+ offs_d = tl.arange(0, BLOCK_DMODEL)
155
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
156
+
157
+ block_n_size = (
158
+ tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1)
159
+ // BLOCK_SEQ
160
+ )
161
+
162
+ sum_exp = 0.0
163
+ max_logic = -float("inf")
164
+ acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
165
+
166
+ offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
167
+ offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh
168
+ for block_seq_n in range(0, block_n_size, 1):
169
+ tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)
170
+ tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)
171
+ new_max_logic = tl.maximum(tlogic, max_logic)
172
+
173
+ old_scale = tl.exp(max_logic - new_max_logic)
174
+ acc *= old_scale
175
+ exp_logic = tl.exp(tlogic - new_max_logic)
176
+ acc += exp_logic * tv
177
+ sum_exp = sum_exp * old_scale + exp_logic
178
+ max_logic = new_max_logic
179
+
180
+ tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)
181
+ return
182
+
183
+
184
+ @torch.no_grad()
185
+ def flash_decode_stage1(
186
+ q,
187
+ k,
188
+ v,
189
+ Req_to_tokens,
190
+ B_req_idx,
191
+ B_Seqlen,
192
+ max_len_in_batch,
193
+ mid_out,
194
+ mid_out_logsumexp,
195
+ block_seq,
196
+ ):
197
+ BLOCK_SEQ = block_seq
198
+ BLOCK_N = 16
199
+ assert BLOCK_SEQ % BLOCK_N == 0
200
+ # shape constraints
201
+ Lq, Lk = q.shape[-1], k.shape[-1]
202
+ assert Lq == Lk
203
+ assert Lk in {16, 32, 64, 128}
204
+ sm_scale = 1.0 / (Lk**0.5)
205
+ batch, head_num = B_req_idx.shape[0], q.shape[1]
206
+ grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))
207
+ gqa_group_size = q.shape[1] // k.shape[1]
208
+
209
+ _fwd_kernel_flash_decode_stage1[grid](
210
+ q,
211
+ k,
212
+ v,
213
+ sm_scale,
214
+ Req_to_tokens,
215
+ B_req_idx,
216
+ B_Seqlen,
217
+ mid_out,
218
+ mid_out_logsumexp,
219
+ Req_to_tokens.stride(0),
220
+ Req_to_tokens.stride(1),
221
+ q.stride(0),
222
+ q.stride(1),
223
+ q.stride(2),
224
+ k.stride(0),
225
+ k.stride(1),
226
+ k.stride(2),
227
+ v.stride(0),
228
+ v.stride(1),
229
+ v.stride(2),
230
+ mid_out.stride(0),
231
+ mid_out.stride(1),
232
+ mid_out.stride(2),
233
+ mid_out.stride(3),
234
+ mid_out_logsumexp.stride(0),
235
+ mid_out_logsumexp.stride(1),
236
+ mid_out_logsumexp.stride(2),
237
+ gqa_group_size,
238
+ BLOCK_SEQ=BLOCK_SEQ,
239
+ BLOCK_DMODEL=Lk,
240
+ BLOCK_N=BLOCK_N,
241
+ num_warps=1,
242
+ num_stages=2,
243
+ )
244
+ return
245
+
246
+
247
+ @torch.no_grad()
248
+ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
249
+ Lk = mid_out.shape[-1]
250
+ assert Lk in {16, 32, 64, 128}
251
+ batch, head_num = mid_out.shape[0], mid_out.shape[1]
252
+ grid = (batch, head_num)
253
+
254
+ _fwd_kernel_flash_decode_stage2[grid](
255
+ B_Seqlen,
256
+ mid_out,
257
+ mid_out_logexpsum,
258
+ O,
259
+ mid_out.stride(0),
260
+ mid_out.stride(1),
261
+ mid_out.stride(2),
262
+ mid_out.stride(3),
263
+ mid_out_logexpsum.stride(0),
264
+ mid_out_logexpsum.stride(1),
265
+ mid_out_logexpsum.stride(2),
266
+ O.stride(0),
267
+ O.stride(1),
268
+ O.stride(2),
269
+ BLOCK_SEQ=block_seq,
270
+ BLOCK_DMODEL=Lk,
271
+ num_warps=4,
272
+ num_stages=2,
273
+ )
274
+ return
275
+
276
+
277
+ import torch
278
+
279
+
280
+ def flash_decode_attention_fwd(
281
+ q,
282
+ k_buffer,
283
+ v_buffer,
284
+ o,
285
+ req_to_token,
286
+ b_req_idx,
287
+ b_start_loc,
288
+ b_seq_len,
289
+ attn_logits,
290
+ max_len_in_batch,
291
+ sm_scale,
292
+ logit_cap=0.0,
293
+ ):
294
+ BLOCK_SEQ = 256
295
+ kv_group_num = q.shape[1] // v_buffer.shape[1]
296
+ # batch_size = q.shape[0]
297
+
298
+ block_seq_num = (max_len_in_batch + BLOCK_SEQ - 1) // BLOCK_SEQ
299
+
300
+ mid_o = torch.empty(
301
+ [q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],
302
+ dtype=torch.float32,
303
+ device="cuda",
304
+ )
305
+ mid_o_logexpsum = torch.empty(
306
+ [q.shape[0], q.shape[1], block_seq_num], dtype=torch.float32, device="cuda"
307
+ )
308
+
309
+ flash_decode_stage1(
310
+ q,
311
+ k_buffer,
312
+ v_buffer,
313
+ req_to_token,
314
+ b_req_idx,
315
+ b_seq_len,
316
+ max_len_in_batch,
317
+ mid_o,
318
+ mid_o_logexpsum,
319
+ BLOCK_SEQ,
320
+ )
321
+ flash_decode_stage2(mid_o, mid_o_logexpsum, b_seq_len, o, BLOCK_SEQ)
322
+
323
+
324
+ @triton.jit
325
+ def _sparse_fwd_kernel_flash_decode_stage1( # Double Sparsity's approximate attention
326
+ Q_Label,
327
+ K_Label_Buffer,
328
+ sm_scale,
329
+ Req_to_tokens, # shape: [B, S]
330
+ B_Seqlen,
331
+ Att_Out, # shape: [H, B, S] easier for topk
332
+ stride_req_to_tokens_b,
333
+ stride_qbs,
334
+ stride_qh,
335
+ stride_buf_kbs,
336
+ stride_buf_kh,
337
+ att_stride_h,
338
+ att_stride_b,
339
+ kv_group_num: tl.constexpr,
340
+ BLOCK_DMODEL: tl.constexpr,
341
+ BLOCK_N: tl.constexpr,
342
+ logit_cap: tl.constexpr,
343
+ ):
344
+ cur_batch = tl.program_id(0)
345
+ cur_head = tl.program_id(1)
346
+ start_n = tl.program_id(2)
347
+
348
+ cur_kv_head = cur_head // kv_group_num
349
+
350
+ offs_d = tl.arange(0, BLOCK_DMODEL)
351
+ cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
352
+
353
+ cur_batch_start_index = 0
354
+ cur_batch_end_index = cur_batch_seq_len
355
+
356
+ min_val = -float("inf")
357
+ att_value = tl.full([BLOCK_N], min_val, dtype=tl.float32)
358
+
359
+ off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
360
+
361
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
362
+
363
+ block_index = start_n * BLOCK_N
364
+ block_mask = tl.where(block_index < cur_batch_seq_len, 1, 0)
365
+
366
+ for start_mark in range(0, block_mask, 1):
367
+ q = tl.load(Q_Label + off_q + start_mark).to(REDUCE_TRITON_TYPE)
368
+ offs_n_new = cur_batch_start_index + offs_n
369
+ k_loc = tl.load(
370
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch + offs_n_new,
371
+ mask=offs_n_new < cur_batch_end_index,
372
+ other=0,
373
+ )
374
+ offs_buf_k = (
375
+ k_loc[:, None] * stride_buf_kbs
376
+ + cur_kv_head * stride_buf_kh
377
+ + offs_d[None, :]
378
+ )
379
+ k = tl.load(
380
+ K_Label_Buffer + offs_buf_k,
381
+ mask=offs_n_new[:, None] < cur_batch_end_index,
382
+ other=0.0,
383
+ ).to(REDUCE_TRITON_TYPE)
384
+
385
+ att_value = tl.sum(q[None, :] * k, 1)
386
+ att_value *= sm_scale
387
+
388
+ if logit_cap > 0:
389
+ att_value = logit_cap * tanh(att_value / logit_cap)
390
+
391
+ att_value = tl.where(offs_n < cur_batch_end_index, att_value, min_val)
392
+ off_o = cur_head * att_stride_h + (cur_batch * att_stride_b + offs_n)
393
+ tl.store(Att_Out + off_o, att_value)
394
+
395
+
396
+ @triton.jit
397
+ def _sparse_fwd_kernel_flash_decode_stage2(
398
+ Q,
399
+ K,
400
+ V,
401
+ sm_scale,
402
+ Req_to_tokens, # shape: [B, S]
403
+ Topk_token_indices, # shape: [H, B, k]
404
+ Mid_O, # [batch, head, seq_block_num, head_dim]
405
+ Mid_O_LogExpSum, # [batch, head, seq_block_num]
406
+ Heavy_token_num, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
407
+ stride_req_to_tokens_b,
408
+ stride_topk_token_indices_h,
409
+ stride_topk_token_indices_b,
410
+ stride_qbs,
411
+ stride_qh,
412
+ stride_kbs,
413
+ stride_kh,
414
+ stride_vbs,
415
+ stride_vh,
416
+ stride_mid_ob,
417
+ stride_mid_oh,
418
+ stride_mid_os,
419
+ stride_mid_o_eb,
420
+ stride_mid_o_eh,
421
+ gqa_group_size,
422
+ BLOCK_SEQ: tl.constexpr,
423
+ BLOCK_DMODEL: tl.constexpr,
424
+ BLOCK_N: tl.constexpr,
425
+ ):
426
+ cur_batch = tl.program_id(0)
427
+ cur_head = tl.program_id(1)
428
+ seq_start_block = tl.program_id(2)
429
+ cur_kv_head = cur_head // gqa_group_size
430
+
431
+ offs_d = tl.arange(0, BLOCK_DMODEL)
432
+ cur_batch_start_index = seq_start_block * BLOCK_SEQ
433
+ cur_batch_end_index = tl.minimum(Heavy_token_num, cur_batch_start_index + BLOCK_SEQ)
434
+
435
+ off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
436
+
437
+ block_n_size = (
438
+ tl.where(
439
+ cur_batch_end_index - cur_batch_start_index <= 0,
440
+ 0,
441
+ cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,
442
+ )
443
+ // BLOCK_N
444
+ )
445
+
446
+ # offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)
447
+ offs_n = tl.arange(0, BLOCK_N)
448
+
449
+ q = tl.load(Q + off_q)
450
+
451
+ sum_exp = 0.0
452
+ max_logic = -float("inf")
453
+ acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
454
+
455
+ for start_n in range(cur_batch_start_index, cur_batch_end_index, BLOCK_N):
456
+ # for start_n in range(0, block_n_size, 1):
457
+ # offs_n_new = start_n * BLOCK_N + offs_n
458
+ offs_n_new = start_n + offs_n
459
+ # offs_n_new = cur_batch_start_index + start_n * BLOCK_N + offs_n
460
+ topk_token_indices = tl.load(
461
+ Topk_token_indices
462
+ + stride_topk_token_indices_h * cur_head
463
+ + stride_topk_token_indices_b * cur_batch
464
+ + offs_n_new,
465
+ mask=offs_n_new < cur_batch_end_index,
466
+ other=0,
467
+ )
468
+ k_loc = tl.load(
469
+ Req_to_tokens + stride_req_to_tokens_b * cur_batch + topk_token_indices,
470
+ mask=offs_n_new < cur_batch_end_index,
471
+ other=0,
472
+ )
473
+ off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :]
474
+ k = tl.load(
475
+ K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
476
+ )
477
+ att_value = tl.sum(q[None, :] * k, 1)
478
+ att_value *= sm_scale
479
+ att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf"))
480
+ v = tl.load(
481
+ V + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0
482
+ )
483
+
484
+ cur_max_logic = tl.max(att_value, axis=0)
485
+ new_max_logic = tl.maximum(cur_max_logic, max_logic)
486
+
487
+ exp_logic = tl.exp(att_value - new_max_logic)
488
+ logic_scale = tl.exp(max_logic - new_max_logic)
489
+ acc *= logic_scale
490
+ acc += tl.sum(exp_logic[:, None] * v, axis=0)
491
+
492
+ sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0)
493
+ max_logic = new_max_logic
494
+
495
+ # need_store = tl.where(block_n_size == 0, 0, 1)
496
+ need_store = 1
497
+ for _ in range(0, need_store, 1):
498
+ off_mid_o = (
499
+ cur_batch * stride_mid_ob
500
+ + cur_head * stride_mid_oh
501
+ + seq_start_block * stride_mid_os
502
+ + offs_d
503
+ )
504
+ off_mid_o_logexpsum = (
505
+ cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block
506
+ )
507
+ tl.store(Mid_O + off_mid_o, acc / sum_exp)
508
+ tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp))
509
+ return
510
+
511
+
512
+ @triton.jit
513
+ def _sparse_fwd_kernel_flash_decode_stage3(
514
+ Mid_O, # [batch, head, seq_block_num, head_dim]
515
+ Mid_O_LogExpSum, # [batch, head, seq_block_num]
516
+ O, # [batch, head, head_dim]
517
+ seq_len, # NOTE: This can be used as constexpr but we may support dynamic heavy token number in the future
518
+ stride_mid_ob,
519
+ stride_mid_oh,
520
+ stride_mid_os,
521
+ stride_mid_o_eb,
522
+ stride_mid_o_eh,
523
+ stride_obs,
524
+ stride_oh,
525
+ BLOCK_SEQ: tl.constexpr,
526
+ BLOCK_DMODEL: tl.constexpr,
527
+ ):
528
+ cur_batch = tl.program_id(0)
529
+ cur_head = tl.program_id(1)
530
+
531
+ offs_d = tl.arange(0, BLOCK_DMODEL)
532
+
533
+ block_n_size = tl.where(seq_len <= 0, 0, seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ
534
+
535
+ sum_exp = 0.0
536
+ max_logic = -float("inf")
537
+ acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
538
+
539
+ offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
540
+ offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh
541
+ for block_seq_n in range(0, block_n_size, 1):
542
+ tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)
543
+ tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)
544
+ new_max_logic = tl.maximum(tlogic, max_logic)
545
+
546
+ old_scale = tl.exp(max_logic - new_max_logic)
547
+ acc *= old_scale
548
+ exp_logic = tl.exp(tlogic - new_max_logic)
549
+ acc += exp_logic * tv
550
+ sum_exp = sum_exp * old_scale + exp_logic
551
+ max_logic = new_max_logic
552
+
553
+ tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)
554
+ return
555
+
556
+
557
+ def sparse_flash_decode_stage1(
558
+ q_label,
559
+ k_label_buffer,
560
+ att_out,
561
+ Req_to_tokens,
562
+ B_Seqlen,
563
+ max_len_in_batch,
564
+ sm_scale,
565
+ logit_cap,
566
+ ):
567
+ BLOCK = 32
568
+ # shape constraints
569
+ Lq, Lk = q_label.shape[-1], k_label_buffer.shape[-1]
570
+ assert Lq == Lk
571
+ assert Lk in {16, 32, 64, 128, 256, 576}
572
+
573
+ BLOCK_DMODEL = Lk
574
+
575
+ batch, head_num = q_label.shape[0], q_label.shape[1]
576
+
577
+ grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK))
578
+ kv_group_num = q_label.shape[1] // k_label_buffer.shape[1]
579
+
580
+ if kv_group_num == 1:
581
+ num_warps = 4
582
+ else:
583
+ num_warps = 2
584
+
585
+ _sparse_fwd_kernel_flash_decode_stage1[grid](
586
+ q_label,
587
+ k_label_buffer,
588
+ sm_scale,
589
+ Req_to_tokens,
590
+ B_Seqlen,
591
+ att_out,
592
+ Req_to_tokens.stride(0),
593
+ q_label.stride(0),
594
+ q_label.stride(1),
595
+ k_label_buffer.stride(0),
596
+ k_label_buffer.stride(1),
597
+ att_out.stride(0),
598
+ att_out.stride(1),
599
+ kv_group_num,
600
+ BLOCK_DMODEL,
601
+ BLOCK,
602
+ logit_cap,
603
+ num_warps=num_warps,
604
+ num_stages=1,
605
+ )
606
+
607
+
608
+ @torch.no_grad()
609
+ def sparse_flash_decode_stage2(
610
+ q,
611
+ k,
612
+ v,
613
+ Req_to_tokens,
614
+ Topk_token_indices,
615
+ heavy_token_num,
616
+ mid_out,
617
+ mid_out_logsumexp,
618
+ block_seq,
619
+ sm_scale,
620
+ ):
621
+ BLOCK_SEQ = block_seq
622
+ BLOCK_N = 16
623
+ assert BLOCK_SEQ % BLOCK_N == 0
624
+ # shape constraints
625
+ Lq, Lk = q.shape[-1], k.shape[-1]
626
+ assert Lq == Lk
627
+ assert Lk in {16, 32, 64, 128}
628
+ assert heavy_token_num == Topk_token_indices.shape[-1]
629
+ # sm_scale = 1.0 / (Lk ** 0.5)
630
+ batch, head_num = q.shape[0], q.shape[1]
631
+ grid = (batch, head_num, triton.cdiv(heavy_token_num, BLOCK_SEQ))
632
+
633
+ gqa_group_size = q.shape[1] // k.shape[1]
634
+
635
+ _sparse_fwd_kernel_flash_decode_stage2[grid](
636
+ q,
637
+ k,
638
+ v,
639
+ sm_scale,
640
+ Req_to_tokens,
641
+ Topk_token_indices,
642
+ mid_out,
643
+ mid_out_logsumexp,
644
+ heavy_token_num,
645
+ Req_to_tokens.stride(0),
646
+ Topk_token_indices.stride(0),
647
+ Topk_token_indices.stride(1),
648
+ q.stride(0),
649
+ q.stride(1),
650
+ k.stride(0),
651
+ k.stride(1),
652
+ v.stride(0),
653
+ v.stride(1),
654
+ mid_out.stride(0),
655
+ mid_out.stride(1),
656
+ mid_out.stride(2),
657
+ mid_out_logsumexp.stride(0),
658
+ mid_out_logsumexp.stride(1),
659
+ gqa_group_size,
660
+ BLOCK_SEQ=BLOCK_SEQ,
661
+ BLOCK_DMODEL=Lk,
662
+ BLOCK_N=BLOCK_N,
663
+ num_warps=1,
664
+ num_stages=2,
665
+ )
666
+ return
667
+
668
+
669
+ @torch.no_grad()
670
+ def sparse_flash_decode_stage3(Seqlen, mid_out, mid_out_logexpsum, O, block_seq):
671
+ Lk = mid_out.shape[-1]
672
+ assert Lk in {16, 32, 64, 128}
673
+ batch, head_num = mid_out.shape[0], mid_out.shape[1]
674
+ grid = (batch, head_num)
675
+
676
+ _sparse_fwd_kernel_flash_decode_stage3[grid](
677
+ mid_out,
678
+ mid_out_logexpsum,
679
+ O,
680
+ Seqlen,
681
+ mid_out.stride(0),
682
+ mid_out.stride(1),
683
+ mid_out.stride(2),
684
+ mid_out_logexpsum.stride(0),
685
+ mid_out_logexpsum.stride(1),
686
+ O.stride(0),
687
+ O.stride(1),
688
+ BLOCK_SEQ=block_seq,
689
+ BLOCK_DMODEL=Lk,
690
+ num_warps=4,
691
+ num_stages=2,
692
+ )
693
+ return
694
+
695
+
696
+ def flash_decode_sparse_attention_fwd(
697
+ q,
698
+ k_buffer,
699
+ v_buffer,
700
+ o,
701
+ q_label,
702
+ k_label_buffer,
703
+ req_to_token,
704
+ b_seq_len,
705
+ max_len_in_batch,
706
+ sm_scale,
707
+ logit_cap,
708
+ heavy_token_num=32,
709
+ att_out_approx=None,
710
+ mid_out=None,
711
+ mid_o_logexpsum=None,
712
+ BLOCK_SEQ=256,
713
+ ):
714
+ # TODO(Andy): Tune BLOCK_SEQ & BLOCK_D
715
+ kv_group_num = q.shape[1] // v_buffer.shape[1]
716
+ # batch_size = q.shape[0]
717
+
718
+ # Step 1: BGEMV approximate attention (page implementation)
719
+
720
+ if att_out_approx is None:
721
+ att_out_approx = torch.empty(
722
+ [q.shape[1], q.shape[0], max_len_in_batch],
723
+ dtype=REDUCE_TORCH_TYPE,
724
+ device=q.device,
725
+ )
726
+
727
+ if mid_out is None:
728
+ block_seq_num = (heavy_token_num + BLOCK_SEQ - 1) // BLOCK_SEQ
729
+
730
+ mid_out = torch.empty(
731
+ [q.shape[0], q.shape[1], block_seq_num, q.shape[-1]],
732
+ dtype=torch.float32,
733
+ device=q.device,
734
+ )
735
+ mid_o_logexpsum = torch.empty(
736
+ [q.shape[0], q.shape[1], block_seq_num],
737
+ dtype=torch.float32,
738
+ device=q.device,
739
+ )
740
+
741
+ sparse_flash_decode_stage1(
742
+ q_label,
743
+ k_label_buffer,
744
+ att_out_approx,
745
+ req_to_token,
746
+ b_seq_len,
747
+ max_len_in_batch,
748
+ sm_scale,
749
+ logit_cap,
750
+ )
751
+
752
+ # Step 2: TopK token selection
753
+ # NOTE(Andy): Apply sparse decoding when min > heavy_token_num and max > sparse decoding threshold
754
+ # TODO(Andy): Change a faster topk implementation
755
+ topk_token_indices = torch.topk(att_out_approx, heavy_token_num, dim=-1).indices
756
+ # topk_token_indices: [H, B, k], Req_to_tokens: [B, S]
757
+ # topk_token_indices = torch.arange(0, heavy_token_num, device=q.device).unsqueeze(0).unsqueeze(0).expand(q.shape[1], q.shape[0], -1)
758
+
759
+ sparse_flash_decode_stage2(
760
+ q,
761
+ k_buffer,
762
+ v_buffer,
763
+ req_to_token,
764
+ topk_token_indices,
765
+ heavy_token_num,
766
+ mid_out,
767
+ mid_o_logexpsum,
768
+ BLOCK_SEQ,
769
+ sm_scale,
770
+ )
771
+
772
+ sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)