sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -47,7 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
47
47
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
- from sglang.srt.utils import add_prefix
50
+ from sglang.srt.utils import add_prefix, make_layers
51
51
 
52
52
  expert_distribution_recorder = ExpertDistributionRecorder()
53
53
 
@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
231
231
  self.scaling,
232
232
  num_kv_heads=self.num_kv_heads,
233
233
  layer_id=layer_id,
234
+ quant_config=quant_config,
234
235
  prefix=add_prefix("attn", prefix),
235
236
  )
236
237
 
@@ -261,8 +262,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
261
262
  rope_theta = getattr(config, "rope_theta", 10000)
262
263
  rope_scaling = getattr(config, "rope_scaling", None)
263
264
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
264
- # note: replace config.num_hidden_layers < 80 with True once its available in transformers 4.50.0
265
- qkv_bias = getattr(config, "qkv_bias", config.num_hidden_layers < 80)
265
+ qkv_bias = getattr(config, "qkv_bias", True)
266
266
  self.self_attn = Qwen2MoeAttention(
267
267
  hidden_size=self.hidden_size,
268
268
  num_heads=config.num_attention_heads,
@@ -333,6 +333,7 @@ class Qwen2MoeModel(nn.Module):
333
333
  config: PretrainedConfig,
334
334
  quant_config: Optional[QuantizationConfig] = None,
335
335
  prefix: str = "",
336
+ decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
336
337
  ) -> None:
337
338
  super().__init__()
338
339
  self.padding_idx = config.pad_token_id
@@ -343,16 +344,17 @@ class Qwen2MoeModel(nn.Module):
343
344
  config.hidden_size,
344
345
  prefix=add_prefix("embed_tokens", prefix),
345
346
  )
346
- self.layers = nn.ModuleList(
347
- [
348
- Qwen2MoeDecoderLayer(
349
- config,
350
- layer_id,
351
- quant_config=quant_config,
352
- prefix=add_prefix(f"layers.{layer_id}", prefix),
353
- )
354
- for layer_id in range(config.num_hidden_layers)
355
- ]
347
+ # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
348
+ decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
349
+ self.layers = make_layers(
350
+ config.num_hidden_layers,
351
+ lambda idx, prefix: decoder_layer_type(
352
+ layer_id=idx,
353
+ config=config,
354
+ quant_config=quant_config,
355
+ prefix=prefix,
356
+ ),
357
+ prefix=add_prefix("layers", prefix),
356
358
  )
357
359
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
358
360
 
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
- use_qkv_parallel=False,
155
+ use_qkv_parallel=True,
156
156
  use_context_forward=use_context_forward,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
351
351
 
352
352
  @property
353
353
  def dtype(self) -> torch.dtype:
354
- return next(self.parameters()).dtype
354
+ return self.patch_embed.proj.weight.dtype
355
355
 
356
356
  @property
357
357
  def device(self) -> torch.device:
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
423
423
 
424
424
 
425
425
  class Qwen2VLForConditionalGeneration(nn.Module):
426
+ # BitandBytes specific attributes
427
+ default_bitsandbytes_target_modules = [
428
+ ".gate_proj.",
429
+ ".down_proj.",
430
+ ".up_proj.",
431
+ ".q_proj.",
432
+ ".k_proj.",
433
+ ".v_proj.",
434
+ ".o_proj.",
435
+ ]
436
+ bitsandbytes_stacked_params_mapping = {
437
+ # shard_name, weight_name, index
438
+ "q_proj": ("qkv_proj", 0),
439
+ "k_proj": ("qkv_proj", 1),
440
+ "v_proj": ("qkv_proj", 2),
441
+ "gate_proj": ("gate_up_proj", 0),
442
+ "up_proj": ("gate_up_proj", 1),
443
+ }
444
+
426
445
  def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
427
446
  processor = cached_get_processor(self.config._name_or_path)
428
447
  grid_t, grid_h, grid_w = image_grid_thw
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
447
466
  self.visual = Qwen2VisionTransformer(
448
467
  config.vision_config,
449
468
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
450
- # NOTE: Qwen2-VL vision encoder does not support any
451
- # quantization method now.
452
- quant_config=None,
469
+ # NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
470
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
471
+ quant_config=quant_config,
453
472
  prefix=add_prefix("visual", prefix),
454
473
  )
455
474
 
@@ -467,6 +486,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
467
486
  prefix=add_prefix("lm_head", prefix),
468
487
  )
469
488
 
489
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
470
490
  self.logits_processor = LogitsProcessor(config)
471
491
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
472
492
 
@@ -521,14 +541,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
521
541
  otherwise it will be `(seq_len,).
522
542
  (Use input_metadata.mrope_positions to replace it)
523
543
  """
524
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
544
+ if self.is_mrope_enabled:
525
545
  positions = forward_batch.mrope_positions
526
546
 
527
547
  if not (
528
548
  forward_batch.forward_mode.is_decode()
529
549
  or not forward_batch.contains_image_inputs()
530
550
  ):
531
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
551
+ if self.is_mrope_enabled:
532
552
  assert positions.ndim == 2 and positions.size(0) == 3, (
533
553
  "multimodal section rotary embedding requires "
534
554
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -577,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
577
597
  weight_loader(param, loaded_weight, shard_id)
578
598
  break
579
599
  else:
580
-
581
- if "visual" in name and "qkv.weight" in name:
582
- visual_num_heads = self.config.vision_config.num_heads
583
- visual_embed_dim = self.config.vision_config.embed_dim
584
- head_size = visual_embed_dim // visual_num_heads
585
- loaded_weight = loaded_weight.view(
586
- 3, visual_num_heads, head_size, visual_embed_dim
587
- )
588
- loaded_weight = loaded_weight.transpose(0, 1)
589
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
590
- elif "visual" in name and "qkv.bias" in name:
591
- visual_num_heads = self.config.vision_config.num_heads
592
- visual_embed_dim = self.config.vision_config.embed_dim
593
- head_size = visual_embed_dim // visual_num_heads
594
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
595
- loaded_weight = loaded_weight.transpose(0, 1)
596
- loaded_weight = loaded_weight.reshape(-1)
597
-
598
600
  if "visual" in name:
599
601
  # adapt to VisionAttention
600
602
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
@@ -0,0 +1,335 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
26
+ from sglang.srt.models.qwen2 import Qwen2Model
27
+ from sglang.srt.utils import add_prefix
28
+
29
+ Qwen3Config = None
30
+
31
+
32
+ class Qwen3Attention(nn.Module):
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ num_heads: int,
37
+ num_kv_heads: int,
38
+ layer_id: int = 0,
39
+ rope_theta: float = 1000000,
40
+ rope_scaling: Optional[Dict[str, Any]] = None,
41
+ head_dim: Optional[int] = None,
42
+ max_position_embeddings: int = 32768,
43
+ quant_config: Optional[QuantizationConfig] = None,
44
+ rms_norm_eps: float = None,
45
+ attention_bias: bool = False,
46
+ prefix: str = "",
47
+ ) -> None:
48
+ super().__init__()
49
+ self.hidden_size = hidden_size
50
+ self.tp_size = get_tensor_model_parallel_world_size()
51
+ self.total_num_heads = num_heads
52
+ assert self.total_num_heads % self.tp_size == 0
53
+ self.num_heads = self.total_num_heads // self.tp_size
54
+ self.total_num_kv_heads = num_kv_heads
55
+ if self.total_num_kv_heads >= self.tp_size:
56
+ # Number of KV heads is greater than TP size, so we partition
57
+ # the KV heads across multiple tensor parallel GPUs.
58
+ assert self.total_num_kv_heads % self.tp_size == 0
59
+ else:
60
+ # Number of KV heads is less than TP size, so we replicate
61
+ # the KV heads across multiple tensor parallel GPUs.
62
+ assert self.tp_size % self.total_num_kv_heads == 0
63
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
64
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
65
+ self.q_size = self.num_heads * self.head_dim
66
+ self.kv_size = self.num_kv_heads * self.head_dim
67
+ self.scaling = self.head_dim**-0.5
68
+ self.rope_theta = rope_theta
69
+ self.max_position_embeddings = max_position_embeddings
70
+ self.tp_rank = get_tensor_model_parallel_rank()
71
+
72
+ self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
73
+ self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
74
+
75
+ self.qkv_proj = QKVParallelLinear(
76
+ hidden_size,
77
+ self.head_dim,
78
+ self.total_num_heads,
79
+ self.total_num_kv_heads,
80
+ bias=attention_bias,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("qkv_proj", prefix),
83
+ )
84
+ self.o_proj = RowParallelLinear(
85
+ self.total_num_heads * self.head_dim,
86
+ hidden_size,
87
+ bias=attention_bias,
88
+ quant_config=quant_config,
89
+ prefix=add_prefix("o_proj", prefix),
90
+ )
91
+
92
+ self.rotary_emb = get_rope(
93
+ self.head_dim,
94
+ rotary_dim=self.head_dim,
95
+ max_position=max_position_embeddings,
96
+ base=rope_theta,
97
+ rope_scaling=rope_scaling,
98
+ )
99
+ self.attn = RadixAttention(
100
+ self.num_heads,
101
+ self.head_dim,
102
+ self.scaling,
103
+ num_kv_heads=self.num_kv_heads,
104
+ layer_id=layer_id,
105
+ prefix=add_prefix("attn", prefix),
106
+ )
107
+
108
+ def _apply_qk_norm(
109
+ self, q: torch.Tensor, k: torch.Tensor
110
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ q_by_head = q.reshape(-1, self.head_dim)
112
+ q_by_head = self.q_norm(q_by_head)
113
+ q = q_by_head.view(q.shape)
114
+ k_by_head = k.reshape(-1, self.head_dim)
115
+ k_by_head = self.k_norm(k_by_head)
116
+ k = k_by_head.view(k.shape)
117
+ return q, k
118
+
119
+ def forward(
120
+ self,
121
+ positions: torch.Tensor,
122
+ hidden_states: torch.Tensor,
123
+ forward_batch: ForwardBatch,
124
+ ) -> torch.Tensor:
125
+ qkv, _ = self.qkv_proj(hidden_states)
126
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
127
+ q, k = self._apply_qk_norm(q, k)
128
+ q, k = self.rotary_emb(positions, q, k)
129
+ attn_output = self.attn(q, k, v, forward_batch)
130
+ output, _ = self.o_proj(attn_output)
131
+ return output
132
+
133
+
134
+ class Qwen3DecoderLayer(nn.Module):
135
+ def __init__(
136
+ self,
137
+ config: Qwen3Config,
138
+ layer_id: int = 0,
139
+ quant_config: Optional[QuantizationConfig] = None,
140
+ prefix: str = "",
141
+ ) -> None:
142
+ super().__init__()
143
+ self.hidden_size = config.hidden_size
144
+ rope_theta = getattr(config, "rope_theta", 1000000)
145
+ rope_scaling = getattr(config, "rope_scaling", None)
146
+ max_position_embeddings = getattr(config, "max_position_embeddings", 32768)
147
+ head_dim = getattr(config, "head_dim", None)
148
+ self.self_attn = Qwen3Attention(
149
+ hidden_size=self.hidden_size,
150
+ num_heads=config.num_attention_heads,
151
+ num_kv_heads=config.num_key_value_heads,
152
+ layer_id=layer_id,
153
+ rope_theta=rope_theta,
154
+ rope_scaling=rope_scaling,
155
+ head_dim=head_dim,
156
+ max_position_embeddings=max_position_embeddings,
157
+ quant_config=quant_config,
158
+ rms_norm_eps=config.rms_norm_eps,
159
+ attention_bias=config.attention_bias,
160
+ prefix=add_prefix("self_attn", prefix),
161
+ )
162
+ self.mlp = Qwen3MLP(
163
+ hidden_size=self.hidden_size,
164
+ intermediate_size=config.intermediate_size,
165
+ hidden_act=config.hidden_act,
166
+ quant_config=quant_config,
167
+ prefix=add_prefix("mlp", prefix),
168
+ )
169
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
170
+ self.post_attention_layernorm = RMSNorm(
171
+ config.hidden_size, eps=config.rms_norm_eps
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ forward_batch: ForwardBatch,
179
+ residual: Optional[torch.Tensor],
180
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
181
+ # Self Attention
182
+ if residual is None:
183
+ residual = hidden_states
184
+ hidden_states = self.input_layernorm(hidden_states)
185
+ else:
186
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
187
+ hidden_states = self.self_attn(
188
+ positions=positions,
189
+ hidden_states=hidden_states,
190
+ forward_batch=forward_batch,
191
+ )
192
+
193
+ # Fully Connected
194
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
195
+ hidden_states = self.mlp(hidden_states)
196
+ return hidden_states, residual
197
+
198
+
199
+ class Qwen3Model(Qwen2Model):
200
+ def __init__(
201
+ self,
202
+ config: Qwen3Config,
203
+ quant_config: Optional[QuantizationConfig] = None,
204
+ prefix: str = "",
205
+ ) -> None:
206
+ super().__init__(
207
+ config=config,
208
+ quant_config=quant_config,
209
+ prefix=prefix,
210
+ decoder_layer_type=Qwen3DecoderLayer,
211
+ )
212
+
213
+
214
+ class Qwen3ForCausalLM(nn.Module):
215
+ # BitandBytes specific attributes
216
+ default_bitsandbytes_target_modules = [
217
+ ".gate_proj.",
218
+ ".down_proj.",
219
+ ".up_proj.",
220
+ ".q_proj.",
221
+ ".k_proj.",
222
+ ".v_proj.",
223
+ ".o_proj.",
224
+ ]
225
+ bitsandbytes_stacked_params_mapping = {
226
+ # shard_name, weight_name, index
227
+ "q_proj": ("qkv_proj", 0),
228
+ "k_proj": ("qkv_proj", 1),
229
+ "v_proj": ("qkv_proj", 2),
230
+ "gate_proj": ("gate_up_proj", 0),
231
+ "up_proj": ("gate_up_proj", 1),
232
+ }
233
+
234
+ def __init__(
235
+ self,
236
+ config: Qwen3Config,
237
+ quant_config: Optional[QuantizationConfig] = None,
238
+ prefix: str = "",
239
+ ) -> None:
240
+ super().__init__()
241
+ self.config = config
242
+ self.quant_config = quant_config
243
+ self.model = Qwen3Model(
244
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
245
+ )
246
+ if config.tie_word_embeddings:
247
+ self.lm_head = self.model.embed_tokens
248
+ else:
249
+ self.lm_head = ParallelLMHead(
250
+ config.vocab_size,
251
+ config.hidden_size,
252
+ quant_config=quant_config,
253
+ prefix=add_prefix("lm_head", prefix),
254
+ )
255
+ self.logits_processor = LogitsProcessor(config)
256
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
257
+
258
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
259
+ return self.model.get_input_embeddings(input_ids)
260
+
261
+ @torch.no_grad()
262
+ def forward(
263
+ self,
264
+ input_ids: torch.Tensor,
265
+ positions: torch.Tensor,
266
+ forward_batch: ForwardBatch,
267
+ input_embeds: torch.Tensor = None,
268
+ get_embedding: bool = False,
269
+ ) -> torch.Tensor:
270
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
271
+ if not get_embedding:
272
+ return self.logits_processor(
273
+ input_ids, hidden_states, self.lm_head, forward_batch
274
+ )
275
+ else:
276
+ return self.pooler(hidden_states, forward_batch)
277
+
278
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
279
+ stacked_params_mapping = [
280
+ # (param_name, shard_name, shard_id)
281
+ ("qkv_proj", "q_proj", "q"),
282
+ ("qkv_proj", "k_proj", "k"),
283
+ ("qkv_proj", "v_proj", "v"),
284
+ ("gate_up_proj", "gate_proj", 0),
285
+ ("gate_up_proj", "up_proj", 1),
286
+ ]
287
+
288
+ params_dict = dict(self.named_parameters())
289
+ for name, loaded_weight in weights:
290
+ if "rotary_emb.inv_freq" in name or "projector" in name:
291
+ continue
292
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
293
+ # Models trained using ColossalAI may include these tensors in
294
+ # the checkpoint. Skip them.
295
+ continue
296
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
297
+ continue
298
+ if name.startswith("model.vision_tower") and name not in params_dict:
299
+ continue
300
+
301
+ for param_name, weight_name, shard_id in stacked_params_mapping:
302
+ if weight_name not in name:
303
+ continue
304
+ name = name.replace(weight_name, param_name)
305
+ # Skip loading extra bias for GPTQ models.
306
+ if name.endswith(".bias") and name not in params_dict:
307
+ continue
308
+ param = params_dict[name]
309
+ weight_loader = param.weight_loader
310
+ weight_loader(param, loaded_weight, shard_id)
311
+ break
312
+ else:
313
+ # Skip loading extra bias for GPTQ models.
314
+ if name.endswith(".bias") and name not in params_dict:
315
+ continue
316
+ param = params_dict[name]
317
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
318
+ weight_loader(param, loaded_weight)
319
+
320
+ def get_embed_and_head(self):
321
+ return self.model.embed_tokens.weight, self.lm_head.weight
322
+
323
+ def set_embed_and_head(self, embed, head):
324
+ del self.model.embed_tokens.weight
325
+ del self.lm_head.weight
326
+ self.model.embed_tokens.weight = embed
327
+ self.lm_head.weight = head
328
+ torch.cuda.empty_cache()
329
+ torch.cuda.synchronize()
330
+
331
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
332
+ self.model.load_kv_cache_scales(quantization_param_path)
333
+
334
+
335
+ EntryClass = Qwen3ForCausalLM