sglang 0.4.2.post1__py3-none-any.whl → 0.4.2.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/srt/constrained/outlines_backend.py +9 -1
  2. sglang/srt/custom_op.py +40 -0
  3. sglang/srt/entrypoints/engine.py +2 -2
  4. sglang/srt/function_call_parser.py +96 -69
  5. sglang/srt/layers/activation.py +10 -5
  6. sglang/srt/layers/attention/double_sparsity_backend.py +1 -3
  7. sglang/srt/layers/attention/flashinfer_backend.py +284 -39
  8. sglang/srt/layers/attention/triton_backend.py +124 -12
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +53 -59
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +337 -3
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +70 -42
  12. sglang/srt/layers/layernorm.py +1 -5
  13. sglang/srt/layers/moe/ep_moe/layer.py +1 -3
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +200 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +178 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +200 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +175 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -13
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -3
  24. sglang/srt/layers/moe/topk.py +4 -0
  25. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  46. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/fp8_kernel.py +173 -2
  48. sglang/srt/layers/rotary_embedding.py +1 -3
  49. sglang/srt/layers/sampler.py +4 -4
  50. sglang/srt/lora/backend/__init__.py +8 -0
  51. sglang/srt/lora/backend/base_backend.py +95 -0
  52. sglang/srt/lora/backend/flashinfer_backend.py +91 -0
  53. sglang/srt/lora/backend/triton_backend.py +61 -0
  54. sglang/srt/lora/lora.py +127 -112
  55. sglang/srt/lora/lora_manager.py +50 -18
  56. sglang/srt/lora/triton_ops/__init__.py +5 -0
  57. sglang/srt/lora/triton_ops/qkv_lora_b.py +182 -0
  58. sglang/srt/lora/triton_ops/sgemm_lora_a.py +143 -0
  59. sglang/srt/lora/triton_ops/sgemm_lora_b.py +159 -0
  60. sglang/srt/model_executor/cuda_graph_runner.py +77 -80
  61. sglang/srt/model_executor/forward_batch_info.py +58 -59
  62. sglang/srt/model_executor/model_runner.py +2 -2
  63. sglang/srt/models/llama.py +8 -3
  64. sglang/srt/models/qwen2_vl.py +1 -1
  65. sglang/srt/server_args.py +13 -2
  66. sglang/srt/speculative/build_eagle_tree.py +486 -104
  67. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +213 -0
  68. sglang/srt/speculative/eagle_utils.py +420 -401
  69. sglang/srt/speculative/eagle_worker.py +177 -45
  70. sglang/srt/utils.py +7 -0
  71. sglang/test/runners.py +2 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/METADATA +15 -6
  74. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/RECORD +77 -38
  75. sglang/srt/layers/custom_op_util.py +0 -25
  76. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.2.post1.dist-info → sglang-0.4.2.post3.dist-info}/top_level.txt +0 -0
@@ -49,11 +49,9 @@ def _fwd_kernel_stage1(
49
49
  K_Buffer,
50
50
  V_Buffer,
51
51
  sm_scale,
52
- Req_to_tokens,
53
- B_req_idx,
54
- B_Seqlen,
52
+ kv_indptr,
53
+ kv_indices,
55
54
  Att_Out,
56
- stride_req_to_tokens_b,
57
55
  stride_qbs,
58
56
  stride_qh,
59
57
  stride_buf_kbs,
@@ -82,8 +80,9 @@ def _fwd_kernel_stage1(
82
80
  offs_dv = tl.arange(0, BLOCK_DV)
83
81
  mask_d = offs_d < Lk
84
82
  mask_dv = offs_dv < Lv
85
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
86
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
83
+
84
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
85
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
87
86
 
88
87
  off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
89
88
  q = tl.load(Q + off_q, mask=mask_d, other=0.0)
@@ -100,7 +99,7 @@ def _fwd_kernel_stage1(
100
99
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
101
100
  offs_n = start_n + tl.arange(0, BLOCK_N)
102
101
  kv_loc = tl.load(
103
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
102
+ kv_indices + cur_batch_kv_start_idx + offs_n,
104
103
  mask=offs_n < split_kv_end,
105
104
  other=0,
106
105
  )
@@ -173,19 +172,21 @@ def _decode_att_m_fwd(
173
172
  k_buffer,
174
173
  v_buffer,
175
174
  att_out,
176
- Req_to_tokens,
177
- B_req_idx,
178
- B_Seqlen,
175
+ kv_indptr,
176
+ kv_indices,
179
177
  num_kv_splits,
180
178
  sm_scale,
181
179
  logit_cap,
182
180
  ):
183
181
  BLOCK = 64
182
+ # [TODO] work around SGPR limit on MI3xx
183
+ if is_hip_:
184
+ BLOCK = 8
184
185
  NUM_KV_SPLITS = num_kv_splits
185
186
  Lk = k_buffer.shape[-1]
186
187
  Lv = v_buffer.shape[-1]
187
188
 
188
- batch, head_num = B_req_idx.shape[0], q.shape[1]
189
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
189
190
 
190
191
  grid = (batch, head_num, NUM_KV_SPLITS)
191
192
  kv_group_num = q.shape[1] // k_buffer.shape[1]
@@ -194,6 +195,8 @@ def _decode_att_m_fwd(
194
195
  num_warps = 4
195
196
  else:
196
197
  num_warps = 2
198
+ if is_hip_:
199
+ num_warps = 1
197
200
 
198
201
  BLOCK_DMODEL = triton.next_power_of_2(Lk)
199
202
  BLOCK_DV = triton.next_power_of_2(Lv)
@@ -203,11 +206,9 @@ def _decode_att_m_fwd(
203
206
  k_buffer,
204
207
  v_buffer,
205
208
  sm_scale,
206
- Req_to_tokens,
207
- B_req_idx,
208
- B_Seqlen,
209
+ kv_indptr,
210
+ kv_indices,
209
211
  att_out,
210
- Req_to_tokens.stride(0),
211
212
  q.stride(0),
212
213
  q.stride(1),
213
214
  k_buffer.stride(0),
@@ -236,11 +237,9 @@ def _fwd_grouped_kernel_stage1(
236
237
  K_Buffer,
237
238
  V_Buffer,
238
239
  sm_scale,
239
- Req_to_tokens,
240
- B_req_idx,
241
- B_Seqlen,
240
+ kv_indptr,
241
+ kv_indices,
242
242
  Att_Out,
243
- stride_req_to_tokens_b,
244
243
  stride_qbs,
245
244
  stride_qh,
246
245
  stride_buf_kbs,
@@ -279,8 +278,9 @@ def _fwd_grouped_kernel_stage1(
279
278
  offs_dv = tl.arange(0, BLOCK_DV)
280
279
  mask_d = offs_d < Lk
281
280
  mask_dv = offs_dv < Lv
282
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
283
- cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
281
+
282
+ cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch)
283
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx
284
284
 
285
285
  offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
286
286
  q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
@@ -307,7 +307,7 @@ def _fwd_grouped_kernel_stage1(
307
307
  for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
308
308
  offs_n = start_n + tl.arange(0, BLOCK_N)
309
309
  kv_loc = tl.load(
310
- Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n,
310
+ kv_indices + cur_batch_kv_start_idx + offs_n,
311
311
  mask=offs_n < split_kv_end,
312
312
  other=0,
313
313
  )
@@ -395,9 +395,8 @@ def _decode_grouped_att_m_fwd(
395
395
  k_buffer,
396
396
  v_buffer,
397
397
  att_out,
398
- Req_to_tokens,
399
- B_req_idx,
400
- B_Seqlen,
398
+ kv_indptr,
399
+ kv_indices,
401
400
  num_kv_splits,
402
401
  sm_scale,
403
402
  logit_cap,
@@ -421,7 +420,7 @@ def _decode_grouped_att_m_fwd(
421
420
  BLOCK_DPE = 0
422
421
  BLOCK_DV = triton.next_power_of_2(Lv)
423
422
 
424
- batch, head_num = B_req_idx.shape[0], q.shape[1]
423
+ batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
425
424
  kv_group_num = q.shape[1] // k_buffer.shape[1]
426
425
 
427
426
  BLOCK_H = 16
@@ -433,21 +432,21 @@ def _decode_grouped_att_m_fwd(
433
432
  )
434
433
 
435
434
  extra_kargs = {}
435
+ num_stages = 2
436
436
  if is_hip_:
437
437
  # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
438
438
  # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
439
- extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
439
+ extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
440
+ num_stages = 1
440
441
 
441
442
  _fwd_grouped_kernel_stage1[grid](
442
443
  q,
443
444
  k_buffer,
444
445
  v_buffer,
445
446
  sm_scale,
446
- Req_to_tokens,
447
- B_req_idx,
448
- B_Seqlen,
447
+ kv_indptr,
448
+ kv_indices,
449
449
  att_out,
450
- Req_to_tokens.stride(0),
451
450
  q.stride(0),
452
451
  q.stride(1),
453
452
  k_buffer.stride(0),
@@ -467,7 +466,7 @@ def _decode_grouped_att_m_fwd(
467
466
  NUM_KV_SPLITS=NUM_KV_SPLITS,
468
467
  logit_cap=logit_cap,
469
468
  num_warps=4,
470
- num_stages=2,
469
+ num_stages=num_stages,
471
470
  Lk=Lk,
472
471
  Lv=Lv,
473
472
  **extra_kargs,
@@ -478,7 +477,7 @@ def _decode_grouped_att_m_fwd(
478
477
  def _fwd_kernel_stage2(
479
478
  Mid_O,
480
479
  O,
481
- B_Seqlen,
480
+ kv_indptr,
482
481
  stride_mid_ob,
483
482
  stride_mid_oh,
484
483
  stride_mid_os,
@@ -491,7 +490,9 @@ def _fwd_kernel_stage2(
491
490
  cur_batch = tl.program_id(0)
492
491
  cur_head = tl.program_id(1)
493
492
 
494
- cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
493
+ cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(
494
+ kv_indptr + cur_batch
495
+ )
495
496
 
496
497
  offs_d = tl.arange(0, BLOCK_DV)
497
498
  mask_d = offs_d < Lv
@@ -535,7 +536,7 @@ def _decode_softmax_reducev_fwd(
535
536
  q,
536
537
  o,
537
538
  v_buffer,
538
- b_seq_len,
539
+ kv_indptr,
539
540
  num_kv_splits,
540
541
  ):
541
542
  batch, head_num = q.shape[0], q.shape[1]
@@ -554,7 +555,7 @@ def _decode_softmax_reducev_fwd(
554
555
  _fwd_kernel_stage2[grid](
555
556
  logits,
556
557
  o,
557
- b_seq_len,
558
+ kv_indptr,
558
559
  logits.stride(0),
559
560
  logits.stride(1),
560
561
  logits.stride(2),
@@ -574,9 +575,8 @@ def decode_attention_fwd_normal(
574
575
  k_buffer,
575
576
  v_buffer,
576
577
  o,
577
- req_to_token,
578
- b_req_idx,
579
- b_seq_len,
578
+ kv_indptr,
579
+ kv_indices,
580
580
  attn_logits,
581
581
  num_kv_splits,
582
582
  sm_scale,
@@ -587,14 +587,13 @@ def decode_attention_fwd_normal(
587
587
  k_buffer,
588
588
  v_buffer,
589
589
  attn_logits,
590
- req_to_token,
591
- b_req_idx,
592
- b_seq_len,
590
+ kv_indptr,
591
+ kv_indices,
593
592
  num_kv_splits,
594
593
  sm_scale,
595
594
  logit_cap,
596
595
  )
597
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
596
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
598
597
 
599
598
 
600
599
  def decode_attention_fwd_grouped(
@@ -602,9 +601,8 @@ def decode_attention_fwd_grouped(
602
601
  k_buffer,
603
602
  v_buffer,
604
603
  o,
605
- req_to_token,
606
- b_req_idx,
607
- b_seq_len,
604
+ kv_indptr,
605
+ kv_indices,
608
606
  attn_logits,
609
607
  num_kv_splits,
610
608
  sm_scale,
@@ -615,14 +613,13 @@ def decode_attention_fwd_grouped(
615
613
  k_buffer,
616
614
  v_buffer,
617
615
  attn_logits,
618
- req_to_token,
619
- b_req_idx,
620
- b_seq_len,
616
+ kv_indptr,
617
+ kv_indices,
621
618
  num_kv_splits,
622
619
  sm_scale,
623
620
  logit_cap,
624
621
  )
625
- _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits)
622
+ _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits)
626
623
 
627
624
 
628
625
  def decode_attention_fwd(
@@ -630,9 +627,8 @@ def decode_attention_fwd(
630
627
  k_buffer,
631
628
  v_buffer,
632
629
  o,
633
- req_to_token,
634
- b_req_idx,
635
- b_seq_len,
630
+ kv_indptr,
631
+ kv_indices,
636
632
  attn_logits,
637
633
  num_kv_splits,
638
634
  sm_scale,
@@ -648,9 +644,8 @@ def decode_attention_fwd(
648
644
  k_buffer,
649
645
  v_buffer,
650
646
  o,
651
- req_to_token,
652
- b_req_idx,
653
- b_seq_len,
647
+ kv_indptr,
648
+ kv_indices,
654
649
  attn_logits,
655
650
  num_kv_splits,
656
651
  sm_scale,
@@ -663,9 +658,8 @@ def decode_attention_fwd(
663
658
  k_buffer,
664
659
  v_buffer,
665
660
  o,
666
- req_to_token,
667
- b_req_idx,
668
- b_seq_len,
661
+ kv_indptr,
662
+ kv_indices,
669
663
  attn_logits,
670
664
  num_kv_splits,
671
665
  sm_scale,
@@ -3,6 +3,13 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.managers.schedule_batch import global_server_args_dict
6
+ from sglang.srt.utils import is_hip
7
+
8
+ is_cuda_available = torch.cuda.is_available()
9
+ if is_cuda_available:
10
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
+
12
+ is_hip_ = is_hip()
6
13
 
7
14
  if global_server_args_dict.get("attention_reduce_in_fp32", False):
8
15
  REDUCE_TRITON_TYPE = tl.float32
@@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):
274
281
  return
275
282
 
276
283
 
277
- import torch
278
-
279
-
280
284
  def flash_decode_attention_fwd(
281
285
  q,
282
286
  k_buffer,
@@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd(
770
774
  )
771
775
 
772
776
  sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ)
777
+
778
+
779
+ # Extend attention kernel for Double Sparsity
780
+ # Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py
781
+ @triton.jit
782
+ def _fwd_kernel(
783
+ Q_Extend,
784
+ K_Extend,
785
+ V_Extend,
786
+ O_Extend,
787
+ K_Buffer,
788
+ V_Buffer,
789
+ Req_to_tokens,
790
+ B_req_idx,
791
+ B_Seq_Len,
792
+ B_Start_Loc_Extend,
793
+ B_Seq_Len_Extend,
794
+ sm_scale,
795
+ kv_group_num,
796
+ stride_qbs,
797
+ stride_qh,
798
+ stride_kbs,
799
+ stride_kh,
800
+ stride_vbs,
801
+ stride_vh,
802
+ stride_obs,
803
+ stride_oh,
804
+ stride_buf_kbs,
805
+ stride_buf_kh,
806
+ stride_buf_vbs,
807
+ stride_buf_vh,
808
+ stride_req_to_tokens_b,
809
+ logit_cap: tl.constexpr,
810
+ Lq: tl.constexpr,
811
+ Lv: tl.constexpr,
812
+ BLOCK_DMODEL: tl.constexpr,
813
+ BLOCK_DPE: tl.constexpr,
814
+ BLOCK_DV: tl.constexpr,
815
+ BLOCK_M: tl.constexpr,
816
+ BLOCK_N: tl.constexpr,
817
+ ):
818
+ cur_seq = tl.program_id(0)
819
+ cur_head = tl.program_id(1)
820
+ cur_block_m = tl.program_id(2)
821
+ cur_kv_head = cur_head // kv_group_num
822
+
823
+ cur_seq_len = tl.load(B_Seq_Len + cur_seq)
824
+ cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq)
825
+ cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend
826
+
827
+ cur_seq_prefix_start_in_loc = 0
828
+ cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq)
829
+ cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
830
+
831
+ offs_d = tl.arange(0, BLOCK_DMODEL)
832
+ offs_dv = tl.arange(0, BLOCK_DV)
833
+ offs_m = tl.arange(0, BLOCK_M)
834
+ mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
835
+
836
+ mask_d = offs_d < Lq
837
+ mask_dv = offs_dv < Lv
838
+
839
+ offs_q = (
840
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
841
+ * stride_qbs
842
+ + cur_head * stride_qh
843
+ + offs_d[None, :]
844
+ )
845
+ q = tl.load(
846
+ Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
847
+ )
848
+
849
+ if BLOCK_DPE > 0:
850
+ offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
851
+ offs_qpe = (
852
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
853
+ * stride_qbs
854
+ + cur_head * stride_qh
855
+ + offs_dpe[None, :]
856
+ )
857
+ qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
858
+
859
+ # stage 1: compute scores with prefix
860
+ offs_n = tl.arange(0, BLOCK_N)
861
+
862
+ acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
863
+ deno = tl.zeros([BLOCK_M], dtype=tl.float32)
864
+ e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
865
+
866
+ for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
867
+ start_n = tl.multiple_of(start_n, BLOCK_N)
868
+ mask_n = (start_n + offs_n) < cur_seq_len_prefix
869
+ offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + (
870
+ cur_seq_prefix_start_in_loc + start_n + offs_n
871
+ )
872
+ offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0)
873
+
874
+ # load k in transposed way
875
+ offs_buf_k = (
876
+ offs_kv_loc[None, :] * stride_buf_kbs
877
+ + cur_kv_head * stride_buf_kh
878
+ + offs_d[:, None]
879
+ )
880
+ k = tl.load(
881
+ K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
882
+ )
883
+
884
+ qk = tl.dot(q.to(k.dtype), k)
885
+ if BLOCK_DPE > 0:
886
+ offs_kpe = (
887
+ offs_kv_loc[None, :] * stride_buf_kbs
888
+ + cur_kv_head * stride_buf_kh
889
+ + offs_dpe[:, None]
890
+ )
891
+ kpe = tl.load(
892
+ K_Buffer + offs_kpe,
893
+ mask=mask_n[None, :],
894
+ other=0.0,
895
+ )
896
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
897
+ qk *= sm_scale
898
+
899
+ if logit_cap > 0:
900
+ qk = logit_cap * tanh(qk / logit_cap)
901
+
902
+ qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
903
+
904
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
905
+ re_scale = tl.exp(e_max - n_e_max)
906
+ p = tl.exp(qk - n_e_max[:, None])
907
+ deno = deno * re_scale + tl.sum(p, 1)
908
+
909
+ offs_buf_v = (
910
+ offs_kv_loc[:, None] * stride_buf_vbs
911
+ + cur_kv_head * stride_buf_vh
912
+ + offs_dv[None, :]
913
+ )
914
+ v = tl.load(
915
+ V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
916
+ )
917
+ p = p.to(v.dtype)
918
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
919
+
920
+ e_max = n_e_max
921
+
922
+ # stage 2: compute the trianlge part
923
+
924
+ cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
925
+ for start_n in range(0, cur_block_m_end, BLOCK_N):
926
+ start_n = tl.multiple_of(start_n, BLOCK_N)
927
+ mask_n = (start_n + offs_n) < cur_block_m_end
928
+
929
+ # load k in transposed way
930
+ offs_k = (
931
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs
932
+ + cur_kv_head * stride_kh
933
+ + offs_d[:, None]
934
+ )
935
+ k = tl.load(
936
+ K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
937
+ )
938
+
939
+ qk = tl.dot(q, k, out_dtype=tl.float32)
940
+ if BLOCK_DPE > 0:
941
+ offs_kpe = (
942
+ (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
943
+ * stride_kbs
944
+ + cur_kv_head * stride_kh
945
+ + offs_dpe[:, None]
946
+ )
947
+ kpe = tl.load(
948
+ K_Extend + offs_kpe,
949
+ mask=mask_n[None, :],
950
+ other=0.0,
951
+ )
952
+ qk += tl.dot(qpe, kpe)
953
+
954
+ qk *= sm_scale
955
+
956
+ if logit_cap > 0:
957
+ qk = logit_cap * tanh(qk / logit_cap)
958
+
959
+ mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
960
+ start_n + offs_n[None, :]
961
+ )
962
+ mask_causual &= mask_m[:, None] & mask_n[None, :]
963
+ qk = tl.where(mask_causual, qk, float("-inf"))
964
+
965
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
966
+ re_scale = tl.exp(e_max - n_e_max)
967
+ p = tl.exp(qk - n_e_max[:, None])
968
+ deno = deno * re_scale + tl.sum(p, 1)
969
+
970
+ offs_v = (
971
+ (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
972
+ + cur_kv_head * stride_vh
973
+ + offs_dv[None, :]
974
+ )
975
+ v = tl.load(
976
+ V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
977
+ )
978
+ p = p.to(v.dtype)
979
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
980
+
981
+ e_max = n_e_max
982
+
983
+ offs_o = (
984
+ (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
985
+ * stride_obs
986
+ + cur_head * stride_oh
987
+ + offs_dv[None, :]
988
+ )
989
+ tl.store(
990
+ O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
991
+ )
992
+
993
+
994
+ def extend_attention_fwd(
995
+ q_extend,
996
+ k_extend,
997
+ v_extend,
998
+ o_extend,
999
+ k_buffer,
1000
+ v_buffer,
1001
+ req_to_tokens,
1002
+ b_req_idx,
1003
+ b_seq_len,
1004
+ b_seq_len_extend,
1005
+ b_start_loc_extend,
1006
+ max_len_extend,
1007
+ sm_scale=None,
1008
+ logit_cap=0.0,
1009
+ ):
1010
+ """
1011
+ q_extend, k_extend, v_extend, o_extend: contiguous tensors
1012
+
1013
+ k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
1014
+ """
1015
+ Lq, Lk, Lv = (
1016
+ q_extend.shape[-1],
1017
+ k_extend.shape[-1],
1018
+ v_extend.shape[-1],
1019
+ )
1020
+
1021
+ if Lq == 576:
1022
+ BLOCK_DMODEL = 512
1023
+ BLOCK_DPE = 64
1024
+ elif Lq == 288:
1025
+ BLOCK_DMODEL = 256
1026
+ BLOCK_DPE = 32
1027
+ elif Lq == 192:
1028
+ BLOCK_DMODEL = 128
1029
+ BLOCK_DPE = 64
1030
+ else:
1031
+ BLOCK_DMODEL = triton.next_power_of_2(Lq)
1032
+ BLOCK_DPE = 0
1033
+ BLOCK_DV = triton.next_power_of_2(Lv)
1034
+
1035
+ if is_hip_:
1036
+ BLOCK_M, BLOCK_N = (64, 64)
1037
+ num_warps = 4
1038
+
1039
+ else:
1040
+ if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1041
+ if Lq <= 256:
1042
+ BLOCK_M, BLOCK_N = (128, 64)
1043
+ else:
1044
+ BLOCK_M, BLOCK_N = (32, 64)
1045
+ elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1046
+ if Lq <= 128:
1047
+ BLOCK_M, BLOCK_N = (128, 128)
1048
+ elif Lq <= 256:
1049
+ BLOCK_M, BLOCK_N = (64, 64)
1050
+ else:
1051
+ BLOCK_M, BLOCK_N = (32, 64)
1052
+ else:
1053
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
1054
+
1055
+ num_warps = 4 if Lk <= 64 else 8
1056
+
1057
+ sm_scale = sm_scale or 1.0 / (Lq**0.5)
1058
+ batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
1059
+ kv_group_num = q_extend.shape[1] // k_extend.shape[1]
1060
+
1061
+ grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
1062
+ num_stages = 1
1063
+
1064
+ extra_kargs = {}
1065
+ if is_hip_:
1066
+ extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
1067
+
1068
+ _fwd_kernel[grid](
1069
+ q_extend,
1070
+ k_extend,
1071
+ v_extend,
1072
+ o_extend,
1073
+ k_buffer,
1074
+ v_buffer,
1075
+ req_to_tokens,
1076
+ b_req_idx,
1077
+ b_seq_len,
1078
+ b_start_loc_extend,
1079
+ b_seq_len_extend,
1080
+ sm_scale,
1081
+ kv_group_num,
1082
+ q_extend.stride(0),
1083
+ q_extend.stride(1),
1084
+ k_extend.stride(0),
1085
+ k_extend.stride(1),
1086
+ v_extend.stride(0),
1087
+ v_extend.stride(1),
1088
+ o_extend.stride(0),
1089
+ o_extend.stride(1),
1090
+ k_buffer.stride(0),
1091
+ k_buffer.stride(1),
1092
+ v_buffer.stride(0),
1093
+ v_buffer.stride(1),
1094
+ req_to_tokens.stride(0),
1095
+ logit_cap=logit_cap,
1096
+ BLOCK_DMODEL=BLOCK_DMODEL,
1097
+ BLOCK_DPE=BLOCK_DPE,
1098
+ BLOCK_DV=BLOCK_DV,
1099
+ BLOCK_M=BLOCK_M,
1100
+ BLOCK_N=BLOCK_N,
1101
+ Lq=Lq,
1102
+ Lv=Lv,
1103
+ num_warps=num_warps,
1104
+ num_stages=num_stages,
1105
+ **extra_kargs,
1106
+ )