sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,467 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""
|
16
|
+
Using mistral-community/pixtral-12b as reference.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
import math
|
21
|
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.nn as nn
|
25
|
+
import torch.nn.functional as F
|
26
|
+
from transformers import PixtralVisionConfig, PretrainedConfig
|
27
|
+
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
|
28
|
+
from transformers.models.pixtral.modeling_pixtral import (
|
29
|
+
generate_block_attention_mask as _get_pixtral_attention_mask,
|
30
|
+
)
|
31
|
+
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
|
32
|
+
|
33
|
+
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
35
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
36
|
+
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
39
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
|
+
|
42
|
+
|
43
|
+
class PixtralHFMLP(nn.Module):
|
44
|
+
"""MLP for PixtralHFVisionModel using SGLang components."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
config: PretrainedConfig,
|
49
|
+
quant_config: Optional[QuantizationConfig] = None,
|
50
|
+
*,
|
51
|
+
prefix: str = "",
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
|
55
|
+
assert config.intermediate_size is not None
|
56
|
+
|
57
|
+
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
|
58
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
59
|
+
input_size=config.hidden_size,
|
60
|
+
output_sizes=[config.intermediate_size, config.intermediate_size],
|
61
|
+
bias=False,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=f"{prefix}.gate_up_proj",
|
64
|
+
)
|
65
|
+
|
66
|
+
self.down_proj = RowParallelLinear(
|
67
|
+
input_size=config.intermediate_size,
|
68
|
+
output_size=config.hidden_size,
|
69
|
+
bias=False,
|
70
|
+
quant_config=quant_config,
|
71
|
+
prefix=f"{prefix}.down_proj",
|
72
|
+
)
|
73
|
+
|
74
|
+
self.act_fn = SiluAndMul()
|
75
|
+
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
gate_up_output, _ = self.gate_up_proj(x)
|
78
|
+
|
79
|
+
# Apply SiLU activation and multiply
|
80
|
+
gate_up = self.act_fn(gate_up_output)
|
81
|
+
|
82
|
+
# Project back to hidden size
|
83
|
+
out, _ = self.down_proj(gate_up)
|
84
|
+
return out
|
85
|
+
|
86
|
+
|
87
|
+
class PixtralHFTransformerBlock(nn.Module):
|
88
|
+
"""Transformer block for PixtralHFVisionModel using SGLang components."""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
config: PretrainedConfig,
|
93
|
+
layer_id: int,
|
94
|
+
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
*,
|
96
|
+
prefix: str = "",
|
97
|
+
) -> None:
|
98
|
+
super().__init__()
|
99
|
+
|
100
|
+
self.layer_id = layer_id
|
101
|
+
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
102
|
+
|
103
|
+
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
|
104
|
+
self.attention = VisionAttention(
|
105
|
+
embed_dim=config.hidden_size,
|
106
|
+
num_heads=config.num_attention_heads,
|
107
|
+
projection_size=config.hidden_size,
|
108
|
+
use_qkv_parallel=True,
|
109
|
+
quant_config=quant_config,
|
110
|
+
dropout=0.0,
|
111
|
+
use_context_forward=False,
|
112
|
+
softmax_in_single_precision=False,
|
113
|
+
flatten_batch=False,
|
114
|
+
prefix=f"{prefix}.attention",
|
115
|
+
)
|
116
|
+
|
117
|
+
self.feed_forward = PixtralHFMLP(
|
118
|
+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
119
|
+
)
|
120
|
+
|
121
|
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
hidden_states: torch.Tensor,
|
126
|
+
attention_mask: Optional[torch.Tensor],
|
127
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
128
|
+
) -> torch.Tensor:
|
129
|
+
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
|
130
|
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
131
|
+
|
132
|
+
# Apply attention norm - normalize along the last dimension
|
133
|
+
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
|
134
|
+
batch_size, seq_len, hidden_dim
|
135
|
+
)
|
136
|
+
|
137
|
+
# Pass through attention layer
|
138
|
+
attention_output = self.attention(
|
139
|
+
attn_normalized,
|
140
|
+
attention_mask=attention_mask,
|
141
|
+
cu_seqlens=None,
|
142
|
+
position_embeddings=position_embeddings,
|
143
|
+
)
|
144
|
+
|
145
|
+
# Apply first residual connection
|
146
|
+
hidden_states = hidden_states + attention_output
|
147
|
+
|
148
|
+
# Apply feed-forward norm - normalize along the last dimension
|
149
|
+
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
|
150
|
+
batch_size, seq_len, hidden_dim
|
151
|
+
)
|
152
|
+
|
153
|
+
# Pass through feed-forward layer
|
154
|
+
# First reshape to 2D for the feed-forward network, then reshape back
|
155
|
+
ffn_output = self.feed_forward(ffn_normalized)
|
156
|
+
|
157
|
+
# Apply second residual connection
|
158
|
+
output = hidden_states + ffn_output
|
159
|
+
|
160
|
+
return output
|
161
|
+
|
162
|
+
|
163
|
+
class PixtralHFTransformer(nn.Module):
|
164
|
+
"""Transformer for PixtralHFVisionModel using SGLang components."""
|
165
|
+
|
166
|
+
def __init__(
|
167
|
+
self,
|
168
|
+
config: PixtralVisionConfig,
|
169
|
+
quant_config: Optional[QuantizationConfig] = None,
|
170
|
+
*,
|
171
|
+
num_hidden_layers_override: Optional[int] = None,
|
172
|
+
prefix: str = "",
|
173
|
+
) -> None:
|
174
|
+
super().__init__()
|
175
|
+
|
176
|
+
num_hidden_layers = config.num_hidden_layers
|
177
|
+
if num_hidden_layers_override is not None:
|
178
|
+
num_hidden_layers = num_hidden_layers_override
|
179
|
+
|
180
|
+
self.layers = nn.ModuleList(
|
181
|
+
[
|
182
|
+
PixtralHFTransformerBlock(
|
183
|
+
config=config,
|
184
|
+
layer_id=layer_idx,
|
185
|
+
quant_config=quant_config,
|
186
|
+
prefix=f"{prefix}.layers.{layer_idx}",
|
187
|
+
)
|
188
|
+
for layer_idx in range(num_hidden_layers)
|
189
|
+
]
|
190
|
+
)
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
self,
|
194
|
+
x: torch.Tensor,
|
195
|
+
attention_mask: Optional[torch.Tensor],
|
196
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
197
|
+
return_all_hidden_states: bool = False,
|
198
|
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
199
|
+
"""Forward pass through transformer layers.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
x: Input tensor
|
203
|
+
attention_mask: Optional attention mask
|
204
|
+
position_embeddings: Optional position embeddings for rotary attention
|
205
|
+
return_all_hidden_states: Whether to return all hidden states
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Either the final hidden state, or a list of all hidden states if
|
209
|
+
return_all_hidden_states is True
|
210
|
+
"""
|
211
|
+
# For HF model compatibility, always start with the input
|
212
|
+
hidden_states = x
|
213
|
+
all_hidden_states = [hidden_states] if return_all_hidden_states else None
|
214
|
+
|
215
|
+
for i, layer in enumerate(self.layers):
|
216
|
+
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
|
217
|
+
if return_all_hidden_states:
|
218
|
+
all_hidden_states.append(hidden_states)
|
219
|
+
|
220
|
+
if return_all_hidden_states:
|
221
|
+
return all_hidden_states
|
222
|
+
return hidden_states
|
223
|
+
|
224
|
+
|
225
|
+
def resolve_visual_encoder_outputs(
|
226
|
+
outputs: Union[torch.Tensor, List[torch.Tensor]],
|
227
|
+
feature_sample_layers: Optional[List[int]],
|
228
|
+
post_norm: Optional[nn.Module],
|
229
|
+
num_hidden_layers: int,
|
230
|
+
) -> torch.Tensor:
|
231
|
+
"""Resolve outputs from visual encoder based on feature_sample_layers."""
|
232
|
+
if feature_sample_layers is None:
|
233
|
+
# Just use the last layer's output
|
234
|
+
if isinstance(outputs, list):
|
235
|
+
outputs = outputs[-1]
|
236
|
+
if post_norm is not None:
|
237
|
+
outputs = post_norm(outputs)
|
238
|
+
return outputs
|
239
|
+
|
240
|
+
# Handle the case where we want to use specific layers
|
241
|
+
if not isinstance(outputs, list):
|
242
|
+
raise ValueError(
|
243
|
+
"Expected outputs to be a list when feature_sample_layers is provided"
|
244
|
+
)
|
245
|
+
|
246
|
+
# Validate layer indices
|
247
|
+
for layer_idx in feature_sample_layers:
|
248
|
+
if layer_idx < 0 or layer_idx > num_hidden_layers:
|
249
|
+
raise ValueError(
|
250
|
+
f"Feature sample layer index {layer_idx} is out of range "
|
251
|
+
f"[0, {num_hidden_layers}]"
|
252
|
+
)
|
253
|
+
|
254
|
+
# Collect outputs from specified layers
|
255
|
+
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
|
256
|
+
|
257
|
+
# Combine the outputs
|
258
|
+
combined_outputs = torch.cat(selected_outputs, dim=-1)
|
259
|
+
|
260
|
+
if post_norm is not None:
|
261
|
+
combined_outputs = post_norm(combined_outputs)
|
262
|
+
|
263
|
+
return combined_outputs
|
264
|
+
|
265
|
+
|
266
|
+
class PixtralHFVisionModel(nn.Module):
|
267
|
+
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
|
268
|
+
|
269
|
+
DEFAULT_IMAGE_TOKEN_ID = 10
|
270
|
+
|
271
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
272
|
+
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
273
|
+
|
274
|
+
def __init__(
|
275
|
+
self,
|
276
|
+
config: PixtralVisionConfig,
|
277
|
+
quant_config: Optional[QuantizationConfig] = None,
|
278
|
+
*,
|
279
|
+
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
280
|
+
num_hidden_layers_override: Optional[int] = None,
|
281
|
+
prefix: str = "",
|
282
|
+
) -> None:
|
283
|
+
super().__init__()
|
284
|
+
|
285
|
+
self.config = config
|
286
|
+
|
287
|
+
self.image_size = config.image_size
|
288
|
+
self.patch_size = config.patch_size
|
289
|
+
|
290
|
+
self.patch_conv = nn.Conv2d(
|
291
|
+
in_channels=config.num_channels,
|
292
|
+
out_channels=config.hidden_size,
|
293
|
+
kernel_size=config.patch_size,
|
294
|
+
stride=config.patch_size,
|
295
|
+
bias=False,
|
296
|
+
)
|
297
|
+
|
298
|
+
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
299
|
+
|
300
|
+
self.transformer = PixtralHFTransformer(
|
301
|
+
config,
|
302
|
+
quant_config,
|
303
|
+
num_hidden_layers_override=num_hidden_layers_override,
|
304
|
+
prefix=f"{prefix}.transformer",
|
305
|
+
)
|
306
|
+
|
307
|
+
# Check that num_hidden_layers is valid
|
308
|
+
num_hidden_layers = config.num_hidden_layers
|
309
|
+
if len(self.transformer.layers) > config.num_hidden_layers:
|
310
|
+
raise ValueError(
|
311
|
+
f"The original encoder only has {num_hidden_layers} "
|
312
|
+
f"layers, but you requested {len(self.transformer.layers)} "
|
313
|
+
"layers."
|
314
|
+
)
|
315
|
+
|
316
|
+
# Initialize patch position embedding
|
317
|
+
self.image_token_id = image_token_id
|
318
|
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
319
|
+
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
320
|
+
[self.image_token_id]
|
321
|
+
)
|
322
|
+
|
323
|
+
@property
|
324
|
+
def dtype(self):
|
325
|
+
return next(self.parameters()).dtype
|
326
|
+
|
327
|
+
@property
|
328
|
+
def device(self):
|
329
|
+
return next(self.parameters()).device
|
330
|
+
|
331
|
+
def forward(
|
332
|
+
self,
|
333
|
+
pixel_values: torch.Tensor,
|
334
|
+
image_sizes: list[tuple[int, int]],
|
335
|
+
output_hidden_states: bool = False,
|
336
|
+
feature_sample_layers: Optional[list[int]] = None,
|
337
|
+
) -> Union[torch.Tensor, tuple]:
|
338
|
+
"""
|
339
|
+
Args:
|
340
|
+
pixel_values: [batch_size, C, H, W], padded if multiple images
|
341
|
+
image_sizes: list of (H, W) for each image in the batch
|
342
|
+
output_hidden_states: Whether to return all hidden states.
|
343
|
+
feature_sample_layers: Layer indices whose features should be
|
344
|
+
concatenated and used as the visual encoder output. If none
|
345
|
+
are provided, the last layer is used.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
A tuple containing:
|
349
|
+
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
|
350
|
+
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
|
351
|
+
"""
|
352
|
+
# batch patch images
|
353
|
+
embeds_orig = self.patch_conv(
|
354
|
+
pixel_values.to(device=self.device, dtype=self.dtype)
|
355
|
+
)
|
356
|
+
# crop the embeddings
|
357
|
+
embeds_2d = [
|
358
|
+
embed[..., : h // self.patch_size, : w // self.patch_size]
|
359
|
+
for embed, (h, w) in zip(embeds_orig, image_sizes)
|
360
|
+
]
|
361
|
+
|
362
|
+
# flatten to sequence
|
363
|
+
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
|
364
|
+
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
|
365
|
+
|
366
|
+
# positional embeddings
|
367
|
+
position_ids = position_ids_in_meshgrid(
|
368
|
+
embeds_2d,
|
369
|
+
max_width=self.image_size // self.patch_size,
|
370
|
+
).to(self.device)
|
371
|
+
|
372
|
+
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
|
373
|
+
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
|
374
|
+
position_embedding = self.patch_positional_embedding(
|
375
|
+
embeds_featurized, position_ids
|
376
|
+
)
|
377
|
+
attention_mask = _get_pixtral_attention_mask(
|
378
|
+
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
|
379
|
+
)
|
380
|
+
|
381
|
+
return_all_hidden_states = (
|
382
|
+
output_hidden_states or feature_sample_layers is not None
|
383
|
+
)
|
384
|
+
|
385
|
+
transformer_outputs = self.transformer(
|
386
|
+
embeds_featurized, # add batch dimension
|
387
|
+
attention_mask,
|
388
|
+
position_embedding,
|
389
|
+
return_all_hidden_states=return_all_hidden_states,
|
390
|
+
)
|
391
|
+
|
392
|
+
# Store all hidden states if requested
|
393
|
+
all_hidden_states = None
|
394
|
+
if isinstance(transformer_outputs, list):
|
395
|
+
all_hidden_states = transformer_outputs
|
396
|
+
# Use the last layer by default if feature_sample_layers is not specified
|
397
|
+
if feature_sample_layers is None:
|
398
|
+
out = transformer_outputs[-1]
|
399
|
+
else:
|
400
|
+
# Resolve outputs based on feature sample layers
|
401
|
+
out = resolve_visual_encoder_outputs(
|
402
|
+
transformer_outputs,
|
403
|
+
feature_sample_layers,
|
404
|
+
None,
|
405
|
+
self.config.num_hidden_layers,
|
406
|
+
)
|
407
|
+
else:
|
408
|
+
out = transformer_outputs
|
409
|
+
|
410
|
+
# Format return to be compatible with HuggingFace vision models
|
411
|
+
if output_hidden_states:
|
412
|
+
return type(
|
413
|
+
"VisualOutput",
|
414
|
+
(),
|
415
|
+
{
|
416
|
+
"last_hidden_state": out,
|
417
|
+
"hidden_states": all_hidden_states,
|
418
|
+
},
|
419
|
+
)
|
420
|
+
else:
|
421
|
+
return out
|
422
|
+
|
423
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
424
|
+
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
|
425
|
+
params_dict = dict(self.named_parameters())
|
426
|
+
|
427
|
+
# for (param, weight, shard_id): load weight into param as param's shard_id part
|
428
|
+
stacked_params_mapping = [
|
429
|
+
(".attention.qkv_proj", ".attention.q_proj", "q"),
|
430
|
+
(".attention.qkv_proj", ".attention.k_proj", "k"),
|
431
|
+
(".attention.qkv_proj", ".attention.v_proj", "v"),
|
432
|
+
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
433
|
+
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
434
|
+
]
|
435
|
+
|
436
|
+
# Process each weight
|
437
|
+
for name, loaded_weight in weights:
|
438
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
439
|
+
if weight_name in name:
|
440
|
+
# Replace the weight name part with the combined parameter name
|
441
|
+
transformed_name = name.replace(weight_name, param_name)
|
442
|
+
if transformed_name in params_dict:
|
443
|
+
param = params_dict[transformed_name]
|
444
|
+
weight_loader = getattr(
|
445
|
+
param, "weight_loader", default_weight_loader
|
446
|
+
)
|
447
|
+
weight_loader(param, loaded_weight, shard_id)
|
448
|
+
break
|
449
|
+
else:
|
450
|
+
if ".attention.o_proj" in name:
|
451
|
+
alt_name = name.replace(".attention.o_proj", ".attention.proj")
|
452
|
+
if alt_name in params_dict:
|
453
|
+
name = alt_name
|
454
|
+
if name in params_dict:
|
455
|
+
param = params_dict[name]
|
456
|
+
weight_loader = getattr(
|
457
|
+
param, "weight_loader", default_weight_loader
|
458
|
+
)
|
459
|
+
weight_loader(param, loaded_weight)
|
460
|
+
|
461
|
+
|
462
|
+
class PixtralVisionModel(PixtralHFVisionModel):
|
463
|
+
pass
|
464
|
+
|
465
|
+
|
466
|
+
# Register the model classes for external access
|
467
|
+
EntryClass = [PixtralVisionModel]
|
sglang/srt/models/qwen2.py
CHANGED
@@ -15,12 +15,14 @@
|
|
15
15
|
# Adapted from llama2.py
|
16
16
|
# Modify details for the adaptation of Qwen2 model.
|
17
17
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
18
|
-
|
18
|
+
import logging
|
19
|
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
19
20
|
|
20
21
|
import torch
|
21
22
|
from torch import nn
|
22
23
|
|
23
24
|
from sglang.srt.distributed import (
|
25
|
+
get_pp_group,
|
24
26
|
get_tensor_model_parallel_rank,
|
25
27
|
get_tensor_model_parallel_world_size,
|
26
28
|
)
|
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
36
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
40
|
from sglang.srt.layers.rotary_embedding import get_rope
|
41
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
39
42
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
43
|
ParallelLMHead,
|
41
44
|
VocabParallelEmbedding,
|
42
45
|
)
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
44
47
|
from sglang.srt.model_loader.weight_utils import (
|
45
48
|
default_weight_loader,
|
46
49
|
kv_cache_scales_loader,
|
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
|
|
50
53
|
Qwen2Config = None
|
51
54
|
|
52
55
|
|
56
|
+
logger = logging.getLogger(__name__)
|
57
|
+
|
58
|
+
|
53
59
|
class Qwen2MLP(nn.Module):
|
54
60
|
def __init__(
|
55
61
|
self,
|
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
|
|
245
251
|
self.config = config
|
246
252
|
self.padding_idx = config.pad_token_id
|
247
253
|
self.vocab_size = config.vocab_size
|
248
|
-
self.
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
+
self.pp_group = get_pp_group()
|
255
|
+
|
256
|
+
if self.pp_group.is_first_rank:
|
257
|
+
self.embed_tokens = VocabParallelEmbedding(
|
258
|
+
config.vocab_size,
|
259
|
+
config.hidden_size,
|
260
|
+
quant_config=quant_config,
|
261
|
+
prefix=add_prefix("embed_tokens", prefix),
|
262
|
+
)
|
263
|
+
else:
|
264
|
+
self.embed_tokens = PPMissingLayer()
|
265
|
+
|
254
266
|
# Use the provided decoder layer type or default to Qwen2DecoderLayer
|
255
267
|
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
|
256
|
-
self.layers = make_layers(
|
268
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
257
269
|
config.num_hidden_layers,
|
258
270
|
lambda idx, prefix: decoder_layer_type(
|
259
271
|
layer_id=idx,
|
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
|
|
261
273
|
quant_config=quant_config,
|
262
274
|
prefix=prefix,
|
263
275
|
),
|
276
|
+
pp_rank=self.pp_group.rank_in_group,
|
277
|
+
pp_size=self.pp_group.world_size,
|
264
278
|
prefix=add_prefix("layers", prefix),
|
265
279
|
)
|
266
|
-
self.
|
280
|
+
if self.pp_group.is_last_rank:
|
281
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
282
|
+
else:
|
283
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
267
284
|
|
268
285
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
269
286
|
if hasattr(self.config, "scale_emb"):
|
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
|
|
280
297
|
positions: torch.Tensor,
|
281
298
|
forward_batch: ForwardBatch,
|
282
299
|
input_embeds: torch.Tensor = None,
|
283
|
-
|
284
|
-
|
285
|
-
|
300
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
301
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
302
|
+
if self.pp_group.is_first_rank:
|
303
|
+
if input_embeds is None:
|
304
|
+
hidden_states = self.embed_tokens(input_ids)
|
305
|
+
else:
|
306
|
+
hidden_states = input_embeds
|
307
|
+
residual = None
|
286
308
|
else:
|
287
|
-
|
288
|
-
|
289
|
-
|
309
|
+
assert pp_proxy_tensors is not None
|
310
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
311
|
+
residual = pp_proxy_tensors["residual"]
|
312
|
+
|
313
|
+
for i in range(self.start_layer, self.end_layer):
|
290
314
|
layer = self.layers[i]
|
291
315
|
hidden_states, residual = layer(
|
292
316
|
positions,
|
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
|
|
294
318
|
forward_batch,
|
295
319
|
residual,
|
296
320
|
)
|
297
|
-
|
321
|
+
if not self.pp_group.is_last_rank:
|
322
|
+
return PPProxyTensors(
|
323
|
+
{
|
324
|
+
"hidden_states": hidden_states,
|
325
|
+
"residual": residual,
|
326
|
+
}
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
298
330
|
return hidden_states
|
299
331
|
|
300
332
|
# If this function is called, it should always initialize KV cache scale
|
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
348
380
|
prefix: str = "",
|
349
381
|
) -> None:
|
350
382
|
super().__init__()
|
383
|
+
self.pp_group = get_pp_group()
|
351
384
|
self.config = config
|
352
385
|
self.quant_config = quant_config
|
353
386
|
self.model = Qwen2Model(
|
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
|
|
379
412
|
forward_batch: ForwardBatch,
|
380
413
|
input_embeds: torch.Tensor = None,
|
381
414
|
get_embedding: bool = False,
|
415
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
382
416
|
) -> torch.Tensor:
|
383
|
-
hidden_states = self.model(
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
417
|
+
hidden_states = self.model(
|
418
|
+
input_ids,
|
419
|
+
positions,
|
420
|
+
forward_batch,
|
421
|
+
input_embeds,
|
422
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
423
|
+
)
|
424
|
+
|
425
|
+
if self.pp_group.is_last_rank:
|
426
|
+
if not get_embedding:
|
427
|
+
return self.logits_processor(
|
428
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
429
|
+
)
|
430
|
+
else:
|
431
|
+
return self.pooler(hidden_states, forward_batch)
|
388
432
|
else:
|
389
|
-
return
|
433
|
+
return hidden_states
|
434
|
+
|
435
|
+
@property
|
436
|
+
def start_layer(self):
|
437
|
+
return self.model.start_layer
|
438
|
+
|
439
|
+
@property
|
440
|
+
def end_layer(self):
|
441
|
+
return self.model.end_layer
|
390
442
|
|
391
443
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
392
444
|
stacked_params_mapping = [
|
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|
400
452
|
|
401
453
|
params_dict = dict(self.named_parameters())
|
402
454
|
for name, loaded_weight in weights:
|
455
|
+
layer_id = get_layer_id(name)
|
456
|
+
if (
|
457
|
+
layer_id is not None
|
458
|
+
and hasattr(self.model, "start_layer")
|
459
|
+
and (
|
460
|
+
layer_id < self.model.start_layer
|
461
|
+
or layer_id >= self.model.end_layer
|
462
|
+
)
|
463
|
+
):
|
464
|
+
continue
|
465
|
+
|
403
466
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
404
467
|
continue
|
405
468
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
|
|
426
489
|
# Skip loading extra bias for GPTQ models.
|
427
490
|
if name.endswith(".bias") and name not in params_dict:
|
428
491
|
continue
|
429
|
-
|
430
|
-
|
431
|
-
|
492
|
+
|
493
|
+
if name in params_dict.keys():
|
494
|
+
param = params_dict[name]
|
495
|
+
weight_loader = getattr(
|
496
|
+
param, "weight_loader", default_weight_loader
|
497
|
+
)
|
498
|
+
weight_loader(param, loaded_weight)
|
499
|
+
else:
|
500
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
432
501
|
|
433
502
|
def get_embed_and_head(self):
|
434
503
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
146
146
|
num_heads=num_heads,
|
147
147
|
projection_size=dim,
|
148
148
|
use_qkv_parallel=True,
|
149
|
+
rotary_embed="normal",
|
150
|
+
proj_bias=True,
|
149
151
|
qkv_backend=qkv_backend,
|
150
152
|
softmax_in_single_precision=softmax_in_single_precision,
|
151
153
|
flatten_batch=flatten_batch,
|
@@ -497,6 +499,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
497
499
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
498
500
|
|
499
501
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
502
|
+
if any(item.precomputed_features is not None for item in items):
|
503
|
+
if not all(item.precomputed_features is not None for item in items):
|
504
|
+
raise NotImplementedError(
|
505
|
+
"MM inputs where only some items are precomputed."
|
506
|
+
)
|
507
|
+
return torch.concat([item.precomputed_features for item in items])
|
500
508
|
# in qwen-vl, last dim is the same
|
501
509
|
pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
|
502
510
|
self.visual.dtype
|