sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
1
+ from collections.abc import Iterable
2
+ from typing import List, Optional, Set, Tuple
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import Llama4Config, Llama4VisionModel
7
+ from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
8
+
9
+ from sglang.srt.layers.logits_processor import LogitsProcessor
10
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
11
+ from sglang.srt.layers.quantization import QuantizationConfig
12
+ from sglang.srt.managers.mm_utils import (
13
+ MultiModalityDataPaddingPatternImageTokens,
14
+ general_mm_embed_routine,
15
+ )
16
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
17
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
19
+ from sglang.srt.utils import add_prefix
20
+
21
+
22
+ class Llama4ForConditionalGeneration(nn.Module):
23
+ packed_modules_mapping = {
24
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
25
+ "gate_up_proj": ["gate_proj", "up_proj"],
26
+ }
27
+
28
+ def __init__(
29
+ self,
30
+ config: Llama4Config,
31
+ quant_config: Optional[QuantizationConfig] = None,
32
+ prefix: str = "",
33
+ ):
34
+ super().__init__()
35
+ self.config = config
36
+ self.quant_config = quant_config
37
+
38
+ self.vision_model = Llama4VisionModel(config.vision_config)
39
+ self.multi_modal_projector = Llama4MultiModalProjector(config)
40
+
41
+ # Initialize the language model
42
+ from sglang.srt.models.llama4 import Llama4ForCausalLM
43
+
44
+ self.language_model = Llama4ForCausalLM(
45
+ config.text_config,
46
+ quant_config=quant_config,
47
+ prefix=add_prefix("language_model", prefix),
48
+ )
49
+
50
+ self.logits_processor = LogitsProcessor(config.text_config)
51
+
52
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53
+ # Get all special token IDs
54
+ im_token_id: int = mm_inputs.im_token_id
55
+
56
+ pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
57
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
58
+
59
+ def get_image_feature(
60
+ self,
61
+ items: List[MultimodalDataItem],
62
+ ) -> torch.Tensor:
63
+ pixel_values = (
64
+ torch.concat([item.pixel_values for item in items])
65
+ .to(next(self.vision_model.parameters()).device)
66
+ .type(next(self.vision_model.parameters()).dtype)
67
+ )
68
+
69
+ image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
70
+ image_features = image_outputs.last_hidden_state
71
+ vision_flat = image_features.view(-1, image_features.size(-1))
72
+ projected_vision_flat = self.multi_modal_projector(vision_flat)
73
+ return projected_vision_flat
74
+
75
+ def forward(
76
+ self,
77
+ input_ids: torch.Tensor,
78
+ positions: torch.Tensor,
79
+ forward_batch: ForwardBatch,
80
+ **kwargs: object,
81
+ ) -> torch.Tensor:
82
+
83
+ hs = general_mm_embed_routine(
84
+ input_ids=input_ids,
85
+ forward_batch=forward_batch,
86
+ language_model=self.language_model,
87
+ image_data_embedding_func=self.get_image_feature,
88
+ positions=positions,
89
+ )
90
+
91
+ return hs
92
+
93
+ def permute_qk_weight_for_rotary(
94
+ self,
95
+ name: str,
96
+ loaded_weight: torch.Tensor,
97
+ ) -> Tuple[str, torch.Tensor]:
98
+
99
+ def permute(w: torch.Tensor, n_heads: int):
100
+ attn_in = self.language_model.config.head_dim * n_heads
101
+ attn_out = self.language_model.config.hidden_size
102
+
103
+ return (
104
+ w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
105
+ .transpose(1, 2)
106
+ .reshape(attn_in, attn_out)
107
+ )
108
+
109
+ modules = name.split(".")
110
+
111
+ # rotary embeds should be sliced
112
+ if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
113
+ loaded_weight = permute(
114
+ loaded_weight, self.language_model.config.num_key_value_heads
115
+ )
116
+ elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
117
+ loaded_weight = permute(
118
+ loaded_weight, self.language_model.config.num_attention_heads
119
+ )
120
+
121
+ return name, loaded_weight
122
+
123
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
124
+
125
+ stacked_params_mapping = [
126
+ # (param_name, shard_name, shard_id)
127
+ (".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
128
+ (".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
129
+ (".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
130
+ (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
131
+ (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
132
+ (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
133
+ (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
134
+ ]
135
+
136
+ params_dict = dict(self.named_parameters())
137
+
138
+ num_experts = self.config.text_config.num_local_experts
139
+
140
+ # Params for weights, fp8 weight scales, fp8 activation scales
141
+ # (param_name, weight_name, expert_id, shard_id)
142
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
143
+ ckpt_gate_proj_name="gate_proj",
144
+ ckpt_down_proj_name="down_proj",
145
+ ckpt_up_proj_name="up_proj",
146
+ num_experts=num_experts,
147
+ )
148
+
149
+ for name, loaded_weight in weights:
150
+ if not "vision" in name:
151
+ name, loaded_weight = self.permute_qk_weight_for_rotary(
152
+ name, loaded_weight
153
+ )
154
+
155
+ for param_name, weight_name, shard_id in stacked_params_mapping:
156
+ if weight_name not in name:
157
+ continue
158
+
159
+ if "vision" in name:
160
+ continue
161
+ name = name.replace(weight_name, param_name)
162
+ param = params_dict[name]
163
+ weight_loader = param.weight_loader
164
+ weight_loader(param, loaded_weight, shard_id)
165
+ break
166
+ else:
167
+ if ".experts" in name:
168
+ # NOTE: llama4 fp8 has different weight format for experts
169
+ if (
170
+ "experts.gate_up_proj" not in name
171
+ and "experts.down_proj" not in name
172
+ ):
173
+ for mapping in expert_params_mapping:
174
+ param_name, weight_name, expert_id, shard_id = mapping
175
+ if weight_name not in name:
176
+ continue
177
+ name = name.replace(weight_name, param_name)
178
+ param = params_dict[name]
179
+ weight_loader = param.weight_loader
180
+ weight_loader(
181
+ param,
182
+ loaded_weight,
183
+ name,
184
+ shard_id=shard_id,
185
+ expert_id=expert_id,
186
+ )
187
+ break
188
+ else:
189
+ if ".gate_up_proj" in name:
190
+ name_list = [
191
+ name.replace(
192
+ ".experts.gate_up_proj", ".experts.w13_weight"
193
+ )
194
+ ] * 2
195
+ loaded_weight_list = loaded_weight.chunk(2, dim=-1)
196
+ shard_id_list = ["w1", "w3"]
197
+ else:
198
+ name_list = [
199
+ name.replace(".experts.down_proj", ".experts.w2_weight")
200
+ ]
201
+ shard_id_list = ["w2"]
202
+ loaded_weight_list = [loaded_weight]
203
+ for name, loaded_weight, shard_id in zip(
204
+ name_list, loaded_weight_list, shard_id_list
205
+ ):
206
+ param = params_dict[name]
207
+ weight_loader = param.weight_loader
208
+ for expert_id in range(num_experts):
209
+ weight_loader(
210
+ param,
211
+ loaded_weight[expert_id].T,
212
+ name,
213
+ shard_id=shard_id,
214
+ expert_id=expert_id,
215
+ )
216
+ else:
217
+ # Skip loading extra bias for GPTQ models.
218
+ if name.endswith(".bias") and name not in params_dict:
219
+ continue
220
+ param = params_dict[name]
221
+ weight_loader = getattr(
222
+ param, "weight_loader", default_weight_loader
223
+ )
224
+ weight_loader(param, loaded_weight)
225
+
226
+
227
+ EntryClass = Llama4ForConditionalGeneration
sglang/srt/models/olmo.py CHANGED
@@ -93,6 +93,7 @@ class OlmoAttention(nn.Module):
93
93
  self.scaling,
94
94
  num_kv_heads=self.num_heads,
95
95
  layer_id=layer_id,
96
+ quant_config=quant_config,
96
97
  prefix=add_prefix("attn", prefix),
97
98
  )
98
99
 
@@ -118,6 +118,7 @@ class Olmo2Attention(nn.Module):
118
118
  self.scaling,
119
119
  num_kv_heads=self.num_kv_heads,
120
120
  layer_id=layer_id,
121
+ quant_config=quant_config,
121
122
  prefix=add_prefix("attn", prefix),
122
123
  )
123
124
 
@@ -170,6 +170,7 @@ class OlmoeAttention(nn.Module):
170
170
  self.scaling,
171
171
  layer_id=layer_id,
172
172
  num_kv_heads=self.num_kv_heads,
173
+ quant_config=quant_config,
173
174
  prefix=add_prefix("attn", prefix),
174
175
  )
175
176
 
@@ -202,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
202
202
  self.scale,
203
203
  num_kv_heads=self.num_kv_heads_per_partion,
204
204
  layer_id=layer_id,
205
+ quant_config=quant_config,
205
206
  prefix=add_prefix("attn", prefix),
206
207
  )
207
208
 
sglang/srt/models/qwen.py CHANGED
@@ -133,6 +133,7 @@ class QWenAttention(nn.Module):
133
133
  self.scaling,
134
134
  num_kv_heads=self.num_heads,
135
135
  layer_id=layer_id,
136
+ quant_config=quant_config,
136
137
  prefix=add_prefix("attn", prefix),
137
138
  )
138
139
 
@@ -154,6 +154,7 @@ class Qwen2Attention(nn.Module):
154
154
  self.scaling,
155
155
  num_kv_heads=self.num_kv_heads,
156
156
  layer_id=layer_id,
157
+ quant_config=quant_config,
157
158
  prefix=add_prefix("attn", prefix),
158
159
  )
159
160
 
@@ -30,12 +30,16 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange
33
- from transformers import Qwen2VLConfig
34
33
  from transformers.activations import ACT2FN
35
34
  from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
36
35
  from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
36
+ Qwen2_5_VLConfig,
37
37
  Qwen2_5_VLVisionConfig,
38
38
  )
39
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
+ Qwen2_5_VisionPatchEmbed,
41
+ Qwen2_5_VisionRotaryEmbedding,
42
+ )
39
43
 
40
44
  from sglang.srt.hf_transformers_utils import get_processor
41
45
  from sglang.srt.layers.attention.vision import VisionAttention
@@ -137,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
137
141
  embed_dim=dim,
138
142
  num_heads=num_heads,
139
143
  projection_size=dim,
140
- use_qkv_parallel=False,
144
+ use_qkv_parallel=True,
141
145
  use_context_forward=use_context_forward,
142
146
  softmax_in_single_precision=softmax_in_single_precision,
143
147
  flatten_batch=flatten_batch,
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
173
177
  return x
174
178
 
175
179
 
176
- class Qwen2_5_VisionPatchEmbed(nn.Module):
177
-
178
- def __init__(
179
- self,
180
- patch_size: int = 14,
181
- temporal_patch_size: int = 2,
182
- in_chans: int = 3,
183
- embed_dim: int = 1152,
184
- ) -> None:
185
- super().__init__()
186
- self.patch_size = patch_size
187
- self.temporal_patch_size = temporal_patch_size
188
- self.embed_dim = embed_dim
189
-
190
- kernel_size = [temporal_patch_size, patch_size, patch_size]
191
- self.proj = nn.Conv3d(
192
- in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
193
- )
194
-
195
- def forward(self, x: torch.Tensor) -> torch.Tensor:
196
- target_dtype = self.proj.weight.dtype
197
- L, C = x.shape
198
- x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
199
- x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
200
- return x
201
-
202
-
203
180
  class Qwen2_5_VisionPatchMerger(nn.Module):
204
181
 
205
182
  def __init__(
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
244
221
  return out
245
222
 
246
223
 
247
- class Qwen2_5_VisionRotaryEmbedding(nn.Module):
248
-
249
- def __init__(self, dim: int, theta: float = 10000.0) -> None:
250
- super().__init__()
251
- inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
252
- self.register_buffer("inv_freq", inv_freq, persistent=False)
253
-
254
- def forward(self, seqlen: int) -> torch.Tensor:
255
- seq = torch.arange(
256
- seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
257
- )
258
- freqs = torch.outer(seq, self.inv_freq)
259
- return freqs
260
-
261
-
262
224
  class Qwen2_5_VisionTransformer(nn.Module):
263
225
 
264
226
  def __init__(
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
275
237
  spatial_merge_size: int = vision_config.spatial_merge_size
276
238
  self.spatial_merge_size = spatial_merge_size
277
239
  self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
278
- in_chans: int = vision_config.in_channels
240
+ in_channels: int = vision_config.in_channels
279
241
  hidden_size: int = vision_config.hidden_size
280
242
  depth: int = vision_config.depth
281
243
  num_heads: int = vision_config.num_heads
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
286
248
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
287
249
  patch_size=patch_size,
288
250
  temporal_patch_size=temporal_patch_size,
289
- in_chans=in_chans,
251
+ in_channels=in_channels,
290
252
  embed_dim=hidden_size,
291
253
  )
292
254
 
@@ -363,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
363
325
 
364
326
  @property
365
327
  def dtype(self) -> torch.dtype:
366
- return self.blocks[0].mlp.gate_proj.weight.dtype
328
+ return self.patch_embed.proj.weight.dtype
367
329
 
368
330
  @property
369
331
  def device(self) -> torch.device:
@@ -467,9 +429,28 @@ cached_get_processor = lru_cache(get_processor)
467
429
 
468
430
 
469
431
  class Qwen2_5_VLForConditionalGeneration(nn.Module):
432
+ # BitandBytes specific attributes
433
+ default_bitsandbytes_target_modules = [
434
+ ".gate_proj.",
435
+ ".down_proj.",
436
+ ".up_proj.",
437
+ ".q_proj.",
438
+ ".k_proj.",
439
+ ".v_proj.",
440
+ ".o_proj.",
441
+ ]
442
+ bitsandbytes_stacked_params_mapping = {
443
+ # shard_name, weight_name, index
444
+ "q_proj": ("qkv_proj", 0),
445
+ "k_proj": ("qkv_proj", 1),
446
+ "v_proj": ("qkv_proj", 2),
447
+ "gate_proj": ("gate_up_proj", 0),
448
+ "up_proj": ("gate_up_proj", 1),
449
+ }
450
+
470
451
  def __init__(
471
452
  self,
472
- config: Qwen2VLConfig,
453
+ config: Qwen2_5_VLConfig,
473
454
  quant_config: Optional[QuantizationConfig] = None,
474
455
  prefix: str = "",
475
456
  ) -> None:
@@ -479,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
479
460
  self.visual = Qwen2_5_VisionTransformer(
480
461
  config.vision_config,
481
462
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
482
- # NOTE: Qwen2-VL vision encoder does not support any
483
- # quantization method now.
484
- quant_config=None,
463
+ # NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
464
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
465
+ quant_config=quant_config,
485
466
  prefix=add_prefix("visual", prefix),
486
467
  )
487
468
 
@@ -500,6 +481,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
500
481
  quant_config=quant_config,
501
482
  prefix=add_prefix("lm_head", prefix),
502
483
  )
484
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
503
485
 
504
486
  self.logits_processor = LogitsProcessor(config)
505
487
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@@ -553,14 +535,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
553
535
  otherwise it will be `(seq_len,).
554
536
  (Use input_metadata.mrope_positions to replace it)
555
537
  """
556
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
538
+ if self.is_mrope_enabled:
557
539
  positions = forward_batch.mrope_positions
558
540
 
559
541
  if not (
560
542
  forward_batch.forward_mode.is_decode()
561
543
  or not forward_batch.contains_image_inputs()
562
544
  ):
563
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
545
+ if self.is_mrope_enabled:
564
546
  assert positions.ndim == 2 and positions.size(0) == 3, (
565
547
  "multimodal section rotary embedding requires "
566
548
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -610,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
610
592
  weight_loader(param, loaded_weight, shard_id)
611
593
  break
612
594
  else:
613
- if "visual" in name and "qkv.weight" in name:
614
- visual_num_heads = self.config.vision_config.num_heads
615
- visual_embed_dim = self.config.vision_config.hidden_size
616
- head_size = visual_embed_dim // visual_num_heads
617
- loaded_weight = loaded_weight.view(
618
- 3, visual_num_heads, head_size, visual_embed_dim
619
- )
620
- loaded_weight = loaded_weight.transpose(0, 1)
621
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
622
- elif "visual" in name and "qkv.bias" in name:
623
- visual_num_heads = self.config.vision_config.num_heads
624
- visual_embed_dim = self.config.vision_config.hidden_size
625
- head_size = visual_embed_dim // visual_num_heads
626
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
627
- loaded_weight = loaded_weight.transpose(0, 1)
628
- loaded_weight = loaded_weight.reshape(-1)
629
-
630
595
  if "visual" in name:
631
596
  # adapt to VisionAttention
632
597
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
@@ -231,6 +231,7 @@ class Qwen2MoeAttention(nn.Module):
231
231
  self.scaling,
232
232
  num_kv_heads=self.num_kv_heads,
233
233
  layer_id=layer_id,
234
+ quant_config=quant_config,
234
235
  prefix=add_prefix("attn", prefix),
235
236
  )
236
237
 
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
- use_qkv_parallel=False,
155
+ use_qkv_parallel=True,
156
156
  use_context_forward=use_context_forward,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
351
351
 
352
352
  @property
353
353
  def dtype(self) -> torch.dtype:
354
- return next(self.parameters()).dtype
354
+ return self.patch_embed.proj.weight.dtype
355
355
 
356
356
  @property
357
357
  def device(self) -> torch.device:
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
423
423
 
424
424
 
425
425
  class Qwen2VLForConditionalGeneration(nn.Module):
426
+ # BitandBytes specific attributes
427
+ default_bitsandbytes_target_modules = [
428
+ ".gate_proj.",
429
+ ".down_proj.",
430
+ ".up_proj.",
431
+ ".q_proj.",
432
+ ".k_proj.",
433
+ ".v_proj.",
434
+ ".o_proj.",
435
+ ]
436
+ bitsandbytes_stacked_params_mapping = {
437
+ # shard_name, weight_name, index
438
+ "q_proj": ("qkv_proj", 0),
439
+ "k_proj": ("qkv_proj", 1),
440
+ "v_proj": ("qkv_proj", 2),
441
+ "gate_proj": ("gate_up_proj", 0),
442
+ "up_proj": ("gate_up_proj", 1),
443
+ }
444
+
426
445
  def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
427
446
  processor = cached_get_processor(self.config._name_or_path)
428
447
  grid_t, grid_h, grid_w = image_grid_thw
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
447
466
  self.visual = Qwen2VisionTransformer(
448
467
  config.vision_config,
449
468
  norm_eps=getattr(config, "rms_norm_eps", 1e-6),
450
- # NOTE: Qwen2-VL vision encoder does not support any
451
- # quantization method now.
452
- quant_config=None,
469
+ # NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
470
+ # Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
471
+ quant_config=quant_config,
453
472
  prefix=add_prefix("visual", prefix),
454
473
  )
455
474
 
@@ -467,6 +486,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
467
486
  prefix=add_prefix("lm_head", prefix),
468
487
  )
469
488
 
489
+ self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
470
490
  self.logits_processor = LogitsProcessor(config)
471
491
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
472
492
 
@@ -521,14 +541,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
521
541
  otherwise it will be `(seq_len,).
522
542
  (Use input_metadata.mrope_positions to replace it)
523
543
  """
524
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
544
+ if self.is_mrope_enabled:
525
545
  positions = forward_batch.mrope_positions
526
546
 
527
547
  if not (
528
548
  forward_batch.forward_mode.is_decode()
529
549
  or not forward_batch.contains_image_inputs()
530
550
  ):
531
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
551
+ if self.is_mrope_enabled:
532
552
  assert positions.ndim == 2 and positions.size(0) == 3, (
533
553
  "multimodal section rotary embedding requires "
534
554
  f"(3, seq_len) positions, but got {positions.size()}"
@@ -577,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
577
597
  weight_loader(param, loaded_weight, shard_id)
578
598
  break
579
599
  else:
580
-
581
- if "visual" in name and "qkv.weight" in name:
582
- visual_num_heads = self.config.vision_config.num_heads
583
- visual_embed_dim = self.config.vision_config.embed_dim
584
- head_size = visual_embed_dim // visual_num_heads
585
- loaded_weight = loaded_weight.view(
586
- 3, visual_num_heads, head_size, visual_embed_dim
587
- )
588
- loaded_weight = loaded_weight.transpose(0, 1)
589
- loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
590
- elif "visual" in name and "qkv.bias" in name:
591
- visual_num_heads = self.config.vision_config.num_heads
592
- visual_embed_dim = self.config.vision_config.embed_dim
593
- head_size = visual_embed_dim // visual_num_heads
594
- loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
595
- loaded_weight = loaded_weight.transpose(0, 1)
596
- loaded_weight = loaded_weight.reshape(-1)
597
-
598
600
  if "visual" in name:
599
601
  # adapt to VisionAttention
600
602
  name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
@@ -149,6 +149,7 @@ class StablelmAttention(nn.Module):
149
149
  self.scaling,
150
150
  num_kv_heads=self.num_key_value_heads,
151
151
  layer_id=layer_id,
152
+ quant_config=quant_config,
152
153
  prefix=add_prefix("attn", prefix),
153
154
  )
154
155
 
@@ -153,6 +153,7 @@ class XverseAttention(nn.Module):
153
153
  self.scaling,
154
154
  num_kv_heads=self.num_kv_heads,
155
155
  layer_id=layer_id,
156
+ quant_config=quant_config,
156
157
  prefix=add_prefix("attn", prefix),
157
158
  )
158
159
 
@@ -252,6 +252,7 @@ class XverseAttention(nn.Module):
252
252
  self.scaling,
253
253
  num_kv_heads=self.num_kv_heads,
254
254
  layer_id=layer_id,
255
+ quant_config=quant_config,
255
256
  prefix=add_prefix("attn", prefix),
256
257
  )
257
258
 
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
983
983
  ):
984
984
  encoded = encoded[1:]
985
985
  prompt_ids += encoded
986
+ if tokenizer_manager.model_config.is_multimodal:
987
+ prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
986
988
  stop = request.stop
987
989
  image_data = None
988
990
  audio_data = None
@@ -993,7 +995,8 @@ def v1_chat_generate_request(
993
995
  image_data = conv.image_data
994
996
  audio_data = conv.audio_data
995
997
  modalities = conv.modalities
996
- stop = conv.stop_str or []
998
+ stop = conv.stop_str or [] if not request.ignore_eos else []
999
+
997
1000
  if request.stop:
998
1001
  if isinstance(request.stop, str):
999
1002
  stop.append(request.stop)
sglang/srt/patch_torch.py CHANGED
@@ -14,6 +14,7 @@
14
14
  from typing import Callable, Union
15
15
 
16
16
  import torch
17
+ from packaging import version
17
18
  from torch.multiprocessing import reductions
18
19
 
19
20
 
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
69
70
 
70
71
  def _modify_tuple(t, index: int, modifier: Callable):
71
72
  return *t[:index], modifier(t[index]), *t[index + 1 :]
73
+
74
+
75
+ def monkey_patch_torch_compile():
76
+ if version.parse(torch.__version__) < version.parse("2.8.0"):
77
+ # These things are cacheable by torch.compile. torch.compile just doesn't know it.
78
+ # This was fixed in PyTorch 2.8, but until then, we monkey patch.
79
+ import torch._higher_order_ops.auto_functionalize as af
80
+
81
+ af.auto_functionalized_v2._cacheable = True
82
+ af.auto_functionalized._cacheable = True