sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -25,17 +25,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
25
25
 
26
26
  import torch
27
27
  import torch.nn.functional as F
28
+ import tqdm
28
29
  from torch import nn
29
30
  from transformers import PretrainedConfig
30
31
 
31
- from sglang.srt import single_batch_overlap
32
32
  from sglang.srt.configs.model_config import (
33
33
  get_nsa_index_head_dim,
34
34
  get_nsa_index_n_heads,
35
35
  get_nsa_index_topk,
36
36
  is_deepseek_nsa,
37
37
  )
38
- from sglang.srt.debug_utils.dumper import dumper
39
38
  from sglang.srt.distributed import (
40
39
  get_moe_expert_parallel_world_size,
41
40
  get_pp_group,
@@ -46,9 +45,11 @@ from sglang.srt.distributed import (
46
45
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
47
46
  use_symmetric_memory,
48
47
  )
48
+ from sglang.srt.environ import envs
49
49
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
50
50
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
51
51
  from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
52
+ from sglang.srt.layers import deep_gemm_wrapper
52
53
  from sglang.srt.layers.activation import SiluAndMul
53
54
  from sglang.srt.layers.amx_utils import PackWeightMethod
54
55
  from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
@@ -56,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
56
57
  is_mla_preprocess_enabled,
57
58
  )
58
59
  from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
60
+ from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
59
61
  from sglang.srt.layers.communicator import (
60
62
  LayerCommunicator,
61
63
  LayerScatterModes,
@@ -75,7 +77,6 @@ from sglang.srt.layers.linear import (
75
77
  )
76
78
  from sglang.srt.layers.logits_processor import LogitsProcessor
77
79
  from sglang.srt.layers.moe import (
78
- get_deepep_mode,
79
80
  get_moe_a2a_backend,
80
81
  should_use_flashinfer_cutlass_moe_fp4_allgather,
81
82
  should_use_flashinfer_trtllm_moe,
@@ -83,8 +84,12 @@ from sglang.srt.layers.moe import (
83
84
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
84
85
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
85
86
  from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
86
- from sglang.srt.layers.quantization import deep_gemm_wrapper
87
+ from sglang.srt.layers.quantization import CompressedTensorsConfig
87
88
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
89
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
90
+ CompressedTensorsWNA16AMXEPMoEMethod,
91
+ )
92
+ from sglang.srt.layers.quantization.fp8 import Fp8Config
88
93
  from sglang.srt.layers.quantization.fp8_kernel import (
89
94
  is_fp8_fnuz,
90
95
  per_tensor_quant_mla_fp8,
@@ -95,7 +100,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
95
100
  block_quant_to_tensor_quant,
96
101
  channel_quant_to_tensor_quant,
97
102
  normalize_e4m3fn_to_e4m3fnuz,
103
+ quant_weight_ue8m0,
98
104
  requant_weight_ue8m0_inplace,
105
+ transform_scale_ue8m0_inplace,
99
106
  )
100
107
  from sglang.srt.layers.quantization.int8_utils import (
101
108
  block_dequant as int8_block_dequant,
@@ -107,14 +114,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
107
114
  ParallelLMHead,
108
115
  VocabParallelEmbedding,
109
116
  )
110
- from sglang.srt.managers.schedule_batch import global_server_args_dict
111
117
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
112
118
  from sglang.srt.model_loader.weight_utils import default_weight_loader
119
+ from sglang.srt.server_args import get_global_server_args
113
120
  from sglang.srt.single_batch_overlap import SboFlags
114
- from sglang.srt.two_batch_overlap import (
115
- MaybeTboDeepEPDispatcher,
116
- model_forward_maybe_tbo,
117
- )
121
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
122
+ from sglang.srt.two_batch_overlap import model_forward_maybe_tbo
118
123
  from sglang.srt.utils import (
119
124
  BumpAllocator,
120
125
  LazyValue,
@@ -131,6 +136,7 @@ from sglang.srt.utils import (
131
136
  is_hip,
132
137
  is_non_idle_and_non_empty,
133
138
  is_npu,
139
+ is_nvidia_cublas_cu12_version_ge_12_9,
134
140
  is_sm100_supported,
135
141
  log_info_on_rank0,
136
142
  make_layers,
@@ -181,18 +187,31 @@ elif _is_hip:
181
187
  awq_dequantize_triton as awq_dequantize,
182
188
  )
183
189
  elif _is_npu:
184
- import custom_ops
185
- import sgl_kernel_npu
186
- import torch_npu
190
+ import custom_ops # noqa: F401
191
+ import sgl_kernel_npu # noqa: F401
192
+ import torch_npu # noqa: F401
193
+
194
+ from sglang.srt.layers.quantization.awq_triton import (
195
+ awq_dequantize_decomposition as awq_dequantize,
196
+ )
187
197
  else:
188
198
  pass
189
199
 
190
200
  _is_flashinfer_available = is_flashinfer_available()
191
201
  _is_sm100_supported = is_cuda() and is_sm100_supported()
192
-
202
+ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9()
193
203
 
194
204
  logger = logging.getLogger(__name__)
195
205
 
206
+
207
+ def enable_nextn_moe_bf16_cast_to_fp8(quant_config):
208
+ return (
209
+ quant_config is not None
210
+ and quant_config.get_name() == "modelopt_fp4"
211
+ and get_moe_a2a_backend().is_deepep()
212
+ )
213
+
214
+
196
215
  FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
197
216
  "fa3",
198
217
  "nsa",
@@ -223,6 +242,10 @@ class AttnForwardMethod(IntEnum):
223
242
  # This method can avoid OOM when prefix lengths are long.
224
243
  MHA_CHUNKED_KV = auto()
225
244
 
245
+ # Use multi-head attention, execute the MHA for prefix and extended kv in one shot
246
+ # when the sequence lengths are below the threshold.
247
+ MHA_ONE_SHOT = auto()
248
+
226
249
  # Use MLA but with fused RoPE
227
250
  MLA_FUSED_ROPE = auto()
228
251
 
@@ -288,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
288
311
  )
289
312
 
290
313
 
314
+ def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
315
+ attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
316
+ sum_seq_lens = (
317
+ sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
318
+ )
319
+ return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
320
+
321
+
291
322
  def _handle_attention_backend(
292
323
  attn: DeepseekV2AttentionMLA, forward_batch, backend_name
293
324
  ):
@@ -307,6 +338,8 @@ def _handle_attention_backend(
307
338
  or sum_extend_prefix_lens == 0
308
339
  )
309
340
  ):
341
+ if _support_mha_one_shot(attn, forward_batch, backend_name):
342
+ return AttnForwardMethod.MHA_ONE_SHOT
310
343
  return AttnForwardMethod.MHA_CHUNKED_KV
311
344
  else:
312
345
  return _dispatch_mla_subtype(attn, forward_batch)
@@ -317,7 +350,11 @@ def handle_attention_flashinfer(attn, forward_batch):
317
350
 
318
351
 
319
352
  def handle_attention_fa3(attn, forward_batch):
320
- return _handle_attention_backend(attn, forward_batch, "fa3")
353
+ # when deterministic inference is enabled, use MLA
354
+ if get_global_server_args().enable_deterministic_inference:
355
+ return _dispatch_mla_subtype(attn, forward_batch)
356
+ else:
357
+ return _handle_attention_backend(attn, forward_batch, "fa3")
321
358
 
322
359
 
323
360
  def handle_attention_flashmla(attn, forward_batch):
@@ -361,6 +398,10 @@ def handle_attention_nsa(attn, forward_batch):
361
398
 
362
399
 
363
400
  def handle_attention_triton(attn, forward_batch):
401
+ # when deterministic inference is enabled, use MLA
402
+ if get_global_server_args().enable_deterministic_inference:
403
+ return _dispatch_mla_subtype(attn, forward_batch)
404
+
364
405
  if (
365
406
  _is_extend_without_speculative(forward_batch)
366
407
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
@@ -517,12 +558,13 @@ class DeepseekV2MoE(nn.Module):
517
558
  self.n_shared_experts = config.n_shared_experts
518
559
  self.num_fused_shared_experts = (
519
560
  0
520
- if global_server_args_dict["disable_shared_experts_fusion"]
561
+ if get_global_server_args().disable_shared_experts_fusion
521
562
  else config.n_shared_experts
522
563
  )
523
564
  self.config = config
524
565
  self.layer_id = layer_id
525
566
  self.alt_stream = alt_stream
567
+ self.is_nextn = is_nextn
526
568
 
527
569
  if self.tp_size > config.n_routed_experts:
528
570
  raise ValueError(
@@ -546,7 +588,7 @@ class DeepseekV2MoE(nn.Module):
546
588
  self.experts = get_moe_impl_class(quant_config)(
547
589
  num_experts=config.n_routed_experts
548
590
  + self.num_fused_shared_experts
549
- + global_server_args_dict["ep_num_redundant_experts"],
591
+ + get_global_server_args().ep_num_redundant_experts,
550
592
  num_fused_shared_experts=self.num_fused_shared_experts,
551
593
  top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
552
594
  hidden_size=config.hidden_size,
@@ -589,6 +631,7 @@ class DeepseekV2MoE(nn.Module):
589
631
  **(
590
632
  dict(tp_rank=0, tp_size=1)
591
633
  if get_moe_a2a_backend().is_deepep()
634
+ or get_moe_a2a_backend().is_mooncake()
592
635
  or should_use_flashinfer_cutlass_moe_fp4_allgather()
593
636
  else {}
594
637
  ),
@@ -619,12 +662,12 @@ class DeepseekV2MoE(nn.Module):
619
662
 
620
663
  self.top_k = config.num_experts_per_tok
621
664
 
622
- if get_moe_a2a_backend().is_deepep():
665
+ if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake():
623
666
  # TODO: we will support tp < ep in the future
624
667
  self.ep_size = get_moe_expert_parallel_world_size()
625
668
  self.num_experts = (
626
669
  config.n_routed_experts
627
- + global_server_args_dict["ep_num_redundant_experts"]
670
+ + get_global_server_args().ep_num_redundant_experts
628
671
  )
629
672
  self.renormalize = config.norm_topk_prob
630
673
  self.topk_group = config.topk_group
@@ -635,20 +678,10 @@ class DeepseekV2MoE(nn.Module):
635
678
  else None
636
679
  )
637
680
 
638
- self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
639
- group=parallel_state.get_tp_group().device_group,
640
- router_topk=self.top_k,
641
- permute_fusion=True,
642
- num_experts=self.num_experts,
643
- num_local_experts=config.n_routed_experts // self.tp_size,
644
- hidden_size=config.hidden_size,
645
- params_dtype=config.torch_dtype,
646
- deepep_mode=get_deepep_mode(),
647
- async_finish=True,
648
- return_recv_hook=True,
649
- )
650
-
651
- self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
681
+ self._enable_a2a_moe = (
682
+ get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake()
683
+ )
684
+ self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo()
652
685
 
653
686
  def get_moe_weights(self):
654
687
  return [
@@ -665,7 +698,7 @@ class DeepseekV2MoE(nn.Module):
665
698
  use_reduce_scatter: bool = False,
666
699
  gemm_output_zero_allocator: BumpAllocator = None,
667
700
  ) -> torch.Tensor:
668
- if not self._enable_deepep_moe:
701
+ if not self._enable_a2a_moe:
669
702
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
670
703
  if (
671
704
  self.alt_stream is not None
@@ -707,6 +740,10 @@ class DeepseekV2MoE(nn.Module):
707
740
  # router_logits: (num_tokens, n_experts)
708
741
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
709
742
  topk_output = self.topk(hidden_states, router_logits)
743
+ if isinstance(
744
+ self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
745
+ ):
746
+ topk_output.topk_weights.mul_(self.routed_scaling_factor)
710
747
  final_hidden_states = self.experts(hidden_states, topk_output)
711
748
  if not _is_cuda:
712
749
  final_hidden_states *= self.routed_scaling_factor
@@ -740,9 +777,10 @@ class DeepseekV2MoE(nn.Module):
740
777
  return self.forward_cpu(hidden_states, should_allreduce_fusion)
741
778
 
742
779
  if hidden_states.shape[0] > 0:
743
- shared_output = self._forward_shared_experts(
744
- hidden_states, gemm_output_zero_allocator
745
- )
780
+ if not self._fuse_shared_experts_inside_sbo:
781
+ shared_output = self._forward_shared_experts(
782
+ hidden_states, gemm_output_zero_allocator
783
+ )
746
784
  # router_logits: (num_tokens, n_experts)
747
785
  router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
748
786
  topk_output = self.topk(hidden_states, router_logits)
@@ -750,7 +788,27 @@ class DeepseekV2MoE(nn.Module):
750
788
  shared_output = None
751
789
  topk_output = self.topk.empty_topk_output(hidden_states.device)
752
790
 
753
- final_hidden_states = self.experts(hidden_states, topk_output)
791
+ if self._fuse_shared_experts_inside_sbo:
792
+ shared_output = None
793
+
794
+ def _forward_shared_experts_and_put_results():
795
+ nonlocal shared_output
796
+ shared_output = self._forward_shared_experts(
797
+ hidden_states, gemm_output_zero_allocator
798
+ )
799
+
800
+ final_hidden_states = self.experts(
801
+ hidden_states,
802
+ topk_output,
803
+ **(
804
+ dict(
805
+ forward_shared_experts=_forward_shared_experts_and_put_results,
806
+ alt_stream=self.alt_stream,
807
+ )
808
+ if self._fuse_shared_experts_inside_sbo
809
+ else {}
810
+ ),
811
+ )
754
812
  if not _is_cuda and not _use_aiter:
755
813
  # fused in biased_grouped_topk so we can skip here
756
814
  final_hidden_states *= self.routed_scaling_factor
@@ -834,9 +892,9 @@ class DeepseekV2MoE(nn.Module):
834
892
  if hidden_states.shape[0] > 0:
835
893
  # router_logits: (num_tokens, n_experts)
836
894
  router_logits = self.gate(hidden_states)
837
- if not SboFlags.fuse_shared_experts_inside_sbo():
895
+ if not self._fuse_shared_experts_inside_sbo:
838
896
  shared_output = self._forward_shared_experts(hidden_states)
839
- topk_weights, topk_idx, _ = self.topk(
897
+ topk_output = self.topk(
840
898
  hidden_states,
841
899
  router_logits,
842
900
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -845,22 +903,29 @@ class DeepseekV2MoE(nn.Module):
845
903
  ),
846
904
  )
847
905
  else:
848
- topk_weights, topk_idx, _ = self.topk.empty_topk_output(
849
- hidden_states.device
850
- )
906
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
907
+
908
+ if self._fuse_shared_experts_inside_sbo:
909
+ shared_output = None
851
910
 
852
- final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
911
+ def _forward_shared_experts_and_put_results():
912
+ nonlocal shared_output
913
+ shared_output = self._forward_shared_experts(hidden_states)
914
+
915
+ final_hidden_states = self.experts(
853
916
  hidden_states=hidden_states,
854
- topk_idx=topk_idx,
855
- topk_weights=topk_weights,
856
- forward_batch=forward_batch,
857
- # SBO args
858
- forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
859
- experts=self.experts,
860
- alt_stream=self.alt_stream,
917
+ topk_output=topk_output,
918
+ **(
919
+ dict(
920
+ forward_shared_experts=_forward_shared_experts_and_put_results,
921
+ alt_stream=self.alt_stream,
922
+ # SBO is not yet implemented for NextN
923
+ disable_sbo=self.is_nextn,
924
+ )
925
+ if self._fuse_shared_experts_inside_sbo
926
+ else {}
927
+ ),
861
928
  )
862
- if sbo_shared_output is not None:
863
- shared_output = sbo_shared_output
864
929
 
865
930
  if shared_output is not None:
866
931
  x = shared_output
@@ -911,7 +976,7 @@ class DeepseekV2MoE(nn.Module):
911
976
  with get_global_expert_distribution_recorder().with_current_layer(
912
977
  self.layer_id
913
978
  ):
914
- state.topk_weights_local, state.topk_idx_local, _ = self.topk(
979
+ state.topk_output = self.topk(
915
980
  hidden_states=hidden_states,
916
981
  router_logits=router_logits,
917
982
  num_token_non_padded=state.forward_batch.num_token_non_padded,
@@ -920,21 +985,13 @@ class DeepseekV2MoE(nn.Module):
920
985
  ),
921
986
  )
922
987
  else:
923
- state.topk_idx_local = torch.full(
924
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
925
- )
926
- state.topk_weights_local = torch.empty(
927
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
928
- )
988
+ state.topk_output = self.topk.empty_topk_output(hidden_states.device)
929
989
 
930
990
  def op_dispatch_a(self, state):
931
991
  if self.ep_size > 1:
932
- self.experts.deepep_dispatcher.dispatch_a(
992
+ self.experts.dispatcher.dispatch_a(
933
993
  hidden_states=state.hidden_states_mlp_input,
934
- input_global_scale=None,
935
- topk_idx=state.pop("topk_idx_local"),
936
- topk_weights=state.pop("topk_weights_local"),
937
- forward_batch=state.forward_batch,
994
+ topk_output=state.pop("topk_output"),
938
995
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
939
996
  )
940
997
 
@@ -943,32 +1000,27 @@ class DeepseekV2MoE(nn.Module):
943
1000
  with get_global_expert_distribution_recorder().with_current_layer(
944
1001
  self.layer_id
945
1002
  ):
946
- state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
1003
+ state.dispatch_output = self.experts.dispatcher.dispatch_b(
947
1004
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
948
1005
  )
949
1006
 
950
1007
  def op_experts(self, state):
951
- state.hidden_states_experts_output = self.experts.moe_impl(
1008
+ state.combine_input = self.experts.run_moe_core(
952
1009
  dispatch_output=state.dispatch_output,
953
1010
  )
954
1011
 
955
1012
  def op_combine_a(self, state):
956
1013
  if self.ep_size > 1:
957
- self.experts.deepep_dispatcher.combine_a(
958
- hidden_states=state.pop("hidden_states_experts_output"),
959
- topk_idx=state.dispatch_output.topk_idx,
960
- topk_weights=state.dispatch_output.topk_weights,
961
- forward_batch=state.forward_batch,
1014
+ self.experts.dispatcher.combine_a(
1015
+ combine_input=state.pop("combine_input"),
962
1016
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
963
1017
  )
964
1018
  state.pop("dispatch_output")
965
1019
 
966
1020
  def op_combine_b(self, state):
967
1021
  if self.ep_size > 1:
968
- state.hidden_states_after_combine = (
969
- self.experts.deepep_dispatcher.combine_b(
970
- tbo_subbatch_index=state.get("tbo_subbatch_index"),
971
- )
1022
+ state.hidden_states_after_combine = self.experts.dispatcher.combine_b(
1023
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
972
1024
  )
973
1025
 
974
1026
  def op_output(self, state):
@@ -1031,6 +1083,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1031
1083
  self.scaling = self.qk_head_dim**-0.5
1032
1084
  self.rope_theta = rope_theta
1033
1085
  self.max_position_embeddings = max_position_embeddings
1086
+ self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
1034
1087
 
1035
1088
  # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1036
1089
  if rope_scaling:
@@ -1050,7 +1103,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1050
1103
  q_lora_rank,
1051
1104
  self.num_heads * self.qk_head_dim,
1052
1105
  bias=False,
1053
- quant_config=quant_config,
1106
+ quant_config=self._get_q_b_proj_quant_config(quant_config),
1054
1107
  prefix=add_prefix("q_b_proj", prefix),
1055
1108
  tp_rank=attn_tp_rank,
1056
1109
  tp_size=attn_tp_size,
@@ -1122,7 +1175,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1122
1175
  base=rope_theta,
1123
1176
  rope_scaling=rope_scaling,
1124
1177
  is_neox_style=False,
1125
- device=global_server_args_dict["device"],
1178
+ device=get_global_server_args().device,
1126
1179
  )
1127
1180
 
1128
1181
  if rope_scaling:
@@ -1166,12 +1219,12 @@ class DeepseekV2AttentionMLA(nn.Module):
1166
1219
  self.w_scale_v = None
1167
1220
  self.use_deep_gemm_bmm = False
1168
1221
 
1169
- self.flashinfer_mla_disable_ragged = global_server_args_dict[
1170
- "flashinfer_mla_disable_ragged"
1171
- ]
1172
- self.disable_chunked_prefix_cache = global_server_args_dict[
1173
- "disable_chunked_prefix_cache"
1174
- ]
1222
+ self.flashinfer_mla_disable_ragged = (
1223
+ get_global_server_args().flashinfer_mla_disable_ragged
1224
+ )
1225
+ self.disable_chunked_prefix_cache = (
1226
+ get_global_server_args().disable_chunked_prefix_cache
1227
+ )
1175
1228
 
1176
1229
  self.current_attention_backend = (
1177
1230
  None # Attention backend used by current forward batch
@@ -1250,18 +1303,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1250
1303
  ) -> AttnForwardMethod:
1251
1304
  # Determine attention backend used by current forward batch
1252
1305
  if forward_batch.forward_mode.is_decode_or_idle():
1253
- attention_backend = global_server_args_dict["decode_attention_backend"]
1306
+ attention_backend = get_global_server_args().decode_attention_backend
1254
1307
  elif (
1255
1308
  forward_batch.forward_mode.is_target_verify()
1256
1309
  or forward_batch.forward_mode.is_draft_extend()
1257
1310
  ):
1258
1311
  # Use the specified backend for speculative operations (both verify and draft extend)
1259
- if global_server_args_dict["speculative_attention_mode"] == "decode":
1260
- attention_backend = global_server_args_dict["decode_attention_backend"]
1312
+ if get_global_server_args().speculative_attention_mode == "decode":
1313
+ attention_backend = get_global_server_args().decode_attention_backend
1261
1314
  else: # default to prefill
1262
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1315
+ attention_backend = get_global_server_args().prefill_attention_backend
1263
1316
  else:
1264
- attention_backend = global_server_args_dict["prefill_attention_backend"]
1317
+ attention_backend = get_global_server_args().prefill_attention_backend
1265
1318
  self.current_attention_backend = attention_backend
1266
1319
 
1267
1320
  handler = AttentionBackendRegistry.get_handler(attention_backend)
@@ -1328,6 +1381,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1328
1381
  inner_state = self.forward_normal_chunked_kv_prepare(
1329
1382
  positions, hidden_states, forward_batch, zero_allocator
1330
1383
  )
1384
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1385
+ inner_state = self.forward_normal_one_shot_prepare(
1386
+ positions, hidden_states, forward_batch, zero_allocator
1387
+ )
1331
1388
  elif attn_forward_method == AttnForwardMethod.MLA:
1332
1389
  if not self.is_mla_preprocess_enabled:
1333
1390
  inner_state = self.forward_absorb_prepare(
@@ -1351,6 +1408,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1351
1408
  inner_state = self.mla_preprocess.forward(
1352
1409
  positions, hidden_states, forward_batch, zero_allocator
1353
1410
  )
1411
+ inner_state = (*inner_state, None) # add a position for topk_indices
1354
1412
  elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
1355
1413
  inner_state = self.forward_npu_sparse_prepare(
1356
1414
  positions, hidden_states, forward_batch, zero_allocator
@@ -1378,6 +1436,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1378
1436
  return self.forward_normal_core(*inner_state)
1379
1437
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1380
1438
  return self.forward_normal_chunked_kv_core(*inner_state)
1439
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1440
+ return self.forward_normal_one_shot_core(*inner_state)
1381
1441
  elif attn_forward_method == AttnForwardMethod.MLA:
1382
1442
  return self.forward_absorb_core(*inner_state)
1383
1443
  elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
@@ -1412,41 +1472,24 @@ class DeepseekV2AttentionMLA(nn.Module):
1412
1472
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1413
1473
  latent_cache = latent_cache.unsqueeze(1)
1414
1474
  kv_a = self.kv_a_layernorm(kv_a)
1415
- kv = self.kv_b_proj(kv_a)[0]
1416
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1417
- k_nope = kv[..., : self.qk_nope_head_dim]
1418
- v = kv[..., self.qk_nope_head_dim :]
1419
1475
  k_pe = latent_cache[:, :, self.kv_lora_rank :]
1420
1476
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1421
1477
  q[..., self.qk_nope_head_dim :] = q_pe
1422
- k = torch.empty_like(q)
1423
1478
 
1424
- # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1479
+ self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
1425
1480
  if (
1426
- _is_cuda
1427
- and (self.num_local_heads == 128)
1428
- and (self.qk_nope_head_dim == 128)
1429
- and (self.qk_rope_head_dim == 64)
1481
+ forward_batch.mha_one_shot
1482
+ and sum(forward_batch.extend_prefix_lens_cpu) != 0
1430
1483
  ):
1431
- concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1432
- else:
1433
- k[..., : self.qk_nope_head_dim] = k_nope
1434
- k[..., self.qk_nope_head_dim :] = k_pe
1435
-
1436
- if not _is_npu:
1437
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1438
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1439
-
1440
- # Save latent cache
1441
- forward_batch.token_to_kv_pool.set_kv_buffer(
1442
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1443
- )
1444
- else:
1445
- # To reduce a time-costing split operation
1446
- forward_batch.token_to_kv_pool.set_kv_buffer(
1447
- self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1484
+ kv_a, k_pe = self._get_mla_kv_buffer(
1485
+ forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
1448
1486
  )
1487
+ kv = self.kv_b_proj(kv_a)[0]
1488
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1489
+ k_nope = kv[..., : self.qk_nope_head_dim]
1490
+ v = kv[..., self.qk_nope_head_dim :]
1449
1491
 
1492
+ k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
1450
1493
  return q, k, v, forward_batch
1451
1494
 
1452
1495
  def forward_normal_core(self, q, k, v, forward_batch):
@@ -1572,9 +1615,14 @@ class DeepseekV2AttentionMLA(nn.Module):
1572
1615
  self.w_kc.to(torch.bfloat16) * self.w_scale,
1573
1616
  )
1574
1617
  elif self.w_kc.dtype == torch.float8_e4m3fn:
1618
+ # fix bmm_fp8 error under cublas12.9 caused by bumpallocator, detail in pr#11612
1575
1619
  q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1576
1620
  q_nope.transpose(0, 1),
1577
- zero_allocator.allocate(1),
1621
+ (
1622
+ torch.zeros((1,), dtype=torch.float32, device=q_nope.device)
1623
+ if _is_cublas_ge_129
1624
+ else zero_allocator.allocate(1)
1625
+ ),
1578
1626
  )
1579
1627
  q_nope_out = bmm_fp8(
1580
1628
  q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
@@ -1718,7 +1766,11 @@ class DeepseekV2AttentionMLA(nn.Module):
1718
1766
  elif self.w_vc.dtype == torch.float8_e4m3fn:
1719
1767
  attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1720
1768
  attn_output.transpose(0, 1),
1721
- zero_allocator.allocate(1),
1769
+ (
1770
+ torch.zeros((1,), dtype=torch.float32, device=attn_output.device)
1771
+ if _is_cublas_ge_129
1772
+ else zero_allocator.allocate(1)
1773
+ ),
1722
1774
  )
1723
1775
  attn_bmm_output = bmm_fp8(
1724
1776
  attn_output_val,
@@ -2247,20 +2299,11 @@ class DeepseekV2AttentionMLA(nn.Module):
2247
2299
  for i in range(forward_batch.num_prefix_chunks):
2248
2300
  forward_batch.set_prefix_chunk_idx(i)
2249
2301
 
2302
+ kv_indices = forward_batch.prefix_chunk_kv_indices[i]
2250
2303
  # Fetch latent cache from memory pool with precomputed chunked kv indices
2251
- latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2252
- self.attn_mha.layer_id
2253
- )
2254
- latent_cache = (
2255
- latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
2256
- .contiguous()
2257
- .to(q.dtype)
2258
- )
2259
-
2260
- kv_a_normed, k_pe = latent_cache.split(
2261
- [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2304
+ kv_a_normed, k_pe = self._get_mla_kv_buffer(
2305
+ kv_indices, q.dtype, forward_batch
2262
2306
  )
2263
- kv_a_normed = kv_a_normed.squeeze(1).contiguous()
2264
2307
  kv = self.kv_b_proj(kv_a_normed)[0]
2265
2308
  kv = kv.view(
2266
2309
  -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
@@ -2335,6 +2378,118 @@ class DeepseekV2AttentionMLA(nn.Module):
2335
2378
  output, _ = self.o_proj(attn_output)
2336
2379
  return output
2337
2380
 
2381
+ def forward_normal_one_shot_prepare(
2382
+ self,
2383
+ positions: torch.Tensor,
2384
+ hidden_states: torch.Tensor,
2385
+ forward_batch: ForwardBatch,
2386
+ zero_allocator: BumpAllocator,
2387
+ ):
2388
+ forward_batch.mha_one_shot = True
2389
+ return self.forward_normal_prepare(
2390
+ positions, hidden_states, forward_batch, zero_allocator
2391
+ )
2392
+
2393
+ def forward_normal_one_shot_core(self, q, k, v, forward_batch):
2394
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
2395
+ # Only initialize the info once
2396
+ if has_extend_prefix and forward_batch.num_prefix_chunks is None:
2397
+ forward_batch.num_prefix_chunks = 0
2398
+ if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
2399
+ forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
2400
+ forward_batch.mha_return_lse = False
2401
+ # Do mha for extended part without prefix
2402
+ forward_batch.set_attn_attend_prefix_cache(False)
2403
+ return self.forward_normal_core(q, k, v, forward_batch)
2404
+
2405
+ def _set_mla_kv_buffer(
2406
+ self,
2407
+ latent_cache: torch.Tensor,
2408
+ kv_a: torch.Tensor,
2409
+ k_pe: torch.Tensor,
2410
+ forward_batch: ForwardBatch,
2411
+ ):
2412
+ if _is_cuda:
2413
+ # Save latent cache
2414
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
2415
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2416
+ )
2417
+ elif _is_npu:
2418
+ # To reduce a time-costing split operation
2419
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2420
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2421
+ )
2422
+ else:
2423
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
2424
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
2425
+
2426
+ # Save latent cache
2427
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2428
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
2429
+ )
2430
+
2431
+ def _get_mla_kv_buffer(
2432
+ self,
2433
+ kv_indices: torch.Tensor,
2434
+ dst_dtype: torch.dtype,
2435
+ forward_batch: ForwardBatch,
2436
+ ):
2437
+ if _is_cuda:
2438
+ kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
2439
+ self.attn_mha, kv_indices, dst_dtype
2440
+ )
2441
+ kv_a = kv_a.squeeze(1)
2442
+ else:
2443
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2444
+ self.attn_mha.layer_id
2445
+ )
2446
+ latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
2447
+
2448
+ kv_a, k_pe = latent_cache.split(
2449
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2450
+ )
2451
+ kv_a = kv_a.squeeze(1).contiguous()
2452
+ return kv_a, k_pe
2453
+
2454
+ def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
2455
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
2456
+ k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
2457
+ if (
2458
+ _is_cuda
2459
+ and (self.num_local_heads == 128)
2460
+ and (self.qk_nope_head_dim == 128)
2461
+ and (self.qk_rope_head_dim == 64)
2462
+ ):
2463
+ k = k_nope.new_empty(*k_shape)
2464
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
2465
+ elif _is_cuda:
2466
+ # fa3 mha support fp8 inputs
2467
+ if (
2468
+ self.current_attention_backend == "fa3"
2469
+ and self.kv_cache_dtype != "auto"
2470
+ ):
2471
+ attn_dtype = forward_batch.token_to_kv_pool.dtype
2472
+ else:
2473
+ attn_dtype = k_nope.dtype
2474
+ k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
2475
+ concat_and_cast_mha_k_triton(k, k_nope, k_pe)
2476
+ else:
2477
+ k = k_nope.new_empty(*k_shape)
2478
+ k[..., : self.qk_nope_head_dim] = k_nope
2479
+ k[..., self.qk_nope_head_dim :] = k_pe
2480
+ return k
2481
+
2482
+ @staticmethod
2483
+ def _get_q_b_proj_quant_config(quant_config):
2484
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
2485
+ # refer to real DeepSeek V3 quant config
2486
+ return Fp8Config(
2487
+ is_checkpoint_fp8_serialized=True,
2488
+ weight_block_size=[128, 128],
2489
+ )
2490
+ else:
2491
+ return quant_config
2492
+
2338
2493
 
2339
2494
  class DeepseekV2DecoderLayer(nn.Module):
2340
2495
 
@@ -2343,6 +2498,7 @@ class DeepseekV2DecoderLayer(nn.Module):
2343
2498
  config: PretrainedConfig,
2344
2499
  layer_id: int,
2345
2500
  quant_config: Optional[QuantizationConfig] = None,
2501
+ moe_quant_config: Optional[QuantizationConfig] = None,
2346
2502
  is_nextn: bool = False,
2347
2503
  prefix: str = "",
2348
2504
  alt_stream: Optional[torch.cuda.Stream] = None,
@@ -2353,7 +2509,9 @@ class DeepseekV2DecoderLayer(nn.Module):
2353
2509
  rope_theta = getattr(config, "rope_theta", 10000)
2354
2510
  rope_scaling = getattr(config, "rope_scaling", None)
2355
2511
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
2356
- self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
2512
+ self.speculative_algorithm = SpeculativeAlgorithm.from_string(
2513
+ get_global_server_args().speculative_algorithm
2514
+ )
2357
2515
  self.layer_id = layer_id
2358
2516
  self.is_nextn = is_nextn
2359
2517
  self.self_attn = DeepseekV2AttentionMLA(
@@ -2390,7 +2548,7 @@ class DeepseekV2DecoderLayer(nn.Module):
2390
2548
  if self.is_layer_sparse:
2391
2549
  self.mlp = DeepseekV2MoE(
2392
2550
  config=config,
2393
- quant_config=quant_config,
2551
+ quant_config=moe_quant_config or quant_config,
2394
2552
  prefix=add_prefix("mlp", prefix),
2395
2553
  layer_id=self.layer_id,
2396
2554
  alt_stream=alt_stream,
@@ -2796,6 +2954,10 @@ class DeepseekV2ForCausalLM(nn.Module):
2796
2954
  self.config = config
2797
2955
  self.tp_size = get_tensor_model_parallel_world_size()
2798
2956
  self.quant_config = quant_config
2957
+ if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
2958
+ CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
2959
+ True, "dynamic", None, [128, 128]
2960
+ )
2799
2961
  self.determine_num_fused_shared_experts()
2800
2962
  self.model = DeepseekV2Model(
2801
2963
  config, quant_config, prefix=add_prefix("model", prefix)
@@ -2805,7 +2967,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2805
2967
  config.hidden_size,
2806
2968
  quant_config=quant_config,
2807
2969
  prefix=add_prefix("lm_head", prefix),
2808
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2970
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
2809
2971
  )
2810
2972
  self.logits_processor = LogitsProcessor(config)
2811
2973
 
@@ -2825,7 +2987,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2825
2987
  self, architecture: str = "DeepseekV3ForCausalLM"
2826
2988
  ):
2827
2989
  self.num_fused_shared_experts = 0
2828
- if global_server_args_dict["disable_shared_experts_fusion"]:
2990
+ if get_global_server_args().disable_shared_experts_fusion:
2829
2991
  return
2830
2992
 
2831
2993
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
@@ -2844,7 +3006,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2844
3006
  disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2845
3007
 
2846
3008
  if disable_reason is not None:
2847
- global_server_args_dict["disable_shared_experts_fusion"] = True
3009
+ get_global_server_args().disable_shared_experts_fusion = True
2848
3010
  self.num_fused_shared_experts = 0
2849
3011
  log_info_on_rank0(
2850
3012
  logger,
@@ -2909,7 +3071,7 @@ class DeepseekV2ForCausalLM(nn.Module):
2909
3071
  )
2910
3072
  if hasattr(self_attn.kv_b_proj, "qweight"):
2911
3073
  # AWQ compatible
2912
- if _is_cuda or _is_hip:
3074
+ if _is_cuda or _is_hip or _is_npu:
2913
3075
  w = awq_dequantize(
2914
3076
  self_attn.kv_b_proj.qweight,
2915
3077
  self_attn.kv_b_proj.scales,
@@ -2935,11 +3097,13 @@ class DeepseekV2ForCausalLM(nn.Module):
2935
3097
  torch.float8_e4m3fn,
2936
3098
  torch.float8_e4m3fnuz,
2937
3099
  ):
2938
- if (
2939
- hasattr(self.quant_config, "weight_block_size")
2940
- and self.quant_config.weight_block_size is not None
2941
- ):
2942
- weight_block_size = self.quant_config.weight_block_size
3100
+ selected_quant_config = getattr(
3101
+ self.quant_config, "DeepSeekFP8Config", self.quant_config
3102
+ )
3103
+ weight_block_size = getattr(
3104
+ selected_quant_config, "weight_block_size", None
3105
+ )
3106
+ if weight_block_size is not None:
2943
3107
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
2944
3108
  if _is_fp8_fnuz:
2945
3109
  weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
@@ -3069,6 +3233,16 @@ class DeepseekV2ForCausalLM(nn.Module):
3069
3233
  ):
3070
3234
  self._weight_requant_ue8m0(is_nextn)
3071
3235
 
3236
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3237
+ if (
3238
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
3239
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
3240
+ and get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN")
3241
+ ):
3242
+ self._transform_scale_ue8m0(is_nextn)
3243
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3244
+ self._transform_scale_nextn_moe_ue8m0()
3245
+
3072
3246
  def _weight_requant_ue8m0(self, is_nextn=False):
3073
3247
  weight_block_size = self.quant_config.weight_block_size
3074
3248
 
@@ -3134,6 +3308,47 @@ class DeepseekV2ForCausalLM(nn.Module):
3134
3308
  module.weight, module.weight_scale_inv, weight_block_size
3135
3309
  )
3136
3310
 
3311
+ # TODO can move weight_requant_ue8m0 and transform_scale_ue8m0 into Fp8LinearMethod.process_weights_after_loading
3312
+ def _transform_scale_ue8m0(self, is_nextn=False):
3313
+ num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
3314
+
3315
+ for layer_id in range(num_hidden_layers):
3316
+ if is_nextn:
3317
+ layer = self.model.decoder
3318
+ else:
3319
+ layer = self.model.layers[layer_id]
3320
+
3321
+ module_list = []
3322
+ if self.config.q_lora_rank is not None:
3323
+ module_list.append(layer.self_attn.q_b_proj)
3324
+
3325
+ for module in module_list:
3326
+ transform_scale_ue8m0_inplace(
3327
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3328
+ )
3329
+
3330
+ # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0)
3331
+ def _transform_scale_nextn_moe_ue8m0(self):
3332
+ layer = self.model.decoder
3333
+
3334
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
3335
+ if shared_experts is not None:
3336
+ for module in [
3337
+ shared_experts.gate_up_proj,
3338
+ shared_experts.down_proj,
3339
+ ]:
3340
+ transform_scale_ue8m0_inplace(
3341
+ module.weight_scale_inv, mn=module.weight.shape[-2]
3342
+ )
3343
+
3344
+ experts = layer.mlp.experts
3345
+ if isinstance(experts, DeepEPMoE):
3346
+ for w in [
3347
+ experts.w13_weight_fp8,
3348
+ experts.w2_weight_fp8,
3349
+ ]:
3350
+ transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2])
3351
+
3137
3352
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
3138
3353
 
3139
3354
  if is_nextn:
@@ -3149,6 +3364,13 @@ class DeepseekV2ForCausalLM(nn.Module):
3149
3364
  else:
3150
3365
  raise ValueError("num_nextn_predict_layers is not in the config")
3151
3366
 
3367
+ if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):
3368
+ weights = self._quant_attn_to_fp8_ue8m0(weights, is_nextn=is_nextn)
3369
+ if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config):
3370
+ weights = self._quant_nextn_moe_to_fp8_ue8m0(
3371
+ weights, nextn_layer_id=nextn_layer_id
3372
+ )
3373
+
3152
3374
  stacked_params_mapping = [
3153
3375
  # (param_name, shard_name, shard_id)
3154
3376
  ("gate_up_proj", "gate_proj", 0),
@@ -3378,6 +3600,62 @@ class DeepseekV2ForCausalLM(nn.Module):
3378
3600
 
3379
3601
  self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
3380
3602
 
3603
+ def _quant_attn_to_fp8_ue8m0(self, weights, is_nextn):
3604
+ weights_dict = dict(weights)
3605
+
3606
+ # temporarily only support DeepSeek V3/R1
3607
+ weight_block_size = [128, 128]
3608
+
3609
+ for layer_id in tqdm.trange(
3610
+ self.config.num_hidden_layers + int(is_nextn),
3611
+ desc="quant attn to fp8 ue8m0",
3612
+ ):
3613
+ for stem in [
3614
+ # may put tensors like `o_proj` here for DeepSeek FP4 ckpt v1
3615
+ "q_b_proj",
3616
+ ]:
3617
+ partial_name = f"model.layers.{layer_id}.self_attn.{stem}"
3618
+ original_weight = weights_dict[f"{partial_name}.weight"]
3619
+ out_w, out_s = quant_weight_ue8m0(
3620
+ original_weight, weight_block_size=weight_block_size
3621
+ )
3622
+ weights_dict[f"{partial_name}.weight"] = out_w
3623
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3624
+
3625
+ return list(weights_dict.items())
3626
+
3627
+ # TODO avoid code dup
3628
+ def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int):
3629
+ weights_dict = dict(weights)
3630
+
3631
+ # temporarily only support DeepSeek V3/R1
3632
+ weight_block_size = [128, 128]
3633
+
3634
+ for layer_id in [nextn_layer_id]:
3635
+ for expert_sub_name in [
3636
+ "shared_experts",
3637
+ *[
3638
+ f"experts.{expert_id}"
3639
+ for expert_id in range(self.config.n_routed_experts)
3640
+ ],
3641
+ ]:
3642
+ for stem in [
3643
+ "gate_proj",
3644
+ "up_proj",
3645
+ "down_proj",
3646
+ ]:
3647
+ partial_name = (
3648
+ f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}"
3649
+ )
3650
+ original_weight = weights_dict[f"{partial_name}.weight"]
3651
+ out_w, out_s = quant_weight_ue8m0(
3652
+ original_weight, weight_block_size=weight_block_size
3653
+ )
3654
+ weights_dict[f"{partial_name}.weight"] = out_w
3655
+ weights_dict[f"{partial_name}.weight_scale_inv"] = out_s
3656
+
3657
+ return list(weights_dict.items())
3658
+
3381
3659
  def get_embed_and_head(self):
3382
3660
  return self.model.embed_tokens.weight, self.lm_head.weight
3383
3661