sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -18,38 +18,30 @@ limitations under the License.
|
|
18
18
|
"""Inference-only Mixtral model."""
|
19
19
|
from typing import Iterable, Optional, Tuple
|
20
20
|
|
21
|
-
import numpy as np
|
22
21
|
import torch
|
23
|
-
import torch.nn.functional as F
|
24
22
|
from torch import nn
|
25
23
|
from transformers import MixtralConfig
|
26
|
-
from vllm import _custom_ops as ops
|
27
24
|
from vllm.config import CacheConfig
|
28
|
-
from vllm.distributed import
|
29
|
-
|
30
|
-
get_tensor_model_parallel_world_size,
|
31
|
-
tensor_model_parallel_all_reduce,
|
32
|
-
)
|
33
|
-
from vllm.model_executor.layers.fused_moe import fused_moe
|
34
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
25
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
35
27
|
from vllm.model_executor.layers.linear import (
|
36
28
|
QKVParallelLinear,
|
37
29
|
ReplicatedLinear,
|
38
30
|
RowParallelLinear,
|
39
31
|
)
|
40
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
41
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
42
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
43
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
44
36
|
ParallelLMHead,
|
45
37
|
VocabParallelEmbedding,
|
46
38
|
)
|
47
39
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
48
|
-
from vllm.model_executor.utils import set_weight_attrs
|
49
|
-
from vllm.utils import print_warning_once
|
50
40
|
|
41
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
51
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.sampler import Sampler
|
53
45
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
54
46
|
|
55
47
|
|
@@ -69,216 +61,44 @@ class MixtralMoE(nn.Module):
|
|
69
61
|
hidden_size: int,
|
70
62
|
intermediate_size: int,
|
71
63
|
params_dtype: Optional[torch.dtype] = None,
|
72
|
-
tp_size: Optional[int] = None,
|
73
64
|
quant_config: Optional[QuantizationConfig] = None,
|
65
|
+
tp_size: Optional[int] = None,
|
66
|
+
prefix: str = "",
|
74
67
|
):
|
75
68
|
super().__init__()
|
76
|
-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
77
|
-
self.num_total_experts = num_experts
|
78
|
-
self.top_k = top_k
|
79
69
|
self.hidden_size = hidden_size
|
80
|
-
self.intermediate_size = intermediate_size // self.tp_size
|
81
|
-
self.quant_config = quant_config
|
82
|
-
|
83
|
-
# FIXME(pcmoritz): Make this more general to support different
|
84
|
-
# quantization schemes
|
85
|
-
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
86
|
-
|
87
|
-
if params_dtype is None:
|
88
|
-
params_dtype = torch.get_default_dtype()
|
89
|
-
self.params_dtype = params_dtype
|
90
70
|
|
91
71
|
# Gate always runs at half / full precision for now.
|
92
72
|
self.gate = ReplicatedLinear(
|
93
|
-
|
94
|
-
|
73
|
+
hidden_size,
|
74
|
+
num_experts,
|
95
75
|
bias=False,
|
96
|
-
params_dtype=
|
76
|
+
params_dtype=params_dtype,
|
97
77
|
quant_config=None,
|
78
|
+
prefix=f"{prefix}.gate",
|
98
79
|
)
|
99
80
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
self.w2_weight = nn.Parameter(
|
112
|
-
torch.empty(
|
113
|
-
self.num_total_experts,
|
114
|
-
self.hidden_size,
|
115
|
-
self.intermediate_size,
|
116
|
-
dtype=params_dtype,
|
117
|
-
)
|
118
|
-
)
|
119
|
-
|
120
|
-
set_weight_attrs(
|
121
|
-
self.w13_weight,
|
122
|
-
{
|
123
|
-
"weight_loader": self.weight_loader,
|
124
|
-
},
|
125
|
-
)
|
126
|
-
set_weight_attrs(
|
127
|
-
self.w2_weight,
|
128
|
-
{
|
129
|
-
"weight_loader": self.weight_loader,
|
130
|
-
},
|
81
|
+
self.experts = FusedMoE(
|
82
|
+
num_experts=num_experts,
|
83
|
+
top_k=top_k,
|
84
|
+
hidden_size=hidden_size,
|
85
|
+
intermediate_size=intermediate_size,
|
86
|
+
params_dtype=params_dtype,
|
87
|
+
reduce_results=True,
|
88
|
+
renormalize=True,
|
89
|
+
quant_config=quant_config,
|
90
|
+
tp_size=tp_size,
|
91
|
+
prefix=f"{prefix}.experts",
|
131
92
|
)
|
132
93
|
|
133
|
-
# Used for fp8.
|
134
|
-
self.w13_scale = None
|
135
|
-
self.w2_scale = None
|
136
|
-
self.a13_scale = None
|
137
|
-
self.a2_scale = None
|
138
|
-
|
139
|
-
if self.use_fp8:
|
140
|
-
# WEIGHT_SCALE (for fp8)
|
141
|
-
self.w13_scale = nn.Parameter(
|
142
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
143
|
-
requires_grad=False,
|
144
|
-
)
|
145
|
-
self.w2_scale = nn.Parameter(
|
146
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
147
|
-
requires_grad=False,
|
148
|
-
)
|
149
|
-
|
150
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
151
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
152
|
-
# process_weights_after_loading()
|
153
|
-
if quant_config.is_checkpoint_fp8_serialized:
|
154
|
-
set_weight_attrs(
|
155
|
-
self.w13_scale,
|
156
|
-
{
|
157
|
-
"weight_loader": self.weight_loader,
|
158
|
-
},
|
159
|
-
)
|
160
|
-
set_weight_attrs(
|
161
|
-
self.w2_scale,
|
162
|
-
{
|
163
|
-
"weight_loader": self.weight_loader,
|
164
|
-
},
|
165
|
-
)
|
166
|
-
|
167
|
-
# ACT_SCALE (for fp8)
|
168
|
-
if quant_config.activation_scheme == "static":
|
169
|
-
if not quant_config.is_checkpoint_fp8_serialized:
|
170
|
-
raise ValueError(
|
171
|
-
"Found static activation scheme for checkpoint that "
|
172
|
-
"was not serialized fp8."
|
173
|
-
)
|
174
|
-
self.a13_scale = nn.Parameter(
|
175
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
176
|
-
requires_grad=False,
|
177
|
-
)
|
178
|
-
self.a2_scale = nn.Parameter(
|
179
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
180
|
-
requires_grad=False,
|
181
|
-
)
|
182
|
-
|
183
|
-
set_weight_attrs(
|
184
|
-
self.a13_scale,
|
185
|
-
{
|
186
|
-
"weight_loader": self.weight_loader,
|
187
|
-
},
|
188
|
-
)
|
189
|
-
set_weight_attrs(
|
190
|
-
self.a2_scale,
|
191
|
-
{
|
192
|
-
"weight_loader": self.weight_loader,
|
193
|
-
},
|
194
|
-
)
|
195
|
-
|
196
|
-
def weight_loader(
|
197
|
-
self,
|
198
|
-
param: nn.Parameter,
|
199
|
-
loaded_weight: torch.Tensor,
|
200
|
-
weight_name: str,
|
201
|
-
expert_id: int,
|
202
|
-
):
|
203
|
-
tp_rank = get_tensor_model_parallel_rank()
|
204
|
-
param_data = param.data
|
205
|
-
shard_size = self.intermediate_size
|
206
|
-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
207
|
-
if weight_name.endswith("w1.weight"):
|
208
|
-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
209
|
-
if weight_name.endswith("w3.weight"):
|
210
|
-
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
211
|
-
shard, :
|
212
|
-
]
|
213
|
-
if weight_name.endswith("w2.weight"):
|
214
|
-
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
215
|
-
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
216
|
-
param_data[expert_id] = loaded_weight
|
217
|
-
|
218
|
-
def process_weights_after_loading(self):
|
219
|
-
# Fp8 is the only case where we need to process after loading.
|
220
|
-
if not self.use_fp8:
|
221
|
-
return
|
222
|
-
|
223
|
-
# If checkpoint is fp16, quantize here.
|
224
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
225
|
-
w13_weight = torch.empty_like(
|
226
|
-
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
227
|
-
)
|
228
|
-
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
229
|
-
for expert in range(self.num_total_experts):
|
230
|
-
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
231
|
-
self.w13_weight.data[expert, :, :]
|
232
|
-
)
|
233
|
-
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
234
|
-
self.w2_weight.data[expert, :, :]
|
235
|
-
)
|
236
|
-
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
237
|
-
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
238
|
-
|
239
|
-
# If checkpoint is fp8 + static, cleanup act_scales.
|
240
|
-
# Since state_dict has an act_scale per expert but our kernels
|
241
|
-
# are passed one act_scale shared across all experts.
|
242
|
-
elif self.quant_config.activation_scheme == "static":
|
243
|
-
if self.a13_scale is None or self.a2_scale is None:
|
244
|
-
raise ValueError(
|
245
|
-
"QuantConfig has static quantization, but found "
|
246
|
-
"activation scales are None."
|
247
|
-
)
|
248
|
-
|
249
|
-
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
250
|
-
print_warning_once(
|
251
|
-
"Found act_scales that are not equal for fp8 MoE layer. "
|
252
|
-
"Using the maximum across experts for each layer. "
|
253
|
-
)
|
254
|
-
|
255
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
256
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
257
|
-
|
258
94
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
259
|
-
|
95
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
96
|
+
orig_shape = hidden_states.shape
|
260
97
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
261
98
|
# router_logits: (num_tokens, n_experts)
|
262
99
|
router_logits, _ = self.gate(hidden_states)
|
263
|
-
final_hidden_states =
|
264
|
-
|
265
|
-
self.w13_weight,
|
266
|
-
self.w2_weight,
|
267
|
-
router_logits,
|
268
|
-
self.top_k,
|
269
|
-
renormalize=True,
|
270
|
-
inplace=True,
|
271
|
-
use_fp8=self.use_fp8,
|
272
|
-
w1_scale=self.w13_scale,
|
273
|
-
w2_scale=self.w2_scale,
|
274
|
-
a1_scale=self.a13_scale,
|
275
|
-
a2_scale=self.a2_scale,
|
276
|
-
)
|
277
|
-
|
278
|
-
if self.tp_size > 1:
|
279
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
280
|
-
|
281
|
-
return final_hidden_states.view(num_tokens, hidden_size)
|
100
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
101
|
+
return final_hidden_states.view(orig_shape)
|
282
102
|
|
283
103
|
|
284
104
|
class MixtralAttention(nn.Module):
|
@@ -291,7 +111,7 @@ class MixtralAttention(nn.Module):
|
|
291
111
|
max_position: int = 4096 * 32,
|
292
112
|
rope_theta: float = 10000,
|
293
113
|
quant_config: Optional[QuantizationConfig] = None,
|
294
|
-
|
114
|
+
prefix: str = "",
|
295
115
|
) -> None:
|
296
116
|
super().__init__()
|
297
117
|
self.hidden_size = hidden_size
|
@@ -314,7 +134,6 @@ class MixtralAttention(nn.Module):
|
|
314
134
|
self.kv_size = self.num_kv_heads * self.head_dim
|
315
135
|
self.scaling = self.head_dim**-0.5
|
316
136
|
self.rope_theta = rope_theta
|
317
|
-
self.sliding_window = sliding_window
|
318
137
|
|
319
138
|
self.qkv_proj = QKVParallelLinear(
|
320
139
|
hidden_size,
|
@@ -323,12 +142,14 @@ class MixtralAttention(nn.Module):
|
|
323
142
|
self.total_num_kv_heads,
|
324
143
|
bias=False,
|
325
144
|
quant_config=quant_config,
|
145
|
+
prefix=f"{prefix}.qkv_proj",
|
326
146
|
)
|
327
147
|
self.o_proj = RowParallelLinear(
|
328
148
|
self.total_num_heads * self.head_dim,
|
329
149
|
hidden_size,
|
330
150
|
bias=False,
|
331
151
|
quant_config=quant_config,
|
152
|
+
prefix=f"{prefix}.o_proj",
|
332
153
|
)
|
333
154
|
self.rotary_emb = get_rope(
|
334
155
|
self.head_dim,
|
@@ -365,6 +186,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
365
186
|
config: MixtralConfig,
|
366
187
|
layer_id: int = 0,
|
367
188
|
quant_config: Optional[QuantizationConfig] = None,
|
189
|
+
prefix: str = "",
|
368
190
|
) -> None:
|
369
191
|
super().__init__()
|
370
192
|
self.hidden_size = config.hidden_size
|
@@ -377,8 +199,8 @@ class MixtralDecoderLayer(nn.Module):
|
|
377
199
|
num_kv_heads=config.num_key_value_heads,
|
378
200
|
layer_id=layer_id,
|
379
201
|
rope_theta=rope_theta,
|
380
|
-
sliding_window=config.sliding_window,
|
381
202
|
quant_config=quant_config,
|
203
|
+
prefix=f"{prefix}.self_attn",
|
382
204
|
)
|
383
205
|
self.block_sparse_moe = MixtralMoE(
|
384
206
|
num_experts=config.num_local_experts,
|
@@ -386,6 +208,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
386
208
|
hidden_size=config.hidden_size,
|
387
209
|
intermediate_size=config.intermediate_size,
|
388
210
|
quant_config=quant_config,
|
211
|
+
prefix=f"{prefix}.block_sparse_moe",
|
389
212
|
)
|
390
213
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
391
214
|
self.post_attention_layernorm = RMSNorm(
|
@@ -422,6 +245,7 @@ class MixtralModel(nn.Module):
|
|
422
245
|
self,
|
423
246
|
config: MixtralConfig,
|
424
247
|
quant_config: Optional[QuantizationConfig] = None,
|
248
|
+
prefix: str = "",
|
425
249
|
) -> None:
|
426
250
|
super().__init__()
|
427
251
|
self.padding_idx = config.pad_token_id
|
@@ -431,10 +255,11 @@ class MixtralModel(nn.Module):
|
|
431
255
|
config.vocab_size,
|
432
256
|
config.hidden_size,
|
433
257
|
)
|
434
|
-
# config.num_hidden_layers=16
|
435
258
|
self.layers = nn.ModuleList(
|
436
259
|
[
|
437
|
-
MixtralDecoderLayer(
|
260
|
+
MixtralDecoderLayer(
|
261
|
+
config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
|
262
|
+
)
|
438
263
|
for i in range(config.num_hidden_layers)
|
439
264
|
]
|
440
265
|
)
|
@@ -462,6 +287,7 @@ class MixtralModel(nn.Module):
|
|
462
287
|
|
463
288
|
|
464
289
|
class MixtralForCausalLM(nn.Module):
|
290
|
+
|
465
291
|
def __init__(
|
466
292
|
self,
|
467
293
|
config: MixtralConfig,
|
@@ -471,11 +297,11 @@ class MixtralForCausalLM(nn.Module):
|
|
471
297
|
super().__init__()
|
472
298
|
self.config = config
|
473
299
|
self.quant_config = quant_config
|
474
|
-
self.model = MixtralModel(config, quant_config=quant_config)
|
300
|
+
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
475
301
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
476
302
|
self.logits_processor = LogitsProcessor(config)
|
303
|
+
self.sampler = Sampler()
|
477
304
|
|
478
|
-
@torch.no_grad()
|
479
305
|
def forward(
|
480
306
|
self,
|
481
307
|
input_ids: torch.Tensor,
|
@@ -484,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
|
|
484
310
|
input_embeds: torch.Tensor = None,
|
485
311
|
) -> torch.Tensor:
|
486
312
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
487
|
-
|
313
|
+
logits_output = self.logits_processor(
|
488
314
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
489
315
|
)
|
316
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
317
|
+
return sample_output, logits_output
|
490
318
|
|
491
319
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
492
320
|
stacked_params_mapping = [
|
@@ -496,40 +324,13 @@ class MixtralForCausalLM(nn.Module):
|
|
496
324
|
("qkv_proj", "v_proj", "v"),
|
497
325
|
]
|
498
326
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
expert_id,
|
507
|
-
)
|
508
|
-
for expert_id in range(self.config.num_local_experts)
|
509
|
-
for weight_name in ["w1", "w2", "w3"]
|
510
|
-
]
|
511
|
-
+ [
|
512
|
-
# These are the weights for the experts
|
513
|
-
# (param_name, weight_name, expert_id)
|
514
|
-
(
|
515
|
-
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
516
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
517
|
-
expert_id,
|
518
|
-
)
|
519
|
-
for expert_id in range(self.config.num_local_experts)
|
520
|
-
for weight_name in ["w1", "w2", "w3"]
|
521
|
-
]
|
522
|
-
+ [
|
523
|
-
# These are the activation scales for the experts
|
524
|
-
# (param_name, weight_name, expert_id)
|
525
|
-
(
|
526
|
-
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
527
|
-
f"experts.{expert_id}.{weight_name}.act_scale",
|
528
|
-
expert_id,
|
529
|
-
)
|
530
|
-
for expert_id in range(self.config.num_local_experts)
|
531
|
-
for weight_name in ["w1", "w2", "w3"]
|
532
|
-
]
|
327
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
328
|
+
# (param_name, weight_name, expert_id, shard_id)
|
329
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
330
|
+
ckpt_gate_proj_name="w1",
|
331
|
+
ckpt_down_proj_name="w2",
|
332
|
+
ckpt_up_proj_name="w3",
|
333
|
+
num_experts=self.config.num_local_experts,
|
533
334
|
)
|
534
335
|
|
535
336
|
params_dict = dict(self.named_parameters())
|
@@ -544,25 +345,35 @@ class MixtralForCausalLM(nn.Module):
|
|
544
345
|
# Skip loading extra bias for GPTQ models.
|
545
346
|
if name.endswith(".bias") and name not in params_dict:
|
546
347
|
continue
|
348
|
+
|
547
349
|
param = params_dict[name]
|
548
350
|
weight_loader = param.weight_loader
|
549
351
|
weight_loader(param, loaded_weight, shard_id)
|
550
352
|
break
|
551
353
|
else:
|
552
|
-
for
|
354
|
+
for mapping in expert_params_mapping:
|
355
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
553
356
|
if weight_name not in name:
|
554
357
|
continue
|
555
358
|
name = name.replace(weight_name, param_name)
|
359
|
+
|
556
360
|
param = params_dict[name]
|
557
361
|
weight_loader = param.weight_loader
|
558
362
|
weight_loader(
|
559
|
-
param,
|
363
|
+
param,
|
364
|
+
loaded_weight,
|
365
|
+
weight_name,
|
366
|
+
shard_id=shard_id,
|
367
|
+
expert_id=expert_id,
|
560
368
|
)
|
561
369
|
break
|
562
370
|
else:
|
563
371
|
# Skip loading extra bias for GPTQ models.
|
564
372
|
if name.endswith(".bias") and name not in params_dict:
|
565
373
|
continue
|
374
|
+
if name is None:
|
375
|
+
continue
|
376
|
+
|
566
377
|
param = params_dict[name]
|
567
378
|
weight_loader = getattr(
|
568
379
|
param, "weight_loader", default_weight_loader
|
@@ -570,9 +381,4 @@ class MixtralForCausalLM(nn.Module):
|
|
570
381
|
weight_loader(param, loaded_weight)
|
571
382
|
|
572
383
|
|
573
|
-
def all_close_1d(x: torch.Tensor) -> bool:
|
574
|
-
assert len(x.shape) == 1
|
575
|
-
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
576
|
-
|
577
|
-
|
578
384
|
EntryClass = MixtralForCausalLM
|
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
get_tensor_model_parallel_world_size,
|
30
30
|
tensor_model_parallel_all_reduce,
|
31
31
|
)
|
32
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
33
32
|
from vllm.model_executor.layers.linear import (
|
34
33
|
QKVParallelLinear,
|
35
34
|
ReplicatedLinear,
|
@@ -43,8 +42,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
43
42
|
)
|
44
43
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
45
44
|
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
+
from sglang.srt.layers.sampler import Sampler
|
48
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
50
|
|
50
51
|
|
@@ -160,7 +161,6 @@ class MixtralAttention(nn.Module):
|
|
160
161
|
max_position: int = 4096 * 32,
|
161
162
|
rope_theta: float = 10000,
|
162
163
|
quant_config: Optional[QuantizationConfig] = None,
|
163
|
-
sliding_window: Optional[int] = None,
|
164
164
|
) -> None:
|
165
165
|
super().__init__()
|
166
166
|
self.hidden_size = hidden_size
|
@@ -183,7 +183,6 @@ class MixtralAttention(nn.Module):
|
|
183
183
|
self.kv_size = self.num_kv_heads * self.head_dim
|
184
184
|
self.scaling = self.head_dim**-0.5
|
185
185
|
self.rope_theta = rope_theta
|
186
|
-
self.sliding_window = sliding_window
|
187
186
|
|
188
187
|
self.qkv_proj = QKVParallelLinear(
|
189
188
|
hidden_size,
|
@@ -246,7 +245,6 @@ class MixtralDecoderLayer(nn.Module):
|
|
246
245
|
num_kv_heads=config.num_key_value_heads,
|
247
246
|
layer_id=layer_id,
|
248
247
|
rope_theta=rope_theta,
|
249
|
-
sliding_window=config.sliding_window,
|
250
248
|
quant_config=quant_config,
|
251
249
|
)
|
252
250
|
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
@@ -336,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
336
334
|
self.model = MixtralModel(config, quant_config=quant_config)
|
337
335
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
338
336
|
self.logits_processor = LogitsProcessor(config)
|
337
|
+
self.sampler = Sampler()
|
339
338
|
|
340
339
|
@torch.no_grad()
|
341
340
|
def forward(
|
@@ -346,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
346
345
|
input_embeds: torch.Tensor = None,
|
347
346
|
) -> torch.Tensor:
|
348
347
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
349
|
-
|
348
|
+
logits_output = self.logits_processor(
|
350
349
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
351
350
|
)
|
351
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
352
|
+
return sample_output, logits_output
|
352
353
|
|
353
354
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
354
355
|
stacked_params_mapping = [
|
sglang/srt/models/qwen.py
CHANGED
@@ -22,8 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.config import CacheConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
26
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
27
25
|
from vllm.model_executor.layers.linear import (
|
28
26
|
MergedColumnParallelLinear,
|
29
27
|
QKVParallelLinear,
|
@@ -37,8 +35,11 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
37
35
|
)
|
38
36
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
37
|
|
38
|
+
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.sampler import Sampler
|
42
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
44
|
|
44
45
|
|
@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
|
|
251
252
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
252
253
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
253
254
|
self.logits_processor = LogitsProcessor(config)
|
255
|
+
self.sampler = Sampler()
|
254
256
|
|
255
257
|
@torch.no_grad()
|
256
258
|
def forward(
|
@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
|
|
260
262
|
input_metadata: InputMetadata,
|
261
263
|
):
|
262
264
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
263
|
-
|
265
|
+
logits_output = self.logits_processor(
|
264
266
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
265
267
|
)
|
266
|
-
|
268
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
269
|
+
return sample_output, logits_output
|
267
270
|
|
268
271
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
269
272
|
stacked_params_mapping = [
|
sglang/srt/models/qwen2.py
CHANGED
@@ -22,8 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from vllm.config import CacheConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
26
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
27
25
|
from vllm.model_executor.layers.linear import (
|
28
26
|
MergedColumnParallelLinear,
|
29
27
|
QKVParallelLinear,
|
@@ -37,8 +35,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
37
35
|
)
|
38
36
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
37
|
|
38
|
+
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
41
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
42
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
45
|
|
44
46
|
Qwen2Config = None
|
@@ -275,6 +277,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
275
277
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
276
278
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
277
279
|
self.logits_processor = LogitsProcessor(config)
|
280
|
+
self.sampler = Sampler()
|
281
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
278
282
|
|
279
283
|
@torch.no_grad()
|
280
284
|
def forward(
|
@@ -283,11 +287,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|
283
287
|
positions: torch.Tensor,
|
284
288
|
input_metadata: InputMetadata,
|
285
289
|
input_embeds: torch.Tensor = None,
|
290
|
+
get_embedding: bool = False,
|
286
291
|
) -> torch.Tensor:
|
287
292
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
288
|
-
|
289
|
-
|
290
|
-
|
293
|
+
if not get_embedding:
|
294
|
+
logits_output = self.logits_processor(
|
295
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
296
|
+
)
|
297
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
298
|
+
return sample_output, logits_output
|
299
|
+
else:
|
300
|
+
return self.pooler(hidden_states, input_metadata)
|
291
301
|
|
292
302
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
293
303
|
stacked_params_mapping = [
|