sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,506 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
18
|
+
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
import types
|
21
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
22
|
+
|
23
|
+
import torch
|
24
|
+
from torch import nn
|
25
|
+
from torch.nn.parameter import Parameter
|
26
|
+
from transformers import LlamaConfig
|
27
|
+
from vllm.config import CacheConfig
|
28
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
29
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
31
|
+
ParallelLMHead,
|
32
|
+
VocabParallelEmbedding,
|
33
|
+
)
|
34
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
|
+
|
36
|
+
from sglang.srt.layers.activation import SiluAndMul
|
37
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
39
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
42
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
|
+
|
45
|
+
|
46
|
+
def gate_up_proj_weight_loader(
|
47
|
+
self,
|
48
|
+
param: Parameter,
|
49
|
+
loaded_weight: torch.Tensor,
|
50
|
+
loaded_shard_id: Optional[int] = None,
|
51
|
+
):
|
52
|
+
if loaded_shard_id is None:
|
53
|
+
shard_offsets: List[Tuple[int, int, int]] = []
|
54
|
+
for i, output_size in enumerate(self.output_sizes):
|
55
|
+
shard_offsets.append((i, current_shard_offset, output_size))
|
56
|
+
current_shard_offset += output_size
|
57
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
58
|
+
loaded_weight_shard = loaded_weight.narrow(
|
59
|
+
output_dim, shard_offset, shard_size
|
60
|
+
)
|
61
|
+
self.weight_loader(param, loaded_weight_shard, shard_id)
|
62
|
+
else:
|
63
|
+
assert loaded_shard_id < len(self.output_sizes)
|
64
|
+
param_data = param.data
|
65
|
+
shard_size = loaded_weight.shape[0]
|
66
|
+
shard_offset = loaded_shard_id * shard_size
|
67
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
68
|
+
assert param_data.shape == loaded_weight.shape
|
69
|
+
param_data.copy_(loaded_weight)
|
70
|
+
return
|
71
|
+
|
72
|
+
|
73
|
+
class LlamaMLP(nn.Module):
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
hidden_size: int,
|
77
|
+
intermediate_size: int,
|
78
|
+
hidden_act: str,
|
79
|
+
quant_config: Optional[QuantizationConfig] = None,
|
80
|
+
prefix: str = "",
|
81
|
+
) -> None:
|
82
|
+
super().__init__()
|
83
|
+
self.gate_up_proj = torch.nn.Linear(
|
84
|
+
hidden_size,
|
85
|
+
intermediate_size * 2,
|
86
|
+
bias=False,
|
87
|
+
)
|
88
|
+
self.gate_up_proj.output_sizes = [intermediate_size] * 2
|
89
|
+
self.gate_up_proj.weight_loader = types.MethodType(
|
90
|
+
gate_up_proj_weight_loader, self.gate_up_proj
|
91
|
+
)
|
92
|
+
self.gate_up_proj.weight.weight_loader = self.gate_up_proj.weight_loader
|
93
|
+
self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
|
94
|
+
if hidden_act != "silu":
|
95
|
+
raise ValueError(
|
96
|
+
f"Unsupported activation: {hidden_act}. "
|
97
|
+
"Only silu is supported for now."
|
98
|
+
)
|
99
|
+
self.act_fn = SiluAndMul()
|
100
|
+
|
101
|
+
def forward(self, x):
|
102
|
+
gate_up = self.gate_up_proj(x)
|
103
|
+
x = self.act_fn(gate_up)
|
104
|
+
x = self.down_proj(x)
|
105
|
+
return x
|
106
|
+
|
107
|
+
|
108
|
+
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
109
|
+
shard_offset_mapping = {
|
110
|
+
"q": 0,
|
111
|
+
"k": self.num_heads * self.head_size,
|
112
|
+
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
|
113
|
+
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
114
|
+
}
|
115
|
+
return shard_offset_mapping.get(loaded_shard_id)
|
116
|
+
|
117
|
+
|
118
|
+
def _get_shard_size_mapping(self, loaded_shard_id: str):
|
119
|
+
shard_size_mapping = {
|
120
|
+
"q": self.num_heads * self.head_size,
|
121
|
+
"k": self.num_kv_heads * self.head_size,
|
122
|
+
"v": self.num_kv_heads * self.head_size,
|
123
|
+
}
|
124
|
+
return shard_size_mapping.get(loaded_shard_id)
|
125
|
+
|
126
|
+
|
127
|
+
def qkv_proj_weight_loader(
|
128
|
+
self,
|
129
|
+
param: Parameter,
|
130
|
+
loaded_weight: torch.Tensor,
|
131
|
+
loaded_shard_id: Optional[str] = None,
|
132
|
+
):
|
133
|
+
if loaded_shard_id is None:
|
134
|
+
shard_offsets = [
|
135
|
+
# (shard_id, shard_offset, shard_size)
|
136
|
+
("q", 0, self.total_num_heads * self.head_size),
|
137
|
+
(
|
138
|
+
"k",
|
139
|
+
self.total_num_heads * self.head_size,
|
140
|
+
self.total_num_kv_heads * self.head_size,
|
141
|
+
),
|
142
|
+
(
|
143
|
+
"v",
|
144
|
+
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
145
|
+
self.total_num_kv_heads * self.head_size,
|
146
|
+
),
|
147
|
+
]
|
148
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
149
|
+
loaded_weight_shard = loaded_weight.narrow(
|
150
|
+
param.output_dim, shard_offset, shard_size
|
151
|
+
)
|
152
|
+
self.weight_loader(param, loaded_weight_shard, shard_id)
|
153
|
+
else:
|
154
|
+
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
|
155
|
+
shard_size = self._get_shard_size_mapping(loaded_shard_id)
|
156
|
+
param_data = param.data
|
157
|
+
param_data = param_data.narrow(0, shard_offset, shard_size)
|
158
|
+
assert param_data.shape == loaded_weight.shape
|
159
|
+
param_data.copy_(loaded_weight)
|
160
|
+
return
|
161
|
+
|
162
|
+
|
163
|
+
class LlamaAttention(nn.Module):
|
164
|
+
def __init__(
|
165
|
+
self,
|
166
|
+
config: LlamaConfig,
|
167
|
+
hidden_size: int,
|
168
|
+
num_heads: int,
|
169
|
+
num_kv_heads: int,
|
170
|
+
layer_id: int = 0,
|
171
|
+
rope_theta: float = 10000,
|
172
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
173
|
+
rope_is_neox_style: bool = True,
|
174
|
+
max_position_embeddings: int = 8192,
|
175
|
+
quant_config: Optional[QuantizationConfig] = None,
|
176
|
+
prefix: str = "",
|
177
|
+
) -> None:
|
178
|
+
super().__init__()
|
179
|
+
self.hidden_size = hidden_size
|
180
|
+
tp_size = get_tensor_model_parallel_world_size()
|
181
|
+
self.total_num_heads = num_heads
|
182
|
+
assert self.total_num_heads % tp_size == 0
|
183
|
+
self.num_heads = self.total_num_heads // tp_size
|
184
|
+
self.total_num_kv_heads = num_kv_heads
|
185
|
+
if self.total_num_kv_heads >= tp_size:
|
186
|
+
# Number of KV heads is greater than TP size, so we partition
|
187
|
+
# the KV heads across multiple tensor parallel GPUs.
|
188
|
+
assert self.total_num_kv_heads % tp_size == 0
|
189
|
+
else:
|
190
|
+
# Number of KV heads is less than TP size, so we replicate
|
191
|
+
# the KV heads across multiple tensor parallel GPUs.
|
192
|
+
assert tp_size % self.total_num_kv_heads == 0
|
193
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
194
|
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
195
|
+
self.head_dim = getattr(
|
196
|
+
config, "head_dim", self.hidden_size // self.total_num_heads
|
197
|
+
)
|
198
|
+
self.q_size = self.num_heads * self.head_dim
|
199
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
200
|
+
self.scaling = self.head_dim**-0.5
|
201
|
+
self.rope_theta = rope_theta
|
202
|
+
self.max_position_embeddings = max_position_embeddings
|
203
|
+
|
204
|
+
self.qkv_proj = torch.nn.Linear(
|
205
|
+
hidden_size,
|
206
|
+
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
|
207
|
+
bias=False,
|
208
|
+
)
|
209
|
+
self.qkv_proj.total_num_heads = self.total_num_heads
|
210
|
+
self.qkv_proj.head_size = self.head_dim
|
211
|
+
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
|
212
|
+
self.qkv_proj.num_heads = self.total_num_heads
|
213
|
+
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
|
214
|
+
self.qkv_proj.weight_loader = types.MethodType(
|
215
|
+
qkv_proj_weight_loader, self.qkv_proj
|
216
|
+
)
|
217
|
+
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
|
218
|
+
_get_shard_offset_mapping, self.qkv_proj
|
219
|
+
)
|
220
|
+
self.qkv_proj._get_shard_size_mapping = types.MethodType(
|
221
|
+
_get_shard_size_mapping, self.qkv_proj
|
222
|
+
)
|
223
|
+
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
|
224
|
+
self.qkv_proj.weight.output_dim = 0
|
225
|
+
self.o_proj = torch.nn.Linear(
|
226
|
+
self.total_num_heads * self.head_dim,
|
227
|
+
hidden_size,
|
228
|
+
bias=False,
|
229
|
+
)
|
230
|
+
self.rotary_emb = get_rope(
|
231
|
+
self.head_dim,
|
232
|
+
rotary_dim=self.head_dim,
|
233
|
+
max_position=max_position_embeddings,
|
234
|
+
base=rope_theta,
|
235
|
+
rope_scaling=rope_scaling,
|
236
|
+
is_neox_style=rope_is_neox_style,
|
237
|
+
)
|
238
|
+
self.attn = RadixAttention(
|
239
|
+
self.num_heads,
|
240
|
+
self.head_dim,
|
241
|
+
self.scaling,
|
242
|
+
num_kv_heads=self.num_kv_heads,
|
243
|
+
layer_id=layer_id,
|
244
|
+
)
|
245
|
+
|
246
|
+
def forward(
|
247
|
+
self,
|
248
|
+
positions: torch.Tensor,
|
249
|
+
hidden_states: torch.Tensor,
|
250
|
+
forward_batch: ForwardBatch,
|
251
|
+
) -> torch.Tensor:
|
252
|
+
qkv = self.qkv_proj(hidden_states)
|
253
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
254
|
+
q, k = self.rotary_emb(positions, q, k)
|
255
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
256
|
+
output = self.o_proj(attn_output)
|
257
|
+
return output
|
258
|
+
|
259
|
+
|
260
|
+
class LlamaDecoderLayer(nn.Module):
|
261
|
+
def __init__(
|
262
|
+
self,
|
263
|
+
config: LlamaConfig,
|
264
|
+
layer_id: int = 0,
|
265
|
+
quant_config: Optional[QuantizationConfig] = None,
|
266
|
+
prefix: str = "",
|
267
|
+
) -> None:
|
268
|
+
super().__init__()
|
269
|
+
self.hidden_size = config.hidden_size
|
270
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
271
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
272
|
+
if rope_scaling is not None and getattr(
|
273
|
+
config, "original_max_position_embeddings", None
|
274
|
+
):
|
275
|
+
rope_scaling["original_max_position_embeddings"] = (
|
276
|
+
config.original_max_position_embeddings
|
277
|
+
)
|
278
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
279
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
280
|
+
self.self_attn = LlamaAttention(
|
281
|
+
config=config,
|
282
|
+
hidden_size=self.hidden_size,
|
283
|
+
num_heads=config.num_attention_heads,
|
284
|
+
num_kv_heads=config.num_key_value_heads,
|
285
|
+
layer_id=layer_id,
|
286
|
+
rope_theta=rope_theta,
|
287
|
+
rope_scaling=rope_scaling,
|
288
|
+
rope_is_neox_style=rope_is_neox_style,
|
289
|
+
max_position_embeddings=max_position_embeddings,
|
290
|
+
quant_config=quant_config,
|
291
|
+
prefix=f"{prefix}.self_attn",
|
292
|
+
)
|
293
|
+
self.mlp = LlamaMLP(
|
294
|
+
hidden_size=self.hidden_size,
|
295
|
+
intermediate_size=config.intermediate_size,
|
296
|
+
hidden_act=config.hidden_act,
|
297
|
+
quant_config=quant_config,
|
298
|
+
prefix=f"{prefix}.mlp",
|
299
|
+
)
|
300
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
301
|
+
self.post_attention_layernorm = RMSNorm(
|
302
|
+
config.hidden_size, eps=config.rms_norm_eps
|
303
|
+
)
|
304
|
+
|
305
|
+
def forward(
|
306
|
+
self,
|
307
|
+
positions: torch.Tensor,
|
308
|
+
hidden_states: torch.Tensor,
|
309
|
+
forward_batch: ForwardBatch,
|
310
|
+
residual: Optional[torch.Tensor],
|
311
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
312
|
+
# Self Attention
|
313
|
+
if residual is None:
|
314
|
+
residual = hidden_states
|
315
|
+
hidden_states = self.input_layernorm(hidden_states)
|
316
|
+
else:
|
317
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
318
|
+
hidden_states = self.self_attn(
|
319
|
+
positions=positions,
|
320
|
+
hidden_states=hidden_states,
|
321
|
+
forward_batch=forward_batch,
|
322
|
+
)
|
323
|
+
|
324
|
+
# Fully Connected
|
325
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
326
|
+
hidden_states = self.mlp(hidden_states)
|
327
|
+
return hidden_states, residual
|
328
|
+
|
329
|
+
|
330
|
+
class LlamaModel(nn.Module):
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
config: LlamaConfig,
|
334
|
+
quant_config: Optional[QuantizationConfig] = None,
|
335
|
+
) -> None:
|
336
|
+
super().__init__()
|
337
|
+
self.config = config
|
338
|
+
self.padding_idx = config.pad_token_id
|
339
|
+
self.vocab_size = config.vocab_size
|
340
|
+
self.embed_tokens = VocabParallelEmbedding(
|
341
|
+
config.vocab_size,
|
342
|
+
config.hidden_size,
|
343
|
+
)
|
344
|
+
self.layers = nn.ModuleList(
|
345
|
+
[
|
346
|
+
LlamaDecoderLayer(
|
347
|
+
config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
|
348
|
+
)
|
349
|
+
for i in range(config.num_hidden_layers)
|
350
|
+
]
|
351
|
+
)
|
352
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
353
|
+
|
354
|
+
def forward(
|
355
|
+
self,
|
356
|
+
input_ids: torch.Tensor,
|
357
|
+
positions: torch.Tensor,
|
358
|
+
forward_batch: ForwardBatch,
|
359
|
+
input_embeds: torch.Tensor = None,
|
360
|
+
) -> torch.Tensor:
|
361
|
+
if input_embeds is None:
|
362
|
+
hidden_states = self.embed_tokens(input_ids)
|
363
|
+
else:
|
364
|
+
hidden_states = input_embeds
|
365
|
+
residual = None
|
366
|
+
for i in range(len(self.layers)):
|
367
|
+
layer = self.layers[i]
|
368
|
+
hidden_states, residual = layer(
|
369
|
+
positions,
|
370
|
+
hidden_states,
|
371
|
+
forward_batch,
|
372
|
+
residual,
|
373
|
+
)
|
374
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
375
|
+
return hidden_states
|
376
|
+
|
377
|
+
|
378
|
+
class TorchNativeLlamaForCausalLM(nn.Module):
|
379
|
+
def __init__(
|
380
|
+
self,
|
381
|
+
config: LlamaConfig,
|
382
|
+
quant_config: Optional[QuantizationConfig] = None,
|
383
|
+
cache_config: Optional[CacheConfig] = None,
|
384
|
+
) -> None:
|
385
|
+
super().__init__()
|
386
|
+
self.config = config
|
387
|
+
self.quant_config = quant_config
|
388
|
+
self.torchao_config = global_server_args_dict["torchao_config"]
|
389
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
390
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
391
|
+
self.logits_processor = LogitsProcessor(config)
|
392
|
+
|
393
|
+
@torch.no_grad()
|
394
|
+
def forward(
|
395
|
+
self,
|
396
|
+
input_ids: torch.Tensor,
|
397
|
+
positions: torch.Tensor,
|
398
|
+
forward_batch: ForwardBatch,
|
399
|
+
input_embeds: torch.Tensor = None,
|
400
|
+
) -> LogitsProcessorOutput:
|
401
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
402
|
+
return self.logits_processor(
|
403
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
404
|
+
)
|
405
|
+
|
406
|
+
def get_hidden_dim(self, module_name):
|
407
|
+
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
408
|
+
return self.config.hidden_size, self.config.hidden_size
|
409
|
+
elif module_name in ["kv_proj"]:
|
410
|
+
return self.config.hidden_size, self.config.hidden_size // (
|
411
|
+
self.config.num_attention_heads // self.config.num_key_value_heads
|
412
|
+
)
|
413
|
+
elif module_name == "gate_up_proj":
|
414
|
+
return self.config.hidden_size, self.config.intermediate_size
|
415
|
+
elif module_name == "down_proj":
|
416
|
+
return self.config.intermediate_size, self.config.hidden_size
|
417
|
+
else:
|
418
|
+
raise NotImplementedError()
|
419
|
+
|
420
|
+
def get_module_name(self, name):
|
421
|
+
params_mapping = {
|
422
|
+
"q_proj": "qkv_proj",
|
423
|
+
"k_proj": "qkv_proj",
|
424
|
+
"v_proj": "qkv_proj",
|
425
|
+
"gate_proj": "gate_up_proj",
|
426
|
+
"up_proj": "gate_up_proj",
|
427
|
+
}
|
428
|
+
return params_mapping.get(name, name)
|
429
|
+
|
430
|
+
def get_module_name_from_weight_name(self, name):
|
431
|
+
stacked_params_mapping = [
|
432
|
+
# (param_name, shard_name, shard_id, num_shard)
|
433
|
+
("qkv_proj", "q_proj", "q", 3),
|
434
|
+
("qkv_proj", "k_proj", "k", 3),
|
435
|
+
("qkv_proj", "v_proj", "v", 3),
|
436
|
+
("gate_up_proj", "gate_proj", 0, 2),
|
437
|
+
("gate_up_proj", "up_proj", 1, 2),
|
438
|
+
]
|
439
|
+
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
440
|
+
if weight_name in name:
|
441
|
+
return (
|
442
|
+
name.replace(weight_name, param_name)[: -len(".weight")],
|
443
|
+
num_shard,
|
444
|
+
)
|
445
|
+
return name[: -len(".weight")], 1
|
446
|
+
|
447
|
+
def get_num_params(self):
|
448
|
+
params_dict = dict(self.named_parameters())
|
449
|
+
return len(params_dict)
|
450
|
+
|
451
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
452
|
+
stacked_params_mapping = [
|
453
|
+
# (param_name, shard_name, shard_id)
|
454
|
+
(".qkv_proj", ".q_proj", "q"),
|
455
|
+
(".qkv_proj", ".k_proj", "k"),
|
456
|
+
(".qkv_proj", ".v_proj", "v"),
|
457
|
+
(".gate_up_proj", ".gate_proj", 0),
|
458
|
+
(".gate_up_proj", ".up_proj", 1),
|
459
|
+
]
|
460
|
+
params_dict = dict(self.named_parameters())
|
461
|
+
|
462
|
+
for name, loaded_weight in weights:
|
463
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
464
|
+
continue
|
465
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
466
|
+
# Models trained using ColossalAI may include these tensors in
|
467
|
+
# the checkpoint. Skip them.
|
468
|
+
continue
|
469
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
470
|
+
continue
|
471
|
+
|
472
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
473
|
+
if weight_name not in name:
|
474
|
+
continue
|
475
|
+
name = name.replace(weight_name, param_name)
|
476
|
+
# Skip loading extra bias for GPTQ models.
|
477
|
+
if name.endswith(".bias") and name not in params_dict:
|
478
|
+
continue
|
479
|
+
param = params_dict[name]
|
480
|
+
weight_loader = param.weight_loader
|
481
|
+
weight_loader(param, loaded_weight, shard_id)
|
482
|
+
break
|
483
|
+
else:
|
484
|
+
# Skip loading extra bias for GPTQ models.
|
485
|
+
if name.endswith(".bias") and name not in params_dict:
|
486
|
+
continue
|
487
|
+
param = params_dict[name]
|
488
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
489
|
+
weight_loader(param, loaded_weight)
|
490
|
+
|
491
|
+
if (
|
492
|
+
hasattr(self.config, "tie_word_embeddings")
|
493
|
+
and self.config.tie_word_embeddings
|
494
|
+
):
|
495
|
+
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
|
496
|
+
param = self.lm_head.weight
|
497
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
498
|
+
weight_loader(param, self.model.embed_tokens.weight)
|
499
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
500
|
+
|
501
|
+
|
502
|
+
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
503
|
+
pass
|
504
|
+
|
505
|
+
|
506
|
+
EntryClass = [TorchNativeLlamaForCausalLM, TorchNativePhi3ForCausalLM]
|
sglang/srt/models/xverse.py
CHANGED
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.model_executor.model_runner import
|
44
|
+
from sglang.srt.model_executor.model_runner import ForwardBatch
|
45
45
|
|
46
46
|
|
47
47
|
class XverseMLP(nn.Module):
|
@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
|
|
160
160
|
self,
|
161
161
|
positions: torch.Tensor,
|
162
162
|
hidden_states: torch.Tensor,
|
163
|
-
|
163
|
+
forward_batch: ForwardBatch,
|
164
164
|
) -> torch.Tensor:
|
165
165
|
qkv, _ = self.qkv_proj(hidden_states)
|
166
166
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
167
167
|
q, k = self.rotary_emb(positions, q, k)
|
168
|
-
attn_output = self.attn(q, k, v,
|
168
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
169
169
|
output, _ = self.o_proj(attn_output)
|
170
170
|
return output
|
171
171
|
|
@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
|
|
222
222
|
self,
|
223
223
|
positions: torch.Tensor,
|
224
224
|
hidden_states: torch.Tensor,
|
225
|
-
|
225
|
+
forward_batch: ForwardBatch,
|
226
226
|
residual: Optional[torch.Tensor],
|
227
227
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
228
228
|
# Self Attention
|
@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
|
|
234
234
|
hidden_states = self.self_attn(
|
235
235
|
positions=positions,
|
236
236
|
hidden_states=hidden_states,
|
237
|
-
|
237
|
+
forward_batch=forward_batch,
|
238
238
|
)
|
239
239
|
|
240
240
|
# Fully Connected
|
@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
|
|
271
271
|
self,
|
272
272
|
input_ids: torch.Tensor,
|
273
273
|
positions: torch.Tensor,
|
274
|
-
|
274
|
+
forward_batch: ForwardBatch,
|
275
275
|
input_embeds: torch.Tensor = None,
|
276
276
|
) -> torch.Tensor:
|
277
277
|
if input_embeds is None:
|
@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
|
|
284
284
|
hidden_states, residual = layer(
|
285
285
|
positions,
|
286
286
|
hidden_states,
|
287
|
-
|
287
|
+
forward_batch,
|
288
288
|
residual,
|
289
289
|
)
|
290
290
|
# print(f"layer[{i}].hidden_states: {hidden_states}")
|
@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
|
|
312
312
|
self,
|
313
313
|
input_ids: torch.Tensor,
|
314
314
|
positions: torch.Tensor,
|
315
|
-
|
315
|
+
forward_batch: ForwardBatch,
|
316
316
|
input_embeds: torch.Tensor = None,
|
317
317
|
) -> torch.Tensor:
|
318
|
-
hidden_states = self.model(input_ids, positions,
|
318
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
319
319
|
return self.logits_processor(
|
320
|
-
input_ids, hidden_states, self.lm_head.weight,
|
320
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
321
321
|
)
|
322
322
|
|
323
323
|
def load_weights(
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
44
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
46
|
from sglang.srt.layers.radix_attention import RadixAttention
|
47
|
-
from sglang.srt.model_executor.forward_batch_info import
|
47
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
48
|
|
49
49
|
|
50
50
|
class XverseMLP(nn.Module):
|
@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
|
|
244
244
|
self,
|
245
245
|
positions: torch.Tensor,
|
246
246
|
hidden_states: torch.Tensor,
|
247
|
-
|
247
|
+
forward_batch: ForwardBatch,
|
248
248
|
) -> torch.Tensor:
|
249
249
|
qkv, _ = self.qkv_proj(hidden_states)
|
250
250
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
251
251
|
q, k = self.rotary_emb(positions, q, k)
|
252
|
-
attn_output = self.attn(q, k, v,
|
252
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
253
253
|
output, _ = self.o_proj(attn_output)
|
254
254
|
return output
|
255
255
|
|
@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
|
|
300
300
|
self,
|
301
301
|
positions: torch.Tensor,
|
302
302
|
hidden_states: torch.Tensor,
|
303
|
-
|
303
|
+
forward_batch: ForwardBatch,
|
304
304
|
residual: Optional[torch.Tensor],
|
305
305
|
) -> torch.Tensor:
|
306
306
|
# Self Attention
|
@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
|
|
312
312
|
hidden_states = self.self_attn(
|
313
313
|
positions=positions,
|
314
314
|
hidden_states=hidden_states,
|
315
|
-
|
315
|
+
forward_batch=forward_batch,
|
316
316
|
)
|
317
317
|
|
318
318
|
# Fully Connected
|
@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
|
|
353
353
|
self,
|
354
354
|
input_ids: torch.Tensor,
|
355
355
|
positions: torch.Tensor,
|
356
|
-
|
356
|
+
forward_batch: ForwardBatch,
|
357
357
|
) -> torch.Tensor:
|
358
358
|
hidden_states = self.embed_tokens(input_ids)
|
359
359
|
residual = None
|
360
360
|
for i in range(len(self.layers)):
|
361
361
|
layer = self.layers[i]
|
362
362
|
hidden_states, residual = layer(
|
363
|
-
positions, hidden_states,
|
363
|
+
positions, hidden_states, forward_batch, residual
|
364
364
|
)
|
365
365
|
hidden_states, _ = self.norm(hidden_states, residual)
|
366
366
|
return hidden_states
|
@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
|
|
388
388
|
self,
|
389
389
|
input_ids: torch.Tensor,
|
390
390
|
positions: torch.Tensor,
|
391
|
-
|
391
|
+
forward_batch: ForwardBatch,
|
392
392
|
) -> torch.Tensor:
|
393
|
-
hidden_states = self.model(input_ids, positions,
|
393
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
394
394
|
return self.logits_processor(
|
395
|
-
input_ids, hidden_states, self.lm_head.weight,
|
395
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
396
396
|
)
|
397
397
|
|
398
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -858,11 +858,18 @@ def v1_chat_generate_request(
|
|
858
858
|
openai_compatible_messages.append(
|
859
859
|
{"role": message.role, "content": content["text"]}
|
860
860
|
)
|
861
|
+
if openai_compatible_messages[-1]["role"] == "assistant":
|
862
|
+
assistant_prefix = openai_compatible_messages[-1]["content"]
|
863
|
+
openai_compatible_messages = openai_compatible_messages[:-1]
|
864
|
+
else:
|
865
|
+
assistant_prefix = None
|
861
866
|
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
862
867
|
openai_compatible_messages,
|
863
868
|
tokenize=True,
|
864
869
|
add_generation_prompt=True,
|
865
870
|
)
|
871
|
+
if assistant_prefix:
|
872
|
+
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
866
873
|
stop = request.stop
|
867
874
|
image_data = None
|
868
875
|
modalities = []
|