sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -25,17 +25,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn.functional as F
28
+ import tqdm
28
29
  from torch import nn
29
30
  from transformers import PretrainedConfig
30
31
 
31
- from sglang.srt import single_batch_overlap
32
32
  from sglang.srt.configs.model_config import (
33
33
  get_nsa_index_head_dim,
34
34
  get_nsa_index_n_heads,
35
35
  get_nsa_index_topk,
36
36
  is_deepseek_nsa,
37
37
  )
38
- from sglang.srt.debug_utils.dumper import dumper
39
38
  from sglang.srt.distributed import (
40
39
  get_moe_expert_parallel_world_size,
41
40
  get_pp_group,
@@ -46,9 +45,11 @@ from sglang.srt.distributed import (
46
45
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
47
46
  use_symmetric_memory,
48
47
  )
48
+ from sglang.srt.environ import envs
49
49
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
50
50
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
51
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
52
+ from sglang.srt.layers import deep_gemm_wrapper
52
53
  from sglang.srt.layers.activation import SiluAndMul
53
54
  from sglang.srt.layers.amx_utils import PackWeightMethod
54
55
  from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
@@ -75,7 +76,6 @@ from sglang.srt.layers.linear import (
75
76
  )
76
77
  from sglang.srt.layers.logits_processor import LogitsProcessor
77
78
  from sglang.srt.layers.moe import (
78
- get_deepep_mode,
79
79
  get_moe_a2a_backend,
80
80
  should_use_flashinfer_cutlass_moe_fp4_allgather,
81
81
  should_use_flashinfer_trtllm_moe,
@@ -83,8 +83,12 @@ from sglang.srt.layers.moe import (
83
83
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
84
84
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
85
85
  from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
86
- from sglang.srt.layers.quantization import deep_gemm_wrapper
86
+ from sglang.srt.layers.quantization import CompressedTensorsConfig
87
87
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
88
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
89
+ CompressedTensorsWNA16AMXEPMoEMethod,
90
+ )
91
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
88
92
  from sglang.srt.layers.quantization.fp8_kernel import (
89
93
  is_fp8_fnuz,
90
94
  per_tensor_quant_mla_fp8,
@@ -95,7 +99,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
95
99
  block_quant_to_tensor_quant,
96
100
  channel_quant_to_tensor_quant,
97
101
  normalize_e4m3fn_to_e4m3fnuz,
102
+ quant_weight_ue8m0,
98
103
  requant_weight_ue8m0_inplace,
104
+ transform_scale_ue8m0_inplace,
99
105
  )
100
106
  from sglang.srt.layers.quantization.int8_utils import (
101
107
  block_dequant as int8_block_dequant,
@@ -107,14 +113,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
107
113
  ParallelLMHead,
108
114
  VocabParallelEmbedding,
109
115
  )
110
- from sglang.srt.managers.schedule_batch import global_server_args_dict
111
116
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
112
117
  from sglang.srt.model_loader.weight_utils import default_weight_loader
118
+ from sglang.srt.server_args import get_global_server_args
113
119
  from sglang.srt.single_batch_overlap import SboFlags
114
- from sglang.srt.two_batch_overlap import (
115
- MaybeTboDeepEPDispatcher,
116
- model_forward_maybe_tbo,
117
- )
120
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
121
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
118
122
  from sglang.srt.utils import (
119
123
  BumpAllocator,
120
124
  LazyValue,
@@ -131,6 +135,7 @@ from sglang.srt.utils import (
131
135
  is_hip,
132
136
  is_non_idle_and_non_empty,
133
137
  is_npu,
138
+ is_nvidia_cublas_cu12_version_ge_12_9,
134
139
  is_sm100_supported,
135
140
  log_info_on_rank0,
136
141
  make_layers,
@@ -181,18 +186,31 @@ elif _is_hip:
181
186
  awq_dequantize_triton as awq_dequantize,
182
187
  )
183
188
  elif _is_npu:
184
- import custom_ops
185
- import sgl_kernel_npu
186
- import torch_npu
189
+ import custom_ops # noqa: F401
190
+ import sgl_kernel_npu # noqa: F401
191
+ import torch_npu # noqa: F401
192
+
193
+ from sglang.srt.layers.quantization.awq_triton import (
194
+ awq_dequantize_decomposition as awq_dequantize,
195
+ )
187
196
  else:
188
197
  pass
189
198
 
190
199
  _is_flashinfer_available = is_flashinfer_available()
191
200
  _is_sm100_supported = is_cuda() and is_sm100_supported()
192
-
201
+ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
193
202
 
194
203
  logger = logging.getLogger(__name__)
195
204
 
205
+
206
+ def enable_nextn_moe_bf16_cast_to_fp8(quant_config):
207
+ return (
208
+ quant_config is not None
209
+ and quant_config.get_name() == "modelopt_fp4"
210
+ and get_moe_a2a_backend().is_deepep()
211
+ )
212
+
213
+
196
214
  FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
197
215
  "fa3",
198
216
  "nsa",
@@ -517,12 +535,13 @@ class DeepseekV2MoE(nn.Module):
517
535
  self.n_shared_experts = config.n_shared_experts
518
536
  self.num_fused_shared_experts = (
519
537
  0
520
- if global_server_args_dict["disable_shared_experts_fusion"]
538
+ if get_global_server_args().disable_shared_experts_fusion
521
539
  else config.n_shared_experts
522
540
  )
523
541
  self.config = config
524
542
  self.layer_id = layer_id
525
543
  self.alt_stream = alt_stream
544
+ self.is_nextn = is_nextn
526
545
 
527
546
  if self.tp_size > config.n_routed_experts:
528
547
  raise ValueError(
@@ -546,7 +565,7 @@ class DeepseekV2MoE(nn.Module):
546
565
  self.experts = get_moe_impl_class(quant_config)(
547
566
  num_experts=config.n_routed_experts
548
567
  + self.num_fused_shared_experts
549
- + global_server_args_dict["ep_num_redundant_experts"],
568
+ + get_global_server_args().ep_num_redundant_experts,
550
569
  num_fused_shared_experts=self.num_fused_shared_experts,
551
570
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
552
571
  hidden_size=config.hidden_size,
@@ -589,6 +608,7 @@ class DeepseekV2MoE(nn.Module):
589
608
  **(
590
609
  dict(tp_rank=0, tp_size=1)
591
610
  if get_moe_a2a_backend().is_deepep()
611
+ or get_moe_a2a_backend().is_mooncake()
592
612
  or should_use_flashinfer_cutlass_moe_fp4_allgather()
593
613
  else {}
594
614
  ),
@@ -619,12 +639,12 @@ class DeepseekV2MoE(nn.Module):
619
639
 
620
640
  self.top_k = config.num_experts_per_tok
621
641
 
622
- if get_moe_a2a_backend().is_deepep():
642
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
623
643
  # TODO: we will support tp < ep in the future
624
644
  self.ep_size = get_moe_expert_parallel_world_size()
625
645
  self.num_experts = (
626
646
  config.n_routed_experts
627
- + global_server_args_dict["ep_num_redundant_experts"]
647
+ + get_global_server_args().ep_num_redundant_experts
628
648
  )
629
649
  self.renormalize = config.norm_topk_prob
630
650
  self.topk_group = config.topk_group
@@ -635,20 +655,10 @@ class DeepseekV2MoE(nn.Module):
635
655
  else None
636
656
  )
637
657
 
638
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
639
- group=parallel_state.get_tp_group().device_group,
640
- router_topk=self.top_k,
641
- permute_fusion=True,
642
- num_experts=self.num_experts,
643
- num_local_experts=config.n_routed_experts // self.tp_size,
644
- hidden_size=config.hidden_size,
645
- params_dtype=config.torch_dtype,
646
- deepep_mode=get_deepep_mode(),
647
- async_finish=True,
648
- return_recv_hook=True,
649
- )
650
-
651
- self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
658
+ self._enable_a2a_moe = (
659
+ get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
660
+ )
661
+ self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
652
662
 
653
663
  def get_moe_weights(self):
654
664
  return [
@@ -665,7 +675,7 @@ class DeepseekV2MoE(nn.Module):
665
675
  use_reduce_scatter: bool = False,
666
676
  gemm_output_zero_allocator: BumpAllocator = None,
667
677
  ) -> torch.Tensor:
668
- if not self._enable_deepep_moe:
678
+ if not self._enable_a2a_moe:
669
679
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
670
680
  if (
671
681
  self.alt_stream is not None
@@ -707,6 +717,10 @@ class DeepseekV2MoE(nn.Module):
707
717
  # router_logits: (num_tokens, n_experts)
708
718
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
709
719
  topk_output = self.topk(hidden_states, router_logits)
720
+ if isinstance(
721
+ self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
722
+ ):
723
+ topk_output.topk_weights.mul_(self.routed_scaling_factor)
710
724
  final_hidden_states = self.experts(hidden_states, topk_output)
711
725
  if not _is_cuda:
712
726
  final_hidden_states *= self.routed_scaling_factor
@@ -740,9 +754,10 @@ class DeepseekV2MoE(nn.Module):
740
754
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
741
755
 
742
756
  if hidden_states.shape[0] > 0:
743
- shared_output = self._forward_shared_experts(
744
- hidden_states, gemm_output_zero_allocator
745
- )
757
+ if not self._fuse_shared_experts_inside_sbo:
758
+ shared_output = self._forward_shared_experts(
759
+ hidden_states, gemm_output_zero_allocator
760
+ )
746
761
  # router_logits: (num_tokens, n_experts)
747
762
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
748
763
  topk_output = self.topk(hidden_states, router_logits)
@@ -750,7 +765,27 @@ class DeepseekV2MoE(nn.Module):
750
765
  shared_output = None
751
766
  topk_output = self.topk.empty_topk_output(hidden_states.device)
752
767
 
753
- final_hidden_states = self.experts(hidden_states, topk_output)
768
+ if self._fuse_shared_experts_inside_sbo:
769
+ shared_output = None
770
+
771
+ def _forward_shared_experts_and_put_results():
772
+ nonlocal shared_output
773
+ shared_output = self._forward_shared_experts(
774
+ hidden_states, gemm_output_zero_allocator
775
+ )
776
+
777
+ final_hidden_states = self.experts(
778
+ hidden_states,
779
+ topk_output,
780
+ **(
781
+ dict(
782
+ forward_shared_experts=_forward_shared_experts_and_put_results,
783
+ alt_stream=self.alt_stream,
784
+ )
785
+ if self._fuse_shared_experts_inside_sbo
786
+ else {}
787
+ ),
788
+ )
754
789
  if not _is_cuda and not _use_aiter:
755
790
  # fused in biased_grouped_topk so we can skip here
756
791
  final_hidden_states *= self.routed_scaling_factor
@@ -834,9 +869,9 @@ class DeepseekV2MoE(nn.Module):
834
869
  if hidden_states.shape[0] > 0:
835
870
  # router_logits: (num_tokens, n_experts)
836
871
  router_logits = self.gate(hidden_states)
837
- if not SboFlags.fuse_shared_experts_inside_sbo():
872
+ if not self._fuse_shared_experts_inside_sbo:
838
873
  shared_output = self._forward_shared_experts(hidden_states)
839
- topk_weights, topk_idx, _ = self.topk(
874
+ topk_output = self.topk(
840
875
  hidden_states,
841
876
  router_logits,
842
877
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -845,22 +880,29 @@ class DeepseekV2MoE(nn.Module):
845
880
  ),
846
881
  )
847
882
  else:
848
- topk_weights, topk_idx, _ = self.topk.empty_topk_output(
849
- hidden_states.device
850
- )
883
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
884
+
885
+ if self._fuse_shared_experts_inside_sbo:
886
+ shared_output = None
887
+
888
+ def _forward_shared_experts_and_put_results():
889
+ nonlocal shared_output
890
+ shared_output = self._forward_shared_experts(hidden_states)
851
891
 
852
- final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
892
+ final_hidden_states = self.experts(
853
893
  hidden_states=hidden_states,
854
- topk_idx=topk_idx,
855
- topk_weights=topk_weights,
856
- forward_batch=forward_batch,
857
- # SBO args
858
- forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
859
- experts=self.experts,
860
- alt_stream=self.alt_stream,
894
+ topk_output=topk_output,
895
+ **(
896
+ dict(
897
+ forward_shared_experts=_forward_shared_experts_and_put_results,
898
+ alt_stream=self.alt_stream,
899
+ # SBO is not yet implemented for NextN
900
+ disable_sbo=self.is_nextn,
901
+ )
902
+ if self._fuse_shared_experts_inside_sbo
903
+ else {}
904
+ ),
861
905
  )
862
- if sbo_shared_output is not None:
863
- shared_output = sbo_shared_output
864
906
 
865
907
  if shared_output is not None:
866
908
  x = shared_output
@@ -911,7 +953,7 @@ class DeepseekV2MoE(nn.Module):
911
953
  with get_global_expert_distribution_recorder().with_current_layer(
912
954
  self.layer_id
913
955
  ):
914
- state.topk_weights_local, state.topk_idx_local, _ = self.topk(
956
+ state.topk_output = self.topk(
915
957
  hidden_states=hidden_states,
916
958
  router_logits=router_logits,
917
959
  num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -920,21 +962,13 @@ class DeepseekV2MoE(nn.Module):
920
962
  ),
921
963
  )
922
964
  else:
923
- state.topk_idx_local = torch.full(
924
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
925
- )
926
- state.topk_weights_local = torch.empty(
927
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
928
- )
965
+ state.topk_output = self.topk.empty_topk_output(hidden_states.device)
929
966
 
930
967
  def op_dispatch_a(self, state):
931
968
  if self.ep_size > 1:
932
- self.experts.deepep_dispatcher.dispatch_a(
969
+ self.experts.dispatcher.dispatch_a(
933
970
  hidden_states=state.hidden_states_mlp_input,
934
- input_global_scale=None,
935
- topk_idx=state.pop("topk_idx_local"),
936
- topk_weights=state.pop("topk_weights_local"),
937
- forward_batch=state.forward_batch,
971
+ topk_output=state.pop("topk_output"),
938
972
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
939
973
  )
940
974
 
@@ -943,32 +977,29 @@ class DeepseekV2MoE(nn.Module):
943
977
  with get_global_expert_distribution_recorder().with_current_layer(
944
978
  self.layer_id
945
979
  ):
946
- state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
980
+ state.dispatch_output = self.experts.dispatcher.dispatch_b(
947
981
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
948
982
  )
949
983
 
950
984
  def op_experts(self, state):
951
- state.hidden_states_experts_output = self.experts.moe_impl(
985
+ state.hidden_states_experts_output = self.experts.run_moe_core(
952
986
  dispatch_output=state.dispatch_output,
953
987
  )
954
988
 
955
989
  def op_combine_a(self, state):
956
990
  if self.ep_size > 1:
957
- self.experts.deepep_dispatcher.combine_a(
991
+ self.experts.dispatcher.combine_a(
958
992
  hidden_states=state.pop("hidden_states_experts_output"),
959
- topk_idx=state.dispatch_output.topk_idx,
993
+ topk_ids=state.dispatch_output.topk_ids,
960
994
  topk_weights=state.dispatch_output.topk_weights,
961
- forward_batch=state.forward_batch,
962
995
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
963
996
  )
964
997
  state.pop("dispatch_output")
965
998
 
966
999
  def op_combine_b(self, state):
967
1000
  if self.ep_size > 1:
968
- state.hidden_states_after_combine = (
969
- self.experts.deepep_dispatcher.combine_b(
970
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
971
- )
1001
+ state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
1002
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
972
1003
  )
973
1004
 
974
1005
  def op_output(self, state):
@@ -1050,7 +1081,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1050
1081
  q_lora_rank,
1051
1082
  self.num_heads * self.qk_head_dim,
1052
1083
  bias=False,
1053
- quant_config=quant_config,
1084
+ quant_config=self._get_q_b_proj_quant_config(quant_config),
1054
1085
  prefix=add_prefix("q_b_proj", prefix),
1055
1086
  tp_rank=attn_tp_rank,
1056
1087
  tp_size=attn_tp_size,
@@ -1122,7 +1153,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1122
1153
  base=rope_theta,
1123
1154
  rope_scaling=rope_scaling,
1124
1155
  is_neox_style=False,
1125
- device=global_server_args_dict["device"],
1156
+ device=get_global_server_args().device,
1126
1157
  )
1127
1158
 
1128
1159
  if rope_scaling:
@@ -1166,12 +1197,12 @@ class DeepseekV2AttentionMLA(nn.Module):
1166
1197
  self.w_scale_v = None
1167
1198
  self.use_deep_gemm_bmm = False
1168
1199
 
1169
- self.flashinfer_mla_disable_ragged = global_server_args_dict[
1170
- "flashinfer_mla_disable_ragged"
1171
- ]
1172
- self.disable_chunked_prefix_cache = global_server_args_dict[
1173
- "disable_chunked_prefix_cache"
1174
- ]
1200
+ self.flashinfer_mla_disable_ragged = (
1201
+ get_global_server_args().flashinfer_mla_disable_ragged
1202
+ )
1203
+ self.disable_chunked_prefix_cache = (
1204
+ get_global_server_args().disable_chunked_prefix_cache
1205
+ )
1175
1206
 
1176
1207
  self.current_attention_backend = (
1177
1208
  None # Attention backend used by current forward batch
@@ -1250,18 +1281,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1250
1281
  ) -> AttnForwardMethod:
1251
1282
  # Determine attention backend used by current forward batch
1252
1283
  if forward_batch.forward_mode.is_decode_or_idle():
1253
- attention_backend = global_server_args_dict["decode_attention_backend"]
1284
+ attention_backend = get_global_server_args().decode_attention_backend
1254
1285
  elif (
1255
1286
  forward_batch.forward_mode.is_target_verify()
1256
1287
  or forward_batch.forward_mode.is_draft_extend()
1257
1288
  ):
1258
1289
  # Use the specified backend for speculative operations (both verify and draft extend)
1259
- if global_server_args_dict["speculative_attention_mode"] == "decode":
1260
- attention_backend = global_server_args_dict["decode_attention_backend"]
1290
+ if get_global_server_args().speculative_attention_mode == "decode":
1291
+ attention_backend = get_global_server_args().decode_attention_backend
1261
1292
  else: # default to prefill
1262
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1293
+ attention_backend = get_global_server_args().prefill_attention_backend
1263
1294
  else:
1264
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1295
+ attention_backend = get_global_server_args().prefill_attention_backend
1265
1296
  self.current_attention_backend = attention_backend
1266
1297
 
1267
1298
  handler = AttentionBackendRegistry.get_handler(attention_backend)
@@ -1351,6 +1382,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1351
1382
  inner_state = self.mla_preprocess.forward(
1352
1383
  positions, hidden_states, forward_batch, zero_allocator
1353
1384
  )
1385
+ inner_state = (*inner_state, None) # add a position for topk_indices
1354
1386
  elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1355
1387
  inner_state = self.forward_npu_sparse_prepare(
1356
1388
  positions, hidden_states, forward_batch, zero_allocator
@@ -1572,9 +1604,14 @@ class DeepseekV2AttentionMLA(nn.Module):
1572
1604
  self.w_kc.to(torch.bfloat16) * self.w_scale,
1573
1605
  )
1574
1606
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1607
+ # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
1575
1608
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1576
1609
  q_nope.transpose(0, 1),
1577
- zero_allocator.allocate(1),
1610
+ (
1611
+ torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
1612
+ if _is_cublas_ge_129
1613
+ else zero_allocator.allocate(1)
1614
+ ),
1578
1615
  )
1579
1616
  q_nope_out = bmm_fp8(
1580
1617
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -1718,7 +1755,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1718
1755
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1719
1756
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1720
1757
  attn_output.transpose(0, 1),
1721
- zero_allocator.allocate(1),
1758
+ (
1759
+ torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
1760
+ if _is_cublas_ge_129
1761
+ else zero_allocator.allocate(1)
1762
+ ),
1722
1763
  )
1723
1764
  attn_bmm_output = bmm_fp8(
1724
1765
  attn_output_val,
@@ -2335,6 +2376,17 @@ class DeepseekV2AttentionMLA(nn.Module):
2335
2376
  output, _ = self.o_proj(attn_output)
2336
2377
  return output
2337
2378
 
2379
+ @staticmethod
2380
+ def _get_q_b_proj_quant_config(quant_config):
2381
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
2382
+ # refer to real DeepSeek V3 quant config
2383
+ return Fp8Config(
2384
+ is_checkpoint_fp8_serialized=True,
2385
+ weight_block_size=[128, 128],
2386
+ )
2387
+ else:
2388
+ return quant_config
2389
+
2338
2390
 
2339
2391
  class DeepseekV2DecoderLayer(nn.Module):
2340
2392
 
@@ -2343,6 +2395,7 @@ class DeepseekV2DecoderLayer(nn.Module):
2343
2395
  config: PretrainedConfig,
2344
2396
  layer_id: int,
2345
2397
  quant_config: Optional[QuantizationConfig] = None,
2398
+ moe_quant_config: Optional[QuantizationConfig] = None,
2346
2399
  is_nextn: bool = False,
2347
2400
  prefix: str = "",
2348
2401
  alt_stream: Optional[torch.cuda.Stream] = None,
@@ -2353,7 +2406,9 @@ class DeepseekV2DecoderLayer(nn.Module):
2353
2406
  rope_theta = getattr(config, "rope_theta", 10000)
2354
2407
  rope_scaling = getattr(config, "rope_scaling", None)
2355
2408
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
2356
- self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
2409
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
2410
+ get_global_server_args().speculative_algorithm
2411
+ )
2357
2412
  self.layer_id = layer_id
2358
2413
  self.is_nextn = is_nextn
2359
2414
  self.self_attn = DeepseekV2AttentionMLA(
@@ -2390,7 +2445,7 @@ class DeepseekV2DecoderLayer(nn.Module):
2390
2445
  if self.is_layer_sparse:
2391
2446
  self.mlp = DeepseekV2MoE(
2392
2447
  config=config,
2393
- quant_config=quant_config,
2448
+ quant_config=moe_quant_config or quant_config,
2394
2449
  prefix=add_prefix("mlp", prefix),
2395
2450
  layer_id=self.layer_id,
2396
2451
  alt_stream=alt_stream,
@@ -2796,6 +2851,10 @@ class DeepseekV2ForCausalLM(nn.Module):
2796
2851
  self.config = config
2797
2852
  self.tp_size = get_tensor_model_parallel_world_size()
2798
2853
  self.quant_config = quant_config
2854
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
2855
+ CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
2856
+ True, "dynamic", None, [128, 128]
2857
+ )
2799
2858
  self.determine_num_fused_shared_experts()
2800
2859
  self.model = DeepseekV2Model(
2801
2860
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -2805,7 +2864,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2805
2864
  config.hidden_size,
2806
2865
  quant_config=quant_config,
2807
2866
  prefix=add_prefix("lm_head", prefix),
2808
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2867
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
2809
2868
  )
2810
2869
  self.logits_processor = LogitsProcessor(config)
2811
2870
 
@@ -2825,7 +2884,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2825
2884
  self, architecture: str = "DeepseekV3ForCausalLM"
2826
2885
  ):
2827
2886
  self.num_fused_shared_experts = 0
2828
- if global_server_args_dict["disable_shared_experts_fusion"]:
2887
+ if get_global_server_args().disable_shared_experts_fusion:
2829
2888
  return
2830
2889
 
2831
2890
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2844,7 +2903,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2844
2903
  disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2845
2904
 
2846
2905
  if disable_reason is not None:
2847
- global_server_args_dict["disable_shared_experts_fusion"] = True
2906
+ get_global_server_args().disable_shared_experts_fusion = True
2848
2907
  self.num_fused_shared_experts = 0
2849
2908
  log_info_on_rank0(
2850
2909
  logger,
@@ -2909,7 +2968,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2909
2968
  )
2910
2969
  if hasattr(self_attn.kv_b_proj, "qweight"):
2911
2970
  # AWQ compatible
2912
- if _is_cuda or _is_hip:
2971
+ if _is_cuda or _is_hip or _is_npu:
2913
2972
  w = awq_dequantize(
2914
2973
  self_attn.kv_b_proj.qweight,
2915
2974
  self_attn.kv_b_proj.scales,
@@ -2935,11 +2994,13 @@ class DeepseekV2ForCausalLM(nn.Module):
2935
2994
  torch.float8_e4m3fn,
2936
2995
  torch.float8_e4m3fnuz,
2937
2996
  ):
2938
- if (
2939
- hasattr(self.quant_config, "weight_block_size")
2940
- and self.quant_config.weight_block_size is not None
2941
- ):
2942
- weight_block_size = self.quant_config.weight_block_size
2997
+ selected_quant_config = getattr(
2998
+ self.quant_config, "DeepSeekFP8Config", self.quant_config
2999
+ )
3000
+ weight_block_size = getattr(
3001
+ selected_quant_config, "weight_block_size", None
3002
+ )
3003
+ if weight_block_size is not None:
2943
3004
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
2944
3005
  if _is_fp8_fnuz:
2945
3006
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
@@ -3069,6 +3130,16 @@ class DeepseekV2ForCausalLM(nn.Module):
3069
3130
  ):
3070
3131
  self._weight_requant_ue8m0(is_nextn)
3071
3132
 
3133
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3134
+ if (
3135
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
3136
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
3137
+ and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
3138
+ ):
3139
+ self._transform_scale_ue8m0(is_nextn)
3140
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3141
+ self._transform_scale_nextn_moe_ue8m0()
3142
+
3072
3143
  def _weight_requant_ue8m0(self, is_nextn=False):
3073
3144
  weight_block_size = self.quant_config.weight_block_size
3074
3145
 
@@ -3134,6 +3205,47 @@ class DeepseekV2ForCausalLM(nn.Module):
3134
3205
  module.weight, module.weight_scale_inv, weight_block_size
3135
3206
  )
3136
3207
 
3208
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3209
+ def _transform_scale_ue8m0(self, is_nextn=False):
3210
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
3211
+
3212
+ for layer_id in range(num_hidden_layers):
3213
+ if is_nextn:
3214
+ layer = self.model.decoder
3215
+ else:
3216
+ layer = self.model.layers[layer_id]
3217
+
3218
+ module_list = []
3219
+ if self.config.q_lora_rank is not None:
3220
+ module_list.append(layer.self_attn.q_b_proj)
3221
+
3222
+ for module in module_list:
3223
+ transform_scale_ue8m0_inplace(
3224
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3225
+ )
3226
+
3227
+ # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
3228
+ def _transform_scale_nextn_moe_ue8m0(self):
3229
+ layer = self.model.decoder
3230
+
3231
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
3232
+ if shared_experts is not None:
3233
+ for module in [
3234
+ shared_experts.gate_up_proj,
3235
+ shared_experts.down_proj,
3236
+ ]:
3237
+ transform_scale_ue8m0_inplace(
3238
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3239
+ )
3240
+
3241
+ experts = layer.mlp.experts
3242
+ if isinstance(experts, DeepEPMoE):
3243
+ for w in [
3244
+ experts.w13_weight_fp8,
3245
+ experts.w2_weight_fp8,
3246
+ ]:
3247
+ transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
3248
+
3137
3249
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
3138
3250
 
3139
3251
  if is_nextn:
@@ -3149,6 +3261,13 @@ class DeepseekV2ForCausalLM(nn.Module):
3149
3261
  else:
3150
3262
  raise ValueError("num_nextn_predict_layers is not in the config")
3151
3263
 
3264
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
3265
+ weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
3266
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3267
+ weights = self._quant_nextn_moe_to_fp8_ue8m0(
3268
+ weights, nextn_layer_id=nextn_layer_id
3269
+ )
3270
+
3152
3271
  stacked_params_mapping = [
3153
3272
  # (param_name, shard_name, shard_id)
3154
3273
  ("gate_up_proj", "gate_proj", 0),
@@ -3378,6 +3497,62 @@ class DeepseekV2ForCausalLM(nn.Module):
3378
3497
 
3379
3498
  self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
3380
3499
 
3500
+ def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
3501
+ weights_dict = dict(weights)
3502
+
3503
+ # temporarily only support DeepSeek V3/R1
3504
+ weight_block_size = [128, 128]
3505
+
3506
+ for layer_id in tqdm.trange(
3507
+ self.config.num_hidden_layers + int(is_nextn),
3508
+ desc="quant attn to fp8 ue8m0",
3509
+ ):
3510
+ for stem in [
3511
+ # may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
3512
+ "q_b_proj",
3513
+ ]:
3514
+ partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
3515
+ original_weight = weights_dict[f"{partial_name}.weight"]
3516
+ out_w, out_s = quant_weight_ue8m0(
3517
+ original_weight, weight_block_size=weight_block_size
3518
+ )
3519
+ weights_dict[f"{partial_name}.weight"] = out_w
3520
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3521
+
3522
+ return list(weights_dict.items())
3523
+
3524
+ # TODO avoid code dup
3525
+ def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
3526
+ weights_dict = dict(weights)
3527
+
3528
+ # temporarily only support DeepSeek V3/R1
3529
+ weight_block_size = [128, 128]
3530
+
3531
+ for layer_id in [nextn_layer_id]:
3532
+ for expert_sub_name in [
3533
+ "shared_experts",
3534
+ *[
3535
+ f"experts.{expert_id}"
3536
+ for expert_id in range(self.config.n_routed_experts)
3537
+ ],
3538
+ ]:
3539
+ for stem in [
3540
+ "gate_proj",
3541
+ "up_proj",
3542
+ "down_proj",
3543
+ ]:
3544
+ partial_name = (
3545
+ f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}"
3546
+ )
3547
+ original_weight = weights_dict[f"{partial_name}.weight"]
3548
+ out_w, out_s = quant_weight_ue8m0(
3549
+ original_weight, weight_block_size=weight_block_size
3550
+ )
3551
+ weights_dict[f"{partial_name}.weight"] = out_w
3552
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3553
+
3554
+ return list(weights_dict.items())
3555
+
3381
3556
  def get_embed_and_head(self):
3382
3557
  return self.model.embed_tokens.weight, self.lm_head.weight
3383
3558