sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
40
  Qwen2_5_VisionRotaryEmbedding,
41
41
  )
42
42
 
43
- from sglang.srt.hf_transformers_utils import get_processor
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.layernorm import RMSNorm
46
45
  from sglang.srt.layers.linear import (
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2Model
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -113,12 +113,13 @@ class Qwen2_5_VisionBlock(nn.Module):
113
113
  quant_config: Optional[QuantizationConfig] = None,
114
114
  prefix: str = "",
115
115
  num_dummy_heads: int = 0,
116
+ rms_norm_eps: float = 1e-6,
116
117
  ) -> None:
117
118
  super().__init__()
118
119
  if norm_layer is None:
119
120
  norm_layer = partial(nn.LayerNorm, eps=1e-6)
120
- self.norm1 = RMSNorm(dim, eps=1e-6)
121
- self.norm2 = RMSNorm(dim, eps=1e-6)
121
+ self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
122
+ self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
122
123
 
123
124
  if attn_implementation is None:
124
125
  softmax_in_single_precision = False
@@ -264,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
264
265
  self.fullatt_block_indexes = vision_config.fullatt_block_indexes
265
266
  self.window_size = vision_config.window_size
266
267
  self.patch_size = vision_config.patch_size
267
- mlp_hidden_size: int = vision_config.intermediate_size
268
+ mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
268
269
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
269
270
  patch_size=patch_size,
270
271
  temporal_patch_size=temporal_patch_size,
@@ -517,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
517
518
  self.logits_processor = LogitsProcessor(config)
518
519
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
519
520
 
521
+ # For EAGLE3 support
522
+ self.capture_aux_hidden_states = False
523
+
520
524
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
521
525
  pattern = MultiModalityDataPaddingPatternMultimodalTokens()
522
526
  return pattern.pad_input_tokens(input_ids, mm_inputs)
@@ -587,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
587
591
  positions=positions,
588
592
  )
589
593
 
594
+ aux_hidden_states = None
595
+ if self.capture_aux_hidden_states:
596
+ hidden_states, aux_hidden_states = hidden_states
597
+
590
598
  if not get_embedding:
591
599
  return self.logits_processor(
592
- input_ids, hidden_states, self.lm_head, forward_batch
600
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
593
601
  )
594
602
  else:
595
603
  return self.pooler(hidden_states, forward_batch)
@@ -643,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
643
651
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
644
652
  weight_loader(param, loaded_weight)
645
653
 
654
+ def get_embed_and_head(self):
655
+ return self.model.embed_tokens.weight, self.lm_head.weight
656
+
657
+ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
658
+ self.capture_aux_hidden_states = True
659
+ self.model.capture_aux_hidden_states = True
660
+ if layer_ids is None:
661
+ num_layers = self.config.num_hidden_layers
662
+ self.model.layers_to_capture = [
663
+ 2,
664
+ num_layers // 2,
665
+ num_layers - 3,
666
+ ] # Specific layers for EAGLE3 support
667
+ else:
668
+ self.model.layers_to_capture = [val + 1 for val in layer_ids]
669
+
646
670
 
647
671
  EntryClass = [Qwen2_5_VLForConditionalGeneration]
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
39
39
  Qwen2AudioMultiModalProjector,
40
40
  )
41
41
 
42
- from sglang.srt.hf_transformers_utils import get_processor
43
42
  from sglang.srt.layers.activation import QuickGELU
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -25,12 +25,14 @@ from torch import nn
25
25
  from transformers import PretrainedConfig
26
26
 
27
27
  from sglang.srt.distributed import (
28
+ get_moe_expert_parallel_world_size,
28
29
  get_pp_group,
29
30
  get_tensor_model_parallel_world_size,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
33
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
33
34
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
35
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
34
36
  from sglang.srt.layers.activation import SiluAndMul
35
37
  from sglang.srt.layers.communicator import (
36
38
  LayerCommunicator,
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
50
52
  RowParallelLinear,
51
53
  )
52
54
  from sglang.srt.layers.logits_processor import LogitsProcessor
55
+ from sglang.srt.layers.moe import get_moe_a2a_backend
53
56
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import TopK
@@ -62,13 +65,16 @@ from sglang.srt.layers.vocab_parallel_embedding import (
62
65
  VocabParallelEmbedding,
63
66
  )
64
67
  from sglang.srt.managers.schedule_batch import global_server_args_dict
68
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
65
69
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
66
70
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
71
  from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
68
- from sglang.srt.utils import add_prefix, make_layers
72
+ from sglang.srt.utils import add_prefix, is_cuda, make_layers
69
73
 
70
74
  logger = logging.getLogger(__name__)
71
75
 
76
+ _is_cuda = is_cuda()
77
+
72
78
 
73
79
  class Qwen2MoeMLP(nn.Module):
74
80
  def __init__(
@@ -79,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
79
85
  quant_config: Optional[QuantizationConfig] = None,
80
86
  reduce_results: bool = True,
81
87
  prefix: str = "",
88
+ tp_rank: Optional[int] = None,
89
+ tp_size: Optional[int] = None,
82
90
  ) -> None:
83
91
  super().__init__()
84
92
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -87,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
87
95
  bias=False,
88
96
  quant_config=quant_config,
89
97
  prefix=add_prefix("gate_up_proj", prefix),
98
+ tp_rank=tp_rank,
99
+ tp_size=tp_size,
90
100
  )
91
101
  self.down_proj = RowParallelLinear(
92
102
  intermediate_size,
@@ -95,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
95
105
  quant_config=quant_config,
96
106
  reduce_results=reduce_results,
97
107
  prefix=add_prefix("down_proj", prefix),
108
+ tp_rank=tp_rank,
109
+ tp_size=tp_size,
98
110
  )
99
111
  if hidden_act != "silu":
100
112
  raise ValueError(
@@ -122,11 +134,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
122
134
  layer_id: int,
123
135
  config: PretrainedConfig,
124
136
  quant_config: Optional[QuantizationConfig] = None,
137
+ alt_stream: Optional[torch.cuda.Stream] = None,
125
138
  prefix: str = "",
126
139
  ):
127
140
  super().__init__()
128
141
  self.tp_size = get_tensor_model_parallel_world_size()
129
142
  self.layer_id = layer_id
143
+ self.alt_stream = alt_stream
130
144
  if self.tp_size > config.num_experts:
131
145
  raise ValueError(
132
146
  f"Tensor parallel size {self.tp_size} is greater than "
@@ -138,10 +152,11 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
138
152
  renormalize=config.norm_topk_prob,
139
153
  )
140
154
 
141
- self.experts = get_moe_impl_class()(
155
+ self.experts = get_moe_impl_class(quant_config)(
142
156
  layer_id=self.layer_id,
143
157
  top_k=config.num_experts_per_tok,
144
- num_experts=config.num_experts,
158
+ num_experts=config.num_experts
159
+ + global_server_args_dict["ep_num_redundant_experts"],
145
160
  hidden_size=config.hidden_size,
146
161
  intermediate_size=config.moe_intermediate_size,
147
162
  quant_config=quant_config,
@@ -163,19 +178,32 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
163
178
  quant_config=quant_config,
164
179
  reduce_results=False,
165
180
  prefix=add_prefix("shared_expert", prefix),
181
+ **(
182
+ dict(tp_rank=0, tp_size=1)
183
+ if get_moe_a2a_backend().is_deepep()
184
+ else {}
185
+ ),
166
186
  )
167
187
  else:
168
188
  self.shared_expert = None
169
189
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
170
190
 
171
- def forward(
172
- self,
173
- hidden_states: torch.Tensor,
174
- forward_batch: Optional[ForwardBatch] = None,
175
- use_reduce_scatter: bool = False,
176
- ) -> torch.Tensor:
177
- num_tokens, hidden_dim = hidden_states.shape
178
- hidden_states = hidden_states.view(-1, hidden_dim)
191
+ if get_moe_a2a_backend().is_deepep():
192
+ # TODO: we will support tp < ep in the future
193
+ self.ep_size = get_moe_expert_parallel_world_size()
194
+ self.num_experts = (
195
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
196
+ )
197
+ self.top_k = config.num_experts_per_tok
198
+
199
+ def get_moe_weights(self):
200
+ return [
201
+ x.data
202
+ for name, x in self.experts.named_parameters()
203
+ if name not in ["correction_bias"]
204
+ ]
205
+
206
+ def _forward_shared_experts(self, hidden_states: torch.Tensor):
179
207
  shared_output = None
180
208
  if self.shared_expert is not None:
181
209
  shared_output = self.shared_expert(hidden_states)
@@ -183,11 +211,85 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
183
211
  shared_output = (
184
212
  F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
185
213
  )
214
+ return shared_output
215
+
216
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
217
+ shared_output = None
218
+ if hidden_states.shape[0] > 0:
219
+ # router_logits: (num_tokens, n_experts)
220
+ router_logits, _ = self.gate(hidden_states)
221
+ shared_output = self._forward_shared_experts(hidden_states)
222
+ topk_weights, topk_idx, _ = self.topk(
223
+ hidden_states,
224
+ router_logits,
225
+ num_token_non_padded=forward_batch.num_token_non_padded,
226
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
227
+ layer_id=self.layer_id,
228
+ ),
229
+ )
230
+ else:
231
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
232
+ hidden_states.device
233
+ )
234
+ final_hidden_states = self.experts(
235
+ hidden_states=hidden_states,
236
+ topk_idx=topk_idx,
237
+ topk_weights=topk_weights,
238
+ forward_batch=forward_batch,
239
+ )
240
+
241
+ if shared_output is not None:
242
+ final_hidden_states.add_(shared_output)
186
243
 
244
+ return final_hidden_states
245
+
246
+ def _forward_router_experts(self, hidden_states: torch.Tensor):
187
247
  # router_logits: (num_tokens, n_experts)
188
248
  router_logits, _ = self.gate(hidden_states)
189
249
  topk_output = self.topk(hidden_states, router_logits)
190
- final_hidden_states = self.experts(hidden_states, topk_output)
250
+ return self.experts(hidden_states, topk_output)
251
+
252
+ def forward_normal_dual_stream(
253
+ self,
254
+ hidden_states: torch.Tensor,
255
+ ) -> torch.Tensor:
256
+ current_stream = torch.cuda.current_stream()
257
+ self.alt_stream.wait_stream(current_stream)
258
+ shared_output = self._forward_shared_experts(hidden_states.clone())
259
+
260
+ with torch.cuda.stream(self.alt_stream):
261
+ router_output = self._forward_router_experts(hidden_states)
262
+
263
+ current_stream.wait_stream(self.alt_stream)
264
+
265
+ return router_output, shared_output
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ forward_batch: Optional[ForwardBatch] = None,
271
+ use_reduce_scatter: bool = False,
272
+ ) -> torch.Tensor:
273
+ num_tokens, hidden_dim = hidden_states.shape
274
+ hidden_states = hidden_states.view(-1, hidden_dim)
275
+
276
+ if get_moe_a2a_backend().is_deepep():
277
+ return self._forward_deepep(hidden_states, forward_batch)
278
+
279
+ DUAL_STREAM_TOKEN_THRESHOLD = 1024
280
+ if (
281
+ self.alt_stream is not None
282
+ and hidden_states.shape[0] > 0
283
+ and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
284
+ and get_is_capture_mode()
285
+ ):
286
+ final_hidden_states, shared_output = self.forward_normal_dual_stream(
287
+ hidden_states
288
+ )
289
+ else:
290
+ shared_output = self._forward_shared_experts(hidden_states)
291
+ final_hidden_states = self._forward_router_experts(hidden_states)
292
+
191
293
  if shared_output is not None:
192
294
  final_hidden_states = final_hidden_states + shared_output
193
295
  if self.tp_size > 1 and not use_reduce_scatter:
@@ -346,6 +448,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
346
448
  layer_id=layer_id,
347
449
  config=config,
348
450
  quant_config=quant_config,
451
+ alt_stream=alt_stream,
349
452
  prefix=add_prefix("mlp", prefix),
350
453
  )
351
454
  else:
@@ -528,8 +631,12 @@ class Qwen2MoeForCausalLM(nn.Module):
528
631
  self.pp_group = get_pp_group()
529
632
  self.config = config
530
633
  self.quant_config = quant_config
634
+ alt_stream = torch.cuda.Stream() if _is_cuda else None
531
635
  self.model = Qwen2MoeModel(
532
- config, quant_config, prefix=add_prefix("model", prefix)
636
+ config,
637
+ quant_config,
638
+ prefix=add_prefix("model", prefix),
639
+ alt_stream=alt_stream,
533
640
  )
534
641
  self.lm_head = ParallelLMHead(
535
642
  config.vocab_size,
@@ -33,7 +33,6 @@ from einops import rearrange
33
33
  from transformers import Qwen2VLConfig
34
34
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
35
35
 
36
- from sglang.srt.hf_transformers_utils import get_processor
37
36
  from sglang.srt.layers.activation import QuickGELU
38
37
  from sglang.srt.layers.attention.vision import VisionAttention
39
38
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
50
  from sglang.srt.models.qwen2 import Qwen2Model
52
51
  from sglang.srt.utils import add_prefix
52
+ from sglang.srt.utils.hf_transformers_utils import get_processor
53
53
 
54
54
  logger = logging.getLogger(__name__)
55
55
 
@@ -1,6 +1,5 @@
1
1
  # Adapted from qwen2.py
2
2
  import logging
3
- from functools import partial
4
3
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
4
 
6
5
  import torch
@@ -30,12 +29,19 @@ from sglang.srt.model_loader.weight_utils import (
30
29
  )
31
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
32
31
  from sglang.srt.models.qwen2 import Qwen2Model
33
- from sglang.srt.utils import add_prefix, is_cuda
32
+ from sglang.srt.utils import (
33
+ add_prefix,
34
+ get_cmo_stream,
35
+ is_cuda,
36
+ is_npu,
37
+ wait_cmo_stream,
38
+ )
34
39
 
35
40
  Qwen3Config = None
36
41
 
37
42
  logger = logging.getLogger(__name__)
38
43
  _is_cuda = is_cuda()
44
+ _is_npu = is_npu()
39
45
 
40
46
 
41
47
  class Qwen3Attention(nn.Module):
@@ -235,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
235
241
 
236
242
  # Fully Connected
237
243
  hidden_states, residual = self.layer_communicator.prepare_mlp(
238
- hidden_states, residual, forward_batch
244
+ hidden_states,
245
+ residual,
246
+ forward_batch,
247
+ cache=(
248
+ [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
249
+ if _is_npu
250
+ else None
251
+ ),
239
252
  )
240
253
  hidden_states = self.mlp(hidden_states)
254
+ if _is_npu and get_cmo_stream():
255
+ wait_cmo_stream()
241
256
  hidden_states, residual = self.layer_communicator.postprocess_layer(
242
257
  hidden_states, residual, forward_batch
243
258
  )
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
51
51
  from sglang.srt.layers.moe.topk import TopK
52
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
53
  from sglang.srt.layers.radix_attention import RadixAttention
54
- from sglang.srt.layers.rotary_embedding import get_rope
54
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
55
55
  from sglang.srt.layers.utils import get_layer_id
56
56
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
57
57
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
60
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
62
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
63
+ from sglang.srt.models.utils import (
64
+ create_fused_set_kv_buffer_arg,
65
+ enable_fused_set_kv_buffer,
66
+ )
63
67
  from sglang.srt.utils import (
64
68
  add_prefix,
65
69
  is_cuda,
@@ -98,7 +102,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
98
102
  use_grouped_topk=False,
99
103
  )
100
104
 
101
- self.experts = get_moe_impl_class()(
105
+ self.experts = get_moe_impl_class(quant_config)(
102
106
  num_experts=config.num_experts
103
107
  + global_server_args_dict["ep_num_redundant_experts"],
104
108
  top_k=config.num_experts_per_tok,
@@ -354,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
354
358
  rope_scaling=rope_scaling,
355
359
  dual_chunk_attention_config=dual_chunk_attention_config,
356
360
  )
361
+ self.compatible_with_fused_kv_buffer = (
362
+ False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
363
+ )
364
+
357
365
  self.attn = RadixAttention(
358
366
  self.num_heads,
359
367
  self.head_dim,
@@ -412,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
412
420
  qkv, _ = self.qkv_proj(hidden_states)
413
421
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
414
422
  q, k = self._apply_qk_norm(q, k)
415
- q, k = self.rotary_emb(positions, q, k)
423
+ q, k = self.rotary_emb(
424
+ positions,
425
+ q,
426
+ k,
427
+ fused_set_kv_buffer_arg=(
428
+ create_fused_set_kv_buffer_arg(
429
+ value=v,
430
+ layer=self.attn,
431
+ forward_batch=forward_batch,
432
+ )
433
+ if enable_fused_set_kv_buffer(forward_batch)
434
+ and self.compatible_with_fused_kv_buffer
435
+ else None
436
+ ),
437
+ )
416
438
  inner_state = q, k, v, forward_batch
417
439
  return None, forward_batch, inner_state
418
440
 
@@ -420,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
420
442
  hidden_states, forward_batch, inner_state = intermediate_state
421
443
  if inner_state is None:
422
444
  return hidden_states
423
- attn_output = self.attn(*inner_state)
445
+ attn_output = self.attn(
446
+ *inner_state,
447
+ save_kv_cache=not (
448
+ enable_fused_set_kv_buffer(forward_batch)
449
+ and self.compatible_with_fused_kv_buffer
450
+ ),
451
+ )
424
452
  output, _ = self.o_proj(attn_output)
425
453
  return output
426
454