sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +3 -13
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +158 -8
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +119 -75
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +5 -2
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +18 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +71 -53
- sglang/srt/conversation.py +78 -46
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +11 -3
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +236 -138
- sglang/srt/disaggregation/nixl/conn.py +242 -71
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +51 -2
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +31 -4
- sglang/srt/entrypoints/http_server.py +45 -3
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +147 -51
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +343 -8
- sglang/srt/layers/moe/ep_moe/layer.py +121 -2
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +77 -71
- sglang/srt/layers/quantization/fp8.py +110 -97
- sglang/srt/layers/quantization/fp8_kernel.py +81 -62
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +11 -14
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +13 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +93 -23
- sglang/srt/managers/schedule_policy.py +11 -8
- sglang/srt/managers/scheduler.py +140 -100
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +157 -47
- sglang/srt/managers/tp_worker.py +21 -21
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +4 -2
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +57 -41
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +3 -3
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +77 -39
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +3 -1
- sglang/srt/models/llama4.py +58 -13
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +52 -42
- sglang/srt/openai_api/protocol.py +20 -16
- sglang/srt/reasoning_parser.py +1 -1
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +2 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +64 -10
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +41 -6
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +92 -15
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +150 -137
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +1 -1
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
sglang/srt/models/phi3_small.py
CHANGED
@@ -6,7 +6,7 @@ from torch import nn
|
|
6
6
|
from transformers import Phi3Config
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
8
|
|
9
|
-
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
9
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
10
|
from sglang.srt.layers.linear import (
|
11
11
|
MergedColumnParallelLinear,
|
12
12
|
QKVParallelLinear,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
17
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
19
|
from sglang.srt.layers.rotary_embedding import get_rope
|
20
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
20
21
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
21
22
|
DEFAULT_VOCAB_PADDING_SIZE,
|
22
23
|
ParallelLMHead,
|
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
|
|
294
295
|
super().__init__()
|
295
296
|
|
296
297
|
self.config = config
|
298
|
+
|
299
|
+
self.pp_group = get_pp_group()
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
self.embed_tokens = VocabParallelEmbedding(
|
302
|
+
config.vocab_size,
|
303
|
+
config.hidden_size,
|
304
|
+
prefix=add_prefix("embed_tokens", prefix),
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
self.embed_tokens = PPMissingLayer()
|
308
|
+
|
297
309
|
self.embed_tokens = VocabParallelEmbedding(
|
298
310
|
config.vocab_size,
|
299
311
|
config.hidden_size,
|
300
312
|
prefix=add_prefix("embed_tokens", prefix),
|
301
313
|
)
|
302
314
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
303
|
-
self.
|
315
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
304
316
|
config.num_hidden_layers,
|
305
317
|
lambda idx, prefix: Phi3SmallDecoderLayer(
|
306
318
|
config,
|
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
|
|
308
320
|
quant_config,
|
309
321
|
prefix=prefix,
|
310
322
|
),
|
323
|
+
pp_rank=self.pp_group.rank_in_group,
|
324
|
+
pp_size=self.pp_group.world_size,
|
311
325
|
prefix=add_prefix("layers", prefix),
|
312
326
|
)
|
313
327
|
|
@@ -0,0 +1,467 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""
|
16
|
+
Using mistral-community/pixtral-12b as reference.
|
17
|
+
"""
|
18
|
+
|
19
|
+
import logging
|
20
|
+
import math
|
21
|
+
from typing import Iterable, List, Optional, Set, Tuple, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import torch.nn as nn
|
25
|
+
import torch.nn.functional as F
|
26
|
+
from transformers import PixtralVisionConfig, PretrainedConfig
|
27
|
+
from transformers.models.pixtral.modeling_pixtral import PixtralRotaryEmbedding
|
28
|
+
from transformers.models.pixtral.modeling_pixtral import (
|
29
|
+
generate_block_attention_mask as _get_pixtral_attention_mask,
|
30
|
+
)
|
31
|
+
from transformers.models.pixtral.modeling_pixtral import position_ids_in_meshgrid
|
32
|
+
|
33
|
+
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
35
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
36
|
+
from sglang.srt.layers.linear import MergedColumnParallelLinear, RowParallelLinear
|
37
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
39
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
|
+
|
42
|
+
|
43
|
+
class PixtralHFMLP(nn.Module):
|
44
|
+
"""MLP for PixtralHFVisionModel using SGLang components."""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
config: PretrainedConfig,
|
49
|
+
quant_config: Optional[QuantizationConfig] = None,
|
50
|
+
*,
|
51
|
+
prefix: str = "",
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
|
55
|
+
assert config.intermediate_size is not None
|
56
|
+
|
57
|
+
# Use MergedColumnParallelLinear for gate_up_proj to handle combined weights
|
58
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
59
|
+
input_size=config.hidden_size,
|
60
|
+
output_sizes=[config.intermediate_size, config.intermediate_size],
|
61
|
+
bias=False,
|
62
|
+
quant_config=quant_config,
|
63
|
+
prefix=f"{prefix}.gate_up_proj",
|
64
|
+
)
|
65
|
+
|
66
|
+
self.down_proj = RowParallelLinear(
|
67
|
+
input_size=config.intermediate_size,
|
68
|
+
output_size=config.hidden_size,
|
69
|
+
bias=False,
|
70
|
+
quant_config=quant_config,
|
71
|
+
prefix=f"{prefix}.down_proj",
|
72
|
+
)
|
73
|
+
|
74
|
+
self.act_fn = SiluAndMul()
|
75
|
+
|
76
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
77
|
+
gate_up_output, _ = self.gate_up_proj(x)
|
78
|
+
|
79
|
+
# Apply SiLU activation and multiply
|
80
|
+
gate_up = self.act_fn(gate_up_output)
|
81
|
+
|
82
|
+
# Project back to hidden size
|
83
|
+
out, _ = self.down_proj(gate_up)
|
84
|
+
return out
|
85
|
+
|
86
|
+
|
87
|
+
class PixtralHFTransformerBlock(nn.Module):
|
88
|
+
"""Transformer block for PixtralHFVisionModel using SGLang components."""
|
89
|
+
|
90
|
+
def __init__(
|
91
|
+
self,
|
92
|
+
config: PretrainedConfig,
|
93
|
+
layer_id: int,
|
94
|
+
quant_config: Optional[QuantizationConfig] = None,
|
95
|
+
*,
|
96
|
+
prefix: str = "",
|
97
|
+
) -> None:
|
98
|
+
super().__init__()
|
99
|
+
|
100
|
+
self.layer_id = layer_id
|
101
|
+
self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
102
|
+
|
103
|
+
# Use SGLang's VisionAttention instead of vLLM's PixtralHFAttention
|
104
|
+
self.attention = VisionAttention(
|
105
|
+
embed_dim=config.hidden_size,
|
106
|
+
num_heads=config.num_attention_heads,
|
107
|
+
projection_size=config.hidden_size,
|
108
|
+
use_qkv_parallel=True,
|
109
|
+
quant_config=quant_config,
|
110
|
+
dropout=0.0,
|
111
|
+
use_context_forward=False,
|
112
|
+
softmax_in_single_precision=False,
|
113
|
+
flatten_batch=False,
|
114
|
+
prefix=f"{prefix}.attention",
|
115
|
+
)
|
116
|
+
|
117
|
+
self.feed_forward = PixtralHFMLP(
|
118
|
+
config, quant_config=quant_config, prefix=f"{prefix}.feed_forward"
|
119
|
+
)
|
120
|
+
|
121
|
+
self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
|
122
|
+
|
123
|
+
def forward(
|
124
|
+
self,
|
125
|
+
hidden_states: torch.Tensor,
|
126
|
+
attention_mask: Optional[torch.Tensor],
|
127
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
128
|
+
) -> torch.Tensor:
|
129
|
+
# Ensure hidden_states has the batch dimension [batch, seq_len, hidden_dim]
|
130
|
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
131
|
+
|
132
|
+
# Apply attention norm - normalize along the last dimension
|
133
|
+
attn_normalized = self.attention_norm(hidden_states.view(-1, hidden_dim)).view(
|
134
|
+
batch_size, seq_len, hidden_dim
|
135
|
+
)
|
136
|
+
|
137
|
+
# Pass through attention layer
|
138
|
+
attention_output = self.attention(
|
139
|
+
attn_normalized,
|
140
|
+
attention_mask=attention_mask,
|
141
|
+
cu_seqlens=None,
|
142
|
+
position_embeddings=position_embeddings,
|
143
|
+
)
|
144
|
+
|
145
|
+
# Apply first residual connection
|
146
|
+
hidden_states = hidden_states + attention_output
|
147
|
+
|
148
|
+
# Apply feed-forward norm - normalize along the last dimension
|
149
|
+
ffn_normalized = self.ffn_norm(hidden_states.view(-1, hidden_dim)).view(
|
150
|
+
batch_size, seq_len, hidden_dim
|
151
|
+
)
|
152
|
+
|
153
|
+
# Pass through feed-forward layer
|
154
|
+
# First reshape to 2D for the feed-forward network, then reshape back
|
155
|
+
ffn_output = self.feed_forward(ffn_normalized)
|
156
|
+
|
157
|
+
# Apply second residual connection
|
158
|
+
output = hidden_states + ffn_output
|
159
|
+
|
160
|
+
return output
|
161
|
+
|
162
|
+
|
163
|
+
class PixtralHFTransformer(nn.Module):
|
164
|
+
"""Transformer for PixtralHFVisionModel using SGLang components."""
|
165
|
+
|
166
|
+
def __init__(
|
167
|
+
self,
|
168
|
+
config: PixtralVisionConfig,
|
169
|
+
quant_config: Optional[QuantizationConfig] = None,
|
170
|
+
*,
|
171
|
+
num_hidden_layers_override: Optional[int] = None,
|
172
|
+
prefix: str = "",
|
173
|
+
) -> None:
|
174
|
+
super().__init__()
|
175
|
+
|
176
|
+
num_hidden_layers = config.num_hidden_layers
|
177
|
+
if num_hidden_layers_override is not None:
|
178
|
+
num_hidden_layers = num_hidden_layers_override
|
179
|
+
|
180
|
+
self.layers = nn.ModuleList(
|
181
|
+
[
|
182
|
+
PixtralHFTransformerBlock(
|
183
|
+
config=config,
|
184
|
+
layer_id=layer_idx,
|
185
|
+
quant_config=quant_config,
|
186
|
+
prefix=f"{prefix}.layers.{layer_idx}",
|
187
|
+
)
|
188
|
+
for layer_idx in range(num_hidden_layers)
|
189
|
+
]
|
190
|
+
)
|
191
|
+
|
192
|
+
def forward(
|
193
|
+
self,
|
194
|
+
x: torch.Tensor,
|
195
|
+
attention_mask: Optional[torch.Tensor],
|
196
|
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
197
|
+
return_all_hidden_states: bool = False,
|
198
|
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
199
|
+
"""Forward pass through transformer layers.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
x: Input tensor
|
203
|
+
attention_mask: Optional attention mask
|
204
|
+
position_embeddings: Optional position embeddings for rotary attention
|
205
|
+
return_all_hidden_states: Whether to return all hidden states
|
206
|
+
|
207
|
+
Returns:
|
208
|
+
Either the final hidden state, or a list of all hidden states if
|
209
|
+
return_all_hidden_states is True
|
210
|
+
"""
|
211
|
+
# For HF model compatibility, always start with the input
|
212
|
+
hidden_states = x
|
213
|
+
all_hidden_states = [hidden_states] if return_all_hidden_states else None
|
214
|
+
|
215
|
+
for i, layer in enumerate(self.layers):
|
216
|
+
hidden_states = layer(hidden_states, attention_mask, position_embeddings)
|
217
|
+
if return_all_hidden_states:
|
218
|
+
all_hidden_states.append(hidden_states)
|
219
|
+
|
220
|
+
if return_all_hidden_states:
|
221
|
+
return all_hidden_states
|
222
|
+
return hidden_states
|
223
|
+
|
224
|
+
|
225
|
+
def resolve_visual_encoder_outputs(
|
226
|
+
outputs: Union[torch.Tensor, List[torch.Tensor]],
|
227
|
+
feature_sample_layers: Optional[List[int]],
|
228
|
+
post_norm: Optional[nn.Module],
|
229
|
+
num_hidden_layers: int,
|
230
|
+
) -> torch.Tensor:
|
231
|
+
"""Resolve outputs from visual encoder based on feature_sample_layers."""
|
232
|
+
if feature_sample_layers is None:
|
233
|
+
# Just use the last layer's output
|
234
|
+
if isinstance(outputs, list):
|
235
|
+
outputs = outputs[-1]
|
236
|
+
if post_norm is not None:
|
237
|
+
outputs = post_norm(outputs)
|
238
|
+
return outputs
|
239
|
+
|
240
|
+
# Handle the case where we want to use specific layers
|
241
|
+
if not isinstance(outputs, list):
|
242
|
+
raise ValueError(
|
243
|
+
"Expected outputs to be a list when feature_sample_layers is provided"
|
244
|
+
)
|
245
|
+
|
246
|
+
# Validate layer indices
|
247
|
+
for layer_idx in feature_sample_layers:
|
248
|
+
if layer_idx < 0 or layer_idx > num_hidden_layers:
|
249
|
+
raise ValueError(
|
250
|
+
f"Feature sample layer index {layer_idx} is out of range "
|
251
|
+
f"[0, {num_hidden_layers}]"
|
252
|
+
)
|
253
|
+
|
254
|
+
# Collect outputs from specified layers
|
255
|
+
selected_outputs = [outputs[layer_idx] for layer_idx in feature_sample_layers]
|
256
|
+
|
257
|
+
# Combine the outputs
|
258
|
+
combined_outputs = torch.cat(selected_outputs, dim=-1)
|
259
|
+
|
260
|
+
if post_norm is not None:
|
261
|
+
combined_outputs = post_norm(combined_outputs)
|
262
|
+
|
263
|
+
return combined_outputs
|
264
|
+
|
265
|
+
|
266
|
+
class PixtralHFVisionModel(nn.Module):
|
267
|
+
"""Hugging Face Pixtral Vision Model implemented using SGLang components."""
|
268
|
+
|
269
|
+
DEFAULT_IMAGE_TOKEN_ID = 10
|
270
|
+
|
271
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
272
|
+
return self.input_padder.pad_input_tokens(input_ids, image_inputs)
|
273
|
+
|
274
|
+
def __init__(
|
275
|
+
self,
|
276
|
+
config: PixtralVisionConfig,
|
277
|
+
quant_config: Optional[QuantizationConfig] = None,
|
278
|
+
*,
|
279
|
+
image_token_id: int = DEFAULT_IMAGE_TOKEN_ID,
|
280
|
+
num_hidden_layers_override: Optional[int] = None,
|
281
|
+
prefix: str = "",
|
282
|
+
) -> None:
|
283
|
+
super().__init__()
|
284
|
+
|
285
|
+
self.config = config
|
286
|
+
|
287
|
+
self.image_size = config.image_size
|
288
|
+
self.patch_size = config.patch_size
|
289
|
+
|
290
|
+
self.patch_conv = nn.Conv2d(
|
291
|
+
in_channels=config.num_channels,
|
292
|
+
out_channels=config.hidden_size,
|
293
|
+
kernel_size=config.patch_size,
|
294
|
+
stride=config.patch_size,
|
295
|
+
bias=False,
|
296
|
+
)
|
297
|
+
|
298
|
+
self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
|
299
|
+
|
300
|
+
self.transformer = PixtralHFTransformer(
|
301
|
+
config,
|
302
|
+
quant_config,
|
303
|
+
num_hidden_layers_override=num_hidden_layers_override,
|
304
|
+
prefix=f"{prefix}.transformer",
|
305
|
+
)
|
306
|
+
|
307
|
+
# Check that num_hidden_layers is valid
|
308
|
+
num_hidden_layers = config.num_hidden_layers
|
309
|
+
if len(self.transformer.layers) > config.num_hidden_layers:
|
310
|
+
raise ValueError(
|
311
|
+
f"The original encoder only has {num_hidden_layers} "
|
312
|
+
f"layers, but you requested {len(self.transformer.layers)} "
|
313
|
+
"layers."
|
314
|
+
)
|
315
|
+
|
316
|
+
# Initialize patch position embedding
|
317
|
+
self.image_token_id = image_token_id
|
318
|
+
self.patch_positional_embedding = PixtralRotaryEmbedding(config)
|
319
|
+
self.input_padder = MultiModalityDataPaddingPatternMultimodalTokens(
|
320
|
+
[self.image_token_id]
|
321
|
+
)
|
322
|
+
|
323
|
+
@property
|
324
|
+
def dtype(self):
|
325
|
+
return next(self.parameters()).dtype
|
326
|
+
|
327
|
+
@property
|
328
|
+
def device(self):
|
329
|
+
return next(self.parameters()).device
|
330
|
+
|
331
|
+
def forward(
|
332
|
+
self,
|
333
|
+
pixel_values: torch.Tensor,
|
334
|
+
image_sizes: list[tuple[int, int]],
|
335
|
+
output_hidden_states: bool = False,
|
336
|
+
feature_sample_layers: Optional[list[int]] = None,
|
337
|
+
) -> Union[torch.Tensor, tuple]:
|
338
|
+
"""
|
339
|
+
Args:
|
340
|
+
pixel_values: [batch_size, C, H, W], padded if multiple images
|
341
|
+
image_sizes: list of (H, W) for each image in the batch
|
342
|
+
output_hidden_states: Whether to return all hidden states.
|
343
|
+
feature_sample_layers: Layer indices whose features should be
|
344
|
+
concatenated and used as the visual encoder output. If none
|
345
|
+
are provided, the last layer is used.
|
346
|
+
|
347
|
+
Returns:
|
348
|
+
A tuple containing:
|
349
|
+
- hidden_states: Final model outputs (or selected layers if feature_sample_layers given)
|
350
|
+
- hidden_states tuple (optional): All hidden states if output_hidden_states=True
|
351
|
+
"""
|
352
|
+
# batch patch images
|
353
|
+
embeds_orig = self.patch_conv(
|
354
|
+
pixel_values.to(device=self.device, dtype=self.dtype)
|
355
|
+
)
|
356
|
+
# crop the embeddings
|
357
|
+
embeds_2d = [
|
358
|
+
embed[..., : h // self.patch_size, : w // self.patch_size]
|
359
|
+
for embed, (h, w) in zip(embeds_orig, image_sizes)
|
360
|
+
]
|
361
|
+
|
362
|
+
# flatten to sequence
|
363
|
+
embeds_1d = torch.cat([p.flatten(1).T for p in embeds_2d], dim=0)
|
364
|
+
embeds_featurized = self.ln_pre(embeds_1d).unsqueeze(0)
|
365
|
+
|
366
|
+
# positional embeddings
|
367
|
+
position_ids = position_ids_in_meshgrid(
|
368
|
+
embeds_2d,
|
369
|
+
max_width=self.image_size // self.patch_size,
|
370
|
+
).to(self.device)
|
371
|
+
|
372
|
+
# The original PixtralRotaryEmbedding expects 2D input but returns a tuple of tensors (cos, sin)
|
373
|
+
# These tensors are used by apply_rotary_pos_emb in the transformer blocks
|
374
|
+
position_embedding = self.patch_positional_embedding(
|
375
|
+
embeds_featurized, position_ids
|
376
|
+
)
|
377
|
+
attention_mask = _get_pixtral_attention_mask(
|
378
|
+
[p.shape[-2] * p.shape[-1] for p in embeds_2d], embeds_featurized
|
379
|
+
)
|
380
|
+
|
381
|
+
return_all_hidden_states = (
|
382
|
+
output_hidden_states or feature_sample_layers is not None
|
383
|
+
)
|
384
|
+
|
385
|
+
transformer_outputs = self.transformer(
|
386
|
+
embeds_featurized, # add batch dimension
|
387
|
+
attention_mask,
|
388
|
+
position_embedding,
|
389
|
+
return_all_hidden_states=return_all_hidden_states,
|
390
|
+
)
|
391
|
+
|
392
|
+
# Store all hidden states if requested
|
393
|
+
all_hidden_states = None
|
394
|
+
if isinstance(transformer_outputs, list):
|
395
|
+
all_hidden_states = transformer_outputs
|
396
|
+
# Use the last layer by default if feature_sample_layers is not specified
|
397
|
+
if feature_sample_layers is None:
|
398
|
+
out = transformer_outputs[-1]
|
399
|
+
else:
|
400
|
+
# Resolve outputs based on feature sample layers
|
401
|
+
out = resolve_visual_encoder_outputs(
|
402
|
+
transformer_outputs,
|
403
|
+
feature_sample_layers,
|
404
|
+
None,
|
405
|
+
self.config.num_hidden_layers,
|
406
|
+
)
|
407
|
+
else:
|
408
|
+
out = transformer_outputs
|
409
|
+
|
410
|
+
# Format return to be compatible with HuggingFace vision models
|
411
|
+
if output_hidden_states:
|
412
|
+
return type(
|
413
|
+
"VisualOutput",
|
414
|
+
(),
|
415
|
+
{
|
416
|
+
"last_hidden_state": out,
|
417
|
+
"hidden_states": all_hidden_states,
|
418
|
+
},
|
419
|
+
)
|
420
|
+
else:
|
421
|
+
return out
|
422
|
+
|
423
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
|
424
|
+
"""Load weights from a HuggingFace checkpoint with proper parameter mapping."""
|
425
|
+
params_dict = dict(self.named_parameters())
|
426
|
+
|
427
|
+
# for (param, weight, shard_id): load weight into param as param's shard_id part
|
428
|
+
stacked_params_mapping = [
|
429
|
+
(".attention.qkv_proj", ".attention.q_proj", "q"),
|
430
|
+
(".attention.qkv_proj", ".attention.k_proj", "k"),
|
431
|
+
(".attention.qkv_proj", ".attention.v_proj", "v"),
|
432
|
+
(".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0),
|
433
|
+
(".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1),
|
434
|
+
]
|
435
|
+
|
436
|
+
# Process each weight
|
437
|
+
for name, loaded_weight in weights:
|
438
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
439
|
+
if weight_name in name:
|
440
|
+
# Replace the weight name part with the combined parameter name
|
441
|
+
transformed_name = name.replace(weight_name, param_name)
|
442
|
+
if transformed_name in params_dict:
|
443
|
+
param = params_dict[transformed_name]
|
444
|
+
weight_loader = getattr(
|
445
|
+
param, "weight_loader", default_weight_loader
|
446
|
+
)
|
447
|
+
weight_loader(param, loaded_weight, shard_id)
|
448
|
+
break
|
449
|
+
else:
|
450
|
+
if ".attention.o_proj" in name:
|
451
|
+
alt_name = name.replace(".attention.o_proj", ".attention.proj")
|
452
|
+
if alt_name in params_dict:
|
453
|
+
name = alt_name
|
454
|
+
if name in params_dict:
|
455
|
+
param = params_dict[name]
|
456
|
+
weight_loader = getattr(
|
457
|
+
param, "weight_loader", default_weight_loader
|
458
|
+
)
|
459
|
+
weight_loader(param, loaded_weight)
|
460
|
+
|
461
|
+
|
462
|
+
class PixtralVisionModel(PixtralHFVisionModel):
|
463
|
+
pass
|
464
|
+
|
465
|
+
|
466
|
+
# Register the model classes for external access
|
467
|
+
EntryClass = [PixtralVisionModel]
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
125
125
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
126
126
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
127
127
|
if attn_implementation == "sdpa":
|
128
|
-
use_context_forward = False
|
129
128
|
softmax_in_single_precision = False
|
129
|
+
qkv_backend = "sdpa"
|
130
130
|
flatten_batch = True
|
131
131
|
elif attn_implementation == "flash_attention_2":
|
132
132
|
softmax_in_single_precision = False
|
133
|
-
|
133
|
+
qkv_backend = "triton_attn"
|
134
134
|
flatten_batch = True
|
135
135
|
elif attn_implementation == "eager":
|
136
136
|
softmax_in_single_precision = True
|
137
|
-
|
137
|
+
qkv_backend = "sdpa"
|
138
|
+
flatten_batch = True
|
139
|
+
elif attn_implementation == "flash_attention_3":
|
140
|
+
softmax_in_single_precision = False
|
141
|
+
qkv_backend = "fa3"
|
138
142
|
flatten_batch = True
|
139
143
|
|
140
144
|
self.attn = VisionAttention(
|
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
142
146
|
num_heads=num_heads,
|
143
147
|
projection_size=dim,
|
144
148
|
use_qkv_parallel=True,
|
145
|
-
|
149
|
+
qkv_backend=qkv_backend,
|
146
150
|
softmax_in_single_precision=softmax_in_single_precision,
|
147
151
|
flatten_batch=flatten_batch,
|
148
152
|
quant_config=quant_config,
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
|
|
139
139
|
self.norm2 = norm_layer(dim)
|
140
140
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
141
141
|
if attn_implementation == "sdpa":
|
142
|
-
|
142
|
+
qkv_backend = "sdpa"
|
143
143
|
softmax_in_single_precision = False
|
144
144
|
elif attn_implementation == "flash_attention_2":
|
145
|
+
qkv_backend = "triton_attn"
|
145
146
|
softmax_in_single_precision = False
|
146
|
-
use_context_forward = True
|
147
147
|
elif attn_implementation == "eager":
|
148
|
+
qkv_backend = "sdpa"
|
148
149
|
softmax_in_single_precision = True
|
149
|
-
use_context_forward = False
|
150
150
|
|
151
151
|
self.attn = VisionAttention(
|
152
152
|
embed_dim=dim,
|
153
153
|
num_heads=num_heads,
|
154
154
|
projection_size=dim,
|
155
155
|
use_qkv_parallel=True,
|
156
|
-
|
156
|
+
qkv_backend=qkv_backend,
|
157
157
|
softmax_in_single_precision=softmax_in_single_precision,
|
158
158
|
flatten_batch=True,
|
159
159
|
quant_config=quant_config,
|
sglang/srt/models/roberta.py
CHANGED
@@ -57,7 +57,7 @@ class RobertaEmbedding(nn.Module):
|
|
57
57
|
input_shape = input_ids.size()
|
58
58
|
inputs_embeds = self.word_embeddings(input_ids)
|
59
59
|
|
60
|
-
#
|
60
|
+
# Adapted from vllm: https://github.com/vllm-project/vllm/commit/4a18fd14ba4a349291c798a16bf62fa8a9af0b6b/vllm/model_executor/models/roberta.py
|
61
61
|
|
62
62
|
pos_list = []
|
63
63
|
token_list = []
|