sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,8 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
15
+ """Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
16
+
16
17
  import logging
17
18
  from typing import Iterable, Optional, Tuple
18
19
 
@@ -30,10 +31,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
30
31
  ParallelLMHead,
31
32
  VocabParallelEmbedding,
32
33
  )
33
- from sglang.srt.managers.schedule_batch import global_server_args_dict
34
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
35
  from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
36
- from sglang.srt.utils import BumpAllocator, add_prefix
36
+ from sglang.srt.server_args import get_global_server_args
37
+ from sglang.srt.utils import add_prefix
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
84
85
  forward_batch: ForwardBatch,
85
86
  input_embeds: torch.Tensor = None,
86
87
  ) -> torch.Tensor:
87
- zero_allocator = BumpAllocator(
88
- buffer_size=2,
89
- dtype=torch.float32,
90
- device=(
91
- input_embeds.device if input_embeds is not None else input_ids.device
92
- ),
93
- )
94
-
95
88
  if input_embeds is None:
96
89
  hidden_states = self.embed_tokens(input_ids)
97
90
  else:
@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
111
104
  residual = None
112
105
  with get_global_expert_distribution_recorder().disable_this_region():
113
106
  hidden_states, residual = self.decoder(
114
- positions, hidden_states, forward_batch, residual, zero_allocator
107
+ positions, hidden_states, forward_batch, residual
115
108
  )
116
109
 
117
110
  if not forward_batch.forward_mode.is_idle():
@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
124
117
 
125
118
 
126
119
  class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
127
-
128
120
  def __init__(
129
121
  self,
130
122
  config: PretrainedConfig,
@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
135
127
  self.config = config
136
128
  self.tp_size = get_tensor_model_parallel_world_size()
137
129
  self.quant_config = quant_config
138
- self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
139
-
140
130
  self.model = Glm4MoeModelNextN(
141
131
  config, quant_config, prefix=add_prefix("model", prefix)
142
132
  )
@@ -145,7 +135,7 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
145
135
  config.hidden_size,
146
136
  quant_config=quant_config,
147
137
  prefix=add_prefix("model.shared_head.head", prefix),
148
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
138
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
149
139
  )
150
140
  self.logits_processor = LogitsProcessor(config)
151
141
 
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
9
9
 
10
10
  from sglang.srt.layers.activation import SiluAndMul
11
11
  from sglang.srt.layers.attention import vision_utils
12
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
12
13
  from sglang.srt.layers.layernorm import RMSNorm
13
14
  from sglang.srt.layers.linear import (
14
15
  ColumnParallelLinear,
@@ -434,7 +435,7 @@ class Glm4vVisionModel(nn.Module):
434
435
  cu_seqlens = torch.repeat_interleave(
435
436
  grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
436
437
  ).cumsum(dim=0, dtype=torch.int32)
437
- cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
438
+ cu_seqlens = torch.cat([cu_seqlens.new_zeros(1), cu_seqlens])
438
439
 
439
440
  seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
440
441
  x = self.embeddings(
@@ -6,21 +6,18 @@ import torch
6
6
  import torch.nn as nn
7
7
  from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
8
8
 
9
- from sglang.srt.distributed import (
10
- get_moe_expert_parallel_world_size,
11
- get_tensor_model_parallel_world_size,
12
- )
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
13
10
  from sglang.srt.layers.attention import vision_utils
14
11
  from sglang.srt.layers.logits_processor import LogitsProcessor
15
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
12
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
16
13
  from sglang.srt.layers.pooler import Pooler, PoolingType
17
14
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
15
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
19
- from sglang.srt.managers.schedule_batch import global_server_args_dict
20
16
  from sglang.srt.model_loader.weight_utils import default_weight_loader
21
17
  from sglang.srt.models.glm4_moe import Glm4MoeModel
22
18
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
23
- from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
19
+ from sglang.srt.server_args import get_global_server_args
20
+ from sglang.srt.utils import add_prefix, is_cuda
24
21
  from sglang.srt.utils.hf_transformers_utils import get_processor
25
22
 
26
23
  _is_cuda = is_cuda()
@@ -39,15 +36,13 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
39
36
  ) -> None:
40
37
  nn.Module.__init__(self)
41
38
 
42
- config.moe_layer_freq = 1
43
39
  self.config = config
44
40
  vision_utils.update_vit_attn_dummy_heads_config(self.config)
45
41
  self.tp_size = get_tensor_model_parallel_world_size()
46
42
  self.quant_config = quant_config
47
- self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
48
43
  self.num_fused_shared_experts = (
49
44
  0
50
- if global_server_args_dict["disable_shared_experts_fusion"]
45
+ if get_global_server_args().disable_shared_experts_fusion
51
46
  else config.n_shared_experts
52
47
  )
53
48
 
@@ -68,7 +63,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
68
63
  config.hidden_size,
69
64
  quant_config=quant_config,
70
65
  prefix=add_prefix("lm_head", prefix),
71
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
66
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
72
67
  )
73
68
  self.logits_processor = LogitsProcessor(config)
74
69
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
77
72
  # For EAGLE3 support
78
73
  self.capture_aux_hidden_states = False
79
74
 
80
- def determine_num_fused_shared_experts(
81
- self, architecture: str = "Glm4MoeForCausalLM"
82
- ):
83
- self.num_fused_shared_experts = 0
84
- if global_server_args_dict["disable_shared_experts_fusion"]:
85
- return
86
-
87
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
88
- disable_reason = None
89
- if (
90
- not _is_cuda
91
- or torch.cuda.get_device_capability("cuda") < (8, 0)
92
- or self.config.architectures[0] != architecture
93
- or self.config.n_shared_experts != 1
94
- ):
95
- disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
96
- elif get_moe_expert_parallel_world_size() > 1:
97
- disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
98
-
99
- if disable_reason is not None:
100
- global_server_args_dict["disable_shared_experts_fusion"] = True
101
- self.num_fused_shared_experts = 0
102
- log_info_on_rank0(
103
- logger,
104
- f"{disable_reason} Shared experts fusion optimization is disabled.",
105
- )
106
- return
107
-
108
- self.num_fused_shared_experts = self.config.n_shared_experts
109
-
110
75
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
111
-
112
76
  if is_nextn:
113
77
  if hasattr(self.config, "num_nextn_predict_layers"):
114
78
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
130
94
  ("gate_up_proj", "gate_proj", 0),
131
95
  ("gate_up_proj", "up_proj", 1),
132
96
  ]
133
- if self.num_fused_shared_experts > 0:
134
- assert self.num_fused_shared_experts == 1
135
- weights_list = list(weights)
136
- weights_dict = dict(weights_list)
137
- if self.quant_config is not None:
138
- if self.quant_config.get_name() == "w8a8_int8":
139
- suffix_list = [
140
- "down_proj.weight",
141
- "down_proj.weight_scale",
142
- "gate_proj.weight",
143
- "gate_proj.weight_scale",
144
- "up_proj.weight",
145
- "up_proj.weight_scale",
146
- ]
147
- elif (
148
- self.quant_config.get_name() == "fp8"
149
- or self.quant_config.get_name() == "blockwise_int8"
150
- or self.quant_config.get_name() == "compressed_tensors"
151
- ):
152
- suffix_list = [
153
- "down_proj.weight",
154
- "down_proj.weight_scale",
155
- "gate_proj.weight",
156
- "gate_proj.weight_scale",
157
- "up_proj.weight",
158
- "up_proj.weight_scale",
159
- ]
160
- elif self.quant_config.get_name() == "awq":
161
- suffix_list = [
162
- "down_proj.qweight",
163
- "down_proj.qzeros",
164
- "down_proj.scales",
165
- "gate_proj.qweight",
166
- "gate_proj.qzeros",
167
- "gate_proj.scales",
168
- "up_proj.qweight",
169
- "up_proj.qzeros",
170
- "up_proj.scales",
171
- ]
172
- elif self.quant_config.get_name() == "modelopt_fp4":
173
- suffix_list = [
174
- "down_proj.weight",
175
- "down_proj.weight_scale",
176
- "down_proj.weight_scale_2",
177
- "down_proj.input_scale",
178
- "gate_proj.weight",
179
- "gate_proj.weight_scale",
180
- "gate_proj.weight_scale_2",
181
- "gate_proj.input_scale",
182
- "up_proj.weight",
183
- "up_proj.weight_scale",
184
- "up_proj.weight_scale_2",
185
- "up_proj.input_scale",
186
- ]
187
- else:
188
- raise ValueError(
189
- f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
190
- )
191
- else:
192
- suffix_list = [
193
- "down_proj.weight",
194
- "gate_proj.weight",
195
- "up_proj.weight",
196
- ]
197
- names_to_remove = []
198
-
199
- moe_layers = (
200
- range(
201
- self.config.first_k_dense_replace,
202
- self.config.num_hidden_layers,
203
- self.config.moe_layer_freq,
204
- )
205
- if not is_nextn
206
- else [nextn_layer_id]
207
- )
208
97
 
209
- for moe_layer in moe_layers:
210
- for suffix in suffix_list:
211
- shared_expert_weight_name = (
212
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
213
- )
214
- # online fp8 quantization does not load weight_scale
215
- if shared_expert_weight_name not in weights_dict:
216
- continue
217
- weights_list.append(
218
- (
219
- f"model.layers.{moe_layer}."
220
- f"mlp.experts."
221
- f"{self.config.n_routed_experts + 0}"
222
- f".{suffix}",
223
- weights_dict[shared_expert_weight_name],
224
- )
225
- )
226
- names_to_remove += [shared_expert_weight_name]
227
- weights = [w for w in weights_list if w[0] not in names_to_remove]
228
-
229
- # Params for weights, fp8 weight scales, fp8 activation scales
230
- # (param_name, weight_name, expert_id, shard_id)
231
98
  expert_params_mapping = FusedMoE.make_expert_params_mapping(
232
99
  ckpt_gate_proj_name="gate_proj",
233
100
  ckpt_down_proj_name="down_proj",
234
101
  ckpt_up_proj_name="up_proj",
235
- num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
102
+ num_experts=self.config.n_routed_experts,
236
103
  )
237
104
 
238
- # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
239
- fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
240
- self.config.q_lora_rank is not None
241
- )
242
- cached_a_proj = {} if fuse_qkv_a_proj else None
243
-
244
105
  if is_nextn:
245
106
  nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
246
107
  nextn_spec_weight_names = [
@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
300
161
  # name will be updated to mlp.experts[0].gate_up_proj, which
301
162
  # will then be updated below in expert_params_mapping
302
163
  # for mlp.experts[0].gate_gate_up_proj, which breaks load.
303
- if ("mlp.experts." in name) and name not in params_dict:
164
+ if "mlp.experts" in name:
304
165
  continue
305
166
  name = name.replace(weight_name, param_name)
306
167
  # Skip loading extra bias for GPTQ models.
307
168
  if name.endswith(".bias") and name not in params_dict:
308
169
  continue
309
- param = params_dict[name]
170
+ if name not in params_dict:
171
+ continue
310
172
 
173
+ param = params_dict[name]
311
174
  weight_loader = param.weight_loader
312
175
  weight_loader(param, loaded_weight, shard_id)
313
176
  break
314
177
  else:
178
+ # Track if this is an expert weight to enable early skipping
179
+ is_expert_weight = False
180
+
315
181
  for mapping in expert_params_mapping:
316
182
  param_name, weight_name, expert_id, shard_id = mapping
317
183
  if weight_name not in name:
318
184
  continue
185
+
186
+ # Mark as expert weight regardless of whether we can process it
187
+ is_expert_weight = True
188
+
319
189
  name = name.replace(weight_name, param_name)
190
+ if name not in params_dict:
191
+ # Expert weight not on this rank, will be skipped below
192
+ continue
193
+
320
194
  param = params_dict[name]
321
195
  weight_loader = param.weight_loader
322
196
  weight_loader(
@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
328
202
  )
329
203
  break
330
204
  else:
205
+ if is_expert_weight:
206
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
207
+ continue
208
+
331
209
  if "visual" in name:
332
- # adapt to VisionAttention
210
+ # adapt to VisionAttention for GLM-V
333
211
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
334
212
 
335
213
  # Skip loading extra bias for GPTQ models.
336
214
  if name.endswith(".bias") and name not in params_dict:
337
215
  continue
338
- if fuse_qkv_a_proj and (
339
- "q_a_proj" in name or "kv_a_proj_with_mqa" in name
340
- ):
341
- cached_a_proj[name] = loaded_weight
342
- q_a_proj_name = (
343
- name
344
- if "q_a_proj" in name
345
- else name.replace("kv_a_proj_with_mqa", "q_a_proj")
346
- )
347
- kv_a_proj_name = (
348
- name
349
- if "kv_a_proj_with_mqa" in name
350
- else name.replace("q_a_proj", "kv_a_proj_with_mqa")
351
- )
352
-
353
- # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
354
- if (
355
- q_a_proj_name in cached_a_proj
356
- and kv_a_proj_name in cached_a_proj
357
- ):
358
- q_a_proj_weight = cached_a_proj[q_a_proj_name]
359
- kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
360
- fused_weight = torch.cat(
361
- [q_a_proj_weight, kv_a_proj_weight], dim=0
362
- )
363
- param_name = (
364
- name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
365
- if "q_a_proj" in name
366
- else name.replace(
367
- "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
368
- )
369
- )
370
- param = params_dict[param_name]
216
+ if name not in params_dict:
217
+ continue
371
218
 
372
- weight_loader = getattr(
373
- param, "weight_loader", default_weight_loader
374
- )
375
- weight_loader(param, fused_weight)
376
- cached_a_proj.pop(q_a_proj_name)
377
- cached_a_proj.pop(kv_a_proj_name)
378
- else:
379
- if (
380
- "k_scale" in name or "v_scale" in name
381
- ) and name not in params_dict:
382
- # modelopt attn kv scale is named differently
383
- if any(scale in name for scale in ["k_scale", "v_scale"]):
384
- name = name.replace("_proj", "attn_mqa")
385
- else:
386
- logger.warning(
387
- f"Unknown scale found in checkpoint: {name}"
388
- )
219
+ if name in params_dict.keys():
389
220
  param = params_dict[name]
390
221
  weight_loader = getattr(
391
222
  param, "weight_loader", default_weight_loader
@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
395
226
  self.config, name, loaded_weight
396
227
  )
397
228
  weight_loader(param, loaded_weight)
229
+ else:
230
+ logger.warning(f"Parameter {name} not found in params_dict")
398
231
 
399
232
 
400
233
  EntryClass = [Glm4vMoeForConditionalGeneration]
@@ -63,13 +63,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
63
63
  ParallelLMHead,
64
64
  VocabParallelEmbedding,
65
65
  )
66
- from sglang.srt.managers.schedule_batch import global_server_args_dict
67
66
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
68
67
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
68
  from sglang.srt.models.utils import (
70
69
  create_fused_set_kv_buffer_arg,
71
70
  enable_fused_set_kv_buffer,
72
71
  )
72
+ from sglang.srt.server_args import get_global_server_args
73
73
  from sglang.srt.utils import (
74
74
  LazyValue,
75
75
  add_prefix,
@@ -85,7 +85,7 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
85
85
 
86
86
 
87
87
  if _is_cuda:
88
- from sgl_kernel import FusedSetKVBufferArg
88
+ from sgl_kernel import FusedSetKVBufferArg # noqa: F401
89
89
 
90
90
 
91
91
  class GptOssConfig(PretrainedConfig):
@@ -138,7 +138,7 @@ class GptOssSparseMoeBlock(nn.Module):
138
138
  }
139
139
  self.experts = experts_type(
140
140
  num_experts=config.num_local_experts
141
- + global_server_args_dict["ep_num_redundant_experts"],
141
+ + get_global_server_args().ep_num_redundant_experts,
142
142
  top_k=config.num_experts_per_tok,
143
143
  layer_id=layer_id,
144
144
  hidden_size=config.hidden_size,
@@ -259,7 +259,7 @@ class GptOssAttention(nn.Module):
259
259
 
260
260
  # Choose dtype of sinks based on attention backend: trtllm_mha requires float32,
261
261
  # others can use bfloat16
262
- attn_backend = global_server_args_dict.get("attention_backend")
262
+ attn_backend = get_global_server_args().attention_backend
263
263
  sinks_dtype = torch.float32 if attn_backend == "trtllm_mha" else torch.bfloat16
264
264
  self.sinks = nn.Parameter(
265
265
  torch.empty(self.num_heads, dtype=sinks_dtype), requires_grad=False
@@ -591,7 +591,7 @@ class GptOssForCausalLM(nn.Module):
591
591
  config.hidden_size,
592
592
  # quant_config=quant_config,
593
593
  prefix=add_prefix("lm_head", prefix),
594
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
594
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
595
595
  )
596
596
  self.logits_processor = LogitsProcessor(config)
597
597
  self.capture_aux_hidden_states = False
sglang/srt/models/grok.py CHANGED
@@ -28,7 +28,6 @@ from torch import nn
28
28
  from transformers import PretrainedConfig
29
29
 
30
30
  from sglang.srt.distributed import (
31
- get_moe_expert_parallel_world_size,
32
31
  get_tensor_model_parallel_rank,
33
32
  get_tensor_model_parallel_world_size,
34
33
  tensor_model_parallel_all_gather,
@@ -36,7 +35,6 @@ from sglang.srt.distributed import (
36
35
  )
37
36
  from sglang.srt.layers.activation import GeluAndMul
38
37
  from sglang.srt.layers.elementwise import (
39
- experts_combine_triton,
40
38
  fused_dual_residual_rmsnorm,
41
39
  fused_rmsnorm,
42
40
  gelu_and_mul_triton,
@@ -49,7 +47,6 @@ from sglang.srt.layers.linear import (
49
47
  RowParallelLinear,
50
48
  )
51
49
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
53
50
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
51
  from sglang.srt.layers.moe.router import fused_moe_router_shim
55
52
  from sglang.srt.layers.moe.topk import TopK
@@ -65,10 +62,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
65
62
  ParallelLMHead,
66
63
  VocabParallelEmbedding,
67
64
  )
68
- from sglang.srt.managers.schedule_batch import global_server_args_dict
69
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
70
66
  from sglang.srt.model_loader.loader import DefaultModelLoader
71
67
  from sglang.srt.model_loader.weight_utils import default_weight_loader
68
+ from sglang.srt.server_args import get_global_server_args
72
69
  from sglang.srt.utils import add_prefix, dispose_tensor, dump_to_file
73
70
 
74
71
  logger = logging.getLogger(__name__)
@@ -76,9 +73,6 @@ logger = logging.getLogger(__name__)
76
73
 
77
74
  # Dump tensors for debugging
78
75
  debug_tensor_dump_output_folder = None
79
- debug_tensor_dump_prefill_only = False
80
- # Skip all the other tensor dumps, only dump the target logits
81
- debug_tensor_dump_only_target_logprobs = False
82
76
  debug_tensor_dump_inject = False
83
77
  debug_tensor_dump_layers = None
84
78
  debug_tensor_dump_test = False
@@ -176,17 +170,7 @@ class Grok1MoE(nn.Module):
176
170
  custom_routing_function=custom_routing_function,
177
171
  )
178
172
 
179
- kwargs = {}
180
- if get_moe_expert_parallel_world_size() > 1:
181
- MoEImpl = EPMoE
182
- else:
183
- MoEImpl = FusedMoE
184
- kwargs["reduce_results"] = reduce_results
185
- kwargs["use_presharded_weights"] = use_presharded_weights
186
- kwargs["inplace"] = inplace
187
- kwargs["no_combine"] = no_combine
188
-
189
- self.experts = MoEImpl(
173
+ self.experts = FusedMoE(
190
174
  num_experts=num_experts,
191
175
  top_k=top_k,
192
176
  layer_id=layer_id,
@@ -195,7 +179,10 @@ class Grok1MoE(nn.Module):
195
179
  params_dtype=params_dtype,
196
180
  quant_config=quant_config,
197
181
  activation="gelu",
198
- **kwargs,
182
+ reduce_results=reduce_results,
183
+ use_presharded_weights=use_presharded_weights,
184
+ inplace=inplace,
185
+ no_combine=no_combine,
199
186
  )
200
187
 
201
188
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -877,10 +864,10 @@ class Grok1ForCausalLM(nn.Module):
877
864
 
878
865
  # Dump tensors for debugging
879
866
  global debug_tensor_dump_output_folder, debug_tensor_dump_inject
880
- debug_tensor_dump_output_folder = global_server_args_dict[
881
- "debug_tensor_dump_output_folder"
882
- ]
883
- debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
867
+ debug_tensor_dump_output_folder = (
868
+ get_global_server_args().debug_tensor_dump_output_folder
869
+ )
870
+ debug_tensor_dump_inject = get_global_server_args().debug_tensor_dump_inject
884
871
  warnings.filterwarnings("ignore", category=FutureWarning)
885
872
 
886
873
  if get_tensor_model_parallel_rank() == 0:
@@ -12,18 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  """Inference-only HunYuan model compatible with HuggingFace weights."""
15
- import logging
16
15
  import re
17
- from dataclasses import dataclass
18
- from enum import Enum, auto
19
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
16
+ from typing import Any, Dict, Iterable, Optional, Tuple
20
17
 
21
18
  import torch
22
19
  from torch import nn
23
20
  from transformers import PretrainedConfig
24
21
 
25
22
  from sglang.srt.distributed import (
26
- get_pp_group,
27
23
  get_tensor_model_parallel_rank,
28
24
  get_tensor_model_parallel_world_size,
29
25
  tensor_model_parallel_all_reduce,
@@ -46,7 +42,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
46
42
  from sglang.srt.layers.rotary_embedding import get_rope
47
43
  from sglang.srt.layers.sampler import Sampler
48
44
  from sglang.srt.layers.vocab_parallel_embedding import (
49
- DEFAULT_VOCAB_PADDING_SIZE,
50
45
  ParallelLMHead,
51
46
  VocabParallelEmbedding,
52
47
  )
@@ -56,7 +51,7 @@ from sglang.srt.model_loader.weight_utils import (
56
51
  kv_cache_scales_loader,
57
52
  maybe_remap_kv_scale_name,
58
53
  )
59
- from sglang.srt.utils import add_prefix, is_hip
54
+ from sglang.srt.utils import is_hip
60
55
 
61
56
  expert_distribution_recorder = ExpertDistributionRecorder()
62
57
 
@@ -5,7 +5,6 @@ from torch import nn
5
5
  from transformers import PretrainedConfig
6
6
 
7
7
  from sglang.srt.layers.attention import vision_utils
8
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
9
8
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
10
9
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
11
10
  from sglang.srt.managers.mm_utils import (
@@ -43,10 +43,8 @@
43
43
 
44
44
  import copy
45
45
  import logging
46
- import math
47
- from collections.abc import Mapping
48
46
  from dataclasses import dataclass
49
- from typing import Any, Iterable, List, Optional, Tuple
47
+ from typing import Iterable, List, Optional, Tuple
50
48
 
51
49
  import torch
52
50
  from torch import nn
@@ -56,10 +54,6 @@ from sglang.srt.configs import KimiVLConfig
56
54
  from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
57
55
  from sglang.srt.configs.kimi_vl import KimiVLConfig
58
56
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
59
- from sglang.srt.distributed import (
60
- get_tensor_model_parallel_rank,
61
- get_tensor_model_parallel_world_size,
62
- )
63
57
  from sglang.srt.layers.activation import QuickGELU
64
58
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
65
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
49
49
  import torch
50
50
  import torch.nn as nn
51
51
  import torch.nn.functional as F
52
- from transformers.activations import ACT2FN, GELUTanh
52
+ from transformers.activations import ACT2FN
53
53
  from transformers.modeling_utils import PreTrainedModel
54
54
 
55
55
  try:
@@ -596,6 +596,8 @@ class MoonVitPretrainedModel(PreTrainedModel):
596
596
  _supports_sdpa = True
597
597
 
598
598
  def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
599
+ from transformers.activations import GELUTanh
600
+
599
601
  super().__init__(config, *inputs, **kwargs)
600
602
  config = deepcopy(config)
601
603
  self.merge_kernel_size = config.merge_kernel_size
@@ -45,13 +45,13 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  ParallelLMHead,
46
46
  VocabParallelEmbedding,
47
47
  )
48
- from sglang.srt.managers.schedule_batch import global_server_args_dict
49
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
50
49
  from sglang.srt.model_loader.weight_utils import (
51
50
  default_weight_loader,
52
51
  kv_cache_scales_loader,
53
52
  maybe_remap_kv_scale_name,
54
53
  )
54
+ from sglang.srt.server_args import get_global_server_args
55
55
  from sglang.srt.utils import add_prefix, make_layers
56
56
  from sglang.utils import get_exception_traceback
57
57
 
@@ -433,7 +433,7 @@ class LlamaForCausalLM(nn.Module):
433
433
  config.hidden_size,
434
434
  quant_config=quant_config,
435
435
  prefix=add_prefix("lm_head", prefix),
436
- use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
436
+ use_attn_tp_group=get_global_server_args().enable_dp_lm_head,
437
437
  )
438
438
  self.logits_processor = LogitsProcessor(config)
439
439
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)