sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -37,8 +37,11 @@ from sglang.srt.configs.model_config import ModelConfig
37
37
  from sglang.srt.distributed import get_tensor_model_parallel_rank
38
38
  from sglang.srt.layers.dp_attention import get_attention_tp_rank
39
39
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
40
- from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
41
- from sglang.srt.utils import find_local_repo_dir, print_warning_once
40
+ from sglang.srt.layers.quantization.modelopt_quant import (
41
+ ModelOptFp4Config,
42
+ ModelOptFp8Config,
43
+ )
44
+ from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
42
45
  from sglang.utils import is_in_ci
43
46
 
44
47
  logger = logging.getLogger(__name__)
@@ -110,6 +113,9 @@ def convert_bin_to_safetensor_file(
110
113
 
111
114
  dirname = os.path.dirname(sf_filename)
112
115
  os.makedirs(dirname, exist_ok=True)
116
+
117
+ from safetensors.torch import save_file
118
+
113
119
  save_file(loaded, sf_filename, metadata={"format": "pt"})
114
120
 
115
121
  # check file size
@@ -132,11 +138,26 @@ def convert_bin_to_safetensor_file(
132
138
  raise RuntimeError(f"The output tensors do not match for key {k}")
133
139
 
134
140
 
141
+ def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
142
+ for prefix, new_prefix in prefix_mapping.items():
143
+ if key.startswith(prefix):
144
+ key = key.replace(prefix, new_prefix, 1)
145
+ return key
146
+
147
+
148
+ def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
149
+ for substr, new_substr in substring_mapping.items():
150
+ if substr in key:
151
+ key = key.replace(substr, new_substr)
152
+ return key
153
+
154
+
135
155
  # TODO(woosuk): Move this to other place.
136
156
  def get_quant_config(
137
157
  model_config: ModelConfig,
138
158
  load_config: LoadConfig,
139
159
  packed_modules_mapping: Dict[str, List[str]],
160
+ remap_prefix: Dict[str, str] | None = None,
140
161
  ) -> QuantizationConfig:
141
162
  quant_cls = get_quantization_config(model_config.quantization)
142
163
 
@@ -206,35 +227,33 @@ def get_quant_config(
206
227
  quant_config_file = quant_config_files[0]
207
228
  with open(quant_config_file) as f:
208
229
  config = json.load(f)
230
+ if remap_prefix is not None:
231
+ exclude_modules = [
232
+ replace_prefix(key, remap_prefix)
233
+ for key in config["quantization"]["exclude_modules"]
234
+ ]
235
+ config["quantization"]["exclude_modules"] = exclude_modules
236
+ config["packed_modules_mapping"] = packed_modules_mapping
209
237
 
210
238
  if model_config.quantization == "bitsandbytes":
211
239
  config["adapter_name_or_path"] = model_name_or_path
212
- elif model_config.quantization == "modelopt":
213
- if config["producer"]["name"] == "modelopt":
240
+ elif model_config.quantization.startswith("modelopt") and (
241
+ config["producer"]["name"].startswith("modelopt")
242
+ ):
243
+ quant_algo = config["quantization"]["quant_algo"]
244
+ if quant_algo is None:
214
245
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
215
- if config["quantization"]["quant_algo"] is None:
216
- if (
217
- model_config.hf_config.architectures[0]
218
- != "LlamaForCausalLMEagle3"
219
- ):
220
- raise ValueError(
221
- f"Invalid quant_config, quantization method: {model_config.quantization},"
222
- f"hf architectures: {model_config.hf_config.architectures[0]}. "
223
- )
224
- return None
225
- if "FP4" in config["quantization"]["quant_algo"]:
226
- return ModelOptFp4Config.from_config(config)
227
- else:
228
- return quant_cls.from_config(config)
229
- else:
230
- raise ValueError(
231
- f"Unsupported quantization config"
232
- f" found for {model_config.quantization} in {f}."
233
- )
234
- elif model_config.quantization == "w8a8_int8":
235
- config["packed_modules_mapping"] = packed_modules_mapping
236
-
237
- return quant_cls.from_config(config)
246
+ if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
247
+ raise ValueError(
248
+ f"Invalid quant_config, quantization method: {model_config.quantization},"
249
+ f"hf architectures: {model_config.hf_config.architectures[0]}. "
250
+ )
251
+ return None
252
+ elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
253
+ return ModelOptFp8Config.from_config(config)
254
+ elif "FP4" in quant_algo:
255
+ return ModelOptFp4Config.from_config(config)
256
+ return quant_cls.from_config(config)
238
257
 
239
258
 
240
259
  def find_local_hf_snapshot_dir(
@@ -426,7 +445,7 @@ def download_weights_from_hf(
426
445
  allow_patterns = [pattern]
427
446
  break
428
447
 
429
- logger.info("Using model weights format %s", allow_patterns)
448
+ log_info_on_rank0(logger, f"Using model weights format {allow_patterns}")
430
449
  # Use file lock to prevent multiple processes from
431
450
  # downloading the same model weights at the same time.
432
451
  with get_lock(model_name_or_path, cache_dir):
@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
46
  ParallelLMHead,
47
47
  VocabParallelEmbedding,
48
48
  )
49
- from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
51
50
  from sglang.srt.model_loader.weight_utils import (
52
51
  default_weight_loader,
53
52
  kv_cache_scales_loader,
54
53
  maybe_remap_kv_scale_name,
55
54
  )
55
+ from sglang.srt.server_args import get_global_server_args
56
56
  from sglang.srt.utils import add_prefix, make_layers
57
- from sglang.utils import get_exception_traceback
58
57
 
59
58
  logger = logging.getLogger(__name__)
60
59
 
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
447
446
  config.hidden_size,
448
447
  quant_config=quant_config,
449
448
  prefix=add_prefix("lm_head", prefix),
450
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
449
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
451
450
  )
452
451
  self.logits_processor = LogitsProcessor(config)
453
452
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  ParallelLMHead,
43
43
  VocabParallelEmbedding,
44
44
  )
45
- from sglang.srt.managers.schedule_batch import global_server_args_dict
46
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
46
  from sglang.srt.model_loader.weight_utils import (
48
47
  default_weight_loader,
49
48
  kv_cache_scales_loader,
50
49
  maybe_remap_kv_scale_name,
51
50
  )
51
+ from sglang.srt.server_args import get_global_server_args
52
52
  from sglang.srt.utils import add_prefix, make_layers
53
53
 
54
54
  logger = logging.getLogger(__name__)
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
407
407
  config.hidden_size,
408
408
  quant_config=quant_config,
409
409
  prefix=add_prefix("lm_head", prefix),
410
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
410
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
411
411
  )
412
412
  self.logits_processor = LogitsProcessor(config)
413
413
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -17,9 +17,9 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ SGLang BailingMoE model."""
20
+ """SGLang BailingMoE model."""
21
21
  import logging
22
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
22
+ from typing import Iterable, Optional, Tuple, Union
23
23
 
24
24
  import torch
25
25
  import torch.nn.functional as F
@@ -54,12 +54,11 @@ from sglang.srt.layers.linear import (
54
54
  RowParallelLinear,
55
55
  )
56
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
57
- from sglang.srt.layers.moe import get_moe_a2a_backend
57
+ from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
58
58
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
59
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
60
60
  from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
61
61
  from sglang.srt.layers.moe.topk import TopK
62
- from sglang.srt.layers.moe.utils import DeepEPMode
63
62
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
64
63
  from sglang.srt.layers.radix_attention import RadixAttention
65
64
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -68,7 +67,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
68
67
  ParallelLMHead,
69
68
  VocabParallelEmbedding,
70
69
  )
71
- from sglang.srt.managers.schedule_batch import global_server_args_dict
72
70
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
73
71
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
74
72
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -76,6 +74,7 @@ from sglang.srt.models.utils import (
76
74
  create_fused_set_kv_buffer_arg,
77
75
  enable_fused_set_kv_buffer,
78
76
  )
77
+ from sglang.srt.server_args import get_global_server_args
79
78
  from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
80
79
 
81
80
  LoraConfig = None
@@ -204,8 +203,8 @@ class BailingMoESparseMoeBlock(nn.Module):
204
203
  else:
205
204
  self.router_dtype = torch.bfloat16
206
205
 
207
- # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
208
- assert global_server_args_dict["ep_num_redundant_experts"] == 0
206
+ # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
207
+ assert get_global_server_args().ep_num_redundant_experts == 0
209
208
  # check group topk
210
209
  self.num_expert_group = getattr(config, "n_group", 0)
211
210
  self.topk_group = getattr(config, "topk_group", 0)
@@ -220,7 +219,7 @@ class BailingMoESparseMoeBlock(nn.Module):
220
219
  self.use_grouped_topk = False
221
220
 
222
221
  self.num_experts = (
223
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
222
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
224
223
  )
225
224
 
226
225
  self.gate = BailingMoEGate(
@@ -293,7 +292,7 @@ class BailingMoESparseMoeBlock(nn.Module):
293
292
  num_local_experts=config.num_experts // self.tp_size,
294
293
  hidden_size=config.hidden_size,
295
294
  params_dtype=config.torch_dtype,
296
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
295
+ deepep_mode=get_deepep_mode(),
297
296
  async_finish=True, # TODO
298
297
  return_recv_hook=True,
299
298
  )
@@ -381,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
381
380
  if self.num_shared_experts > 0:
382
381
  shared_output = self.shared_experts(hidden_states)
383
382
 
384
- topk_weights, topk_idx, _ = self.topk(
383
+ topk_output = self.topk(
385
384
  hidden_states,
386
385
  router_logits,
387
386
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -390,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
390
389
  ),
391
390
  )
392
391
  else:
393
- topk_idx = torch.full(
394
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
395
- )
396
- topk_weights = torch.empty(
397
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
398
- )
399
-
400
- if self.ep_size > 1:
401
- (
402
- hidden_states,
403
- topk_idx,
404
- topk_weights,
405
- reorder_topk_ids,
406
- num_recv_tokens_per_expert,
407
- seg_indptr,
408
- masked_m,
409
- expected_m,
410
- ) = self.deepep_dispatcher.dispatch(
411
- hidden_states,
412
- topk_idx,
413
- topk_weights,
414
- forward_batch=forward_batch,
415
- )
392
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
416
393
 
417
394
  final_hidden_states = self.experts(
418
395
  hidden_states=hidden_states,
419
- topk_idx=topk_idx,
420
- topk_weights=topk_weights,
421
- reorder_topk_ids=reorder_topk_ids,
422
- seg_indptr=seg_indptr,
423
- masked_m=masked_m,
424
- expected_m=expected_m,
425
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
426
- forward_batch=forward_batch,
396
+ topk_output=topk_output,
427
397
  )
428
- if self.ep_size > 1:
429
- final_hidden_states = self.deepep_dispatcher.combine(
430
- final_hidden_states,
431
- topk_idx,
432
- topk_weights,
433
- forward_batch=forward_batch,
434
- )
435
-
436
- final_hidden_states *= self.routed_scaling_factor
437
398
 
438
399
  if shared_output is not None:
439
- final_hidden_states = final_hidden_states + shared_output
400
+ final_hidden_states += shared_output
440
401
  return final_hidden_states
441
402
 
442
403
 
@@ -824,7 +785,7 @@ class BailingMoEForCausalLM(nn.Module):
824
785
  config.hidden_size,
825
786
  quant_config=quant_config,
826
787
  prefix=add_prefix("lm_head", prefix),
827
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
788
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
828
789
  )
829
790
  self.logits_processor = LogitsProcessor(config)
830
791
 
@@ -17,7 +17,7 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ SGLang BailingMoENextN model."""
20
+ """SGLang BailingMoENextN model."""
21
21
  import logging
22
22
  from typing import Iterable, Optional, Tuple
23
23
 
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
29
29
  from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
30
  from sglang.srt.layers.layernorm import RMSNorm
31
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
- from sglang.srt.layers.moe.topk import select_experts
33
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
33
  from sglang.srt.layers.vocab_parallel_embedding import (
35
34
  ParallelLMHead,
36
35
  VocabParallelEmbedding,
37
36
  )
38
- from sglang.srt.managers.schedule_batch import global_server_args_dict
39
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
38
  from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
39
+ from sglang.srt.server_args import get_global_server_args
41
40
  from sglang.srt.utils import add_prefix
42
41
 
43
42
  LoraConfig = None
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
145
144
  config.hidden_size,
146
145
  quant_config=quant_config,
147
146
  prefix=add_prefix("model.shared_head.head", prefix),
148
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
147
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
149
148
  )
150
149
  self.logits_processor = LogitsProcessor(config)
151
150
 
sglang/srt/models/bert.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- from typing import Any, Dict, Iterable, Optional, Set, Tuple
2
+ from typing import Iterable, Optional, Set, Tuple
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -25,14 +25,19 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
25
25
  from sglang.srt.layers.dp_attention import is_dp_attention_enabled
26
26
  from sglang.srt.layers.layernorm import RMSNorm
27
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
28
+ from sglang.srt.layers.quantization import Fp8Config
28
29
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
30
  from sglang.srt.layers.vocab_parallel_embedding import (
30
31
  ParallelLMHead,
31
32
  VocabParallelEmbedding,
32
33
  )
33
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
- from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
35
+ from sglang.srt.models.deepseek_v2 import (
36
+ DeepseekV2DecoderLayer,
37
+ DeepseekV3ForCausalLM,
38
+ enable_nextn_moe_bf16_cast_to_fp8,
39
+ )
40
+ from sglang.srt.server_args import get_global_server_args
36
41
  from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
37
42
 
38
43
  logger = logging.getLogger(__name__)
@@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
49
54
  prefix: str = "",
50
55
  ) -> None:
51
56
  super().__init__()
57
+
58
+ if enable_nextn_moe_bf16_cast_to_fp8(quant_config):
59
+ # refer to real DeepSeek V3 quant config
60
+ moe_quant_config = Fp8Config(
61
+ is_checkpoint_fp8_serialized=True,
62
+ weight_block_size=[128, 128],
63
+ )
64
+ else:
65
+ moe_quant_config = None
66
+
52
67
  if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
53
68
  logger.warning(
54
69
  "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
@@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module):
74
89
  config,
75
90
  0,
76
91
  quant_config=quant_config,
92
+ moe_quant_config=moe_quant_config,
77
93
  is_nextn=True,
78
94
  prefix=add_prefix("decoder", prefix),
79
95
  alt_stream=self.alt_stream,
@@ -152,7 +168,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
152
168
  config.hidden_size,
153
169
  quant_config=quant_config,
154
170
  prefix=add_prefix("model.shared_head.head", prefix),
155
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
171
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
156
172
  )
157
173
  self.logits_processor = LogitsProcessor(config)
158
174