sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ import hashlib
8
8
  import json
9
9
  import logging
10
10
  import os
11
- import queue
11
+ import re
12
12
  import tempfile
13
13
  from collections import defaultdict
14
14
  from typing import (
@@ -37,8 +37,12 @@ from sglang.srt.configs.model_config import ModelConfig
37
37
  from sglang.srt.distributed import get_tensor_model_parallel_rank
38
38
  from sglang.srt.layers.dp_attention import get_attention_tp_rank
39
39
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
40
- from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
41
- from sglang.srt.utils import print_warning_once
40
+ from sglang.srt.layers.quantization.modelopt_quant import (
41
+ ModelOptFp4Config,
42
+ ModelOptFp8Config,
43
+ )
44
+ from sglang.srt.utils import find_local_repo_dir, log_info_on_rank0, print_warning_once
45
+ from sglang.utils import is_in_ci
42
46
 
43
47
  logger = logging.getLogger(__name__)
44
48
 
@@ -109,6 +113,9 @@ def convert_bin_to_safetensor_file(
109
113
 
110
114
  dirname = os.path.dirname(sf_filename)
111
115
  os.makedirs(dirname, exist_ok=True)
116
+
117
+ from safetensors.torch import save_file
118
+
112
119
  save_file(loaded, sf_filename, metadata={"format": "pt"})
113
120
 
114
121
  # check file size
@@ -131,11 +138,26 @@ def convert_bin_to_safetensor_file(
131
138
  raise RuntimeError(f"The output tensors do not match for key {k}")
132
139
 
133
140
 
141
+ def replace_prefix(key: str, prefix_mapping: dict[str, str]) -> str:
142
+ for prefix, new_prefix in prefix_mapping.items():
143
+ if key.startswith(prefix):
144
+ key = key.replace(prefix, new_prefix, 1)
145
+ return key
146
+
147
+
148
+ def replace_substrings(key: str, substring_mapping: dict[str, str]) -> str:
149
+ for substr, new_substr in substring_mapping.items():
150
+ if substr in key:
151
+ key = key.replace(substr, new_substr)
152
+ return key
153
+
154
+
134
155
  # TODO(woosuk): Move this to other place.
135
156
  def get_quant_config(
136
157
  model_config: ModelConfig,
137
158
  load_config: LoadConfig,
138
159
  packed_modules_mapping: Dict[str, List[str]],
160
+ remap_prefix: Dict[str, str] | None = None,
139
161
  ) -> QuantizationConfig:
140
162
  quant_cls = get_quantization_config(model_config.quantization)
141
163
 
@@ -205,35 +227,176 @@ def get_quant_config(
205
227
  quant_config_file = quant_config_files[0]
206
228
  with open(quant_config_file) as f:
207
229
  config = json.load(f)
230
+ if remap_prefix is not None:
231
+ exclude_modules = [
232
+ replace_prefix(key, remap_prefix)
233
+ for key in config["quantization"]["exclude_modules"]
234
+ ]
235
+ config["quantization"]["exclude_modules"] = exclude_modules
236
+ config["packed_modules_mapping"] = packed_modules_mapping
208
237
 
209
238
  if model_config.quantization == "bitsandbytes":
210
239
  config["adapter_name_or_path"] = model_name_or_path
211
- elif model_config.quantization == "modelopt":
212
- if config["producer"]["name"] == "modelopt":
240
+ elif model_config.quantization.startswith("modelopt") and (
241
+ config["producer"]["name"].startswith("modelopt")
242
+ ):
243
+ quant_algo = config["quantization"]["quant_algo"]
244
+ if quant_algo is None:
213
245
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
214
- if config["quantization"]["quant_algo"] is None:
215
- if (
216
- model_config.hf_config.architectures[0]
217
- != "LlamaForCausalLMEagle3"
218
- ):
219
- raise ValueError(
220
- f"Invalid quant_config, quantization method: {model_config.quantization},"
221
- f"hf architectures: {model_config.hf_config.architectures[0]}. "
222
- )
223
- return None
224
- if "FP4" in config["quantization"]["quant_algo"]:
225
- return ModelOptFp4Config.from_config(config)
226
- else:
227
- return quant_cls.from_config(config)
228
- else:
229
- raise ValueError(
230
- f"Unsupported quantization config"
231
- f" found for {model_config.quantization} in {f}."
246
+ if model_config.hf_config.architectures[0] != "LlamaForCausalLMEagle3":
247
+ raise ValueError(
248
+ f"Invalid quant_config, quantization method: {model_config.quantization},"
249
+ f"hf architectures: {model_config.hf_config.architectures[0]}. "
250
+ )
251
+ return None
252
+ elif quant_algo == "FP8" or model_config.quantization == "modelopt_fp8":
253
+ return ModelOptFp8Config.from_config(config)
254
+ elif "FP4" in quant_algo:
255
+ return ModelOptFp4Config.from_config(config)
256
+ return quant_cls.from_config(config)
257
+
258
+
259
+ def find_local_hf_snapshot_dir(
260
+ model_name_or_path: str,
261
+ cache_dir: Optional[str],
262
+ allow_patterns: List[str],
263
+ revision: Optional[str] = None,
264
+ ) -> Optional[str]:
265
+ """If the weights are already local, skip downloading and returns the path."""
266
+ if os.path.isdir(model_name_or_path):
267
+ return None
268
+
269
+ found_local_snapshot_dir = None
270
+
271
+ # Check custom cache_dir (if provided)
272
+ if cache_dir:
273
+ try:
274
+ repo_folder = os.path.join(
275
+ cache_dir,
276
+ huggingface_hub.constants.REPO_ID_SEPARATOR.join(
277
+ ["models", *model_name_or_path.split("/")]
278
+ ),
279
+ )
280
+ rev_to_use = revision
281
+ if not rev_to_use:
282
+ ref_main = os.path.join(repo_folder, "refs", "main")
283
+ if os.path.isfile(ref_main):
284
+ with open(ref_main) as f:
285
+ rev_to_use = f.read().strip()
286
+ if rev_to_use:
287
+ rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
288
+ if os.path.isdir(rev_dir):
289
+ found_local_snapshot_dir = rev_dir
290
+ except Exception as e:
291
+ logger.warning(
292
+ "Failed to find local snapshot in custom cache_dir %s: %s",
293
+ cache_dir,
294
+ e,
295
+ )
296
+
297
+ # Check default HF cache as well
298
+ if not found_local_snapshot_dir:
299
+ try:
300
+ rev_dir = find_local_repo_dir(model_name_or_path, revision)
301
+ if rev_dir and os.path.isdir(rev_dir):
302
+ found_local_snapshot_dir = rev_dir
303
+ except Exception as e:
304
+ logger.warning("Failed to find local snapshot in default HF cache: %s", e)
305
+
306
+ # if any incomplete file exists, force re-download by returning None
307
+ if found_local_snapshot_dir:
308
+ repo_folder = os.path.abspath(
309
+ os.path.join(found_local_snapshot_dir, "..", "..")
310
+ )
311
+ blobs_dir = os.path.join(repo_folder, "blobs")
312
+ if os.path.isdir(blobs_dir) and glob.glob(
313
+ os.path.join(blobs_dir, "*.incomplete")
314
+ ):
315
+ logger.info(
316
+ "Found .incomplete files in %s for %s. "
317
+ "Considering local snapshot incomplete.",
318
+ blobs_dir,
319
+ model_name_or_path,
320
+ )
321
+ return None
322
+
323
+ # if local snapshot exists, validate it contains at least one weight file
324
+ # matching allow_patterns before skipping download.
325
+ if found_local_snapshot_dir is None:
326
+ return None
327
+
328
+ local_weight_files: List[str] = []
329
+ try:
330
+ for pattern in allow_patterns:
331
+ matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
332
+ for f in matched_files:
333
+ # os.path.exists returns False for broken symlinks.
334
+ if not os.path.exists(f):
335
+ continue
336
+ local_weight_files.append(f)
337
+ except Exception as e:
338
+ logger.warning(
339
+ "Failed to scan local snapshot %s with patterns %s: %s",
340
+ found_local_snapshot_dir,
341
+ allow_patterns,
342
+ e,
343
+ )
344
+ local_weight_files = []
345
+
346
+ # After we have a list of valid files, check for sharded model completeness.
347
+ # Check if all safetensors with name model-{i}-of-{n}.safetensors exists
348
+ checked_sharded_model = False
349
+ for f in local_weight_files:
350
+ if checked_sharded_model:
351
+ break
352
+ base_name = os.path.basename(f)
353
+ # Regex for files like model-00001-of-00009.safetensors
354
+ match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
355
+ if match:
356
+ prefix = match.group(1)
357
+ shard_id_str = match.group(2)
358
+ total_shards_str = match.group(3)
359
+ suffix = match.group(4)
360
+ total_shards = int(total_shards_str)
361
+
362
+ # Check if all shards are present
363
+ missing_shards = []
364
+ for i in range(1, total_shards + 1):
365
+ # Reconstruct shard name, preserving padding of original shard id
366
+ shard_name = (
367
+ f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
368
+ )
369
+ expected_path = os.path.join(found_local_snapshot_dir, shard_name)
370
+ # os.path.exists returns False for broken symlinks, which is desired.
371
+ if not os.path.exists(expected_path):
372
+ missing_shards.append(shard_name)
373
+
374
+ if missing_shards:
375
+ logger.info(
376
+ "Found incomplete sharded model %s. Missing shards: %s. "
377
+ "Will attempt download.",
378
+ model_name_or_path,
379
+ missing_shards,
232
380
  )
233
- elif model_config.quantization == "w8a8_int8":
234
- config["packed_modules_mapping"] = packed_modules_mapping
381
+ return None
235
382
 
236
- return quant_cls.from_config(config)
383
+ # If we found and verified one set of shards, we are done.
384
+ checked_sharded_model = True
385
+
386
+ if len(local_weight_files) > 0:
387
+ logger.info(
388
+ "Found local HF snapshot for %s at %s; skipping download.",
389
+ model_name_or_path,
390
+ found_local_snapshot_dir,
391
+ )
392
+ return found_local_snapshot_dir
393
+ else:
394
+ logger.info(
395
+ "Local HF snapshot at %s has no files matching %s; will attempt download.",
396
+ found_local_snapshot_dir,
397
+ allow_patterns,
398
+ )
399
+ return None
237
400
 
238
401
 
239
402
  def download_weights_from_hf(
@@ -260,6 +423,16 @@ def download_weights_from_hf(
260
423
  Returns:
261
424
  str: The path to the downloaded model weights.
262
425
  """
426
+
427
+ if is_in_ci():
428
+ # If the weights are already local, skip downloading and returns the path.
429
+ # This is used to skip too-many Huggingface API calls in CI.
430
+ path = find_local_hf_snapshot_dir(
431
+ model_name_or_path, cache_dir, allow_patterns, revision
432
+ )
433
+ if path is not None:
434
+ return path
435
+
263
436
  if not huggingface_hub.constants.HF_HUB_OFFLINE:
264
437
  # Before we download we look at that is available:
265
438
  fs = HfFileSystem()
@@ -272,7 +445,7 @@ def download_weights_from_hf(
272
445
  allow_patterns = [pattern]
273
446
  break
274
447
 
275
- logger.info("Using model weights format %s", allow_patterns)
448
+ log_info_on_rank0(logger, f"Using model weights format {allow_patterns}")
276
449
  # Use file lock to prevent multiple processes from
277
450
  # downloading the same model weights at the same time.
278
451
  with get_lock(model_name_or_path, cache_dir):
@@ -46,15 +46,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
46
46
  ParallelLMHead,
47
47
  VocabParallelEmbedding,
48
48
  )
49
- from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
51
50
  from sglang.srt.model_loader.weight_utils import (
52
51
  default_weight_loader,
53
52
  kv_cache_scales_loader,
54
53
  maybe_remap_kv_scale_name,
55
54
  )
55
+ from sglang.srt.server_args import get_global_server_args
56
56
  from sglang.srt.utils import add_prefix, make_layers
57
- from sglang.utils import get_exception_traceback
58
57
 
59
58
  logger = logging.getLogger(__name__)
60
59
 
@@ -447,7 +446,7 @@ class ApertusForCausalLM(nn.Module):
447
446
  config.hidden_size,
448
447
  quant_config=quant_config,
449
448
  prefix=add_prefix("lm_head", prefix),
450
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
449
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
451
450
  )
452
451
  self.logits_processor = LogitsProcessor(config)
453
452
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -42,13 +42,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  ParallelLMHead,
43
43
  VocabParallelEmbedding,
44
44
  )
45
- from sglang.srt.managers.schedule_batch import global_server_args_dict
46
45
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
46
  from sglang.srt.model_loader.weight_utils import (
48
47
  default_weight_loader,
49
48
  kv_cache_scales_loader,
50
49
  maybe_remap_kv_scale_name,
51
50
  )
51
+ from sglang.srt.server_args import get_global_server_args
52
52
  from sglang.srt.utils import add_prefix, make_layers
53
53
 
54
54
  logger = logging.getLogger(__name__)
@@ -407,7 +407,7 @@ class ArceeForCausalLM(nn.Module):
407
407
  config.hidden_size,
408
408
  quant_config=quant_config,
409
409
  prefix=add_prefix("lm_head", prefix),
410
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
410
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
411
411
  )
412
412
  self.logits_processor = LogitsProcessor(config)
413
413
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -17,9 +17,9 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ SGLang BailingMoE model."""
20
+ """SGLang BailingMoE model."""
21
21
  import logging
22
- from typing import Any, Dict, Iterable, Optional, Tuple, Union
22
+ from typing import Iterable, Optional, Tuple, Union
23
23
 
24
24
  import torch
25
25
  import torch.nn.functional as F
@@ -45,21 +45,20 @@ from sglang.srt.layers.dp_attention import (
45
45
  get_attention_dp_size,
46
46
  get_attention_tp_rank,
47
47
  get_attention_tp_size,
48
+ is_dp_attention_enabled,
48
49
  )
49
50
  from sglang.srt.layers.layernorm import RMSNorm
50
51
  from sglang.srt.layers.linear import (
51
52
  MergedColumnParallelLinear,
52
53
  QKVParallelLinear,
53
- ReplicatedLinear,
54
54
  RowParallelLinear,
55
55
  )
56
56
  from sglang.srt.layers.logits_processor import LogitsProcessor
57
- from sglang.srt.layers.moe import get_moe_a2a_backend
57
+ from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
58
58
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
59
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
60
60
  from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
61
61
  from sglang.srt.layers.moe.topk import TopK
62
- from sglang.srt.layers.moe.utils import DeepEPMode
63
62
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
64
63
  from sglang.srt.layers.radix_attention import RadixAttention
65
64
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -68,10 +67,14 @@ from sglang.srt.layers.vocab_parallel_embedding import (
68
67
  ParallelLMHead,
69
68
  VocabParallelEmbedding,
70
69
  )
71
- from sglang.srt.managers.schedule_batch import global_server_args_dict
72
70
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
73
71
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
74
72
  from sglang.srt.model_loader.weight_utils import default_weight_loader
73
+ from sglang.srt.models.utils import (
74
+ create_fused_set_kv_buffer_arg,
75
+ enable_fused_set_kv_buffer,
76
+ )
77
+ from sglang.srt.server_args import get_global_server_args
75
78
  from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
76
79
 
77
80
  LoraConfig = None
@@ -200,8 +203,8 @@ class BailingMoESparseMoeBlock(nn.Module):
200
203
  else:
201
204
  self.router_dtype = torch.bfloat16
202
205
 
203
- # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
204
- assert global_server_args_dict["ep_num_redundant_experts"] == 0
206
+ # TODO global_server_args.ep_num_redundant_experts is used for eplb, not supported now
207
+ assert get_global_server_args().ep_num_redundant_experts == 0
205
208
  # check group topk
206
209
  self.num_expert_group = getattr(config, "n_group", 0)
207
210
  self.topk_group = getattr(config, "topk_group", 0)
@@ -216,7 +219,7 @@ class BailingMoESparseMoeBlock(nn.Module):
216
219
  self.use_grouped_topk = False
217
220
 
218
221
  self.num_experts = (
219
- config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
222
+ config.num_experts + get_global_server_args().ep_num_redundant_experts
220
223
  )
221
224
 
222
225
  self.gate = BailingMoEGate(
@@ -289,7 +292,7 @@ class BailingMoESparseMoeBlock(nn.Module):
289
292
  num_local_experts=config.num_experts // self.tp_size,
290
293
  hidden_size=config.hidden_size,
291
294
  params_dtype=config.torch_dtype,
292
- deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
295
+ deepep_mode=get_deepep_mode(),
293
296
  async_finish=True, # TODO
294
297
  return_recv_hook=True,
295
298
  )
@@ -377,7 +380,7 @@ class BailingMoESparseMoeBlock(nn.Module):
377
380
  if self.num_shared_experts > 0:
378
381
  shared_output = self.shared_experts(hidden_states)
379
382
 
380
- topk_weights, topk_idx, _ = self.topk(
383
+ topk_output = self.topk(
381
384
  hidden_states,
382
385
  router_logits,
383
386
  num_token_non_padded=forward_batch.num_token_non_padded,
@@ -386,53 +389,15 @@ class BailingMoESparseMoeBlock(nn.Module):
386
389
  ),
387
390
  )
388
391
  else:
389
- topk_idx = torch.full(
390
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
391
- )
392
- topk_weights = torch.empty(
393
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
394
- )
395
-
396
- if self.ep_size > 1:
397
- (
398
- hidden_states,
399
- topk_idx,
400
- topk_weights,
401
- reorder_topk_ids,
402
- num_recv_tokens_per_expert,
403
- seg_indptr,
404
- masked_m,
405
- expected_m,
406
- ) = self.deepep_dispatcher.dispatch(
407
- hidden_states,
408
- topk_idx,
409
- topk_weights,
410
- forward_batch=forward_batch,
411
- )
392
+ topk_output = self.topk.empty_topk_output(hidden_states.device)
412
393
 
413
394
  final_hidden_states = self.experts(
414
395
  hidden_states=hidden_states,
415
- topk_idx=topk_idx,
416
- topk_weights=topk_weights,
417
- reorder_topk_ids=reorder_topk_ids,
418
- seg_indptr=seg_indptr,
419
- masked_m=masked_m,
420
- expected_m=expected_m,
421
- num_recv_tokens_per_expert=num_recv_tokens_per_expert,
422
- forward_batch=forward_batch,
396
+ topk_output=topk_output,
423
397
  )
424
- if self.ep_size > 1:
425
- final_hidden_states = self.deepep_dispatcher.combine(
426
- final_hidden_states,
427
- topk_idx,
428
- topk_weights,
429
- forward_batch=forward_batch,
430
- )
431
-
432
- final_hidden_states *= self.routed_scaling_factor
433
398
 
434
399
  if shared_output is not None:
435
- final_hidden_states = final_hidden_states + shared_output
400
+ final_hidden_states += shared_output
436
401
  return final_hidden_states
437
402
 
438
403
 
@@ -555,8 +520,27 @@ class BailingMoEAttention(nn.Module):
555
520
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
556
521
  if self.use_qk_norm:
557
522
  q, k = self._apply_qk_norm(q, k)
558
- q, k = self.rotary_emb(positions, q, k)
559
- context_layer = self.attn(q, k, v, forward_batch)
523
+ q, k = self.rotary_emb(
524
+ positions,
525
+ q,
526
+ k,
527
+ fused_set_kv_buffer_arg=(
528
+ create_fused_set_kv_buffer_arg(
529
+ value=v,
530
+ layer=self.attn,
531
+ forward_batch=forward_batch,
532
+ )
533
+ if enable_fused_set_kv_buffer(forward_batch)
534
+ else None
535
+ ),
536
+ )
537
+ context_layer = self.attn(
538
+ q,
539
+ k,
540
+ v,
541
+ forward_batch,
542
+ save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
543
+ )
560
544
  attn_output, _ = self.dense(context_layer)
561
545
  return attn_output
562
546
 
@@ -702,7 +686,7 @@ class BailingMoEModel(nn.Module):
702
686
  self.embed_dim,
703
687
  quant_config=quant_config,
704
688
  prefix=add_prefix("word_embeddings", prefix),
705
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
689
+ enable_tp=not is_dp_attention_enabled(),
706
690
  )
707
691
  else:
708
692
  self.word_embeddings = PPMissingLayer()
@@ -801,7 +785,7 @@ class BailingMoEForCausalLM(nn.Module):
801
785
  config.hidden_size,
802
786
  quant_config=quant_config,
803
787
  prefix=add_prefix("lm_head", prefix),
804
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
788
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
805
789
  )
806
790
  self.logits_processor = LogitsProcessor(config)
807
791
 
@@ -17,7 +17,7 @@
17
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
- """ SGLang BailingMoENextN model."""
20
+ """SGLang BailingMoENextN model."""
21
21
  import logging
22
22
  from typing import Iterable, Optional, Tuple
23
23
 
@@ -29,15 +29,14 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
29
29
  from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
30
  from sglang.srt.layers.layernorm import RMSNorm
31
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
32
- from sglang.srt.layers.moe.topk import select_experts
33
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
33
  from sglang.srt.layers.vocab_parallel_embedding import (
35
34
  ParallelLMHead,
36
35
  VocabParallelEmbedding,
37
36
  )
38
- from sglang.srt.managers.schedule_batch import global_server_args_dict
39
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
38
  from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
39
+ from sglang.srt.server_args import get_global_server_args
41
40
  from sglang.srt.utils import add_prefix
42
41
 
43
42
  LoraConfig = None
@@ -145,7 +144,7 @@ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
145
144
  config.hidden_size,
146
145
  quant_config=quant_config,
147
146
  prefix=add_prefix("model.shared_head.head", prefix),
148
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
147
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
149
148
  )
150
149
  self.logits_processor = LogitsProcessor(config)
151
150
 
sglang/srt/models/bert.py CHANGED
@@ -1,5 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
- from typing import Any, Dict, Iterable, Optional, Set, Tuple
2
+ from typing import Iterable, Optional, Set, Tuple
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -25,19 +25,27 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
25
25
  from sglang.srt.layers.dp_attention import is_dp_attention_enabled
26
26
  from sglang.srt.layers.layernorm import RMSNorm
27
27
  from sglang.srt.layers.logits_processor import LogitsProcessor
28
+ from sglang.srt.layers.quantization import Fp8Config
28
29
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
29
30
  from sglang.srt.layers.vocab_parallel_embedding import (
30
31
  ParallelLMHead,
31
32
  VocabParallelEmbedding,
32
33
  )
33
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
- from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
36
- from sglang.srt.utils import BumpAllocator, add_prefix
35
+ from sglang.srt.models.deepseek_v2 import (
36
+ DeepseekV2DecoderLayer,
37
+ DeepseekV3ForCausalLM,
38
+ enable_nextn_moe_bf16_cast_to_fp8,
39
+ )
40
+ from sglang.srt.server_args import get_global_server_args
41
+ from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda
37
42
 
38
43
  logger = logging.getLogger(__name__)
39
44
 
40
45
 
46
+ _is_cuda = is_cuda()
47
+
48
+
41
49
  class DeepseekModelNextN(nn.Module):
42
50
  def __init__(
43
51
  self,
@@ -46,6 +54,16 @@ class DeepseekModelNextN(nn.Module):
46
54
  prefix: str = "",
47
55
  ) -> None:
48
56
  super().__init__()
57
+
58
+ if enable_nextn_moe_bf16_cast_to_fp8(quant_config):
59
+ # refer to real DeepSeek V3 quant config
60
+ moe_quant_config = Fp8Config(
61
+ is_checkpoint_fp8_serialized=True,
62
+ weight_block_size=[128, 128],
63
+ )
64
+ else:
65
+ moe_quant_config = None
66
+
49
67
  if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
50
68
  logger.warning(
51
69
  "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
@@ -66,12 +84,15 @@ class DeepseekModelNextN(nn.Module):
66
84
 
67
85
  self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
68
86
 
87
+ self.alt_stream = torch.cuda.Stream() if _is_cuda else None
69
88
  self.decoder = DeepseekV2DecoderLayer(
70
89
  config,
71
90
  0,
72
91
  quant_config=quant_config,
92
+ moe_quant_config=moe_quant_config,
73
93
  is_nextn=True,
74
94
  prefix=add_prefix("decoder", prefix),
95
+ alt_stream=self.alt_stream,
75
96
  )
76
97
 
77
98
  self.shared_head = nn.Module()
@@ -147,7 +168,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
147
168
  config.hidden_size,
148
169
  quant_config=quant_config,
149
170
  prefix=add_prefix("model.shared_head.head", prefix),
150
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
171
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
151
172
  )
152
173
  self.logits_processor = LogitsProcessor(config)
153
174