sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
88
88
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
89
 
90
90
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
91
92
 
92
93
  if not self.skip_prefill:
93
94
  self.qo_indptr = torch.zeros(
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
197
198
  forward_batch.req_pool_indices,
198
199
  bs,
199
200
  self.device,
201
+ self.token_to_kv_pool_allocator,
200
202
  )
201
203
  )
202
204
  window_num_kv_splits = torch.empty(
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
225
227
  mask_indptr = None
226
228
  max_extend_len = None
227
229
  elif forward_batch.forward_mode.is_target_verify():
228
- # TODO: Support sliding window in spec inference
229
230
  bs = len(forward_batch.req_pool_indices)
230
231
  qo_indptr = torch.arange(
231
232
  0,
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
250
251
  self.req_to_token.stride(0),
251
252
  )
252
253
 
254
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
255
+ window_kv_indptr, window_kv_indices, window_kv_lens = (
256
+ update_sliding_window_buffer(
257
+ self.window_kv_indptr,
258
+ self.req_to_token,
259
+ self.sliding_window_size,
260
+ forward_batch.seq_lens,
261
+ forward_batch.req_pool_indices,
262
+ bs,
263
+ self.device,
264
+ self.token_to_kv_pool_allocator,
265
+ )
266
+ )
267
+
253
268
  custom_mask = spec_info.custom_mask
254
269
  seq_mask_len = self.num_draft_tokens * (
255
270
  forward_batch.seq_lens + self.num_draft_tokens
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
308
323
  forward_batch.req_pool_indices,
309
324
  bs,
310
325
  self.device,
326
+ self.token_to_kv_pool_allocator,
311
327
  )
312
328
 
313
329
  qo_indptr = self.qo_indptr
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
423
439
  ):
424
440
  window_kv_indices = self.cuda_graph_window_kv_indices
425
441
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
426
- window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
427
- self.window_kv_indptr,
428
- window_kv_indices,
429
- self.req_to_token,
430
- self.sliding_window_size,
431
- seq_lens[:bs],
432
- req_pool_indices,
433
- bs,
442
+ window_kv_indptr, window_kv_indices, _ = (
443
+ update_sliding_window_buffer_cuda_graph(
444
+ self.window_kv_indptr,
445
+ window_kv_indices,
446
+ self.req_to_token,
447
+ self.sliding_window_size,
448
+ seq_lens[:bs],
449
+ req_pool_indices,
450
+ bs,
451
+ self.token_to_kv_pool_allocator,
452
+ )
434
453
  )
435
454
  else:
436
455
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
464
483
  self.req_to_token.stride(0),
465
484
  )
466
485
 
486
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
487
+ window_kv_indices = self.cuda_graph_window_kv_indices
488
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
489
+ window_kv_indptr, window_kv_indices, _ = (
490
+ update_sliding_window_buffer_cuda_graph(
491
+ self.window_kv_indptr,
492
+ window_kv_indices,
493
+ self.req_to_token,
494
+ self.sliding_window_size,
495
+ seq_lens,
496
+ req_pool_indices,
497
+ bs,
498
+ self.token_to_kv_pool_allocator,
499
+ )
500
+ )
501
+
467
502
  custom_mask = self.cuda_graph_custom_mask
468
503
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
469
504
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
557
592
  ):
558
593
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
559
594
  window_kv_indices = self.cuda_graph_window_kv_indices
560
- _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
595
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
561
596
  self.window_kv_indptr,
562
597
  window_kv_indices,
563
598
  self.req_to_token,
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
565
600
  seq_lens[:bs],
566
601
  req_pool_indices[:bs],
567
602
  bs,
603
+ self.token_to_kv_pool_allocator,
568
604
  )
569
605
  self.get_num_kv_splits(
570
606
  window_num_kv_splits[:num_token], window_kv_lens[:bs]
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
599
635
  kv_indices,
600
636
  self.req_to_token.stride(0),
601
637
  )
638
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
639
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
640
+ window_kv_indices = self.cuda_graph_window_kv_indices
641
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
642
+ self.window_kv_indptr,
643
+ window_kv_indices,
644
+ self.req_to_token,
645
+ self.sliding_window_size,
646
+ seq_lens,
647
+ req_pool_indices,
648
+ bs,
649
+ self.token_to_kv_pool_allocator,
650
+ )
602
651
  custom_mask = self.cuda_graph_custom_mask
603
652
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
604
653
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
637
686
  layer: RadixAttention,
638
687
  forward_batch: ForwardBatch,
639
688
  save_kv_cache=True,
689
+ sinks=None,
640
690
  ):
641
691
  # TODO: reuse the buffer across layers
642
692
  if layer.qk_head_dim != layer.v_head_dim:
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
680
730
  self.forward_metadata.max_extend_len,
681
731
  layer.scaling,
682
732
  layer.logit_cap,
683
- sliding_window_size,
733
+ sliding_window_size=sliding_window_size,
734
+ sinks=sinks,
684
735
  )
685
736
  return o
686
737
 
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
692
743
  layer: RadixAttention,
693
744
  forward_batch: ForwardBatch,
694
745
  save_kv_cache=True,
746
+ sinks=None,
695
747
  ):
696
748
  # During torch.compile, there is a bug in rotary_emb that causes the
697
749
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
728
780
  self.max_kv_splits,
729
781
  layer.scaling,
730
782
  layer.logit_cap,
783
+ sinks=sinks,
731
784
  )
732
785
  return o
733
786
 
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
932
985
  req_pool_indices,
933
986
  bs,
934
987
  device,
988
+ token_to_kv_pool_allocator=None,
935
989
  ):
936
990
  window_kv_lens = torch.minimum(
937
991
  seq_lens,
938
- torch.tensor(sliding_window_size + 1),
992
+ torch.tensor(sliding_window_size),
939
993
  )
940
994
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
941
995
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
952
1006
  window_kv_indices,
953
1007
  req_to_token.stride(0),
954
1008
  )
1009
+ # full to swa index mapping
1010
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1011
+ kv_last_index = window_kv_indptr[-1]
1012
+ window_kv_indices[:kv_last_index] = (
1013
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1014
+ window_kv_indices[:kv_last_index]
1015
+ )
1016
+ )
955
1017
  return window_kv_indptr, window_kv_indices, window_kv_lens
956
1018
 
957
1019
 
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
963
1025
  seq_lens,
964
1026
  req_pool_indices,
965
1027
  bs,
1028
+ token_to_kv_pool_allocator=None,
966
1029
  ):
967
1030
  window_kv_lens = torch.minimum(
968
1031
  seq_lens,
969
- torch.tensor(sliding_window_size + 1),
1032
+ torch.tensor(sliding_window_size),
970
1033
  )
971
1034
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
972
1035
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
980
1043
  window_kv_indices,
981
1044
  req_to_token.stride(0),
982
1045
  )
983
- return window_kv_indptr, window_kv_lens
1046
+ # full to swa index mapping
1047
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1048
+ kv_last_index = window_kv_indptr[-1]
1049
+ window_kv_indices[:kv_last_index] = (
1050
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1051
+ window_kv_indices[:kv_last_index]
1052
+ )
1053
+ )
1054
+ return window_kv_indptr, window_kv_indices, window_kv_lens
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
495
495
  O,
496
496
  kv_indptr,
497
497
  num_kv_splits,
498
+ sink_ptr,
498
499
  stride_mid_ob,
499
500
  stride_mid_oh,
500
501
  stride_mid_os,
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
504
505
  MIN_BLOCK_KV: tl.constexpr,
505
506
  BLOCK_DV: tl.constexpr,
506
507
  Lv: tl.constexpr,
508
+ HAS_SINK: tl.constexpr,
507
509
  ):
508
510
  cur_batch = tl.program_id(0)
509
511
  cur_head = tl.program_id(1)
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
545
547
  e_sum = e_sum * old_scale + exp_logic
546
548
  e_max = n_e_max
547
549
 
550
+ if HAS_SINK:
551
+ cur_sink = tl.load(sink_ptr + cur_head)
552
+ e_sum += tl.exp(cur_sink - e_max)
553
+
548
554
  tl.store(
549
555
  O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
550
556
  acc / e_sum,
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
561
567
  kv_indptr,
562
568
  num_kv_splits,
563
569
  max_kv_splits,
570
+ sinks=None,
564
571
  ):
565
572
  batch, head_num = q.shape[0], q.shape[1]
566
573
  Lv = v_buffer.shape[-1]
567
574
  BLOCK_DV = triton.next_power_of_2(Lv)
568
575
 
569
576
  MAX_KV_SPLITS = max_kv_splits
577
+ HAS_SINK = sinks is not None
570
578
 
571
579
  extra_kargs = {}
572
580
  if _is_hip:
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
581
589
  o,
582
590
  kv_indptr,
583
591
  num_kv_splits,
592
+ sinks,
584
593
  logits.stride(0),
585
594
  logits.stride(1),
586
595
  logits.stride(2),
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
590
599
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
591
600
  BLOCK_DV=BLOCK_DV,
592
601
  Lv=Lv,
602
+ HAS_SINK=HAS_SINK,
593
603
  num_warps=4,
594
604
  num_stages=2,
595
605
  **extra_kargs,
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
609
619
  max_kv_splits,
610
620
  sm_scale,
611
621
  logit_cap=0.0,
622
+ sinks=None,
612
623
  ):
613
624
  _decode_att_m_fwd(
614
625
  q,
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
632
643
  kv_indptr,
633
644
  num_kv_splits,
634
645
  max_kv_splits,
646
+ sinks,
635
647
  )
636
648
 
637
649
 
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
648
660
  max_kv_splits,
649
661
  sm_scale,
650
662
  logit_cap=0.0,
663
+ sinks=None,
651
664
  ):
652
665
  _decode_grouped_att_m_fwd(
653
666
  q,
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
671
684
  kv_indptr,
672
685
  num_kv_splits,
673
686
  max_kv_splits,
687
+ sinks,
674
688
  )
675
689
 
676
690
 
@@ -687,6 +701,7 @@ def decode_attention_fwd(
687
701
  max_kv_splits,
688
702
  sm_scale,
689
703
  logit_cap=0.0,
704
+ sinks=None,
690
705
  ):
691
706
  assert max_kv_splits == attn_logits.shape[2]
692
707
  assert q.shape[0] <= kv_indptr.shape[0] - 1
@@ -709,6 +724,7 @@ def decode_attention_fwd(
709
724
  max_kv_splits,
710
725
  sm_scale,
711
726
  logit_cap=logit_cap,
727
+ sinks=sinks,
712
728
  )
713
729
  else:
714
730
  # GQA/MQA/MLA
@@ -725,4 +741,5 @@ def decode_attention_fwd(
725
741
  max_kv_splits,
726
742
  sm_scale,
727
743
  logit_cap=logit_cap,
744
+ sinks=sinks,
728
745
  )
@@ -51,6 +51,7 @@ def _fwd_kernel(
51
51
  kv_indices,
52
52
  mask_ptr,
53
53
  mask_indptr,
54
+ sink_ptr,
54
55
  sm_scale,
55
56
  kv_group_num,
56
57
  stride_qbs,
@@ -78,6 +79,7 @@ def _fwd_kernel(
78
79
  IS_CAUSAL: tl.constexpr,
79
80
  SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
80
81
  STORE_TRANSPOSE: tl.constexpr,
82
+ HAS_SINK: tl.constexpr,
81
83
  ):
82
84
  cur_seq = tl.program_id(0)
83
85
  cur_head = tl.program_id(1)
@@ -132,38 +134,6 @@ def _fwd_kernel(
132
134
  start_n = tl.multiple_of(start_n, BLOCK_N)
133
135
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
134
136
 
135
- offs_kv_loc = tl.load(
136
- kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
137
- )
138
-
139
- # load k in transposed way
140
- offs_buf_k = (
141
- offs_kv_loc[None, :] * stride_buf_kbs
142
- + cur_kv_head * stride_buf_kh
143
- + offs_d[:, None]
144
- )
145
- k = tl.load(
146
- K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
147
- )
148
-
149
- qk = tl.dot(q.to(k.dtype), k)
150
- if BLOCK_DPE > 0:
151
- offs_kpe = (
152
- offs_kv_loc[None, :] * stride_buf_kbs
153
- + cur_kv_head * stride_buf_kh
154
- + offs_dpe[:, None]
155
- )
156
- kpe = tl.load(
157
- K_Buffer + offs_kpe,
158
- mask=mask_n[None, :],
159
- other=0.0,
160
- )
161
- qk += tl.dot(qpe.to(kpe.dtype), kpe)
162
- qk *= sm_scale
163
-
164
- if logit_cap > 0:
165
- qk = logit_cap * tanh(qk / logit_cap)
166
-
167
137
  final_mask = mask_m[:, None] & mask_n[None, :]
168
138
  if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
169
139
  custom_mask = tl.load(
@@ -178,29 +148,77 @@ def _fwd_kernel(
178
148
  final_mask &= custom_mask
179
149
  if SLIDING_WINDOW_SIZE > 0:
180
150
  # Add mask where q_id <= kv_id + sliding_window_size
181
- window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
182
- start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
183
- )
151
+ # q_id = prefix_len + cur_m, kv_id = cur_n
152
+ window_mask = (
153
+ cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
154
+ ) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
184
155
  final_mask &= window_mask
185
- qk = tl.where(final_mask, qk, float("-inf"))
186
156
 
187
- n_e_max = tl.maximum(tl.max(qk, 1), e_max)
188
- re_scale = tl.exp(e_max - n_e_max)
189
- p = tl.exp(qk - n_e_max[:, None])
190
- deno = deno * re_scale + tl.sum(p, 1)
157
+ SKIP_TILE = False
158
+ if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
159
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
191
160
 
192
- offs_buf_v = (
193
- offs_kv_loc[:, None] * stride_buf_vbs
194
- + cur_kv_head * stride_buf_vh
195
- + offs_dv[None, :]
196
- )
197
- v = tl.load(
198
- V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
199
- )
200
- p = p.to(v.dtype)
201
- acc = acc * re_scale[:, None] + tl.dot(p, v)
161
+ if not SKIP_TILE:
162
+ offs_kv_loc = tl.load(
163
+ kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
164
+ mask=mask_n,
165
+ other=0,
166
+ )
202
167
 
203
- e_max = n_e_max
168
+ # load k in transposed way
169
+ offs_buf_k = (
170
+ offs_kv_loc[None, :] * stride_buf_kbs
171
+ + cur_kv_head * stride_buf_kh
172
+ + offs_d[:, None]
173
+ )
174
+ k = tl.load(
175
+ K_Buffer + offs_buf_k,
176
+ mask=(mask_n[None, :]) & (mask_d[:, None]),
177
+ other=0.0,
178
+ )
179
+
180
+ qk = tl.dot(q.to(k.dtype), k)
181
+ if BLOCK_DPE > 0:
182
+ offs_kpe = (
183
+ offs_kv_loc[None, :] * stride_buf_kbs
184
+ + cur_kv_head * stride_buf_kh
185
+ + offs_dpe[:, None]
186
+ )
187
+ kpe = tl.load(
188
+ K_Buffer + offs_kpe,
189
+ mask=mask_n[None, :],
190
+ other=0.0,
191
+ )
192
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
193
+ qk *= sm_scale
194
+
195
+ if logit_cap > 0:
196
+ qk = logit_cap * tanh(qk / logit_cap)
197
+
198
+ qk = tl.where(final_mask, qk, float("-inf"))
199
+
200
+ row_max = tl.max(qk, 1)
201
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
202
+ n_e_max = tl.maximum(row_max_fixed, e_max)
203
+
204
+ re_scale = tl.exp(e_max - n_e_max)
205
+ p = tl.exp(qk - n_e_max[:, None])
206
+ deno = deno * re_scale + tl.sum(p, 1)
207
+
208
+ offs_buf_v = (
209
+ offs_kv_loc[:, None] * stride_buf_vbs
210
+ + cur_kv_head * stride_buf_vh
211
+ + offs_dv[None, :]
212
+ )
213
+ v = tl.load(
214
+ V_Buffer + offs_buf_v,
215
+ mask=mask_n[:, None] & mask_dv[None, :],
216
+ other=0.0,
217
+ )
218
+ p = p.to(v.dtype)
219
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
220
+
221
+ e_max = n_e_max
204
222
 
205
223
  # stage 2: compute the triangle part
206
224
 
@@ -213,35 +231,7 @@ def _fwd_kernel(
213
231
  start_n = tl.multiple_of(start_n, BLOCK_N)
214
232
  mask_n = (start_n + offs_n) < cur_block_m_end
215
233
 
216
- # load k in transposed way
217
- offs_k = (
218
- (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
219
- + cur_kv_head * stride_kh
220
- + offs_d[:, None]
221
- )
222
- k = tl.load(
223
- K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
224
- )
225
-
226
- qk = tl.dot(q, k, out_dtype=tl.float32)
227
- if BLOCK_DPE > 0:
228
- offs_kpe = (
229
- (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
230
- + cur_kv_head * stride_kh
231
- + offs_dpe[:, None]
232
- )
233
- kpe = tl.load(
234
- K_Extend + offs_kpe,
235
- mask=mask_n[None, :],
236
- other=0.0,
237
- )
238
- qk += tl.dot(qpe, kpe)
239
-
240
- qk *= sm_scale
241
-
242
- if logit_cap > 0:
243
- qk = logit_cap * tanh(qk / logit_cap)
244
-
234
+ final_mask = mask_m[:, None] & mask_n[None, :]
245
235
  if USE_CUSTOM_MASK:
246
236
  custom_mask = tl.load(
247
237
  mask_ptr
@@ -254,34 +244,84 @@ def _fwd_kernel(
254
244
  other=0,
255
245
  )
256
246
  custom_mask &= mask_m[:, None] & mask_n[None, :]
257
- qk = tl.where(custom_mask, qk, float("-inf"))
247
+ final_mask &= custom_mask
258
248
  elif IS_CAUSAL:
259
249
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
260
250
  start_n + offs_n[None, :]
261
251
  )
262
252
  mask_causual &= mask_m[:, None] & mask_n[None, :]
263
- qk = tl.where(mask_causual, qk, float("-inf"))
253
+ final_mask &= mask_causual
264
254
  else:
265
255
  mask_non_causal = mask_m[:, None] & mask_n[None, :]
266
- qk = tl.where(mask_non_causal, qk, float("-inf"))
256
+ final_mask &= mask_non_causal
257
+
258
+ if SLIDING_WINDOW_SIZE > 0:
259
+ # Add mask where q_id <= kv_id + sliding_window_size
260
+ window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
261
+ start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
262
+ )
263
+ final_mask &= window_mask
267
264
 
268
- n_e_max = tl.maximum(tl.max(qk, 1), e_max)
269
- re_scale = tl.exp(e_max - n_e_max)
270
- p = tl.exp(qk - n_e_max[:, None])
271
- deno = deno * re_scale + tl.sum(p, 1)
265
+ SKIP_TILE = False
266
+ if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
267
+ SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
272
268
 
273
- offs_v = (
274
- (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
275
- + cur_kv_head * stride_vh
276
- + offs_dv[None, :]
277
- )
278
- v = tl.load(
279
- V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
280
- )
281
- p = p.to(v.dtype)
282
- acc = acc * re_scale[:, None] + tl.dot(p, v)
269
+ if not SKIP_TILE:
270
+ # load k in transposed way
271
+ offs_k = (
272
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
273
+ + cur_kv_head * stride_kh
274
+ + offs_d[:, None]
275
+ )
276
+ k = tl.load(
277
+ K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
278
+ )
283
279
 
284
- e_max = n_e_max
280
+ qk = tl.dot(q, k, out_dtype=tl.float32)
281
+ if BLOCK_DPE > 0:
282
+ offs_kpe = (
283
+ (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
284
+ + cur_kv_head * stride_kh
285
+ + offs_dpe[:, None]
286
+ )
287
+ kpe = tl.load(
288
+ K_Extend + offs_kpe,
289
+ mask=mask_n[None, :],
290
+ other=0.0,
291
+ )
292
+ qk += tl.dot(qpe, kpe)
293
+
294
+ qk *= sm_scale
295
+
296
+ if logit_cap > 0:
297
+ qk = logit_cap * tanh(qk / logit_cap)
298
+
299
+ qk = tl.where(final_mask, qk, float("-inf"))
300
+
301
+ row_max = tl.max(qk, 1)
302
+ row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
303
+ n_e_max = tl.maximum(row_max_fixed, e_max)
304
+
305
+ re_scale = tl.exp(e_max - n_e_max)
306
+ p = tl.exp(qk - n_e_max[:, None])
307
+ deno = deno * re_scale + tl.sum(p, 1)
308
+
309
+ offs_v = (
310
+ (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
311
+ + cur_kv_head * stride_vh
312
+ + offs_dv[None, :]
313
+ )
314
+ v = tl.load(
315
+ V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
316
+ )
317
+ p = p.to(v.dtype)
318
+ acc = acc * re_scale[:, None] + tl.dot(p, v)
319
+
320
+ e_max = n_e_max
321
+
322
+ if HAS_SINK:
323
+ cur_sink = tl.load(sink_ptr + cur_head)
324
+ deno += tl.exp(cur_sink - e_max)
285
325
 
286
326
  offs_o = (
287
327
  (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
@@ -321,6 +361,7 @@ def extend_attention_fwd(
321
361
  logit_cap=0.0,
322
362
  skip_prefix_custom_mask=True,
323
363
  sliding_window_size=-1,
364
+ sinks=None,
324
365
  ):
325
366
  """
326
367
  q_extend, k_extend, v_extend, o_extend: contiguous tensors
@@ -386,6 +427,8 @@ def extend_attention_fwd(
386
427
  # Skip custom mask for prefix part
387
428
  SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
388
429
 
430
+ HAS_SINK = sinks is not None
431
+
389
432
  grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
390
433
  num_stages = 1
391
434
 
@@ -405,6 +448,7 @@ def extend_attention_fwd(
405
448
  kv_indices,
406
449
  custom_mask,
407
450
  mask_indptr,
451
+ sinks,
408
452
  sm_scale,
409
453
  kv_group_num,
410
454
  q_extend.stride(0),
@@ -431,6 +475,7 @@ def extend_attention_fwd(
431
475
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
432
476
  IS_CAUSAL=is_causal,
433
477
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
478
+ HAS_SINK=HAS_SINK,
434
479
  STORE_TRANSPOSE=_is_hip,
435
480
  num_warps=num_warps,
436
481
  num_stages=num_stages,