sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
1
+ # coding=utf-8
2
+ # Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ SGLang BailingMoENextN model."""
21
+ import logging
22
+ from typing import Iterable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import PretrainedConfig
27
+
28
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
29
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
+ from sglang.srt.layers.layernorm import RMSNorm
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.moe.topk import select_experts
33
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
+ from sglang.srt.layers.vocab_parallel_embedding import (
35
+ ParallelLMHead,
36
+ VocabParallelEmbedding,
37
+ )
38
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
39
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
41
+ from sglang.srt.utils import add_prefix
42
+
43
+ LoraConfig = None
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class BailingMoEModelNextN(nn.Module):
48
+ def __init__(
49
+ self,
50
+ config: PretrainedConfig,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ prefix: str = "",
53
+ ) -> None:
54
+ super().__init__()
55
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
56
+ logger.warning(
57
+ "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
58
+ )
59
+ quant_config = None
60
+
61
+ self.vocab_size = config.vocab_size
62
+
63
+ self.word_embeddings = VocabParallelEmbedding(
64
+ config.vocab_size,
65
+ config.hidden_size,
66
+ enable_tp=not is_dp_attention_enabled(),
67
+ prefix=add_prefix("word_embeddings", prefix),
68
+ )
69
+
70
+ self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
71
+ self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
72
+
73
+ self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
74
+
75
+ self.decoder = BailingMoEBlock(
76
+ config,
77
+ 0,
78
+ quant_config=quant_config,
79
+ # is_nextn=True,
80
+ prefix=add_prefix("decoder", prefix),
81
+ )
82
+
83
+ self.shared_head = nn.Module()
84
+ self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
85
+
86
+ def forward(
87
+ self,
88
+ input_ids: torch.Tensor,
89
+ positions: torch.Tensor,
90
+ forward_batch: ForwardBatch,
91
+ input_embeds: torch.Tensor = None,
92
+ ) -> torch.Tensor:
93
+
94
+ if input_embeds is None:
95
+ hidden_states = self.word_embeddings(input_ids)
96
+ else:
97
+ hidden_states = input_embeds
98
+
99
+ if hidden_states.shape[0] > 0:
100
+ hidden_states = self.eh_proj(
101
+ torch.cat(
102
+ (
103
+ self.enorm(hidden_states),
104
+ self.hnorm(forward_batch.spec_info.hidden_states),
105
+ ),
106
+ dim=-1,
107
+ )
108
+ )
109
+
110
+ residual = None
111
+ hidden_states, residual = self.decoder(
112
+ positions, hidden_states, forward_batch, residual
113
+ )
114
+
115
+ if not forward_batch.forward_mode.is_idle():
116
+ if residual is not None:
117
+ hidden_states, _ = self.final_layernorm(hidden_states, residual)
118
+ else:
119
+ hidden_states = self.final_layernorm(hidden_states)
120
+
121
+ return hidden_states
122
+
123
+
124
+ class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
125
+
126
+ def __init__(
127
+ self,
128
+ config: PretrainedConfig,
129
+ quant_config: Optional[QuantizationConfig] = None,
130
+ prefix: str = "",
131
+ ) -> None:
132
+ nn.Module.__init__(self)
133
+ self.config = config
134
+ self.tp_size = get_tensor_model_parallel_world_size()
135
+ self.quant_config = quant_config
136
+ if hasattr(self, "determine_num_fused_shared_experts"):
137
+ # Asystem has determine_num_fused_shared_experts but theta does not.
138
+ self.determine_num_fused_shared_experts("BailingMoeForCausalLMNextN")
139
+
140
+ self.model = BailingMoEModelNextN(
141
+ config, quant_config, prefix=add_prefix("model", prefix)
142
+ )
143
+ self.lm_head = ParallelLMHead(
144
+ config.vocab_size,
145
+ config.hidden_size,
146
+ quant_config=quant_config,
147
+ prefix=add_prefix("model.shared_head.head", prefix),
148
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
149
+ )
150
+ self.logits_processor = LogitsProcessor(config)
151
+
152
+ @torch.no_grad()
153
+ def forward(
154
+ self,
155
+ input_ids: torch.Tensor,
156
+ positions: torch.Tensor,
157
+ forward_batch: ForwardBatch,
158
+ ) -> torch.Tensor:
159
+ hidden_states = self.model(input_ids, positions, forward_batch)
160
+ return self.logits_processor(
161
+ input_ids, hidden_states, self.lm_head, forward_batch
162
+ )
163
+
164
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
165
+ super().load_weights(weights, is_nextn=True)
166
+
167
+
168
+ EntryClass = [BailingMoeForCausalLMNextN]
@@ -65,10 +65,11 @@ from sglang.srt.layers.moe import (
65
65
  get_deepep_mode,
66
66
  get_moe_a2a_backend,
67
67
  should_use_flashinfer_cutlass_moe_fp4_allgather,
68
+ should_use_flashinfer_trtllm_moe,
68
69
  )
69
70
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
71
  from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
- from sglang.srt.layers.moe.topk import TopK
72
+ from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
72
73
  from sglang.srt.layers.quantization import deep_gemm_wrapper
73
74
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
75
  from sglang.srt.layers.quantization.fp8_kernel import (
@@ -151,6 +152,7 @@ if _is_cuda:
151
152
  from sgl_kernel import (
152
153
  awq_dequantize,
153
154
  bmm_fp8,
155
+ concat_mla_k,
154
156
  dsv3_fused_a_gemm,
155
157
  dsv3_router_gemm,
156
158
  merge_state_v2,
@@ -246,7 +248,11 @@ class DeepseekV2MLP(nn.Module):
246
248
  if (self.tp_size == 1) and x.shape[0] == 0:
247
249
  return x
248
250
 
249
- if gemm_output_zero_allocator != None and x.shape[0] <= 256:
251
+ if (
252
+ gemm_output_zero_allocator is not None
253
+ and x.shape[0] <= 256
254
+ and self.gate_up_proj.weight.dtype == torch.uint8
255
+ ):
250
256
  y = gemm_output_zero_allocator.allocate(
251
257
  x.shape[0] * self.gate_up_proj.output_size_per_partition
252
258
  ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
@@ -264,6 +270,7 @@ class MoEGate(nn.Module):
264
270
  def __init__(
265
271
  self,
266
272
  config,
273
+ quant_config,
267
274
  prefix: str = "",
268
275
  is_nextn: bool = False,
269
276
  ):
@@ -273,8 +280,15 @@ class MoEGate(nn.Module):
273
280
  torch.empty((config.n_routed_experts, config.hidden_size))
274
281
  )
275
282
  if config.topk_method == "noaux_tc":
283
+ correction_bias_dtype = (
284
+ torch.bfloat16
285
+ if quant_config is not None
286
+ and quant_config.get_name() == "modelopt_fp4"
287
+ and should_use_flashinfer_trtllm_moe()
288
+ else torch.float32
289
+ )
276
290
  self.e_score_correction_bias = nn.Parameter(
277
- torch.empty((config.n_routed_experts), dtype=torch.float32)
291
+ torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
278
292
  )
279
293
  else:
280
294
  self.e_score_correction_bias = None
@@ -299,7 +313,9 @@ class MoEGate(nn.Module):
299
313
  and _device_sm >= 90
300
314
  ):
301
315
  # router gemm output float32
302
- logits = dsv3_router_gemm(hidden_states, self.weight)
316
+ logits = dsv3_router_gemm(
317
+ hidden_states, self.weight, out_dtype=torch.float32
318
+ )
303
319
  elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
304
320
  logits = aiter_dsv3_router_gemm(
305
321
  hidden_states, self.weight, gemm_output_zero_allocator
@@ -347,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
347
363
  )
348
364
 
349
365
  self.gate = MoEGate(
350
- config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
366
+ config=config,
367
+ quant_config=quant_config,
368
+ prefix=add_prefix("gate", prefix),
369
+ is_nextn=is_nextn,
351
370
  )
352
371
 
353
372
  self.experts = get_moe_impl_class(quant_config)(
@@ -372,9 +391,12 @@ class DeepseekV2MoE(nn.Module):
372
391
  num_fused_shared_experts=self.num_fused_shared_experts,
373
392
  topk_group=config.topk_group,
374
393
  correction_bias=self.gate.e_score_correction_bias,
394
+ quant_config=quant_config,
375
395
  routed_scaling_factor=self.routed_scaling_factor,
376
396
  apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
377
- force_topk=quant_config is None,
397
+ # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
398
+ # and requires the output format to be standard. We use quant_config to determine the output format.
399
+ output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
378
400
  )
379
401
 
380
402
  self.shared_experts_is_int8 = False
@@ -661,10 +683,14 @@ class DeepseekV2MoE(nn.Module):
661
683
 
662
684
  if shared_output is not None:
663
685
  x = shared_output
664
- x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
686
+ if self.experts.should_fuse_routed_scaling_factor_in_topk():
687
+ x.add_(final_hidden_states)
688
+ else:
689
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
665
690
  final_hidden_states = x
666
691
  else:
667
- final_hidden_states *= self.routed_scaling_factor
692
+ if not self.experts.should_fuse_routed_scaling_factor_in_topk():
693
+ final_hidden_states *= self.routed_scaling_factor
668
694
 
669
695
  return final_hidden_states
670
696
 
@@ -1033,6 +1059,15 @@ class DeepseekV2AttentionMLA(nn.Module):
1033
1059
  # Determine attention backend used by current forward batch
1034
1060
  if forward_batch.forward_mode.is_decode_or_idle():
1035
1061
  attention_backend = global_server_args_dict["decode_attention_backend"]
1062
+ elif (
1063
+ forward_batch.forward_mode.is_target_verify()
1064
+ or forward_batch.forward_mode.is_draft_extend()
1065
+ ):
1066
+ # Use the specified backend for speculative operations (both verify and draft extend)
1067
+ if global_server_args_dict["speculative_attention_mode"] == "decode":
1068
+ attention_backend = global_server_args_dict["decode_attention_backend"]
1069
+ else: # default to prefill
1070
+ attention_backend = global_server_args_dict["prefill_attention_backend"]
1036
1071
  else:
1037
1072
  attention_backend = global_server_args_dict["prefill_attention_backend"]
1038
1073
  self.current_attention_backend = attention_backend
@@ -1050,7 +1085,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1050
1085
  attention_backend == "flashinfer"
1051
1086
  or attention_backend == "fa3"
1052
1087
  or attention_backend == "flashmla"
1053
- or attention_backend == "trtllm_mla"
1054
1088
  or attention_backend == "cutlass_mla"
1055
1089
  ):
1056
1090
  # Use MHA with chunked KV cache when prefilling on long sequences.
@@ -1063,6 +1097,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1063
1097
  disable_ragged = (
1064
1098
  attention_backend == "flashinfer" or attention_backend == "flashmla"
1065
1099
  ) and self.flashinfer_mla_disable_ragged
1100
+
1101
+ original_mode = getattr(forward_batch, "_original_forward_mode", None)
1066
1102
  if (
1067
1103
  not disable_ragged
1068
1104
  and forward_batch.forward_mode.is_extend()
@@ -1075,6 +1111,40 @@ class DeepseekV2AttentionMLA(nn.Module):
1075
1111
  )
1076
1112
  or sum_extend_prefix_lens == 0
1077
1113
  )
1114
+ # TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
1115
+ # dp case. Redirect to mla kernel as a workaround.
1116
+ # Tracked by https://github.com/sgl-project/sglang/issues/9806.
1117
+ and not (
1118
+ original_mode is not None
1119
+ and original_mode.is_decode()
1120
+ and is_sm100_supported()
1121
+ and self.current_attention_backend in ("cutlass_mla", "flashinfer")
1122
+ )
1123
+ ):
1124
+ return AttnForwardMethod.MHA_CHUNKED_KV
1125
+ else:
1126
+ return _dispatch_mla_subtype()
1127
+ elif attention_backend == "trtllm_mla":
1128
+ original_mode = getattr(forward_batch, "_original_forward_mode", None)
1129
+ if (
1130
+ original_mode is not None
1131
+ and original_mode.is_decode()
1132
+ and is_sm100_supported()
1133
+ ):
1134
+ return _dispatch_mla_subtype()
1135
+
1136
+ sum_extend_prefix_lens = (
1137
+ sum(forward_batch.extend_prefix_lens_cpu)
1138
+ if forward_batch.extend_prefix_lens_cpu is not None
1139
+ else 0
1140
+ )
1141
+ if (
1142
+ forward_batch.forward_mode.is_extend()
1143
+ and not forward_batch.forward_mode.is_target_verify()
1144
+ and not forward_batch.forward_mode.is_draft_extend()
1145
+ and (
1146
+ not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
1147
+ )
1078
1148
  ):
1079
1149
  return AttnForwardMethod.MHA_CHUNKED_KV
1080
1150
  else:
@@ -1235,8 +1305,18 @@ class DeepseekV2AttentionMLA(nn.Module):
1235
1305
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1236
1306
  q[..., self.qk_nope_head_dim :] = q_pe
1237
1307
  k = torch.empty_like(q)
1238
- k[..., : self.qk_nope_head_dim] = k_nope
1239
- k[..., self.qk_nope_head_dim :] = k_pe
1308
+
1309
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1310
+ if (
1311
+ _is_cuda
1312
+ and (self.num_local_heads == 128)
1313
+ and (self.qk_nope_head_dim == 128)
1314
+ and (self.qk_rope_head_dim == 64)
1315
+ ):
1316
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1317
+ else:
1318
+ k[..., : self.qk_nope_head_dim] = k_nope
1319
+ k[..., self.qk_nope_head_dim :] = k_pe
1240
1320
 
1241
1321
  if not _is_npu:
1242
1322
  latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
@@ -1998,7 +2078,10 @@ class DeepseekV2DecoderLayer(nn.Module):
1998
2078
  quant_format = (
1999
2079
  "mxfp4"
2000
2080
  if _is_gfx95_supported
2001
- and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
2081
+ and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
2082
+ and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
2083
+ is not None
2084
+ and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2002
2085
  else ""
2003
2086
  )
2004
2087
 
@@ -2170,8 +2253,15 @@ class DeepseekV2Model(nn.Module):
2170
2253
  [
2171
2254
  "w13_weight",
2172
2255
  "w2_weight",
2173
- "w13_blockscale_swizzled",
2174
- "w2_blockscale_swizzled",
2256
+ # only for nvfp4
2257
+ *(
2258
+ [
2259
+ "w13_blockscale_swizzled",
2260
+ "w2_blockscale_swizzled",
2261
+ ]
2262
+ if hasattr(module, "w13_blockscale_swizzled")
2263
+ else []
2264
+ ),
2175
2265
  ]
2176
2266
  if isinstance(module, FusedMoE)
2177
2267
  else []
@@ -2553,7 +2643,11 @@ class DeepseekV2ForCausalLM(nn.Module):
2553
2643
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
2554
2644
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2555
2645
 
2556
- if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
2646
+ if (
2647
+ _use_aiter_gfx95
2648
+ and self.quant_config is not None
2649
+ and self.quant_config.get_name() == "quark"
2650
+ ):
2557
2651
  w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
2558
2652
  quark_post_load_weights(self_attn, w, "mxfp4")
2559
2653
  )
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 The RedNote HiLab team.
2
+ # Copyright 2025 The SGLang team.
3
+ #
4
+ # This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
5
+ # implementation in this library.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """Inference-only Dots-VL model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.configs.dots_vlm import DotsVLMConfig
26
+ from sglang.srt.distributed import parallel_state
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.managers.mm_utils import (
29
+ MultiModalityDataPaddingPatternMultimodalTokens,
30
+ general_mm_embed_routine,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
35
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
36
+
37
+ from .dots_vlm_vit import DotsVisionTransformer
38
+
39
+
40
+ class DotsVLMForCausalLM(nn.Module):
41
+ """DotsVLM model for sglang inference"""
42
+
43
+ def __init__(
44
+ self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ self.config = config
49
+ self.image_token_id = config.im_span_id
50
+ self.video_token_id = config.video_span_id
51
+
52
+ self.language_model = DeepseekV2ForCausalLM(
53
+ config.language_config, quant_config
54
+ )
55
+
56
+ # Initialize vision tower (matching transformers naming for weight compatibility)
57
+ self.vision_tower = DotsVisionTransformer(config.vision_config)
58
+
59
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
60
+ """pad attn qkv weights for dummy heads"""
61
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
62
+ if num_dummy_heads == 0:
63
+ return loaded_weight
64
+ head_dim = self.config.vision_config.head_dim
65
+
66
+ if "attn.qkv_proj" in name:
67
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
68
+ if name.endswith(".weight"):
69
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
70
+ elif name.endswith(".bias"):
71
+ dummy_shape = [num_dummy_heads, head_dim]
72
+ else:
73
+ raise RuntimeError(f"Unsupported weight with name={name}")
74
+ pad_func = lambda x: torch.cat(
75
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
76
+ ).flatten(0, 1)
77
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
78
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
79
+ if "attn.proj.weight" in name:
80
+ padded_weight = loaded_weight.new_zeros(
81
+ loaded_weight.shape[0], head_dim * num_dummy_heads
82
+ )
83
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
84
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
85
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
86
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
87
+ return loaded_weight
88
+
89
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
90
+ """Load weights for the model, separating vision and language weights"""
91
+ weights = list(weights)
92
+
93
+ # Separate vision tower weights and language model weights
94
+ vision_weights = []
95
+ language_weights = []
96
+
97
+ for name, loaded_weight in weights:
98
+ if name.startswith("vision_tower."):
99
+ vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
100
+ vision_weights.append((vision_name, loaded_weight))
101
+ else:
102
+ # All other weights go to language model
103
+ language_weights.append((name, loaded_weight))
104
+
105
+ # Load vision tower weights
106
+ vision_state_dict = dict(vision_weights)
107
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
108
+ for name, loaded_weight in vision_state_dict.items():
109
+ if name not in params_dict:
110
+ raise ValueError(f"Weight {name} not found in params_dict")
111
+ param = params_dict[name]
112
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
113
+ loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
114
+ weight_loader(param, loaded_weight)
115
+
116
+ # Load language model weights
117
+ if language_weights:
118
+ self.language_model.load_weights(language_weights)
119
+
120
+ @classmethod
121
+ def get_model_config_for_expert_location(cls, config):
122
+ return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
123
+
124
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
125
+ """Pad input_ids with multimodal tokens"""
126
+ # Get image token ID for padding pattern
127
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
128
+ padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
129
+ return padded_input_ids
130
+
131
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
132
+ # Extract pixel values and grid information (following reference pattern)
133
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
134
+ self.vision_tower.dtype
135
+ )
136
+ image_grid_thw = torch.concat(
137
+ [item.image_grid_thw for item in items], dim=0
138
+ ).to(self.vision_tower.device)
139
+
140
+ # Add dimension checks like in reference code
141
+ assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
142
+ assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
143
+
144
+ # Process through vision tower
145
+ image_embeds = self.vision_tower(pixel_values, image_grid_thw)
146
+
147
+ # Ensure consistent dtype for FlashInfer compatibility
148
+ # Force bfloat16 to match model's expected dtype
149
+ if image_embeds.dtype != torch.bfloat16 and hasattr(
150
+ self.language_model.model, "embed_tokens"
151
+ ):
152
+ target_dtype = self.language_model.model.embed_tokens.weight.dtype
153
+ image_embeds = image_embeds.to(target_dtype)
154
+
155
+ return image_embeds
156
+
157
+ def forward(
158
+ self,
159
+ input_ids: torch.Tensor,
160
+ positions: torch.Tensor,
161
+ forward_batch: ForwardBatch,
162
+ **kwargs: object,
163
+ ) -> torch.Tensor:
164
+ hidden_states = general_mm_embed_routine(
165
+ input_ids=input_ids,
166
+ positions=positions,
167
+ forward_batch=forward_batch,
168
+ multimodal_model=self,
169
+ language_model=self.language_model,
170
+ )
171
+ return hidden_states
172
+
173
+
174
+ EntryClass = [DotsVLMForCausalLM]