sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +6 -1
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +8 -7
  6. sglang/srt/disaggregation/decode.py +8 -4
  7. sglang/srt/disaggregation/mooncake/conn.py +43 -25
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  9. sglang/srt/distributed/parallel_state.py +4 -2
  10. sglang/srt/entrypoints/context.py +3 -20
  11. sglang/srt/entrypoints/engine.py +13 -8
  12. sglang/srt/entrypoints/harmony_utils.py +2 -0
  13. sglang/srt/entrypoints/http_server.py +68 -5
  14. sglang/srt/entrypoints/openai/protocol.py +2 -9
  15. sglang/srt/entrypoints/openai/serving_chat.py +60 -265
  16. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +4 -3
  18. sglang/srt/function_call/ebnf_composer.py +1 -0
  19. sglang/srt/function_call/function_call_parser.py +2 -0
  20. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  21. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  22. sglang/srt/function_call/kimik2_detector.py +3 -3
  23. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  24. sglang/srt/jinja_template_utils.py +6 -0
  25. sglang/srt/layers/attention/aiter_backend.py +370 -107
  26. sglang/srt/layers/attention/ascend_backend.py +3 -0
  27. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  28. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  29. sglang/srt/layers/attention/flashinfer_backend.py +55 -13
  30. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
  31. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  32. sglang/srt/layers/attention/triton_backend.py +24 -27
  33. sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
  34. sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
  35. sglang/srt/layers/attention/vision.py +9 -1
  36. sglang/srt/layers/attention/wave_backend.py +627 -0
  37. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  38. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  39. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  40. sglang/srt/layers/communicator.py +11 -13
  41. sglang/srt/layers/dp_attention.py +118 -27
  42. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  43. sglang/srt/layers/linear.py +1 -0
  44. sglang/srt/layers/logits_processor.py +12 -18
  45. sglang/srt/layers/moe/cutlass_moe.py +11 -16
  46. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  47. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  48. sglang/srt/layers/moe/ep_moe/layer.py +60 -2
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
  63. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  64. sglang/srt/layers/moe/topk.py +4 -1
  65. sglang/srt/layers/multimodal.py +156 -40
  66. sglang/srt/layers/quantization/__init__.py +10 -35
  67. sglang/srt/layers/quantization/awq.py +15 -16
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
  69. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  70. sglang/srt/layers/quantization/fp8_utils.py +22 -10
  71. sglang/srt/layers/quantization/gptq.py +12 -17
  72. sglang/srt/layers/quantization/marlin_utils.py +15 -5
  73. sglang/srt/layers/quantization/modelopt_quant.py +58 -41
  74. sglang/srt/layers/quantization/mxfp4.py +20 -3
  75. sglang/srt/layers/quantization/utils.py +52 -2
  76. sglang/srt/layers/quantization/w4afp8.py +20 -11
  77. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  78. sglang/srt/layers/rotary_embedding.py +281 -2
  79. sglang/srt/layers/sampler.py +5 -2
  80. sglang/srt/lora/backend/base_backend.py +3 -23
  81. sglang/srt/lora/layers.py +66 -116
  82. sglang/srt/lora/lora.py +17 -62
  83. sglang/srt/lora/lora_manager.py +12 -48
  84. sglang/srt/lora/lora_registry.py +20 -9
  85. sglang/srt/lora/mem_pool.py +20 -63
  86. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  87. sglang/srt/lora/utils.py +25 -58
  88. sglang/srt/managers/cache_controller.py +24 -29
  89. sglang/srt/managers/detokenizer_manager.py +1 -1
  90. sglang/srt/managers/io_struct.py +20 -6
  91. sglang/srt/managers/mm_utils.py +1 -2
  92. sglang/srt/managers/multimodal_processor.py +1 -1
  93. sglang/srt/managers/schedule_batch.py +43 -49
  94. sglang/srt/managers/schedule_policy.py +6 -6
  95. sglang/srt/managers/scheduler.py +18 -11
  96. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  97. sglang/srt/managers/tokenizer_manager.py +53 -44
  98. sglang/srt/mem_cache/allocator.py +39 -214
  99. sglang/srt/mem_cache/allocator_ascend.py +158 -0
  100. sglang/srt/mem_cache/chunk_cache.py +1 -1
  101. sglang/srt/mem_cache/hicache_storage.py +1 -1
  102. sglang/srt/mem_cache/hiradix_cache.py +34 -24
  103. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  104. sglang/srt/mem_cache/memory_pool_host.py +33 -35
  105. sglang/srt/mem_cache/radix_cache.py +2 -5
  106. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  107. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  108. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  109. sglang/srt/model_executor/cuda_graph_runner.py +29 -23
  110. sglang/srt/model_executor/forward_batch_info.py +33 -14
  111. sglang/srt/model_executor/model_runner.py +179 -81
  112. sglang/srt/model_loader/loader.py +18 -6
  113. sglang/srt/models/deepseek_nextn.py +2 -1
  114. sglang/srt/models/deepseek_v2.py +79 -38
  115. sglang/srt/models/gemma2.py +0 -34
  116. sglang/srt/models/gemma3n_mm.py +8 -9
  117. sglang/srt/models/glm4.py +6 -0
  118. sglang/srt/models/glm4_moe.py +11 -11
  119. sglang/srt/models/glm4_moe_nextn.py +2 -1
  120. sglang/srt/models/glm4v.py +589 -0
  121. sglang/srt/models/glm4v_moe.py +400 -0
  122. sglang/srt/models/gpt_oss.py +142 -20
  123. sglang/srt/models/granite.py +0 -25
  124. sglang/srt/models/llama.py +10 -27
  125. sglang/srt/models/llama4.py +19 -6
  126. sglang/srt/models/qwen2.py +2 -2
  127. sglang/srt/models/qwen2_5_vl.py +7 -3
  128. sglang/srt/models/qwen2_audio.py +10 -9
  129. sglang/srt/models/qwen2_moe.py +20 -5
  130. sglang/srt/models/qwen3.py +0 -24
  131. sglang/srt/models/qwen3_classification.py +78 -0
  132. sglang/srt/models/qwen3_moe.py +18 -5
  133. sglang/srt/models/registry.py +1 -1
  134. sglang/srt/models/step3_vl.py +6 -2
  135. sglang/srt/models/torch_native_llama.py +0 -24
  136. sglang/srt/multimodal/processors/base_processor.py +23 -13
  137. sglang/srt/multimodal/processors/glm4v.py +132 -0
  138. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  139. sglang/srt/operations.py +17 -2
  140. sglang/srt/reasoning_parser.py +316 -0
  141. sglang/srt/sampling/sampling_batch_info.py +7 -4
  142. sglang/srt/server_args.py +142 -140
  143. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
  144. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  145. sglang/srt/speculative/eagle_worker.py +16 -0
  146. sglang/srt/two_batch_overlap.py +16 -12
  147. sglang/srt/utils.py +3 -3
  148. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  149. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  150. sglang/test/doc_patch.py +59 -0
  151. sglang/test/few_shot_gsm8k.py +1 -1
  152. sglang/test/few_shot_gsm8k_engine.py +1 -1
  153. sglang/test/run_eval.py +4 -1
  154. sglang/test/simple_eval_common.py +6 -0
  155. sglang/test/simple_eval_gpqa.py +2 -0
  156. sglang/test/test_fp4_moe.py +118 -36
  157. sglang/test/test_marlin_moe.py +1 -1
  158. sglang/test/test_marlin_utils.py +1 -1
  159. sglang/utils.py +1 -1
  160. sglang/version.py +1 -1
  161. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
  162. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
  163. sglang/lang/backend/__init__.py +0 -0
  164. sglang/srt/function_call/harmony_tool_parser.py +0 -130
  165. sglang/srt/layers/quantization/scalar_type.py +0 -352
  166. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  167. /sglang/{api.py → lang/api.py} +0 -0
  168. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
  169. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
  170. {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,9 @@ class AscendAttnBackend(AttentionBackend):
75
75
  )
76
76
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
77
77
 
78
+ def get_cuda_graph_seq_len_fill_value(self):
79
+ return 1
80
+
78
81
  def forward_extend(
79
82
  self,
80
83
  q,
@@ -483,7 +483,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
483
483
  ).squeeze(1)
484
484
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
485
485
 
486
- def init_cuda_graph_state(self, max_bs: int):
486
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
487
487
  """Initialize CUDA graph state for the attention backend.
488
488
 
489
489
  Args:
@@ -629,6 +629,7 @@ class FlashAttentionBackend(AttentionBackend):
629
629
  # For multi-head latent attention
630
630
  q_rope: Optional[torch.Tensor] = None,
631
631
  k_rope: Optional[torch.Tensor] = None,
632
+ sinks: Optional[torch.Tensor] = None,
632
633
  ):
633
634
  if k is not None:
634
635
  assert v is not None
@@ -687,6 +688,11 @@ class FlashAttentionBackend(AttentionBackend):
687
688
  forward_batch.forward_mode.is_target_verify() and self.topk > 1
688
689
  )
689
690
 
691
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
692
+ kwargs = {}
693
+ if sinks is not None:
694
+ kwargs["sinks"] = sinks
695
+
690
696
  # Get the appropriate page table based on whether we're using local attention
691
697
  if use_local_attn:
692
698
  local_metadata = metadata.local_attn_metadata
@@ -737,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend):
737
743
  k_descale=k_descale,
738
744
  v_descale=v_descale,
739
745
  return_softmax_lse=use_cascade_attn,
746
+ **kwargs,
740
747
  )
741
748
 
742
749
  if use_cascade_attn:
@@ -757,6 +764,7 @@ class FlashAttentionBackend(AttentionBackend):
757
764
  k_descale=k_descale,
758
765
  v_descale=v_descale,
759
766
  return_softmax_lse=True,
767
+ **kwargs,
760
768
  )
761
769
  o, _ = merge_state_v2_wrapper(
762
770
  o,
@@ -898,6 +906,7 @@ class FlashAttentionBackend(AttentionBackend):
898
906
  # For multi-head latent attention
899
907
  q_rope: Optional[torch.Tensor] = None,
900
908
  k_rope: Optional[torch.Tensor] = None,
909
+ sinks: Optional[torch.Tensor] = None,
901
910
  ) -> torch.Tensor:
902
911
  if k is not None:
903
912
  assert v is not None
@@ -943,6 +952,11 @@ class FlashAttentionBackend(AttentionBackend):
943
952
  )
944
953
  causal = not layer.is_cross_attention
945
954
 
955
+ # For fa3 interface version compatibility, we put new fields into conditional keyword args
956
+ kwargs = {}
957
+ if sinks is not None:
958
+ kwargs["sinks"] = sinks
959
+
946
960
  k_descale, v_descale = None, None
947
961
  # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
948
962
  # has corresponding quantization method so that layer.k_scale is not None,
@@ -985,6 +999,7 @@ class FlashAttentionBackend(AttentionBackend):
985
999
  softcap=layer.logit_cap,
986
1000
  k_descale=k_descale,
987
1001
  v_descale=v_descale,
1002
+ **kwargs,
988
1003
  )
989
1004
  elif use_local_attn:
990
1005
  # Use chunked (local) attention batching for self-attention
@@ -1003,6 +1018,7 @@ class FlashAttentionBackend(AttentionBackend):
1003
1018
  softcap=layer.logit_cap,
1004
1019
  k_descale=k_descale,
1005
1020
  v_descale=v_descale,
1021
+ **kwargs,
1006
1022
  )
1007
1023
  else:
1008
1024
  page_table = metadata.page_table
@@ -1030,6 +1046,7 @@ class FlashAttentionBackend(AttentionBackend):
1030
1046
  k_descale=k_descale,
1031
1047
  v_descale=v_descale,
1032
1048
  return_softmax_lse=use_cascade_attn,
1049
+ **kwargs,
1033
1050
  )
1034
1051
  if use_cascade_attn:
1035
1052
  o, softmax_lse, *rest = result
@@ -1050,6 +1067,7 @@ class FlashAttentionBackend(AttentionBackend):
1050
1067
  k_descale=k_descale,
1051
1068
  v_descale=v_descale,
1052
1069
  return_softmax_lse=True,
1070
+ **kwargs,
1053
1071
  )
1054
1072
  )
1055
1073
  o, _ = merge_state_v2(
@@ -66,6 +66,10 @@ class PrefillMetadata:
66
66
  # Reuse this workspace buffer across all flashinfer wrappers
67
67
  global_workspace_buffer = None
68
68
 
69
+ # Use as a fast path to override the indptr in flashinfer's plan function
70
+ # This is used to remove some host-to-device copy overhead.
71
+ global_override_indptr_cpu = None
72
+
69
73
 
70
74
  class FlashInferAttnBackend(AttentionBackend):
71
75
  """Flashinfer attention kernels."""
@@ -118,6 +122,7 @@ class FlashInferAttnBackend(AttentionBackend):
118
122
  # Allocate buffers
119
123
  global global_workspace_buffer
120
124
  if global_workspace_buffer is None:
125
+ # different from flashinfer zero_init_global_workspace_buffer
121
126
  global_workspace_buffer = torch.empty(
122
127
  global_config.flashinfer_workspace_size,
123
128
  dtype=torch.uint8,
@@ -205,6 +210,7 @@ class FlashInferAttnBackend(AttentionBackend):
205
210
  self.indices_updater_decode.update(
206
211
  forward_batch.req_pool_indices,
207
212
  forward_batch.seq_lens,
213
+ forward_batch.seq_lens_cpu,
208
214
  forward_batch.seq_lens_sum,
209
215
  decode_wrappers=self.decode_wrappers,
210
216
  encoder_lens=forward_batch.encoder_lens,
@@ -215,6 +221,7 @@ class FlashInferAttnBackend(AttentionBackend):
215
221
  self.indices_updater_prefill.update(
216
222
  forward_batch.req_pool_indices,
217
223
  forward_batch.seq_lens,
224
+ forward_batch.seq_lens_cpu,
218
225
  forward_batch.seq_lens_sum,
219
226
  prefix_lens=None,
220
227
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -229,6 +236,7 @@ class FlashInferAttnBackend(AttentionBackend):
229
236
  self.indices_updater_prefill.update(
230
237
  forward_batch.req_pool_indices,
231
238
  forward_batch.seq_lens,
239
+ forward_batch.seq_lens_cpu,
232
240
  forward_batch.seq_lens_sum,
233
241
  prefix_lens=None,
234
242
  prefill_wrappers=self.prefill_wrappers_verify,
@@ -252,6 +260,7 @@ class FlashInferAttnBackend(AttentionBackend):
252
260
  self.indices_updater_prefill.update(
253
261
  forward_batch.req_pool_indices,
254
262
  forward_batch.seq_lens,
263
+ forward_batch.seq_lens_cpu,
255
264
  forward_batch.seq_lens_sum,
256
265
  prefix_lens,
257
266
  prefill_wrappers=self.prefill_wrappers_paged,
@@ -327,6 +336,7 @@ class FlashInferAttnBackend(AttentionBackend):
327
336
  self.indices_updater_decode.update(
328
337
  req_pool_indices,
329
338
  seq_lens,
339
+ seq_lens.cpu(), # may add a little overhead in capture stage
330
340
  seq_lens_sum,
331
341
  decode_wrappers=decode_wrappers,
332
342
  encoder_lens=encoder_lens,
@@ -358,6 +368,7 @@ class FlashInferAttnBackend(AttentionBackend):
358
368
  self.indices_updater_prefill.update(
359
369
  req_pool_indices,
360
370
  seq_lens,
371
+ seq_lens.cpu(), # may add a little overhead in capture stage
361
372
  seq_lens_sum,
362
373
  prefix_lens=None,
363
374
  prefill_wrappers=prefill_wrappers,
@@ -387,6 +398,7 @@ class FlashInferAttnBackend(AttentionBackend):
387
398
  self.indices_updater_prefill.update(
388
399
  req_pool_indices,
389
400
  seq_lens,
401
+ seq_lens.cpu(), # may add a little overhead in capture stage
390
402
  seq_lens_sum,
391
403
  prefix_lens=None,
392
404
  prefill_wrappers=prefill_wrappers,
@@ -414,6 +426,7 @@ class FlashInferAttnBackend(AttentionBackend):
414
426
  self.indices_updater_decode.update(
415
427
  req_pool_indices[:bs],
416
428
  seq_lens[:bs],
429
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
417
430
  seq_lens_sum,
418
431
  decode_wrappers=self.decode_cuda_graph_metadata[bs],
419
432
  encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
@@ -423,6 +436,7 @@ class FlashInferAttnBackend(AttentionBackend):
423
436
  self.indices_updater_prefill.update(
424
437
  req_pool_indices[:bs],
425
438
  seq_lens[:bs],
439
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
426
440
  seq_lens_sum,
427
441
  prefix_lens=None,
428
442
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -434,6 +448,7 @@ class FlashInferAttnBackend(AttentionBackend):
434
448
  self.indices_updater_prefill.update(
435
449
  req_pool_indices[:bs],
436
450
  seq_lens[:bs],
451
+ seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
437
452
  seq_lens_sum,
438
453
  prefix_lens=None,
439
454
  prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
@@ -581,7 +596,7 @@ class FlashInferAttnBackend(AttentionBackend):
581
596
 
582
597
 
583
598
  class FlashInferIndicesUpdaterDecode:
584
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
599
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
585
600
  # Parse Constants
586
601
  self.num_qo_heads = (
587
602
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -614,6 +629,7 @@ class FlashInferIndicesUpdaterDecode:
614
629
  self,
615
630
  req_pool_indices: torch.Tensor,
616
631
  seq_lens: torch.Tensor,
632
+ seq_lens_cpu: Optional[torch.Tensor],
617
633
  seq_lens_sum: int,
618
634
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
619
635
  encoder_lens: Optional[torch.Tensor],
@@ -626,6 +642,7 @@ class FlashInferIndicesUpdaterDecode:
626
642
  self,
627
643
  req_pool_indices: torch.Tensor,
628
644
  seq_lens: torch.Tensor,
645
+ seq_lens_cpu: Optional[torch.Tensor],
629
646
  seq_lens_sum: int,
630
647
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
631
648
  encoder_lens: Optional[torch.Tensor],
@@ -640,30 +657,39 @@ class FlashInferIndicesUpdaterDecode:
640
657
  self.kv_indptr[0],
641
658
  None,
642
659
  spec_info,
660
+ seq_lens_cpu,
643
661
  )
644
662
 
645
663
  def update_sliding_window(
646
664
  self,
647
665
  req_pool_indices: torch.Tensor,
648
666
  seq_lens: torch.Tensor,
667
+ seq_lens_cpu: Optional[torch.Tensor],
649
668
  seq_lens_sum: int,
650
669
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
651
670
  encoder_lens: Optional[torch.Tensor],
652
671
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
653
672
  ):
673
+ assert self.sliding_window_size is not None
654
674
  for wrapper_id in range(2):
655
675
  if wrapper_id == 0:
656
676
  # Sliding window attention
657
- paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
658
- seq_lens,
659
- torch.tensor(self.sliding_window_size + 1),
677
+ paged_kernel_lens_tmp = torch.clamp(
678
+ seq_lens, max=self.sliding_window_size + 1
660
679
  )
661
- paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
680
+ if seq_lens_cpu is not None:
681
+ seq_lens_cpu_tmp = torch.clamp(
682
+ seq_lens_cpu, max=self.sliding_window_size + 1
683
+ )
684
+ paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
685
+ else:
686
+ paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
662
687
  kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
663
688
  else:
664
689
  # Full attention
665
690
  paged_kernel_lens_tmp = seq_lens
666
691
  paged_kernel_lens_sum_tmp = seq_lens_sum
692
+ seq_lens_cpu_tmp = seq_lens_cpu
667
693
  kv_start_idx_tmp = None
668
694
 
669
695
  use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
@@ -678,6 +704,7 @@ class FlashInferIndicesUpdaterDecode:
678
704
  self.kv_indptr[wrapper_id],
679
705
  kv_start_idx_tmp,
680
706
  spec_info,
707
+ seq_lens_cpu=seq_lens_cpu_tmp,
681
708
  use_sliding_window_kv_pool=use_sliding_window_kv_pool,
682
709
  )
683
710
 
@@ -685,6 +712,7 @@ class FlashInferIndicesUpdaterDecode:
685
712
  self,
686
713
  req_pool_indices: torch.Tensor,
687
714
  seq_lens: torch.Tensor,
715
+ seq_lens_cpu: Optional[torch.Tensor],
688
716
  seq_lens_sum: int,
689
717
  decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
690
718
  encoder_lens: Optional[torch.Tensor],
@@ -709,6 +737,7 @@ class FlashInferIndicesUpdaterDecode:
709
737
  self.kv_indptr[wrapper_id],
710
738
  kv_start_idx,
711
739
  spec_info,
740
+ seq_lens_cpu=seq_lens_cpu,
712
741
  )
713
742
 
714
743
  def call_begin_forward(
@@ -720,6 +749,7 @@ class FlashInferIndicesUpdaterDecode:
720
749
  kv_indptr: torch.Tensor,
721
750
  kv_start_idx: torch.Tensor,
722
751
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
752
+ seq_lens_cpu: Optional[torch.Tensor],
723
753
  use_sliding_window_kv_pool: bool = False,
724
754
  ):
725
755
  if spec_info is None:
@@ -756,6 +786,14 @@ class FlashInferIndicesUpdaterDecode:
756
786
  )
757
787
  )
758
788
 
789
+ global global_override_indptr_cpu
790
+ locally_override = False
791
+ if seq_lens_cpu is not None and global_override_indptr_cpu is None:
792
+ locally_override = True
793
+ global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
794
+ global_override_indptr_cpu[0] = 0
795
+ global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
796
+
759
797
  wrapper.begin_forward(
760
798
  kv_indptr,
761
799
  kv_indices,
@@ -769,9 +807,12 @@ class FlashInferIndicesUpdaterDecode:
769
807
  non_blocking=True,
770
808
  )
771
809
 
810
+ if locally_override:
811
+ global_override_indptr_cpu = None
812
+
772
813
 
773
814
  class FlashInferIndicesUpdaterPrefill:
774
- def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
815
+ def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
775
816
  # Parse Constants
776
817
  self.num_qo_heads = (
777
818
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
@@ -806,6 +847,7 @@ class FlashInferIndicesUpdaterPrefill:
806
847
  self,
807
848
  req_pool_indices: torch.Tensor,
808
849
  seq_lens: torch.Tensor,
850
+ seq_lens_cpu: Optional[torch.Tensor],
809
851
  seq_lens_sum: int,
810
852
  prefix_lens: torch.Tensor,
811
853
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -820,6 +862,7 @@ class FlashInferIndicesUpdaterPrefill:
820
862
  self,
821
863
  req_pool_indices: torch.Tensor,
822
864
  seq_lens: torch.Tensor,
865
+ seq_lens_cpu: Optional[torch.Tensor],
823
866
  seq_lens_sum: int,
824
867
  prefix_lens: torch.Tensor,
825
868
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -828,6 +871,8 @@ class FlashInferIndicesUpdaterPrefill:
828
871
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
829
872
  ):
830
873
  if use_ragged:
874
+ # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
875
+ # and forward_batch.extend_seq_lens_cpu
831
876
  paged_kernel_lens = prefix_lens
832
877
  paged_kernel_lens_sum = paged_kernel_lens.sum().item()
833
878
  else:
@@ -853,6 +898,7 @@ class FlashInferIndicesUpdaterPrefill:
853
898
  self,
854
899
  req_pool_indices: torch.Tensor,
855
900
  seq_lens: torch.Tensor,
901
+ seq_lens_cpu: Optional[torch.Tensor],
856
902
  seq_lens_sum: int,
857
903
  prefix_lens: torch.Tensor,
858
904
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -898,6 +944,7 @@ class FlashInferIndicesUpdaterPrefill:
898
944
  self,
899
945
  req_pool_indices: torch.Tensor,
900
946
  seq_lens: torch.Tensor,
947
+ seq_lens_cpu: Optional[torch.Tensor],
901
948
  seq_lens_sum: int,
902
949
  prefix_lens: torch.Tensor,
903
950
  prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
@@ -1020,11 +1067,6 @@ class FlashInferIndicesUpdaterPrefill:
1020
1067
  )
1021
1068
 
1022
1069
 
1023
- # Use as a fast path to override the indptr in flashinfer's plan function
1024
- # This is used to remove some host-to-device copy overhead.
1025
- global global_override_indptr_cpu
1026
-
1027
-
1028
1070
  class FlashInferMultiStepDraftBackend:
1029
1071
  """
1030
1072
  Wrap multiple flashinfer attention backends as one for multiple consecutive
@@ -1056,7 +1098,7 @@ class FlashInferMultiStepDraftBackend:
1056
1098
  self.kv_last_page_len = torch.ones(
1057
1099
  (max_bs,), dtype=torch.int32, device=model_runner.device
1058
1100
  )
1059
- self.attn_backends = []
1101
+ self.attn_backends: List[FlashInferAttnBackend] = []
1060
1102
  for i in range(self.speculative_num_steps):
1061
1103
  self.attn_backends.append(
1062
1104
  FlashInferAttnBackend(
@@ -1176,7 +1218,7 @@ class FlashInferMultiStepDraftBackend:
1176
1218
  encoder_lens=None,
1177
1219
  forward_mode=ForwardMode.DECODE,
1178
1220
  spec_info=forward_batch.spec_info,
1179
- seq_lens_cpu=None,
1221
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
1180
1222
  )
1181
1223
 
1182
1224
  self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@@ -81,6 +81,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
81
81
  # Allocate buffers
82
82
  global global_workspace_buffer
83
83
  if global_workspace_buffer is None:
84
+ # different from flashinfer zero_init_global_workspace_buffer
84
85
  global_workspace_buffer = torch.empty(
85
86
  global_config.flashinfer_workspace_size,
86
87
  dtype=torch.uint8,
@@ -1,4 +1,4 @@
1
- from typing import TYPE_CHECKING, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import torch
4
4
 
@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
57
57
  self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
58
58
  self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
59
59
 
60
+ # Parse args
60
61
  self.skip_prefill = skip_prefill
61
-
62
62
  max_bs = model_runner.req_to_token_pool.size
63
+ self.sliding_window_size = model_runner.sliding_window_size
64
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
65
+ self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
66
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
67
+ self.speculative_num_steps = model_runner.server_args.speculative_num_steps
68
+ self.num_head = (
69
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
+ )
71
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
72
+ get_attention_tp_size()
73
+ )
74
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
75
+ self.max_context_len = model_runner.model_config.context_len
76
+ self.device = model_runner.device
77
+ self.device_core_count = get_device_core_count(model_runner.gpu_id)
78
+ self.static_kv_splits = get_bool_env_var(
79
+ "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
80
+ )
81
+ self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
63
82
 
83
+ # Check arguments
64
84
  assert not (
65
85
  model_runner.sliding_window_size is not None
66
86
  and model_runner.model_config.is_encoder_decoder
67
87
  ), "Sliding window and cross attention are not supported together"
68
- self.sliding_window_size = model_runner.sliding_window_size
69
88
 
89
+ # Initialize buffers
70
90
  # TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
71
91
  if kv_indptr_buf is None:
72
92
  self.kv_indptr = torch.zeros(
@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
87
107
  # When provided a buffer, create a clone for the second buffer
88
108
  self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
89
109
 
90
- self.req_to_token = model_runner.req_to_token_pool.req_to_token
91
- self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
92
-
93
110
  if not self.skip_prefill:
94
111
  self.qo_indptr = torch.zeros(
95
112
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
99
116
  (max_bs + 1,), dtype=torch.int64, device=model_runner.device
100
117
  )
101
118
 
102
- self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
103
- self.speculative_num_steps = model_runner.server_args.speculative_num_steps
104
-
105
- self.num_head = (
106
- model_runner.model_config.num_attention_heads // get_attention_tp_size()
107
- )
108
- self.num_kv_head = model_runner.model_config.get_num_kv_heads(
109
- get_attention_tp_size()
110
- )
111
-
112
- self.static_kv_splits = get_bool_env_var(
113
- "SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
114
- )
115
- self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
116
- self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
117
-
119
+ # Initialize forward metadata
118
120
  self.forward_metadata: ForwardMetadata = None
119
121
 
120
- self.max_context_len = model_runner.model_config.context_len
121
-
122
- self.device = model_runner.device
123
- self.device_core_count = get_device_core_count(model_runner.gpu_id)
124
-
125
122
  def get_num_kv_splits(
126
123
  self,
127
124
  num_kv_splits: torch.Tensor,
@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
333
330
  mask_indptr = None
334
331
  attn_logits = None
335
332
  attn_lse = None
336
- max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
333
+ max_extend_len = max(forward_batch.extend_seq_lens_cpu)
337
334
  num_kv_splits = None
338
335
 
339
336
  self.forward_metadata = ForwardMetadata(
@@ -23,10 +23,12 @@ if TYPE_CHECKING:
23
23
  from sglang.srt.speculative.spec_info import SpecInfo
24
24
 
25
25
  # Constants
26
- DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
26
+ DEFAULT_WORKSPACE_SIZE_MB = (
27
+ 512 # Memory workspace size in MB, todo(Yingyi): read from config
28
+ )
27
29
 
28
30
  # Reuse this workspace buffer across all TRTLLM MHA wrappers
29
- global_workspace_buffer = None
31
+ global_zero_init_workspace_buffer = None
30
32
 
31
33
 
32
34
  @dataclass
@@ -73,14 +75,14 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
73
75
  # Workspace allocation
74
76
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
75
77
  # Allocate buffers
76
- global global_workspace_buffer
77
- if global_workspace_buffer is None:
78
- global_workspace_buffer = torch.empty(
78
+ global global_zero_init_workspace_buffer
79
+ if global_zero_init_workspace_buffer is None:
80
+ global_zero_init_workspace_buffer = torch.zeros(
79
81
  self.workspace_size,
80
82
  dtype=torch.uint8,
81
83
  device=model_runner.device,
82
84
  )
83
- self.workspace_buffer = global_workspace_buffer
85
+ self.workspace_buffer = global_zero_init_workspace_buffer
84
86
 
85
87
  # CUDA graph state
86
88
  self.decode_cuda_graph_metadata = {}