sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,435 @@
1
+ # Copyright 2023-2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_nas.py
15
+
16
+ """Inference-only deci model compatible with HuggingFace weights."""
17
+ from typing import Iterable, Optional, Tuple, Type, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import LlamaConfig
22
+
23
+ from sglang.srt.distributed import get_pp_group
24
+ from sglang.srt.layers.layernorm import RMSNorm
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
26
+ from sglang.srt.layers.pooler import Pooler, PoolingType
27
+ from sglang.srt.layers.quantization import QuantizationConfig
28
+ from sglang.srt.layers.utils import PPMissingLayer
29
+ from sglang.srt.layers.vocab_parallel_embedding import (
30
+ DEFAULT_VOCAB_PADDING_SIZE,
31
+ ParallelLMHead,
32
+ VocabParallelEmbedding,
33
+ )
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
+ from sglang.srt.model_loader.weight_utils import (
36
+ default_weight_loader,
37
+ maybe_remap_kv_scale_name,
38
+ )
39
+ from sglang.srt.models.llama import LlamaAttention, LlamaMLP
40
+ from sglang.srt.utils import add_prefix, make_layers
41
+ from sglang.utils import logger
42
+
43
+
44
+ def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
45
+ # DeciLM-specific code
46
+ intermediate_size = int(2 * ffn_mult * n_embd / 3)
47
+ return _find_multiple(intermediate_size, 256)
48
+
49
+
50
+ def _find_multiple(n: int, k: int) -> int:
51
+ # DeciLM-specific code
52
+ if n % k == 0:
53
+ return n
54
+ return n + k - (n % k)
55
+
56
+
57
+ class DeciLMDecoderLayer(nn.Module):
58
+
59
+ def __init__(
60
+ self,
61
+ config: LlamaConfig,
62
+ layer_idx: int,
63
+ quant_config: Optional[QuantizationConfig] = None,
64
+ prefix: str = "",
65
+ ) -> None:
66
+ super().__init__()
67
+ block_config = config.block_configs[layer_idx]
68
+ self._is_no_op_attention = block_config.attention.no_op
69
+ self._is_no_op_ffn = block_config.ffn.no_op
70
+
71
+ self.hidden_size = config.hidden_size
72
+ rope_theta = getattr(config, "rope_theta", 10000)
73
+ rope_scaling = getattr(config, "rope_scaling", None)
74
+ if rope_scaling is not None and getattr(
75
+ config, "original_max_position_embeddings", None
76
+ ):
77
+ rope_scaling["original_max_position_embeddings"] = (
78
+ config.original_max_position_embeddings
79
+ )
80
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
81
+ # Support abacusai/Smaug-72B-v0.1 with attention_bias
82
+ # Support internlm/internlm-7b with bias
83
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
84
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
85
+ config, "bias", False
86
+ )
87
+ # support internlm/internlm3-8b with qkv_bias
88
+ if hasattr(config, "qkv_bias"):
89
+ attention_bias = config.qkv_bias
90
+
91
+ if not self._is_no_op_attention:
92
+ num_kv_heads = (
93
+ config.num_attention_heads // block_config.attention.n_heads_in_group
94
+ )
95
+ self.self_attn = LlamaAttention(
96
+ config=config,
97
+ hidden_size=self.hidden_size,
98
+ num_heads=config.num_attention_heads,
99
+ num_kv_heads=num_kv_heads,
100
+ layer_id=layer_idx,
101
+ rope_theta=rope_theta,
102
+ rope_scaling=rope_scaling,
103
+ rope_is_neox_style=rope_is_neox_style,
104
+ max_position_embeddings=max_position_embeddings,
105
+ quant_config=quant_config,
106
+ prefix=add_prefix("self_attn", prefix),
107
+ bias=attention_bias,
108
+ )
109
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
110
+
111
+ if not self._is_no_op_ffn:
112
+ ffn_mult = block_config.ffn.ffn_mult
113
+ intermediate_size = _ffn_mult_to_intermediate_size(
114
+ ffn_mult, config.hidden_size
115
+ )
116
+ self.mlp = LlamaMLP(
117
+ hidden_size=self.hidden_size,
118
+ intermediate_size=intermediate_size,
119
+ hidden_act=config.hidden_act,
120
+ quant_config=quant_config,
121
+ prefix=add_prefix("mlp", prefix),
122
+ )
123
+ self.post_attention_layernorm = RMSNorm(
124
+ config.hidden_size, eps=config.rms_norm_eps
125
+ )
126
+
127
+ def forward(
128
+ self,
129
+ positions: torch.Tensor,
130
+ hidden_states: torch.Tensor,
131
+ forward_batch: ForwardBatch,
132
+ residual: Optional[torch.Tensor],
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ # Self Attention
135
+
136
+ if self._is_no_op_attention:
137
+ pass
138
+ else:
139
+ if residual is None:
140
+ residual = hidden_states
141
+ hidden_states = self.input_layernorm(hidden_states)
142
+ else:
143
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
144
+ hidden_states = self.self_attn(
145
+ positions=positions,
146
+ hidden_states=hidden_states,
147
+ forward_batch=forward_batch,
148
+ )
149
+
150
+ # Fully Connected
151
+ if not self._is_no_op_ffn:
152
+ hidden_states, residual = self.post_attention_layernorm(
153
+ hidden_states, residual
154
+ )
155
+ hidden_states = self.mlp(hidden_states)
156
+ return hidden_states, residual
157
+
158
+
159
+ class DeciModel(nn.Module):
160
+ def __init__(
161
+ self,
162
+ *,
163
+ config: LlamaConfig,
164
+ quant_config: Optional[QuantizationConfig] = None,
165
+ prefix: str = "",
166
+ layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
167
+ ):
168
+ super().__init__()
169
+
170
+ lora_config = None
171
+ self.config = config
172
+ self.quant_config = quant_config
173
+ self.padding_idx = config.pad_token_id
174
+ lora_vocab = (
175
+ (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
176
+ if lora_config
177
+ else 0
178
+ )
179
+ vocab_size = config.vocab_size + lora_vocab
180
+ if get_pp_group().is_first_rank:
181
+ self.embed_tokens = VocabParallelEmbedding(
182
+ vocab_size,
183
+ config.hidden_size,
184
+ org_num_embeddings=config.vocab_size,
185
+ quant_config=quant_config,
186
+ )
187
+ else:
188
+ self.embed_tokens = PPMissingLayer()
189
+
190
+ def get_layer(idx: int, prefix: str):
191
+ return layer_type(
192
+ config,
193
+ layer_idx=idx,
194
+ quant_config=quant_config,
195
+ prefix=prefix,
196
+ )
197
+
198
+ self.layers, self.start_layer, self.end_layer = make_layers(
199
+ config.num_hidden_layers,
200
+ get_layer,
201
+ pp_rank=get_pp_group().rank_in_group,
202
+ pp_size=get_pp_group().world_size,
203
+ prefix=add_prefix("layers", prefix),
204
+ )
205
+ if get_pp_group().is_last_rank:
206
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
207
+ else:
208
+ self.norm = PPMissingLayer(return_tuple=True)
209
+
210
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
211
+ return self.embed_tokens(input_ids)
212
+
213
+ def forward(
214
+ self,
215
+ input_ids: Optional[torch.Tensor],
216
+ positions: torch.Tensor,
217
+ forward_batch: ForwardBatch,
218
+ inputs_embeds: Optional[torch.Tensor] = None,
219
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
220
+ ) -> Union[torch.Tensor, PPProxyTensors]:
221
+ if get_pp_group().is_first_rank:
222
+ if inputs_embeds is not None:
223
+ hidden_states = inputs_embeds
224
+ else:
225
+ hidden_states = self.get_input_embeddings(input_ids)
226
+ residual = None
227
+ else:
228
+ assert pp_proxy_tensors is not None
229
+ hidden_states = pp_proxy_tensors["hidden_states"]
230
+ residual = pp_proxy_tensors["residual"]
231
+
232
+ kv_cache_index = 0
233
+ for i in range(self.start_layer, self.end_layer):
234
+ layer = self.layers[i]
235
+ if not layer._is_no_op_attention:
236
+ hidden_states, residual = layer(
237
+ positions, hidden_states, forward_batch, residual
238
+ )
239
+ kv_cache_index += 1
240
+ else:
241
+ hidden_states, residual = layer(
242
+ positions, hidden_states, forward_batch, residual
243
+ )
244
+
245
+ if not get_pp_group().is_last_rank:
246
+ return PPProxyTensors(
247
+ {"hidden_states": hidden_states, "residual": residual}
248
+ )
249
+
250
+ hidden_states, _ = self.norm(hidden_states, residual)
251
+ return hidden_states
252
+
253
+
254
+ class DeciLMForCausalLM(nn.Module):
255
+ packed_modules_mapping = {
256
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
257
+ "gate_up_proj": ["gate_proj", "up_proj"],
258
+ }
259
+
260
+ # LoRA specific attributes
261
+ supported_lora_modules = [
262
+ "qkv_proj",
263
+ "o_proj",
264
+ "gate_up_proj",
265
+ "down_proj",
266
+ "embed_tokens",
267
+ "lm_head",
268
+ ]
269
+ embedding_modules = {
270
+ "embed_tokens": "input_embeddings",
271
+ "lm_head": "output_embeddings",
272
+ }
273
+ embedding_padding_modules = ["lm_head"]
274
+
275
+ # Mistral/Llama models can also be loaded with --load-format mistral
276
+ # from consolidated.safetensors checkpoints
277
+ mistral_mapping = {
278
+ "layers": "model.layers",
279
+ "attention": "self_attn",
280
+ "wq": "q_proj",
281
+ "wk": "k_proj",
282
+ "wv": "v_proj",
283
+ "wo": "o_proj",
284
+ "attention_norm": "input_layernorm",
285
+ "feed_forward": "mlp",
286
+ "w1": "gate_proj",
287
+ "w2": "down_proj",
288
+ "w3": "up_proj",
289
+ "ffn_norm": "post_attention_layernorm",
290
+ "tok_embeddings": "model.embed_tokens",
291
+ "output": "lm_head",
292
+ "norm": "model.norm",
293
+ }
294
+
295
+ def __init__(
296
+ self,
297
+ *,
298
+ config: LlamaConfig,
299
+ quant_config: Optional[QuantizationConfig] = None,
300
+ prefix: str = "",
301
+ ):
302
+ super().__init__()
303
+ lora_config = None
304
+ self.config = config
305
+ self.lora_config = lora_config
306
+
307
+ self.model = self._init_model(
308
+ config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
309
+ )
310
+ if self.config.tie_word_embeddings:
311
+ self.lm_head = self.model.embed_tokens
312
+ else:
313
+ self.unpadded_vocab_size = config.vocab_size
314
+ if lora_config:
315
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
316
+ self.lm_head = ParallelLMHead(
317
+ self.unpadded_vocab_size,
318
+ config.hidden_size,
319
+ org_num_embeddings=config.vocab_size,
320
+ padding_size=(
321
+ DEFAULT_VOCAB_PADDING_SIZE
322
+ # We need bigger padding if using lora for kernel
323
+ # compatibility
324
+ if not lora_config
325
+ else lora_config.lora_vocab_padding_size
326
+ ),
327
+ quant_config=quant_config,
328
+ prefix=add_prefix("lm_head", prefix),
329
+ )
330
+ self.logits_processor = LogitsProcessor(config)
331
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
332
+
333
+ def _init_model(
334
+ self,
335
+ config: LlamaConfig,
336
+ quant_config: Optional[QuantizationConfig] = None,
337
+ prefix: str = "",
338
+ ):
339
+ return DeciModel(config=config, quant_config=quant_config, prefix=prefix)
340
+
341
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
342
+ return self.model.get_input_embeddings(input_ids)
343
+
344
+ @torch.no_grad()
345
+ def forward(
346
+ self,
347
+ input_ids: torch.Tensor,
348
+ positions: torch.Tensor,
349
+ forward_batch: ForwardBatch,
350
+ inputs_embeds: Optional[torch.Tensor] = None,
351
+ get_embedding: bool = False,
352
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
353
+ ) -> LogitsProcessorOutput:
354
+ hidden_states = self.model(
355
+ input_ids,
356
+ positions,
357
+ forward_batch,
358
+ inputs_embeds,
359
+ pp_proxy_tensors=pp_proxy_tensors,
360
+ )
361
+ if get_pp_group().is_last_rank:
362
+ if not get_embedding:
363
+ return self.logits_processor(
364
+ input_ids, hidden_states, self.lm_head, forward_batch
365
+ )
366
+ else:
367
+ return self.pooler(hidden_states, forward_batch)
368
+ else:
369
+ return hidden_states
370
+
371
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
372
+ stacked_params_mapping = [
373
+ # (param_name, shard_name, shard_id)
374
+ (".qkv_proj", ".q_proj", "q"),
375
+ (".qkv_proj", ".k_proj", "k"),
376
+ (".qkv_proj", ".v_proj", "v"),
377
+ (".gate_up_proj", ".gate_proj", 0),
378
+ (".gate_up_proj", ".up_proj", 1),
379
+ ]
380
+
381
+ params_dict = dict(self.named_parameters())
382
+
383
+ for name, loaded_weight in weights:
384
+ if "rotary_emb.inv_freq" in name:
385
+ continue
386
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
387
+ # Models trained using ColossalAI may include these tensors in
388
+ # the checkpoint. Skip them.
389
+ continue
390
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
391
+ continue
392
+ if self.model.quant_config is not None and (
393
+ scale_name := self.model.quant_config.get_cache_scale(name)
394
+ ):
395
+ # Loading kv cache quantization scales
396
+ param = params_dict[scale_name]
397
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
398
+ loaded_weight = (
399
+ loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
400
+ )
401
+ weight_loader(param, loaded_weight)
402
+ continue
403
+ if "scale" in name:
404
+ name = maybe_remap_kv_scale_name(name, params_dict)
405
+ if name is None:
406
+ continue
407
+
408
+ for param_name, weight_name, shard_id in stacked_params_mapping:
409
+ if weight_name not in name:
410
+ continue
411
+ name = name.replace(weight_name, param_name)
412
+ # Skip loading extra bias for GPTQ models.
413
+ if name.endswith(".bias") and name not in params_dict:
414
+ continue
415
+ if name not in params_dict:
416
+ continue
417
+ param = params_dict[name]
418
+ weight_loader = param.weight_loader
419
+ weight_loader(param, loaded_weight, shard_id)
420
+ break
421
+ else:
422
+ # Skip loading extra bias for GPTQ models.
423
+ if name.endswith(".bias") and name not in params_dict:
424
+ continue
425
+ if name in params_dict.keys():
426
+ param = params_dict[name]
427
+ weight_loader = getattr(
428
+ param, "weight_loader", default_weight_loader
429
+ )
430
+ weight_loader(param, loaded_weight)
431
+ else:
432
+ logger.warning(f"Parameter {name} not found in params_dict")
433
+
434
+
435
+ EntryClass = [DeciLMForCausalLM]
@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
89
89
  intermediate_size=intermediate_size,
90
90
  reduce_results=True,
91
91
  quant_config=quant_config,
92
- tp_size=tp_size,
93
92
  layer_id=layer_id,
94
93
  prefix=add_prefix("experts", prefix),
95
94
  )
@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = {
54
54
  }
55
55
 
56
56
 
57
- def get_navit_vision_model():
58
- vision_config = {
59
- "hidden_size": 1152,
60
- "image_size": 448,
61
- "intermediate_size": 4304,
62
- "model_type": "siglip_vision_model",
63
- "num_attention_heads": 16,
64
- "num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
65
- "patch_size": 14,
66
- }
67
- model_config = SiglipVisionConfig(**vision_config)
68
-
69
- vision_model = Idefics2VisionTransformer(
70
- config=model_config, require_post_norm=False
71
- )
72
-
73
- return vision_model
74
-
75
-
76
57
  class Phi4MMImageEncoder(nn.Module):
77
58
  """Image embedding."""
78
59
 
@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module):
88
69
  # n_embed or hidden_size
89
70
  hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
90
71
  self.type_feature = "patch"
91
-
92
- self.img_processor = get_navit_vision_model()
72
+ self.img_processor = Idefics2VisionTransformer(
73
+ config=config.vision_config, require_post_norm=False
74
+ )
93
75
 
94
76
  pe_weight = self.img_processor.embeddings.position_embedding.weight
95
77
  L, D = pe_weight.size()
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
27
27
  get_tensor_model_parallel_world_size,
28
28
  )
29
29
  from sglang.srt.layers.activation import SiluAndMul
30
+ from sglang.srt.layers.dp_attention import is_dp_attention_enabled
30
31
  from sglang.srt.layers.layernorm import RMSNorm
31
32
  from sglang.srt.layers.linear import (
32
33
  MergedColumnParallelLinear,
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
44
  ParallelLMHead,
44
45
  VocabParallelEmbedding,
45
46
  )
46
- from sglang.srt.managers.schedule_batch import global_server_args_dict
47
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
48
48
  from sglang.srt.model_loader.weight_utils import (
49
49
  default_weight_loader,
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
273
273
  config.vocab_size,
274
274
  config.hidden_size,
275
275
  quant_config=quant_config,
276
- enable_tp=not global_server_args_dict["enable_dp_attention"],
276
+ enable_tp=not is_dp_attention_enabled(),
277
277
  prefix=add_prefix("embed_tokens", prefix),
278
278
  )
279
279
  else:
@@ -117,6 +117,7 @@ class Qwen2_5_VisionBlock(nn.Module):
117
117
  attn_implementation: Optional[str] = None,
118
118
  quant_config: Optional[QuantizationConfig] = None,
119
119
  prefix: str = "",
120
+ num_dummy_heads: int = 0,
120
121
  ) -> None:
121
122
  super().__init__()
122
123
  if norm_layer is None:
@@ -157,6 +158,7 @@ class Qwen2_5_VisionBlock(nn.Module):
157
158
  flatten_batch=flatten_batch,
158
159
  quant_config=quant_config,
159
160
  prefix=add_prefix("attn", prefix),
161
+ num_dummy_heads=num_dummy_heads,
160
162
  )
161
163
  self.mlp = Qwen2_5_VLMLP(
162
164
  dim,
@@ -17,8 +17,6 @@
17
17
  """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from dataclasses import dataclass
21
- from enum import Enum, auto
22
20
  from typing import Any, Dict, Iterable, Optional, Tuple, Union
23
21
 
24
22
  import torch
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
31
29
  get_tensor_model_parallel_world_size,
32
30
  tensor_model_parallel_all_reduce,
33
31
  )
34
- from sglang.srt.eplb.expert_distribution import (
35
- ExpertDistributionRecorder,
36
- get_global_expert_distribution_recorder,
37
- )
32
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
38
33
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
39
34
  from sglang.srt.layers.activation import SiluAndMul
40
35
  from sglang.srt.layers.communicator import (
@@ -45,7 +40,7 @@ from sglang.srt.layers.communicator import (
45
40
  from sglang.srt.layers.dp_attention import (
46
41
  get_attention_tp_rank,
47
42
  get_attention_tp_size,
48
- get_local_attention_dp_size,
43
+ is_dp_attention_enabled,
49
44
  )
50
45
  from sglang.srt.layers.layernorm import RMSNorm
51
46
  from sglang.srt.layers.linear import (
@@ -54,8 +49,8 @@ from sglang.srt.layers.linear import (
54
49
  ReplicatedLinear,
55
50
  RowParallelLinear,
56
51
  )
57
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
58
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
52
+ from sglang.srt.layers.logits_processor import LogitsProcessor
53
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
59
54
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
60
55
  from sglang.srt.layers.moe.topk import TopK
61
56
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -107,10 +102,14 @@ class Qwen2MoeMLP(nn.Module):
107
102
  )
108
103
  self.act_fn = SiluAndMul()
109
104
 
110
- def forward(self, x):
105
+ def forward(
106
+ self,
107
+ x,
108
+ use_reduce_scatter: bool = False,
109
+ ):
111
110
  gate_up, _ = self.gate_up_proj(x)
112
111
  x = self.act_fn(gate_up)
113
- x, _ = self.down_proj(x)
112
+ x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
114
113
  return x
115
114
 
116
115
 
@@ -144,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
144
143
  intermediate_size=config.moe_intermediate_size,
145
144
  quant_config=quant_config,
146
145
  prefix=add_prefix("experts", prefix),
147
- # Additional args for FusedMoE
148
- **(
149
- dict(
150
- enable_flashinfer_cutlass_moe=True,
151
- )
152
- if global_server_args_dict["enable_flashinfer_cutlass_moe"]
153
- else {}
154
- ),
155
146
  )
156
147
 
157
148
  self.gate = ReplicatedLinear(
@@ -175,7 +166,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
175
166
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
176
167
 
177
168
  def forward(
178
- self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ forward_batch: Optional[ForwardBatch] = None,
172
+ use_reduce_scatter: bool = False,
179
173
  ) -> torch.Tensor:
180
174
  num_tokens, hidden_dim = hidden_states.shape
181
175
  hidden_states = hidden_states.view(-1, hidden_dim)
@@ -193,6 +187,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
193
187
  final_hidden_states = self.experts(hidden_states, topk_output)
194
188
  if shared_output is not None:
195
189
  final_hidden_states = final_hidden_states + shared_output
190
+ if self.tp_size > 1 and not use_reduce_scatter:
196
191
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
197
192
 
198
193
  return final_hidden_states.view(num_tokens, hidden_dim)
@@ -331,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
331
326
 
332
327
  self.attn_tp_size = get_attention_tp_size()
333
328
  self.attn_tp_rank = get_attention_tp_rank()
334
- self.local_dp_size = get_local_attention_dp_size()
335
329
 
336
330
  # Qwen2MoE all layers are sparse and have no nextn now
337
331
  self.is_layer_sparse = True
@@ -367,6 +361,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
367
361
  layer_scatter_modes=self.layer_scatter_modes,
368
362
  input_layernorm=self.input_layernorm,
369
363
  post_attention_layernorm=self.post_attention_layernorm,
364
+ allow_reduce_scatter=True,
370
365
  )
371
366
 
372
367
  def forward(
@@ -392,7 +387,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
392
387
  hidden_states, residual, forward_batch
393
388
  )
394
389
 
395
- hidden_states = self.mlp(hidden_states, forward_batch)
390
+ # For DP with padding, reduce scatter can be used instead of all-reduce.
391
+ use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
392
+ forward_batch
393
+ )
394
+
395
+ hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
396
396
 
397
397
  hidden_states, residual = self.layer_communicator.postprocess_layer(
398
398
  hidden_states, residual, forward_batch
@@ -420,7 +420,7 @@ class Qwen2MoeModel(nn.Module):
420
420
  self.embed_tokens = VocabParallelEmbedding(
421
421
  config.vocab_size,
422
422
  config.hidden_size,
423
- enable_tp=not global_server_args_dict["enable_dp_attention"],
423
+ enable_tp=not is_dp_attention_enabled(),
424
424
  prefix=add_prefix("embed_tokens", prefix),
425
425
  )
426
426
  else:
@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module):
327
327
  # For EAGLE3 support
328
328
  self.capture_aux_hidden_states = False
329
329
 
330
- def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
331
- return self.model.get_input_embeddings(input_ids)
330
+ def get_input_embeddings(self) -> nn.Embedding:
331
+ return self.model.get_input_embeddings()
332
332
 
333
333
  @torch.no_grad()
334
334
  def forward(