sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
362
362
  column_parallel_weights_modules = [".down_proj.", ".o_proj."]
363
363
  bitsandbytes_stacked_params_mapping = {
364
364
  # shard_name, weight_name, index
365
- "q_proj": ("qkv_proj", 0),
366
- "k_proj": ("qkv_proj", 1),
367
- "v_proj": ("qkv_proj", 2),
368
- "gate_proj": ("gate_up_proj", 0),
369
- "up_proj": ("gate_up_proj", 1),
365
+ ".q_proj": (".qkv_proj", 0),
366
+ ".k_proj": (".qkv_proj", 1),
367
+ ".v_proj": (".qkv_proj", 2),
368
+ ".gate_proj": (".gate_up_proj", 0),
369
+ ".up_proj": (".gate_up_proj", 1),
370
370
  }
371
371
 
372
372
  def __init__(
@@ -40,9 +40,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
- from sglang.srt.utils import add_prefix, is_cuda_available
43
+ from sglang.srt.utils import add_prefix, is_cuda
44
44
 
45
- if is_cuda_available():
45
+ if is_cuda():
46
46
  from sgl_kernel import bmm_fp8
47
47
 
48
48
 
@@ -93,158 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn):
93
93
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
94
94
 
95
95
 
96
- class MiniCPM3Attention(nn.Module):
97
-
98
- def __init__(
99
- self,
100
- config: PretrainedConfig,
101
- hidden_size: int,
102
- num_heads: int,
103
- qk_nope_head_dim: int,
104
- qk_rope_head_dim: int,
105
- v_head_dim: int,
106
- q_lora_rank: int,
107
- kv_lora_rank: int,
108
- rope_theta: float = 10000,
109
- rope_scaling: Optional[Dict[str, Any]] = None,
110
- max_position_embeddings: int = 8192,
111
- quant_config: Optional[QuantizationConfig] = None,
112
- layer_id=None,
113
- prefix: str = "",
114
- ) -> None:
115
- super().__init__()
116
- self.layer_id = layer_id
117
- self.hidden_size = hidden_size
118
- self.qk_nope_head_dim = qk_nope_head_dim
119
- self.qk_rope_head_dim = qk_rope_head_dim
120
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
121
- self.v_head_dim = v_head_dim
122
- self.q_lora_rank = q_lora_rank
123
- self.kv_lora_rank = kv_lora_rank
124
- self.num_heads = num_heads
125
- tp_size = get_tensor_model_parallel_world_size()
126
- assert num_heads % tp_size == 0
127
- self.num_local_heads = num_heads // tp_size
128
- self.scaling = self.qk_head_dim**-0.5
129
- self.rope_theta = rope_theta
130
- self.max_position_embeddings = max_position_embeddings
131
-
132
- if self.q_lora_rank is not None:
133
- self.q_a_proj = ReplicatedLinear(
134
- self.hidden_size,
135
- self.q_lora_rank,
136
- bias=False,
137
- quant_config=quant_config,
138
- prefix=add_prefix("q_a_proj", prefix),
139
- )
140
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
141
- self.q_b_proj = ColumnParallelLinear(
142
- q_lora_rank,
143
- self.num_heads * self.qk_head_dim,
144
- bias=False,
145
- quant_config=quant_config,
146
- prefix=add_prefix("q_b_proj", prefix),
147
- )
148
- else:
149
- self.q_proj = ColumnParallelLinear(
150
- self.hidden_size,
151
- self.num_heads * self.qk_head_dim,
152
- bias=False,
153
- quant_config=quant_config,
154
- prefix=add_prefix("q_proj", prefix),
155
- )
156
-
157
- self.kv_a_proj_with_mqa = ReplicatedLinear(
158
- self.hidden_size,
159
- self.kv_lora_rank + self.qk_rope_head_dim,
160
- bias=False,
161
- quant_config=quant_config,
162
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
163
- )
164
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
165
- self.kv_b_proj = ColumnParallelLinear(
166
- self.kv_lora_rank,
167
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
168
- bias=False,
169
- quant_config=quant_config,
170
- prefix=add_prefix("kv_b_proj", prefix),
171
- )
172
- # O projection.
173
- self.o_proj = RowParallelLinear(
174
- self.num_heads * self.v_head_dim,
175
- self.hidden_size,
176
- bias=False,
177
- quant_config=quant_config,
178
- prefix=add_prefix("o_proj", prefix),
179
- )
180
- self.rotary_emb = get_rope(
181
- qk_rope_head_dim,
182
- rotary_dim=qk_rope_head_dim,
183
- max_position=max_position_embeddings,
184
- base=rope_theta,
185
- rope_scaling=rope_scaling,
186
- )
187
-
188
- # TODO support head_size 96
189
- self.attn = RadixAttention(
190
- self.num_local_heads,
191
- 128,
192
- self.scaling,
193
- num_kv_heads=self.num_local_heads,
194
- layer_id=layer_id,
195
- quant_config=quant_config,
196
- prefix=add_prefix("attn", prefix),
197
- )
198
-
199
- def forward(
200
- self,
201
- positions: torch.Tensor,
202
- hidden_states: torch.Tensor,
203
- forward_batch: ForwardBatch,
204
- ) -> torch.Tensor:
205
- if self.q_lora_rank is not None:
206
- q = self.q_a_proj(hidden_states)[0]
207
- q = self.q_a_layernorm(q)
208
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
209
- else:
210
- q = self.q_proj(hidden_states)[0].view(
211
- -1, self.num_local_heads, self.qk_head_dim
212
- )
213
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
214
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
215
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
216
- latent_cache = latent_cache.unsqueeze(1)
217
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
218
- kv = self.kv_b_proj(kv_a)[0]
219
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
220
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
221
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
222
- original_shapes = [q_pe.shape, k_pe.shape]
223
- q_pe, k_pe = self.rotary_emb(
224
- positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
225
- )
226
- q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
227
- q[..., self.qk_nope_head_dim :] = q_pe
228
- k = torch.empty_like(q)
229
- k[..., : self.qk_nope_head_dim] = k_nope
230
- k[..., self.qk_nope_head_dim :] = k_pe
231
- q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
232
- -1, self.num_local_heads * 128
233
- )
234
- k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
235
- -1, self.num_local_heads * 128
236
- )
237
- v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
238
- -1, self.num_local_heads * 128
239
- )
240
- attn_output = self.attn(q, k, v, forward_batch)
241
- attn_output = attn_output.view(-1, self.num_local_heads, 128)[
242
- ..., : self.v_head_dim
243
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
244
- output, _ = self.o_proj(attn_output)
245
- return output
246
-
247
-
248
96
  class MiniCPM3AttentionMLA(nn.Module):
249
97
 
250
98
  def __init__(
@@ -434,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module):
434
282
  rope_theta = getattr(config, "rope_theta", 10000)
435
283
  rope_scaling = getattr(config, "rope_scaling", None)
436
284
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
437
- if not global_server_args_dict["disable_mla"]:
438
- self.self_attn = MiniCPM3AttentionMLA(
439
- config=config,
440
- hidden_size=self.hidden_size,
441
- num_heads=config.num_attention_heads,
442
- qk_nope_head_dim=config.qk_nope_head_dim,
443
- qk_rope_head_dim=config.qk_rope_head_dim,
444
- v_head_dim=self.hidden_size // config.num_attention_heads,
445
- q_lora_rank=(
446
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
447
- ),
448
- kv_lora_rank=config.kv_lora_rank,
449
- rope_theta=rope_theta,
450
- rope_scaling=rope_scaling,
451
- max_position_embeddings=max_position_embeddings,
452
- quant_config=quant_config,
453
- layer_id=layer_id,
454
- prefix=add_prefix("self_attn", prefix),
455
- )
456
- else:
457
- self.self_attn = MiniCPM3Attention(
458
- config=config,
459
- hidden_size=self.hidden_size,
460
- num_heads=config.num_attention_heads,
461
- qk_nope_head_dim=config.qk_nope_head_dim,
462
- qk_rope_head_dim=config.qk_rope_head_dim,
463
- v_head_dim=self.hidden_size // config.num_attention_heads,
464
- q_lora_rank=(
465
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
466
- ),
467
- kv_lora_rank=config.kv_lora_rank,
468
- rope_theta=rope_theta,
469
- rope_scaling=rope_scaling,
470
- max_position_embeddings=max_position_embeddings,
471
- quant_config=quant_config,
472
- layer_id=layer_id,
473
- prefix=add_prefix("self_attn", prefix),
474
- )
285
+ self.self_attn = MiniCPM3AttentionMLA(
286
+ config=config,
287
+ hidden_size=self.hidden_size,
288
+ num_heads=config.num_attention_heads,
289
+ qk_nope_head_dim=config.qk_nope_head_dim,
290
+ qk_rope_head_dim=config.qk_rope_head_dim,
291
+ v_head_dim=self.hidden_size // config.num_attention_heads,
292
+ q_lora_rank=(
293
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
294
+ ),
295
+ kv_lora_rank=config.kv_lora_rank,
296
+ rope_theta=rope_theta,
297
+ rope_scaling=rope_scaling,
298
+ max_position_embeddings=max_position_embeddings,
299
+ quant_config=quant_config,
300
+ layer_id=layer_id,
301
+ prefix=add_prefix("self_attn", prefix),
302
+ )
303
+
475
304
  self.mlp = MiniCPM3MLP(
476
305
  hidden_size=self.hidden_size,
477
306
  intermediate_size=config.intermediate_size,
@@ -674,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module):
674
503
  )
675
504
  weight_loader(param, loaded_weight)
676
505
 
677
- if not global_server_args_dict["disable_mla"]:
678
- for layer_id in range(self.config.num_hidden_layers):
679
- self_attn = self.model.layers[layer_id].self_attn
680
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
681
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
682
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
683
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
684
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
685
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
686
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
687
- del self_attn.kv_b_proj
506
+ for layer_id in range(self.config.num_hidden_layers):
507
+ self_attn = self.model.layers[layer_id].self_attn
508
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
509
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
510
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
511
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
512
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
513
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
514
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
515
+ del self_attn.kv_b_proj
688
516
 
689
517
 
690
518
  EntryClass = MiniCPM3ForCausalLM
@@ -25,7 +25,7 @@ import torch.nn.functional as F
25
25
  import torch.nn.utils.parametrize as P
26
26
  import torch.types
27
27
  from torch import nn
28
- from torch.nn.utils import weight_norm
28
+ from torch.nn.utils import parametrizations
29
29
  from tqdm import tqdm
30
30
  from transformers import LlamaConfig, LlamaModel, PretrainedConfig, PreTrainedModel
31
31
  from transformers.activations import ACT2FN
@@ -585,7 +585,7 @@ class ConditionalChatTTS(PreTrainedModel):
585
585
  self.emb_text = nn.Embedding(config.num_text_tokens, config.hidden_size)
586
586
  self.head_code = nn.ModuleList(
587
587
  [
588
- weight_norm(
588
+ parametrizations.weight_norm(
589
589
  nn.Linear(config.hidden_size, config.num_audio_tokens, bias=False),
590
590
  name="weight",
591
591
  )
@@ -1859,11 +1859,22 @@ class MiniCPMO(MiniCPMBaseModel):
1859
1859
  # the checkpoint. Skip them.
1860
1860
  continue
1861
1861
 
1862
- # adapt to parametrization
1862
+ # For weight_norm parametrization, handle both old and new formats
1863
1863
  if self.config.init_tts and "tts" in name:
1864
- name = name.replace(".parametrizations", "")
1865
- name = name.replace(".weight.original0", ".weight_g")
1866
- name = name.replace(".weight.original1", ".weight_v")
1864
+ # Handle loading from older checkpoints with weight_g/weight_v format
1865
+ if ".weight_g" in name or ".weight_v" in name:
1866
+ name = name.replace(
1867
+ ".weight_g", ".parametrizations.weight.original0"
1868
+ )
1869
+ name = name.replace(
1870
+ ".weight_v", ".parametrizations.weight.original1"
1871
+ )
1872
+ elif ".weight" in name and name not in params_dict:
1873
+ param_name = name.replace(
1874
+ ".weight", ".parametrizations.weight.original0"
1875
+ )
1876
+ if param_name in params_dict:
1877
+ name = param_name
1867
1878
 
1868
1879
  # adapt to VisionAttention
1869
1880
  if "vpm" in name:
@@ -239,6 +239,7 @@ class Qwen2Model(nn.Module):
239
239
  config: Qwen2Config,
240
240
  quant_config: Optional[QuantizationConfig] = None,
241
241
  prefix: str = "",
242
+ decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
242
243
  ) -> None:
243
244
  super().__init__()
244
245
  self.config = config
@@ -250,9 +251,11 @@ class Qwen2Model(nn.Module):
250
251
  quant_config=quant_config,
251
252
  prefix=add_prefix("embed_tokens", prefix),
252
253
  )
254
+ # Use the provided decoder layer type or default to Qwen2DecoderLayer
255
+ decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
253
256
  self.layers = make_layers(
254
257
  config.num_hidden_layers,
255
- lambda idx, prefix: Qwen2DecoderLayer(
258
+ lambda idx, prefix: decoder_layer_type(
256
259
  layer_id=idx,
257
260
  config=config,
258
261
  quant_config=quant_config,
@@ -47,7 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
- from sglang.srt.utils import add_prefix
50
+ from sglang.srt.utils import add_prefix, make_layers
51
51
 
52
52
  expert_distribution_recorder = ExpertDistributionRecorder()
53
53
 
@@ -262,8 +262,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
262
262
  rope_theta = getattr(config, "rope_theta", 10000)
263
263
  rope_scaling = getattr(config, "rope_scaling", None)
264
264
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
265
- # note: replace config.num_hidden_layers < 80 with True once its available in transformers 4.50.0
266
- qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
265
+ qkv_bias = getattr(config, "qkv_bias", True)
267
266
  self.self_attn = Qwen2MoeAttention(
268
267
  hidden_size=self.hidden_size,
269
268
  num_heads=config.num_attention_heads,
@@ -334,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
334
333
  config: PretrainedConfig,
335
334
  quant_config: Optional[QuantizationConfig] = None,
336
335
  prefix: str = "",
336
+ decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
337
337
  ) -> None:
338
338
  super().__init__()
339
339
  self.padding_idx = config.pad_token_id
@@ -344,16 +344,17 @@ class Qwen2MoeModel(nn.Module):
344
344
  config.hidden_size,
345
345
  prefix=add_prefix("embed_tokens", prefix),
346
346
  )
347
- self.layers = nn.ModuleList(
348
- [
349
- Qwen2MoeDecoderLayer(
350
- config,
351
- layer_id,
352
- quant_config=quant_config,
353
- prefix=add_prefix(f"layers.{layer_id}", prefix),
354
- )
355
- for layer_id in range(config.num_hidden_layers)
356
- ]
347
+ # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
348
+ decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
349
+ self.layers = make_layers(
350
+ config.num_hidden_layers,
351
+ lambda idx, prefix: decoder_layer_type(
352
+ layer_id=idx,
353
+ config=config,
354
+ quant_config=quant_config,
355
+ prefix=prefix,
356
+ ),
357
+ prefix=add_prefix("layers", prefix),
357
358
  )
358
359
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359
360