sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner:
41
41
  # Parse args
42
42
  self.eagle_worker = eagle_worker
43
43
  self.model_runner = model_runner = eagle_worker.model_runner
44
+ self.model_runner: EAGLEWorker
44
45
  self.graphs = {}
45
46
  self.output_buffers = {}
46
47
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
@@ -105,30 +106,15 @@ class EAGLEDraftCudaGraphRunner:
105
106
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
106
107
  (self.dp_size,), dtype=torch.int32
107
108
  )
108
- self.gathered_buffer = torch.zeros(
109
- (
110
- self.max_num_token * self.dp_size,
111
- self.model_runner.model_config.hidden_size,
112
- ),
113
- dtype=self.model_runner.dtype,
114
- )
115
109
  else:
116
110
  assert self.require_attn_tp_gather
117
111
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
118
112
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
119
113
  (1,), dtype=torch.int32
120
114
  )
121
- self.gathered_buffer = torch.zeros(
122
- (
123
- self.max_num_token,
124
- self.model_runner.model_config.hidden_size,
125
- ),
126
- dtype=self.model_runner.dtype,
127
- )
128
115
  else:
129
116
  self.global_num_tokens_gpu = None
130
117
  self.global_num_tokens_for_logprob_gpu = None
131
- self.gathered_buffer = None
132
118
 
133
119
  # Capture
134
120
  try:
@@ -193,7 +179,7 @@ class EAGLEDraftCudaGraphRunner:
193
179
  )
194
180
  )
195
181
  global_num_tokens = self.global_num_tokens_gpu
196
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
182
+ global_dp_buffer_len = num_tokens * self.dp_size
197
183
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
198
184
  elif self.require_attn_tp_gather:
199
185
  self.global_num_tokens_gpu.copy_(
@@ -211,11 +197,11 @@ class EAGLEDraftCudaGraphRunner:
211
197
  )
212
198
  )
213
199
  global_num_tokens = self.global_num_tokens_gpu
214
- gathered_buffer = self.gathered_buffer[:num_tokens]
200
+ global_dp_buffer_len = num_tokens
215
201
  global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
216
202
  else:
217
203
  global_num_tokens = None
218
- gathered_buffer = None
204
+ global_dp_buffer_len = None
219
205
  global_num_tokens_for_logprob = None
220
206
 
221
207
  spec_info = EagleDraftInput(
@@ -239,8 +225,8 @@ class EAGLEDraftCudaGraphRunner:
239
225
  return_logprob=False,
240
226
  positions=positions,
241
227
  global_num_tokens_gpu=global_num_tokens,
242
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
243
- gathered_buffer=gathered_buffer,
228
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
229
+ global_dp_buffer_len=global_dp_buffer_len,
244
230
  spec_algorithm=self.model_runner.spec_algorithm,
245
231
  spec_info=spec_info,
246
232
  capture_hidden_mode=(
@@ -258,6 +244,7 @@ class EAGLEDraftCudaGraphRunner:
258
244
  def run_once():
259
245
  # Clean intermediate result cache for DP attention
260
246
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
247
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
261
248
 
262
249
  # Backup two fields, which will be modified in-place in `draft_forward`.
263
250
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.layers.dp_attention import DPPaddingMode
8
+ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
9
9
  from sglang.srt.model_executor.cuda_graph_runner import (
10
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
11
11
  CudaGraphRunner,
@@ -117,30 +117,15 @@ class EAGLEDraftExtendCudaGraphRunner:
117
117
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
118
  (self.dp_size,), dtype=torch.int32
119
119
  )
120
- self.gathered_buffer = torch.zeros(
121
- (
122
- self.max_num_token * self.dp_size,
123
- self.model_runner.model_config.hidden_size,
124
- ),
125
- dtype=self.model_runner.dtype,
126
- )
127
120
  else:
128
121
  assert self.require_attn_tp_gather
129
122
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
130
123
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
131
124
  (1,), dtype=torch.int32
132
125
  )
133
- self.gathered_buffer = torch.zeros(
134
- (
135
- self.max_num_token,
136
- self.model_runner.model_config.hidden_size,
137
- ),
138
- dtype=self.model_runner.dtype,
139
- )
140
126
  else:
141
127
  self.global_num_tokens_gpu = None
142
128
  self.global_num_tokens_for_logprob_gpu = None
143
- self.gathered_buffer = None
144
129
 
145
130
  if hasattr(
146
131
  self.model_runner.model_config.hf_config, "draft_vocab_size"
@@ -222,7 +207,7 @@ class EAGLEDraftExtendCudaGraphRunner:
222
207
  device=self.input_ids.device,
223
208
  )
224
209
  )
225
- gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
210
+ global_dp_buffer_len = num_tokens * self.dp_size
226
211
  elif self.require_attn_tp_gather:
227
212
  self.global_num_tokens_gpu.copy_(
228
213
  torch.tensor(
@@ -238,9 +223,9 @@ class EAGLEDraftExtendCudaGraphRunner:
238
223
  device=self.input_ids.device,
239
224
  )
240
225
  )
241
- gathered_buffer = self.gathered_buffer[:num_tokens]
226
+ global_dp_buffer_len = num_tokens
242
227
  else:
243
- gathered_buffer = None
228
+ global_dp_buffer_len = None
244
229
 
245
230
  spec_info = EagleDraftInput(
246
231
  hidden_states=hidden_states,
@@ -264,8 +249,8 @@ class EAGLEDraftExtendCudaGraphRunner:
264
249
  positions=positions,
265
250
  global_num_tokens_gpu=self.global_num_tokens_gpu,
266
251
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
267
- dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
268
- gathered_buffer=gathered_buffer,
252
+ dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
253
+ global_dp_buffer_len=global_dp_buffer_len,
269
254
  spec_algorithm=self.model_runner.spec_algorithm,
270
255
  spec_info=spec_info,
271
256
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -288,6 +273,7 @@ class EAGLEDraftExtendCudaGraphRunner:
288
273
  def run_once():
289
274
  # Clean intermediate result cache for DP attention
290
275
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
276
+ set_dp_buffer_len(global_dp_buffer_len, num_tokens)
291
277
 
292
278
  # Backup two fields, which will be modified in-place in `draft_forward`.
293
279
  output_cache_loc_backup = forward_batch.out_cache_loc
@@ -49,6 +49,8 @@ SIMULATE_ACC_METHOD = os.environ.get("SIMULATE_ACC_METHOD", "multinomial")
49
49
 
50
50
  TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
51
51
 
52
+ TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
53
+
52
54
 
53
55
  @dataclass
54
56
  class EagleDraftInput:
@@ -177,11 +179,24 @@ class EagleDraftInput:
177
179
  )
178
180
  return kv_indices, cum_kv_seq_len, qo_indptr, None
179
181
 
180
- def filter_batch(self, new_indices: torch.Tensor):
181
- self.topk_p = self.topk_p[: len(new_indices)]
182
- self.topk_index = self.topk_index[: len(new_indices)]
183
- self.hidden_states = self.hidden_states[: len(new_indices)]
184
- self.verified_id = self.verified_id[: len(new_indices)]
182
+ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
183
+ if has_been_filtered:
184
+ # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
185
+ # therefore, we don't need to filter the batch again in scheduler
186
+ if len(new_indices) != len(self.topk_p):
187
+ logger.warning(
188
+ f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
189
+ )
190
+ self.topk_p = self.topk_p[: len(new_indices)]
191
+ self.topk_index = self.topk_index[: len(new_indices)]
192
+ self.hidden_states = self.hidden_states[: len(new_indices)]
193
+ self.verified_id = self.verified_id[: len(new_indices)]
194
+ else:
195
+ # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
196
+ self.topk_p = self.topk_p[new_indices]
197
+ self.topk_index = self.topk_index[new_indices]
198
+ self.hidden_states = self.hidden_states[new_indices]
199
+ self.verified_id = self.verified_id[new_indices]
185
200
 
186
201
  def merge_batch(self, spec_info: EagleDraftInput):
187
202
  if self.hidden_states is None:
@@ -410,8 +425,15 @@ class EagleVerifyInput:
410
425
  logits=logits_output.next_token_logits, vocab_mask=vocab_mask
411
426
  )
412
427
 
413
- # Sample tokens
414
- if batch.sampling_info.is_all_greedy:
428
+ # Sample tokens. Force greedy sampling on AMD
429
+ is_all_greedy = sampling_info.is_all_greedy
430
+ if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
431
+ logger.warning(
432
+ "Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
433
+ "Falling back to greedy verification."
434
+ )
435
+
436
+ if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
415
437
  target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
416
438
  target_predict = target_predict.reshape(bs, self.draft_token_num)
417
439
 
@@ -440,12 +462,13 @@ class EagleVerifyInput:
440
462
  sampling_info.top_ks, self.draft_token_num, dim=0
441
463
  ),
442
464
  ) # (bs * draft_token_num, vocab_size)
443
- target_probs = top_p_renorm_prob(
444
- target_probs,
445
- torch.repeat_interleave(
446
- sampling_info.top_ps, self.draft_token_num, dim=0
447
- ),
448
- )
465
+ if not torch.all(sampling_info.top_ps == 1.0):
466
+ target_probs = top_p_renorm_prob(
467
+ target_probs,
468
+ torch.repeat_interleave(
469
+ sampling_info.top_ps, self.draft_token_num, dim=0
470
+ ),
471
+ )
449
472
  target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
450
473
 
451
474
  draft_probs = torch.zeros(
@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
9
9
 
10
10
  from sglang.srt.distributed import (
11
11
  GroupCoordinator,
12
- get_tensor_model_parallel_world_size,
13
12
  get_tp_group,
14
13
  patch_tensor_parallel_group,
15
14
  )
@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
92
91
  )
93
92
  self.padded_static_len = -1
94
93
 
95
- # Override context length with target model's context length
94
+ # Override the context length of the draft model to be the same as the target model.
96
95
  server_args.context_length = target_worker.model_runner.model_config.context_len
97
96
 
98
97
  # Do not capture cuda graph in `super().__init__()`
@@ -267,6 +266,43 @@ class EAGLEWorker(TpModelWorker):
267
266
  self.topk,
268
267
  self.speculative_num_steps,
269
268
  )
269
+ elif self.server_args.attention_backend == "trtllm_mha":
270
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
271
+ TRTLLMHAAttnBackend,
272
+ TRTLLMHAAttnMultiStepDraftBackend,
273
+ )
274
+
275
+ self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
276
+ self.draft_model_runner,
277
+ self.topk,
278
+ self.speculative_num_steps,
279
+ )
280
+ self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
281
+ self.draft_model_runner,
282
+ skip_prefill=False,
283
+ )
284
+ self.has_prefill_wrapper_verify = True
285
+ elif self.server_args.attention_backend == "trtllm_mla":
286
+ if not global_server_args_dict["use_mla_backend"]:
287
+ raise ValueError(
288
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
289
+ )
290
+
291
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
292
+ TRTLLMMLABackend,
293
+ TRTLLMMLAMultiStepDraftBackend,
294
+ )
295
+
296
+ self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
297
+ self.draft_model_runner,
298
+ self.topk,
299
+ self.speculative_num_steps,
300
+ )
301
+ self.draft_extend_attn_backend = TRTLLMMLABackend(
302
+ self.draft_model_runner,
303
+ skip_prefill=False,
304
+ )
305
+ self.has_prefill_wrapper_verify = True
270
306
  else:
271
307
  raise ValueError(
272
308
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -836,6 +872,21 @@ class EAGLEWorker(TpModelWorker):
836
872
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
837
873
  assert forward_batch.spec_info is batch.spec_info
838
874
  self.capture_for_decode(logits_output, forward_batch.spec_info)
875
+ has_finished, unfinished_req_index = False, []
876
+ for i, req in enumerate(batch.reqs):
877
+ if req.finished():
878
+ has_finished = True
879
+ else:
880
+ unfinished_req_index.append(i)
881
+ if has_finished:
882
+ unfinished_index_device = torch.tensor(
883
+ unfinished_req_index,
884
+ dtype=torch.int64,
885
+ device=batch.spec_info.topk_p.device,
886
+ )
887
+ batch.spec_info.filter_batch(
888
+ unfinished_index_device, has_been_filtered=False
889
+ )
839
890
 
840
891
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
841
892
  assert isinstance(batch.spec_info, EagleDraftInput)
@@ -966,7 +1017,9 @@ def get_last_loc_large_page_size_top_k_1(
966
1017
  return prefix_lens, seq_lens, last_loc
967
1018
 
968
1019
 
969
- @torch.compile(dynamic=True)
1020
+ # Disable torch.compile for this function because it will be
1021
+ # even slower.
1022
+ # @torch.compile(dynamic=True)
970
1023
  def get_last_loc_large_page_size_large_top_k(
971
1024
  req_to_token: torch.Tensor,
972
1025
  req_pool_indices: torch.Tensor,
@@ -0,0 +1,161 @@
1
+ import functools
2
+ import json
3
+ from typing import AbstractSet, Collection, List, Literal, Union
4
+
5
+
6
+ class TiktokenProcessor:
7
+ def __init__(self, name: str):
8
+ self.tokenizer = TiktokenTokenizer(name)
9
+
10
+ def image_processor(self, image):
11
+ return {"pixel_values": [image]}
12
+
13
+
14
+ RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)]
15
+ CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)]
16
+
17
+
18
+ PAD = "<|pad|>"
19
+ EOS = "<|eos|>"
20
+ SEP = "<|separator|>"
21
+
22
+ DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS]
23
+ DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
24
+
25
+ # default + separate each single digit
26
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
27
+
28
+
29
+ class TiktokenTokenizer:
30
+ def __init__(self, tokenizer_path):
31
+ import tiktoken
32
+ from jinja2 import Template
33
+
34
+ # Read the JSON
35
+ with open(tokenizer_path, "rb") as fin:
36
+ xtok_dict = json.load(fin)
37
+
38
+ # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
39
+ mergeable_ranks = {
40
+ bytes(item["bytes"]): item["token"] for item in xtok_dict["regular_tokens"]
41
+ }
42
+ special_tokens = {
43
+ bytes(item["bytes"]).decode(): item["token"]
44
+ for item in xtok_dict["special_tokens"]
45
+ }
46
+ if xtok_dict["word_split"] == "V1":
47
+ pad_str = PAT_STR_B
48
+ else:
49
+ assert False, f"Unknown word_split: {xtok_dict['word_split']}"
50
+ pad_str = xtok_dict.get("pat_str", pad_str)
51
+
52
+ kwargs = {
53
+ "name": tokenizer_path,
54
+ "pat_str": pad_str,
55
+ "mergeable_ranks": mergeable_ranks,
56
+ "special_tokens": special_tokens,
57
+ }
58
+ if "default_allowed_special" in xtok_dict:
59
+ default_allowed_special = set(
60
+ [
61
+ bytes(bytes_list).decode()
62
+ for bytes_list in xtok_dict["default_allowed_special"]
63
+ ]
64
+ )
65
+ if "vocab_size" in xtok_dict:
66
+ kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"]
67
+
68
+ # Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
69
+ default_allowed_special = None
70
+ control_tokens = DEFAULT_CONTROL_TOKENS
71
+ tokenizer = tiktoken.Encoding(**kwargs)
72
+ tokenizer._default_allowed_special = default_allowed_special or set()
73
+ tokenizer._control_tokens = control_tokens
74
+
75
+ def encode_patched(
76
+ self,
77
+ text: str,
78
+ *,
79
+ allowed_special: Union[
80
+ Literal["all"], AbstractSet[str]
81
+ ] = set(), # noqa: B006
82
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
83
+ ) -> List[int]:
84
+ if isinstance(allowed_special, set):
85
+ allowed_special |= self._default_allowed_special
86
+ return tiktoken.Encoding.encode(
87
+ self,
88
+ text,
89
+ allowed_special=allowed_special,
90
+ disallowed_special=(),
91
+ )
92
+
93
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
94
+
95
+ # Allow more tokens to prevent crash
96
+ tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values())
97
+ tokenizer._default_allowed_special |= set(
98
+ CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS
99
+ )
100
+
101
+ # Convert to HF interface
102
+ self.tokenizer = tokenizer
103
+ self.bos_token_id = None
104
+ self.eos_token_id = tokenizer._special_tokens[EOS]
105
+ self.vocab_size = tokenizer.n_vocab
106
+ self.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
107
+ self.chat_template_jinja = Template(self.chat_template)
108
+ self.additional_stop_token_ids = None
109
+
110
+ def encode(self, x, add_special_tokens=False):
111
+ return self.tokenizer.encode(x)
112
+
113
+ def decode(self, x, *args, **kwargs):
114
+ return self.tokenizer.decode(x)
115
+
116
+ def batch_decode(
117
+ self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
118
+ ):
119
+ if len(batch) > 0 and isinstance(batch[0], int):
120
+ batch = [[x] for x in batch]
121
+ return self.tokenizer.decode_batch(batch)
122
+
123
+ def apply_chat_template(
124
+ self, messages, tokenize, add_generation_prompt, tools=None
125
+ ):
126
+ ret = self.chat_template_jinja.render(
127
+ messages=messages, add_generation_prompt=add_generation_prompt
128
+ )
129
+ return self.encode(ret) if tokenize else ret
130
+
131
+ def __call__(self, text, **kwargs):
132
+ return {
133
+ "input_ids": self.encode(text),
134
+ }
135
+
136
+ def init_xgrammar(self):
137
+ from xgrammar import TokenizerInfo
138
+
139
+ XGRAMMAR_SPECIAL_TOKEN_TEMPLATE = "<|xg_special_token_{}|>"
140
+
141
+ enc = self.tokenizer
142
+ encoded_vocab = {**enc._mergeable_ranks, **enc._special_tokens}
143
+ encoded_vocab = [
144
+ token for token, _ in sorted(encoded_vocab.items(), key=lambda x: x[1])
145
+ ]
146
+ override_stop_tokens = [2] # eos
147
+ # These are treated as special tokens in xgrammar; we want to avoid them
148
+ # For now, xgrammar treats anything starting with b'\x00' as a special token
149
+ xgrammar_special_token_ids = []
150
+ for i, token in enumerate(encoded_vocab):
151
+ if isinstance(token, bytes) and token.startswith(b"\x00"):
152
+ xgrammar_special_token_ids.append(i)
153
+
154
+ for i, id in enumerate(xgrammar_special_token_ids):
155
+ encoded_vocab[id] = XGRAMMAR_SPECIAL_TOKEN_TEMPLATE.format(i)
156
+ tokenizer_info = TokenizerInfo(
157
+ encoded_vocab, stop_token_ids=override_stop_tokens
158
+ )
159
+ assert len(tokenizer_info.special_token_ids) == 0
160
+
161
+ return tokenizer_info, override_stop_tokens
@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
14
14
  CommunicateSummableTensorPairFn,
15
15
  ScatterMode,
16
16
  )
17
+ from sglang.srt.layers.moe import (
18
+ get_deepep_mode,
19
+ get_moe_a2a_backend,
20
+ get_tbo_token_distribution_threshold,
21
+ is_tbo_enabled,
22
+ )
17
23
  from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
18
- from sglang.srt.layers.moe.utils import DeepEPMode
19
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
20
25
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
21
26
  from sglang.srt.model_executor.forward_batch_info import (
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
83
88
  vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
84
89
  left_sum = sum(extend_lens[:vanilla_split_seq_index])
85
90
  overall_sum = sum(extend_lens)
86
- threshold = global_server_args_dict["tbo_token_distribution_threshold"]
91
+ threshold = get_tbo_token_distribution_threshold()
87
92
  assert threshold <= 0.5, f"{threshold=}"
88
93
  return left_sum < overall_sum * threshold or left_sum > overall_sum * (
89
94
  1 - threshold
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
299
304
  self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
300
305
 
301
306
  def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
302
- if not global_server_args_dict["enable_two_batch_overlap"]:
307
+ if not is_tbo_enabled():
303
308
  return
304
309
  token_num_per_seq = get_token_num_per_seq(
305
310
  forward_mode=batch.forward_mode, spec_info=batch.spec_info
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
353
358
  def prepare_all_gather(
354
359
  self,
355
360
  local_batch: ScheduleBatch,
356
- deepep_mode: DeepEPMode,
357
- enable_deepep_moe: bool,
358
- enable_two_batch_overlap: bool,
359
361
  ):
362
+
363
+ deepep_mode = get_deepep_mode()
364
+ enable_deepep_moe = get_moe_a2a_backend().is_deepep()
365
+ enable_two_batch_overlap = is_tbo_enabled()
366
+
360
367
  self.enable_two_batch_overlap = enable_two_batch_overlap
361
368
 
362
369
  if local_batch is not None:
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
384
391
  and not local_batch.forward_mode.is_target_verify()
385
392
  )
386
393
  and enable_deepep_moe
387
- and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
394
+ and (resolved_deepep_mode.is_low_latency())
388
395
  )
389
396
  else:
390
397
  self.local_tbo_split_seq_index = 0
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
657
664
  "req_to_token_pool",
658
665
  "token_to_kv_pool",
659
666
  "can_run_dp_cuda_graph",
667
+ "dp_padding_mode",
660
668
  "global_forward_mode",
661
669
  "spec_algorithm",
662
670
  "capture_hidden_mode",
@@ -678,16 +686,12 @@ class TboForwardBatchPreparer:
678
686
  # TODO improve, e.g. unify w/ `init_raw`
679
687
  if (
680
688
  global_server_args_dict["moe_dense_tp_size"] == 1
681
- and batch.gathered_buffer is not None
689
+ and batch.global_dp_buffer_len is not None
682
690
  ):
683
691
  sum_len = end_token_index - start_token_index
684
- gathered_buffer = torch.zeros(
685
- (sum_len, batch.gathered_buffer.shape[1]),
686
- dtype=batch.gathered_buffer.dtype,
687
- device=batch.gathered_buffer.device,
688
- )
692
+ global_dp_buffer_len = sum_len
689
693
  else:
690
- gathered_buffer = None
694
+ global_dp_buffer_len = None
691
695
 
692
696
  output_dict.update(
693
697
  dict(
@@ -705,8 +709,7 @@ class TboForwardBatchPreparer:
705
709
  tbo_children=None,
706
710
  global_num_tokens_gpu=None,
707
711
  global_num_tokens_cpu=None,
708
- dp_padding_mode=None,
709
- gathered_buffer=gathered_buffer,
712
+ global_dp_buffer_len=global_dp_buffer_len,
710
713
  global_num_tokens_for_logprob_gpu=None,
711
714
  global_num_tokens_for_logprob_cpu=None,
712
715
  sampling_info=None,
@@ -959,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
959
962
 
960
963
  class MaybeTboDeepEPDispatcher:
961
964
  def __init__(self, **kwargs):
962
- num_inner_dispatchers = (
963
- 2 if global_server_args_dict["enable_two_batch_overlap"] else 1
964
- )
965
+ num_inner_dispatchers = 2 if is_tbo_enabled() else 1
965
966
  self._inners = [
966
967
  DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
967
968
  ]