sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,471 @@
1
+ # Copyright 2025 Qwen Team
2
+ # Copyright 2025 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Inference-only Qwen3-VL model compatible with HuggingFace weights."""
16
+ import logging
17
+ from functools import lru_cache, partial
18
+ from typing import Callable, Iterable, List, Literal, Optional, Tuple, TypedDict, Union
19
+
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ from einops import rearrange
25
+ from transformers import BatchFeature
26
+ from transformers.activations import ACT2FN
27
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
28
+ Qwen2_5_VisionRotaryEmbedding,
29
+ )
30
+
31
+ from sglang.srt.configs.qwen3_vl import Qwen3VLMoeConfig, Qwen3VLMoeVisionConfig
32
+ from sglang.srt.distributed import (
33
+ get_moe_expert_parallel_world_size,
34
+ get_pp_group,
35
+ get_tensor_model_parallel_rank,
36
+ )
37
+ from sglang.srt.layers.logits_processor import LogitsProcessor
38
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
39
+ from sglang.srt.layers.pooler import Pooler, PoolingType
40
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
+ from sglang.srt.layers.utils import get_layer_id
42
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
43
+ from sglang.srt.managers.mm_utils import (
44
+ MultiModalityDataPaddingPatternMultimodalTokens,
45
+ general_mm_embed_routine,
46
+ )
47
+ from sglang.srt.managers.schedule_batch import (
48
+ MultimodalDataItem,
49
+ MultimodalInputs,
50
+ global_server_args_dict,
51
+ )
52
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
53
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
54
+ from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM, Qwen3MoeModel
55
+ from sglang.srt.models.qwen3_vl import (
56
+ Qwen3_VisionTransformer,
57
+ Qwen3VLForConditionalGeneration,
58
+ )
59
+ from sglang.srt.utils import add_prefix
60
+ from sglang.srt.utils.hf_transformers_utils import get_processor
61
+
62
+ logger = logging.getLogger(__name__)
63
+
64
+ cached_get_processor = lru_cache(get_processor)
65
+
66
+
67
+ class Qwen3MoeLLMModel(Qwen3MoeModel):
68
+ def __init__(
69
+ self,
70
+ *,
71
+ config: Qwen3VLMoeConfig,
72
+ quant_config: Optional[QuantizationConfig] = None,
73
+ prefix: str = "",
74
+ ):
75
+ super().__init__(config=config, quant_config=quant_config, prefix=prefix)
76
+
77
+ self.hidden_size = config.hidden_size
78
+
79
+ def get_input_embeddings(self) -> nn.Embedding:
80
+ return self.embed_tokens
81
+
82
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
83
+ # in qwen-vl, last dim is the same
84
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
85
+ self.visual.dtype
86
+ )
87
+ image_grid_thw = torch.concat([item.image_grid_thw for item in items], dim=0)
88
+ assert pixel_values.dim() == 2, pixel_values.dim()
89
+ assert image_grid_thw.dim() == 2, image_grid_thw.dim()
90
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
91
+ return image_embeds
92
+
93
+ def forward(
94
+ self,
95
+ input_ids: torch.Tensor,
96
+ positions: torch.Tensor,
97
+ forward_batch: ForwardBatch,
98
+ input_embeds: torch.Tensor = None,
99
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
100
+ input_deepstack_embeds: Optional[torch.Tensor] = None,
101
+ ) -> Union[torch.Tensor, PPProxyTensors]:
102
+ if self.pp_group.is_first_rank:
103
+ if input_embeds is None:
104
+ hidden_states = self.embed_tokens(input_ids)
105
+ else:
106
+ hidden_states = input_embeds
107
+ residual = None
108
+ else:
109
+ assert pp_proxy_tensors is not None
110
+ hidden_states = pp_proxy_tensors["hidden_states"]
111
+ residual = pp_proxy_tensors["residual"]
112
+
113
+ aux_hidden_states = []
114
+ for layer_idx, layer in enumerate(
115
+ self.layers[self.start_layer : self.end_layer]
116
+ ):
117
+ layer_idx = layer_idx + self.start_layer
118
+ if layer_idx in self.layers_to_capture:
119
+ aux_hidden_states.append(
120
+ hidden_states + residual if residual is not None else hidden_states
121
+ )
122
+
123
+ hidden_states, residual = layer(
124
+ positions,
125
+ hidden_states,
126
+ forward_batch,
127
+ residual,
128
+ )
129
+
130
+ # process deepstack
131
+ if input_deepstack_embeds is not None and layer_idx in range(3):
132
+ sep = self.hidden_size * layer_idx
133
+ hidden_states = (
134
+ hidden_states
135
+ + input_deepstack_embeds[:, sep : sep + self.hidden_size]
136
+ )
137
+
138
+ if not self.pp_group.is_last_rank:
139
+ return PPProxyTensors(
140
+ {
141
+ "hidden_states": hidden_states,
142
+ "residual": residual,
143
+ }
144
+ )
145
+ else:
146
+ if hidden_states.shape[0] != 0:
147
+ if residual is None:
148
+ hidden_states = self.norm(hidden_states)
149
+ else:
150
+ hidden_states, _ = self.norm(hidden_states, residual)
151
+
152
+ if len(aux_hidden_states) == 0:
153
+ return hidden_states
154
+
155
+ return hidden_states, aux_hidden_states
156
+
157
+
158
+ class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration):
159
+ def __init__(
160
+ self,
161
+ *,
162
+ config: Qwen3VLMoeConfig,
163
+ quant_config: Optional[QuantizationConfig] = None,
164
+ prefix: str = "",
165
+ ):
166
+ super(Qwen3VLForConditionalGeneration, self).__init__()
167
+ self.config = config
168
+
169
+ self.visual = Qwen3_VisionTransformer(
170
+ config.vision_config,
171
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
172
+ # NOTE: Qwen3-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
173
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
174
+ quant_config=quant_config,
175
+ prefix=add_prefix("visual", prefix),
176
+ )
177
+
178
+ self.model = Qwen3MoeLLMModel(
179
+ config=config,
180
+ quant_config=quant_config,
181
+ prefix=add_prefix("model", prefix),
182
+ )
183
+
184
+ if config.tie_word_embeddings:
185
+ self.lm_head = self.model.embed_tokens
186
+ else:
187
+ self.lm_head = ParallelLMHead(
188
+ config.vocab_size,
189
+ config.hidden_size,
190
+ quant_config=quant_config,
191
+ prefix=add_prefix("lm_head", prefix),
192
+ )
193
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
194
+
195
+ self.logits_processor = LogitsProcessor(config)
196
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
197
+
198
+ # deepstack
199
+ self.deepstack_visual_indexes = self.visual.deepstack_visual_indexes
200
+ self.num_deepstack_embeddings = len(self.deepstack_visual_indexes)
201
+
202
+ @property
203
+ def use_deepstack(self) -> bool:
204
+ return hasattr(self, "deepstack_visual_indexes")
205
+
206
+ def forward(
207
+ self,
208
+ input_ids: torch.Tensor,
209
+ positions: torch.Tensor,
210
+ forward_batch: ForwardBatch,
211
+ get_embedding: bool = False,
212
+ ):
213
+ """Run forward pass for Qwen3-VL.
214
+
215
+ Args:
216
+ input_ids: Flattened (concatenated) input_ids corresponding to a
217
+ batch.
218
+ positions: Flattened (concatenated) position ids corresponding to a
219
+ batch.
220
+ **NOTE**: If mrope is enabled (default setting for Qwen2-VL
221
+ opensource models), the shape will be `(3, seq_len)`,
222
+ otherwise it will be `(seq_len,).
223
+ (Use input_metadata.mrope_positions to replace it)
224
+ """
225
+ if self.is_mrope_enabled:
226
+ positions = forward_batch.mrope_positions
227
+
228
+ if not (
229
+ forward_batch.forward_mode.is_decode()
230
+ or not forward_batch.contains_image_inputs()
231
+ ):
232
+ if self.is_mrope_enabled:
233
+ assert positions.ndim == 2 and positions.size(0) == 3, (
234
+ "multimodal section rotary embedding requires "
235
+ f"(3, seq_len) positions, but got {positions.size()}"
236
+ )
237
+
238
+ hidden_states = general_mm_embed_routine(
239
+ input_ids=input_ids,
240
+ forward_batch=forward_batch,
241
+ language_model=self.model,
242
+ multimodal_model=self,
243
+ positions=positions,
244
+ use_deepstack=self.use_deepstack,
245
+ )
246
+
247
+ if not get_embedding:
248
+ return self.logits_processor(
249
+ input_ids, hidden_states, self.lm_head, forward_batch
250
+ )
251
+ else:
252
+ return self.pooler(hidden_states, forward_batch)
253
+
254
+ def load_fused_expert_weights(
255
+ self,
256
+ name: str,
257
+ params_dict: dict,
258
+ loaded_weight: torch.Tensor,
259
+ shard_id: str,
260
+ num_experts: int,
261
+ ):
262
+ param = params_dict[name]
263
+ # weight_loader = typing.cast(Callable[..., bool], param.weight_loader)
264
+ weight_loader = param.weight_loader
265
+ ep_rank = get_tensor_model_parallel_rank()
266
+ ep_size = get_moe_expert_parallel_world_size()
267
+ if ep_size == 1:
268
+ for expert_id in range(num_experts):
269
+ curr_expert_weight = loaded_weight[expert_id]
270
+ weight_loader(
271
+ param,
272
+ curr_expert_weight,
273
+ name,
274
+ shard_id,
275
+ expert_id,
276
+ )
277
+ else:
278
+ experts_per_ep = num_experts // ep_size
279
+ start_expert = ep_rank * experts_per_ep
280
+ end_expert = (
281
+ (ep_rank + 1) * experts_per_ep
282
+ if ep_rank != ep_size - 1
283
+ else num_experts
284
+ )
285
+
286
+ for idx, expert_id in enumerate(range(start_expert, end_expert)):
287
+ curr_expert_weight = loaded_weight[expert_id]
288
+ weight_loader(
289
+ param,
290
+ curr_expert_weight,
291
+ name,
292
+ shard_id,
293
+ idx,
294
+ )
295
+ return True
296
+
297
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
298
+ stacked_params_mapping = [
299
+ # (param_name, shard_name, shard_id)
300
+ (".qkv_proj", ".q_proj", "q"),
301
+ (".qkv_proj", ".k_proj", "k"),
302
+ (".qkv_proj", ".v_proj", "v"),
303
+ ("gate_up_proj", "up_proj", 1),
304
+ ("gate_up_proj", "gate_proj", 0),
305
+ ]
306
+
307
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
308
+ ckpt_gate_proj_name="gate_proj",
309
+ ckpt_down_proj_name="down_proj",
310
+ ckpt_up_proj_name="up_proj",
311
+ num_experts=self.config.num_experts,
312
+ )
313
+
314
+ # Skip loading extra parameters for GPTQ/modelopt models.
315
+ ignore_suffixes = (
316
+ ".bias",
317
+ "_bias",
318
+ ".k_scale",
319
+ "_k_scale",
320
+ ".v_scale",
321
+ "_v_scale",
322
+ ".weight_scale",
323
+ "_weight_scale",
324
+ ".input_scale",
325
+ "_input_scale",
326
+ )
327
+
328
+ is_fused_expert = False
329
+ fused_expert_params_mapping = [
330
+ ("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
331
+ ("experts.w2_weight", "experts.down_proj", 0, "w2"),
332
+ ]
333
+
334
+ num_experts = self.config.num_experts
335
+
336
+ # Cache params_dict to avoid repeated expensive traversal of model parameters
337
+ if not hasattr(self, "_cached_params_dict"):
338
+ self._cached_params_dict = dict(self.named_parameters())
339
+ params_dict = self._cached_params_dict
340
+ for name, loaded_weight in weights:
341
+ if "language_model" in name:
342
+ name = name.replace(r"model.language_model.", r"model.")
343
+
344
+ for param_name, weight_name, shard_id in stacked_params_mapping:
345
+ if "experts.gate_up_proj" in name or "experts.down_proj" in name:
346
+ is_fused_expert = True
347
+ expert_params_mapping = fused_expert_params_mapping
348
+
349
+ # Skip non-stacked layers and experts (experts handled below).
350
+ if weight_name not in name:
351
+ continue
352
+ if "visual" in name:
353
+ continue
354
+
355
+ # We have mlp.experts[0].gate_proj in the checkpoint.
356
+ # Since we handle the experts below in expert_params_mapping,
357
+ # we need to skip here BEFORE we update the name, otherwise
358
+ # name will be updated to mlp.experts[0].gate_up_proj, which
359
+ # will then be updated below in expert_params_mapping
360
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
361
+ if "mlp.experts" in name:
362
+ continue
363
+ name = name.replace(weight_name, param_name)
364
+ # Skip loading extra parameters for GPTQ/modelopt models.
365
+ if name.endswith(ignore_suffixes) and name not in params_dict:
366
+ continue
367
+ # [TODO] Skip layers that are on other devices (check if sglang has a similar function)
368
+ # if is_pp_missing_parameter(name, self):
369
+ # continue
370
+
371
+ if name not in params_dict:
372
+ continue
373
+
374
+ param = params_dict[name]
375
+ weight_loader = param.weight_loader
376
+ weight_loader(param, loaded_weight, shard_id)
377
+ break
378
+ else:
379
+ # Track if this is an expert weight to enable early skipping
380
+ is_expert_weight = False
381
+
382
+ for mapping in expert_params_mapping:
383
+ param_name, weight_name, expert_id, shard_id = mapping
384
+ if weight_name not in name:
385
+ continue
386
+ if "visual" in name:
387
+ continue
388
+ # Anyway, this is an expert weight and should not be
389
+ # attempted to load as other weights later
390
+ is_expert_weight = True
391
+ name_mapped = name.replace(weight_name, param_name)
392
+ if is_fused_expert:
393
+ loaded_weight = loaded_weight.transpose(-1, -2) # no bias
394
+ if "experts.gate_up_proj" in name:
395
+ loaded_weight = loaded_weight.chunk(2, dim=-2)
396
+ self.load_fused_expert_weights(
397
+ name_mapped,
398
+ params_dict,
399
+ loaded_weight[0],
400
+ "w1",
401
+ num_experts,
402
+ )
403
+ self.load_fused_expert_weights(
404
+ name_mapped,
405
+ params_dict,
406
+ loaded_weight[1],
407
+ "w3",
408
+ num_experts,
409
+ )
410
+ else:
411
+ self.load_fused_expert_weights(
412
+ name_mapped,
413
+ params_dict,
414
+ loaded_weight,
415
+ shard_id,
416
+ num_experts,
417
+ )
418
+ else:
419
+ # Skip loading extra parameters for GPTQ/modelopt models.
420
+ if (
421
+ name_mapped.endswith(ignore_suffixes)
422
+ and name_mapped not in params_dict
423
+ ):
424
+ continue
425
+ param = params_dict[name_mapped]
426
+ # We should ask the weight loader to return success or
427
+ # not here since otherwise we may skip experts with
428
+ # # other available replicas.
429
+ weight_loader = param.weight_loader
430
+ weight_loader(
431
+ param,
432
+ loaded_weight,
433
+ name_mapped,
434
+ shard_id=shard_id,
435
+ expert_id=expert_id,
436
+ )
437
+ name = name_mapped
438
+ break
439
+ else:
440
+ if is_expert_weight:
441
+ # This is an expert weight but not mapped to this rank, skip all remaining processing
442
+ continue
443
+ if "visual" in name:
444
+ # adapt to VisionAttention
445
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
446
+ name = name.replace(r"model.visual.", r"visual.")
447
+
448
+ # Skip loading extra parameters for GPTQ/modelopt models.
449
+ if name.endswith(ignore_suffixes) and name not in params_dict:
450
+ continue
451
+
452
+ if name in params_dict.keys():
453
+ param = params_dict[name]
454
+ weight_loader = getattr(
455
+ param, "weight_loader", default_weight_loader
456
+ )
457
+ weight_loader(param, loaded_weight)
458
+ else:
459
+ logger.warning(f"Parameter {name} not found in params_dict")
460
+
461
+ # TODO mimic deepseek
462
+ # Lazy initialization of expert weights cache to avoid slowing down load_weights
463
+ # if not hasattr(self, "routed_experts_weights_of_layer"):
464
+ # self.routed_experts_weights_of_layer = {
465
+ # layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
466
+ # for layer_id in range(self.start_layer, self.end_layer)
467
+ # if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
468
+ # }
469
+
470
+
471
+ EntryClass = Qwen3VLMoeForConditionalGeneration
@@ -17,6 +17,18 @@ class _ModelRegistry:
17
17
  # Keyed by model_arch
18
18
  models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
19
19
 
20
+ def register(self, package_name: str, overwrite: bool = False):
21
+ new_models = import_model_classes(package_name)
22
+ if overwrite:
23
+ self.models.update(new_models)
24
+ else:
25
+ for arch, cls in new_models.items():
26
+ if arch in self.models:
27
+ raise ValueError(
28
+ f"Model architecture {arch} already registered. Set overwrite=True to replace."
29
+ )
30
+ self.models[arch] = cls
31
+
20
32
  def get_supported_archs(self) -> AbstractSet[str]:
21
33
  return self.models.keys()
22
34
 
@@ -74,9 +86,8 @@ class _ModelRegistry:
74
86
 
75
87
 
76
88
  @lru_cache()
77
- def import_model_classes():
89
+ def import_model_classes(package_name: str):
78
90
  model_arch_name_to_cls = {}
79
- package_name = "sglang.srt.models"
80
91
  package = importlib.import_module(package_name)
81
92
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
82
93
  if not ispkg:
@@ -104,4 +115,5 @@ def import_model_classes():
104
115
  return model_arch_name_to_cls
105
116
 
106
117
 
107
- ModelRegistry = _ModelRegistry(import_model_classes())
118
+ ModelRegistry = _ModelRegistry()
119
+ ModelRegistry.register("sglang.srt.models")