sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,227 @@
|
|
1
|
+
from collections.abc import Iterable
|
2
|
+
from typing import List, Optional, Set, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn
|
6
|
+
from transformers import Llama4Config, Llama4VisionModel
|
7
|
+
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
8
|
+
|
9
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
10
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
11
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
12
|
+
from sglang.srt.managers.mm_utils import (
|
13
|
+
MultiModalityDataPaddingPatternImageTokens,
|
14
|
+
general_mm_embed_routine,
|
15
|
+
)
|
16
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
17
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
19
|
+
from sglang.srt.utils import add_prefix
|
20
|
+
|
21
|
+
|
22
|
+
class Llama4ForConditionalGeneration(nn.Module):
|
23
|
+
packed_modules_mapping = {
|
24
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
25
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
26
|
+
}
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
config: Llama4Config,
|
31
|
+
quant_config: Optional[QuantizationConfig] = None,
|
32
|
+
prefix: str = "",
|
33
|
+
):
|
34
|
+
super().__init__()
|
35
|
+
self.config = config
|
36
|
+
self.quant_config = quant_config
|
37
|
+
|
38
|
+
self.vision_model = Llama4VisionModel(config.vision_config)
|
39
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
40
|
+
|
41
|
+
# Initialize the language model
|
42
|
+
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
43
|
+
|
44
|
+
self.language_model = Llama4ForCausalLM(
|
45
|
+
config.text_config,
|
46
|
+
quant_config=quant_config,
|
47
|
+
prefix=add_prefix("language_model", prefix),
|
48
|
+
)
|
49
|
+
|
50
|
+
self.logits_processor = LogitsProcessor(config.text_config)
|
51
|
+
|
52
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
53
|
+
# Get all special token IDs
|
54
|
+
im_token_id: int = mm_inputs.im_token_id
|
55
|
+
|
56
|
+
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
|
57
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
58
|
+
|
59
|
+
def get_image_feature(
|
60
|
+
self,
|
61
|
+
items: List[MultimodalDataItem],
|
62
|
+
) -> torch.Tensor:
|
63
|
+
pixel_values = (
|
64
|
+
torch.concat([item.pixel_values for item in items])
|
65
|
+
.to(next(self.vision_model.parameters()).device)
|
66
|
+
.type(next(self.vision_model.parameters()).dtype)
|
67
|
+
)
|
68
|
+
|
69
|
+
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
70
|
+
image_features = image_outputs.last_hidden_state
|
71
|
+
vision_flat = image_features.view(-1, image_features.size(-1))
|
72
|
+
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
73
|
+
return projected_vision_flat
|
74
|
+
|
75
|
+
def forward(
|
76
|
+
self,
|
77
|
+
input_ids: torch.Tensor,
|
78
|
+
positions: torch.Tensor,
|
79
|
+
forward_batch: ForwardBatch,
|
80
|
+
**kwargs: object,
|
81
|
+
) -> torch.Tensor:
|
82
|
+
|
83
|
+
hs = general_mm_embed_routine(
|
84
|
+
input_ids=input_ids,
|
85
|
+
forward_batch=forward_batch,
|
86
|
+
language_model=self.language_model,
|
87
|
+
image_data_embedding_func=self.get_image_feature,
|
88
|
+
positions=positions,
|
89
|
+
)
|
90
|
+
|
91
|
+
return hs
|
92
|
+
|
93
|
+
def permute_qk_weight_for_rotary(
|
94
|
+
self,
|
95
|
+
name: str,
|
96
|
+
loaded_weight: torch.Tensor,
|
97
|
+
) -> Tuple[str, torch.Tensor]:
|
98
|
+
|
99
|
+
def permute(w: torch.Tensor, n_heads: int):
|
100
|
+
attn_in = self.language_model.config.head_dim * n_heads
|
101
|
+
attn_out = self.language_model.config.hidden_size
|
102
|
+
|
103
|
+
return (
|
104
|
+
w.view(n_heads, attn_in // n_heads // 2, 2, attn_out)
|
105
|
+
.transpose(1, 2)
|
106
|
+
.reshape(attn_in, attn_out)
|
107
|
+
)
|
108
|
+
|
109
|
+
modules = name.split(".")
|
110
|
+
|
111
|
+
# rotary embeds should be sliced
|
112
|
+
if ("wk" in modules or "k_proj" in modules) and modules[-1] == "weight":
|
113
|
+
loaded_weight = permute(
|
114
|
+
loaded_weight, self.language_model.config.num_key_value_heads
|
115
|
+
)
|
116
|
+
elif ("wq" in modules or "q_proj" in modules) and modules[-1] == "weight":
|
117
|
+
loaded_weight = permute(
|
118
|
+
loaded_weight, self.language_model.config.num_attention_heads
|
119
|
+
)
|
120
|
+
|
121
|
+
return name, loaded_weight
|
122
|
+
|
123
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
124
|
+
|
125
|
+
stacked_params_mapping = [
|
126
|
+
# (param_name, shard_name, shard_id)
|
127
|
+
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
|
128
|
+
(".self_attn.qkv_proj", ".self_attn.k_proj", "k"),
|
129
|
+
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
130
|
+
(".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0),
|
131
|
+
(".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1),
|
132
|
+
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
133
|
+
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
134
|
+
]
|
135
|
+
|
136
|
+
params_dict = dict(self.named_parameters())
|
137
|
+
|
138
|
+
num_experts = self.config.text_config.num_local_experts
|
139
|
+
|
140
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
141
|
+
# (param_name, weight_name, expert_id, shard_id)
|
142
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
143
|
+
ckpt_gate_proj_name="gate_proj",
|
144
|
+
ckpt_down_proj_name="down_proj",
|
145
|
+
ckpt_up_proj_name="up_proj",
|
146
|
+
num_experts=num_experts,
|
147
|
+
)
|
148
|
+
|
149
|
+
for name, loaded_weight in weights:
|
150
|
+
if not "vision" in name:
|
151
|
+
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
152
|
+
name, loaded_weight
|
153
|
+
)
|
154
|
+
|
155
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
156
|
+
if weight_name not in name:
|
157
|
+
continue
|
158
|
+
|
159
|
+
if "vision" in name:
|
160
|
+
continue
|
161
|
+
name = name.replace(weight_name, param_name)
|
162
|
+
param = params_dict[name]
|
163
|
+
weight_loader = param.weight_loader
|
164
|
+
weight_loader(param, loaded_weight, shard_id)
|
165
|
+
break
|
166
|
+
else:
|
167
|
+
if ".experts" in name:
|
168
|
+
# NOTE: llama4 fp8 has different weight format for experts
|
169
|
+
if (
|
170
|
+
"experts.gate_up_proj" not in name
|
171
|
+
and "experts.down_proj" not in name
|
172
|
+
):
|
173
|
+
for mapping in expert_params_mapping:
|
174
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
175
|
+
if weight_name not in name:
|
176
|
+
continue
|
177
|
+
name = name.replace(weight_name, param_name)
|
178
|
+
param = params_dict[name]
|
179
|
+
weight_loader = param.weight_loader
|
180
|
+
weight_loader(
|
181
|
+
param,
|
182
|
+
loaded_weight,
|
183
|
+
name,
|
184
|
+
shard_id=shard_id,
|
185
|
+
expert_id=expert_id,
|
186
|
+
)
|
187
|
+
break
|
188
|
+
else:
|
189
|
+
if ".gate_up_proj" in name:
|
190
|
+
name_list = [
|
191
|
+
name.replace(
|
192
|
+
".experts.gate_up_proj", ".experts.w13_weight"
|
193
|
+
)
|
194
|
+
] * 2
|
195
|
+
loaded_weight_list = loaded_weight.chunk(2, dim=-1)
|
196
|
+
shard_id_list = ["w1", "w3"]
|
197
|
+
else:
|
198
|
+
name_list = [
|
199
|
+
name.replace(".experts.down_proj", ".experts.w2_weight")
|
200
|
+
]
|
201
|
+
shard_id_list = ["w2"]
|
202
|
+
loaded_weight_list = [loaded_weight]
|
203
|
+
for name, loaded_weight, shard_id in zip(
|
204
|
+
name_list, loaded_weight_list, shard_id_list
|
205
|
+
):
|
206
|
+
param = params_dict[name]
|
207
|
+
weight_loader = param.weight_loader
|
208
|
+
for expert_id in range(num_experts):
|
209
|
+
weight_loader(
|
210
|
+
param,
|
211
|
+
loaded_weight[expert_id].T,
|
212
|
+
name,
|
213
|
+
shard_id=shard_id,
|
214
|
+
expert_id=expert_id,
|
215
|
+
)
|
216
|
+
else:
|
217
|
+
# Skip loading extra bias for GPTQ models.
|
218
|
+
if name.endswith(".bias") and name not in params_dict:
|
219
|
+
continue
|
220
|
+
param = params_dict[name]
|
221
|
+
weight_loader = getattr(
|
222
|
+
param, "weight_loader", default_weight_loader
|
223
|
+
)
|
224
|
+
weight_loader(param, loaded_weight)
|
225
|
+
|
226
|
+
|
227
|
+
EntryClass = Llama4ForConditionalGeneration
|
sglang/srt/models/olmo.py
CHANGED
sglang/srt/models/olmo2.py
CHANGED
sglang/srt/models/olmoe.py
CHANGED
sglang/srt/models/phi3_small.py
CHANGED
sglang/srt/models/qwen.py
CHANGED
sglang/srt/models/qwen2.py
CHANGED
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -30,12 +30,16 @@ import torch
|
|
30
30
|
import torch.nn as nn
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
|
-
from transformers import Qwen2VLConfig
|
34
33
|
from transformers.activations import ACT2FN
|
35
34
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
36
35
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
36
|
+
Qwen2_5_VLConfig,
|
37
37
|
Qwen2_5_VLVisionConfig,
|
38
38
|
)
|
39
|
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
40
|
+
Qwen2_5_VisionPatchEmbed,
|
41
|
+
Qwen2_5_VisionRotaryEmbedding,
|
42
|
+
)
|
39
43
|
|
40
44
|
from sglang.srt.hf_transformers_utils import get_processor
|
41
45
|
from sglang.srt.layers.attention.vision import VisionAttention
|
@@ -137,7 +141,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
137
141
|
embed_dim=dim,
|
138
142
|
num_heads=num_heads,
|
139
143
|
projection_size=dim,
|
140
|
-
use_qkv_parallel=
|
144
|
+
use_qkv_parallel=True,
|
141
145
|
use_context_forward=use_context_forward,
|
142
146
|
softmax_in_single_precision=softmax_in_single_precision,
|
143
147
|
flatten_batch=flatten_batch,
|
@@ -173,33 +177,6 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
173
177
|
return x
|
174
178
|
|
175
179
|
|
176
|
-
class Qwen2_5_VisionPatchEmbed(nn.Module):
|
177
|
-
|
178
|
-
def __init__(
|
179
|
-
self,
|
180
|
-
patch_size: int = 14,
|
181
|
-
temporal_patch_size: int = 2,
|
182
|
-
in_chans: int = 3,
|
183
|
-
embed_dim: int = 1152,
|
184
|
-
) -> None:
|
185
|
-
super().__init__()
|
186
|
-
self.patch_size = patch_size
|
187
|
-
self.temporal_patch_size = temporal_patch_size
|
188
|
-
self.embed_dim = embed_dim
|
189
|
-
|
190
|
-
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
191
|
-
self.proj = nn.Conv3d(
|
192
|
-
in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False
|
193
|
-
)
|
194
|
-
|
195
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
196
|
-
target_dtype = self.proj.weight.dtype
|
197
|
-
L, C = x.shape
|
198
|
-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
199
|
-
x = self.proj(x.to(dtype=target_dtype)).view(L, self.embed_dim)
|
200
|
-
return x
|
201
|
-
|
202
|
-
|
203
180
|
class Qwen2_5_VisionPatchMerger(nn.Module):
|
204
181
|
|
205
182
|
def __init__(
|
@@ -244,21 +221,6 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
244
221
|
return out
|
245
222
|
|
246
223
|
|
247
|
-
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
|
248
|
-
|
249
|
-
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
250
|
-
super().__init__()
|
251
|
-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
252
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
253
|
-
|
254
|
-
def forward(self, seqlen: int) -> torch.Tensor:
|
255
|
-
seq = torch.arange(
|
256
|
-
seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
|
257
|
-
)
|
258
|
-
freqs = torch.outer(seq, self.inv_freq)
|
259
|
-
return freqs
|
260
|
-
|
261
|
-
|
262
224
|
class Qwen2_5_VisionTransformer(nn.Module):
|
263
225
|
|
264
226
|
def __init__(
|
@@ -275,7 +237,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
275
237
|
spatial_merge_size: int = vision_config.spatial_merge_size
|
276
238
|
self.spatial_merge_size = spatial_merge_size
|
277
239
|
self.spatial_merge_unit: int = spatial_merge_size * spatial_merge_size
|
278
|
-
|
240
|
+
in_channels: int = vision_config.in_channels
|
279
241
|
hidden_size: int = vision_config.hidden_size
|
280
242
|
depth: int = vision_config.depth
|
281
243
|
num_heads: int = vision_config.num_heads
|
@@ -286,7 +248,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
286
248
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
287
249
|
patch_size=patch_size,
|
288
250
|
temporal_patch_size=temporal_patch_size,
|
289
|
-
|
251
|
+
in_channels=in_channels,
|
290
252
|
embed_dim=hidden_size,
|
291
253
|
)
|
292
254
|
|
@@ -363,7 +325,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
363
325
|
|
364
326
|
@property
|
365
327
|
def dtype(self) -> torch.dtype:
|
366
|
-
return self.
|
328
|
+
return self.patch_embed.proj.weight.dtype
|
367
329
|
|
368
330
|
@property
|
369
331
|
def device(self) -> torch.device:
|
@@ -467,9 +429,28 @@ cached_get_processor = lru_cache(get_processor)
|
|
467
429
|
|
468
430
|
|
469
431
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
432
|
+
# BitandBytes specific attributes
|
433
|
+
default_bitsandbytes_target_modules = [
|
434
|
+
".gate_proj.",
|
435
|
+
".down_proj.",
|
436
|
+
".up_proj.",
|
437
|
+
".q_proj.",
|
438
|
+
".k_proj.",
|
439
|
+
".v_proj.",
|
440
|
+
".o_proj.",
|
441
|
+
]
|
442
|
+
bitsandbytes_stacked_params_mapping = {
|
443
|
+
# shard_name, weight_name, index
|
444
|
+
"q_proj": ("qkv_proj", 0),
|
445
|
+
"k_proj": ("qkv_proj", 1),
|
446
|
+
"v_proj": ("qkv_proj", 2),
|
447
|
+
"gate_proj": ("gate_up_proj", 0),
|
448
|
+
"up_proj": ("gate_up_proj", 1),
|
449
|
+
}
|
450
|
+
|
470
451
|
def __init__(
|
471
452
|
self,
|
472
|
-
config:
|
453
|
+
config: Qwen2_5_VLConfig,
|
473
454
|
quant_config: Optional[QuantizationConfig] = None,
|
474
455
|
prefix: str = "",
|
475
456
|
) -> None:
|
@@ -479,9 +460,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
479
460
|
self.visual = Qwen2_5_VisionTransformer(
|
480
461
|
config.vision_config,
|
481
462
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
482
|
-
# NOTE:
|
483
|
-
# quantization
|
484
|
-
quant_config=
|
463
|
+
# NOTE: Qwen2_5-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
464
|
+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
465
|
+
quant_config=quant_config,
|
485
466
|
prefix=add_prefix("visual", prefix),
|
486
467
|
)
|
487
468
|
|
@@ -500,6 +481,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
500
481
|
quant_config=quant_config,
|
501
482
|
prefix=add_prefix("lm_head", prefix),
|
502
483
|
)
|
484
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
503
485
|
|
504
486
|
self.logits_processor = LogitsProcessor(config)
|
505
487
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
@@ -553,14 +535,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
553
535
|
otherwise it will be `(seq_len,).
|
554
536
|
(Use input_metadata.mrope_positions to replace it)
|
555
537
|
"""
|
556
|
-
if
|
538
|
+
if self.is_mrope_enabled:
|
557
539
|
positions = forward_batch.mrope_positions
|
558
540
|
|
559
541
|
if not (
|
560
542
|
forward_batch.forward_mode.is_decode()
|
561
543
|
or not forward_batch.contains_image_inputs()
|
562
544
|
):
|
563
|
-
if
|
545
|
+
if self.is_mrope_enabled:
|
564
546
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
565
547
|
"multimodal section rotary embedding requires "
|
566
548
|
f"(3, seq_len) positions, but got {positions.size()}"
|
@@ -610,23 +592,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
610
592
|
weight_loader(param, loaded_weight, shard_id)
|
611
593
|
break
|
612
594
|
else:
|
613
|
-
if "visual" in name and "qkv.weight" in name:
|
614
|
-
visual_num_heads = self.config.vision_config.num_heads
|
615
|
-
visual_embed_dim = self.config.vision_config.hidden_size
|
616
|
-
head_size = visual_embed_dim // visual_num_heads
|
617
|
-
loaded_weight = loaded_weight.view(
|
618
|
-
3, visual_num_heads, head_size, visual_embed_dim
|
619
|
-
)
|
620
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
621
|
-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
622
|
-
elif "visual" in name and "qkv.bias" in name:
|
623
|
-
visual_num_heads = self.config.vision_config.num_heads
|
624
|
-
visual_embed_dim = self.config.vision_config.hidden_size
|
625
|
-
head_size = visual_embed_dim // visual_num_heads
|
626
|
-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
627
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
628
|
-
loaded_weight = loaded_weight.reshape(-1)
|
629
|
-
|
630
595
|
if "visual" in name:
|
631
596
|
# adapt to VisionAttention
|
632
597
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
sglang/srt/models/qwen2_moe.py
CHANGED
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -152,7 +152,7 @@ class Qwen2VisionBlock(nn.Module):
|
|
152
152
|
embed_dim=dim,
|
153
153
|
num_heads=num_heads,
|
154
154
|
projection_size=dim,
|
155
|
-
use_qkv_parallel=
|
155
|
+
use_qkv_parallel=True,
|
156
156
|
use_context_forward=use_context_forward,
|
157
157
|
softmax_in_single_precision=softmax_in_single_precision,
|
158
158
|
flatten_batch=True,
|
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|
351
351
|
|
352
352
|
@property
|
353
353
|
def dtype(self) -> torch.dtype:
|
354
|
-
return
|
354
|
+
return self.patch_embed.proj.weight.dtype
|
355
355
|
|
356
356
|
@property
|
357
357
|
def device(self) -> torch.device:
|
@@ -423,6 +423,25 @@ cached_get_processor = lru_cache(get_processor)
|
|
423
423
|
|
424
424
|
|
425
425
|
class Qwen2VLForConditionalGeneration(nn.Module):
|
426
|
+
# BitandBytes specific attributes
|
427
|
+
default_bitsandbytes_target_modules = [
|
428
|
+
".gate_proj.",
|
429
|
+
".down_proj.",
|
430
|
+
".up_proj.",
|
431
|
+
".q_proj.",
|
432
|
+
".k_proj.",
|
433
|
+
".v_proj.",
|
434
|
+
".o_proj.",
|
435
|
+
]
|
436
|
+
bitsandbytes_stacked_params_mapping = {
|
437
|
+
# shard_name, weight_name, index
|
438
|
+
"q_proj": ("qkv_proj", 0),
|
439
|
+
"k_proj": ("qkv_proj", 1),
|
440
|
+
"v_proj": ("qkv_proj", 2),
|
441
|
+
"gate_proj": ("gate_up_proj", 0),
|
442
|
+
"up_proj": ("gate_up_proj", 1),
|
443
|
+
}
|
444
|
+
|
426
445
|
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
427
446
|
processor = cached_get_processor(self.config._name_or_path)
|
428
447
|
grid_t, grid_h, grid_w = image_grid_thw
|
@@ -447,9 +466,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
447
466
|
self.visual = Qwen2VisionTransformer(
|
448
467
|
config.vision_config,
|
449
468
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
450
|
-
# NOTE: Qwen2-VL vision encoder
|
451
|
-
# quantization
|
452
|
-
quant_config=
|
469
|
+
# NOTE: Qwen2-VL vision encoder currently supports BitsAndBytes 4-bit quantization.
|
470
|
+
# Other quantization methods (e.g., GPTQ, AWQ) are untested and may not be supported.
|
471
|
+
quant_config=quant_config,
|
453
472
|
prefix=add_prefix("visual", prefix),
|
454
473
|
)
|
455
474
|
|
@@ -467,6 +486,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
467
486
|
prefix=add_prefix("lm_head", prefix),
|
468
487
|
)
|
469
488
|
|
489
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
470
490
|
self.logits_processor = LogitsProcessor(config)
|
471
491
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
472
492
|
|
@@ -521,14 +541,14 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
521
541
|
otherwise it will be `(seq_len,).
|
522
542
|
(Use input_metadata.mrope_positions to replace it)
|
523
543
|
"""
|
524
|
-
if
|
544
|
+
if self.is_mrope_enabled:
|
525
545
|
positions = forward_batch.mrope_positions
|
526
546
|
|
527
547
|
if not (
|
528
548
|
forward_batch.forward_mode.is_decode()
|
529
549
|
or not forward_batch.contains_image_inputs()
|
530
550
|
):
|
531
|
-
if
|
551
|
+
if self.is_mrope_enabled:
|
532
552
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
533
553
|
"multimodal section rotary embedding requires "
|
534
554
|
f"(3, seq_len) positions, but got {positions.size()}"
|
@@ -577,24 +597,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
577
597
|
weight_loader(param, loaded_weight, shard_id)
|
578
598
|
break
|
579
599
|
else:
|
580
|
-
|
581
|
-
if "visual" in name and "qkv.weight" in name:
|
582
|
-
visual_num_heads = self.config.vision_config.num_heads
|
583
|
-
visual_embed_dim = self.config.vision_config.embed_dim
|
584
|
-
head_size = visual_embed_dim // visual_num_heads
|
585
|
-
loaded_weight = loaded_weight.view(
|
586
|
-
3, visual_num_heads, head_size, visual_embed_dim
|
587
|
-
)
|
588
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
589
|
-
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
|
590
|
-
elif "visual" in name and "qkv.bias" in name:
|
591
|
-
visual_num_heads = self.config.vision_config.num_heads
|
592
|
-
visual_embed_dim = self.config.vision_config.embed_dim
|
593
|
-
head_size = visual_embed_dim // visual_num_heads
|
594
|
-
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
595
|
-
loaded_weight = loaded_weight.transpose(0, 1)
|
596
|
-
loaded_weight = loaded_weight.reshape(-1)
|
597
|
-
|
598
600
|
if "visual" in name:
|
599
601
|
# adapt to VisionAttention
|
600
602
|
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
sglang/srt/models/stablelm.py
CHANGED
sglang/srt/models/xverse.py
CHANGED
sglang/srt/models/xverse_moe.py
CHANGED
sglang/srt/openai_api/adapter.py
CHANGED
@@ -983,6 +983,8 @@ def v1_chat_generate_request(
|
|
983
983
|
):
|
984
984
|
encoded = encoded[1:]
|
985
985
|
prompt_ids += encoded
|
986
|
+
if tokenizer_manager.model_config.is_multimodal:
|
987
|
+
prompt = tokenizer_manager.tokenizer.decode(prompt_ids)
|
986
988
|
stop = request.stop
|
987
989
|
image_data = None
|
988
990
|
audio_data = None
|
@@ -993,7 +995,8 @@ def v1_chat_generate_request(
|
|
993
995
|
image_data = conv.image_data
|
994
996
|
audio_data = conv.audio_data
|
995
997
|
modalities = conv.modalities
|
996
|
-
stop = conv.stop_str or []
|
998
|
+
stop = conv.stop_str or [] if not request.ignore_eos else []
|
999
|
+
|
997
1000
|
if request.stop:
|
998
1001
|
if isinstance(request.stop, str):
|
999
1002
|
stop.append(request.stop)
|
sglang/srt/patch_torch.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
from typing import Callable, Union
|
15
15
|
|
16
16
|
import torch
|
17
|
+
from packaging import version
|
17
18
|
from torch.multiprocessing import reductions
|
18
19
|
|
19
20
|
|
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
|
|
69
70
|
|
70
71
|
def _modify_tuple(t, index: int, modifier: Callable):
|
71
72
|
return *t[:index], modifier(t[index]), *t[index + 1 :]
|
73
|
+
|
74
|
+
|
75
|
+
def monkey_patch_torch_compile():
|
76
|
+
if version.parse(torch.__version__) < version.parse("2.8.0"):
|
77
|
+
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
|
78
|
+
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
|
79
|
+
import torch._higher_order_ops.auto_functionalize as af
|
80
|
+
|
81
|
+
af.auto_functionalized_v2._cacheable = True
|
82
|
+
af.auto_functionalized._cacheable = True
|