sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,661 @@
1
+ # Copyright 2025 Qwen Team
2
+ # Copyright 2025 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
+ import math
17
+ from typing import Iterable, List, Optional, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from transformers import PreTrainedModel
24
+ from transformers.activations import ACT2FN
25
+ from transformers.modeling_outputs import BaseModelOutput
26
+
27
+ from sglang.srt.configs.qwen3_omni import (
28
+ Qwen3OmniMoeAudioEncoderConfig,
29
+ Qwen3OmniMoeThinkerConfig,
30
+ Qwen3OmniMoeVisionEncoderConfig,
31
+ )
32
+ from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig
33
+ from sglang.srt.layers.attention.vision import VisionAttention
34
+ from sglang.srt.layers.layernorm import RMSNorm
35
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
36
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
+ from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel
41
+ from sglang.srt.models.qwen3_vl_moe import (
42
+ Qwen3MoeLLMModel,
43
+ Qwen3VLMoeForConditionalGeneration,
44
+ load_fused_expert_weights,
45
+ )
46
+ from sglang.srt.utils import add_prefix, logger
47
+
48
+
49
+ class Qwen3OmniMoeAudioEncoderLayer(nn.Module):
50
+ def __init__(
51
+ self,
52
+ config: Qwen3OmniMoeAudioEncoderConfig,
53
+ quant_config: Optional[QuantizationConfig] = None,
54
+ prefix: str = "",
55
+ ):
56
+ super().__init__()
57
+ embed_dim = config.d_model
58
+ self.embed_dim = config.d_model
59
+ self.self_attn = VisionAttention(
60
+ embed_dim=embed_dim,
61
+ num_heads=config.encoder_attention_heads,
62
+ projection_size=embed_dim,
63
+ use_qkv_parallel=True,
64
+ rotary_embed="normal",
65
+ proj_bias=True,
66
+ qkv_backend="fa3",
67
+ softmax_in_single_precision=False,
68
+ flatten_batch=True,
69
+ quant_config=quant_config,
70
+ prefix=add_prefix("attn", prefix),
71
+ )
72
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
73
+ self.dropout = config.dropout
74
+ self.activation_fn = ACT2FN[config.activation_function]
75
+ self.activation_dropout = config.activation_dropout
76
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
77
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
78
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
79
+
80
+ def forward(
81
+ self,
82
+ hidden_states: torch.Tensor,
83
+ cu_seqlens: torch.Tensor,
84
+ **kwargs,
85
+ ) -> torch.Tensor:
86
+ """
87
+ Args:
88
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
89
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
90
+ `(encoder_attention_heads,)`.
91
+ output_attentions (`bool`, *optional*):
92
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
93
+ returned tensors for more detail.
94
+ """
95
+ residual = hidden_states
96
+ hidden_states = self.self_attn_layer_norm(hidden_states)
97
+ hidden_states = self.self_attn(
98
+ x=hidden_states,
99
+ cu_seqlens=cu_seqlens,
100
+ )
101
+ hidden_states = residual + hidden_states
102
+ residual = hidden_states
103
+ hidden_states = self.final_layer_norm(hidden_states)
104
+ hidden_states = self.fc1(hidden_states)
105
+ hidden_states = self.activation_fn(hidden_states)
106
+ hidden_states = self.fc2(hidden_states)
107
+ hidden_states = residual + hidden_states
108
+
109
+ if hidden_states.dtype == torch.float16:
110
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
111
+ hidden_states = torch.clamp(
112
+ hidden_states, min=-clamp_value, max=clamp_value
113
+ )
114
+
115
+ outputs = (hidden_states,)
116
+
117
+ return outputs
118
+
119
+
120
+ class SinusoidsPositionEmbedding(nn.Module):
121
+ def __init__(self, length, channels, max_timescale=10000):
122
+ super().__init__()
123
+ if channels % 2 != 0:
124
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
125
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
126
+ inv_timescales = torch.exp(
127
+ -log_timescale_increment * torch.arange(channels // 2).float()
128
+ )
129
+ scaled_time = (
130
+ torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
131
+ )
132
+ self.register_buffer(
133
+ "positional_embedding",
134
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
135
+ persistent=False,
136
+ )
137
+
138
+ def forward(self, seqlen: int):
139
+ return self.positional_embedding[:seqlen, :]
140
+
141
+
142
+ def _get_feat_extract_output_lengths(input_lengths):
143
+ """
144
+ Computes the output length of the convolutional layers and the output length of the audio encoder
145
+ """
146
+
147
+ input_lengths_leave = input_lengths % 100
148
+ feat_lengths = (input_lengths_leave - 1) // 2 + 1
149
+ output_lengths = (
150
+ ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
151
+ )
152
+ return output_lengths
153
+
154
+
155
+ class Qwen3OmniMoeAudioEncoder(PreTrainedModel):
156
+ config: Qwen3OmniMoeAudioEncoderConfig
157
+
158
+ def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
159
+ super().__init__(config)
160
+ self.dropout = config.dropout
161
+
162
+ embed_dim = config.d_model
163
+ self.num_mel_bins = config.num_mel_bins
164
+ self.max_source_positions = config.max_source_positions
165
+ self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
166
+ self.n_window = config.n_window
167
+ self.positional_embedding = SinusoidsPositionEmbedding(
168
+ self.max_source_positions, embed_dim
169
+ )
170
+ self.layers = nn.ModuleList(
171
+ [
172
+ Qwen3OmniMoeAudioEncoderLayer(config)
173
+ for _ in range(config.encoder_layers)
174
+ ]
175
+ )
176
+ self.ln_post = nn.LayerNorm(config.d_model)
177
+ self.gradient_checkpointing = False
178
+ self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
179
+ self.conv2d2 = nn.Conv2d(
180
+ config.downsample_hidden_size,
181
+ config.downsample_hidden_size,
182
+ 3,
183
+ 2,
184
+ padding=1,
185
+ )
186
+ self.conv2d3 = nn.Conv2d(
187
+ config.downsample_hidden_size,
188
+ config.downsample_hidden_size,
189
+ 3,
190
+ 2,
191
+ padding=1,
192
+ )
193
+ self.conv_out = nn.Linear(
194
+ config.downsample_hidden_size
195
+ * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
196
+ config.d_model,
197
+ bias=False,
198
+ )
199
+ self.proj1 = nn.Linear(config.d_model, config.d_model)
200
+ self.act = ACT2FN[config.activation_function]
201
+ self.proj2 = nn.Linear(config.d_model, config.output_dim)
202
+ self.n_window_infer = self.config.n_window_infer
203
+ self.conv_chunksize = self.config.conv_chunksize
204
+
205
+ def _freeze_parameters(self):
206
+ for param in self.parameters():
207
+ param.requires_grad = False
208
+ self._requires_grad = False
209
+
210
+ def get_input_embeddings(self) -> nn.Module:
211
+ return self.conv1
212
+
213
+ def set_input_embeddings(self, value: nn.Module):
214
+ self.conv1 = value
215
+
216
+ def forward(
217
+ self,
218
+ input_features,
219
+ feature_lens=None,
220
+ aftercnn_lens=None,
221
+ ):
222
+ r"""
223
+ feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
224
+ mel length
225
+ aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
226
+ mel length after cnn
227
+ """
228
+ aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
229
+ chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
230
+
231
+ chunk_lengths = torch.tensor(
232
+ [self.n_window * 2] * chunk_num.sum(),
233
+ dtype=torch.long,
234
+ device=feature_lens.device,
235
+ )
236
+ tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
237
+ chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
238
+ chunk_lengths[chunk_lengths == 0] = self.n_window * 2
239
+
240
+ chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
241
+ padded_feature = nn.utils.rnn.pad_sequence(
242
+ chunk_list, batch_first=True
243
+ ).transpose(1, 2)
244
+ feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
245
+ padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
246
+ [
247
+ torch.ones(length, dtype=torch.bool, device=padded_feature.device)
248
+ for length in feature_lens_after_cnn
249
+ ],
250
+ batch_first=True,
251
+ )
252
+ padded_feature = padded_feature.unsqueeze(1)
253
+ # Split to chunk to avoid OOM during convolution
254
+ padded_embeds = []
255
+ for chunk in padded_feature.split(self.conv_chunksize, dim=0):
256
+ padded_embed = F.gelu(self.conv2d1(chunk))
257
+ padded_embed = F.gelu(self.conv2d2(padded_embed))
258
+ padded_embed = F.gelu(self.conv2d3(padded_embed))
259
+ padded_embeds.append(padded_embed)
260
+ padded_embed = torch.cat(padded_embeds, dim=0)
261
+ b, c, f, t = padded_embed.size()
262
+ padded_embed = self.conv_out(
263
+ padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
264
+ )
265
+
266
+ positional_embedding = (
267
+ self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
268
+ .unsqueeze(0)
269
+ .to(padded_embed.dtype)
270
+ )
271
+ padded_embed = padded_embed + positional_embedding
272
+ hidden_states = padded_embed[padded_mask_after_cnn]
273
+ cu_chunk_lens = [0]
274
+ window_aftercnn = padded_mask_after_cnn.shape[-1] * (
275
+ self.n_window_infer // (self.n_window * 2)
276
+ )
277
+ for cnn_len in aftercnn_lens:
278
+ cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
279
+ remainder = cnn_len % window_aftercnn
280
+ if remainder != 0:
281
+ cu_chunk_lens += [remainder]
282
+ cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(
283
+ -1, dtype=torch.int32
284
+ )
285
+
286
+ for encoder_layer in self.layers:
287
+ layer_outputs = encoder_layer(
288
+ hidden_states,
289
+ cu_seqlens,
290
+ )
291
+
292
+ hidden_states = layer_outputs[0]
293
+
294
+ hidden_states = self.ln_post(hidden_states)
295
+ hidden_states = self.proj1(hidden_states)
296
+ hidden_states = self.act(hidden_states)
297
+ hidden_states = self.proj2(hidden_states)
298
+ return BaseModelOutput(last_hidden_state=hidden_states)
299
+
300
+ # Ignore copy
301
+ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
302
+ """
303
+ Computes the output length of the convolutional layers and the output length of the audio encoder
304
+ """
305
+ input_lengths = (input_lengths - 1) // 2 + 1
306
+ output_lengths = (input_lengths - 2) // 2 + 1
307
+ return input_lengths, output_lengths
308
+
309
+
310
+ class Qwen3OmniMoeVisionPatchMerger(nn.Module):
311
+
312
+ def __init__(
313
+ self,
314
+ dim: int,
315
+ context_dim: int,
316
+ spatial_merge_size: int = 2,
317
+ quant_config: Optional[QuantizationConfig] = None,
318
+ prefix: str = "",
319
+ use_postshuffle_norm=False,
320
+ ) -> None:
321
+ super().__init__()
322
+ self.hidden_size = context_dim * (spatial_merge_size**2)
323
+ self.use_postshuffle_norm = use_postshuffle_norm
324
+ self.ln_q = RMSNorm(
325
+ self.hidden_size if use_postshuffle_norm else context_dim, eps=1e-6
326
+ )
327
+ self.mlp = nn.ModuleList(
328
+ [
329
+ ColumnParallelLinear(
330
+ self.hidden_size,
331
+ self.hidden_size,
332
+ bias=True,
333
+ quant_config=quant_config,
334
+ prefix=add_prefix("mlp.0", prefix),
335
+ ),
336
+ nn.GELU(),
337
+ RowParallelLinear(
338
+ self.hidden_size,
339
+ dim,
340
+ bias=True,
341
+ quant_config=quant_config,
342
+ prefix=add_prefix("mlp.2", prefix),
343
+ ),
344
+ ]
345
+ )
346
+
347
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
348
+ x = (
349
+ x.view(-1, self.hidden_size)
350
+ if self.use_postshuffle_norm
351
+ else x.view(-1, x.shape[-1])
352
+ )
353
+ hidden = self.ln_q(x).view(-1, self.hidden_size)
354
+ for layer in self.mlp:
355
+ if isinstance(hidden, tuple):
356
+ hidden = hidden[0]
357
+ hidden = layer(hidden)
358
+
359
+ if isinstance(hidden, tuple):
360
+ hidden = hidden[0]
361
+
362
+ return hidden
363
+
364
+
365
+ class Qwen3OmniMoeVisionEncoder(Qwen3VLMoeVisionModel):
366
+ config: Qwen3OmniMoeVisionEncoderConfig
367
+
368
+ def __init__(
369
+ self,
370
+ config: Qwen3OmniMoeVisionEncoderConfig,
371
+ quant_config: Optional[QuantizationConfig] = None,
372
+ prefix: str = None,
373
+ **kwargs,
374
+ ):
375
+ super().__init__(
376
+ vision_config=config,
377
+ quant_config=quant_config,
378
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
379
+ )
380
+
381
+ self.merger = Qwen3OmniMoeVisionPatchMerger(
382
+ dim=config.out_hidden_size,
383
+ context_dim=config.hidden_size,
384
+ spatial_merge_size=config.spatial_merge_size,
385
+ quant_config=quant_config,
386
+ use_postshuffle_norm=False,
387
+ prefix=add_prefix("merger", prefix),
388
+ )
389
+ self.merger_list = nn.ModuleList(
390
+ [
391
+ Qwen3OmniMoeVisionPatchMerger(
392
+ dim=config.out_hidden_size,
393
+ context_dim=config.hidden_size,
394
+ spatial_merge_size=config.spatial_merge_size,
395
+ use_postshuffle_norm=True,
396
+ quant_config=quant_config,
397
+ prefix=add_prefix("merger_list", prefix),
398
+ )
399
+ for _ in range(len(config.deepstack_visual_indexes))
400
+ ]
401
+ )
402
+ del self.deepstack_merger_list
403
+
404
+ @property
405
+ def deepstack_merger_list(self):
406
+ return self.merger_list
407
+
408
+ @property
409
+ def dtype(self) -> torch.dtype:
410
+ return self.patch_embed.proj.weight.dtype
411
+
412
+ @property
413
+ def device(self) -> torch.device:
414
+ return self.patch_embed.proj.weight.device
415
+
416
+
417
+ class Qwen3OmniMoeThinkerForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
418
+ config: Qwen3OmniMoeThinkerConfig
419
+
420
+ def __init__(
421
+ self,
422
+ config: Qwen3OmniMoeThinkerConfig,
423
+ quant_config: Optional[QuantizationConfig] = None,
424
+ prefix: str = "",
425
+ ):
426
+ super().__init__(
427
+ config, quant_config, prefix, language_model_cls=Qwen3MoeLLMModel
428
+ )
429
+ self.audio_tower = Qwen3OmniMoeAudioEncoder(config.audio_config)
430
+ self.visual = Qwen3OmniMoeVisionEncoder(
431
+ config.vision_config,
432
+ quant_config=quant_config,
433
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
434
+ prefix=add_prefix("visual", prefix),
435
+ )
436
+ self.pad_token_id = (
437
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
438
+ )
439
+
440
+ def get_audio_feature(self, items: List[MultimodalDataItem]):
441
+ feature_attention_mask = torch.cat(
442
+ [item.feature_attention_mask for item in items], dim=0
443
+ ).type(torch.long)
444
+ input_features = (
445
+ torch.cat([item.feature for item in items])
446
+ .type(self.audio_tower.dtype)
447
+ .to(next(self.audio_tower.parameters()).device)
448
+ )
449
+ if feature_attention_mask is not None:
450
+ audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
451
+ input_features = input_features.permute(0, 2, 1)[
452
+ feature_attention_mask.bool()
453
+ ].permute(1, 0)
454
+ else:
455
+ audio_feature_lengths = None
456
+
457
+ feature_lens = (
458
+ audio_feature_lengths
459
+ if audio_feature_lengths is not None
460
+ else feature_attention_mask.sum(-1)
461
+ )
462
+ audio_outputs = self.audio_tower(
463
+ input_features,
464
+ feature_lens=feature_lens,
465
+ )
466
+ audio_features = audio_outputs.last_hidden_state
467
+
468
+ return audio_features
469
+
470
+
471
+ class Qwen3OmniMoeForConditionalGeneration(PreTrainedModel):
472
+ def __init__(
473
+ self,
474
+ config: Qwen3VLMoeConfig,
475
+ quant_config: Optional[QuantizationConfig] = None,
476
+ prefix: str = "",
477
+ ):
478
+ super().__init__(config)
479
+ self.config = config
480
+
481
+ self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration(
482
+ config.thinker_config, quant_config=quant_config, prefix=prefix
483
+ )
484
+ self.enable_talker = False
485
+ self.pad_input_ids = self.thinker.pad_input_ids
486
+ self.forward = self.thinker.forward
487
+
488
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
489
+ stacked_params_mapping = [
490
+ # (param_name, shard_name, shard_id)
491
+ (".qkv_proj", ".q_proj", "q"),
492
+ (".qkv_proj", ".k_proj", "k"),
493
+ (".qkv_proj", ".v_proj", "v"),
494
+ ("gate_up_proj", "up_proj", 1),
495
+ ("gate_up_proj", "gate_proj", 0),
496
+ ]
497
+
498
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
499
+ ckpt_gate_proj_name="gate_proj",
500
+ ckpt_down_proj_name="down_proj",
501
+ ckpt_up_proj_name="up_proj",
502
+ num_experts=self.config.num_experts,
503
+ )
504
+
505
+ # Skip loading extra parameters for GPTQ/modelopt models.
506
+ ignore_suffixes = (
507
+ ".bias",
508
+ "_bias",
509
+ ".k_scale",
510
+ "_k_scale",
511
+ ".v_scale",
512
+ "_v_scale",
513
+ ".weight_scale",
514
+ "_weight_scale",
515
+ ".input_scale",
516
+ "_input_scale",
517
+ )
518
+
519
+ is_fused_expert = False
520
+ fused_expert_params_mapping = [
521
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
522
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
523
+ ]
524
+
525
+ num_experts = self.config.num_experts
526
+
527
+ # Cache params_dict to avoid repeated expensive traversal of model parameters
528
+ if not hasattr(self, "_cached_params_dict"):
529
+ self._cached_params_dict = dict(self.named_parameters())
530
+ params_dict = self._cached_params_dict
531
+
532
+ for name, loaded_weight in weights:
533
+ name = name.replace(r"model.language_model.", r"model.")
534
+
535
+ if ("talker" in name or "code2wav" in name) and not self.enable_talker:
536
+ continue
537
+
538
+ name = name.replace(".self_attn.out_proj", ".self_attn.proj")
539
+
540
+ for param_name, weight_name, shard_id in stacked_params_mapping:
541
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
542
+ is_fused_expert = True
543
+ expert_params_mapping = fused_expert_params_mapping
544
+
545
+ # Skip non-stacked layers and experts (experts handled below).
546
+ if weight_name not in name:
547
+ continue
548
+ if "visual" in name:
549
+ continue
550
+
551
+ # We have mlp.experts[0].gate_proj in the checkpoint.
552
+ # Since we handle the experts below in expert_params_mapping,
553
+ # we need to skip here BEFORE we update the name, otherwise
554
+ # name will be updated to mlp.experts[0].gate_up_proj, which
555
+ # will then be updated below in expert_params_mapping
556
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
557
+ if "mlp.experts" in name:
558
+ continue
559
+ name = name.replace(weight_name, param_name)
560
+ # Skip loading extra parameters for GPTQ/modelopt models.
561
+ if name.endswith(ignore_suffixes) and name not in params_dict:
562
+ continue
563
+ # [TODO] Skip layers that are on other devices (check if sglang has a similar function)
564
+ # if is_pp_missing_parameter(name, self):
565
+ # continue
566
+
567
+ if name not in params_dict:
568
+ continue
569
+
570
+ param = params_dict[name]
571
+ weight_loader = param.weight_loader
572
+ weight_loader(param, loaded_weight, shard_id)
573
+ break
574
+ else:
575
+ # Track if this is an expert weight to enable early skipping
576
+ is_expert_weight = False
577
+
578
+ for mapping in expert_params_mapping:
579
+ param_name, weight_name, expert_id, shard_id = mapping
580
+ if weight_name not in name:
581
+ continue
582
+ if "visual" in name or "audio_tower" in name:
583
+ continue
584
+ # Anyway, this is an expert weight and should not be
585
+ # attempted to load as other weights later
586
+ is_expert_weight = True
587
+ name_mapped = name.replace(weight_name, param_name)
588
+ if is_fused_expert:
589
+ loaded_weight = loaded_weight.transpose(-1, -2) # no bias
590
+ if "experts.gate_up_proj" in name:
591
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
592
+ load_fused_expert_weights(
593
+ name_mapped,
594
+ params_dict,
595
+ loaded_weight[0],
596
+ "w1",
597
+ num_experts,
598
+ )
599
+ load_fused_expert_weights(
600
+ name_mapped,
601
+ params_dict,
602
+ loaded_weight[1],
603
+ "w3",
604
+ num_experts,
605
+ )
606
+ else:
607
+ load_fused_expert_weights(
608
+ name_mapped,
609
+ params_dict,
610
+ loaded_weight,
611
+ shard_id,
612
+ num_experts,
613
+ )
614
+ else:
615
+ # Skip loading extra parameters for GPTQ/modelopt models.
616
+ if (
617
+ name_mapped.endswith(ignore_suffixes)
618
+ and name_mapped not in params_dict
619
+ ):
620
+ continue
621
+ param = params_dict[name_mapped]
622
+ # We should ask the weight loader to return success or
623
+ # not here since otherwise we may skip experts with
624
+ # # other available replicas.
625
+ weight_loader = param.weight_loader
626
+ weight_loader(
627
+ param,
628
+ loaded_weight,
629
+ name_mapped,
630
+ shard_id=shard_id,
631
+ expert_id=expert_id,
632
+ )
633
+ name = name_mapped
634
+ break
635
+ else:
636
+ if is_expert_weight:
637
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
638
+ continue
639
+ if "visual" in name or "audio_tower" in name:
640
+ # adapt to VisionAttention
641
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
642
+ name = name.replace(r"model.visual.", r"visual.")
643
+ name = name.replace(r"attn.out_proj.", r"attn.proj.")
644
+
645
+ # Skip loading extra parameters for GPTQ/modelopt models.
646
+ if name.endswith(ignore_suffixes) and name not in params_dict:
647
+ continue
648
+
649
+ if name in params_dict.keys():
650
+ param = params_dict[name]
651
+ weight_loader = getattr(
652
+ param, "weight_loader", default_weight_loader
653
+ )
654
+ weight_loader(param, loaded_weight)
655
+ else:
656
+ logger.warning(
657
+ f"Loaded weight with {name=} not found in params_dict"
658
+ )
659
+
660
+
661
+ EntryClass = Qwen3OmniMoeForConditionalGeneration