sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
1406
1406
  )
1407
1407
  metadata.page_table = self.decode_cuda_graph_metadata[
1408
1408
  "page_table_draft_decode"
1409
- ][req_pool_indices, :]
1409
+ ][:bs, :]
1410
1410
  self.decode_cuda_graph_metadata[bs] = metadata
1411
1411
  else:
1412
1412
  # When top k > 1, we need two specific draft decode metadata, and then merge states
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
1424
1424
  ][: bs + 1]
1425
1425
  metadata.page_table = self.draft_decode_metadata_topk_normal[
1426
1426
  "page_table"
1427
- ][req_pool_indices, :]
1427
+ ][:bs, :]
1428
1428
 
1429
1429
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1430
1430
  metadata_expand.cache_seqlens_int32 = (
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
1461
1461
  metadata.max_seq_len_k = seq_lens.max().item()
1462
1462
  # Precompute page table
1463
1463
  metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
1464
- req_pool_indices, :
1464
+ :bs, :
1465
1465
  ]
1466
1466
  # Precompute cumulative sequence lengths
1467
1467
  metadata.cu_seqlens_q = torch.arange(
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
1498
1498
  : (bs + 1)
1499
1499
  ]
1500
1500
 
1501
- metadata.page_table = self.target_verify_metadata["page_table"][
1502
- req_pool_indices, :
1503
- ]
1501
+ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
1504
1502
 
1505
1503
  self.target_verify_metadata[bs] = metadata
1506
1504
  else:
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
1519
1517
  ][: bs + 1]
1520
1518
  metadata.page_table = self.target_verify_metadata_topk_normal[
1521
1519
  "page_table"
1522
- ][req_pool_indices, :]
1520
+ ][:bs, :]
1523
1521
 
1524
1522
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1525
1523
  metadata_expand.cache_seqlens_int32 = (
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
1562
1560
  metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
1563
1561
  : (bs + 1)
1564
1562
  ]
1565
- metadata.page_table = self.draft_extend_metadata["page_table"][
1566
- req_pool_indices, :
1567
- ]
1563
+ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
1568
1564
 
1569
1565
  self.draft_extend_metadata[bs] = metadata
1570
1566
 
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1574
  ][: (encoder_bs + 1)]
1579
1575
 
1580
1576
  metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
1581
- req_pool_indices, :
1577
+ :bs, :
1582
1578
  ]
1583
1579
 
1584
1580
  self.forward_metadata = metadata
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
88
88
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
89
 
90
90
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
91
92
 
92
93
  if not self.skip_prefill:
93
94
  self.qo_indptr = torch.zeros(
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
197
198
  forward_batch.req_pool_indices,
198
199
  bs,
199
200
  self.device,
201
+ self.token_to_kv_pool_allocator,
200
202
  )
201
203
  )
202
204
  window_num_kv_splits = torch.empty(
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
225
227
  mask_indptr = None
226
228
  max_extend_len = None
227
229
  elif forward_batch.forward_mode.is_target_verify():
228
- # TODO: Support sliding window in spec inference
229
230
  bs = len(forward_batch.req_pool_indices)
230
231
  qo_indptr = torch.arange(
231
232
  0,
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
250
251
  self.req_to_token.stride(0),
251
252
  )
252
253
 
254
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
255
+ window_kv_indptr, window_kv_indices, window_kv_lens = (
256
+ update_sliding_window_buffer(
257
+ self.window_kv_indptr,
258
+ self.req_to_token,
259
+ self.sliding_window_size,
260
+ forward_batch.seq_lens,
261
+ forward_batch.req_pool_indices,
262
+ bs,
263
+ self.device,
264
+ self.token_to_kv_pool_allocator,
265
+ )
266
+ )
267
+
253
268
  custom_mask = spec_info.custom_mask
254
269
  seq_mask_len = self.num_draft_tokens * (
255
270
  forward_batch.seq_lens + self.num_draft_tokens
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
308
323
  forward_batch.req_pool_indices,
309
324
  bs,
310
325
  self.device,
326
+ self.token_to_kv_pool_allocator,
311
327
  )
312
328
 
313
329
  qo_indptr = self.qo_indptr
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
423
439
  ):
424
440
  window_kv_indices = self.cuda_graph_window_kv_indices
425
441
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
426
- window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
427
- self.window_kv_indptr,
428
- window_kv_indices,
429
- self.req_to_token,
430
- self.sliding_window_size,
431
- seq_lens[:bs],
432
- req_pool_indices,
433
- bs,
442
+ window_kv_indptr, window_kv_indices, _ = (
443
+ update_sliding_window_buffer_cuda_graph(
444
+ self.window_kv_indptr,
445
+ window_kv_indices,
446
+ self.req_to_token,
447
+ self.sliding_window_size,
448
+ seq_lens[:bs],
449
+ req_pool_indices,
450
+ bs,
451
+ self.token_to_kv_pool_allocator,
452
+ )
434
453
  )
435
454
  else:
436
455
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
464
483
  self.req_to_token.stride(0),
465
484
  )
466
485
 
486
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
487
+ window_kv_indices = self.cuda_graph_window_kv_indices
488
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
489
+ window_kv_indptr, window_kv_indices, _ = (
490
+ update_sliding_window_buffer_cuda_graph(
491
+ self.window_kv_indptr,
492
+ window_kv_indices,
493
+ self.req_to_token,
494
+ self.sliding_window_size,
495
+ seq_lens,
496
+ req_pool_indices,
497
+ bs,
498
+ self.token_to_kv_pool_allocator,
499
+ )
500
+ )
501
+
467
502
  custom_mask = self.cuda_graph_custom_mask
468
503
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
469
504
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
557
592
  ):
558
593
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
559
594
  window_kv_indices = self.cuda_graph_window_kv_indices
560
- _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
595
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
561
596
  self.window_kv_indptr,
562
597
  window_kv_indices,
563
598
  self.req_to_token,
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
565
600
  seq_lens[:bs],
566
601
  req_pool_indices[:bs],
567
602
  bs,
603
+ self.token_to_kv_pool_allocator,
568
604
  )
569
605
  self.get_num_kv_splits(
570
606
  window_num_kv_splits[:num_token], window_kv_lens[:bs]
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
599
635
  kv_indices,
600
636
  self.req_to_token.stride(0),
601
637
  )
638
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
639
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
640
+ window_kv_indices = self.cuda_graph_window_kv_indices
641
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
642
+ self.window_kv_indptr,
643
+ window_kv_indices,
644
+ self.req_to_token,
645
+ self.sliding_window_size,
646
+ seq_lens,
647
+ req_pool_indices,
648
+ bs,
649
+ self.token_to_kv_pool_allocator,
650
+ )
602
651
  custom_mask = self.cuda_graph_custom_mask
603
652
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
604
653
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
637
686
  layer: RadixAttention,
638
687
  forward_batch: ForwardBatch,
639
688
  save_kv_cache=True,
689
+ sinks=None,
640
690
  ):
641
691
  # TODO: reuse the buffer across layers
642
692
  if layer.qk_head_dim != layer.v_head_dim:
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
680
730
  self.forward_metadata.max_extend_len,
681
731
  layer.scaling,
682
732
  layer.logit_cap,
683
- sliding_window_size,
733
+ sliding_window_size=sliding_window_size,
734
+ sinks=sinks,
684
735
  )
685
736
  return o
686
737
 
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
692
743
  layer: RadixAttention,
693
744
  forward_batch: ForwardBatch,
694
745
  save_kv_cache=True,
746
+ sinks=None,
695
747
  ):
696
748
  # During torch.compile, there is a bug in rotary_emb that causes the
697
749
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
728
780
  self.max_kv_splits,
729
781
  layer.scaling,
730
782
  layer.logit_cap,
783
+ sinks=sinks,
731
784
  )
732
785
  return o
733
786
 
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
932
985
  req_pool_indices,
933
986
  bs,
934
987
  device,
988
+ token_to_kv_pool_allocator=None,
935
989
  ):
936
990
  window_kv_lens = torch.minimum(
937
991
  seq_lens,
938
- torch.tensor(sliding_window_size + 1),
992
+ torch.tensor(sliding_window_size),
939
993
  )
940
994
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
941
995
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
952
1006
  window_kv_indices,
953
1007
  req_to_token.stride(0),
954
1008
  )
1009
+ # full to swa index mapping
1010
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1011
+ kv_last_index = window_kv_indptr[-1]
1012
+ window_kv_indices[:kv_last_index] = (
1013
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1014
+ window_kv_indices[:kv_last_index]
1015
+ )
1016
+ )
955
1017
  return window_kv_indptr, window_kv_indices, window_kv_lens
956
1018
 
957
1019
 
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
963
1025
  seq_lens,
964
1026
  req_pool_indices,
965
1027
  bs,
1028
+ token_to_kv_pool_allocator=None,
966
1029
  ):
967
1030
  window_kv_lens = torch.minimum(
968
1031
  seq_lens,
969
- torch.tensor(sliding_window_size + 1),
1032
+ torch.tensor(sliding_window_size),
970
1033
  )
971
1034
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
972
1035
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
980
1043
  window_kv_indices,
981
1044
  req_to_token.stride(0),
982
1045
  )
983
- return window_kv_indptr, window_kv_lens
1046
+ # full to swa index mapping
1047
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1048
+ kv_last_index = window_kv_indptr[-1]
1049
+ window_kv_indices[:kv_last_index] = (
1050
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1051
+ window_kv_indices[:kv_last_index]
1052
+ )
1053
+ )
1054
+ return window_kv_indptr, window_kv_indices, window_kv_lens
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
495
495
  O,
496
496
  kv_indptr,
497
497
  num_kv_splits,
498
+ sink_ptr,
498
499
  stride_mid_ob,
499
500
  stride_mid_oh,
500
501
  stride_mid_os,
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
504
505
  MIN_BLOCK_KV: tl.constexpr,
505
506
  BLOCK_DV: tl.constexpr,
506
507
  Lv: tl.constexpr,
508
+ HAS_SINK: tl.constexpr,
507
509
  ):
508
510
  cur_batch = tl.program_id(0)
509
511
  cur_head = tl.program_id(1)
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
545
547
  e_sum = e_sum * old_scale + exp_logic
546
548
  e_max = n_e_max
547
549
 
550
+ if HAS_SINK:
551
+ cur_sink = tl.load(sink_ptr + cur_head)
552
+ e_sum += tl.exp(cur_sink - e_max)
553
+
548
554
  tl.store(
549
555
  O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
550
556
  acc / e_sum,
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
561
567
  kv_indptr,
562
568
  num_kv_splits,
563
569
  max_kv_splits,
570
+ sinks=None,
564
571
  ):
565
572
  batch, head_num = q.shape[0], q.shape[1]
566
573
  Lv = v_buffer.shape[-1]
567
574
  BLOCK_DV = triton.next_power_of_2(Lv)
568
575
 
569
576
  MAX_KV_SPLITS = max_kv_splits
577
+ HAS_SINK = sinks is not None
570
578
 
571
579
  extra_kargs = {}
572
580
  if _is_hip:
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
581
589
  o,
582
590
  kv_indptr,
583
591
  num_kv_splits,
592
+ sinks,
584
593
  logits.stride(0),
585
594
  logits.stride(1),
586
595
  logits.stride(2),
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
590
599
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
591
600
  BLOCK_DV=BLOCK_DV,
592
601
  Lv=Lv,
602
+ HAS_SINK=HAS_SINK,
593
603
  num_warps=4,
594
604
  num_stages=2,
595
605
  **extra_kargs,
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
609
619
  max_kv_splits,
610
620
  sm_scale,
611
621
  logit_cap=0.0,
622
+ sinks=None,
612
623
  ):
613
624
  _decode_att_m_fwd(
614
625
  q,
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
632
643
  kv_indptr,
633
644
  num_kv_splits,
634
645
  max_kv_splits,
646
+ sinks,
635
647
  )
636
648
 
637
649
 
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
648
660
  max_kv_splits,
649
661
  sm_scale,
650
662
  logit_cap=0.0,
663
+ sinks=None,
651
664
  ):
652
665
  _decode_grouped_att_m_fwd(
653
666
  q,
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
671
684
  kv_indptr,
672
685
  num_kv_splits,
673
686
  max_kv_splits,
687
+ sinks,
674
688
  )
675
689
 
676
690
 
@@ -687,6 +701,7 @@ def decode_attention_fwd(
687
701
  max_kv_splits,
688
702
  sm_scale,
689
703
  logit_cap=0.0,
704
+ sinks=None,
690
705
  ):
691
706
  assert max_kv_splits == attn_logits.shape[2]
692
707
  assert q.shape[0] <= kv_indptr.shape[0] - 1
@@ -709,6 +724,7 @@ def decode_attention_fwd(
709
724
  max_kv_splits,
710
725
  sm_scale,
711
726
  logit_cap=logit_cap,
727
+ sinks=sinks,
712
728
  )
713
729
  else:
714
730
  # GQA/MQA/MLA
@@ -725,4 +741,5 @@ def decode_attention_fwd(
725
741
  max_kv_splits,
726
742
  sm_scale,
727
743
  logit_cap=logit_cap,
744
+ sinks=sinks,
728
745
  )