sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 The RedNote HiLab team.
2
+ # Copyright 2025 The SGLang team.
3
+ #
4
+ # This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
5
+ # implementation in this library.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """Inference-only Dots-VL model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.configs.dots_vlm import DotsVLMConfig
26
+ from sglang.srt.distributed import parallel_state
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.managers.mm_utils import (
29
+ MultiModalityDataPaddingPatternMultimodalTokens,
30
+ general_mm_embed_routine,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
35
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
36
+
37
+ from .dots_vlm_vit import DotsVisionTransformer
38
+
39
+
40
+ class DotsVLMForCausalLM(nn.Module):
41
+ """DotsVLM model for sglang inference"""
42
+
43
+ def __init__(
44
+ self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ self.config = config
49
+ self.image_token_id = config.im_span_id
50
+ self.video_token_id = config.video_span_id
51
+
52
+ self.language_model = DeepseekV2ForCausalLM(
53
+ config.language_config, quant_config
54
+ )
55
+
56
+ # Initialize vision tower (matching transformers naming for weight compatibility)
57
+ self.vision_tower = DotsVisionTransformer(config.vision_config)
58
+
59
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
60
+ """pad attn qkv weights for dummy heads"""
61
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
62
+ if num_dummy_heads == 0:
63
+ return loaded_weight
64
+ head_dim = self.config.vision_config.head_dim
65
+
66
+ if "attn.qkv_proj" in name:
67
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
68
+ if name.endswith(".weight"):
69
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
70
+ elif name.endswith(".bias"):
71
+ dummy_shape = [num_dummy_heads, head_dim]
72
+ else:
73
+ raise RuntimeError(f"Unsupported weight with name={name}")
74
+ pad_func = lambda x: torch.cat(
75
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
76
+ ).flatten(0, 1)
77
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
78
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
79
+ if "attn.proj.weight" in name:
80
+ padded_weight = loaded_weight.new_zeros(
81
+ loaded_weight.shape[0], head_dim * num_dummy_heads
82
+ )
83
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
84
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
85
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
86
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
87
+ return loaded_weight
88
+
89
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
90
+ """Load weights for the model, separating vision and language weights"""
91
+ weights = list(weights)
92
+
93
+ # Separate vision tower weights and language model weights
94
+ vision_weights = []
95
+ language_weights = []
96
+
97
+ for name, loaded_weight in weights:
98
+ if name.startswith("vision_tower."):
99
+ vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
100
+ vision_weights.append((vision_name, loaded_weight))
101
+ else:
102
+ # All other weights go to language model
103
+ language_weights.append((name, loaded_weight))
104
+
105
+ # Load vision tower weights
106
+ vision_state_dict = dict(vision_weights)
107
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
108
+ for name, loaded_weight in vision_state_dict.items():
109
+ if name not in params_dict:
110
+ raise ValueError(f"Weight {name} not found in params_dict")
111
+ param = params_dict[name]
112
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
113
+ loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
114
+ weight_loader(param, loaded_weight)
115
+
116
+ # Load language model weights
117
+ if language_weights:
118
+ self.language_model.load_weights(language_weights)
119
+
120
+ @classmethod
121
+ def get_model_config_for_expert_location(cls, config):
122
+ return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
123
+
124
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
125
+ """Pad input_ids with multimodal tokens"""
126
+ # Get image token ID for padding pattern
127
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
128
+ padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
129
+ return padded_input_ids
130
+
131
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
132
+ # Extract pixel values and grid information (following reference pattern)
133
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
134
+ self.vision_tower.dtype
135
+ )
136
+ image_grid_thw = torch.concat(
137
+ [item.image_grid_thw for item in items], dim=0
138
+ ).to(self.vision_tower.device)
139
+
140
+ # Add dimension checks like in reference code
141
+ assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
142
+ assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
143
+
144
+ # Process through vision tower
145
+ image_embeds = self.vision_tower(pixel_values, image_grid_thw)
146
+
147
+ # Ensure consistent dtype for FlashInfer compatibility
148
+ # Force bfloat16 to match model's expected dtype
149
+ if image_embeds.dtype != torch.bfloat16 and hasattr(
150
+ self.language_model.model, "embed_tokens"
151
+ ):
152
+ target_dtype = self.language_model.model.embed_tokens.weight.dtype
153
+ image_embeds = image_embeds.to(target_dtype)
154
+
155
+ return image_embeds
156
+
157
+ def forward(
158
+ self,
159
+ input_ids: torch.Tensor,
160
+ positions: torch.Tensor,
161
+ forward_batch: ForwardBatch,
162
+ **kwargs: object,
163
+ ) -> torch.Tensor:
164
+ hidden_states = general_mm_embed_routine(
165
+ input_ids=input_ids,
166
+ positions=positions,
167
+ forward_batch=forward_batch,
168
+ multimodal_model=self,
169
+ language_model=self.language_model,
170
+ )
171
+ return hidden_states
172
+
173
+
174
+ EntryClass = [DotsVLMForCausalLM]
@@ -0,0 +1,337 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.nn import LayerNorm
9
+ from transformers.modeling_utils import PreTrainedModel
10
+
11
+ from sglang.srt.configs.dots_vlm import DotsVisionConfig
12
+ from sglang.srt.distributed import parallel_state
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.quantization import QuantizationConfig
15
+ from sglang.srt.utils import add_prefix
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class VisionRotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
22
+ super().__init__()
23
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
24
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
25
+
26
+ def forward(self, seqlen: int) -> torch.Tensor:
27
+ seq = torch.arange(
28
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
29
+ )
30
+ freqs = torch.outer(seq, self.inv_freq)
31
+ return freqs
32
+
33
+
34
+ class PatchMerger(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ context_dim: int,
39
+ spatial_merge_size: int = 2,
40
+ pre_norm="layernorm",
41
+ init_merger_std=None,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.hidden_size = context_dim * (spatial_merge_size**2)
46
+ self.pre_norm = pre_norm
47
+ if self.pre_norm == "layernorm":
48
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
49
+ elif self.pre_norm == "rmsnorm":
50
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
51
+ else:
52
+ logger.warning(f"no norm in patch merger: {self.pre_norm}")
53
+
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(self.hidden_size, self.hidden_size),
56
+ nn.GELU(),
57
+ nn.Linear(self.hidden_size, dim),
58
+ )
59
+
60
+ if init_merger_std is not None:
61
+ nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
62
+ nn.init.zeros_(self.mlp[0].bias)
63
+ nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
64
+ nn.init.zeros_(self.mlp[2].bias)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ if self.pre_norm:
68
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
69
+ else:
70
+ x = self.mlp(x.view(-1, self.hidden_size))
71
+ return x
72
+
73
+
74
+ class RMSNorm(nn.Module):
75
+ def __init__(self, dim: int, eps: float = 1e-6):
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.ones(dim))
78
+ self.eps = eps
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ output = self._norm(x.float()).type_as(x)
82
+ return output * self.weight
83
+
84
+ def extra_repr(self) -> str:
85
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
86
+
87
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
88
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
89
+
90
+
91
+ class DotsSwiGLUFFN(nn.Module):
92
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
93
+ super().__init__()
94
+ hidden_features = config.intermediate_size
95
+ in_features = config.embed_dim
96
+ bias = config.use_bias
97
+
98
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
99
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
100
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ x = F.silu(self.fc1(x)) * self.fc3(x)
104
+ x = self.fc2(x)
105
+ return x
106
+
107
+
108
+ class DotsPatchEmbed(nn.Module):
109
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
110
+ super().__init__()
111
+ self.num_channels = config.num_channels
112
+ self.patch_size = config.patch_size
113
+ self.temporal_patch_size = config.temporal_patch_size
114
+ self.embed_dim = config.embed_dim
115
+ self.config = config
116
+ self.proj = nn.Conv2d(
117
+ config.num_channels,
118
+ config.embed_dim,
119
+ kernel_size=(config.patch_size, config.patch_size),
120
+ stride=(config.patch_size, config.patch_size),
121
+ )
122
+ self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
123
+
124
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
125
+ x = x.view(
126
+ -1,
127
+ self.num_channels,
128
+ self.temporal_patch_size,
129
+ self.patch_size,
130
+ self.patch_size,
131
+ )[:, :, 0]
132
+ x = self.proj(x).view(-1, self.embed_dim)
133
+ x = self.norm(x)
134
+ return x
135
+
136
+
137
+ class DotsViTPreprocessor(nn.Module):
138
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
139
+ super().__init__()
140
+ self.patch_h = config.patch_size
141
+ self.patch_w = config.patch_size
142
+ self.embed_dim = config.embed_dim
143
+ self.config = config
144
+ self.patchifier = DotsPatchEmbed(config, quant_config)
145
+
146
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
147
+ tokens = self.patchifier(x, grid_thw)
148
+ return tokens
149
+
150
+
151
+ class DotsVisionBlock(nn.Module):
152
+ def __init__(
153
+ self,
154
+ config: DotsVisionConfig,
155
+ quant_config: Optional[QuantizationConfig] = None,
156
+ prefix: str = "",
157
+ attn_implementation: str = "flash_attention_2",
158
+ ):
159
+ super().__init__()
160
+ if attn_implementation == "flash_attention_2":
161
+ qkv_backend = "fa3"
162
+ softmax_in_single_precision = False
163
+ else:
164
+ raise RuntimeError("Unimplemented")
165
+ self.attn = VisionAttention(
166
+ embed_dim=config.embed_dim,
167
+ num_heads=config.num_attention_heads,
168
+ projection_size=config.embed_dim,
169
+ use_qkv_parallel=True,
170
+ qkv_backend=qkv_backend,
171
+ softmax_in_single_precision=softmax_in_single_precision,
172
+ flatten_batch=True,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("attn", prefix),
175
+ num_dummy_heads=config.num_dummy_heads,
176
+ qkv_bias=config.use_bias,
177
+ proj_bias=config.use_bias,
178
+ )
179
+ self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
180
+ self.mlp = DotsSwiGLUFFN(config, quant_config)
181
+ self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
182
+
183
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
184
+ hidden_states = hidden_states + self.attn(
185
+ self.norm1(hidden_states),
186
+ cu_seqlens=cu_seqlens,
187
+ position_embeddings=rotary_pos_emb,
188
+ )
189
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
190
+ return hidden_states
191
+
192
+
193
+ class DotsVisionTransformer(PreTrainedModel):
194
+ def __init__(
195
+ self,
196
+ config: DotsVisionConfig,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ ) -> None:
199
+ super().__init__(config)
200
+ self.config = config
201
+ self._update_vision_config()
202
+ self.spatial_merge_size = config.spatial_merge_size
203
+
204
+ self.patch_embed = DotsViTPreprocessor(config, quant_config)
205
+ self._init_weights(self.patch_embed.patchifier.proj)
206
+
207
+ head_dim = config.embed_dim // config.num_attention_heads
208
+
209
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
210
+
211
+ _num_hidden_layers = config.num_hidden_layers
212
+ self.blocks = nn.ModuleList(
213
+ [
214
+ DotsVisionBlock(
215
+ config, quant_config, f"blocks.{i}", config.attn_implementation
216
+ )
217
+ for i in range(_num_hidden_layers)
218
+ ]
219
+ )
220
+
221
+ if self.config.post_norm:
222
+ self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
223
+
224
+ self.merger = PatchMerger(
225
+ dim=config.hidden_size,
226
+ context_dim=config.embed_dim,
227
+ spatial_merge_size=config.spatial_merge_size,
228
+ init_merger_std=self.config.init_merger_std,
229
+ quant_config=quant_config,
230
+ )
231
+
232
+ self.gradient_checkpointing = False
233
+
234
+ def _update_vision_config(self):
235
+ """update vision config to support tp"""
236
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
237
+ num_heads = self.config.num_attention_heads
238
+ head_dim = self.config.embed_dim // num_heads
239
+ num_dummy_heads = 0
240
+
241
+ if num_heads % world_size != 0:
242
+ num_dummy_heads = (
243
+ (num_heads + world_size) // world_size
244
+ ) * world_size - num_heads
245
+
246
+ setattr(self.config, "head_dim", head_dim)
247
+ setattr(self.config, "num_dummy_heads", num_dummy_heads)
248
+
249
+ def _init_weights(self, module):
250
+ std = self.config.initializer_range
251
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
252
+ module.weight.data.normal_(mean=0.0, std=std)
253
+ if module.bias is not None:
254
+ module.bias.data.zero_()
255
+ elif isinstance(module, nn.Embedding):
256
+ module.weight.data.normal_(mean=0.0, std=std)
257
+ if module.padding_idx is not None:
258
+ module.weight.data[module.padding_idx].zero_()
259
+
260
+ @property
261
+ def dtype(self) -> torch.dtype:
262
+ return self.blocks[0].mlp.fc2.weight.dtype
263
+
264
+ @property
265
+ def device(self) -> torch.device:
266
+ return self.blocks[0].mlp.fc2.weight.device
267
+
268
+ def get_pos_ids_by_grid(self, grid_thw):
269
+ pos_ids = []
270
+ for t, h, w in grid_thw:
271
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
272
+ hpos_ids = hpos_ids.reshape(
273
+ h // self.spatial_merge_size,
274
+ self.spatial_merge_size,
275
+ w // self.spatial_merge_size,
276
+ self.spatial_merge_size,
277
+ )
278
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
279
+ hpos_ids = hpos_ids.flatten()
280
+
281
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
282
+ wpos_ids = wpos_ids.reshape(
283
+ h // self.spatial_merge_size,
284
+ self.spatial_merge_size,
285
+ w // self.spatial_merge_size,
286
+ self.spatial_merge_size,
287
+ )
288
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
289
+ wpos_ids = wpos_ids.flatten()
290
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
291
+
292
+ return pos_ids
293
+
294
+ def rot_pos_emb(self, grid_thw):
295
+ pos_ids = self.get_pos_ids_by_grid(grid_thw)
296
+ pos_ids = torch.cat(pos_ids, dim=0)
297
+ max_grid_size = grid_thw[:, 1:].max()
298
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
299
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
300
+ return rotary_pos_emb
301
+
302
+ def calc_cos_sin(self, rotary_pos_emb):
303
+ cos = rotary_pos_emb.cos()
304
+ sin = rotary_pos_emb.sin()
305
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
306
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
307
+ rotary_pos_emb = (cos, sin)
308
+ return rotary_pos_emb
309
+
310
+ def forward(
311
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True
312
+ ) -> torch.Tensor:
313
+ if bf16:
314
+ hidden_states = hidden_states.bfloat16()
315
+ hidden_states = self.patch_embed(hidden_states, grid_thw)
316
+
317
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
318
+ rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb)
319
+
320
+ cu_seqlens = torch.repeat_interleave(
321
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
322
+ ).cumsum(
323
+ dim=0,
324
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
325
+ )
326
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
327
+
328
+ for blk in self.blocks:
329
+ hidden_states = blk(
330
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
331
+ )
332
+
333
+ if self.config.post_norm:
334
+ hidden_states = self.post_trunk_norm(hidden_states)
335
+
336
+ hidden_states = self.merger(hidden_states)
337
+ return hidden_states
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
92
92
  correction_bias=self.gate.e_score_correction_bias,
93
93
  )
94
94
 
95
- self.experts = get_moe_impl_class()(
95
+ self.experts = get_moe_impl_class(quant_config)(
96
96
  num_experts=config.moe_num_experts,
97
97
  top_k=config.moe_k,
98
98
  hidden_size=config.hidden_size,
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (
@@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module):
153
153
  )
154
154
  self.act_fn = SiluAndMul()
155
155
 
156
- def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
156
+ def forward(
157
+ self,
158
+ x,
159
+ forward_batch=None,
160
+ should_allreduce_fusion=False,
161
+ gemm_output_zero_allocator: BumpAllocator = None,
162
+ ):
157
163
  if (self.tp_size == 1) and x.shape[0] == 0:
158
164
  return x
159
165
 
@@ -423,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
423
429
  routed_scaling_factor=self.routed_scaling_factor,
424
430
  )
425
431
 
426
- self.experts = get_moe_impl_class()(
432
+ self.experts = get_moe_impl_class(quant_config)(
427
433
  num_experts=config.n_routed_experts
428
434
  + self.num_fused_shared_experts
429
435
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
501
507
  hidden_states: torch.Tensor,
502
508
  should_allreduce_fusion: bool = False,
503
509
  use_reduce_scatter: bool = False,
510
+ gemm_output_zero_allocator: BumpAllocator = None,
504
511
  ) -> torch.Tensor:
505
512
 
506
513
  current_stream = torch.cuda.current_stream()
@@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
543
550
  hidden_states: torch.Tensor,
544
551
  should_allreduce_fusion: bool = False,
545
552
  use_reduce_scatter: bool = False,
553
+ gemm_output_zero_allocator: BumpAllocator = None,
546
554
  ) -> torch.Tensor:
547
555
  if hasattr(self, "shared_experts") and use_intel_amx_backend(
548
556
  self.shared_experts.gate_up_proj
@@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
666
674
  forward_batch: ForwardBatch,
667
675
  residual: Optional[torch.Tensor],
668
676
  zero_allocator: BumpAllocator,
677
+ gemm_output_zero_allocator: BumpAllocator = None,
669
678
  ) -> torch.Tensor:
670
679
  hidden_states, residual = self.layer_communicator.prepare_attn(
671
680
  hidden_states, residual, forward_batch
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
93
93
  quant_config=quant_config,
94
94
  prefix=prefix,
95
95
  num_dummy_heads=config.num_dummy_heads,
96
+ rms_norm_eps=config.rms_norm_eps,
96
97
  )
97
- self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
98
- self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
99
98
 
100
99
  self.mlp = Glm4vVisionMLP(
101
100
  config.hidden_size,
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
498
497
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
498
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
500
499
 
500
+ # For EAGLE3 support
501
+ self.capture_aux_hidden_states = False
502
+
501
503
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
504
  pixel_values = torch.cat(
503
505
  [item.feature.squeeze(0) for item in items], dim=0
@@ -74,6 +74,9 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
74
74
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
75
75
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
76
76
 
77
+ # For EAGLE3 support
78
+ self.capture_aux_hidden_states = False
79
+
77
80
  def determine_num_fused_shared_experts(
78
81
  self, architecture: str = "Glm4MoeForCausalLM"
79
82
  ):
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
121
121
  )
122
122
 
123
123
  self.top_k = config.num_experts_per_tok
124
- experts_type = get_moe_impl_class()
124
+ experts_type = get_moe_impl_class(quant_config)
125
125
  extra_kwargs = {}
126
126
  if experts_type.__name__ == "FusedMoE":
127
127
  quant_config_name = (
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.deepseek_janus_pro import DropPath
29
+ from sglang.srt.models.gpt_oss import GptOssForCausalLM
29
30
  from sglang.srt.models.internlm2 import InternLM2ForCausalLM
30
31
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
32
+ from sglang.srt.models.qwen3 import Qwen3ForCausalLM
31
33
  from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
32
34
  from sglang.utils import logger
33
35
 
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
445
447
  self.language_model = Qwen3MoeForCausalLM(
446
448
  config=config.llm_config, quant_config=quant_config
447
449
  )
450
+ elif config.llm_config.architectures[0] == "GptOssForCausalLM":
451
+ self.language_model = GptOssForCausalLM(
452
+ config=config.llm_config, quant_config=quant_config
453
+ )
454
+ elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
455
+ self.language_model = Qwen3ForCausalLM(
456
+ config=config.llm_config, quant_config=quant_config
457
+ )
448
458
  else:
449
459
  raise NotImplementedError(
450
460
  f"{config.llm_config.architectures[0]} is not implemented."
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
577
587
  ckpt_up_proj_name="up_proj",
578
588
  num_experts=self.config.num_experts,
579
589
  )
590
+ elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
591
+ stacked_params_mapping = [
592
+ # (param_name, shard_name, shard_id)
593
+ ("qkv_proj", "q_proj", "q"),
594
+ ("qkv_proj", "k_proj", "k"),
595
+ ("qkv_proj", "v_proj", "v"),
596
+ ("gate_up_proj", "gate_proj", 0),
597
+ ("gate_up_proj", "up_proj", 1),
598
+ ]
580
599
 
581
600
  params_dict = dict(self.named_parameters())
582
601
  loaded_params: Set[str] = set()
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
661
680
 
662
681
  loaded_params.add(name)
663
682
  unloaded_params = params_dict.keys() - loaded_params
683
+ # Skip params that are created by quantization wrappers and are not expected in the ckpt
684
+ _quant_only_fragments = (
685
+ "weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
686
+ )
687
+ unloaded_params = {
688
+ n
689
+ for n in unloaded_params
690
+ if not any(frag in n for frag in _quant_only_fragments)
691
+ }
664
692
  if unloaded_params:
665
693
  raise RuntimeError(
666
694
  f"Some weights are not initialized from checkpoints: {unloaded_params}"