sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from sglang.srt.layers.layernorm import RMSNorm
22
22
  from sglang.srt.layers.linear import (
23
23
  ColumnParallelLinear,
24
24
  QKVParallelLinear,
25
+ ReplicatedLinear,
25
26
  RowParallelLinear,
26
27
  )
27
28
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
184
185
  def __init__(
185
186
  self,
186
187
  config: config_mllama.MllamaVisionConfig,
188
+ quant_config: Optional[QuantizationConfig] = None,
187
189
  is_gated: bool = False,
188
190
  prefix: str = "",
189
191
  ):
@@ -199,14 +201,16 @@ class MllamaVisionEncoderLayer(nn.Module):
199
201
  self.num_attention_heads,
200
202
  self.hidden_size,
201
203
  use_qkv_parallel=True,
202
- quant_config=None,
204
+ quant_config=quant_config,
203
205
  dropout=0.0,
204
206
  use_context_forward=False,
205
207
  softmax_in_single_precision=False,
206
208
  flatten_batch=False,
207
209
  prefix=add_prefix("self_attn", prefix),
208
210
  )
209
- self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
211
+ self.mlp = MllamaVisionMLP(
212
+ config, quant_config, prefix=add_prefix("mlp", prefix)
213
+ )
210
214
 
211
215
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
212
216
  self.post_attention_layernorm = nn.LayerNorm(
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
244
248
  def __init__(
245
249
  self,
246
250
  config: config_mllama.MllamaVisionConfig,
251
+ quant_config: Optional[QuantizationConfig] = None,
247
252
  num_layers=32,
248
253
  is_gated=False,
249
254
  output_hidden_states=None,
@@ -254,7 +259,10 @@ class MllamaVisionEncoder(nn.Module):
254
259
  self.layers = nn.ModuleList(
255
260
  [
256
261
  MllamaVisionEncoderLayer(
257
- config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
262
+ config,
263
+ quant_config,
264
+ is_gated,
265
+ prefix=add_prefix(f"layers.{i}", prefix),
258
266
  )
259
267
  for i in range(num_layers)
260
268
  ]
@@ -283,7 +291,12 @@ class MllamaVisionEncoder(nn.Module):
283
291
 
284
292
 
285
293
  class MllamaVisionModel(nn.Module):
286
- def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
294
+ def __init__(
295
+ self,
296
+ config: config_mllama.MllamaVisionConfig,
297
+ quant_config: Optional[QuantizationConfig] = None,
298
+ prefix: str = "",
299
+ ):
287
300
  super().__init__()
288
301
  self.image_size = config.image_size
289
302
  self.patch_size = config.patch_size
@@ -320,6 +333,7 @@ class MllamaVisionModel(nn.Module):
320
333
  # encoders
321
334
  self.transformer = MllamaVisionEncoder(
322
335
  config,
336
+ quant_config,
323
337
  config.num_hidden_layers,
324
338
  is_gated=False,
325
339
  output_hidden_states=config.intermediate_layers_indices,
@@ -327,6 +341,7 @@ class MllamaVisionModel(nn.Module):
327
341
  )
328
342
  self.global_transformer = MllamaVisionEncoder(
329
343
  config,
344
+ quant_config,
330
345
  config.num_global_layers,
331
346
  is_gated=True,
332
347
  prefix=add_prefix("global_transformer", prefix),
@@ -535,6 +550,7 @@ class MllamaTextCrossAttention(nn.Module):
535
550
  self.num_local_key_value_heads,
536
551
  layer_id=layer_id,
537
552
  is_cross_attention=True,
553
+ quant_config=quant_config,
538
554
  prefix=add_prefix("attn", prefix),
539
555
  )
540
556
 
@@ -764,6 +780,27 @@ class MllamaForCausalLM(nn.Module):
764
780
 
765
781
 
766
782
  class MllamaForConditionalGeneration(nn.Module):
783
+ # BitandBytes specific attributes
784
+ default_bitsandbytes_target_modules = [
785
+ ".gate_proj.",
786
+ ".down_proj.",
787
+ ".up_proj.",
788
+ ".q_proj.",
789
+ ".k_proj.",
790
+ ".v_proj.",
791
+ ".o_proj.",
792
+ ]
793
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
794
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
795
+ bitsandbytes_stacked_params_mapping = {
796
+ # shard_name, weight_name, index
797
+ "q_proj": ("qkv_proj", 0),
798
+ "k_proj": ("qkv_proj", 1),
799
+ "v_proj": ("qkv_proj", 2),
800
+ "gate_proj": ("gate_up_proj", 0),
801
+ "up_proj": ("gate_up_proj", 1),
802
+ }
803
+
767
804
  def __init__(
768
805
  self,
769
806
  config: config_mllama.MllamaConfig,
@@ -771,6 +808,7 @@ class MllamaForConditionalGeneration(nn.Module):
771
808
  prefix: str = "",
772
809
  ):
773
810
  super().__init__()
811
+ self.quant_config = quant_config
774
812
  self.vocab_size = config.text_config.vocab_size
775
813
  self.hidden_size = config.text_config.hidden_size
776
814
  self.max_num_tiles = config.vision_config.max_num_tiles
@@ -781,17 +819,21 @@ class MllamaForConditionalGeneration(nn.Module):
781
819
  self.image_size = config.vision_config.image_size
782
820
 
783
821
  self.vision_model = MllamaVisionModel(
784
- config.vision_config, prefix=add_prefix("vision_model", prefix)
822
+ config.vision_config,
823
+ quant_config=quant_config,
824
+ prefix=add_prefix("vision_model", prefix),
785
825
  )
786
826
  self.language_model = MllamaForCausalLM(
787
827
  config.text_config,
788
828
  quant_config=quant_config,
789
829
  prefix=add_prefix("language_model", prefix),
790
830
  )
791
- self.multi_modal_projector = nn.Linear(
831
+ self.multi_modal_projector = ReplicatedLinear(
792
832
  config.vision_config.vision_output_dim,
793
833
  config.text_config.hidden_size,
794
834
  bias=True,
835
+ quant_config=quant_config,
836
+ prefix="multi_modal_projector",
795
837
  )
796
838
  self.logits_processor = LogitsProcessor(config.text_config)
797
839
  self.capture_mode = False
@@ -958,7 +1000,9 @@ class MllamaForConditionalGeneration(nn.Module):
958
1000
  cross_attention_states = self.vision_model(
959
1001
  batched_images, batched_ar_ids, batched_ar_mask
960
1002
  )
961
- cross_attention_states = self.multi_modal_projector(cross_attention_states)
1003
+ cross_attention_states, _ = self.multi_modal_projector(
1004
+ cross_attention_states
1005
+ )
962
1006
 
963
1007
  bs, _, _, _, image_token_dim = cross_attention_states.shape
964
1008
  cross_attention_states = cross_attention_states.view(
@@ -1012,7 +1056,6 @@ class MllamaForConditionalGeneration(nn.Module):
1012
1056
  if "vision_model" in name:
1013
1057
  # adapt to VisionAttention
1014
1058
  name = name.replace("self_attn.o_proj", "self_attn.proj")
1015
-
1016
1059
  param = params_dict.pop(name)
1017
1060
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
1018
1061
  weight_loader(param, loaded_weight)
@@ -1,13 +1,19 @@
1
- # TODO: add Aapted from vllm/mllama4.py
2
1
  from collections.abc import Iterable
3
- from typing import Optional, Set, Tuple
2
+ from typing import List, Optional, Set, Tuple
4
3
 
5
4
  import torch
6
5
  from torch import nn
7
- from transformers import Llama4Config
6
+ from transformers import Llama4Config, Llama4VisionModel
7
+ from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
8
8
 
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessor
10
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
10
11
  from sglang.srt.layers.quantization import QuantizationConfig
12
+ from sglang.srt.managers.mm_utils import (
13
+ MultiModalityDataPaddingPatternImageTokens,
14
+ general_mm_embed_routine,
15
+ )
16
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
11
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
12
18
  from sglang.srt.model_loader.weight_utils import default_weight_loader
13
19
  from sglang.srt.utils import add_prefix
@@ -16,6 +22,7 @@ from sglang.srt.utils import add_prefix
16
22
  class Llama4ForConditionalGeneration(nn.Module):
17
23
  packed_modules_mapping = {
18
24
  "qkv_proj": ["q_proj", "k_proj", "v_proj"],
25
+ "gate_up_proj": ["gate_proj", "up_proj"],
19
26
  }
20
27
 
21
28
  def __init__(
@@ -28,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
28
35
  self.config = config
29
36
  self.quant_config = quant_config
30
37
 
38
+ self.vision_model = Llama4VisionModel(config.vision_config)
39
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
40
+
31
41
  # Initialize the language model
32
42
  from sglang.srt.models.llama4 import Llama4ForCausalLM
33
43
 
@@ -39,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
39
49
 
40
50
  self.logits_processor = LogitsProcessor(config.text_config)
41
51
 
52
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53
+ # Get all special token IDs
54
+ im_token_id: int = mm_inputs.im_token_id
55
+
56
+ pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
57
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
58
+
59
+ def get_image_feature(
60
+ self,
61
+ items: List[MultimodalDataItem],
62
+ ) -> torch.Tensor:
63
+ pixel_values = (
64
+ torch.concat([item.pixel_values for item in items])
65
+ .to(next(self.vision_model.parameters()).device)
66
+ .type(next(self.vision_model.parameters()).dtype)
67
+ )
68
+
69
+ image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
70
+ image_features = image_outputs.last_hidden_state
71
+ vision_flat = image_features.view(-1, image_features.size(-1))
72
+ projected_vision_flat = self.multi_modal_projector(vision_flat)
73
+ return projected_vision_flat
74
+
42
75
  def forward(
43
76
  self,
44
77
  input_ids: torch.Tensor,
@@ -47,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
47
80
  **kwargs: object,
48
81
  ) -> torch.Tensor:
49
82
 
50
- return self.language_model(input_ids, positions, forward_batch)
83
+ hs = general_mm_embed_routine(
84
+ input_ids=input_ids,
85
+ forward_batch=forward_batch,
86
+ language_model=self.language_model,
87
+ image_data_embedding_func=self.get_image_feature,
88
+ positions=positions,
89
+ )
90
+
91
+ return hs
51
92
 
52
93
  def permute_qk_weight_for_rotary(
53
94
  self,
@@ -96,18 +137,27 @@ class Llama4ForConditionalGeneration(nn.Module):
96
137
 
97
138
  num_experts = self.config.text_config.num_local_experts
98
139
 
99
- for name, loaded_weight in weights:
100
-
101
- if name.startswith("vision_model") or name.startswith(
102
- "multi_modal_projector"
103
- ):
104
- continue
140
+ # Params for weights, fp8 weight scales, fp8 activation scales
141
+ # (param_name, weight_name, expert_id, shard_id)
142
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
143
+ ckpt_gate_proj_name="gate_proj",
144
+ ckpt_down_proj_name="down_proj",
145
+ ckpt_up_proj_name="up_proj",
146
+ num_experts=num_experts,
147
+ )
105
148
 
106
- name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
149
+ for name, loaded_weight in weights:
150
+ if not "vision" in name:
151
+ name, loaded_weight = self.permute_qk_weight_for_rotary(
152
+ name, loaded_weight
153
+ )
107
154
 
108
155
  for param_name, weight_name, shard_id in stacked_params_mapping:
109
156
  if weight_name not in name:
110
157
  continue
158
+
159
+ if "vision" in name:
160
+ continue
111
161
  name = name.replace(weight_name, param_name)
112
162
  param = params_dict[name]
113
163
  weight_loader = param.weight_loader
@@ -115,31 +165,54 @@ class Llama4ForConditionalGeneration(nn.Module):
115
165
  break
116
166
  else:
117
167
  if ".experts" in name:
118
- if ".gate_up_proj" in name:
119
- name_list = [
120
- name.replace(".experts.gate_up_proj", ".experts.w13_weight")
121
- ] * 2
122
- loaded_weight_list = loaded_weight.chunk(2, dim=-1)
123
- shard_id_list = ["w1", "w3"]
124
- else:
125
- name_list = [
126
- name.replace(".experts.down_proj", ".experts.w2_weight")
127
- ]
128
- shard_id_list = ["w2"]
129
- loaded_weight_list = [loaded_weight]
130
- for name, loaded_weight, shard_id in zip(
131
- name_list, loaded_weight_list, shard_id_list
168
+ # NOTE: llama4 fp8 has different weight format for experts
169
+ if (
170
+ "experts.gate_up_proj" not in name
171
+ and "experts.down_proj" not in name
132
172
  ):
133
- param = params_dict[name]
134
- weight_loader = param.weight_loader
135
- for expert_id in range(num_experts):
173
+ for mapping in expert_params_mapping:
174
+ param_name, weight_name, expert_id, shard_id = mapping
175
+ if weight_name not in name:
176
+ continue
177
+ name = name.replace(weight_name, param_name)
178
+ param = params_dict[name]
179
+ weight_loader = param.weight_loader
136
180
  weight_loader(
137
181
  param,
138
- loaded_weight[expert_id].T,
182
+ loaded_weight,
139
183
  name,
140
184
  shard_id=shard_id,
141
185
  expert_id=expert_id,
142
186
  )
187
+ break
188
+ else:
189
+ if ".gate_up_proj" in name:
190
+ name_list = [
191
+ name.replace(
192
+ ".experts.gate_up_proj", ".experts.w13_weight"
193
+ )
194
+ ] * 2
195
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
196
+ shard_id_list = ["w1", "w3"]
197
+ else:
198
+ name_list = [
199
+ name.replace(".experts.down_proj", ".experts.w2_weight")
200
+ ]
201
+ shard_id_list = ["w2"]
202
+ loaded_weight_list = [loaded_weight]
203
+ for name, loaded_weight, shard_id in zip(
204
+ name_list, loaded_weight_list, shard_id_list
205
+ ):
206
+ param = params_dict[name]
207
+ weight_loader = param.weight_loader
208
+ for expert_id in range(num_experts):
209
+ weight_loader(
210
+ param,
211
+ loaded_weight[expert_id].T,
212
+ name,
213
+ shard_id=shard_id,
214
+ expert_id=expert_id,
215
+ )
143
216
  else:
144
217
  # Skip loading extra bias for GPTQ models.
145
218
  if name.endswith(".bias") and name not in params_dict:
sglang/srt/models/olmo.py CHANGED
@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
93
93
  self.scaling,
94
94
  num_kv_heads=self.num_heads,
95
95
  layer_id=layer_id,
96
+ quant_config=quant_config,
96
97
  prefix=add_prefix("attn", prefix),
97
98
  )
98
99
 
@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
118
118
  self.scaling,
119
119
  num_kv_heads=self.num_kv_heads,
120
120
  layer_id=layer_id,
121
+ quant_config=quant_config,
121
122
  prefix=add_prefix("attn", prefix),
122
123
  )
123
124
 
@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
170
170
  self.scaling,
171
171
  layer_id=layer_id,
172
172
  num_kv_heads=self.num_kv_heads,
173
+ quant_config=quant_config,
173
174
  prefix=add_prefix("attn", prefix),
174
175
  )
175
176
 
@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
202
202
  self.scale,
203
203
  num_kv_heads=self.num_kv_heads_per_partion,
204
204
  layer_id=layer_id,
205
+ quant_config=quant_config,
205
206
  prefix=add_prefix("attn", prefix),
206
207
  )
207
208
 
sglang/srt/models/qwen.py CHANGED
@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
133
133
  self.scaling,
134
134
  num_kv_heads=self.num_heads,
135
135
  layer_id=layer_id,
136
+ quant_config=quant_config,
136
137
  prefix=add_prefix("attn", prefix),
137
138
  )
138
139
 
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
154
154
  self.scaling,
155
155
  num_kv_heads=self.num_kv_heads,
156
156
  layer_id=layer_id,
157
+ quant_config=quant_config,
157
158
  prefix=add_prefix("attn", prefix),
158
159
  )
159
160
 
@@ -238,6 +239,7 @@ class Qwen2Model(nn.Module):
238
239
  config: Qwen2Config,
239
240
  quant_config: Optional[QuantizationConfig] = None,
240
241
  prefix: str = "",
242
+ decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
241
243
  ) -> None:
242
244
  super().__init__()
243
245
  self.config = config
@@ -249,9 +251,11 @@ class Qwen2Model(nn.Module):
249
251
  quant_config=quant_config,
250
252
  prefix=add_prefix("embed_tokens", prefix),
251
253
  )
254
+ # Use the provided decoder layer type or default to Qwen2DecoderLayer
255
+ decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
252
256
  self.layers = make_layers(
253
257
  config.num_hidden_layers,
254
- lambda idx, prefix: Qwen2DecoderLayer(
258
+ lambda idx, prefix: decoder_layer_type(
255
259
  layer_id=idx,
256
260
  config=config,
257
261
  quant_config=quant_config,
@@ -30,12 +30,16 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
- from transformers import Qwen2VLConfig
34
33
  from transformers.activations import ACT2FN
35
34
  from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
36
35
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
+ Qwen2_5_VLConfig,
37
37
  Qwen2_5_VLVisionConfig,
38
38
  )
39
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
+ Qwen2_5_VisionPatchEmbed,
41
+ Qwen2_5_VisionRotaryEmbedding,
42
+ )
39
43
 
40
44
  from sglang.srt.hf_transformers_utils import get_processor
41
45
  from sglang.srt.layers.attention.vision import VisionAttention
@@ -137,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
137
141
  embed_dim=dim,
138
142
  num_heads=num_heads,
139
143
  projection_size=dim,
140
- use_qkv_parallel=False,
144
+ use_qkv_parallel=True,
141
145
  use_context_forward=use_context_forward,
142
146
  softmax_in_single_precision=softmax_in_single_precision,
143
147
  flatten_batch=flatten_batch,
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
173
177
  return x
174
178
 
175
179
 
176
- class Qwen2_5_VisionPatchEmbed(nn.Module):
177
-
178
- def __init__(
179
- self,
180
- patch_size: int = 14,
181
- temporal_patch_size: int = 2,
182
- in_chans: int = 3,
183
- embed_dim: int = 1152,
184
- ) -> None:
185
- super().__init__()
186
- self.patch_size = patch_size
187
- self.temporal_patch_size = temporal_patch_size
188
- self.embed_dim = embed_dim
189
-
190
- kernel_size = [temporal_patch_size, patch_size, patch_size]
191
- self.proj = nn.Conv3d(
192
- in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
193
- )
194
-
195
- def forward(self, x: torch.Tensor) -> torch.Tensor:
196
- target_dtype = self.proj.weight.dtype
197
- L, C = x.shape
198
- x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
199
- x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
200
- return x
201
-
202
-
203
180
  class Qwen2_5_VisionPatchMerger(nn.Module):
204
181
 
205
182
  def __init__(
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
244
221
  return out
245
222
 
246
223
 
247
- class Qwen2_5_VisionRotaryEmbedding(nn.Module):
248
-
249
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
250
- super().__init__()
251
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
252
- self.register_buffer("inv_freq", inv_freq, persistent=False)
253
-
254
- def forward(self, seqlen: int) -> torch.Tensor:
255
- seq = torch.arange(
256
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
257
- )
258
- freqs = torch.outer(seq, self.inv_freq)
259
- return freqs
260
-
261
-
262
224
  class Qwen2_5_VisionTransformer(nn.Module):
263
225
 
264
226
  def __init__(
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
275
237
  spatial_merge_size: int = vision_config.spatial_merge_size
276
238
  self.spatial_merge_size = spatial_merge_size
277
239
  self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
278
- in_chans: int = vision_config.in_channels
240
+ in_channels: int = vision_config.in_channels
279
241
  hidden_size: int = vision_config.hidden_size
280
242
  depth: int = vision_config.depth
281
243
  num_heads: int = vision_config.num_heads
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
286
248
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
287
249
  patch_size=patch_size,
288
250
  temporal_patch_size=temporal_patch_size,
289
- in_chans=in_chans,
251
+ in_channels=in_channels,
290
252
  embed_dim=hidden_size,
291
253
  )
292
254
 
@@ -363,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
363
325
 
364
326
  @property
365
327
  def dtype(self) -> torch.dtype:
366
- return self.blocks[0].mlp.gate_proj.weight.dtype
328
+ return self.patch_embed.proj.weight.dtype
367
329
 
368
330
  @property
369
331
  def device(self) -> torch.device:
@@ -467,9 +429,28 @@ cached_get_processor = lru_cache(get_processor)
467
429
 
468
430
 
469
431
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
432
+ # BitandBytes specific attributes
433
+ default_bitsandbytes_target_modules = [
434
+ ".gate_proj.",
435
+ ".down_proj.",
436
+ ".up_proj.",
437
+ ".q_proj.",
438
+ ".k_proj.",
439
+ ".v_proj.",
440
+ ".o_proj.",
441
+ ]
442
+ bitsandbytes_stacked_params_mapping = {
443
+ # shard_name, weight_name, index
444
+ "q_proj": ("qkv_proj", 0),
445
+ "k_proj": ("qkv_proj", 1),
446
+ "v_proj": ("qkv_proj", 2),
447
+ "gate_proj": ("gate_up_proj", 0),
448
+ "up_proj": ("gate_up_proj", 1),
449
+ }
450
+
470
451
  def __init__(
471
452
  self,
472
- config: Qwen2VLConfig,
453
+ config: Qwen2_5_VLConfig,
473
454
  quant_config: Optional[QuantizationConfig] = None,
474
455
  prefix: str = "",
475
456
  ) -> None:
@@ -479,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
479
460
  self.visual = Qwen2_5_VisionTransformer(
480
461
  config.vision_config,
481
462
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
482
- # NOTE: Qwen2-VL vision encoder does not support any
483
- # quantization method now.
484
- quant_config=None,
463
+ # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
464
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
465
+ quant_config=quant_config,
485
466
  prefix=add_prefix("visual", prefix),
486
467
  )
487
468
 
@@ -500,6 +481,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
500
481
  quant_config=quant_config,
501
482
  prefix=add_prefix("lm_head", prefix),
502
483
  )
484
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
503
485
 
504
486
  self.logits_processor = LogitsProcessor(config)
505
487
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -553,14 +535,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
553
535
  otherwise it will be `(seq_len,).
554
536
  (Use input_metadata.mrope_positions to replace it)
555
537
  """
556
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
538
+ if self.is_mrope_enabled:
557
539
  positions = forward_batch.mrope_positions
558
540
 
559
541
  if not (
560
542
  forward_batch.forward_mode.is_decode()
561
543
  or not forward_batch.contains_image_inputs()
562
544
  ):
563
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
545
+ if self.is_mrope_enabled:
564
546
  assert positions.ndim == 2 and positions.size(0) == 3, (
565
547
  "multimodal section rotary embedding requires "
566
548
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -610,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
610
592
  weight_loader(param, loaded_weight, shard_id)
611
593
  break
612
594
  else:
613
- if "visual" in name and "qkv.weight" in name:
614
- visual_num_heads = self.config.vision_config.num_heads
615
- visual_embed_dim = self.config.vision_config.hidden_size
616
- head_size = visual_embed_dim // visual_num_heads
617
- loaded_weight = loaded_weight.view(
618
- 3, visual_num_heads, head_size, visual_embed_dim
619
- )
620
- loaded_weight = loaded_weight.transpose(0, 1)
621
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
622
- elif "visual" in name and "qkv.bias" in name:
623
- visual_num_heads = self.config.vision_config.num_heads
624
- visual_embed_dim = self.config.vision_config.hidden_size
625
- head_size = visual_embed_dim // visual_num_heads
626
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
627
- loaded_weight = loaded_weight.transpose(0, 1)
628
- loaded_weight = loaded_weight.reshape(-1)
629
-
630
595
  if "visual" in name:
631
596
  # adapt to VisionAttention
632
597
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")