sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -1,141 +1,269 @@
|
|
1
1
|
# Adapted from
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
3
3
|
"""Inference-only Mixtral model."""
|
4
|
-
from typing import
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
import torch.nn.functional as F
|
9
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
10
|
-
from sglang.srt.layers.radix_attention import RadixAttention
|
11
|
-
from sglang.srt.managers.router.model_runner import InputMetadata
|
12
9
|
from torch import nn
|
13
10
|
from transformers import MixtralConfig
|
11
|
+
from vllm import _custom_ops as ops
|
12
|
+
from vllm.config import CacheConfig
|
13
|
+
from vllm.distributed import (
|
14
|
+
get_tensor_model_parallel_rank,
|
15
|
+
get_tensor_model_parallel_world_size,
|
16
|
+
tensor_model_parallel_all_reduce,
|
17
|
+
)
|
18
|
+
from vllm.model_executor.layers.fused_moe import fused_moe
|
14
19
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
15
20
|
from vllm.model_executor.layers.linear import (
|
16
|
-
LinearMethodBase,
|
17
21
|
QKVParallelLinear,
|
18
22
|
ReplicatedLinear,
|
19
23
|
RowParallelLinear,
|
20
24
|
)
|
25
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
21
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
22
28
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
23
29
|
ParallelLMHead,
|
24
30
|
VocabParallelEmbedding,
|
25
31
|
)
|
26
|
-
from vllm.model_executor.
|
27
|
-
|
28
|
-
|
29
|
-
from vllm.model_executor.parallel_utils.parallel_state import (
|
30
|
-
get_tensor_model_parallel_rank,
|
31
|
-
get_tensor_model_parallel_world_size,
|
32
|
-
)
|
33
|
-
from vllm.model_executor.weight_utils import (
|
34
|
-
default_weight_loader,
|
35
|
-
hf_model_weights_iterator,
|
36
|
-
)
|
32
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
|
+
from vllm.model_executor.utils import set_weight_attrs
|
34
|
+
from vllm.utils import print_warning_once
|
37
35
|
|
36
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
39
|
+
|
40
|
+
|
41
|
+
class MixtralMoE(nn.Module):
|
42
|
+
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
43
|
+
across all ranks.
|
44
|
+
|
45
|
+
Each expert's weights are sharded across all ranks and a fused MoE
|
46
|
+
kernel is used for the forward pass, and finally we reduce the outputs
|
47
|
+
across ranks.
|
48
|
+
"""
|
38
49
|
|
39
|
-
class MixtralMLP(nn.Module):
|
40
50
|
def __init__(
|
41
51
|
self,
|
42
52
|
num_experts: int,
|
53
|
+
top_k: int,
|
43
54
|
hidden_size: int,
|
44
55
|
intermediate_size: int,
|
45
|
-
|
46
|
-
|
56
|
+
params_dtype: Optional[torch.dtype] = None,
|
57
|
+
tp_size: Optional[int] = None,
|
58
|
+
quant_config: Optional[QuantizationConfig] = None,
|
59
|
+
):
|
47
60
|
super().__init__()
|
48
|
-
self.
|
49
|
-
self.
|
50
|
-
self.
|
61
|
+
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
62
|
+
self.num_total_experts = num_experts
|
63
|
+
self.top_k = top_k
|
64
|
+
self.hidden_size = hidden_size
|
65
|
+
self.intermediate_size = intermediate_size // self.tp_size
|
66
|
+
self.quant_config = quant_config
|
67
|
+
|
68
|
+
# FIXME(pcmoritz): Make this more general to support different
|
69
|
+
# quantization schemes
|
70
|
+
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
71
|
+
|
72
|
+
if params_dtype is None:
|
73
|
+
params_dtype = torch.get_default_dtype()
|
74
|
+
self.params_dtype = params_dtype
|
51
75
|
|
52
|
-
|
53
|
-
|
76
|
+
# Gate always runs at half / full precision for now.
|
77
|
+
self.gate = ReplicatedLinear(
|
78
|
+
self.hidden_size,
|
79
|
+
self.num_total_experts,
|
80
|
+
bias=False,
|
81
|
+
params_dtype=self.params_dtype,
|
82
|
+
quant_config=None,
|
54
83
|
)
|
55
|
-
|
56
|
-
|
84
|
+
|
85
|
+
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
86
|
+
params_dtype = torch.float8_e4m3fn
|
87
|
+
|
88
|
+
self.w13_weight = nn.Parameter(
|
89
|
+
torch.empty(
|
90
|
+
self.num_total_experts,
|
91
|
+
2 * self.intermediate_size,
|
92
|
+
self.hidden_size,
|
93
|
+
dtype=params_dtype,
|
94
|
+
)
|
57
95
|
)
|
58
|
-
self.
|
59
|
-
|
96
|
+
self.w2_weight = nn.Parameter(
|
97
|
+
torch.empty(
|
98
|
+
self.num_total_experts,
|
99
|
+
self.hidden_size,
|
100
|
+
self.intermediate_size,
|
101
|
+
dtype=params_dtype,
|
102
|
+
)
|
60
103
|
)
|
61
104
|
|
62
|
-
|
63
|
-
|
105
|
+
set_weight_attrs(
|
106
|
+
self.w13_weight,
|
107
|
+
{
|
108
|
+
"weight_loader": self.weight_loader,
|
109
|
+
},
|
110
|
+
)
|
111
|
+
set_weight_attrs(
|
112
|
+
self.w2_weight,
|
113
|
+
{
|
114
|
+
"weight_loader": self.weight_loader,
|
115
|
+
},
|
116
|
+
)
|
64
117
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
118
|
+
# Used for fp8.
|
119
|
+
self.w13_scale = None
|
120
|
+
self.w2_scale = None
|
121
|
+
self.a13_scale = None
|
122
|
+
self.a2_scale = None
|
123
|
+
|
124
|
+
if self.use_fp8:
|
125
|
+
# WEIGHT_SCALE (for fp8)
|
126
|
+
self.w13_scale = nn.Parameter(
|
127
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
128
|
+
requires_grad=False,
|
129
|
+
)
|
130
|
+
self.w2_scale = nn.Parameter(
|
131
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
132
|
+
requires_grad=False,
|
133
|
+
)
|
72
134
|
|
135
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
136
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
137
|
+
# process_weights_after_loading()
|
138
|
+
if quant_config.is_checkpoint_fp8_serialized:
|
139
|
+
set_weight_attrs(
|
140
|
+
self.w13_scale,
|
141
|
+
{
|
142
|
+
"weight_loader": self.weight_loader,
|
143
|
+
},
|
144
|
+
)
|
145
|
+
set_weight_attrs(
|
146
|
+
self.w2_scale,
|
147
|
+
{
|
148
|
+
"weight_loader": self.weight_loader,
|
149
|
+
},
|
150
|
+
)
|
73
151
|
|
74
|
-
|
75
|
-
|
152
|
+
# ACT_SCALE (for fp8)
|
153
|
+
if quant_config.activation_scheme == "static":
|
154
|
+
if not quant_config.is_checkpoint_fp8_serialized:
|
155
|
+
raise ValueError(
|
156
|
+
"Found static activation scheme for checkpoint that "
|
157
|
+
"was not serialized fp8."
|
158
|
+
)
|
159
|
+
self.a13_scale = nn.Parameter(
|
160
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
161
|
+
requires_grad=False,
|
162
|
+
)
|
163
|
+
self.a2_scale = nn.Parameter(
|
164
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
165
|
+
requires_grad=False,
|
166
|
+
)
|
167
|
+
|
168
|
+
set_weight_attrs(
|
169
|
+
self.a13_scale,
|
170
|
+
{
|
171
|
+
"weight_loader": self.weight_loader,
|
172
|
+
},
|
173
|
+
)
|
174
|
+
set_weight_attrs(
|
175
|
+
self.a2_scale,
|
176
|
+
{
|
177
|
+
"weight_loader": self.weight_loader,
|
178
|
+
},
|
179
|
+
)
|
180
|
+
|
181
|
+
def weight_loader(
|
76
182
|
self,
|
77
|
-
|
78
|
-
|
183
|
+
param: nn.Parameter,
|
184
|
+
loaded_weight: torch.Tensor,
|
185
|
+
weight_name: str,
|
186
|
+
expert_id: int,
|
79
187
|
):
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
if
|
87
|
-
|
88
|
-
|
89
|
-
|
188
|
+
tp_rank = get_tensor_model_parallel_rank()
|
189
|
+
param_data = param.data
|
190
|
+
shard_size = self.intermediate_size
|
191
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
192
|
+
if weight_name.endswith("w1.weight"):
|
193
|
+
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
194
|
+
if weight_name.endswith("w3.weight"):
|
195
|
+
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
196
|
+
shard, :
|
197
|
+
]
|
198
|
+
if weight_name.endswith("w2.weight"):
|
199
|
+
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
200
|
+
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
201
|
+
param_data[expert_id] = loaded_weight
|
202
|
+
|
203
|
+
def process_weights_after_loading(self):
|
204
|
+
# Fp8 is the only case where we need to process after loading.
|
205
|
+
if not self.use_fp8:
|
206
|
+
return
|
207
|
+
|
208
|
+
# If checkpoint is fp16, quantize here.
|
209
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
210
|
+
w13_weight = torch.empty_like(
|
211
|
+
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
90
212
|
)
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
if not self.expert_indicies:
|
96
|
-
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
97
|
-
|
98
|
-
self.experts = nn.ModuleList(
|
99
|
-
[
|
100
|
-
(
|
101
|
-
MixtralMLP(
|
102
|
-
self.num_total_experts,
|
103
|
-
config.hidden_size,
|
104
|
-
config.intermediate_size,
|
105
|
-
linear_method=linear_method,
|
106
|
-
)
|
107
|
-
if idx in self.expert_indicies
|
108
|
-
else None
|
213
|
+
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
214
|
+
for expert in range(self.num_total_experts):
|
215
|
+
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
216
|
+
self.w13_weight.data[expert, :, :]
|
109
217
|
)
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
218
|
+
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
219
|
+
self.w2_weight.data[expert, :, :]
|
220
|
+
)
|
221
|
+
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
222
|
+
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
223
|
+
|
224
|
+
# If checkpoint is fp8 + static, cleanup act_scales.
|
225
|
+
# Since state_dict has an act_scale per expert but our kernels
|
226
|
+
# are passed one act_scale shared across all experts.
|
227
|
+
elif self.quant_config.activation_scheme == "static":
|
228
|
+
if self.a13_scale is None or self.a2_scale is None:
|
229
|
+
raise ValueError(
|
230
|
+
"QuantConfig has static quantization, but found "
|
231
|
+
"activation scales are None."
|
232
|
+
)
|
233
|
+
|
234
|
+
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
235
|
+
print_warning_once(
|
236
|
+
"Found act_scales that are not equal for fp8 MoE layer. "
|
237
|
+
"Using the maximum across experts for each layer. "
|
238
|
+
)
|
239
|
+
|
240
|
+
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
241
|
+
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
116
242
|
|
117
243
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
244
|
+
num_tokens, hidden_size = hidden_states.shape
|
245
|
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
246
|
+
# router_logits: (num_tokens, n_experts)
|
118
247
|
router_logits, _ = self.gate(hidden_states)
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
248
|
+
final_hidden_states = fused_moe(
|
249
|
+
hidden_states,
|
250
|
+
self.w13_weight,
|
251
|
+
self.w2_weight,
|
252
|
+
router_logits,
|
253
|
+
self.top_k,
|
254
|
+
renormalize=True,
|
255
|
+
inplace=True,
|
256
|
+
use_fp8=self.use_fp8,
|
257
|
+
w1_scale=self.w13_scale,
|
258
|
+
w2_scale=self.w2_scale,
|
259
|
+
a1_scale=self.a13_scale,
|
260
|
+
a2_scale=self.a2_scale,
|
123
261
|
)
|
124
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
125
262
|
|
126
|
-
|
127
|
-
|
128
|
-
expert_layer = self.experts[expert_idx]
|
129
|
-
expert_mask = selected_experts == expert_idx
|
130
|
-
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
|
263
|
+
if self.tp_size > 1:
|
264
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
131
265
|
|
132
|
-
|
133
|
-
if final_hidden_states is None:
|
134
|
-
final_hidden_states = current_hidden_states
|
135
|
-
else:
|
136
|
-
final_hidden_states.add_(current_hidden_states)
|
137
|
-
|
138
|
-
return tensor_model_parallel_all_reduce(final_hidden_states)
|
266
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
139
267
|
|
140
268
|
|
141
269
|
class MixtralAttention(nn.Module):
|
@@ -147,7 +275,7 @@ class MixtralAttention(nn.Module):
|
|
147
275
|
layer_id: int = 0,
|
148
276
|
max_position: int = 4096 * 32,
|
149
277
|
rope_theta: float = 10000,
|
150
|
-
|
278
|
+
quant_config: Optional[QuantizationConfig] = None,
|
151
279
|
sliding_window: Optional[int] = None,
|
152
280
|
) -> None:
|
153
281
|
super().__init__()
|
@@ -179,13 +307,13 @@ class MixtralAttention(nn.Module):
|
|
179
307
|
self.total_num_heads,
|
180
308
|
self.total_num_kv_heads,
|
181
309
|
bias=False,
|
182
|
-
|
310
|
+
quant_config=quant_config,
|
183
311
|
)
|
184
312
|
self.o_proj = RowParallelLinear(
|
185
313
|
self.total_num_heads * self.head_dim,
|
186
314
|
hidden_size,
|
187
315
|
bias=False,
|
188
|
-
|
316
|
+
quant_config=quant_config,
|
189
317
|
)
|
190
318
|
self.rotary_emb = get_rope(
|
191
319
|
self.head_dim,
|
@@ -221,7 +349,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
221
349
|
self,
|
222
350
|
config: MixtralConfig,
|
223
351
|
layer_id: int = 0,
|
224
|
-
|
352
|
+
quant_config: Optional[QuantizationConfig] = None,
|
225
353
|
) -> None:
|
226
354
|
super().__init__()
|
227
355
|
self.hidden_size = config.hidden_size
|
@@ -235,9 +363,15 @@ class MixtralDecoderLayer(nn.Module):
|
|
235
363
|
layer_id=layer_id,
|
236
364
|
rope_theta=rope_theta,
|
237
365
|
sliding_window=config.sliding_window,
|
238
|
-
|
366
|
+
quant_config=quant_config,
|
367
|
+
)
|
368
|
+
self.block_sparse_moe = MixtralMoE(
|
369
|
+
num_experts=config.num_local_experts,
|
370
|
+
top_k=config.num_experts_per_tok,
|
371
|
+
hidden_size=config.hidden_size,
|
372
|
+
intermediate_size=config.intermediate_size,
|
373
|
+
quant_config=quant_config,
|
239
374
|
)
|
240
|
-
self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
|
241
375
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
242
376
|
self.post_attention_layernorm = RMSNorm(
|
243
377
|
config.hidden_size, eps=config.rms_norm_eps
|
@@ -272,7 +406,7 @@ class MixtralModel(nn.Module):
|
|
272
406
|
def __init__(
|
273
407
|
self,
|
274
408
|
config: MixtralConfig,
|
275
|
-
|
409
|
+
quant_config: Optional[QuantizationConfig] = None,
|
276
410
|
) -> None:
|
277
411
|
super().__init__()
|
278
412
|
self.padding_idx = config.pad_token_id
|
@@ -285,7 +419,7 @@ class MixtralModel(nn.Module):
|
|
285
419
|
# config.num_hidden_layers=16
|
286
420
|
self.layers = nn.ModuleList(
|
287
421
|
[
|
288
|
-
MixtralDecoderLayer(config, i,
|
422
|
+
MixtralDecoderLayer(config, i, quant_config=quant_config)
|
289
423
|
for i in range(config.num_hidden_layers)
|
290
424
|
]
|
291
425
|
)
|
@@ -316,12 +450,13 @@ class MixtralForCausalLM(nn.Module):
|
|
316
450
|
def __init__(
|
317
451
|
self,
|
318
452
|
config: MixtralConfig,
|
319
|
-
|
453
|
+
quant_config: Optional[QuantizationConfig] = None,
|
454
|
+
cache_config: Optional[CacheConfig] = None,
|
320
455
|
) -> None:
|
321
456
|
super().__init__()
|
322
457
|
self.config = config
|
323
|
-
self.
|
324
|
-
self.model = MixtralModel(config,
|
458
|
+
self.quant_config = quant_config
|
459
|
+
self.model = MixtralModel(config, quant_config=quant_config)
|
325
460
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
326
461
|
self.logits_processor = LogitsProcessor(config)
|
327
462
|
|
@@ -337,13 +472,7 @@ class MixtralForCausalLM(nn.Module):
|
|
337
472
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
338
473
|
)
|
339
474
|
|
340
|
-
def load_weights(
|
341
|
-
self,
|
342
|
-
model_name_or_path: str,
|
343
|
-
cache_dir: Optional[str] = None,
|
344
|
-
load_format: str = "auto",
|
345
|
-
revision: Optional[str] = None,
|
346
|
-
):
|
475
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
347
476
|
stacked_params_mapping = [
|
348
477
|
# (param_name, shard_name, shard_id)
|
349
478
|
("qkv_proj", "q_proj", "q"),
|
@@ -351,16 +480,47 @@ class MixtralForCausalLM(nn.Module):
|
|
351
480
|
("qkv_proj", "v_proj", "v"),
|
352
481
|
]
|
353
482
|
|
483
|
+
expert_params_mapping = (
|
484
|
+
[
|
485
|
+
# These are the weight scales for the experts
|
486
|
+
# (param_name, weight_name, expert_id)
|
487
|
+
(
|
488
|
+
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
489
|
+
f"experts.{expert_id}.{weight_name}.weight_scale",
|
490
|
+
expert_id,
|
491
|
+
)
|
492
|
+
for expert_id in range(self.config.num_local_experts)
|
493
|
+
for weight_name in ["w1", "w2", "w3"]
|
494
|
+
]
|
495
|
+
+ [
|
496
|
+
# These are the weights for the experts
|
497
|
+
# (param_name, weight_name, expert_id)
|
498
|
+
(
|
499
|
+
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
500
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
501
|
+
expert_id,
|
502
|
+
)
|
503
|
+
for expert_id in range(self.config.num_local_experts)
|
504
|
+
for weight_name in ["w1", "w2", "w3"]
|
505
|
+
]
|
506
|
+
+ [
|
507
|
+
# These are the activation scales for the experts
|
508
|
+
# (param_name, weight_name, expert_id)
|
509
|
+
(
|
510
|
+
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
511
|
+
f"experts.{expert_id}.{weight_name}.act_scale",
|
512
|
+
expert_id,
|
513
|
+
)
|
514
|
+
for expert_id in range(self.config.num_local_experts)
|
515
|
+
for weight_name in ["w1", "w2", "w3"]
|
516
|
+
]
|
517
|
+
)
|
518
|
+
|
354
519
|
params_dict = dict(self.named_parameters())
|
355
|
-
for name, loaded_weight in
|
356
|
-
model_name_or_path,
|
357
|
-
cache_dir,
|
358
|
-
load_format,
|
359
|
-
revision,
|
360
|
-
fall_back_to_pt=False,
|
361
|
-
):
|
520
|
+
for name, loaded_weight in weights:
|
362
521
|
if "rotary_emb.inv_freq" in name:
|
363
522
|
continue
|
523
|
+
|
364
524
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
365
525
|
if weight_name not in name:
|
366
526
|
continue
|
@@ -373,15 +533,30 @@ class MixtralForCausalLM(nn.Module):
|
|
373
533
|
weight_loader(param, loaded_weight, shard_id)
|
374
534
|
break
|
375
535
|
else:
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
536
|
+
for param_name, weight_name, expert_id in expert_params_mapping:
|
537
|
+
if weight_name not in name:
|
538
|
+
continue
|
539
|
+
name = name.replace(weight_name, param_name)
|
540
|
+
param = params_dict[name]
|
541
|
+
weight_loader = param.weight_loader
|
542
|
+
weight_loader(
|
543
|
+
param, loaded_weight, weight_name, expert_id=expert_id
|
544
|
+
)
|
545
|
+
break
|
546
|
+
else:
|
547
|
+
# Skip loading extra bias for GPTQ models.
|
548
|
+
if name.endswith(".bias") and name not in params_dict:
|
549
|
+
continue
|
550
|
+
param = params_dict[name]
|
551
|
+
weight_loader = getattr(
|
552
|
+
param, "weight_loader", default_weight_loader
|
553
|
+
)
|
554
|
+
weight_loader(param, loaded_weight)
|
555
|
+
|
556
|
+
|
557
|
+
def all_close_1d(x: torch.Tensor) -> bool:
|
558
|
+
assert len(x.shape) == 1
|
559
|
+
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
385
560
|
|
386
561
|
|
387
562
|
EntryClass = MixtralForCausalLM
|