sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from torch import nn
|
5
|
+
from transformers import LlamaConfig
|
6
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
|
+
|
8
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
|
+
from sglang.srt.model_executor.model_runner import InputMetadata
|
10
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
|
11
|
+
|
12
|
+
|
13
|
+
class LlamaEmbeddingModel(nn.Module):
|
14
|
+
def __init__(
|
15
|
+
self,
|
16
|
+
config: LlamaConfig,
|
17
|
+
quant_config=None,
|
18
|
+
cache_config=None,
|
19
|
+
efficient_weight_load=False,
|
20
|
+
) -> None:
|
21
|
+
super().__init__()
|
22
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
23
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
24
|
+
|
25
|
+
@torch.no_grad()
|
26
|
+
def forward(
|
27
|
+
self,
|
28
|
+
input_ids: torch.Tensor,
|
29
|
+
positions: torch.Tensor,
|
30
|
+
input_metadata: InputMetadata,
|
31
|
+
input_embeds: torch.Tensor = None,
|
32
|
+
) -> EmbeddingPoolerOutput:
|
33
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
34
|
+
return self.pooler(hidden_states, input_metadata)
|
35
|
+
|
36
|
+
def load_weights(
|
37
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
|
38
|
+
):
|
39
|
+
stacked_params_mapping = [
|
40
|
+
# (param_name, shard_name, shard_id)
|
41
|
+
("qkv_proj", "q_proj", "q"),
|
42
|
+
("qkv_proj", "k_proj", "k"),
|
43
|
+
("qkv_proj", "v_proj", "v"),
|
44
|
+
("gate_up_proj", "gate_proj", 0),
|
45
|
+
("gate_up_proj", "up_proj", 1),
|
46
|
+
]
|
47
|
+
params_dict = dict(self.model.named_parameters())
|
48
|
+
|
49
|
+
def load_weights_per_param(name, loaded_weight):
|
50
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
51
|
+
return
|
52
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
53
|
+
# Models trained using ColossalAI may include these tensors in
|
54
|
+
# the checkpoint. Skip them.
|
55
|
+
return
|
56
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
57
|
+
if weight_name not in name:
|
58
|
+
continue
|
59
|
+
name = name.replace(weight_name, param_name)
|
60
|
+
# Skip loading extra bias for GPTQ models.
|
61
|
+
if name.endswith(".bias") and name not in params_dict:
|
62
|
+
continue
|
63
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
64
|
+
continue
|
65
|
+
param = params_dict[name]
|
66
|
+
weight_loader = param.weight_loader
|
67
|
+
weight_loader(param, loaded_weight, shard_id)
|
68
|
+
break
|
69
|
+
else:
|
70
|
+
# Skip loading extra bias for GPTQ models.
|
71
|
+
if name.endswith(".bias") and name not in params_dict:
|
72
|
+
return
|
73
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
74
|
+
return
|
75
|
+
param = params_dict[name]
|
76
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
77
|
+
weight_loader(param, loaded_weight)
|
78
|
+
|
79
|
+
if name is None or loaded_weight is None:
|
80
|
+
for name, loaded_weight in weights:
|
81
|
+
load_weights_per_param(name, loaded_weight)
|
82
|
+
else:
|
83
|
+
load_weights_per_param(name, loaded_weight)
|
84
|
+
|
85
|
+
|
86
|
+
EntryClass = LlamaEmbeddingModel
|
87
|
+
# compat: e5-mistral model.config class == MistralModel
|
88
|
+
EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
|
sglang/srt/models/minicpm.py
CHANGED
@@ -22,8 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from vllm.config import CacheConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
26
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
27
25
|
from vllm.model_executor.layers.linear import (
|
28
26
|
MergedColumnParallelLinear,
|
29
27
|
QKVParallelLinear,
|
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
37
35
|
)
|
38
36
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
37
|
|
38
|
+
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/mixtral.py
CHANGED
@@ -18,36 +18,27 @@ limitations under the License.
|
|
18
18
|
"""Inference-only Mixtral model."""
|
19
19
|
from typing import Iterable, Optional, Tuple
|
20
20
|
|
21
|
-
import numpy as np
|
22
21
|
import torch
|
23
|
-
import torch.nn.functional as F
|
24
22
|
from torch import nn
|
25
23
|
from transformers import MixtralConfig
|
26
|
-
from vllm import _custom_ops as ops
|
27
24
|
from vllm.config import CacheConfig
|
28
|
-
from vllm.distributed import
|
29
|
-
|
30
|
-
get_tensor_model_parallel_world_size,
|
31
|
-
tensor_model_parallel_all_reduce,
|
32
|
-
)
|
33
|
-
from vllm.model_executor.layers.fused_moe import fused_moe
|
34
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
25
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
35
27
|
from vllm.model_executor.layers.linear import (
|
36
28
|
QKVParallelLinear,
|
37
29
|
ReplicatedLinear,
|
38
30
|
RowParallelLinear,
|
39
31
|
)
|
40
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
41
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
42
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
43
34
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
44
36
|
ParallelLMHead,
|
45
37
|
VocabParallelEmbedding,
|
46
38
|
)
|
47
39
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
48
|
-
from vllm.model_executor.utils import set_weight_attrs
|
49
|
-
from vllm.utils import print_warning_once
|
50
40
|
|
41
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
51
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
53
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
@@ -69,216 +60,44 @@ class MixtralMoE(nn.Module):
|
|
69
60
|
hidden_size: int,
|
70
61
|
intermediate_size: int,
|
71
62
|
params_dtype: Optional[torch.dtype] = None,
|
72
|
-
tp_size: Optional[int] = None,
|
73
63
|
quant_config: Optional[QuantizationConfig] = None,
|
64
|
+
tp_size: Optional[int] = None,
|
65
|
+
prefix: str = "",
|
74
66
|
):
|
75
67
|
super().__init__()
|
76
|
-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
77
|
-
self.num_total_experts = num_experts
|
78
|
-
self.top_k = top_k
|
79
68
|
self.hidden_size = hidden_size
|
80
|
-
self.intermediate_size = intermediate_size // self.tp_size
|
81
|
-
self.quant_config = quant_config
|
82
|
-
|
83
|
-
# FIXME(pcmoritz): Make this more general to support different
|
84
|
-
# quantization schemes
|
85
|
-
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
86
|
-
|
87
|
-
if params_dtype is None:
|
88
|
-
params_dtype = torch.get_default_dtype()
|
89
|
-
self.params_dtype = params_dtype
|
90
69
|
|
91
70
|
# Gate always runs at half / full precision for now.
|
92
71
|
self.gate = ReplicatedLinear(
|
93
|
-
|
94
|
-
|
72
|
+
hidden_size,
|
73
|
+
num_experts,
|
95
74
|
bias=False,
|
96
|
-
params_dtype=
|
75
|
+
params_dtype=params_dtype,
|
97
76
|
quant_config=None,
|
77
|
+
prefix=f"{prefix}.gate",
|
98
78
|
)
|
99
79
|
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
self.w2_weight = nn.Parameter(
|
112
|
-
torch.empty(
|
113
|
-
self.num_total_experts,
|
114
|
-
self.hidden_size,
|
115
|
-
self.intermediate_size,
|
116
|
-
dtype=params_dtype,
|
117
|
-
)
|
118
|
-
)
|
119
|
-
|
120
|
-
set_weight_attrs(
|
121
|
-
self.w13_weight,
|
122
|
-
{
|
123
|
-
"weight_loader": self.weight_loader,
|
124
|
-
},
|
125
|
-
)
|
126
|
-
set_weight_attrs(
|
127
|
-
self.w2_weight,
|
128
|
-
{
|
129
|
-
"weight_loader": self.weight_loader,
|
130
|
-
},
|
80
|
+
self.experts = FusedMoE(
|
81
|
+
num_experts=num_experts,
|
82
|
+
top_k=top_k,
|
83
|
+
hidden_size=hidden_size,
|
84
|
+
intermediate_size=intermediate_size,
|
85
|
+
params_dtype=params_dtype,
|
86
|
+
reduce_results=True,
|
87
|
+
renormalize=True,
|
88
|
+
quant_config=quant_config,
|
89
|
+
tp_size=tp_size,
|
90
|
+
prefix=f"{prefix}.experts",
|
131
91
|
)
|
132
92
|
|
133
|
-
# Used for fp8.
|
134
|
-
self.w13_scale = None
|
135
|
-
self.w2_scale = None
|
136
|
-
self.a13_scale = None
|
137
|
-
self.a2_scale = None
|
138
|
-
|
139
|
-
if self.use_fp8:
|
140
|
-
# WEIGHT_SCALE (for fp8)
|
141
|
-
self.w13_scale = nn.Parameter(
|
142
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
143
|
-
requires_grad=False,
|
144
|
-
)
|
145
|
-
self.w2_scale = nn.Parameter(
|
146
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
147
|
-
requires_grad=False,
|
148
|
-
)
|
149
|
-
|
150
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
151
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
152
|
-
# process_weights_after_loading()
|
153
|
-
if quant_config.is_checkpoint_fp8_serialized:
|
154
|
-
set_weight_attrs(
|
155
|
-
self.w13_scale,
|
156
|
-
{
|
157
|
-
"weight_loader": self.weight_loader,
|
158
|
-
},
|
159
|
-
)
|
160
|
-
set_weight_attrs(
|
161
|
-
self.w2_scale,
|
162
|
-
{
|
163
|
-
"weight_loader": self.weight_loader,
|
164
|
-
},
|
165
|
-
)
|
166
|
-
|
167
|
-
# ACT_SCALE (for fp8)
|
168
|
-
if quant_config.activation_scheme == "static":
|
169
|
-
if not quant_config.is_checkpoint_fp8_serialized:
|
170
|
-
raise ValueError(
|
171
|
-
"Found static activation scheme for checkpoint that "
|
172
|
-
"was not serialized fp8."
|
173
|
-
)
|
174
|
-
self.a13_scale = nn.Parameter(
|
175
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
176
|
-
requires_grad=False,
|
177
|
-
)
|
178
|
-
self.a2_scale = nn.Parameter(
|
179
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
180
|
-
requires_grad=False,
|
181
|
-
)
|
182
|
-
|
183
|
-
set_weight_attrs(
|
184
|
-
self.a13_scale,
|
185
|
-
{
|
186
|
-
"weight_loader": self.weight_loader,
|
187
|
-
},
|
188
|
-
)
|
189
|
-
set_weight_attrs(
|
190
|
-
self.a2_scale,
|
191
|
-
{
|
192
|
-
"weight_loader": self.weight_loader,
|
193
|
-
},
|
194
|
-
)
|
195
|
-
|
196
|
-
def weight_loader(
|
197
|
-
self,
|
198
|
-
param: nn.Parameter,
|
199
|
-
loaded_weight: torch.Tensor,
|
200
|
-
weight_name: str,
|
201
|
-
expert_id: int,
|
202
|
-
):
|
203
|
-
tp_rank = get_tensor_model_parallel_rank()
|
204
|
-
param_data = param.data
|
205
|
-
shard_size = self.intermediate_size
|
206
|
-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
207
|
-
if weight_name.endswith("w1.weight"):
|
208
|
-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
209
|
-
if weight_name.endswith("w3.weight"):
|
210
|
-
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
211
|
-
shard, :
|
212
|
-
]
|
213
|
-
if weight_name.endswith("w2.weight"):
|
214
|
-
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
215
|
-
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
216
|
-
param_data[expert_id] = loaded_weight
|
217
|
-
|
218
|
-
def process_weights_after_loading(self):
|
219
|
-
# Fp8 is the only case where we need to process after loading.
|
220
|
-
if not self.use_fp8:
|
221
|
-
return
|
222
|
-
|
223
|
-
# If checkpoint is fp16, quantize here.
|
224
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
225
|
-
w13_weight = torch.empty_like(
|
226
|
-
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
227
|
-
)
|
228
|
-
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
229
|
-
for expert in range(self.num_total_experts):
|
230
|
-
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
231
|
-
self.w13_weight.data[expert, :, :]
|
232
|
-
)
|
233
|
-
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
234
|
-
self.w2_weight.data[expert, :, :]
|
235
|
-
)
|
236
|
-
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
237
|
-
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
238
|
-
|
239
|
-
# If checkpoint is fp8 + static, cleanup act_scales.
|
240
|
-
# Since state_dict has an act_scale per expert but our kernels
|
241
|
-
# are passed one act_scale shared across all experts.
|
242
|
-
elif self.quant_config.activation_scheme == "static":
|
243
|
-
if self.a13_scale is None or self.a2_scale is None:
|
244
|
-
raise ValueError(
|
245
|
-
"QuantConfig has static quantization, but found "
|
246
|
-
"activation scales are None."
|
247
|
-
)
|
248
|
-
|
249
|
-
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
250
|
-
print_warning_once(
|
251
|
-
"Found act_scales that are not equal for fp8 MoE layer. "
|
252
|
-
"Using the maximum across experts for each layer. "
|
253
|
-
)
|
254
|
-
|
255
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
256
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
257
|
-
|
258
93
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
259
|
-
|
94
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
95
|
+
orig_shape = hidden_states.shape
|
260
96
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
261
97
|
# router_logits: (num_tokens, n_experts)
|
262
98
|
router_logits, _ = self.gate(hidden_states)
|
263
|
-
final_hidden_states =
|
264
|
-
|
265
|
-
self.w13_weight,
|
266
|
-
self.w2_weight,
|
267
|
-
router_logits,
|
268
|
-
self.top_k,
|
269
|
-
renormalize=True,
|
270
|
-
inplace=True,
|
271
|
-
use_fp8=self.use_fp8,
|
272
|
-
w1_scale=self.w13_scale,
|
273
|
-
w2_scale=self.w2_scale,
|
274
|
-
a1_scale=self.a13_scale,
|
275
|
-
a2_scale=self.a2_scale,
|
276
|
-
)
|
277
|
-
|
278
|
-
if self.tp_size > 1:
|
279
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
280
|
-
|
281
|
-
return final_hidden_states.view(num_tokens, hidden_size)
|
99
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
100
|
+
return final_hidden_states.view(orig_shape)
|
282
101
|
|
283
102
|
|
284
103
|
class MixtralAttention(nn.Module):
|
@@ -291,7 +110,7 @@ class MixtralAttention(nn.Module):
|
|
291
110
|
max_position: int = 4096 * 32,
|
292
111
|
rope_theta: float = 10000,
|
293
112
|
quant_config: Optional[QuantizationConfig] = None,
|
294
|
-
|
113
|
+
prefix: str = "",
|
295
114
|
) -> None:
|
296
115
|
super().__init__()
|
297
116
|
self.hidden_size = hidden_size
|
@@ -314,7 +133,6 @@ class MixtralAttention(nn.Module):
|
|
314
133
|
self.kv_size = self.num_kv_heads * self.head_dim
|
315
134
|
self.scaling = self.head_dim**-0.5
|
316
135
|
self.rope_theta = rope_theta
|
317
|
-
self.sliding_window = sliding_window
|
318
136
|
|
319
137
|
self.qkv_proj = QKVParallelLinear(
|
320
138
|
hidden_size,
|
@@ -323,12 +141,14 @@ class MixtralAttention(nn.Module):
|
|
323
141
|
self.total_num_kv_heads,
|
324
142
|
bias=False,
|
325
143
|
quant_config=quant_config,
|
144
|
+
prefix=f"{prefix}.qkv_proj",
|
326
145
|
)
|
327
146
|
self.o_proj = RowParallelLinear(
|
328
147
|
self.total_num_heads * self.head_dim,
|
329
148
|
hidden_size,
|
330
149
|
bias=False,
|
331
150
|
quant_config=quant_config,
|
151
|
+
prefix=f"{prefix}.o_proj",
|
332
152
|
)
|
333
153
|
self.rotary_emb = get_rope(
|
334
154
|
self.head_dim,
|
@@ -365,6 +185,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
365
185
|
config: MixtralConfig,
|
366
186
|
layer_id: int = 0,
|
367
187
|
quant_config: Optional[QuantizationConfig] = None,
|
188
|
+
prefix: str = "",
|
368
189
|
) -> None:
|
369
190
|
super().__init__()
|
370
191
|
self.hidden_size = config.hidden_size
|
@@ -377,8 +198,8 @@ class MixtralDecoderLayer(nn.Module):
|
|
377
198
|
num_kv_heads=config.num_key_value_heads,
|
378
199
|
layer_id=layer_id,
|
379
200
|
rope_theta=rope_theta,
|
380
|
-
sliding_window=config.sliding_window,
|
381
201
|
quant_config=quant_config,
|
202
|
+
prefix=f"{prefix}.self_attn",
|
382
203
|
)
|
383
204
|
self.block_sparse_moe = MixtralMoE(
|
384
205
|
num_experts=config.num_local_experts,
|
@@ -386,6 +207,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
386
207
|
hidden_size=config.hidden_size,
|
387
208
|
intermediate_size=config.intermediate_size,
|
388
209
|
quant_config=quant_config,
|
210
|
+
prefix=f"{prefix}.block_sparse_moe",
|
389
211
|
)
|
390
212
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
391
213
|
self.post_attention_layernorm = RMSNorm(
|
@@ -422,6 +244,7 @@ class MixtralModel(nn.Module):
|
|
422
244
|
self,
|
423
245
|
config: MixtralConfig,
|
424
246
|
quant_config: Optional[QuantizationConfig] = None,
|
247
|
+
prefix: str = "",
|
425
248
|
) -> None:
|
426
249
|
super().__init__()
|
427
250
|
self.padding_idx = config.pad_token_id
|
@@ -431,10 +254,11 @@ class MixtralModel(nn.Module):
|
|
431
254
|
config.vocab_size,
|
432
255
|
config.hidden_size,
|
433
256
|
)
|
434
|
-
# config.num_hidden_layers=16
|
435
257
|
self.layers = nn.ModuleList(
|
436
258
|
[
|
437
|
-
MixtralDecoderLayer(
|
259
|
+
MixtralDecoderLayer(
|
260
|
+
config, i, quant_config=quant_config, prefix=f"{prefix}.layers"
|
261
|
+
)
|
438
262
|
for i in range(config.num_hidden_layers)
|
439
263
|
]
|
440
264
|
)
|
@@ -462,6 +286,7 @@ class MixtralModel(nn.Module):
|
|
462
286
|
|
463
287
|
|
464
288
|
class MixtralForCausalLM(nn.Module):
|
289
|
+
|
465
290
|
def __init__(
|
466
291
|
self,
|
467
292
|
config: MixtralConfig,
|
@@ -471,11 +296,10 @@ class MixtralForCausalLM(nn.Module):
|
|
471
296
|
super().__init__()
|
472
297
|
self.config = config
|
473
298
|
self.quant_config = quant_config
|
474
|
-
self.model = MixtralModel(config, quant_config=quant_config)
|
299
|
+
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
475
300
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
476
301
|
self.logits_processor = LogitsProcessor(config)
|
477
302
|
|
478
|
-
@torch.no_grad()
|
479
303
|
def forward(
|
480
304
|
self,
|
481
305
|
input_ids: torch.Tensor,
|
@@ -496,40 +320,13 @@ class MixtralForCausalLM(nn.Module):
|
|
496
320
|
("qkv_proj", "v_proj", "v"),
|
497
321
|
]
|
498
322
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
expert_id,
|
507
|
-
)
|
508
|
-
for expert_id in range(self.config.num_local_experts)
|
509
|
-
for weight_name in ["w1", "w2", "w3"]
|
510
|
-
]
|
511
|
-
+ [
|
512
|
-
# These are the weights for the experts
|
513
|
-
# (param_name, weight_name, expert_id)
|
514
|
-
(
|
515
|
-
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
516
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
517
|
-
expert_id,
|
518
|
-
)
|
519
|
-
for expert_id in range(self.config.num_local_experts)
|
520
|
-
for weight_name in ["w1", "w2", "w3"]
|
521
|
-
]
|
522
|
-
+ [
|
523
|
-
# These are the activation scales for the experts
|
524
|
-
# (param_name, weight_name, expert_id)
|
525
|
-
(
|
526
|
-
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
527
|
-
f"experts.{expert_id}.{weight_name}.act_scale",
|
528
|
-
expert_id,
|
529
|
-
)
|
530
|
-
for expert_id in range(self.config.num_local_experts)
|
531
|
-
for weight_name in ["w1", "w2", "w3"]
|
532
|
-
]
|
323
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
324
|
+
# (param_name, weight_name, expert_id, shard_id)
|
325
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
326
|
+
ckpt_gate_proj_name="w1",
|
327
|
+
ckpt_down_proj_name="w2",
|
328
|
+
ckpt_up_proj_name="w3",
|
329
|
+
num_experts=self.config.num_local_experts,
|
533
330
|
)
|
534
331
|
|
535
332
|
params_dict = dict(self.named_parameters())
|
@@ -544,25 +341,35 @@ class MixtralForCausalLM(nn.Module):
|
|
544
341
|
# Skip loading extra bias for GPTQ models.
|
545
342
|
if name.endswith(".bias") and name not in params_dict:
|
546
343
|
continue
|
344
|
+
|
547
345
|
param = params_dict[name]
|
548
346
|
weight_loader = param.weight_loader
|
549
347
|
weight_loader(param, loaded_weight, shard_id)
|
550
348
|
break
|
551
349
|
else:
|
552
|
-
for
|
350
|
+
for mapping in expert_params_mapping:
|
351
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
553
352
|
if weight_name not in name:
|
554
353
|
continue
|
555
354
|
name = name.replace(weight_name, param_name)
|
355
|
+
|
556
356
|
param = params_dict[name]
|
557
357
|
weight_loader = param.weight_loader
|
558
358
|
weight_loader(
|
559
|
-
param,
|
359
|
+
param,
|
360
|
+
loaded_weight,
|
361
|
+
weight_name,
|
362
|
+
shard_id=shard_id,
|
363
|
+
expert_id=expert_id,
|
560
364
|
)
|
561
365
|
break
|
562
366
|
else:
|
563
367
|
# Skip loading extra bias for GPTQ models.
|
564
368
|
if name.endswith(".bias") and name not in params_dict:
|
565
369
|
continue
|
370
|
+
if name is None:
|
371
|
+
continue
|
372
|
+
|
566
373
|
param = params_dict[name]
|
567
374
|
weight_loader = getattr(
|
568
375
|
param, "weight_loader", default_weight_loader
|
@@ -570,9 +377,4 @@ class MixtralForCausalLM(nn.Module):
|
|
570
377
|
weight_loader(param, loaded_weight)
|
571
378
|
|
572
379
|
|
573
|
-
def all_close_1d(x: torch.Tensor) -> bool:
|
574
|
-
assert len(x.shape) == 1
|
575
|
-
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
576
|
-
|
577
|
-
|
578
380
|
EntryClass = MixtralForCausalLM
|
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
get_tensor_model_parallel_world_size,
|
30
30
|
tensor_model_parallel_all_reduce,
|
31
31
|
)
|
32
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
33
32
|
from vllm.model_executor.layers.linear import (
|
34
33
|
QKVParallelLinear,
|
35
34
|
ReplicatedLinear,
|
@@ -43,6 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
43
42
|
)
|
44
43
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
45
44
|
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
@@ -160,7 +160,6 @@ class MixtralAttention(nn.Module):
|
|
160
160
|
max_position: int = 4096 * 32,
|
161
161
|
rope_theta: float = 10000,
|
162
162
|
quant_config: Optional[QuantizationConfig] = None,
|
163
|
-
sliding_window: Optional[int] = None,
|
164
163
|
) -> None:
|
165
164
|
super().__init__()
|
166
165
|
self.hidden_size = hidden_size
|
@@ -183,7 +182,6 @@ class MixtralAttention(nn.Module):
|
|
183
182
|
self.kv_size = self.num_kv_heads * self.head_dim
|
184
183
|
self.scaling = self.head_dim**-0.5
|
185
184
|
self.rope_theta = rope_theta
|
186
|
-
self.sliding_window = sliding_window
|
187
185
|
|
188
186
|
self.qkv_proj = QKVParallelLinear(
|
189
187
|
hidden_size,
|
@@ -246,7 +244,6 @@ class MixtralDecoderLayer(nn.Module):
|
|
246
244
|
num_kv_heads=config.num_key_value_heads,
|
247
245
|
layer_id=layer_id,
|
248
246
|
rope_theta=rope_theta,
|
249
|
-
sliding_window=config.sliding_window,
|
250
247
|
quant_config=quant_config,
|
251
248
|
)
|
252
249
|
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
|
sglang/srt/models/qwen.py
CHANGED
@@ -22,8 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.config import CacheConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
26
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
27
25
|
from vllm.model_executor.layers.linear import (
|
28
26
|
MergedColumnParallelLinear,
|
29
27
|
QKVParallelLinear,
|
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
37
35
|
)
|
38
36
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
37
|
|
38
|
+
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
sglang/srt/models/qwen2.py
CHANGED
@@ -22,8 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from vllm.config import CacheConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
26
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
27
25
|
from vllm.model_executor.layers.linear import (
|
28
26
|
MergedColumnParallelLinear,
|
29
27
|
QKVParallelLinear,
|
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
37
35
|
)
|
38
36
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
37
|
|
38
|
+
from sglang.srt.layers.activation import SiluAndMul
|
39
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|