sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
91
91
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
92
  )
93
93
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
94
+ self.mrope_positions = torch.zeros(
95
+ (3, self.max_num_token), dtype=torch.int64
96
+ )
94
97
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
95
98
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
96
99
  self.hidden_states = torch.zeros(
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
159
162
  seq_lens = self.seq_lens[:num_seqs]
160
163
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
161
164
  positions = self.positions[:num_tokens]
165
+ mrope_positions = self.mrope_positions[:, :num_tokens]
162
166
  topk_p = self.topk_p[:num_seqs]
163
167
  topk_index = self.topk_index[:num_seqs]
164
168
  hidden_states = self.hidden_states[:num_seqs]
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
224
228
  seq_lens_sum=seq_lens.sum().item(),
225
229
  return_logprob=False,
226
230
  positions=positions,
231
+ mrope_positions=mrope_positions,
227
232
  global_num_tokens_gpu=global_num_tokens,
228
233
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
229
234
  global_dp_buffer_len=global_dp_buffer_len,
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
80
80
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
81
81
  self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
82
82
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
83
+ self.mrope_positions = torch.zeros(
84
+ (3, self.max_num_token), dtype=torch.int64
85
+ )
83
86
 
84
87
  if self.eagle_worker.speculative_algorithm.is_eagle3():
85
88
  self.hidden_states = torch.zeros(
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
189
192
  accept_length = self.accept_length[:bs]
190
193
  out_cache_loc = self.out_cache_loc[:num_tokens]
191
194
  positions = self.positions[:num_tokens]
195
+ mrope_positions = self.mrope_positions[:, :num_tokens]
192
196
  hidden_states = self.hidden_states[:num_tokens]
193
197
  next_token_logits_buffer = self.next_token_logits_buffer[:bs]
194
198
 
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
247
251
  seq_lens_sum=seq_lens.sum().item(),
248
252
  return_logprob=False,
249
253
  positions=positions,
254
+ mrope_positions=mrope_positions,
250
255
  global_num_tokens_gpu=self.global_num_tokens_gpu,
251
256
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
252
257
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
336
341
  self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
337
342
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
338
343
  self.positions[:num_tokens].copy_(forward_batch.positions)
339
- self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
344
+ if (
345
+ forward_batch.spec_info.hidden_states.shape[1]
346
+ == self.hidden_states.shape[1]
347
+ ):
348
+ self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
340
349
  if forward_batch.spec_info.accept_length is not None:
341
350
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
342
351
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
14
14
  )
15
15
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
16
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
+ from sglang.srt.managers.mm_utils import embed_mm_inputs
17
18
  from sglang.srt.managers.schedule_batch import (
18
19
  ScheduleBatch,
19
20
  get_last_loc,
@@ -46,6 +47,7 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
46
47
  from sglang.srt.utils import (
47
48
  empty_context,
48
49
  get_available_gpu_memory,
50
+ get_bool_env_var,
49
51
  is_cuda,
50
52
  next_power_of_2,
51
53
  )
@@ -54,6 +56,7 @@ if is_cuda():
54
56
  from sgl_kernel import segment_packbits
55
57
 
56
58
  logger = logging.getLogger(__name__)
59
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
57
60
 
58
61
 
59
62
  @contextmanager
@@ -137,8 +140,15 @@ class EAGLEWorker(TpModelWorker):
137
140
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
138
141
 
139
142
  if self.speculative_algorithm.is_eagle3():
140
- # EAGLE3 models don't share lm_head
141
- self.draft_model_runner.model.set_embed(embed)
143
+ # most cases EAGLE3 models don't share lm_head
144
+ # but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
145
+ if (
146
+ hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
147
+ and self.draft_model_runner.model.load_lm_head_from_target
148
+ ):
149
+ self.draft_model_runner.model.set_embed_and_head(embed, head)
150
+ else:
151
+ self.draft_model_runner.model.set_embed(embed)
142
152
 
143
153
  # grab hot token ids
144
154
  if self.draft_model_runner.model.hot_token_id is not None:
@@ -178,137 +188,189 @@ class EAGLEWorker(TpModelWorker):
178
188
  self.has_prefill_wrapper_verify = False
179
189
  self.draft_extend_attn_backend = None
180
190
 
181
- if self.server_args.attention_backend == "flashinfer":
182
- if not global_server_args_dict["use_mla_backend"]:
183
- from sglang.srt.layers.attention.flashinfer_backend import (
184
- FlashInferAttnBackend,
185
- FlashInferMultiStepDraftBackend,
186
- )
191
+ # Initialize decode attention backend
192
+ self.draft_attn_backend = self._create_decode_backend()
187
193
 
188
- self.draft_attn_backend = FlashInferMultiStepDraftBackend(
189
- self.draft_model_runner,
190
- self.topk,
191
- self.speculative_num_steps,
192
- )
193
- self.draft_extend_attn_backend = FlashInferAttnBackend(
194
- self.draft_model_runner,
195
- skip_prefill=False,
196
- )
197
- else:
198
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
199
- FlashInferMLAAttnBackend,
200
- FlashInferMLAMultiStepDraftBackend,
201
- )
194
+ # Initialize draft extend attention backend (respects speculative_attention_mode setting)
195
+ self.draft_extend_attn_backend = self._create_draft_extend_backend()
202
196
 
203
- self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
204
- self.draft_model_runner,
205
- self.topk,
206
- self.speculative_num_steps,
207
- )
208
- self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
209
- self.draft_model_runner,
210
- skip_prefill=False,
211
- )
212
- self.has_prefill_wrapper_verify = True
213
- elif self.server_args.attention_backend == "triton":
214
- from sglang.srt.layers.attention.triton_backend import (
215
- TritonAttnBackend,
216
- TritonMultiStepDraftBackend,
217
- )
197
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
218
198
 
219
- self.draft_attn_backend = TritonMultiStepDraftBackend(
220
- self.draft_model_runner,
221
- self.topk,
222
- self.speculative_num_steps,
223
- )
224
- self.draft_extend_attn_backend = TritonAttnBackend(
225
- self.draft_model_runner,
226
- skip_prefill=False,
227
- )
228
- elif self.server_args.attention_backend == "aiter":
229
- from sglang.srt.layers.attention.aiter_backend import (
230
- AiterAttnBackend,
231
- AiterMultiStepDraftBackend,
232
- )
199
+ def _create_backend(
200
+ self, backend_name: str, backend_map: dict, error_template: str
201
+ ):
202
+ backend_type = getattr(self.server_args, backend_name)
203
+ if backend_type is None:
204
+ backend_type = self.server_args.attention_backend
205
+
206
+ if backend_type not in backend_map:
207
+ raise ValueError(error_template.format(backend_type=backend_type))
208
+
209
+ return backend_map[backend_type]()
210
+
211
+ def _create_decode_backend(self):
212
+ backend_map = {
213
+ "flashinfer": self._create_flashinfer_decode_backend,
214
+ "triton": self._create_triton_decode_backend,
215
+ "aiter": self._create_aiter_decode_backend,
216
+ "fa3": self._create_fa3_decode_backend,
217
+ "hybrid_linear_attn": self._create_fa3_decode_backend,
218
+ "flashmla": self._create_flashmla_decode_backend,
219
+ "trtllm_mha": self._create_trtllm_mha_decode_backend,
220
+ "trtllm_mla": self._create_trtllm_mla_decode_backend,
221
+ }
222
+
223
+ return self._create_backend(
224
+ "decode_attention_backend",
225
+ backend_map,
226
+ "EAGLE is not supported in decode attention backend {backend_type}",
227
+ )
233
228
 
234
- self.draft_attn_backend = AiterMultiStepDraftBackend(
235
- self.draft_model_runner,
236
- self.topk,
237
- self.speculative_num_steps,
238
- )
239
- self.draft_extend_attn_backend = AiterAttnBackend(
240
- self.draft_model_runner,
241
- skip_prefill=False,
242
- )
243
- self.has_prefill_wrapper_verify = False
244
- elif self.server_args.attention_backend == "fa3":
245
- from sglang.srt.layers.attention.flashattention_backend import (
246
- FlashAttentionBackend,
247
- FlashAttentionMultiStepBackend,
248
- )
229
+ def _create_draft_extend_backend(self):
230
+ backend_map = {
231
+ "flashinfer": self._create_flashinfer_prefill_backend,
232
+ "triton": self._create_triton_prefill_backend,
233
+ "aiter": self._create_aiter_prefill_backend,
234
+ "fa3": self._create_fa3_prefill_backend,
235
+ "hybrid_linear_attn": self._create_fa3_prefill_backend,
236
+ "trtllm_mha": self._create_trtllm_mha_prefill_backend,
237
+ "trtllm_mla": self._create_trtllm_mla_prefill_backend,
238
+ }
239
+ backend_name = (
240
+ "decode_attention_backend"
241
+ if self.server_args.speculative_attention_mode == "decode"
242
+ else "prefill_attention_backend"
243
+ )
244
+ return self._create_backend(
245
+ backend_name,
246
+ backend_map,
247
+ "EAGLE is not supported in attention backend {backend_type}",
248
+ )
249
249
 
250
- self.draft_attn_backend = FlashAttentionMultiStepBackend(
251
- self.draft_model_runner,
252
- self.topk,
253
- self.speculative_num_steps,
254
- )
255
- self.draft_extend_attn_backend = FlashAttentionBackend(
256
- self.draft_model_runner,
257
- skip_prefill=False,
258
- )
259
- elif self.server_args.attention_backend == "flashmla":
260
- from sglang.srt.layers.attention.flashmla_backend import (
261
- FlashMLAMultiStepDraftBackend,
250
+ def _create_flashinfer_decode_backend(self):
251
+ if not global_server_args_dict["use_mla_backend"]:
252
+ from sglang.srt.layers.attention.flashinfer_backend import (
253
+ FlashInferMultiStepDraftBackend,
262
254
  )
263
255
 
264
- self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
265
- self.draft_model_runner,
266
- self.topk,
267
- self.speculative_num_steps,
256
+ self.has_prefill_wrapper_verify = True
257
+ return FlashInferMultiStepDraftBackend(
258
+ self.draft_model_runner, self.topk, self.speculative_num_steps
268
259
  )
269
- elif self.server_args.attention_backend == "trtllm_mha":
270
- from sglang.srt.layers.attention.trtllm_mha_backend import (
271
- TRTLLMHAAttnBackend,
272
- TRTLLMHAAttnMultiStepDraftBackend,
260
+ else:
261
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
262
+ FlashInferMLAMultiStepDraftBackend,
273
263
  )
274
264
 
275
- self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
276
- self.draft_model_runner,
277
- self.topk,
278
- self.speculative_num_steps,
279
- )
280
- self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
281
- self.draft_model_runner,
282
- skip_prefill=False,
283
- )
284
265
  self.has_prefill_wrapper_verify = True
285
- elif self.server_args.attention_backend == "trtllm_mla":
286
- if not global_server_args_dict["use_mla_backend"]:
287
- raise ValueError(
288
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
289
- )
290
-
291
- from sglang.srt.layers.attention.trtllm_mla_backend import (
292
- TRTLLMMLABackend,
293
- TRTLLMMLAMultiStepDraftBackend,
266
+ return FlashInferMLAMultiStepDraftBackend(
267
+ self.draft_model_runner, self.topk, self.speculative_num_steps
294
268
  )
295
269
 
296
- self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
297
- self.draft_model_runner,
298
- self.topk,
299
- self.speculative_num_steps,
270
+ def _create_triton_decode_backend(self):
271
+ from sglang.srt.layers.attention.triton_backend import (
272
+ TritonMultiStepDraftBackend,
273
+ )
274
+
275
+ return TritonMultiStepDraftBackend(
276
+ self.draft_model_runner, self.topk, self.speculative_num_steps
277
+ )
278
+
279
+ def _create_aiter_decode_backend(self):
280
+ from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
281
+
282
+ return AiterMultiStepDraftBackend(
283
+ self.draft_model_runner, self.topk, self.speculative_num_steps
284
+ )
285
+
286
+ def _create_fa3_decode_backend(self):
287
+ from sglang.srt.layers.attention.flashattention_backend import (
288
+ FlashAttentionMultiStepBackend,
289
+ )
290
+
291
+ return FlashAttentionMultiStepBackend(
292
+ self.draft_model_runner, self.topk, self.speculative_num_steps
293
+ )
294
+
295
+ def _create_flashmla_decode_backend(self):
296
+ from sglang.srt.layers.attention.flashmla_backend import (
297
+ FlashMLAMultiStepDraftBackend,
298
+ )
299
+
300
+ return FlashMLAMultiStepDraftBackend(
301
+ self.draft_model_runner, self.topk, self.speculative_num_steps
302
+ )
303
+
304
+ def _create_trtllm_mha_decode_backend(self):
305
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
306
+ TRTLLMHAAttnMultiStepDraftBackend,
307
+ )
308
+
309
+ self.has_prefill_wrapper_verify = True
310
+ return TRTLLMHAAttnMultiStepDraftBackend(
311
+ self.draft_model_runner, self.topk, self.speculative_num_steps
312
+ )
313
+
314
+ def _create_trtllm_mla_decode_backend(self):
315
+ if not global_server_args_dict["use_mla_backend"]:
316
+ raise ValueError(
317
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
300
318
  )
301
- self.draft_extend_attn_backend = TRTLLMMLABackend(
302
- self.draft_model_runner,
303
- skip_prefill=False,
319
+
320
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
321
+ TRTLLMMLAMultiStepDraftBackend,
322
+ )
323
+
324
+ self.has_prefill_wrapper_verify = True
325
+ return TRTLLMMLAMultiStepDraftBackend(
326
+ self.draft_model_runner, self.topk, self.speculative_num_steps
327
+ )
328
+
329
+ def _create_flashinfer_prefill_backend(self):
330
+ if not global_server_args_dict["use_mla_backend"]:
331
+ from sglang.srt.layers.attention.flashinfer_backend import (
332
+ FlashInferAttnBackend,
304
333
  )
305
- self.has_prefill_wrapper_verify = True
334
+
335
+ return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
306
336
  else:
337
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
338
+ FlashInferMLAAttnBackend,
339
+ )
340
+
341
+ return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
342
+
343
+ def _create_triton_prefill_backend(self):
344
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
345
+
346
+ return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
347
+
348
+ def _create_aiter_prefill_backend(self):
349
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
350
+
351
+ return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
352
+
353
+ def _create_fa3_prefill_backend(self):
354
+ from sglang.srt.layers.attention.flashattention_backend import (
355
+ FlashAttentionBackend,
356
+ )
357
+
358
+ return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
359
+
360
+ def _create_trtllm_mha_prefill_backend(self):
361
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
362
+
363
+ return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
364
+
365
+ def _create_trtllm_mla_prefill_backend(self):
366
+ if not global_server_args_dict["use_mla_backend"]:
307
367
  raise ValueError(
308
- f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
368
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
309
369
  )
310
370
 
311
- self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
371
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
372
+
373
+ return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
312
374
 
313
375
  def init_cuda_graphs(self):
314
376
  """Capture cuda graphs."""
@@ -674,6 +736,14 @@ class EAGLEWorker(TpModelWorker):
674
736
 
675
737
  # Set inputs
676
738
  forward_batch.input_ids = input_ids
739
+ # This is a temporary fix for the case that the user is using standalone
740
+ # speculative decoding and the draft model architecture is gpt-oss. gpt-oss
741
+ # rope kernel needs cache_loc to be contiguous.
742
+ if (
743
+ self.server_args.speculative_algorithm == "STANDALONE"
744
+ and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
745
+ ):
746
+ out_cache_loc = out_cache_loc.contiguous()
677
747
  forward_batch.out_cache_loc = out_cache_loc[i]
678
748
  forward_batch.positions.add_(1)
679
749
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
@@ -758,6 +828,21 @@ class EAGLEWorker(TpModelWorker):
758
828
  ]
759
829
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
760
830
 
831
+ # QQ: can be optimized
832
+ if self.target_worker.model_runner.is_hybrid_gdn:
833
+ # res.draft_input.accept_length is on GPU but may be empty for last verify?
834
+ accepted_length = (
835
+ torch.tensor(
836
+ res.accept_length_per_req_cpu,
837
+ device=logits_output.hidden_states.device,
838
+ dtype=torch.int32,
839
+ )
840
+ + 1
841
+ )
842
+ self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
843
+ accepted_length, self.target_worker.model_runner.model
844
+ )
845
+
761
846
  if batch.return_logprob:
762
847
  self.add_logprob_values(batch, res, logits_output)
763
848
 
@@ -781,15 +866,20 @@ class EAGLEWorker(TpModelWorker):
781
866
  token_ids_logprobs = batch.token_ids_logprobs
782
867
  accepted_indices = res.accepted_indices
783
868
  assert len(accepted_indices) == len(logits_output.next_token_logits)
869
+
784
870
  temperatures = batch.sampling_info.temperatures
785
871
  num_draft_tokens = batch.spec_info.draft_token_num
786
872
  # acceptance indices are the indices in a "flattened" batch.
787
873
  # dividing it to num_draft_tokens will yield the actual batch index.
788
874
  temperatures = temperatures[accepted_indices // num_draft_tokens]
789
-
790
- logprobs = torch.nn.functional.log_softmax(
791
- logits_output.next_token_logits / temperatures, dim=-1
792
- )
875
+ if RETURN_ORIGINAL_LOGPROB:
876
+ logprobs = torch.nn.functional.log_softmax(
877
+ logits_output.next_token_logits, dim=-1
878
+ )
879
+ else:
880
+ logprobs = torch.nn.functional.log_softmax(
881
+ logits_output.next_token_logits / temperatures, dim=-1
882
+ )
793
883
  batch_next_token_ids = res.verified_id
794
884
  num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
795
885
 
@@ -806,13 +896,19 @@ class EAGLEWorker(TpModelWorker):
806
896
  (
807
897
  logits_output.next_token_top_logprobs_val,
808
898
  logits_output.next_token_top_logprobs_idx,
809
- ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
899
+ ) = get_top_logprobs(
900
+ logprobs,
901
+ top_logprobs_nums_repeat_interleaved,
902
+ )
810
903
 
811
904
  if any(x is not None for x in token_ids_logprobs):
812
905
  (
813
906
  logits_output.next_token_token_ids_logprobs_val,
814
907
  logits_output.next_token_token_ids_logprobs_idx,
815
- ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
908
+ ) = get_token_ids_logprobs(
909
+ logprobs,
910
+ token_ids_logprobs_repeat_interleaved,
911
+ )
816
912
 
817
913
  logits_output.next_token_logprobs = logprobs[
818
914
  torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
7
  EAGLE3 = auto()
8
+ STANDALONE = auto()
8
9
 
9
10
  def is_none(self):
10
11
  return self == SpeculativeAlgorithm.NONE
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
15
16
  def is_eagle3(self):
16
17
  return self == SpeculativeAlgorithm.EAGLE3
17
18
 
19
+ def is_standalone(self):
20
+ return self == SpeculativeAlgorithm.STANDALONE
21
+
18
22
  @staticmethod
19
23
  def from_string(name: str):
20
24
  name_map = {
21
25
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
26
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
27
+ "STANDALONE": SpeculativeAlgorithm.STANDALONE,
23
28
  None: SpeculativeAlgorithm.NONE,
24
29
  }
25
30
  if name is not None:
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
8
+ from sglang.srt.managers.tp_worker import TpModelWorker
9
+ from sglang.srt.server_args import ServerArgs
10
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
11
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
12
+ from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
13
+
14
+ if is_cuda():
15
+ from sgl_kernel import segment_packbits
16
+
17
+ logger = logging.getLogger(__name__)
18
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
19
+
20
+
21
+ @contextmanager
22
+ def draft_tp_context(tp_group: GroupCoordinator):
23
+ # Draft model doesn't use dp and has its own tp group.
24
+ # We disable mscclpp now because it doesn't support 2 comm groups.
25
+ with patch_tensor_parallel_group(tp_group):
26
+ yield
27
+
28
+
29
+ class StandaloneWorker(EAGLEWorker):
30
+
31
+ def __init__(
32
+ self,
33
+ server_args: ServerArgs,
34
+ gpu_id: int,
35
+ tp_rank: int,
36
+ dp_rank: Optional[int],
37
+ moe_ep_rank: int,
38
+ nccl_port: int,
39
+ target_worker: TpModelWorker,
40
+ ):
41
+ # Parse arguments
42
+ self.server_args = server_args
43
+ self.topk = server_args.speculative_eagle_topk
44
+ self.speculative_num_steps = server_args.speculative_num_steps
45
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
46
+ self.enable_nan_detection = server_args.enable_nan_detection
47
+ self.gpu_id = gpu_id
48
+ self.device = server_args.device
49
+ self.target_worker = target_worker
50
+ self.page_size = server_args.page_size
51
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
52
+ server_args.speculative_algorithm
53
+ )
54
+ self.padded_static_len = -1
55
+
56
+ # Override the context length of the draft model to be the same as the target model.
57
+ server_args.context_length = target_worker.model_runner.model_config.context_len
58
+
59
+ # Do not capture cuda graph in `super().__init__()`
60
+ # It will be captured later.
61
+ backup_disable_cuda_graph = server_args.disable_cuda_graph
62
+ server_args.disable_cuda_graph = True
63
+ # Share the allocator with a target worker.
64
+ # Draft and target worker own their own KV cache pools.
65
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
66
+ target_worker.get_memory_pool()
67
+ )
68
+
69
+ # Load hot token ids
70
+ if server_args.speculative_token_map is not None:
71
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
72
+ server_args.json_model_override_args = (
73
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
74
+ )
75
+ else:
76
+ self.hot_token_id = None
77
+
78
+ # Init draft worker
79
+ with empty_context():
80
+ TpModelWorker.__init__(
81
+ self,
82
+ server_args=server_args,
83
+ gpu_id=gpu_id,
84
+ tp_rank=tp_rank,
85
+ pp_rank=0, # FIXME
86
+ dp_rank=dp_rank,
87
+ moe_ep_rank=moe_ep_rank,
88
+ nccl_port=nccl_port,
89
+ is_draft_worker=True,
90
+ req_to_token_pool=self.req_to_token_pool,
91
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
92
+ )
93
+
94
+ # Init attention backend and cuda graphs
95
+ self.draft_model_runner.server_args.disable_cuda_graph = (
96
+ backup_disable_cuda_graph
97
+ )
98
+ self.draft_tp_context = (
99
+ draft_tp_context if server_args.enable_dp_attention else empty_context
100
+ )
101
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
102
+ self.init_attention_backend()
103
+ self.init_cuda_graphs()
104
+
105
+ # Some dummy tensors
106
+ self.num_new_pages_per_topk = torch.empty(
107
+ (), dtype=torch.int64, device=self.device
108
+ )
109
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)