sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
91
91
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
92
  )
93
93
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
94
+ self.mrope_positions = torch.zeros(
95
+ (3, self.max_num_token), dtype=torch.int64
96
+ )
94
97
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
95
98
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
96
99
  self.hidden_states = torch.zeros(
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
159
162
  seq_lens = self.seq_lens[:num_seqs]
160
163
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
161
164
  positions = self.positions[:num_tokens]
165
+ mrope_positions = self.mrope_positions[:, :num_tokens]
162
166
  topk_p = self.topk_p[:num_seqs]
163
167
  topk_index = self.topk_index[:num_seqs]
164
168
  hidden_states = self.hidden_states[:num_seqs]
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
224
228
  seq_lens_sum=seq_lens.sum().item(),
225
229
  return_logprob=False,
226
230
  positions=positions,
231
+ mrope_positions=mrope_positions,
227
232
  global_num_tokens_gpu=global_num_tokens,
228
233
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
229
234
  global_dp_buffer_len=global_dp_buffer_len,
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
80
80
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
81
81
  self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
82
82
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
83
+ self.mrope_positions = torch.zeros(
84
+ (3, self.max_num_token), dtype=torch.int64
85
+ )
83
86
 
84
87
  if self.eagle_worker.speculative_algorithm.is_eagle3():
85
88
  self.hidden_states = torch.zeros(
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
189
192
  accept_length = self.accept_length[:bs]
190
193
  out_cache_loc = self.out_cache_loc[:num_tokens]
191
194
  positions = self.positions[:num_tokens]
195
+ mrope_positions = self.mrope_positions[:, :num_tokens]
192
196
  hidden_states = self.hidden_states[:num_tokens]
193
197
  next_token_logits_buffer = self.next_token_logits_buffer[:bs]
194
198
 
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
247
251
  seq_lens_sum=seq_lens.sum().item(),
248
252
  return_logprob=False,
249
253
  positions=positions,
254
+ mrope_positions=mrope_positions,
250
255
  global_num_tokens_gpu=self.global_num_tokens_gpu,
251
256
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
252
257
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
336
341
  self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
337
342
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
338
343
  self.positions[:num_tokens].copy_(forward_batch.positions)
339
- self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
344
+ if (
345
+ forward_batch.spec_info.hidden_states.shape[1]
346
+ == self.hidden_states.shape[1]
347
+ ):
348
+ self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
340
349
  if forward_batch.spec_info.accept_length is not None:
341
350
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
342
351
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -26,8 +26,6 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
27
27
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
28
28
 
29
- logger = logging.getLogger(__name__)
30
-
31
29
  if is_cuda():
32
30
  from sgl_kernel import (
33
31
  fast_topk,
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
14
14
  )
15
15
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
16
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
+ from sglang.srt.managers.mm_utils import embed_mm_inputs
17
18
  from sglang.srt.managers.schedule_batch import (
18
19
  ScheduleBatch,
19
20
  get_last_loc,
@@ -47,6 +48,7 @@ from sglang.srt.utils import (
47
48
  empty_context,
48
49
  get_available_gpu_memory,
49
50
  get_bool_env_var,
51
+ is_blackwell,
50
52
  is_cuda,
51
53
  next_power_of_2,
52
54
  )
@@ -187,137 +189,197 @@ class EAGLEWorker(TpModelWorker):
187
189
  self.has_prefill_wrapper_verify = False
188
190
  self.draft_extend_attn_backend = None
189
191
 
190
- if self.server_args.attention_backend == "flashinfer":
191
- if not global_server_args_dict["use_mla_backend"]:
192
- from sglang.srt.layers.attention.flashinfer_backend import (
193
- FlashInferAttnBackend,
194
- FlashInferMultiStepDraftBackend,
195
- )
192
+ # Initialize decode attention backend
193
+ self.draft_attn_backend = self._create_decode_backend()
196
194
 
197
- self.draft_attn_backend = FlashInferMultiStepDraftBackend(
198
- self.draft_model_runner,
199
- self.topk,
200
- self.speculative_num_steps,
201
- )
202
- self.draft_extend_attn_backend = FlashInferAttnBackend(
203
- self.draft_model_runner,
204
- skip_prefill=False,
205
- )
206
- else:
207
- from sglang.srt.layers.attention.flashinfer_mla_backend import (
208
- FlashInferMLAAttnBackend,
209
- FlashInferMLAMultiStepDraftBackend,
210
- )
195
+ # Initialize draft extend attention backend (respects speculative_attention_mode setting)
196
+ self.draft_extend_attn_backend = self._create_draft_extend_backend()
211
197
 
212
- self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
213
- self.draft_model_runner,
214
- self.topk,
215
- self.speculative_num_steps,
216
- )
217
- self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
218
- self.draft_model_runner,
219
- skip_prefill=False,
220
- )
221
- self.has_prefill_wrapper_verify = True
222
- elif self.server_args.attention_backend == "triton":
223
- from sglang.srt.layers.attention.triton_backend import (
224
- TritonAttnBackend,
225
- TritonMultiStepDraftBackend,
226
- )
198
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
227
199
 
228
- self.draft_attn_backend = TritonMultiStepDraftBackend(
229
- self.draft_model_runner,
230
- self.topk,
231
- self.speculative_num_steps,
232
- )
233
- self.draft_extend_attn_backend = TritonAttnBackend(
234
- self.draft_model_runner,
235
- skip_prefill=False,
236
- )
237
- elif self.server_args.attention_backend == "aiter":
238
- from sglang.srt.layers.attention.aiter_backend import (
239
- AiterAttnBackend,
240
- AiterMultiStepDraftBackend,
241
- )
200
+ def _create_backend(
201
+ self, backend_name: str, backend_map: dict, error_template: str
202
+ ):
203
+ backend_type = getattr(self.server_args, backend_name)
204
+ if backend_type is None:
205
+ backend_type = self.server_args.attention_backend
206
+
207
+ if backend_type not in backend_map:
208
+ raise ValueError(error_template.format(backend_type=backend_type))
209
+
210
+ return backend_map[backend_type]()
211
+
212
+ def _create_decode_backend(self):
213
+ backend_map = {
214
+ "flashinfer": self._create_flashinfer_decode_backend,
215
+ "triton": self._create_triton_decode_backend,
216
+ "aiter": self._create_aiter_decode_backend,
217
+ "fa3": self._create_fa3_decode_backend,
218
+ "hybrid_linear_attn": (
219
+ self._create_fa3_decode_backend
220
+ if not is_blackwell()
221
+ else self._create_triton_decode_backend
222
+ ),
223
+ "flashmla": self._create_flashmla_decode_backend,
224
+ "trtllm_mha": self._create_trtllm_mha_decode_backend,
225
+ "trtllm_mla": self._create_trtllm_mla_decode_backend,
226
+ }
227
+
228
+ return self._create_backend(
229
+ "decode_attention_backend",
230
+ backend_map,
231
+ "EAGLE is not supported in decode attention backend {backend_type}",
232
+ )
242
233
 
243
- self.draft_attn_backend = AiterMultiStepDraftBackend(
244
- self.draft_model_runner,
245
- self.topk,
246
- self.speculative_num_steps,
247
- )
248
- self.draft_extend_attn_backend = AiterAttnBackend(
249
- self.draft_model_runner,
250
- skip_prefill=False,
251
- )
252
- self.has_prefill_wrapper_verify = False
253
- elif self.server_args.attention_backend == "fa3":
254
- from sglang.srt.layers.attention.flashattention_backend import (
255
- FlashAttentionBackend,
256
- FlashAttentionMultiStepBackend,
257
- )
234
+ def _create_draft_extend_backend(self):
235
+ backend_map = {
236
+ "flashinfer": self._create_flashinfer_prefill_backend,
237
+ "triton": self._create_triton_prefill_backend,
238
+ "aiter": self._create_aiter_prefill_backend,
239
+ "fa3": self._create_fa3_prefill_backend,
240
+ "hybrid_linear_attn": (
241
+ self._create_fa3_prefill_backend
242
+ if not is_blackwell()
243
+ else self._create_triton_prefill_backend
244
+ ),
245
+ "trtllm_mha": self._create_trtllm_mha_prefill_backend,
246
+ "trtllm_mla": self._create_trtllm_mla_prefill_backend,
247
+ }
248
+ backend_name = (
249
+ "decode_attention_backend"
250
+ if self.server_args.speculative_attention_mode == "decode"
251
+ else "prefill_attention_backend"
252
+ )
253
+ return self._create_backend(
254
+ backend_name,
255
+ backend_map,
256
+ "EAGLE is not supported in attention backend {backend_type}",
257
+ )
258
258
 
259
- self.draft_attn_backend = FlashAttentionMultiStepBackend(
260
- self.draft_model_runner,
261
- self.topk,
262
- self.speculative_num_steps,
263
- )
264
- self.draft_extend_attn_backend = FlashAttentionBackend(
265
- self.draft_model_runner,
266
- skip_prefill=False,
267
- )
268
- elif self.server_args.attention_backend == "flashmla":
269
- from sglang.srt.layers.attention.flashmla_backend import (
270
- FlashMLAMultiStepDraftBackend,
259
+ def _create_flashinfer_decode_backend(self):
260
+ if not global_server_args_dict["use_mla_backend"]:
261
+ from sglang.srt.layers.attention.flashinfer_backend import (
262
+ FlashInferMultiStepDraftBackend,
271
263
  )
272
264
 
273
- self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
274
- self.draft_model_runner,
275
- self.topk,
276
- self.speculative_num_steps,
265
+ self.has_prefill_wrapper_verify = True
266
+ return FlashInferMultiStepDraftBackend(
267
+ self.draft_model_runner, self.topk, self.speculative_num_steps
277
268
  )
278
- elif self.server_args.attention_backend == "trtllm_mha":
279
- from sglang.srt.layers.attention.trtllm_mha_backend import (
280
- TRTLLMHAAttnBackend,
281
- TRTLLMHAAttnMultiStepDraftBackend,
269
+ else:
270
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
271
+ FlashInferMLAMultiStepDraftBackend,
282
272
  )
283
273
 
284
- self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
285
- self.draft_model_runner,
286
- self.topk,
287
- self.speculative_num_steps,
288
- )
289
- self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
290
- self.draft_model_runner,
291
- skip_prefill=False,
292
- )
293
274
  self.has_prefill_wrapper_verify = True
294
- elif self.server_args.attention_backend == "trtllm_mla":
295
- if not global_server_args_dict["use_mla_backend"]:
296
- raise ValueError(
297
- "trtllm_mla backend requires MLA model (use_mla_backend=True)."
298
- )
299
-
300
- from sglang.srt.layers.attention.trtllm_mla_backend import (
301
- TRTLLMMLABackend,
302
- TRTLLMMLAMultiStepDraftBackend,
275
+ return FlashInferMLAMultiStepDraftBackend(
276
+ self.draft_model_runner, self.topk, self.speculative_num_steps
303
277
  )
304
278
 
305
- self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
306
- self.draft_model_runner,
307
- self.topk,
308
- self.speculative_num_steps,
279
+ def _create_triton_decode_backend(self):
280
+ from sglang.srt.layers.attention.triton_backend import (
281
+ TritonMultiStepDraftBackend,
282
+ )
283
+
284
+ return TritonMultiStepDraftBackend(
285
+ self.draft_model_runner, self.topk, self.speculative_num_steps
286
+ )
287
+
288
+ def _create_aiter_decode_backend(self):
289
+ from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
290
+
291
+ return AiterMultiStepDraftBackend(
292
+ self.draft_model_runner, self.topk, self.speculative_num_steps
293
+ )
294
+
295
+ def _create_fa3_decode_backend(self):
296
+ from sglang.srt.layers.attention.flashattention_backend import (
297
+ FlashAttentionMultiStepBackend,
298
+ )
299
+
300
+ return FlashAttentionMultiStepBackend(
301
+ self.draft_model_runner, self.topk, self.speculative_num_steps
302
+ )
303
+
304
+ def _create_flashmla_decode_backend(self):
305
+ from sglang.srt.layers.attention.flashmla_backend import (
306
+ FlashMLAMultiStepDraftBackend,
307
+ )
308
+
309
+ return FlashMLAMultiStepDraftBackend(
310
+ self.draft_model_runner, self.topk, self.speculative_num_steps
311
+ )
312
+
313
+ def _create_trtllm_mha_decode_backend(self):
314
+ from sglang.srt.layers.attention.trtllm_mha_backend import (
315
+ TRTLLMHAAttnMultiStepDraftBackend,
316
+ )
317
+
318
+ self.has_prefill_wrapper_verify = True
319
+ return TRTLLMHAAttnMultiStepDraftBackend(
320
+ self.draft_model_runner, self.topk, self.speculative_num_steps
321
+ )
322
+
323
+ def _create_trtllm_mla_decode_backend(self):
324
+ if not global_server_args_dict["use_mla_backend"]:
325
+ raise ValueError(
326
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
309
327
  )
310
- self.draft_extend_attn_backend = TRTLLMMLABackend(
311
- self.draft_model_runner,
312
- skip_prefill=False,
328
+
329
+ from sglang.srt.layers.attention.trtllm_mla_backend import (
330
+ TRTLLMMLAMultiStepDraftBackend,
331
+ )
332
+
333
+ self.has_prefill_wrapper_verify = True
334
+ return TRTLLMMLAMultiStepDraftBackend(
335
+ self.draft_model_runner, self.topk, self.speculative_num_steps
336
+ )
337
+
338
+ def _create_flashinfer_prefill_backend(self):
339
+ if not global_server_args_dict["use_mla_backend"]:
340
+ from sglang.srt.layers.attention.flashinfer_backend import (
341
+ FlashInferAttnBackend,
313
342
  )
314
- self.has_prefill_wrapper_verify = True
343
+
344
+ return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
315
345
  else:
346
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
347
+ FlashInferMLAAttnBackend,
348
+ )
349
+
350
+ return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
351
+
352
+ def _create_triton_prefill_backend(self):
353
+ from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
354
+
355
+ return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
356
+
357
+ def _create_aiter_prefill_backend(self):
358
+ from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
359
+
360
+ return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
361
+
362
+ def _create_fa3_prefill_backend(self):
363
+ from sglang.srt.layers.attention.flashattention_backend import (
364
+ FlashAttentionBackend,
365
+ )
366
+
367
+ return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
368
+
369
+ def _create_trtllm_mha_prefill_backend(self):
370
+ from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
371
+
372
+ return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
373
+
374
+ def _create_trtllm_mla_prefill_backend(self):
375
+ if not global_server_args_dict["use_mla_backend"]:
316
376
  raise ValueError(
317
- f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
377
+ "trtllm_mla backend requires MLA model (use_mla_backend=True)."
318
378
  )
319
379
 
320
- self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
380
+ from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
381
+
382
+ return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
321
383
 
322
384
  def init_cuda_graphs(self):
323
385
  """Capture cuda graphs."""
@@ -683,6 +745,14 @@ class EAGLEWorker(TpModelWorker):
683
745
 
684
746
  # Set inputs
685
747
  forward_batch.input_ids = input_ids
748
+ # This is a temporary fix for the case that the user is using standalone
749
+ # speculative decoding and the draft model architecture is gpt-oss. gpt-oss
750
+ # rope kernel needs cache_loc to be contiguous.
751
+ if (
752
+ self.server_args.speculative_algorithm == "STANDALONE"
753
+ and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
754
+ ):
755
+ out_cache_loc = out_cache_loc.contiguous()
686
756
  forward_batch.out_cache_loc = out_cache_loc[i]
687
757
  forward_batch.positions.add_(1)
688
758
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
@@ -767,6 +837,21 @@ class EAGLEWorker(TpModelWorker):
767
837
  ]
768
838
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
769
839
 
840
+ # QQ: can be optimized
841
+ if self.target_worker.model_runner.is_hybrid_gdn:
842
+ # res.draft_input.accept_length is on GPU but may be empty for last verify?
843
+ accepted_length = (
844
+ torch.tensor(
845
+ res.accept_length_per_req_cpu,
846
+ device=logits_output.hidden_states.device,
847
+ dtype=torch.int32,
848
+ )
849
+ + 1
850
+ )
851
+ self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
852
+ accepted_length, self.target_worker.model_runner.model
853
+ )
854
+
770
855
  if batch.return_logprob:
771
856
  self.add_logprob_values(batch, res, logits_output)
772
857
 
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
7
  EAGLE3 = auto()
8
+ STANDALONE = auto()
8
9
 
9
10
  def is_none(self):
10
11
  return self == SpeculativeAlgorithm.NONE
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
15
16
  def is_eagle3(self):
16
17
  return self == SpeculativeAlgorithm.EAGLE3
17
18
 
19
+ def is_standalone(self):
20
+ return self == SpeculativeAlgorithm.STANDALONE
21
+
18
22
  @staticmethod
19
23
  def from_string(name: str):
20
24
  name_map = {
21
25
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
26
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
27
+ "STANDALONE": SpeculativeAlgorithm.STANDALONE,
23
28
  None: SpeculativeAlgorithm.NONE,
24
29
  }
25
30
  if name is not None:
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
8
+ from sglang.srt.managers.tp_worker import TpModelWorker
9
+ from sglang.srt.server_args import ServerArgs
10
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
11
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
12
+ from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
13
+
14
+ if is_cuda():
15
+ from sgl_kernel import segment_packbits
16
+
17
+ logger = logging.getLogger(__name__)
18
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
19
+
20
+
21
+ @contextmanager
22
+ def draft_tp_context(tp_group: GroupCoordinator):
23
+ # Draft model doesn't use dp and has its own tp group.
24
+ # We disable mscclpp now because it doesn't support 2 comm groups.
25
+ with patch_tensor_parallel_group(tp_group):
26
+ yield
27
+
28
+
29
+ class StandaloneWorker(EAGLEWorker):
30
+
31
+ def __init__(
32
+ self,
33
+ server_args: ServerArgs,
34
+ gpu_id: int,
35
+ tp_rank: int,
36
+ dp_rank: Optional[int],
37
+ moe_ep_rank: int,
38
+ nccl_port: int,
39
+ target_worker: TpModelWorker,
40
+ ):
41
+ # Parse arguments
42
+ self.server_args = server_args
43
+ self.topk = server_args.speculative_eagle_topk
44
+ self.speculative_num_steps = server_args.speculative_num_steps
45
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
46
+ self.enable_nan_detection = server_args.enable_nan_detection
47
+ self.gpu_id = gpu_id
48
+ self.device = server_args.device
49
+ self.target_worker = target_worker
50
+ self.page_size = server_args.page_size
51
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
52
+ server_args.speculative_algorithm
53
+ )
54
+ self.padded_static_len = -1
55
+
56
+ # Override the context length of the draft model to be the same as the target model.
57
+ server_args.context_length = target_worker.model_runner.model_config.context_len
58
+
59
+ # Do not capture cuda graph in `super().__init__()`
60
+ # It will be captured later.
61
+ backup_disable_cuda_graph = server_args.disable_cuda_graph
62
+ server_args.disable_cuda_graph = True
63
+ # Share the allocator with a target worker.
64
+ # Draft and target worker own their own KV cache pools.
65
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
66
+ target_worker.get_memory_pool()
67
+ )
68
+
69
+ # Load hot token ids
70
+ if server_args.speculative_token_map is not None:
71
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
72
+ server_args.json_model_override_args = (
73
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
74
+ )
75
+ else:
76
+ self.hot_token_id = None
77
+
78
+ # Init draft worker
79
+ with empty_context():
80
+ TpModelWorker.__init__(
81
+ self,
82
+ server_args=server_args,
83
+ gpu_id=gpu_id,
84
+ tp_rank=tp_rank,
85
+ pp_rank=0, # FIXME
86
+ dp_rank=dp_rank,
87
+ moe_ep_rank=moe_ep_rank,
88
+ nccl_port=nccl_port,
89
+ is_draft_worker=True,
90
+ req_to_token_pool=self.req_to_token_pool,
91
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
92
+ )
93
+
94
+ # Init attention backend and cuda graphs
95
+ self.draft_model_runner.server_args.disable_cuda_graph = (
96
+ backup_disable_cuda_graph
97
+ )
98
+ self.draft_tp_context = (
99
+ draft_tp_context if server_args.enable_dp_attention else empty_context
100
+ )
101
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
102
+ self.init_attention_backend()
103
+ self.init_cuda_graphs()
104
+
105
+ # Some dummy tensors
106
+ self.num_new_pages_per_topk = torch.empty(
107
+ (), dtype=torch.int64, device=self.device
108
+ )
109
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)