sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,423 @@
1
+ # Adapted from qwen2_moe.py
2
+
3
+ # Copyright 2023-2024 SGLang Team
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+
18
+ """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
19
+
20
+ from functools import partial
21
+ from typing import Any, Dict, Iterable, Optional, Tuple
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch import nn
26
+
27
+ from sglang.srt.distributed import (
28
+ get_tensor_model_parallel_rank,
29
+ get_tensor_model_parallel_world_size,
30
+ split_tensor_along_last_dim,
31
+ tensor_model_parallel_all_gather,
32
+ tensor_model_parallel_all_reduce,
33
+ )
34
+ from sglang.srt.layers.activation import SiluAndMul
35
+ from sglang.srt.layers.layernorm import RMSNorm
36
+ from sglang.srt.layers.linear import (
37
+ MergedColumnParallelLinear,
38
+ QKVParallelLinear,
39
+ ReplicatedLinear,
40
+ RowParallelLinear,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
+ from sglang.srt.layers.radix_attention import RadixAttention
46
+ from sglang.srt.layers.rotary_embedding import get_rope
47
+ from sglang.srt.layers.vocab_parallel_embedding import (
48
+ ParallelLMHead,
49
+ VocabParallelEmbedding,
50
+ )
51
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
53
+ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
54
+ from sglang.srt.models.qwen2_moe import Qwen2MoeModel
55
+ from sglang.srt.utils import add_prefix
56
+
57
+ Qwen3MoeConfig = None
58
+
59
+
60
+ class Qwen3MoeSparseMoeBlock(nn.Module):
61
+ def __init__(
62
+ self,
63
+ config: Qwen3MoeConfig,
64
+ quant_config: Optional[QuantizationConfig] = None,
65
+ prefix: str = "",
66
+ ):
67
+ super().__init__()
68
+ self.tp_size = get_tensor_model_parallel_world_size()
69
+
70
+ if self.tp_size > config.num_experts:
71
+ raise ValueError(
72
+ f"Tensor parallel size {self.tp_size} is greater than "
73
+ f"the number of experts {config.num_experts}."
74
+ )
75
+
76
+ self.experts = FusedMoE(
77
+ num_experts=config.num_experts,
78
+ top_k=config.num_experts_per_tok,
79
+ hidden_size=config.hidden_size,
80
+ intermediate_size=config.moe_intermediate_size,
81
+ reduce_results=False,
82
+ renormalize=config.norm_topk_prob,
83
+ quant_config=quant_config,
84
+ prefix=add_prefix("experts", prefix),
85
+ )
86
+
87
+ self.gate = ReplicatedLinear(
88
+ config.hidden_size,
89
+ config.num_experts,
90
+ bias=False,
91
+ quant_config=None,
92
+ prefix=add_prefix("gate", prefix),
93
+ )
94
+
95
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
96
+ num_tokens, hidden_dim = hidden_states.shape
97
+ hidden_states = hidden_states.view(-1, hidden_dim)
98
+
99
+ # router_logits: (num_tokens, n_experts)
100
+ router_logits, _ = self.gate(hidden_states)
101
+ final_hidden_states = self.experts(
102
+ hidden_states=hidden_states, router_logits=router_logits
103
+ )
104
+ if self.tp_size > 1:
105
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
106
+
107
+ return final_hidden_states.view(num_tokens, hidden_dim)
108
+
109
+
110
+ class Qwen3MoeAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ hidden_size: int,
114
+ num_heads: int,
115
+ num_kv_heads: int,
116
+ layer_id: int = 0,
117
+ rope_theta: float = 10000,
118
+ rope_scaling: Optional[Dict[str, Any]] = None,
119
+ max_position_embeddings: int = 8192,
120
+ head_dim: Optional[int] = None,
121
+ rms_norm_eps: float = 1e-06,
122
+ attention_bias: bool = False,
123
+ quant_config: Optional[QuantizationConfig] = None,
124
+ prefix: str = "",
125
+ ) -> None:
126
+ super().__init__()
127
+ self.hidden_size = hidden_size
128
+ self.tp_size = get_tensor_model_parallel_world_size()
129
+ self.total_num_heads = num_heads
130
+ assert self.total_num_heads % self.tp_size == 0
131
+ self.num_heads = self.total_num_heads // self.tp_size
132
+ self.total_num_kv_heads = num_kv_heads
133
+ if self.total_num_kv_heads >= self.tp_size:
134
+ # Number of KV heads is greater than TP size, so we partition
135
+ # the KV heads across multiple tensor parallel GPUs.
136
+ assert self.total_num_kv_heads % self.tp_size == 0
137
+ else:
138
+ # Number of KV heads is less than TP size, so we replicate
139
+ # the KV heads across multiple tensor parallel GPUs.
140
+ assert self.tp_size % self.total_num_kv_heads == 0
141
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
142
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
143
+ self.q_size = self.num_heads * self.head_dim
144
+ self.kv_size = self.num_kv_heads * self.head_dim
145
+ self.scaling = self.head_dim**-0.5
146
+ self.rope_theta = rope_theta
147
+ self.max_position_embeddings = max_position_embeddings
148
+ self.tp_rank = get_tensor_model_parallel_rank()
149
+
150
+ self.qkv_proj = QKVParallelLinear(
151
+ hidden_size,
152
+ self.head_dim,
153
+ self.total_num_heads,
154
+ self.total_num_kv_heads,
155
+ bias=attention_bias,
156
+ quant_config=quant_config,
157
+ prefix=add_prefix("qkv_proj", prefix),
158
+ )
159
+
160
+ self.o_proj = RowParallelLinear(
161
+ self.total_num_heads * self.head_dim,
162
+ hidden_size,
163
+ bias=attention_bias,
164
+ quant_config=quant_config,
165
+ prefix=add_prefix("o_proj", prefix),
166
+ )
167
+
168
+ self.rotary_emb = get_rope(
169
+ self.head_dim,
170
+ rotary_dim=self.head_dim,
171
+ max_position=max_position_embeddings,
172
+ base=rope_theta,
173
+ rope_scaling=rope_scaling,
174
+ )
175
+ self.attn = RadixAttention(
176
+ self.num_heads,
177
+ self.head_dim,
178
+ self.scaling,
179
+ num_kv_heads=self.num_kv_heads,
180
+ layer_id=layer_id,
181
+ prefix=add_prefix("attn", prefix),
182
+ )
183
+
184
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
185
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
186
+
187
+ def _apply_qk_norm(
188
+ self, q: torch.Tensor, k: torch.Tensor
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ q_by_head = q.reshape(-1, self.head_dim)
191
+ q_by_head = self.q_norm(q_by_head)
192
+ q = q_by_head.view(q.shape)
193
+ k_by_head = k.reshape(-1, self.head_dim)
194
+ k_by_head = self.k_norm(k_by_head)
195
+ k = k_by_head.view(k.shape)
196
+ return q, k
197
+
198
+ def forward(
199
+ self,
200
+ positions: torch.Tensor,
201
+ hidden_states: torch.Tensor,
202
+ forward_batch: ForwardBatch,
203
+ ) -> torch.Tensor:
204
+ qkv, _ = self.qkv_proj(hidden_states)
205
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
206
+ q, k = self._apply_qk_norm(q, k)
207
+ q, k = self.rotary_emb(positions, q, k)
208
+ attn_output = self.attn(q, k, v, forward_batch)
209
+ output, _ = self.o_proj(attn_output)
210
+ return output
211
+
212
+
213
+ class Qwen3MoeDecoderLayer(nn.Module):
214
+ def __init__(
215
+ self,
216
+ config: Qwen3MoeConfig,
217
+ layer_id: int,
218
+ quant_config: Optional[QuantizationConfig] = None,
219
+ prefix: str = "",
220
+ ) -> None:
221
+ super().__init__()
222
+ self.hidden_size = config.hidden_size
223
+ rope_theta = getattr(config, "rope_theta", 10000)
224
+ rope_scaling = getattr(config, "rope_scaling", None)
225
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
226
+ head_dim = getattr(
227
+ config, "head_dim", config.hidden_size // config.num_attention_heads
228
+ )
229
+ rms_norm_eps = config.rms_norm_eps
230
+ attention_bias = config.attention_bias
231
+ self.self_attn = Qwen3MoeAttention(
232
+ hidden_size=self.hidden_size,
233
+ num_heads=config.num_attention_heads,
234
+ num_kv_heads=config.num_key_value_heads,
235
+ layer_id=layer_id,
236
+ rope_theta=rope_theta,
237
+ rope_scaling=rope_scaling,
238
+ max_position_embeddings=max_position_embeddings,
239
+ head_dim=head_dim,
240
+ rms_norm_eps=rms_norm_eps,
241
+ attention_bias=attention_bias,
242
+ quant_config=quant_config,
243
+ prefix=add_prefix("self_attn", prefix),
244
+ )
245
+
246
+ # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
247
+ # `mlp_only_layers` in the config.
248
+ mlp_only_layers = (
249
+ [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
250
+ )
251
+ if (layer_id not in mlp_only_layers) and (
252
+ config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
253
+ ):
254
+ self.mlp = Qwen3MoeSparseMoeBlock(
255
+ config=config,
256
+ quant_config=quant_config,
257
+ prefix=add_prefix("mlp", prefix),
258
+ )
259
+ else:
260
+ self.mlp = Qwen3MoeMLP(
261
+ hidden_size=config.hidden_size,
262
+ intermediate_size=config.intermediate_size,
263
+ hidden_act=config.hidden_act,
264
+ quant_config=quant_config,
265
+ prefix=add_prefix("mlp", prefix),
266
+ )
267
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268
+ self.post_attention_layernorm = RMSNorm(
269
+ config.hidden_size, eps=config.rms_norm_eps
270
+ )
271
+
272
+ def forward(
273
+ self,
274
+ positions: torch.Tensor,
275
+ hidden_states: torch.Tensor,
276
+ forward_batch: ForwardBatch,
277
+ residual: Optional[torch.Tensor],
278
+ ) -> torch.Tensor:
279
+ # Self Attention
280
+ if residual is None:
281
+ residual = hidden_states
282
+ hidden_states = self.input_layernorm(hidden_states)
283
+ else:
284
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
285
+ hidden_states = self.self_attn(
286
+ positions=positions,
287
+ hidden_states=hidden_states,
288
+ forward_batch=forward_batch,
289
+ )
290
+
291
+ # Fully Connected
292
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
293
+ hidden_states = self.mlp(hidden_states)
294
+ return hidden_states, residual
295
+
296
+
297
+ class Qwen3MoeModel(Qwen2MoeModel):
298
+ def __init__(
299
+ self,
300
+ config: Qwen3MoeConfig,
301
+ quant_config: Optional[QuantizationConfig] = None,
302
+ prefix: str = "",
303
+ ) -> None:
304
+ super().__init__(
305
+ config=config,
306
+ quant_config=quant_config,
307
+ prefix=prefix,
308
+ decoder_layer_type=Qwen3MoeDecoderLayer,
309
+ )
310
+
311
+
312
+ class Qwen3MoeForCausalLM(nn.Module):
313
+
314
+ fall_back_to_pt_during_load = False
315
+
316
+ def __init__(
317
+ self,
318
+ config: Qwen3MoeConfig,
319
+ quant_config: Optional[QuantizationConfig] = None,
320
+ prefix: str = "",
321
+ ) -> None:
322
+ super().__init__()
323
+ self.config = config
324
+ self.quant_config = quant_config
325
+ self.model = Qwen3MoeModel(
326
+ config, quant_config, prefix=add_prefix("model", prefix)
327
+ )
328
+ self.lm_head = ParallelLMHead(
329
+ config.vocab_size,
330
+ config.hidden_size,
331
+ quant_config=quant_config,
332
+ prefix=add_prefix("lm_head", prefix),
333
+ )
334
+ self.logits_processor = LogitsProcessor(config)
335
+
336
+ @torch.no_grad()
337
+ def forward(
338
+ self,
339
+ input_ids: torch.Tensor,
340
+ positions: torch.Tensor,
341
+ forward_batch: ForwardBatch,
342
+ input_embeds: torch.Tensor = None,
343
+ ) -> torch.Tensor:
344
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
+ return self.logits_processor(
346
+ input_ids, hidden_states, self.lm_head, forward_batch
347
+ )
348
+
349
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
350
+ stacked_params_mapping = [
351
+ # (param_name, shard_name, shard_id)
352
+ ("qkv_proj", "q_proj", "q"),
353
+ ("qkv_proj", "k_proj", "k"),
354
+ ("qkv_proj", "v_proj", "v"),
355
+ ("gate_up_proj", "gate_proj", 0),
356
+ ("gate_up_proj", "up_proj", 1),
357
+ ]
358
+
359
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
360
+ ckpt_gate_proj_name="gate_proj",
361
+ ckpt_down_proj_name="down_proj",
362
+ ckpt_up_proj_name="up_proj",
363
+ num_experts=self.config.num_experts,
364
+ )
365
+
366
+ params_dict = dict(self.named_parameters())
367
+ for name, loaded_weight in weights:
368
+ if "rotary_emb.inv_freq" in name:
369
+ continue
370
+ for param_name, weight_name, shard_id in stacked_params_mapping:
371
+ # Skip non-stacked layers and experts (experts handled below).
372
+ if weight_name not in name:
373
+ continue
374
+ # We have mlp.experts[0].gate_proj in the checkpoint.
375
+ # Since we handle the experts below in expert_params_mapping,
376
+ # we need to skip here BEFORE we update the name, otherwise
377
+ # name will be updated to mlp.experts[0].gate_up_proj, which
378
+ # will then be updated below in expert_params_mapping
379
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
380
+ if "mlp.experts" in name:
381
+ continue
382
+ name = name.replace(weight_name, param_name)
383
+ # Skip loading extra bias for GPTQ models.
384
+ if name.endswith(".bias") and name not in params_dict:
385
+ continue
386
+ if name not in params_dict:
387
+ continue
388
+
389
+ param = params_dict[name]
390
+ weight_loader = param.weight_loader
391
+ weight_loader(param, loaded_weight, shard_id)
392
+ break
393
+ else:
394
+ for mapping in expert_params_mapping:
395
+ param_name, weight_name, expert_id, shard_id = mapping
396
+ if weight_name not in name:
397
+ continue
398
+ name = name.replace(weight_name, param_name)
399
+ param = params_dict[name]
400
+ weight_loader = param.weight_loader
401
+ weight_loader(
402
+ param,
403
+ loaded_weight,
404
+ name,
405
+ shard_id=shard_id,
406
+ expert_id=expert_id,
407
+ )
408
+ break
409
+ else:
410
+ # Skip loading extra bias for GPTQ models.
411
+ if name.endswith(".bias") and name not in params_dict:
412
+ continue
413
+ if name not in params_dict:
414
+ continue
415
+
416
+ param = params_dict[name]
417
+ weight_loader = getattr(
418
+ param, "weight_loader", default_weight_loader
419
+ )
420
+ weight_loader(param, loaded_weight)
421
+
422
+
423
+ EntryClass = Qwen3MoeForCausalLM
@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
149
149
  self.scaling,
150
150
  num_kv_heads=self.num_key_value_heads,
151
151
  layer_id=layer_id,
152
+ quant_config=quant_config,
152
153
  prefix=add_prefix("attn", prefix),
153
154
  )
154
155
 
@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
153
153
  self.scaling,
154
154
  num_kv_heads=self.num_kv_heads,
155
155
  layer_id=layer_id,
156
+ quant_config=quant_config,
156
157
  prefix=add_prefix("attn", prefix),
157
158
  )
158
159
 
@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
252
252
  self.scaling,
253
253
  num_kv_heads=self.num_kv_heads,
254
254
  layer_id=layer_id,
255
+ quant_config=quant_config,
255
256
  prefix=add_prefix("attn", prefix),
256
257
  )
257
258
 
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
983
983
  ):
984
984
  encoded = encoded[1:]
985
985
  prompt_ids += encoded
986
+ if tokenizer_manager.model_config.is_multimodal:
987
+ prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
986
988
  stop = request.stop
987
989
  image_data = None
988
990
  audio_data = None
@@ -993,7 +995,8 @@ def v1_chat_generate_request(
993
995
  image_data = conv.image_data
994
996
  audio_data = conv.audio_data
995
997
  modalities = conv.modalities
996
- stop = conv.stop_str or []
998
+ stop = conv.stop_str or [] if not request.ignore_eos else []
999
+
997
1000
  if request.stop:
998
1001
  if isinstance(request.stop, str):
999
1002
  stop.append(request.stop)
sglang/srt/patch_torch.py CHANGED
@@ -14,6 +14,7 @@
14
14
  from typing import Callable, Union
15
15
 
16
16
  import torch
17
+ from packaging import version
17
18
  from torch.multiprocessing import reductions
18
19
 
19
20
 
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
69
70
 
70
71
  def _modify_tuple(t, index: int, modifier: Callable):
71
72
  return *t[:index], modifier(t[index]), *t[index + 1 :]
73
+
74
+
75
+ def monkey_patch_torch_compile():
76
+ if version.parse(torch.__version__) < version.parse("2.8.0"):
77
+ # These things are cacheable by torch.compile. torch.compile just doesn't know it.
78
+ # This was fixed in PyTorch 2.8, but until then, we monkey patch.
79
+ import torch._higher_order_ops.auto_functionalize as af
80
+
81
+ af.auto_functionalized_v2._cacheable = True
82
+ af.auto_functionalized._cacheable = True
@@ -1,4 +1,3 @@
1
- import re
2
1
  from typing import Dict, Tuple
3
2
 
4
3
 
@@ -10,12 +10,11 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
 
13
- logger = logging.getLogger(__name__)
14
-
15
-
16
13
  if TYPE_CHECKING:
17
14
  from sglang.srt.managers.schedule_batch import ScheduleBatch
18
15
 
16
+ logger = logging.getLogger(__name__)
17
+
19
18
 
20
19
  @dataclasses.dataclass
21
20
  class SamplingBatchInfo: