sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
|
|
362
362
|
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
363
363
|
bitsandbytes_stacked_params_mapping = {
|
364
364
|
# shard_name, weight_name, index
|
365
|
-
"q_proj": ("qkv_proj", 0),
|
366
|
-
"k_proj": ("qkv_proj", 1),
|
367
|
-
"v_proj": ("qkv_proj", 2),
|
368
|
-
"gate_proj": ("gate_up_proj", 0),
|
369
|
-
"up_proj": ("gate_up_proj", 1),
|
365
|
+
".q_proj": (".qkv_proj", 0),
|
366
|
+
".k_proj": (".qkv_proj", 1),
|
367
|
+
".v_proj": (".qkv_proj", 2),
|
368
|
+
".gate_proj": (".gate_up_proj", 0),
|
369
|
+
".up_proj": (".gate_up_proj", 1),
|
370
370
|
}
|
371
371
|
|
372
372
|
def __init__(
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
40
40
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
42
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
|
-
from sglang.srt.utils import add_prefix,
|
43
|
+
from sglang.srt.utils import add_prefix, is_cuda
|
44
44
|
|
45
|
-
if
|
45
|
+
if is_cuda():
|
46
46
|
from sgl_kernel import bmm_fp8
|
47
47
|
|
48
48
|
|
@@ -93,158 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
|
93
93
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
94
94
|
|
95
95
|
|
96
|
-
class MiniCPM3Attention(nn.Module):
|
97
|
-
|
98
|
-
def __init__(
|
99
|
-
self,
|
100
|
-
config: PretrainedConfig,
|
101
|
-
hidden_size: int,
|
102
|
-
num_heads: int,
|
103
|
-
qk_nope_head_dim: int,
|
104
|
-
qk_rope_head_dim: int,
|
105
|
-
v_head_dim: int,
|
106
|
-
q_lora_rank: int,
|
107
|
-
kv_lora_rank: int,
|
108
|
-
rope_theta: float = 10000,
|
109
|
-
rope_scaling: Optional[Dict[str, Any]] = None,
|
110
|
-
max_position_embeddings: int = 8192,
|
111
|
-
quant_config: Optional[QuantizationConfig] = None,
|
112
|
-
layer_id=None,
|
113
|
-
prefix: str = "",
|
114
|
-
) -> None:
|
115
|
-
super().__init__()
|
116
|
-
self.layer_id = layer_id
|
117
|
-
self.hidden_size = hidden_size
|
118
|
-
self.qk_nope_head_dim = qk_nope_head_dim
|
119
|
-
self.qk_rope_head_dim = qk_rope_head_dim
|
120
|
-
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
121
|
-
self.v_head_dim = v_head_dim
|
122
|
-
self.q_lora_rank = q_lora_rank
|
123
|
-
self.kv_lora_rank = kv_lora_rank
|
124
|
-
self.num_heads = num_heads
|
125
|
-
tp_size = get_tensor_model_parallel_world_size()
|
126
|
-
assert num_heads % tp_size == 0
|
127
|
-
self.num_local_heads = num_heads // tp_size
|
128
|
-
self.scaling = self.qk_head_dim**-0.5
|
129
|
-
self.rope_theta = rope_theta
|
130
|
-
self.max_position_embeddings = max_position_embeddings
|
131
|
-
|
132
|
-
if self.q_lora_rank is not None:
|
133
|
-
self.q_a_proj = ReplicatedLinear(
|
134
|
-
self.hidden_size,
|
135
|
-
self.q_lora_rank,
|
136
|
-
bias=False,
|
137
|
-
quant_config=quant_config,
|
138
|
-
prefix=add_prefix("q_a_proj", prefix),
|
139
|
-
)
|
140
|
-
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
141
|
-
self.q_b_proj = ColumnParallelLinear(
|
142
|
-
q_lora_rank,
|
143
|
-
self.num_heads * self.qk_head_dim,
|
144
|
-
bias=False,
|
145
|
-
quant_config=quant_config,
|
146
|
-
prefix=add_prefix("q_b_proj", prefix),
|
147
|
-
)
|
148
|
-
else:
|
149
|
-
self.q_proj = ColumnParallelLinear(
|
150
|
-
self.hidden_size,
|
151
|
-
self.num_heads * self.qk_head_dim,
|
152
|
-
bias=False,
|
153
|
-
quant_config=quant_config,
|
154
|
-
prefix=add_prefix("q_proj", prefix),
|
155
|
-
)
|
156
|
-
|
157
|
-
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
158
|
-
self.hidden_size,
|
159
|
-
self.kv_lora_rank + self.qk_rope_head_dim,
|
160
|
-
bias=False,
|
161
|
-
quant_config=quant_config,
|
162
|
-
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
163
|
-
)
|
164
|
-
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
165
|
-
self.kv_b_proj = ColumnParallelLinear(
|
166
|
-
self.kv_lora_rank,
|
167
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
168
|
-
bias=False,
|
169
|
-
quant_config=quant_config,
|
170
|
-
prefix=add_prefix("kv_b_proj", prefix),
|
171
|
-
)
|
172
|
-
# O projection.
|
173
|
-
self.o_proj = RowParallelLinear(
|
174
|
-
self.num_heads * self.v_head_dim,
|
175
|
-
self.hidden_size,
|
176
|
-
bias=False,
|
177
|
-
quant_config=quant_config,
|
178
|
-
prefix=add_prefix("o_proj", prefix),
|
179
|
-
)
|
180
|
-
self.rotary_emb = get_rope(
|
181
|
-
qk_rope_head_dim,
|
182
|
-
rotary_dim=qk_rope_head_dim,
|
183
|
-
max_position=max_position_embeddings,
|
184
|
-
base=rope_theta,
|
185
|
-
rope_scaling=rope_scaling,
|
186
|
-
)
|
187
|
-
|
188
|
-
# TODO support head_size 96
|
189
|
-
self.attn = RadixAttention(
|
190
|
-
self.num_local_heads,
|
191
|
-
128,
|
192
|
-
self.scaling,
|
193
|
-
num_kv_heads=self.num_local_heads,
|
194
|
-
layer_id=layer_id,
|
195
|
-
quant_config=quant_config,
|
196
|
-
prefix=add_prefix("attn", prefix),
|
197
|
-
)
|
198
|
-
|
199
|
-
def forward(
|
200
|
-
self,
|
201
|
-
positions: torch.Tensor,
|
202
|
-
hidden_states: torch.Tensor,
|
203
|
-
forward_batch: ForwardBatch,
|
204
|
-
) -> torch.Tensor:
|
205
|
-
if self.q_lora_rank is not None:
|
206
|
-
q = self.q_a_proj(hidden_states)[0]
|
207
|
-
q = self.q_a_layernorm(q)
|
208
|
-
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
209
|
-
else:
|
210
|
-
q = self.q_proj(hidden_states)[0].view(
|
211
|
-
-1, self.num_local_heads, self.qk_head_dim
|
212
|
-
)
|
213
|
-
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
214
|
-
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
215
|
-
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
216
|
-
latent_cache = latent_cache.unsqueeze(1)
|
217
|
-
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
218
|
-
kv = self.kv_b_proj(kv_a)[0]
|
219
|
-
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
220
|
-
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
221
|
-
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
222
|
-
original_shapes = [q_pe.shape, k_pe.shape]
|
223
|
-
q_pe, k_pe = self.rotary_emb(
|
224
|
-
positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
|
225
|
-
)
|
226
|
-
q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
|
227
|
-
q[..., self.qk_nope_head_dim :] = q_pe
|
228
|
-
k = torch.empty_like(q)
|
229
|
-
k[..., : self.qk_nope_head_dim] = k_nope
|
230
|
-
k[..., self.qk_nope_head_dim :] = k_pe
|
231
|
-
q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
|
232
|
-
-1, self.num_local_heads * 128
|
233
|
-
)
|
234
|
-
k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
|
235
|
-
-1, self.num_local_heads * 128
|
236
|
-
)
|
237
|
-
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
|
238
|
-
-1, self.num_local_heads * 128
|
239
|
-
)
|
240
|
-
attn_output = self.attn(q, k, v, forward_batch)
|
241
|
-
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
|
242
|
-
..., : self.v_head_dim
|
243
|
-
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
244
|
-
output, _ = self.o_proj(attn_output)
|
245
|
-
return output
|
246
|
-
|
247
|
-
|
248
96
|
class MiniCPM3AttentionMLA(nn.Module):
|
249
97
|
|
250
98
|
def __init__(
|
@@ -434,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
434
282
|
rope_theta = getattr(config, "rope_theta", 10000)
|
435
283
|
rope_scaling = getattr(config, "rope_scaling", None)
|
436
284
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
q_lora_rank
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
else:
|
457
|
-
self.self_attn = MiniCPM3Attention(
|
458
|
-
config=config,
|
459
|
-
hidden_size=self.hidden_size,
|
460
|
-
num_heads=config.num_attention_heads,
|
461
|
-
qk_nope_head_dim=config.qk_nope_head_dim,
|
462
|
-
qk_rope_head_dim=config.qk_rope_head_dim,
|
463
|
-
v_head_dim=self.hidden_size // config.num_attention_heads,
|
464
|
-
q_lora_rank=(
|
465
|
-
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
466
|
-
),
|
467
|
-
kv_lora_rank=config.kv_lora_rank,
|
468
|
-
rope_theta=rope_theta,
|
469
|
-
rope_scaling=rope_scaling,
|
470
|
-
max_position_embeddings=max_position_embeddings,
|
471
|
-
quant_config=quant_config,
|
472
|
-
layer_id=layer_id,
|
473
|
-
prefix=add_prefix("self_attn", prefix),
|
474
|
-
)
|
285
|
+
self.self_attn = MiniCPM3AttentionMLA(
|
286
|
+
config=config,
|
287
|
+
hidden_size=self.hidden_size,
|
288
|
+
num_heads=config.num_attention_heads,
|
289
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
290
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
291
|
+
v_head_dim=self.hidden_size // config.num_attention_heads,
|
292
|
+
q_lora_rank=(
|
293
|
+
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
294
|
+
),
|
295
|
+
kv_lora_rank=config.kv_lora_rank,
|
296
|
+
rope_theta=rope_theta,
|
297
|
+
rope_scaling=rope_scaling,
|
298
|
+
max_position_embeddings=max_position_embeddings,
|
299
|
+
quant_config=quant_config,
|
300
|
+
layer_id=layer_id,
|
301
|
+
prefix=add_prefix("self_attn", prefix),
|
302
|
+
)
|
303
|
+
|
475
304
|
self.mlp = MiniCPM3MLP(
|
476
305
|
hidden_size=self.hidden_size,
|
477
306
|
intermediate_size=config.intermediate_size,
|
@@ -674,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
674
503
|
)
|
675
504
|
weight_loader(param, loaded_weight)
|
676
505
|
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
del self_attn.kv_b_proj
|
506
|
+
for layer_id in range(self.config.num_hidden_layers):
|
507
|
+
self_attn = self.model.layers[layer_id].self_attn
|
508
|
+
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
509
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
510
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
511
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
512
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
513
|
+
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
514
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
515
|
+
del self_attn.kv_b_proj
|
688
516
|
|
689
517
|
|
690
518
|
EntryClass = MiniCPM3ForCausalLM
|
sglang/srt/models/minicpmo.py
CHANGED
@@ -25,7 +25,7 @@ import torch.nn.functional as F
|
|
25
25
|
import torch.nn.utils.parametrize as P
|
26
26
|
import torch.types
|
27
27
|
from torch import nn
|
28
|
-
from torch.nn.utils import
|
28
|
+
from torch.nn.utils import parametrizations
|
29
29
|
from tqdm import tqdm
|
30
30
|
from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
|
31
31
|
from transformers.activations import ACT2FN
|
@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
|
|
585
585
|
self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
|
586
586
|
self.head_code = nn.ModuleList(
|
587
587
|
[
|
588
|
-
weight_norm(
|
588
|
+
parametrizations.weight_norm(
|
589
589
|
nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
|
590
590
|
name="weight",
|
591
591
|
)
|
@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
|
|
1859
1859
|
# the checkpoint. Skip them.
|
1860
1860
|
continue
|
1861
1861
|
|
1862
|
-
#
|
1862
|
+
# For weight_norm parametrization, handle both old and new formats
|
1863
1863
|
if self.config.init_tts and "tts" in name:
|
1864
|
-
|
1865
|
-
|
1866
|
-
|
1864
|
+
# Handle loading from older checkpoints with weight_g/weight_v format
|
1865
|
+
if ".weight_g" in name or ".weight_v" in name:
|
1866
|
+
name = name.replace(
|
1867
|
+
".weight_g", ".parametrizations.weight.original0"
|
1868
|
+
)
|
1869
|
+
name = name.replace(
|
1870
|
+
".weight_v", ".parametrizations.weight.original1"
|
1871
|
+
)
|
1872
|
+
elif ".weight" in name and name not in params_dict:
|
1873
|
+
param_name = name.replace(
|
1874
|
+
".weight", ".parametrizations.weight.original0"
|
1875
|
+
)
|
1876
|
+
if param_name in params_dict:
|
1877
|
+
name = param_name
|
1867
1878
|
|
1868
1879
|
# adapt to VisionAttention
|
1869
1880
|
if "vpm" in name:
|
sglang/srt/models/qwen2.py
CHANGED
@@ -239,6 +239,7 @@ class Qwen2Model(nn.Module):
|
|
239
239
|
config: Qwen2Config,
|
240
240
|
quant_config: Optional[QuantizationConfig] = None,
|
241
241
|
prefix: str = "",
|
242
|
+
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
|
242
243
|
) -> None:
|
243
244
|
super().__init__()
|
244
245
|
self.config = config
|
@@ -250,9 +251,11 @@ class Qwen2Model(nn.Module):
|
|
250
251
|
quant_config=quant_config,
|
251
252
|
prefix=add_prefix("embed_tokens", prefix),
|
252
253
|
)
|
254
|
+
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
255
|
+
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
253
256
|
self.layers = make_layers(
|
254
257
|
config.num_hidden_layers,
|
255
|
-
lambda idx, prefix:
|
258
|
+
lambda idx, prefix: decoder_layer_type(
|
256
259
|
layer_id=idx,
|
257
260
|
config=config,
|
258
261
|
quant_config=quant_config,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -47,7 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
47
47
|
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
|
-
from sglang.srt.utils import add_prefix
|
50
|
+
from sglang.srt.utils import add_prefix, make_layers
|
51
51
|
|
52
52
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
53
53
|
|
@@ -262,8 +262,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
262
262
|
rope_theta = getattr(config, "rope_theta", 10000)
|
263
263
|
rope_scaling = getattr(config, "rope_scaling", None)
|
264
264
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
265
|
-
|
266
|
-
qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
|
265
|
+
qkv_bias = getattr(config, "qkv_bias", True)
|
267
266
|
self.self_attn = Qwen2MoeAttention(
|
268
267
|
hidden_size=self.hidden_size,
|
269
268
|
num_heads=config.num_attention_heads,
|
@@ -334,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
|
|
334
333
|
config: PretrainedConfig,
|
335
334
|
quant_config: Optional[QuantizationConfig] = None,
|
336
335
|
prefix: str = "",
|
336
|
+
decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
|
337
337
|
) -> None:
|
338
338
|
super().__init__()
|
339
339
|
self.padding_idx = config.pad_token_id
|
@@ -344,16 +344,17 @@ class Qwen2MoeModel(nn.Module):
|
|
344
344
|
config.hidden_size,
|
345
345
|
prefix=add_prefix("embed_tokens", prefix),
|
346
346
|
)
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
347
|
+
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
348
|
+
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
|
349
|
+
self.layers = make_layers(
|
350
|
+
config.num_hidden_layers,
|
351
|
+
lambda idx, prefix: decoder_layer_type(
|
352
|
+
layer_id=idx,
|
353
|
+
config=config,
|
354
|
+
quant_config=quant_config,
|
355
|
+
prefix=prefix,
|
356
|
+
),
|
357
|
+
prefix=add_prefix("layers", prefix),
|
357
358
|
)
|
358
359
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
359
360
|
|