sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,343 @@
1
+ # Copyright 2025 Qwen Team
2
+ # Copyright 2025 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
+ import logging
17
+ from functools import lru_cache
18
+ from typing import Iterable, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
24
+ from sglang.srt.distributed import (
25
+ get_moe_expert_parallel_world_size,
26
+ get_tensor_model_parallel_rank,
27
+ )
28
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
29
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
31
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
32
+ from sglang.srt.models.qwen3_moe import Qwen3MoeModel
33
+ from sglang.srt.models.qwen3_vl import Qwen3VLForConditionalGeneration
34
+ from sglang.srt.utils.hf_transformers_utils import get_processor
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ cached_get_processor = lru_cache(get_processor)
39
+
40
+
41
+ class Qwen3MoeLLMModel(Qwen3MoeModel):
42
+ def __init__(
43
+ self,
44
+ *,
45
+ config: Qwen3VLMoeTextConfig,
46
+ quant_config: Optional[QuantizationConfig] = None,
47
+ prefix: str = "",
48
+ ):
49
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
50
+ self.hidden_size = config.hidden_size
51
+
52
+ def get_input_embeddings(self) -> nn.Embedding:
53
+ return self.embed_tokens
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.Tensor,
58
+ positions: torch.Tensor,
59
+ forward_batch: ForwardBatch,
60
+ input_embeds: torch.Tensor = None,
61
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
62
+ input_deepstack_embeds: Optional[torch.Tensor] = None,
63
+ ) -> Union[torch.Tensor, PPProxyTensors]:
64
+ if self.pp_group.is_first_rank:
65
+ if input_embeds is None:
66
+ hidden_states = self.embed_tokens(input_ids)
67
+ else:
68
+ hidden_states = input_embeds
69
+ residual = None
70
+ else:
71
+ assert pp_proxy_tensors is not None
72
+ hidden_states = pp_proxy_tensors["hidden_states"]
73
+ residual = pp_proxy_tensors["residual"]
74
+
75
+ aux_hidden_states = []
76
+ for layer_idx, layer in enumerate(
77
+ self.layers[self.start_layer : self.end_layer]
78
+ ):
79
+ layer_idx += self.start_layer
80
+ if layer_idx in self.layers_to_capture:
81
+ aux_hidden_states.append(
82
+ hidden_states + residual if residual is not None else hidden_states
83
+ )
84
+
85
+ hidden_states, residual = layer(
86
+ positions,
87
+ hidden_states,
88
+ forward_batch,
89
+ residual,
90
+ )
91
+
92
+ # process deepstack
93
+ if input_deepstack_embeds is not None and layer_idx < 3:
94
+ sep = self.hidden_size * layer_idx
95
+ hidden_states.add_(
96
+ input_deepstack_embeds[:, sep : sep + self.hidden_size]
97
+ )
98
+
99
+ if not self.pp_group.is_last_rank:
100
+ return PPProxyTensors(
101
+ {
102
+ "hidden_states": hidden_states,
103
+ "residual": residual,
104
+ }
105
+ )
106
+ else:
107
+ if hidden_states.shape[0] != 0:
108
+ if residual is None:
109
+ hidden_states = self.norm(hidden_states)
110
+ else:
111
+ hidden_states, _ = self.norm(hidden_states, residual)
112
+
113
+ if len(aux_hidden_states) == 0:
114
+ return hidden_states
115
+
116
+ return hidden_states, aux_hidden_states
117
+
118
+
119
+ def load_fused_expert_weights(
120
+ name: str,
121
+ params_dict: dict,
122
+ loaded_weight: torch.Tensor,
123
+ shard_id: str,
124
+ num_experts: int,
125
+ ):
126
+ param = params_dict[name]
127
+ # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
128
+ weight_loader = param.weight_loader
129
+ ep_rank = get_tensor_model_parallel_rank()
130
+ ep_size = get_moe_expert_parallel_world_size()
131
+ if ep_size == 1:
132
+ for expert_id in range(num_experts):
133
+ curr_expert_weight = loaded_weight[expert_id]
134
+ weight_loader(
135
+ param,
136
+ curr_expert_weight,
137
+ name,
138
+ shard_id,
139
+ expert_id,
140
+ )
141
+ else:
142
+ experts_per_ep = num_experts // ep_size
143
+ start_expert = ep_rank * experts_per_ep
144
+ end_expert = (
145
+ (ep_rank + 1) * experts_per_ep if ep_rank != ep_size - 1 else num_experts
146
+ )
147
+
148
+ for idx, expert_id in enumerate(range(start_expert, end_expert)):
149
+ curr_expert_weight = loaded_weight[expert_id]
150
+ weight_loader(
151
+ param,
152
+ curr_expert_weight,
153
+ name,
154
+ shard_id,
155
+ idx,
156
+ )
157
+ return True
158
+
159
+
160
+ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
161
+ def __init__(
162
+ self,
163
+ config: Qwen3VLMoeConfig,
164
+ quant_config: Optional[QuantizationConfig] = None,
165
+ prefix: str = "",
166
+ language_model_cls=Qwen3MoeLLMModel,
167
+ ):
168
+ super().__init__(config, quant_config, prefix, language_model_cls)
169
+
170
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
171
+ stacked_params_mapping = [
172
+ # (param_name, shard_name, shard_id)
173
+ (".qkv_proj", ".q_proj", "q"),
174
+ (".qkv_proj", ".k_proj", "k"),
175
+ (".qkv_proj", ".v_proj", "v"),
176
+ ("gate_up_proj", "up_proj", 1),
177
+ ("gate_up_proj", "gate_proj", 0),
178
+ ]
179
+
180
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
181
+ ckpt_gate_proj_name="gate_proj",
182
+ ckpt_down_proj_name="down_proj",
183
+ ckpt_up_proj_name="up_proj",
184
+ num_experts=self.config.num_experts,
185
+ )
186
+
187
+ # Skip loading extra parameters for GPTQ/modelopt models.
188
+ ignore_suffixes = (
189
+ ".bias",
190
+ "_bias",
191
+ ".k_scale",
192
+ "_k_scale",
193
+ ".v_scale",
194
+ "_v_scale",
195
+ ".weight_scale",
196
+ "_weight_scale",
197
+ ".input_scale",
198
+ "_input_scale",
199
+ )
200
+
201
+ is_fused_expert = False
202
+ fused_expert_params_mapping = [
203
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
204
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
205
+ ]
206
+
207
+ num_experts = self.config.num_experts
208
+
209
+ # Cache params_dict to avoid repeated expensive traversal of model parameters
210
+ if not hasattr(self, "_cached_params_dict"):
211
+ self._cached_params_dict = dict(self.named_parameters())
212
+ params_dict = self._cached_params_dict
213
+ for name, loaded_weight in weights:
214
+ name = name.replace(r"model.language_model.", r"model.")
215
+
216
+ for param_name, weight_name, shard_id in stacked_params_mapping:
217
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
218
+ is_fused_expert = True
219
+ expert_params_mapping = fused_expert_params_mapping
220
+
221
+ # Skip non-stacked layers and experts (experts handled below).
222
+ if weight_name not in name:
223
+ continue
224
+ if "visual" in name:
225
+ continue
226
+
227
+ # We have mlp.experts[0].gate_proj in the checkpoint.
228
+ # Since we handle the experts below in expert_params_mapping,
229
+ # we need to skip here BEFORE we update the name, otherwise
230
+ # name will be updated to mlp.experts[0].gate_up_proj, which
231
+ # will then be updated below in expert_params_mapping
232
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
233
+ if "mlp.experts" in name:
234
+ continue
235
+ name = name.replace(weight_name, param_name)
236
+ # Skip loading extra parameters for GPTQ/modelopt models.
237
+ if name.endswith(ignore_suffixes) and name not in params_dict:
238
+ continue
239
+ # [TODO] Skip layers that are on other devices (check if sglang has a similar function)
240
+ # if is_pp_missing_parameter(name, self):
241
+ # continue
242
+
243
+ if name not in params_dict:
244
+ continue
245
+
246
+ param = params_dict[name]
247
+ weight_loader = param.weight_loader
248
+ weight_loader(param, loaded_weight, shard_id)
249
+ break
250
+ else:
251
+ # Track if this is an expert weight to enable early skipping
252
+ is_expert_weight = False
253
+
254
+ for mapping in expert_params_mapping:
255
+ param_name, weight_name, expert_id, shard_id = mapping
256
+ if weight_name not in name:
257
+ continue
258
+ if "visual" in name:
259
+ continue
260
+ # Anyway, this is an expert weight and should not be
261
+ # attempted to load as other weights later
262
+ is_expert_weight = True
263
+ name_mapped = name.replace(weight_name, param_name)
264
+ if is_fused_expert:
265
+ loaded_weight = loaded_weight.transpose(-1, -2) # no bias
266
+ if "experts.gate_up_proj" in name:
267
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
268
+ load_fused_expert_weights(
269
+ name_mapped,
270
+ params_dict,
271
+ loaded_weight[0],
272
+ "w1",
273
+ num_experts,
274
+ )
275
+ load_fused_expert_weights(
276
+ name_mapped,
277
+ params_dict,
278
+ loaded_weight[1],
279
+ "w3",
280
+ num_experts,
281
+ )
282
+ else:
283
+ load_fused_expert_weights(
284
+ name_mapped,
285
+ params_dict,
286
+ loaded_weight,
287
+ shard_id,
288
+ num_experts,
289
+ )
290
+ else:
291
+ # Skip loading extra parameters for GPTQ/modelopt models.
292
+ if (
293
+ name_mapped.endswith(ignore_suffixes)
294
+ and name_mapped not in params_dict
295
+ ):
296
+ continue
297
+ param = params_dict[name_mapped]
298
+ # We should ask the weight loader to return success or
299
+ # not here since otherwise we may skip experts with
300
+ # # other available replicas.
301
+ weight_loader = param.weight_loader
302
+ weight_loader(
303
+ param,
304
+ loaded_weight,
305
+ name_mapped,
306
+ shard_id=shard_id,
307
+ expert_id=expert_id,
308
+ )
309
+ name = name_mapped
310
+ break
311
+ else:
312
+ if is_expert_weight:
313
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
314
+ continue
315
+ if "visual" in name:
316
+ # adapt to VisionAttention
317
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
318
+ name = name.replace(r"model.visual.", r"visual.")
319
+
320
+ # Skip loading extra parameters for GPTQ/modelopt models.
321
+ if name.endswith(ignore_suffixes) and name not in params_dict:
322
+ continue
323
+
324
+ if name in params_dict.keys():
325
+ param = params_dict[name]
326
+ weight_loader = getattr(
327
+ param, "weight_loader", default_weight_loader
328
+ )
329
+ weight_loader(param, loaded_weight)
330
+ else:
331
+ logger.warning(f"Parameter {name} not found in params_dict")
332
+
333
+ # TODO mimic deepseek
334
+ # Lazy initialization of expert weights cache to avoid slowing down load_weights
335
+ # if not hasattr(self, "routed_experts_weights_of_layer"):
336
+ # self.routed_experts_weights_of_layer = {
337
+ # layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
338
+ # for layer_id in range(self.start_layer, self.end_layer)
339
+ # if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
340
+ # }
341
+
342
+
343
+ EntryClass = Qwen3VLMoeForConditionalGeneration
@@ -17,6 +17,18 @@ class _ModelRegistry:
17
17
  # Keyed by model_arch
18
18
  models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
19
19
 
20
+ def register(self, package_name: str, overwrite: bool = False):
21
+ new_models = import_model_classes(package_name)
22
+ if overwrite:
23
+ self.models.update(new_models)
24
+ else:
25
+ for arch, cls in new_models.items():
26
+ if arch in self.models:
27
+ raise ValueError(
28
+ f"Model architecture {arch} already registered. Set overwrite=True to replace."
29
+ )
30
+ self.models[arch] = cls
31
+
20
32
  def get_supported_archs(self) -> AbstractSet[str]:
21
33
  return self.models.keys()
22
34
 
@@ -74,9 +86,8 @@ class _ModelRegistry:
74
86
 
75
87
 
76
88
  @lru_cache()
77
- def import_model_classes():
89
+ def import_model_classes(package_name: str):
78
90
  model_arch_name_to_cls = {}
79
- package_name = "sglang.srt.models"
80
91
  package = importlib.import_module(package_name)
81
92
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
82
93
  if not ispkg:
@@ -104,4 +115,5 @@ def import_model_classes():
104
115
  return model_arch_name_to_cls
105
116
 
106
117
 
107
- ModelRegistry = _ModelRegistry(import_model_classes())
118
+ ModelRegistry = _ModelRegistry()
119
+ ModelRegistry.register("sglang.srt.models")
@@ -1,6 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import itertools
3
+ import os
4
4
  from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
@@ -8,10 +8,12 @@ from torch import nn
8
8
 
9
9
  from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
10
10
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
11
+ from sglang.srt.layers.sparse_pooler import SparsePooler
11
12
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
12
13
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
13
14
  from sglang.srt.model_loader.weight_utils import default_weight_loader
14
15
  from sglang.srt.models.bert import BertEncoder
16
+ from sglang.srt.utils.hf_transformers_utils import download_from_hf
15
17
 
16
18
  RobertaConfig = None
17
19
 
@@ -206,12 +208,29 @@ class XLMRobertaModel(nn.Module):
206
208
  config: RobertaConfig,
207
209
  quant_config: Optional[QuantizationConfig] = None,
208
210
  prefix: str = "",
211
+ sparse_head: Optional[str] = None,
212
+ model_path: Optional[str] = None,
209
213
  ):
210
214
  super().__init__()
211
215
  self.roberta = XLMRobertaBaseModel(
212
216
  config=config, quant_config=quant_config, prefix=prefix
213
217
  )
214
- self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
218
+ if sparse_head is not None:
219
+ self._is_sparse = True
220
+ self._model_path = model_path
221
+ self._sparse_head = sparse_head
222
+ self.pooler = SparsePooler(config=config)
223
+ # Zero out special tokens
224
+ self._special_tokens = [
225
+ config.bos_token_id,
226
+ config.eos_token_id,
227
+ config.pad_token_id,
228
+ # self.config.unk_token_id # not available in the XLMRobertaConfig
229
+ ]
230
+ self._special_tokens = [t for t in self._special_tokens if t is not None]
231
+ else:
232
+ self._is_sparse = False
233
+ self.pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
215
234
 
216
235
  def forward(
217
236
  self,
@@ -224,11 +243,44 @@ class XLMRobertaModel(nn.Module):
224
243
  hidden_states = self.roberta(
225
244
  input_ids, positions, forward_batch, input_embeds, get_embedding
226
245
  )
227
- return self.pooler(hidden_states, forward_batch)
246
+ embeddings = self.pooler(hidden_states, forward_batch)
247
+
248
+ if self._is_sparse:
249
+ for token_id in self._special_tokens:
250
+ embeddings.embeddings[:, token_id] = 0.0
251
+ embeddings.embeddings = embeddings.embeddings.to_sparse()
252
+
253
+ return embeddings
228
254
 
229
255
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
230
256
  self.roberta.load_weights(weights)
231
257
 
258
+ if self._is_sparse:
259
+ sparse_dict = XLMRobertaModel._load_sparse_linear(
260
+ self._model_path, self._sparse_head
261
+ )
262
+ self.pooler.load_weights(sparse_dict)
263
+
264
+ @staticmethod
265
+ def _load_sparse_linear(model_path_or_dir: str, sparse_head: str) -> dict:
266
+ """
267
+ Load sparse_head from local dir or HF Hub.
268
+ Returns a state_dict suitable for nn.Linear.load_state_dict().
269
+ """
270
+ if os.path.isdir(model_path_or_dir):
271
+ path = os.path.join(model_path_or_dir, sparse_head)
272
+ if not os.path.exists(path):
273
+ raise FileNotFoundError(
274
+ f"'{sparse_head}' not found in {model_path_or_dir}"
275
+ )
276
+ else:
277
+ # remote → use SGLang HF utility
278
+ local_dir = download_from_hf(model_path_or_dir, allow_patterns=sparse_head)
279
+ path = os.path.join(local_dir, sparse_head)
280
+
281
+ state_dict = torch.load(path)
282
+ return state_dict
283
+
232
284
 
233
285
  class XLMRobertaForSequenceClassification(nn.Module):
234
286
  def __init__(