sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
629
629
  # For multi-head latent attention
630
630
  q_rope: Optional[torch.Tensor] = None,
631
631
  k_rope: Optional[torch.Tensor] = None,
632
+ sinks: Optional[torch.Tensor] = None,
632
633
  ):
633
634
  if k is not None:
634
635
  assert v is not None
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
687
688
  forward_batch.forward_mode.is_target_verify() and self.topk > 1
688
689
  )
689
690
 
691
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
692
+ kwargs = {}
693
+ if sinks is not None:
694
+ kwargs["sinks"] = sinks
695
+
690
696
  # Get the appropriate page table based on whether we're using local attention
691
697
  if use_local_attn:
692
698
  local_metadata = metadata.local_attn_metadata
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
737
743
  k_descale=k_descale,
738
744
  v_descale=v_descale,
739
745
  return_softmax_lse=use_cascade_attn,
746
+ **kwargs,
740
747
  )
741
748
 
742
749
  if use_cascade_attn:
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
757
764
  k_descale=k_descale,
758
765
  v_descale=v_descale,
759
766
  return_softmax_lse=True,
767
+ **kwargs,
760
768
  )
761
769
  o, _ = merge_state_v2_wrapper(
762
770
  o,
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
898
906
  # For multi-head latent attention
899
907
  q_rope: Optional[torch.Tensor] = None,
900
908
  k_rope: Optional[torch.Tensor] = None,
909
+ sinks: Optional[torch.Tensor] = None,
901
910
  ) -> torch.Tensor:
902
911
  if k is not None:
903
912
  assert v is not None
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
943
952
  )
944
953
  causal = not layer.is_cross_attention
945
954
 
955
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
956
+ kwargs = {}
957
+ if sinks is not None:
958
+ kwargs["sinks"] = sinks
959
+
946
960
  k_descale, v_descale = None, None
947
961
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
948
962
  # has corresponding quantization method so that layer.k_scale is not None,
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
985
999
  softcap=layer.logit_cap,
986
1000
  k_descale=k_descale,
987
1001
  v_descale=v_descale,
1002
+ **kwargs,
988
1003
  )
989
1004
  elif use_local_attn:
990
1005
  # Use chunked (local) attention batching for self-attention
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
1003
1018
  softcap=layer.logit_cap,
1004
1019
  k_descale=k_descale,
1005
1020
  v_descale=v_descale,
1021
+ **kwargs,
1006
1022
  )
1007
1023
  else:
1008
1024
  page_table = metadata.page_table
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1046
  k_descale=k_descale,
1031
1047
  v_descale=v_descale,
1032
1048
  return_softmax_lse=use_cascade_attn,
1049
+ **kwargs,
1033
1050
  )
1034
1051
  if use_cascade_attn:
1035
1052
  o, softmax_lse, *rest = result
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
1050
1067
  k_descale=k_descale,
1051
1068
  v_descale=v_descale,
1052
1069
  return_softmax_lse=True,
1070
+ **kwargs,
1053
1071
  )
1054
1072
  )
1055
1073
  o, _ = merge_state_v2(
@@ -66,6 +66,10 @@ class PrefillMetadata:
66
66
  # Reuse this workspace buffer across all flashinfer wrappers
67
67
  global_workspace_buffer = None
68
68
 
69
+ # Use as a fast path to override the indptr in flashinfer's plan function
70
+ # This is used to remove some host-to-device copy overhead.
71
+ global_override_indptr_cpu = None
72
+
69
73
 
70
74
  class FlashInferAttnBackend(AttentionBackend):
71
75
  """Flashinfer attention kernels."""
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
205
209
  self.indices_updater_decode.update(
206
210
  forward_batch.req_pool_indices,
207
211
  forward_batch.seq_lens,
212
+ forward_batch.seq_lens_cpu,
208
213
  forward_batch.seq_lens_sum,
209
214
  decode_wrappers=self.decode_wrappers,
210
215
  encoder_lens=forward_batch.encoder_lens,
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
215
220
  self.indices_updater_prefill.update(
216
221
  forward_batch.req_pool_indices,
217
222
  forward_batch.seq_lens,
223
+ forward_batch.seq_lens_cpu,
218
224
  forward_batch.seq_lens_sum,
219
225
  prefix_lens=None,
220
226
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
229
235
  self.indices_updater_prefill.update(
230
236
  forward_batch.req_pool_indices,
231
237
  forward_batch.seq_lens,
238
+ forward_batch.seq_lens_cpu,
232
239
  forward_batch.seq_lens_sum,
233
240
  prefix_lens=None,
234
241
  prefill_wrappers=self.prefill_wrappers_verify,
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
252
259
  self.indices_updater_prefill.update(
253
260
  forward_batch.req_pool_indices,
254
261
  forward_batch.seq_lens,
262
+ forward_batch.seq_lens_cpu,
255
263
  forward_batch.seq_lens_sum,
256
264
  prefix_lens,
257
265
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
327
335
  self.indices_updater_decode.update(
328
336
  req_pool_indices,
329
337
  seq_lens,
338
+ seq_lens.cpu(), # may add a little overhead in capture stage
330
339
  seq_lens_sum,
331
340
  decode_wrappers=decode_wrappers,
332
341
  encoder_lens=encoder_lens,
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
358
367
  self.indices_updater_prefill.update(
359
368
  req_pool_indices,
360
369
  seq_lens,
370
+ seq_lens.cpu(), # may add a little overhead in capture stage
361
371
  seq_lens_sum,
362
372
  prefix_lens=None,
363
373
  prefill_wrappers=prefill_wrappers,
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
387
397
  self.indices_updater_prefill.update(
388
398
  req_pool_indices,
389
399
  seq_lens,
400
+ seq_lens.cpu(), # may add a little overhead in capture stage
390
401
  seq_lens_sum,
391
402
  prefix_lens=None,
392
403
  prefill_wrappers=prefill_wrappers,
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
414
425
  self.indices_updater_decode.update(
415
426
  req_pool_indices[:bs],
416
427
  seq_lens[:bs],
428
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
417
429
  seq_lens_sum,
418
430
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
419
431
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
423
435
  self.indices_updater_prefill.update(
424
436
  req_pool_indices[:bs],
425
437
  seq_lens[:bs],
438
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
426
439
  seq_lens_sum,
427
440
  prefix_lens=None,
428
441
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
434
447
  self.indices_updater_prefill.update(
435
448
  req_pool_indices[:bs],
436
449
  seq_lens[:bs],
450
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
437
451
  seq_lens_sum,
438
452
  prefix_lens=None,
439
453
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
581
595
 
582
596
 
583
597
  class FlashInferIndicesUpdaterDecode:
584
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
598
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
585
599
  # Parse Constants
586
600
  self.num_qo_heads = (
587
601
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
614
628
  self,
615
629
  req_pool_indices: torch.Tensor,
616
630
  seq_lens: torch.Tensor,
631
+ seq_lens_cpu: Optional[torch.Tensor],
617
632
  seq_lens_sum: int,
618
633
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
619
634
  encoder_lens: Optional[torch.Tensor],
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
626
641
  self,
627
642
  req_pool_indices: torch.Tensor,
628
643
  seq_lens: torch.Tensor,
644
+ seq_lens_cpu: Optional[torch.Tensor],
629
645
  seq_lens_sum: int,
630
646
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
631
647
  encoder_lens: Optional[torch.Tensor],
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
640
656
  self.kv_indptr[0],
641
657
  None,
642
658
  spec_info,
659
+ seq_lens_cpu,
643
660
  )
644
661
 
645
662
  def update_sliding_window(
646
663
  self,
647
664
  req_pool_indices: torch.Tensor,
648
665
  seq_lens: torch.Tensor,
666
+ seq_lens_cpu: Optional[torch.Tensor],
649
667
  seq_lens_sum: int,
650
668
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
651
669
  encoder_lens: Optional[torch.Tensor],
652
670
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
653
671
  ):
672
+ assert self.sliding_window_size is not None
654
673
  for wrapper_id in range(2):
655
674
  if wrapper_id == 0:
656
675
  # Sliding window attention
657
- paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
658
- seq_lens,
659
- torch.tensor(self.sliding_window_size + 1),
676
+ paged_kernel_lens_tmp = torch.clamp(
677
+ seq_lens, max=self.sliding_window_size + 1
660
678
  )
661
- paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
679
+ if seq_lens_cpu is not None:
680
+ seq_lens_cpu_tmp = torch.clamp(
681
+ seq_lens_cpu, max=self.sliding_window_size + 1
682
+ )
683
+ paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
684
+ else:
685
+ paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
662
686
  kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
663
687
  else:
664
688
  # Full attention
665
689
  paged_kernel_lens_tmp = seq_lens
666
690
  paged_kernel_lens_sum_tmp = seq_lens_sum
691
+ seq_lens_cpu_tmp = seq_lens_cpu
667
692
  kv_start_idx_tmp = None
668
693
 
669
694
  use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
678
703
  self.kv_indptr[wrapper_id],
679
704
  kv_start_idx_tmp,
680
705
  spec_info,
706
+ seq_lens_cpu=seq_lens_cpu_tmp,
681
707
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
682
708
  )
683
709
 
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
685
711
  self,
686
712
  req_pool_indices: torch.Tensor,
687
713
  seq_lens: torch.Tensor,
714
+ seq_lens_cpu: Optional[torch.Tensor],
688
715
  seq_lens_sum: int,
689
716
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
690
717
  encoder_lens: Optional[torch.Tensor],
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
709
736
  self.kv_indptr[wrapper_id],
710
737
  kv_start_idx,
711
738
  spec_info,
739
+ seq_lens_cpu=seq_lens_cpu,
712
740
  )
713
741
 
714
742
  def call_begin_forward(
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
720
748
  kv_indptr: torch.Tensor,
721
749
  kv_start_idx: torch.Tensor,
722
750
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
751
+ seq_lens_cpu: Optional[torch.Tensor],
723
752
  use_sliding_window_kv_pool: bool = False,
724
753
  ):
725
754
  if spec_info is None:
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
756
785
  )
757
786
  )
758
787
 
788
+ global global_override_indptr_cpu
789
+ locally_override = False
790
+ if seq_lens_cpu is not None and global_override_indptr_cpu is None:
791
+ locally_override = True
792
+ global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
793
+ global_override_indptr_cpu[0] = 0
794
+ global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
795
+
759
796
  wrapper.begin_forward(
760
797
  kv_indptr,
761
798
  kv_indices,
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
769
806
  non_blocking=True,
770
807
  )
771
808
 
809
+ if locally_override:
810
+ global_override_indptr_cpu = None
811
+
772
812
 
773
813
  class FlashInferIndicesUpdaterPrefill:
774
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
814
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
775
815
  # Parse Constants
776
816
  self.num_qo_heads = (
777
817
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
806
846
  self,
807
847
  req_pool_indices: torch.Tensor,
808
848
  seq_lens: torch.Tensor,
849
+ seq_lens_cpu: Optional[torch.Tensor],
809
850
  seq_lens_sum: int,
810
851
  prefix_lens: torch.Tensor,
811
852
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
820
861
  self,
821
862
  req_pool_indices: torch.Tensor,
822
863
  seq_lens: torch.Tensor,
864
+ seq_lens_cpu: Optional[torch.Tensor],
823
865
  seq_lens_sum: int,
824
866
  prefix_lens: torch.Tensor,
825
867
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
853
895
  self,
854
896
  req_pool_indices: torch.Tensor,
855
897
  seq_lens: torch.Tensor,
898
+ seq_lens_cpu: Optional[torch.Tensor],
856
899
  seq_lens_sum: int,
857
900
  prefix_lens: torch.Tensor,
858
901
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
898
941
  self,
899
942
  req_pool_indices: torch.Tensor,
900
943
  seq_lens: torch.Tensor,
944
+ seq_lens_cpu: Optional[torch.Tensor],
901
945
  seq_lens_sum: int,
902
946
  prefix_lens: torch.Tensor,
903
947
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
1020
1064
  )
1021
1065
 
1022
1066
 
1023
- # Use as a fast path to override the indptr in flashinfer's plan function
1024
- # This is used to remove some host-to-device copy overhead.
1025
- global global_override_indptr_cpu
1026
-
1027
-
1028
1067
  class FlashInferMultiStepDraftBackend:
1029
1068
  """
1030
1069
  Wrap multiple flashinfer attention backends as one for multiple consecutive
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
1056
1095
  self.kv_last_page_len = torch.ones(
1057
1096
  (max_bs,), dtype=torch.int32, device=model_runner.device
1058
1097
  )
1059
- self.attn_backends = []
1098
+ self.attn_backends: List[FlashInferAttnBackend] = []
1060
1099
  for i in range(self.speculative_num_steps):
1061
1100
  self.attn_backends.append(
1062
1101
  FlashInferAttnBackend(
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
1176
1215
  encoder_lens=None,
1177
1216
  forward_mode=ForwardMode.DECODE,
1178
1217
  spec_info=forward_batch.spec_info,
1179
- seq_lens_cpu=None,
1218
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1180
1219
  )
1181
1220
 
1182
1221
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
88
88
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
89
 
90
90
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
91
92
 
92
93
  if not self.skip_prefill:
93
94
  self.qo_indptr = torch.zeros(
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
197
198
  forward_batch.req_pool_indices,
198
199
  bs,
199
200
  self.device,
201
+ self.token_to_kv_pool_allocator,
200
202
  )
201
203
  )
202
204
  window_num_kv_splits = torch.empty(
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
225
227
  mask_indptr = None
226
228
  max_extend_len = None
227
229
  elif forward_batch.forward_mode.is_target_verify():
228
- # TODO: Support sliding window in spec inference
229
230
  bs = len(forward_batch.req_pool_indices)
230
231
  qo_indptr = torch.arange(
231
232
  0,
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
250
251
  self.req_to_token.stride(0),
251
252
  )
252
253
 
254
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
255
+ window_kv_indptr, window_kv_indices, window_kv_lens = (
256
+ update_sliding_window_buffer(
257
+ self.window_kv_indptr,
258
+ self.req_to_token,
259
+ self.sliding_window_size,
260
+ forward_batch.seq_lens,
261
+ forward_batch.req_pool_indices,
262
+ bs,
263
+ self.device,
264
+ self.token_to_kv_pool_allocator,
265
+ )
266
+ )
267
+
253
268
  custom_mask = spec_info.custom_mask
254
269
  seq_mask_len = self.num_draft_tokens * (
255
270
  forward_batch.seq_lens + self.num_draft_tokens
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
308
323
  forward_batch.req_pool_indices,
309
324
  bs,
310
325
  self.device,
326
+ self.token_to_kv_pool_allocator,
311
327
  )
312
328
 
313
329
  qo_indptr = self.qo_indptr
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
423
439
  ):
424
440
  window_kv_indices = self.cuda_graph_window_kv_indices
425
441
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
426
- window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
427
- self.window_kv_indptr,
428
- window_kv_indices,
429
- self.req_to_token,
430
- self.sliding_window_size,
431
- seq_lens[:bs],
432
- req_pool_indices,
433
- bs,
442
+ window_kv_indptr, window_kv_indices, _ = (
443
+ update_sliding_window_buffer_cuda_graph(
444
+ self.window_kv_indptr,
445
+ window_kv_indices,
446
+ self.req_to_token,
447
+ self.sliding_window_size,
448
+ seq_lens[:bs],
449
+ req_pool_indices,
450
+ bs,
451
+ self.token_to_kv_pool_allocator,
452
+ )
434
453
  )
435
454
  else:
436
455
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
464
483
  self.req_to_token.stride(0),
465
484
  )
466
485
 
486
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
487
+ window_kv_indices = self.cuda_graph_window_kv_indices
488
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
489
+ window_kv_indptr, window_kv_indices, _ = (
490
+ update_sliding_window_buffer_cuda_graph(
491
+ self.window_kv_indptr,
492
+ window_kv_indices,
493
+ self.req_to_token,
494
+ self.sliding_window_size,
495
+ seq_lens,
496
+ req_pool_indices,
497
+ bs,
498
+ self.token_to_kv_pool_allocator,
499
+ )
500
+ )
501
+
467
502
  custom_mask = self.cuda_graph_custom_mask
468
503
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
469
504
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
557
592
  ):
558
593
  window_num_kv_splits = self.cuda_graph_window_num_kv_splits
559
594
  window_kv_indices = self.cuda_graph_window_kv_indices
560
- _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
595
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
561
596
  self.window_kv_indptr,
562
597
  window_kv_indices,
563
598
  self.req_to_token,
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
565
600
  seq_lens[:bs],
566
601
  req_pool_indices[:bs],
567
602
  bs,
603
+ self.token_to_kv_pool_allocator,
568
604
  )
569
605
  self.get_num_kv_splits(
570
606
  window_num_kv_splits[:num_token], window_kv_lens[:bs]
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
599
635
  kv_indices,
600
636
  self.req_to_token.stride(0),
601
637
  )
638
+ if self.sliding_window_size is not None and self.sliding_window_size > 0:
639
+ window_num_kv_splits = self.cuda_graph_window_num_kv_splits
640
+ window_kv_indices = self.cuda_graph_window_kv_indices
641
+ _, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
642
+ self.window_kv_indptr,
643
+ window_kv_indices,
644
+ self.req_to_token,
645
+ self.sliding_window_size,
646
+ seq_lens,
647
+ req_pool_indices,
648
+ bs,
649
+ self.token_to_kv_pool_allocator,
650
+ )
602
651
  custom_mask = self.cuda_graph_custom_mask
603
652
  custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
604
653
  seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
637
686
  layer: RadixAttention,
638
687
  forward_batch: ForwardBatch,
639
688
  save_kv_cache=True,
689
+ sinks=None,
640
690
  ):
641
691
  # TODO: reuse the buffer across layers
642
692
  if layer.qk_head_dim != layer.v_head_dim:
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
680
730
  self.forward_metadata.max_extend_len,
681
731
  layer.scaling,
682
732
  layer.logit_cap,
683
- sliding_window_size,
733
+ sliding_window_size=sliding_window_size,
734
+ sinks=sinks,
684
735
  )
685
736
  return o
686
737
 
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
692
743
  layer: RadixAttention,
693
744
  forward_batch: ForwardBatch,
694
745
  save_kv_cache=True,
746
+ sinks=None,
695
747
  ):
696
748
  # During torch.compile, there is a bug in rotary_emb that causes the
697
749
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
728
780
  self.max_kv_splits,
729
781
  layer.scaling,
730
782
  layer.logit_cap,
783
+ sinks=sinks,
731
784
  )
732
785
  return o
733
786
 
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
932
985
  req_pool_indices,
933
986
  bs,
934
987
  device,
988
+ token_to_kv_pool_allocator=None,
935
989
  ):
936
990
  window_kv_lens = torch.minimum(
937
991
  seq_lens,
938
- torch.tensor(sliding_window_size + 1),
992
+ torch.tensor(sliding_window_size),
939
993
  )
940
994
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
941
995
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
952
1006
  window_kv_indices,
953
1007
  req_to_token.stride(0),
954
1008
  )
1009
+ # full to swa index mapping
1010
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1011
+ kv_last_index = window_kv_indptr[-1]
1012
+ window_kv_indices[:kv_last_index] = (
1013
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1014
+ window_kv_indices[:kv_last_index]
1015
+ )
1016
+ )
955
1017
  return window_kv_indptr, window_kv_indices, window_kv_lens
956
1018
 
957
1019
 
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
963
1025
  seq_lens,
964
1026
  req_pool_indices,
965
1027
  bs,
1028
+ token_to_kv_pool_allocator=None,
966
1029
  ):
967
1030
  window_kv_lens = torch.minimum(
968
1031
  seq_lens,
969
- torch.tensor(sliding_window_size + 1),
1032
+ torch.tensor(sliding_window_size),
970
1033
  )
971
1034
  window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
972
1035
  window_kv_indptr = window_kv_indptr[: bs + 1]
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
980
1043
  window_kv_indices,
981
1044
  req_to_token.stride(0),
982
1045
  )
983
- return window_kv_indptr, window_kv_lens
1046
+ # full to swa index mapping
1047
+ if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
1048
+ kv_last_index = window_kv_indptr[-1]
1049
+ window_kv_indices[:kv_last_index] = (
1050
+ token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
1051
+ window_kv_indices[:kv_last_index]
1052
+ )
1053
+ )
1054
+ return window_kv_indptr, window_kv_indices, window_kv_lens
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
495
495
  O,
496
496
  kv_indptr,
497
497
  num_kv_splits,
498
+ sink_ptr,
498
499
  stride_mid_ob,
499
500
  stride_mid_oh,
500
501
  stride_mid_os,
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
504
505
  MIN_BLOCK_KV: tl.constexpr,
505
506
  BLOCK_DV: tl.constexpr,
506
507
  Lv: tl.constexpr,
508
+ HAS_SINK: tl.constexpr,
507
509
  ):
508
510
  cur_batch = tl.program_id(0)
509
511
  cur_head = tl.program_id(1)
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
545
547
  e_sum = e_sum * old_scale + exp_logic
546
548
  e_max = n_e_max
547
549
 
550
+ if HAS_SINK:
551
+ cur_sink = tl.load(sink_ptr + cur_head)
552
+ e_sum += tl.exp(cur_sink - e_max)
553
+
548
554
  tl.store(
549
555
  O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
550
556
  acc / e_sum,
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
561
567
  kv_indptr,
562
568
  num_kv_splits,
563
569
  max_kv_splits,
570
+ sinks=None,
564
571
  ):
565
572
  batch, head_num = q.shape[0], q.shape[1]
566
573
  Lv = v_buffer.shape[-1]
567
574
  BLOCK_DV = triton.next_power_of_2(Lv)
568
575
 
569
576
  MAX_KV_SPLITS = max_kv_splits
577
+ HAS_SINK = sinks is not None
570
578
 
571
579
  extra_kargs = {}
572
580
  if _is_hip:
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
581
589
  o,
582
590
  kv_indptr,
583
591
  num_kv_splits,
592
+ sinks,
584
593
  logits.stride(0),
585
594
  logits.stride(1),
586
595
  logits.stride(2),
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
590
599
  MIN_BLOCK_KV=_MIN_BLOCK_KV,
591
600
  BLOCK_DV=BLOCK_DV,
592
601
  Lv=Lv,
602
+ HAS_SINK=HAS_SINK,
593
603
  num_warps=4,
594
604
  num_stages=2,
595
605
  **extra_kargs,
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
609
619
  max_kv_splits,
610
620
  sm_scale,
611
621
  logit_cap=0.0,
622
+ sinks=None,
612
623
  ):
613
624
  _decode_att_m_fwd(
614
625
  q,
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
632
643
  kv_indptr,
633
644
  num_kv_splits,
634
645
  max_kv_splits,
646
+ sinks,
635
647
  )
636
648
 
637
649
 
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
648
660
  max_kv_splits,
649
661
  sm_scale,
650
662
  logit_cap=0.0,
663
+ sinks=None,
651
664
  ):
652
665
  _decode_grouped_att_m_fwd(
653
666
  q,
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
671
684
  kv_indptr,
672
685
  num_kv_splits,
673
686
  max_kv_splits,
687
+ sinks,
674
688
  )
675
689
 
676
690
 
@@ -687,6 +701,7 @@ def decode_attention_fwd(
687
701
  max_kv_splits,
688
702
  sm_scale,
689
703
  logit_cap=0.0,
704
+ sinks=None,
690
705
  ):
691
706
  assert max_kv_splits == attn_logits.shape[2]
692
707
  assert q.shape[0] <= kv_indptr.shape[0] - 1
@@ -709,6 +724,7 @@ def decode_attention_fwd(
709
724
  max_kv_splits,
710
725
  sm_scale,
711
726
  logit_cap=logit_cap,
727
+ sinks=sinks,
712
728
  )
713
729
  else:
714
730
  # GQA/MQA/MLA
@@ -725,4 +741,5 @@ def decode_attention_fwd(
725
741
  max_kv_splits,
726
742
  sm_scale,
727
743
  logit_cap=logit_cap,
744
+ sinks=sinks,
728
745
  )