sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
91
91
  (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
92
92
  )
93
93
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
94
+ self.mrope_positions = torch.zeros(
95
+ (3, self.max_num_token), dtype=torch.int64
96
+ )
94
97
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
95
98
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
96
99
  self.hidden_states = torch.zeros(
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
159
162
  seq_lens = self.seq_lens[:num_seqs]
160
163
  out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
161
164
  positions = self.positions[:num_tokens]
165
+ mrope_positions = self.mrope_positions[:, :num_tokens]
162
166
  topk_p = self.topk_p[:num_seqs]
163
167
  topk_index = self.topk_index[:num_seqs]
164
168
  hidden_states = self.hidden_states[:num_seqs]
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
224
228
  seq_lens_sum=seq_lens.sum().item(),
225
229
  return_logprob=False,
226
230
  positions=positions,
231
+ mrope_positions=mrope_positions,
227
232
  global_num_tokens_gpu=global_num_tokens,
228
233
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
229
234
  global_dp_buffer_len=global_dp_buffer_len,
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
80
80
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
81
81
  self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
82
82
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
83
+ self.mrope_positions = torch.zeros(
84
+ (3, self.max_num_token), dtype=torch.int64
85
+ )
83
86
 
84
87
  if self.eagle_worker.speculative_algorithm.is_eagle3():
85
88
  self.hidden_states = torch.zeros(
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
189
192
  accept_length = self.accept_length[:bs]
190
193
  out_cache_loc = self.out_cache_loc[:num_tokens]
191
194
  positions = self.positions[:num_tokens]
195
+ mrope_positions = self.mrope_positions[:, :num_tokens]
192
196
  hidden_states = self.hidden_states[:num_tokens]
193
197
  next_token_logits_buffer = self.next_token_logits_buffer[:bs]
194
198
 
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
247
251
  seq_lens_sum=seq_lens.sum().item(),
248
252
  return_logprob=False,
249
253
  positions=positions,
254
+ mrope_positions=mrope_positions,
250
255
  global_num_tokens_gpu=self.global_num_tokens_gpu,
251
256
  global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
252
257
  dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
@@ -336,7 +341,11 @@ class EAGLEDraftExtendCudaGraphRunner:
336
341
  self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
337
342
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
338
343
  self.positions[:num_tokens].copy_(forward_batch.positions)
339
- self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
344
+ if (
345
+ forward_batch.spec_info.hidden_states.shape[1]
346
+ == self.hidden_states.shape[1]
347
+ ):
348
+ self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
340
349
  if forward_batch.spec_info.accept_length is not None:
341
350
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
342
351
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
@@ -26,8 +26,6 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
26
26
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
27
27
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
28
28
 
29
- logger = logging.getLogger(__name__)
30
-
31
29
  if is_cuda():
32
30
  from sgl_kernel import (
33
31
  fast_topk,
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
14
14
  )
15
15
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
16
16
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
17
+ from sglang.srt.managers.mm_utils import embed_mm_inputs
17
18
  from sglang.srt.managers.schedule_batch import (
18
19
  ScheduleBatch,
19
20
  get_last_loc,
@@ -47,6 +48,7 @@ from sglang.srt.utils import (
47
48
  empty_context,
48
49
  get_available_gpu_memory,
49
50
  get_bool_env_var,
51
+ is_blackwell,
50
52
  is_cuda,
51
53
  next_power_of_2,
52
54
  )
@@ -190,7 +192,7 @@ class EAGLEWorker(TpModelWorker):
190
192
  # Initialize decode attention backend
191
193
  self.draft_attn_backend = self._create_decode_backend()
192
194
 
193
- # Initialize prefill attention backend
195
+ # Initialize draft extend attention backend (respects speculative_attention_mode setting)
194
196
  self.draft_extend_attn_backend = self._create_draft_extend_backend()
195
197
 
196
198
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
@@ -213,6 +215,11 @@ class EAGLEWorker(TpModelWorker):
213
215
  "triton": self._create_triton_decode_backend,
214
216
  "aiter": self._create_aiter_decode_backend,
215
217
  "fa3": self._create_fa3_decode_backend,
218
+ "hybrid_linear_attn": (
219
+ self._create_fa3_decode_backend
220
+ if not is_blackwell()
221
+ else self._create_triton_decode_backend
222
+ ),
216
223
  "flashmla": self._create_flashmla_decode_backend,
217
224
  "trtllm_mha": self._create_trtllm_mha_decode_backend,
218
225
  "trtllm_mla": self._create_trtllm_mla_decode_backend,
@@ -230,14 +237,23 @@ class EAGLEWorker(TpModelWorker):
230
237
  "triton": self._create_triton_prefill_backend,
231
238
  "aiter": self._create_aiter_prefill_backend,
232
239
  "fa3": self._create_fa3_prefill_backend,
240
+ "hybrid_linear_attn": (
241
+ self._create_fa3_prefill_backend
242
+ if not is_blackwell()
243
+ else self._create_triton_prefill_backend
244
+ ),
233
245
  "trtllm_mha": self._create_trtllm_mha_prefill_backend,
234
246
  "trtllm_mla": self._create_trtllm_mla_prefill_backend,
235
247
  }
236
-
248
+ backend_name = (
249
+ "decode_attention_backend"
250
+ if self.server_args.speculative_attention_mode == "decode"
251
+ else "prefill_attention_backend"
252
+ )
237
253
  return self._create_backend(
238
- "prefill_attention_backend",
254
+ backend_name,
239
255
  backend_map,
240
- "EAGLE is not supported in prefill attention backend {backend_type}",
256
+ "EAGLE is not supported in attention backend {backend_type}",
241
257
  )
242
258
 
243
259
  def _create_flashinfer_decode_backend(self):
@@ -729,6 +745,14 @@ class EAGLEWorker(TpModelWorker):
729
745
 
730
746
  # Set inputs
731
747
  forward_batch.input_ids = input_ids
748
+ # This is a temporary fix for the case that the user is using standalone
749
+ # speculative decoding and the draft model architecture is gpt-oss. gpt-oss
750
+ # rope kernel needs cache_loc to be contiguous.
751
+ if (
752
+ self.server_args.speculative_algorithm == "STANDALONE"
753
+ and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
754
+ ):
755
+ out_cache_loc = out_cache_loc.contiguous()
732
756
  forward_batch.out_cache_loc = out_cache_loc[i]
733
757
  forward_batch.positions.add_(1)
734
758
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
@@ -813,6 +837,21 @@ class EAGLEWorker(TpModelWorker):
813
837
  ]
814
838
  logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
815
839
 
840
+ # QQ: can be optimized
841
+ if self.target_worker.model_runner.is_hybrid_gdn:
842
+ # res.draft_input.accept_length is on GPU but may be empty for last verify?
843
+ accepted_length = (
844
+ torch.tensor(
845
+ res.accept_length_per_req_cpu,
846
+ device=logits_output.hidden_states.device,
847
+ dtype=torch.int32,
848
+ )
849
+ + 1
850
+ )
851
+ self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
852
+ accepted_length, self.target_worker.model_runner.model
853
+ )
854
+
816
855
  if batch.return_logprob:
817
856
  self.add_logprob_values(batch, res, logits_output)
818
857
 
@@ -5,6 +5,7 @@ class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
7
  EAGLE3 = auto()
8
+ STANDALONE = auto()
8
9
 
9
10
  def is_none(self):
10
11
  return self == SpeculativeAlgorithm.NONE
@@ -15,11 +16,15 @@ class SpeculativeAlgorithm(IntEnum):
15
16
  def is_eagle3(self):
16
17
  return self == SpeculativeAlgorithm.EAGLE3
17
18
 
19
+ def is_standalone(self):
20
+ return self == SpeculativeAlgorithm.STANDALONE
21
+
18
22
  @staticmethod
19
23
  def from_string(name: str):
20
24
  name_map = {
21
25
  "EAGLE": SpeculativeAlgorithm.EAGLE,
22
26
  "EAGLE3": SpeculativeAlgorithm.EAGLE3,
27
+ "STANDALONE": SpeculativeAlgorithm.STANDALONE,
23
28
  None: SpeculativeAlgorithm.NONE,
24
29
  }
25
30
  if name is not None:
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from contextlib import contextmanager
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
8
+ from sglang.srt.managers.tp_worker import TpModelWorker
9
+ from sglang.srt.server_args import ServerArgs
10
+ from sglang.srt.speculative.eagle_worker import EAGLEWorker, load_token_map
11
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
12
+ from sglang.srt.utils import empty_context, get_bool_env_var, is_cuda
13
+
14
+ if is_cuda():
15
+ from sgl_kernel import segment_packbits
16
+
17
+ logger = logging.getLogger(__name__)
18
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
19
+
20
+
21
+ @contextmanager
22
+ def draft_tp_context(tp_group: GroupCoordinator):
23
+ # Draft model doesn't use dp and has its own tp group.
24
+ # We disable mscclpp now because it doesn't support 2 comm groups.
25
+ with patch_tensor_parallel_group(tp_group):
26
+ yield
27
+
28
+
29
+ class StandaloneWorker(EAGLEWorker):
30
+
31
+ def __init__(
32
+ self,
33
+ server_args: ServerArgs,
34
+ gpu_id: int,
35
+ tp_rank: int,
36
+ dp_rank: Optional[int],
37
+ moe_ep_rank: int,
38
+ nccl_port: int,
39
+ target_worker: TpModelWorker,
40
+ ):
41
+ # Parse arguments
42
+ self.server_args = server_args
43
+ self.topk = server_args.speculative_eagle_topk
44
+ self.speculative_num_steps = server_args.speculative_num_steps
45
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
46
+ self.enable_nan_detection = server_args.enable_nan_detection
47
+ self.gpu_id = gpu_id
48
+ self.device = server_args.device
49
+ self.target_worker = target_worker
50
+ self.page_size = server_args.page_size
51
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
52
+ server_args.speculative_algorithm
53
+ )
54
+ self.padded_static_len = -1
55
+
56
+ # Override the context length of the draft model to be the same as the target model.
57
+ server_args.context_length = target_worker.model_runner.model_config.context_len
58
+
59
+ # Do not capture cuda graph in `super().__init__()`
60
+ # It will be captured later.
61
+ backup_disable_cuda_graph = server_args.disable_cuda_graph
62
+ server_args.disable_cuda_graph = True
63
+ # Share the allocator with a target worker.
64
+ # Draft and target worker own their own KV cache pools.
65
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
66
+ target_worker.get_memory_pool()
67
+ )
68
+
69
+ # Load hot token ids
70
+ if server_args.speculative_token_map is not None:
71
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
72
+ server_args.json_model_override_args = (
73
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
74
+ )
75
+ else:
76
+ self.hot_token_id = None
77
+
78
+ # Init draft worker
79
+ with empty_context():
80
+ TpModelWorker.__init__(
81
+ self,
82
+ server_args=server_args,
83
+ gpu_id=gpu_id,
84
+ tp_rank=tp_rank,
85
+ pp_rank=0, # FIXME
86
+ dp_rank=dp_rank,
87
+ moe_ep_rank=moe_ep_rank,
88
+ nccl_port=nccl_port,
89
+ is_draft_worker=True,
90
+ req_to_token_pool=self.req_to_token_pool,
91
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
92
+ )
93
+
94
+ # Init attention backend and cuda graphs
95
+ self.draft_model_runner.server_args.disable_cuda_graph = (
96
+ backup_disable_cuda_graph
97
+ )
98
+ self.draft_tp_context = (
99
+ draft_tp_context if server_args.enable_dp_attention else empty_context
100
+ )
101
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
102
+ self.init_attention_backend()
103
+ self.init_cuda_graphs()
104
+
105
+ # Some dummy tensors
106
+ self.num_new_pages_per_topk = torch.empty(
107
+ (), dtype=torch.int64, device=self.device
108
+ )
109
+ self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)