sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -362,11 +362,11 @@ class LlamaForCausalLM(nn.Module):
362
362
  column_parallel_weights_modules = [".down_proj.", ".o_proj."]
363
363
  bitsandbytes_stacked_params_mapping = {
364
364
  # shard_name, weight_name, index
365
- "q_proj": ("qkv_proj", 0),
366
- "k_proj": ("qkv_proj", 1),
367
- "v_proj": ("qkv_proj", 2),
368
- "gate_proj": ("gate_up_proj", 0),
369
- "up_proj": ("gate_up_proj", 1),
365
+ ".q_proj": (".qkv_proj", 0),
366
+ ".k_proj": (".qkv_proj", 1),
367
+ ".v_proj": (".qkv_proj", 2),
368
+ ".gate_proj": (".gate_up_proj", 0),
369
+ ".up_proj": (".gate_up_proj", 1),
370
370
  }
371
371
 
372
372
  def __init__(
@@ -93,158 +93,6 @@ def input_to_float8(x, dtype=torch.float8_e4m3fn):
93
93
  return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
94
94
 
95
95
 
96
- class MiniCPM3Attention(nn.Module):
97
-
98
- def __init__(
99
- self,
100
- config: PretrainedConfig,
101
- hidden_size: int,
102
- num_heads: int,
103
- qk_nope_head_dim: int,
104
- qk_rope_head_dim: int,
105
- v_head_dim: int,
106
- q_lora_rank: int,
107
- kv_lora_rank: int,
108
- rope_theta: float = 10000,
109
- rope_scaling: Optional[Dict[str, Any]] = None,
110
- max_position_embeddings: int = 8192,
111
- quant_config: Optional[QuantizationConfig] = None,
112
- layer_id=None,
113
- prefix: str = "",
114
- ) -> None:
115
- super().__init__()
116
- self.layer_id = layer_id
117
- self.hidden_size = hidden_size
118
- self.qk_nope_head_dim = qk_nope_head_dim
119
- self.qk_rope_head_dim = qk_rope_head_dim
120
- self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
121
- self.v_head_dim = v_head_dim
122
- self.q_lora_rank = q_lora_rank
123
- self.kv_lora_rank = kv_lora_rank
124
- self.num_heads = num_heads
125
- tp_size = get_tensor_model_parallel_world_size()
126
- assert num_heads % tp_size == 0
127
- self.num_local_heads = num_heads // tp_size
128
- self.scaling = self.qk_head_dim**-0.5
129
- self.rope_theta = rope_theta
130
- self.max_position_embeddings = max_position_embeddings
131
-
132
- if self.q_lora_rank is not None:
133
- self.q_a_proj = ReplicatedLinear(
134
- self.hidden_size,
135
- self.q_lora_rank,
136
- bias=False,
137
- quant_config=quant_config,
138
- prefix=add_prefix("q_a_proj", prefix),
139
- )
140
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
141
- self.q_b_proj = ColumnParallelLinear(
142
- q_lora_rank,
143
- self.num_heads * self.qk_head_dim,
144
- bias=False,
145
- quant_config=quant_config,
146
- prefix=add_prefix("q_b_proj", prefix),
147
- )
148
- else:
149
- self.q_proj = ColumnParallelLinear(
150
- self.hidden_size,
151
- self.num_heads * self.qk_head_dim,
152
- bias=False,
153
- quant_config=quant_config,
154
- prefix=add_prefix("q_proj", prefix),
155
- )
156
-
157
- self.kv_a_proj_with_mqa = ReplicatedLinear(
158
- self.hidden_size,
159
- self.kv_lora_rank + self.qk_rope_head_dim,
160
- bias=False,
161
- quant_config=quant_config,
162
- prefix=add_prefix("kv_a_proj_with_mqa", prefix),
163
- )
164
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
165
- self.kv_b_proj = ColumnParallelLinear(
166
- self.kv_lora_rank,
167
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
168
- bias=False,
169
- quant_config=quant_config,
170
- prefix=add_prefix("kv_b_proj", prefix),
171
- )
172
- # O projection.
173
- self.o_proj = RowParallelLinear(
174
- self.num_heads * self.v_head_dim,
175
- self.hidden_size,
176
- bias=False,
177
- quant_config=quant_config,
178
- prefix=add_prefix("o_proj", prefix),
179
- )
180
- self.rotary_emb = get_rope(
181
- qk_rope_head_dim,
182
- rotary_dim=qk_rope_head_dim,
183
- max_position=max_position_embeddings,
184
- base=rope_theta,
185
- rope_scaling=rope_scaling,
186
- )
187
-
188
- # TODO support head_size 96
189
- self.attn = RadixAttention(
190
- self.num_local_heads,
191
- 128,
192
- self.scaling,
193
- num_kv_heads=self.num_local_heads,
194
- layer_id=layer_id,
195
- quant_config=quant_config,
196
- prefix=add_prefix("attn", prefix),
197
- )
198
-
199
- def forward(
200
- self,
201
- positions: torch.Tensor,
202
- hidden_states: torch.Tensor,
203
- forward_batch: ForwardBatch,
204
- ) -> torch.Tensor:
205
- if self.q_lora_rank is not None:
206
- q = self.q_a_proj(hidden_states)[0]
207
- q = self.q_a_layernorm(q)
208
- q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
209
- else:
210
- q = self.q_proj(hidden_states)[0].view(
211
- -1, self.num_local_heads, self.qk_head_dim
212
- )
213
- _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
214
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
215
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
216
- latent_cache = latent_cache.unsqueeze(1)
217
- kv_a = self.kv_a_layernorm(kv_a.contiguous())
218
- kv = self.kv_b_proj(kv_a)[0]
219
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
220
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
221
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
222
- original_shapes = [q_pe.shape, k_pe.shape]
223
- q_pe, k_pe = self.rotary_emb(
224
- positions, q_pe.reshape(q_pe.shape[0], -1), k_pe.reshape(k_pe.shape[0], -1)
225
- )
226
- q_pe, k_pe = q_pe.view(original_shapes[0]), k_pe.view(original_shapes[1])
227
- q[..., self.qk_nope_head_dim :] = q_pe
228
- k = torch.empty_like(q)
229
- k[..., : self.qk_nope_head_dim] = k_nope
230
- k[..., self.qk_nope_head_dim :] = k_pe
231
- q = torch.nn.functional.pad(q, [0, 128 - self.qk_head_dim], value=0).view(
232
- -1, self.num_local_heads * 128
233
- )
234
- k = torch.nn.functional.pad(k, [0, 128 - self.qk_head_dim], value=0).view(
235
- -1, self.num_local_heads * 128
236
- )
237
- v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
238
- -1, self.num_local_heads * 128
239
- )
240
- attn_output = self.attn(q, k, v, forward_batch)
241
- attn_output = attn_output.view(-1, self.num_local_heads, 128)[
242
- ..., : self.v_head_dim
243
- ].reshape(-1, self.num_local_heads * self.v_head_dim)
244
- output, _ = self.o_proj(attn_output)
245
- return output
246
-
247
-
248
96
  class MiniCPM3AttentionMLA(nn.Module):
249
97
 
250
98
  def __init__(
@@ -434,44 +282,25 @@ class MiniCPM3DecoderLayer(nn.Module):
434
282
  rope_theta = getattr(config, "rope_theta", 10000)
435
283
  rope_scaling = getattr(config, "rope_scaling", None)
436
284
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
437
- if not global_server_args_dict["disable_mla"]:
438
- self.self_attn = MiniCPM3AttentionMLA(
439
- config=config,
440
- hidden_size=self.hidden_size,
441
- num_heads=config.num_attention_heads,
442
- qk_nope_head_dim=config.qk_nope_head_dim,
443
- qk_rope_head_dim=config.qk_rope_head_dim,
444
- v_head_dim=self.hidden_size // config.num_attention_heads,
445
- q_lora_rank=(
446
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
447
- ),
448
- kv_lora_rank=config.kv_lora_rank,
449
- rope_theta=rope_theta,
450
- rope_scaling=rope_scaling,
451
- max_position_embeddings=max_position_embeddings,
452
- quant_config=quant_config,
453
- layer_id=layer_id,
454
- prefix=add_prefix("self_attn", prefix),
455
- )
456
- else:
457
- self.self_attn = MiniCPM3Attention(
458
- config=config,
459
- hidden_size=self.hidden_size,
460
- num_heads=config.num_attention_heads,
461
- qk_nope_head_dim=config.qk_nope_head_dim,
462
- qk_rope_head_dim=config.qk_rope_head_dim,
463
- v_head_dim=self.hidden_size // config.num_attention_heads,
464
- q_lora_rank=(
465
- config.q_lora_rank if hasattr(config, "q_lora_rank") else None
466
- ),
467
- kv_lora_rank=config.kv_lora_rank,
468
- rope_theta=rope_theta,
469
- rope_scaling=rope_scaling,
470
- max_position_embeddings=max_position_embeddings,
471
- quant_config=quant_config,
472
- layer_id=layer_id,
473
- prefix=add_prefix("self_attn", prefix),
474
- )
285
+ self.self_attn = MiniCPM3AttentionMLA(
286
+ config=config,
287
+ hidden_size=self.hidden_size,
288
+ num_heads=config.num_attention_heads,
289
+ qk_nope_head_dim=config.qk_nope_head_dim,
290
+ qk_rope_head_dim=config.qk_rope_head_dim,
291
+ v_head_dim=self.hidden_size // config.num_attention_heads,
292
+ q_lora_rank=(
293
+ config.q_lora_rank if hasattr(config, "q_lora_rank") else None
294
+ ),
295
+ kv_lora_rank=config.kv_lora_rank,
296
+ rope_theta=rope_theta,
297
+ rope_scaling=rope_scaling,
298
+ max_position_embeddings=max_position_embeddings,
299
+ quant_config=quant_config,
300
+ layer_id=layer_id,
301
+ prefix=add_prefix("self_attn", prefix),
302
+ )
303
+
475
304
  self.mlp = MiniCPM3MLP(
476
305
  hidden_size=self.hidden_size,
477
306
  intermediate_size=config.intermediate_size,
@@ -674,17 +503,16 @@ class MiniCPM3ForCausalLM(nn.Module):
674
503
  )
675
504
  weight_loader(param, loaded_weight)
676
505
 
677
- if not global_server_args_dict["disable_mla"]:
678
- for layer_id in range(self.config.num_hidden_layers):
679
- self_attn = self.model.layers[layer_id].self_attn
680
- w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
681
- 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
682
- ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
683
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
684
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
685
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
686
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
687
- del self_attn.kv_b_proj
506
+ for layer_id in range(self.config.num_hidden_layers):
507
+ self_attn = self.model.layers[layer_id].self_attn
508
+ w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
509
+ 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
510
+ ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
511
+ self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
512
+ self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
513
+ if hasattr(self_attn.kv_b_proj, "weight_scale"):
514
+ self_attn.w_scale = self_attn.kv_b_proj.weight_scale
515
+ del self_attn.kv_b_proj
688
516
 
689
517
 
690
518
  EntryClass = MiniCPM3ForCausalLM
@@ -239,6 +239,7 @@ class Qwen2Model(nn.Module):
239
239
  config: Qwen2Config,
240
240
  quant_config: Optional[QuantizationConfig] = None,
241
241
  prefix: str = "",
242
+ decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
242
243
  ) -> None:
243
244
  super().__init__()
244
245
  self.config = config
@@ -250,9 +251,11 @@ class Qwen2Model(nn.Module):
250
251
  quant_config=quant_config,
251
252
  prefix=add_prefix("embed_tokens", prefix),
252
253
  )
254
+ # Use the provided decoder layer type or default to Qwen2DecoderLayer
255
+ decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
253
256
  self.layers = make_layers(
254
257
  config.num_hidden_layers,
255
- lambda idx, prefix: Qwen2DecoderLayer(
258
+ lambda idx, prefix: decoder_layer_type(
256
259
  layer_id=idx,
257
260
  config=config,
258
261
  quant_config=quant_config,
@@ -47,7 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
- from sglang.srt.utils import add_prefix
50
+ from sglang.srt.utils import add_prefix, make_layers
51
51
 
52
52
  expert_distribution_recorder = ExpertDistributionRecorder()
53
53
 
@@ -262,8 +262,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
262
262
  rope_theta = getattr(config, "rope_theta", 10000)
263
263
  rope_scaling = getattr(config, "rope_scaling", None)
264
264
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
265
- # note: replace config.num_hidden_layers < 80 with True once its available in transformers 4.50.0
266
- qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
265
+ qkv_bias = getattr(config, "qkv_bias", True)
267
266
  self.self_attn = Qwen2MoeAttention(
268
267
  hidden_size=self.hidden_size,
269
268
  num_heads=config.num_attention_heads,
@@ -334,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
334
333
  config: PretrainedConfig,
335
334
  quant_config: Optional[QuantizationConfig] = None,
336
335
  prefix: str = "",
336
+ decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
337
337
  ) -> None:
338
338
  super().__init__()
339
339
  self.padding_idx = config.pad_token_id
@@ -344,16 +344,17 @@ class Qwen2MoeModel(nn.Module):
344
344
  config.hidden_size,
345
345
  prefix=add_prefix("embed_tokens", prefix),
346
346
  )
347
- self.layers = nn.ModuleList(
348
- [
349
- Qwen2MoeDecoderLayer(
350
- config,
351
- layer_id,
352
- quant_config=quant_config,
353
- prefix=add_prefix(f"layers.{layer_id}", prefix),
354
- )
355
- for layer_id in range(config.num_hidden_layers)
356
- ]
347
+ # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
348
+ decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
349
+ self.layers = make_layers(
350
+ config.num_hidden_layers,
351
+ lambda idx, prefix: decoder_layer_type(
352
+ layer_id=idx,
353
+ config=config,
354
+ quant_config=quant_config,
355
+ prefix=prefix,
356
+ ),
357
+ prefix=add_prefix("layers", prefix),
357
358
  )
358
359
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
359
360
 
@@ -0,0 +1,335 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
26
+ from sglang.srt.models.qwen2 import Qwen2Model
27
+ from sglang.srt.utils import add_prefix
28
+
29
+ Qwen3Config = None
30
+
31
+
32
+ class Qwen3Attention(nn.Module):
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ num_heads: int,
37
+ num_kv_heads: int,
38
+ layer_id: int = 0,
39
+ rope_theta: float = 1000000,
40
+ rope_scaling: Optional[Dict[str, Any]] = None,
41
+ head_dim: Optional[int] = None,
42
+ max_position_embeddings: int = 32768,
43
+ quant_config: Optional[QuantizationConfig] = None,
44
+ rms_norm_eps: float = None,
45
+ attention_bias: bool = False,
46
+ prefix: str = "",
47
+ ) -> None:
48
+ super().__init__()
49
+ self.hidden_size = hidden_size
50
+ self.tp_size = get_tensor_model_parallel_world_size()
51
+ self.total_num_heads = num_heads
52
+ assert self.total_num_heads % self.tp_size == 0
53
+ self.num_heads = self.total_num_heads // self.tp_size
54
+ self.total_num_kv_heads = num_kv_heads
55
+ if self.total_num_kv_heads >= self.tp_size:
56
+ # Number of KV heads is greater than TP size, so we partition
57
+ # the KV heads across multiple tensor parallel GPUs.
58
+ assert self.total_num_kv_heads % self.tp_size == 0
59
+ else:
60
+ # Number of KV heads is less than TP size, so we replicate
61
+ # the KV heads across multiple tensor parallel GPUs.
62
+ assert self.tp_size % self.total_num_kv_heads == 0
63
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
64
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
65
+ self.q_size = self.num_heads * self.head_dim
66
+ self.kv_size = self.num_kv_heads * self.head_dim
67
+ self.scaling = self.head_dim**-0.5
68
+ self.rope_theta = rope_theta
69
+ self.max_position_embeddings = max_position_embeddings
70
+ self.tp_rank = get_tensor_model_parallel_rank()
71
+
72
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
73
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
74
+
75
+ self.qkv_proj = QKVParallelLinear(
76
+ hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ self.total_num_kv_heads,
80
+ bias=attention_bias,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("qkv_proj", prefix),
83
+ )
84
+ self.o_proj = RowParallelLinear(
85
+ self.total_num_heads * self.head_dim,
86
+ hidden_size,
87
+ bias=attention_bias,
88
+ quant_config=quant_config,
89
+ prefix=add_prefix("o_proj", prefix),
90
+ )
91
+
92
+ self.rotary_emb = get_rope(
93
+ self.head_dim,
94
+ rotary_dim=self.head_dim,
95
+ max_position=max_position_embeddings,
96
+ base=rope_theta,
97
+ rope_scaling=rope_scaling,
98
+ )
99
+ self.attn = RadixAttention(
100
+ self.num_heads,
101
+ self.head_dim,
102
+ self.scaling,
103
+ num_kv_heads=self.num_kv_heads,
104
+ layer_id=layer_id,
105
+ prefix=add_prefix("attn", prefix),
106
+ )
107
+
108
+ def _apply_qk_norm(
109
+ self, q: torch.Tensor, k: torch.Tensor
110
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ q_by_head = q.reshape(-1, self.head_dim)
112
+ q_by_head = self.q_norm(q_by_head)
113
+ q = q_by_head.view(q.shape)
114
+ k_by_head = k.reshape(-1, self.head_dim)
115
+ k_by_head = self.k_norm(k_by_head)
116
+ k = k_by_head.view(k.shape)
117
+ return q, k
118
+
119
+ def forward(
120
+ self,
121
+ positions: torch.Tensor,
122
+ hidden_states: torch.Tensor,
123
+ forward_batch: ForwardBatch,
124
+ ) -> torch.Tensor:
125
+ qkv, _ = self.qkv_proj(hidden_states)
126
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
127
+ q, k = self._apply_qk_norm(q, k)
128
+ q, k = self.rotary_emb(positions, q, k)
129
+ attn_output = self.attn(q, k, v, forward_batch)
130
+ output, _ = self.o_proj(attn_output)
131
+ return output
132
+
133
+
134
+ class Qwen3DecoderLayer(nn.Module):
135
+ def __init__(
136
+ self,
137
+ config: Qwen3Config,
138
+ layer_id: int = 0,
139
+ quant_config: Optional[QuantizationConfig] = None,
140
+ prefix: str = "",
141
+ ) -> None:
142
+ super().__init__()
143
+ self.hidden_size = config.hidden_size
144
+ rope_theta = getattr(config, "rope_theta", 1000000)
145
+ rope_scaling = getattr(config, "rope_scaling", None)
146
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
147
+ head_dim = getattr(config, "head_dim", None)
148
+ self.self_attn = Qwen3Attention(
149
+ hidden_size=self.hidden_size,
150
+ num_heads=config.num_attention_heads,
151
+ num_kv_heads=config.num_key_value_heads,
152
+ layer_id=layer_id,
153
+ rope_theta=rope_theta,
154
+ rope_scaling=rope_scaling,
155
+ head_dim=head_dim,
156
+ max_position_embeddings=max_position_embeddings,
157
+ quant_config=quant_config,
158
+ rms_norm_eps=config.rms_norm_eps,
159
+ attention_bias=config.attention_bias,
160
+ prefix=add_prefix("self_attn", prefix),
161
+ )
162
+ self.mlp = Qwen3MLP(
163
+ hidden_size=self.hidden_size,
164
+ intermediate_size=config.intermediate_size,
165
+ hidden_act=config.hidden_act,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("mlp", prefix),
168
+ )
169
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
170
+ self.post_attention_layernorm = RMSNorm(
171
+ config.hidden_size, eps=config.rms_norm_eps
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ forward_batch: ForwardBatch,
179
+ residual: Optional[torch.Tensor],
180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ # Self Attention
182
+ if residual is None:
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ else:
186
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
187
+ hidden_states = self.self_attn(
188
+ positions=positions,
189
+ hidden_states=hidden_states,
190
+ forward_batch=forward_batch,
191
+ )
192
+
193
+ # Fully Connected
194
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
195
+ hidden_states = self.mlp(hidden_states)
196
+ return hidden_states, residual
197
+
198
+
199
+ class Qwen3Model(Qwen2Model):
200
+ def __init__(
201
+ self,
202
+ config: Qwen3Config,
203
+ quant_config: Optional[QuantizationConfig] = None,
204
+ prefix: str = "",
205
+ ) -> None:
206
+ super().__init__(
207
+ config=config,
208
+ quant_config=quant_config,
209
+ prefix=prefix,
210
+ decoder_layer_type=Qwen3DecoderLayer,
211
+ )
212
+
213
+
214
+ class Qwen3ForCausalLM(nn.Module):
215
+ # BitandBytes specific attributes
216
+ default_bitsandbytes_target_modules = [
217
+ ".gate_proj.",
218
+ ".down_proj.",
219
+ ".up_proj.",
220
+ ".q_proj.",
221
+ ".k_proj.",
222
+ ".v_proj.",
223
+ ".o_proj.",
224
+ ]
225
+ bitsandbytes_stacked_params_mapping = {
226
+ # shard_name, weight_name, index
227
+ "q_proj": ("qkv_proj", 0),
228
+ "k_proj": ("qkv_proj", 1),
229
+ "v_proj": ("qkv_proj", 2),
230
+ "gate_proj": ("gate_up_proj", 0),
231
+ "up_proj": ("gate_up_proj", 1),
232
+ }
233
+
234
+ def __init__(
235
+ self,
236
+ config: Qwen3Config,
237
+ quant_config: Optional[QuantizationConfig] = None,
238
+ prefix: str = "",
239
+ ) -> None:
240
+ super().__init__()
241
+ self.config = config
242
+ self.quant_config = quant_config
243
+ self.model = Qwen3Model(
244
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
245
+ )
246
+ if config.tie_word_embeddings:
247
+ self.lm_head = self.model.embed_tokens
248
+ else:
249
+ self.lm_head = ParallelLMHead(
250
+ config.vocab_size,
251
+ config.hidden_size,
252
+ quant_config=quant_config,
253
+ prefix=add_prefix("lm_head", prefix),
254
+ )
255
+ self.logits_processor = LogitsProcessor(config)
256
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
257
+
258
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
259
+ return self.model.get_input_embeddings(input_ids)
260
+
261
+ @torch.no_grad()
262
+ def forward(
263
+ self,
264
+ input_ids: torch.Tensor,
265
+ positions: torch.Tensor,
266
+ forward_batch: ForwardBatch,
267
+ input_embeds: torch.Tensor = None,
268
+ get_embedding: bool = False,
269
+ ) -> torch.Tensor:
270
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
271
+ if not get_embedding:
272
+ return self.logits_processor(
273
+ input_ids, hidden_states, self.lm_head, forward_batch
274
+ )
275
+ else:
276
+ return self.pooler(hidden_states, forward_batch)
277
+
278
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
279
+ stacked_params_mapping = [
280
+ # (param_name, shard_name, shard_id)
281
+ ("qkv_proj", "q_proj", "q"),
282
+ ("qkv_proj", "k_proj", "k"),
283
+ ("qkv_proj", "v_proj", "v"),
284
+ ("gate_up_proj", "gate_proj", 0),
285
+ ("gate_up_proj", "up_proj", 1),
286
+ ]
287
+
288
+ params_dict = dict(self.named_parameters())
289
+ for name, loaded_weight in weights:
290
+ if "rotary_emb.inv_freq" in name or "projector" in name:
291
+ continue
292
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
293
+ # Models trained using ColossalAI may include these tensors in
294
+ # the checkpoint. Skip them.
295
+ continue
296
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
297
+ continue
298
+ if name.startswith("model.vision_tower") and name not in params_dict:
299
+ continue
300
+
301
+ for param_name, weight_name, shard_id in stacked_params_mapping:
302
+ if weight_name not in name:
303
+ continue
304
+ name = name.replace(weight_name, param_name)
305
+ # Skip loading extra bias for GPTQ models.
306
+ if name.endswith(".bias") and name not in params_dict:
307
+ continue
308
+ param = params_dict[name]
309
+ weight_loader = param.weight_loader
310
+ weight_loader(param, loaded_weight, shard_id)
311
+ break
312
+ else:
313
+ # Skip loading extra bias for GPTQ models.
314
+ if name.endswith(".bias") and name not in params_dict:
315
+ continue
316
+ param = params_dict[name]
317
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
318
+ weight_loader(param, loaded_weight)
319
+
320
+ def get_embed_and_head(self):
321
+ return self.model.embed_tokens.weight, self.lm_head.weight
322
+
323
+ def set_embed_and_head(self, embed, head):
324
+ del self.model.embed_tokens.weight
325
+ del self.lm_head.weight
326
+ self.model.embed_tokens.weight = embed
327
+ self.lm_head.weight = head
328
+ torch.cuda.empty_cache()
329
+ torch.cuda.synchronize()
330
+
331
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
332
+ self.model.load_kv_cache_scales(quantization_param_path)
333
+
334
+
335
+ EntryClass = Qwen3ForCausalLM