sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +474 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +71 -1
- sglang/check_env.py +3 -6
- sglang/srt/constrained/outlines_backend.py +15 -2
- sglang/srt/constrained/xgrammar_backend.py +22 -14
- sglang/srt/layers/activation.py +3 -0
- sglang/srt/layers/attention/flashinfer_backend.py +93 -48
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/custom_op_util.py +26 -0
- sglang/srt/layers/fused_moe/fused_moe.py +11 -4
- sglang/srt/layers/layernorm.py +4 -0
- sglang/srt/layers/logits_processor.py +10 -10
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +74 -9
- sglang/srt/managers/detokenizer_manager.py +1 -0
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/schedule_batch.py +104 -38
- sglang/srt/managers/schedule_policy.py +5 -1
- sglang/srt/managers/scheduler.py +204 -54
- sglang/srt/managers/session_controller.py +62 -0
- sglang/srt/managers/tokenizer_manager.py +38 -0
- sglang/srt/managers/tp_worker.py +12 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
- sglang/srt/model_executor/cuda_graph_runner.py +43 -6
- sglang/srt/model_executor/forward_batch_info.py +109 -15
- sglang/srt/model_executor/model_runner.py +99 -43
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/deepseek_v2.py +147 -44
- sglang/srt/models/gemma2.py +9 -8
- sglang/srt/models/llava.py +1 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/torch_native_llama.py +94 -78
- sglang/srt/openai_api/adapter.py +6 -2
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +58 -57
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +27 -1
- sglang/srt/server_args.py +78 -62
- sglang/srt/utils.py +71 -52
- sglang/test/runners.py +25 -6
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +30 -19
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma2.py
CHANGED
@@ -97,7 +97,7 @@ class Gemma2MLP(nn.Module):
|
|
97
97
|
class Gemma2Attention(nn.Module):
|
98
98
|
def __init__(
|
99
99
|
self,
|
100
|
-
|
100
|
+
layer_id: int,
|
101
101
|
config: PretrainedConfig,
|
102
102
|
hidden_size: int,
|
103
103
|
num_heads: int,
|
@@ -109,7 +109,7 @@ class Gemma2Attention(nn.Module):
|
|
109
109
|
quant_config: Optional[QuantizationConfig] = None,
|
110
110
|
) -> None:
|
111
111
|
super().__init__()
|
112
|
-
self.
|
112
|
+
self.layer_id = layer_id
|
113
113
|
self.config = config
|
114
114
|
self.hidden_size = hidden_size
|
115
115
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -156,13 +156,13 @@ class Gemma2Attention(nn.Module):
|
|
156
156
|
dtype=torch.get_default_dtype(),
|
157
157
|
)
|
158
158
|
|
159
|
-
use_sliding_window =
|
159
|
+
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
160
160
|
self.attn = RadixAttention(
|
161
161
|
self.num_heads,
|
162
162
|
self.head_dim,
|
163
163
|
self.scaling,
|
164
164
|
num_kv_heads=self.num_kv_heads,
|
165
|
-
layer_id=
|
165
|
+
layer_id=layer_id,
|
166
166
|
logit_cap=self.config.attn_logit_softcapping,
|
167
167
|
sliding_window_size=(
|
168
168
|
get_attention_sliding_window_size(config)
|
@@ -188,7 +188,7 @@ class Gemma2Attention(nn.Module):
|
|
188
188
|
class Gemma2DecoderLayer(nn.Module):
|
189
189
|
def __init__(
|
190
190
|
self,
|
191
|
-
|
191
|
+
layer_id: int,
|
192
192
|
config: PretrainedConfig,
|
193
193
|
cache_config=None,
|
194
194
|
quant_config: Optional[QuantizationConfig] = None,
|
@@ -196,7 +196,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
196
196
|
super().__init__()
|
197
197
|
self.hidden_size = config.hidden_size
|
198
198
|
self.self_attn = Gemma2Attention(
|
199
|
-
|
199
|
+
layer_id=layer_id,
|
200
200
|
config=config,
|
201
201
|
hidden_size=self.hidden_size,
|
202
202
|
num_heads=config.num_attention_heads,
|
@@ -269,8 +269,8 @@ class Gemma2Model(nn.Module):
|
|
269
269
|
)
|
270
270
|
self.layers = nn.ModuleList(
|
271
271
|
[
|
272
|
-
Gemma2DecoderLayer(
|
273
|
-
for
|
272
|
+
Gemma2DecoderLayer(layer_id, config, cache_config, quant_config)
|
273
|
+
for layer_id in range(config.num_hidden_layers)
|
274
274
|
]
|
275
275
|
)
|
276
276
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|
332
332
|
# Gemma does not apply LoRA to the embedding layer.
|
333
333
|
embedding_modules = {}
|
334
334
|
embedding_padding_modules = []
|
335
|
+
supports_lora = True
|
335
336
|
|
336
337
|
def __init__(
|
337
338
|
self,
|
sglang/srt/models/llava.py
CHANGED
@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
345
345
|
|
346
346
|
# Fill in the placeholder for the image
|
347
347
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
348
|
-
prefix_lens_cpu = forward_batch.
|
348
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
349
349
|
pt = 0
|
350
350
|
for i in range(bs):
|
351
351
|
if not need_vision[i]:
|
sglang/srt/models/llavavid.py
CHANGED
@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
169
169
|
|
170
170
|
# Fill in the placeholder for the image
|
171
171
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
172
|
-
prefix_lens_cpu = forward_batch.
|
172
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
173
173
|
pt = 0
|
174
174
|
for i in range(bs):
|
175
175
|
if not need_vision[i]:
|
sglang/srt/models/olmo.py
CHANGED
@@ -223,8 +223,8 @@ class OlmoModel(nn.Module):
|
|
223
223
|
)
|
224
224
|
self.layers = nn.ModuleList(
|
225
225
|
[
|
226
|
-
OlmoDecoderLayer(config,
|
227
|
-
for
|
226
|
+
OlmoDecoderLayer(config, layer_id, quant_config)
|
227
|
+
for layer_id in range(config.num_hidden_layers)
|
228
228
|
]
|
229
229
|
)
|
230
230
|
self.norm = nn.LayerNorm(
|
@@ -250,7 +250,7 @@ class OlmoModel(nn.Module):
|
|
250
250
|
hidden_states = input_embeds
|
251
251
|
|
252
252
|
# Apply blocks one-by-one.
|
253
|
-
for
|
253
|
+
for layer_id, decoder_layer in enumerate(self.layers):
|
254
254
|
# shape: (batch_size, seq_len, d_model)
|
255
255
|
hidden_states = decoder_layer(
|
256
256
|
positions,
|
@@ -0,0 +1,447 @@
|
|
1
|
+
import math
|
2
|
+
from typing import Iterable, Optional, Tuple, Union
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch import nn
|
6
|
+
from transformers import Phi3Config
|
7
|
+
from transformers.configuration_utils import PretrainedConfig
|
8
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
10
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
11
|
+
from vllm.model_executor.models.utils import make_layers
|
12
|
+
|
13
|
+
from sglang.srt.layers.linear import (
|
14
|
+
MergedColumnParallelLinear,
|
15
|
+
QKVParallelLinear,
|
16
|
+
RowParallelLinear,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
19
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
21
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
22
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
23
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
24
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
25
|
+
ParallelLMHead,
|
26
|
+
VocabParallelEmbedding,
|
27
|
+
)
|
28
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
|
+
|
31
|
+
|
32
|
+
@torch.jit.script
|
33
|
+
def quick_gelu(x):
|
34
|
+
return x * torch.sigmoid(1.702 * x)
|
35
|
+
|
36
|
+
|
37
|
+
@torch.jit.script
|
38
|
+
def gegelu(input, limit: Optional[float] = None):
|
39
|
+
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
|
40
|
+
if limit is not None:
|
41
|
+
a_gelu = torch.where(
|
42
|
+
torch.isinf(a_gelu), a_gelu, a_gelu.clamp(min=None, max=limit)
|
43
|
+
)
|
44
|
+
a_linear = torch.where(
|
45
|
+
torch.isinf(a_linear),
|
46
|
+
a_linear,
|
47
|
+
a_linear.clamp(min=-limit, max=limit),
|
48
|
+
)
|
49
|
+
out_gelu = quick_gelu(a_gelu)
|
50
|
+
return out_gelu * (a_linear + 1)
|
51
|
+
|
52
|
+
|
53
|
+
class Phi3SmallMLP(nn.Module):
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
config: PretrainedConfig,
|
58
|
+
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
prefix: str = "",
|
60
|
+
) -> None:
|
61
|
+
super().__init__()
|
62
|
+
self.config = config
|
63
|
+
assert (
|
64
|
+
self.config.hidden_act == "gegelu"
|
65
|
+
), "Only `gegelu` is supported for the 4.7 series of models .."
|
66
|
+
self.hidden_size = config.hidden_size
|
67
|
+
self.gegelu_limit = config.gegelu_limit
|
68
|
+
self.intermediate_size = config.intermediate_size
|
69
|
+
|
70
|
+
self.up_proj = MergedColumnParallelLinear(
|
71
|
+
self.hidden_size,
|
72
|
+
2 * [self.intermediate_size],
|
73
|
+
bias=True,
|
74
|
+
quant_config=quant_config,
|
75
|
+
prefix=f"{prefix}.up_proj",
|
76
|
+
)
|
77
|
+
self.down_proj = RowParallelLinear(
|
78
|
+
self.intermediate_size,
|
79
|
+
self.hidden_size,
|
80
|
+
bias=True,
|
81
|
+
quant_config=quant_config,
|
82
|
+
)
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
gate_up, _ = self.up_proj(x)
|
86
|
+
x = gegelu(gate_up)
|
87
|
+
x, _ = self.down_proj(x)
|
88
|
+
return x
|
89
|
+
|
90
|
+
|
91
|
+
class Phi3SmallSelfAttention(nn.Module):
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
config: PretrainedConfig,
|
96
|
+
layer_id: int = 0,
|
97
|
+
quant_config: Optional[QuantizationConfig] = None,
|
98
|
+
prefix: str = "",
|
99
|
+
) -> None:
|
100
|
+
super().__init__()
|
101
|
+
self.layer_id = layer_id
|
102
|
+
self.config = config
|
103
|
+
self.sparse_block_size = config.blocksparse_block_size
|
104
|
+
self.homo_heads = config.blocksparse_homo_head_pattern
|
105
|
+
self.local_blocks = config.blocksparse_num_local_blocks
|
106
|
+
self.vert_stride = config.blocksparse_vert_stride
|
107
|
+
|
108
|
+
assert (
|
109
|
+
config.blocksparse_block_size == config.blocksparse_triton_kernel_block_size
|
110
|
+
)
|
111
|
+
|
112
|
+
self.hidden_size = config.hidden_size
|
113
|
+
# Number of Query Heads
|
114
|
+
self.num_heads = config.num_attention_heads
|
115
|
+
|
116
|
+
self.head_dim = self.hidden_size // self.num_heads
|
117
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
118
|
+
# Number of total Key Value Heads before tensor parallel
|
119
|
+
self.num_key_value_heads = config.num_key_value_heads
|
120
|
+
self.num_q_per_kv = self.num_heads // self.num_key_value_heads
|
121
|
+
if self.tp_size > 1:
|
122
|
+
assert self.num_key_value_heads % self.tp_size == 0
|
123
|
+
self.num_kv_heads_per_partion = max(1, self.num_key_value_heads // self.tp_size)
|
124
|
+
self.num_heads_per_partition = self.num_heads // self.tp_size
|
125
|
+
|
126
|
+
self.max_position_embeddings = config.max_position_embeddings
|
127
|
+
self.rope_embedding_base = config.rope_embedding_base
|
128
|
+
self.rope_position_scale = config.rope_position_scale
|
129
|
+
self.is_causal = True
|
130
|
+
|
131
|
+
norm_factor = None
|
132
|
+
if config.mup_use_scaling:
|
133
|
+
norm_factor = self.head_dim / config.mup_attn_multiplier
|
134
|
+
else:
|
135
|
+
norm_factor = math.sqrt(self.head_dim)
|
136
|
+
self.scale = 1 / norm_factor
|
137
|
+
|
138
|
+
self.query_key_value = QKVParallelLinear(
|
139
|
+
self.hidden_size,
|
140
|
+
self.head_dim,
|
141
|
+
self.num_heads,
|
142
|
+
self.num_key_value_heads,
|
143
|
+
bias=True,
|
144
|
+
quant_config=quant_config,
|
145
|
+
prefix=f"{prefix}.qkv_proj",
|
146
|
+
)
|
147
|
+
|
148
|
+
self.dense = RowParallelLinear(
|
149
|
+
self.hidden_size,
|
150
|
+
self.hidden_size,
|
151
|
+
bias=True,
|
152
|
+
quant_config=quant_config,
|
153
|
+
prefix=f"{prefix}.o_proj",
|
154
|
+
)
|
155
|
+
|
156
|
+
if getattr(self.config, "rope_scaling", None) is not None:
|
157
|
+
rope_scaling = self.config.rope_scaling
|
158
|
+
for key in rope_scaling:
|
159
|
+
if isinstance(rope_scaling[key], list):
|
160
|
+
rope_scaling[key] = tuple(rope_scaling[key])
|
161
|
+
|
162
|
+
if "factor" not in rope_scaling:
|
163
|
+
rope_scaling["factor"] = self.rope_position_scale
|
164
|
+
else:
|
165
|
+
rope_scaling = {
|
166
|
+
"rope_type": "linear",
|
167
|
+
"factor": self.rope_position_scale,
|
168
|
+
}
|
169
|
+
|
170
|
+
self.rotary_emb = get_rope(
|
171
|
+
self.head_dim,
|
172
|
+
rotary_dim=self.head_dim,
|
173
|
+
max_position=self.max_position_embeddings,
|
174
|
+
base=self.rope_embedding_base,
|
175
|
+
rope_scaling=rope_scaling,
|
176
|
+
)
|
177
|
+
|
178
|
+
# blocksparse params
|
179
|
+
self.blocksparse_block_size = config.blocksparse_block_size
|
180
|
+
self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
|
181
|
+
self.blocksparse_vert_stride = config.blocksparse_vert_stride
|
182
|
+
|
183
|
+
use_dense_attn = (
|
184
|
+
getattr(self.config, "dense_attention_every_n_layers", None)
|
185
|
+
and (self.layer_id + 1) % self.config.dense_attention_every_n_layers == 0
|
186
|
+
)
|
187
|
+
|
188
|
+
bs_params = None
|
189
|
+
if not use_dense_attn:
|
190
|
+
bs_params = {
|
191
|
+
"max_seqlen": self.max_position_embeddings,
|
192
|
+
"num_heads": self.num_heads_per_partition,
|
193
|
+
"num_kv_heads": self.num_kv_heads_per_partion,
|
194
|
+
"block_size": self.sparse_block_size,
|
195
|
+
"local_blocks": self.local_blocks,
|
196
|
+
"vert_stride": self.vert_stride,
|
197
|
+
"homo_head": self.homo_heads,
|
198
|
+
}
|
199
|
+
|
200
|
+
self.attn = RadixAttention(
|
201
|
+
self.num_heads_per_partition,
|
202
|
+
self.head_dim,
|
203
|
+
self.scale,
|
204
|
+
num_kv_heads=self.num_kv_heads_per_partion,
|
205
|
+
layer_id=layer_id,
|
206
|
+
)
|
207
|
+
|
208
|
+
def forward(
|
209
|
+
self,
|
210
|
+
positions: torch.Tensor,
|
211
|
+
hidden_states: torch.Tensor,
|
212
|
+
forward_batch: ForwardBatch,
|
213
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
214
|
+
qkv, _ = self.query_key_value(hidden_states)
|
215
|
+
|
216
|
+
qkv = qkv.view(qkv.shape[:-1] + (-1, (self.num_q_per_kv + 2), self.head_dim))
|
217
|
+
q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
|
218
|
+
|
219
|
+
# NOTE: this is required by RotaryEmbed, which indeed does not have to
|
220
|
+
# TODO: allow 3D QK for rotary forward
|
221
|
+
q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
|
222
|
+
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
223
|
+
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
|
224
|
+
|
225
|
+
q, k = self.rotary_emb(positions, q, k)
|
226
|
+
attn_output = self.attn(q, k, v, forward_batch=forward_batch)
|
227
|
+
output, _ = self.dense(attn_output)
|
228
|
+
|
229
|
+
return output
|
230
|
+
|
231
|
+
|
232
|
+
class Phi3SmallDecoderLayer(nn.Module):
|
233
|
+
|
234
|
+
def __init__(
|
235
|
+
self,
|
236
|
+
config: PretrainedConfig,
|
237
|
+
layer_id: int,
|
238
|
+
cache_config=None,
|
239
|
+
quant_config: Optional[QuantizationConfig] = None,
|
240
|
+
):
|
241
|
+
super().__init__()
|
242
|
+
self.hidden_size = config.hidden_size
|
243
|
+
self.self_attn = Phi3SmallSelfAttention(
|
244
|
+
config, layer_id, quant_config=quant_config
|
245
|
+
)
|
246
|
+
self.mlp = Phi3SmallMLP(config, quant_config)
|
247
|
+
|
248
|
+
self.input_layernorm = nn.LayerNorm(
|
249
|
+
config.hidden_size, eps=config.layer_norm_epsilon
|
250
|
+
)
|
251
|
+
self.post_attention_layernorm = nn.LayerNorm(
|
252
|
+
config.hidden_size, eps=config.layer_norm_epsilon
|
253
|
+
)
|
254
|
+
|
255
|
+
def forward(
|
256
|
+
self,
|
257
|
+
positions: torch.Tensor,
|
258
|
+
hidden_states: torch.Tensor,
|
259
|
+
forward_batch: ForwardBatch,
|
260
|
+
) -> torch.Tensor:
|
261
|
+
residual = hidden_states
|
262
|
+
hidden_states = self.input_layernorm(hidden_states)
|
263
|
+
|
264
|
+
hidden_states = self.self_attn(
|
265
|
+
positions=positions,
|
266
|
+
hidden_states=hidden_states,
|
267
|
+
forward_batch=forward_batch,
|
268
|
+
)
|
269
|
+
hidden_states = residual + hidden_states
|
270
|
+
|
271
|
+
residual = hidden_states
|
272
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
273
|
+
hidden_states = self.mlp(hidden_states)
|
274
|
+
hidden_states = residual + hidden_states
|
275
|
+
return hidden_states
|
276
|
+
|
277
|
+
|
278
|
+
class Phi3SmallModel(nn.Module):
|
279
|
+
|
280
|
+
def __init__(
|
281
|
+
self,
|
282
|
+
config: Phi3Config,
|
283
|
+
quant_config: Optional[QuantizationConfig] = None,
|
284
|
+
prefix: str = "",
|
285
|
+
):
|
286
|
+
super().__init__()
|
287
|
+
|
288
|
+
self.config = config
|
289
|
+
cache_config = None
|
290
|
+
self.embed_tokens = VocabParallelEmbedding(
|
291
|
+
config.vocab_size, config.hidden_size
|
292
|
+
)
|
293
|
+
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
294
|
+
self.start_layer, self.end_layer, self.layers = make_layers(
|
295
|
+
config.num_hidden_layers,
|
296
|
+
lambda prefix: Phi3SmallDecoderLayer(
|
297
|
+
config, int(prefix.split(".")[-1]), cache_config, quant_config
|
298
|
+
),
|
299
|
+
prefix=f"{prefix}.layers",
|
300
|
+
)
|
301
|
+
|
302
|
+
self.final_layernorm = nn.LayerNorm(
|
303
|
+
config.hidden_size, eps=config.layer_norm_epsilon
|
304
|
+
)
|
305
|
+
|
306
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
307
|
+
return self.embed_tokens(input_ids)
|
308
|
+
|
309
|
+
def forward(
|
310
|
+
self,
|
311
|
+
input_ids: torch.LongTensor,
|
312
|
+
positions: Optional[torch.LongTensor],
|
313
|
+
forward_batch: ForwardBatch,
|
314
|
+
inputs_embeds: Optional[torch.Tensor],
|
315
|
+
) -> Union[torch.Tensor]:
|
316
|
+
|
317
|
+
if inputs_embeds is not None:
|
318
|
+
hidden_states = inputs_embeds
|
319
|
+
else:
|
320
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
321
|
+
if (
|
322
|
+
self.mup_embedding_multiplier is not None
|
323
|
+
and self.mup_embedding_multiplier > 0.0
|
324
|
+
):
|
325
|
+
hidden_states = hidden_states * self.mup_embedding_multiplier
|
326
|
+
|
327
|
+
for i in range(self.start_layer, self.end_layer):
|
328
|
+
layer = self.layers[i]
|
329
|
+
hidden_states = layer(positions, hidden_states, forward_batch=forward_batch)
|
330
|
+
|
331
|
+
hidden_states = self.final_layernorm(hidden_states)
|
332
|
+
return hidden_states
|
333
|
+
|
334
|
+
|
335
|
+
class Phi3SmallForCausalLM(nn.Module):
|
336
|
+
_tied_weights_keys = ["lm_head.weight"]
|
337
|
+
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
config: Phi3Config,
|
341
|
+
quant_config: Optional[QuantizationConfig] = None,
|
342
|
+
cache_config=None,
|
343
|
+
):
|
344
|
+
|
345
|
+
super().__init__()
|
346
|
+
|
347
|
+
self.config = config
|
348
|
+
self.quant_config = quant_config
|
349
|
+
self.model = Phi3SmallModel(
|
350
|
+
config=config,
|
351
|
+
quant_config=quant_config,
|
352
|
+
prefix="model",
|
353
|
+
)
|
354
|
+
self.torchao_config = global_server_args_dict["torchao_config"]
|
355
|
+
self.vocab_size = config.vocab_size
|
356
|
+
self.mup_width_multiplier = config.mup_width_multiplier
|
357
|
+
self.lm_head = ParallelLMHead(
|
358
|
+
self.vocab_size,
|
359
|
+
config.hidden_size,
|
360
|
+
org_num_embeddings=config.vocab_size,
|
361
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
362
|
+
quant_config=quant_config,
|
363
|
+
)
|
364
|
+
if self.config.tie_word_embeddings:
|
365
|
+
self.lm_head.weight = self.model.embed_tokens.weight
|
366
|
+
self.logits_processor = LogitsProcessor(config)
|
367
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
368
|
+
|
369
|
+
# tokens in tiktoken but not used
|
370
|
+
if hasattr(config, "dummy_token_indices"):
|
371
|
+
device = self.lm_head.weight.device
|
372
|
+
self.register_buffer(
|
373
|
+
"dummy_token_indices",
|
374
|
+
torch.LongTensor(config.dummy_token_indices).to(device),
|
375
|
+
persistent=False,
|
376
|
+
)
|
377
|
+
else:
|
378
|
+
self.dummy_token_indices = None
|
379
|
+
|
380
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
381
|
+
return self.model.get_input_embeddings(input_ids)
|
382
|
+
|
383
|
+
def set_input_embeddings(self, value):
|
384
|
+
self.model.embed_tokens = value
|
385
|
+
|
386
|
+
def get_output_embeddings(self):
|
387
|
+
return self.lm_head
|
388
|
+
|
389
|
+
def set_output_embeddings(self, value):
|
390
|
+
self.lm_head = value
|
391
|
+
|
392
|
+
def set_decoder(self, decoder):
|
393
|
+
self.model = decoder
|
394
|
+
|
395
|
+
def get_decoder(self):
|
396
|
+
return self.model
|
397
|
+
|
398
|
+
def compute_logits(
|
399
|
+
self,
|
400
|
+
hidden_states: torch.Tensor,
|
401
|
+
sampling_metadata,
|
402
|
+
) -> Optional[torch.Tensor]:
|
403
|
+
logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata)
|
404
|
+
if self.dummy_token_indices is not None and logits is not None:
|
405
|
+
logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
|
406
|
+
return logits
|
407
|
+
|
408
|
+
def forward(
|
409
|
+
self,
|
410
|
+
input_ids: torch.LongTensor,
|
411
|
+
positions: Optional[torch.LongTensor],
|
412
|
+
forward_batch: ForwardBatch,
|
413
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
414
|
+
get_embedding: bool = False,
|
415
|
+
) -> LogitsProcessorOutput:
|
416
|
+
hidden_states = self.model(
|
417
|
+
input_ids=input_ids,
|
418
|
+
positions=positions,
|
419
|
+
forward_batch=forward_batch,
|
420
|
+
inputs_embeds=inputs_embeds,
|
421
|
+
)
|
422
|
+
|
423
|
+
if not get_embedding:
|
424
|
+
return self.logits_processor(
|
425
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
426
|
+
)
|
427
|
+
|
428
|
+
else:
|
429
|
+
return self.pooler(hidden_states, forward_batch)
|
430
|
+
|
431
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
432
|
+
|
433
|
+
params_dict = dict(self.named_parameters())
|
434
|
+
for name, loaded_weight in weights:
|
435
|
+
if "rotary_emb.inv_freq" in name:
|
436
|
+
continue
|
437
|
+
if name.endswith(".bias") and name not in params_dict:
|
438
|
+
continue
|
439
|
+
|
440
|
+
param = params_dict[name]
|
441
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
442
|
+
weight_loader(param, loaded_weight)
|
443
|
+
|
444
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
445
|
+
|
446
|
+
|
447
|
+
EntryClass = Phi3SmallForCausalLM
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -44,6 +44,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
44
44
|
)
|
45
45
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
47
48
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
49
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
49
50
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
@@ -559,6 +560,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
559
560
|
)
|
560
561
|
|
561
562
|
self.logits_processor = LogitsProcessor(config)
|
563
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
562
564
|
|
563
565
|
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
564
566
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
@@ -577,6 +579,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
577
579
|
input_ids: torch.Tensor,
|
578
580
|
positions: torch.Tensor,
|
579
581
|
forward_batch: ForwardBatch,
|
582
|
+
get_embedding: bool = False,
|
580
583
|
):
|
581
584
|
"""Run forward pass for Qwen2-VL.
|
582
585
|
|
@@ -599,8 +602,8 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
599
602
|
image_inputs = [
|
600
603
|
img for img in forward_batch.image_inputs if img is not None
|
601
604
|
]
|
602
|
-
|
603
|
-
|
605
|
+
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
606
|
+
positions = forward_batch.mrope_positions
|
604
607
|
if (
|
605
608
|
forward_batch.forward_mode.is_decode()
|
606
609
|
or image_inputs is None
|
@@ -616,7 +619,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
616
619
|
|
617
620
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
618
621
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
619
|
-
prefix_lens_cpu = forward_batch.
|
622
|
+
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
620
623
|
for i, image in enumerate(forward_batch.image_inputs):
|
621
624
|
if image is None:
|
622
625
|
continue
|
@@ -655,9 +658,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
655
658
|
forward_batch=forward_batch,
|
656
659
|
input_embeds=inputs_embeds,
|
657
660
|
)
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
+
|
662
|
+
if not get_embedding:
|
663
|
+
return self.logits_processor(
|
664
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
665
|
+
)
|
666
|
+
else:
|
667
|
+
return self.pooler(hidden_states, forward_batch)
|
661
668
|
|
662
669
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
663
670
|
stacked_params_mapping = [
|