sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +3 -1
- sglang/api.py +3 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +8 -1
- sglang/lang/interpreter.py +114 -67
- sglang/lang/ir.py +17 -2
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +1 -1
- sglang/srt/hf_transformers_utils.py +75 -1
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +15 -11
- sglang/srt/managers/router/infer_batch.py +103 -59
- sglang/srt/managers/router/manager.py +1 -1
- sglang/srt/managers/router/model_rpc.py +175 -122
- sglang/srt/managers/router/model_runner.py +91 -104
- sglang/srt/managers/router/radix_cache.py +7 -1
- sglang/srt/managers/router/scheduler.py +6 -6
- sglang/srt/managers/tokenizer_manager.py +152 -89
- sglang/srt/model_config.py +4 -5
- sglang/srt/models/commandr.py +10 -13
- sglang/srt/models/dbrx.py +9 -15
- sglang/srt/models/gemma.py +8 -15
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +19 -15
- sglang/srt/models/llava.py +84 -20
- sglang/srt/models/llavavid.py +11 -20
- sglang/srt/models/mixtral.py +248 -118
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +9 -13
- sglang/srt/models/qwen2.py +11 -13
- sglang/srt/models/stablelm.py +9 -15
- sglang/srt/models/yivl.py +17 -22
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +77 -42
- sglang/srt/server_args.py +51 -6
- sglang/srt/utils.py +124 -66
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +22 -4
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -417
- sglang-0.1.16.dist-info/RECORD +0 -72
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,671 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
3
|
+
"""Inference-only Grok1 model."""
|
4
|
+
from typing import Iterable, Optional, Tuple, List
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
import tqdm
|
10
|
+
from torch import nn
|
11
|
+
from transformers import PretrainedConfig
|
12
|
+
|
13
|
+
from vllm import _custom_ops as ops
|
14
|
+
from vllm.config import CacheConfig
|
15
|
+
from vllm.distributed import (
|
16
|
+
get_tensor_model_parallel_rank,
|
17
|
+
get_tensor_model_parallel_world_size,
|
18
|
+
tensor_model_parallel_all_reduce,
|
19
|
+
)
|
20
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
21
|
+
from vllm.model_executor.layers.linear import (
|
22
|
+
QKVParallelLinear,
|
23
|
+
ReplicatedLinear,
|
24
|
+
RowParallelLinear,
|
25
|
+
)
|
26
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
28
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
30
|
+
ParallelLMHead,
|
31
|
+
VocabParallelEmbedding,
|
32
|
+
)
|
33
|
+
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
34
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
|
+
from vllm.model_executor.utils import set_weight_attrs
|
36
|
+
from vllm.utils import print_warning_once
|
37
|
+
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
+
from sglang.srt.layers.fused_moe import fused_moe
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
42
|
+
|
43
|
+
|
44
|
+
use_fused = True
|
45
|
+
|
46
|
+
|
47
|
+
class Grok1MLP(nn.Module):
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
num_experts: int,
|
51
|
+
hidden_size: int,
|
52
|
+
intermediate_size: int,
|
53
|
+
quant_config: Optional[QuantizationConfig] = None,
|
54
|
+
) -> None:
|
55
|
+
super().__init__()
|
56
|
+
self.num_experts = num_experts
|
57
|
+
self.ffn_dim = intermediate_size
|
58
|
+
self.hidden_dim = hidden_size
|
59
|
+
|
60
|
+
self.w1 = ReplicatedLinear(
|
61
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
62
|
+
)
|
63
|
+
self.w2 = ReplicatedLinear(
|
64
|
+
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
65
|
+
)
|
66
|
+
self.w3 = ReplicatedLinear(
|
67
|
+
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
68
|
+
)
|
69
|
+
|
70
|
+
self.act_fn = nn.GELU()
|
71
|
+
|
72
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
73
|
+
w1_out, _ = self.w1(hidden_states)
|
74
|
+
w1_out = self.act_fn(w1_out)
|
75
|
+
w3_out, _ = self.w3(hidden_states)
|
76
|
+
current_hidden_states = w1_out * w3_out
|
77
|
+
current_hidden_states, _ = self.w2(current_hidden_states)
|
78
|
+
return current_hidden_states
|
79
|
+
|
80
|
+
|
81
|
+
class Grok1MoEUnfused(nn.Module):
|
82
|
+
def __init__(
|
83
|
+
self,
|
84
|
+
config: PretrainedConfig,
|
85
|
+
quant_config: Optional[QuantizationConfig] = None,
|
86
|
+
):
|
87
|
+
super().__init__()
|
88
|
+
self.config = config
|
89
|
+
self.rank = get_tensor_model_parallel_rank()
|
90
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
91
|
+
self.num_total_experts = config.num_local_experts
|
92
|
+
self.top_k = config.num_experts_per_tok
|
93
|
+
if self.tp_size > self.num_total_experts:
|
94
|
+
raise ValueError(
|
95
|
+
f"Tensor parallel size {self.tp_size} is greater than "
|
96
|
+
f"the number of experts {self.num_total_experts}."
|
97
|
+
)
|
98
|
+
# Split experts equally between ranks
|
99
|
+
self.expert_indicies = np.array_split(
|
100
|
+
range(self.num_total_experts), self.tp_size
|
101
|
+
)[self.rank].tolist()
|
102
|
+
if not self.expert_indicies:
|
103
|
+
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
104
|
+
|
105
|
+
self.experts = nn.ModuleList(
|
106
|
+
[
|
107
|
+
(
|
108
|
+
Grok1MLP(
|
109
|
+
self.num_total_experts,
|
110
|
+
config.hidden_size,
|
111
|
+
config.intermediate_size,
|
112
|
+
quant_config=quant_config,
|
113
|
+
)
|
114
|
+
if idx in self.expert_indicies
|
115
|
+
else None
|
116
|
+
)
|
117
|
+
for idx in range(self.num_total_experts)
|
118
|
+
]
|
119
|
+
)
|
120
|
+
self.gate = ReplicatedLinear(
|
121
|
+
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
122
|
+
)
|
123
|
+
|
124
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
125
|
+
router_logits, _ = self.gate(hidden_states)
|
126
|
+
router_logits = 30 * F.tanh(router_logits / 30)
|
127
|
+
|
128
|
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
129
|
+
routing_weights, selected_experts = torch.topk(
|
130
|
+
routing_weights, self.top_k, dim=-1
|
131
|
+
)
|
132
|
+
routing_weights = routing_weights.to(hidden_states.dtype)
|
133
|
+
hidden_dim = hidden_states.shape[1]
|
134
|
+
|
135
|
+
final_hidden_states = torch.zeros(
|
136
|
+
(hidden_states.shape[0], hidden_dim),
|
137
|
+
dtype=hidden_states.dtype, device=hidden_states.device
|
138
|
+
)
|
139
|
+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0)
|
140
|
+
|
141
|
+
for expert_idx in self.expert_indicies:
|
142
|
+
expert_layer = self.experts[expert_idx]
|
143
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
144
|
+
|
145
|
+
if top_x.shape[0] == 0:
|
146
|
+
continue
|
147
|
+
|
148
|
+
# in torch it is faster to index using lists than torch tensors
|
149
|
+
top_x_list = top_x.tolist()
|
150
|
+
idx_list = idx.tolist()
|
151
|
+
|
152
|
+
# Index the correct hidden states and compute the expert hidden state for
|
153
|
+
# the current expert. We need to make sure to multiply the output hidden
|
154
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
155
|
+
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
156
|
+
current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
|
157
|
+
|
158
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
159
|
+
# the `top_x` tensor here.
|
160
|
+
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
161
|
+
|
162
|
+
return tensor_model_parallel_all_reduce(final_hidden_states)
|
163
|
+
|
164
|
+
|
165
|
+
class Grok1MoE(nn.Module):
|
166
|
+
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
167
|
+
across all ranks.
|
168
|
+
|
169
|
+
Each expert's weights are sharded across all ranks and a fused MoE
|
170
|
+
kernel is used for the forward pass, and finally we reduce the outputs
|
171
|
+
across ranks.
|
172
|
+
"""
|
173
|
+
|
174
|
+
def __init__(
|
175
|
+
self,
|
176
|
+
num_experts: int,
|
177
|
+
top_k: int,
|
178
|
+
hidden_size: int,
|
179
|
+
intermediate_size: int,
|
180
|
+
params_dtype: Optional[torch.dtype] = None,
|
181
|
+
tp_size: Optional[int] = None,
|
182
|
+
quant_config: Optional[QuantizationConfig] = None,
|
183
|
+
):
|
184
|
+
super().__init__()
|
185
|
+
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
186
|
+
self.num_total_experts = num_experts
|
187
|
+
self.top_k = top_k
|
188
|
+
self.hidden_size = hidden_size
|
189
|
+
self.intermediate_size = intermediate_size // self.tp_size
|
190
|
+
self.quant_config = quant_config
|
191
|
+
|
192
|
+
# FIXME(pcmoritz): Make this more general to support different
|
193
|
+
# quantization schemes
|
194
|
+
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
195
|
+
|
196
|
+
if params_dtype is None:
|
197
|
+
params_dtype = torch.get_default_dtype()
|
198
|
+
self.params_dtype = params_dtype
|
199
|
+
|
200
|
+
# Gate always runs at half / full precision for now.
|
201
|
+
self.gate = ReplicatedLinear(self.hidden_size,
|
202
|
+
self.num_total_experts,
|
203
|
+
bias=False,
|
204
|
+
params_dtype=self.params_dtype,
|
205
|
+
quant_config=None)
|
206
|
+
|
207
|
+
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
208
|
+
params_dtype = torch.float8_e4m3fn
|
209
|
+
|
210
|
+
self.w13_weight = nn.Parameter(
|
211
|
+
torch.empty(self.num_total_experts,
|
212
|
+
2 * self.intermediate_size,
|
213
|
+
self.hidden_size,
|
214
|
+
dtype=params_dtype))
|
215
|
+
self.w2_weight = nn.Parameter(
|
216
|
+
torch.empty(self.num_total_experts,
|
217
|
+
self.hidden_size,
|
218
|
+
self.intermediate_size,
|
219
|
+
dtype=params_dtype))
|
220
|
+
|
221
|
+
set_weight_attrs(self.w13_weight, {
|
222
|
+
"weight_loader": self.weight_loader,
|
223
|
+
})
|
224
|
+
set_weight_attrs(self.w2_weight, {
|
225
|
+
"weight_loader": self.weight_loader,
|
226
|
+
})
|
227
|
+
|
228
|
+
# Used for fp8.
|
229
|
+
self.w13_scale = None
|
230
|
+
self.w2_scale = None
|
231
|
+
self.a13_scale = None
|
232
|
+
self.a2_scale = None
|
233
|
+
|
234
|
+
if self.use_fp8:
|
235
|
+
# WEIGHT_SCALE (for fp8)
|
236
|
+
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
237
|
+
dtype=torch.float32),
|
238
|
+
requires_grad=False)
|
239
|
+
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
240
|
+
dtype=torch.float32),
|
241
|
+
requires_grad=False)
|
242
|
+
|
243
|
+
# If loading fp8 checkpoint, pass the weight loaders.
|
244
|
+
# If loading an fp16 checkpoint, do not (we will quantize in
|
245
|
+
# process_weights_after_loading()
|
246
|
+
if quant_config.is_checkpoint_fp8_serialized:
|
247
|
+
set_weight_attrs(self.w13_scale, {
|
248
|
+
"weight_loader": self.weight_loader,
|
249
|
+
})
|
250
|
+
set_weight_attrs(self.w2_scale, {
|
251
|
+
"weight_loader": self.weight_loader,
|
252
|
+
})
|
253
|
+
|
254
|
+
# ACT_SCALE (for fp8)
|
255
|
+
if quant_config.activation_scheme == "static":
|
256
|
+
if not quant_config.is_checkpoint_fp8_serialized:
|
257
|
+
raise ValueError(
|
258
|
+
"Found static activation scheme for checkpoint that "
|
259
|
+
"was not serialized fp8.")
|
260
|
+
self.a13_scale = nn.Parameter(torch.zeros(
|
261
|
+
self.num_total_experts, dtype=torch.float32),
|
262
|
+
requires_grad=False)
|
263
|
+
self.a2_scale = nn.Parameter(torch.zeros(
|
264
|
+
self.num_total_experts, dtype=torch.float32),
|
265
|
+
requires_grad=False)
|
266
|
+
|
267
|
+
set_weight_attrs(self.a13_scale, {
|
268
|
+
"weight_loader": self.weight_loader,
|
269
|
+
})
|
270
|
+
set_weight_attrs(self.a2_scale, {
|
271
|
+
"weight_loader": self.weight_loader,
|
272
|
+
})
|
273
|
+
|
274
|
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
275
|
+
weight_name: str, expert_id: int, pre_sharded: bool):
|
276
|
+
param_data = param.data
|
277
|
+
shard_size = self.intermediate_size
|
278
|
+
if pre_sharded:
|
279
|
+
# The weight is already sharded. Readl the full shard
|
280
|
+
shard = slice(None)
|
281
|
+
else:
|
282
|
+
tp_rank = get_tensor_model_parallel_rank()
|
283
|
+
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
284
|
+
if weight_name.endswith("w1.weight"):
|
285
|
+
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
286
|
+
if weight_name.endswith("w3.weight"):
|
287
|
+
param_data[expert_id,
|
288
|
+
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
289
|
+
if weight_name.endswith("w2.weight"):
|
290
|
+
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
291
|
+
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
292
|
+
param_data[expert_id] = loaded_weight
|
293
|
+
|
294
|
+
def process_weights_after_loading(self):
|
295
|
+
# Fp8 is the only case where we need to process after loading.
|
296
|
+
if not self.use_fp8:
|
297
|
+
return
|
298
|
+
|
299
|
+
# If checkpoint is fp16, quantize here.
|
300
|
+
if not self.quant_config.is_checkpoint_fp8_serialized:
|
301
|
+
w13_weight = torch.empty_like(self.w13_weight.data,
|
302
|
+
dtype=torch.float8_e4m3fn)
|
303
|
+
w2_weight = torch.empty_like(self.w2_weight.data,
|
304
|
+
dtype=torch.float8_e4m3fn)
|
305
|
+
for expert in range(self.num_total_experts):
|
306
|
+
w13_weight[expert, :, :], self.w13_scale[
|
307
|
+
expert] = ops.scaled_fp8_quant(
|
308
|
+
self.w13_weight.data[expert, :, :])
|
309
|
+
w2_weight[expert, :, :], self.w2_scale[
|
310
|
+
expert] = ops.scaled_fp8_quant(
|
311
|
+
self.w2_weight.data[expert, :, :])
|
312
|
+
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
313
|
+
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
314
|
+
|
315
|
+
# If checkpoint is fp8 + static, cleanup act_scales.
|
316
|
+
# Since state_dict has an act_scale per expert but our kernels
|
317
|
+
# are passed one act_scale shared across all experts.
|
318
|
+
elif self.quant_config.activation_scheme == "static":
|
319
|
+
if self.a13_scale is None or self.a2_scale is None:
|
320
|
+
raise ValueError(
|
321
|
+
"QuantConfig has static quantization, but found "
|
322
|
+
"activation scales are None.")
|
323
|
+
|
324
|
+
if (not all_close_1d(self.a13_scale)
|
325
|
+
or not all_close_1d(self.a2_scale)):
|
326
|
+
print_warning_once(
|
327
|
+
"Found act_scales that are not equal for fp8 MoE layer. "
|
328
|
+
"Using the maximum across experts for each layer. ")
|
329
|
+
|
330
|
+
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
331
|
+
requires_grad=False)
|
332
|
+
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
333
|
+
requires_grad=False)
|
334
|
+
|
335
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
336
|
+
num_tokens, hidden_size = hidden_states.shape
|
337
|
+
hidden_states = hidden_states.view(-1, self.hidden_size)
|
338
|
+
# router_logits: (num_tokens, n_experts)
|
339
|
+
router_logits, _ = self.gate(hidden_states)
|
340
|
+
final_hidden_states = fused_moe(hidden_states,
|
341
|
+
self.w13_weight,
|
342
|
+
self.w2_weight,
|
343
|
+
router_logits,
|
344
|
+
self.top_k,
|
345
|
+
renormalize=False,
|
346
|
+
inplace=True,
|
347
|
+
use_fp8=self.use_fp8,
|
348
|
+
w1_scale=self.w13_scale,
|
349
|
+
w2_scale=self.w2_scale,
|
350
|
+
a1_scale=self.a13_scale,
|
351
|
+
a2_scale=self.a2_scale)
|
352
|
+
|
353
|
+
if self.tp_size > 1:
|
354
|
+
final_hidden_states = tensor_model_parallel_all_reduce(
|
355
|
+
final_hidden_states)
|
356
|
+
|
357
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
358
|
+
|
359
|
+
|
360
|
+
class Grok1Attention(nn.Module):
|
361
|
+
def __init__(
|
362
|
+
self,
|
363
|
+
hidden_size: int,
|
364
|
+
num_heads: int,
|
365
|
+
num_kv_heads: int,
|
366
|
+
layer_id: int = 0,
|
367
|
+
max_position: int = 4096 * 32,
|
368
|
+
rope_theta: float = 10000,
|
369
|
+
logit_cap: float = 30,
|
370
|
+
quant_config: Optional[QuantizationConfig] = None,
|
371
|
+
) -> None:
|
372
|
+
super().__init__()
|
373
|
+
self.hidden_size = hidden_size
|
374
|
+
tp_size = get_tensor_model_parallel_world_size()
|
375
|
+
self.total_num_heads = num_heads
|
376
|
+
assert self.total_num_heads % tp_size == 0
|
377
|
+
self.num_heads = self.total_num_heads // tp_size
|
378
|
+
self.total_num_kv_heads = num_kv_heads
|
379
|
+
if self.total_num_kv_heads >= tp_size:
|
380
|
+
# Number of KV heads is greater than TP size, so we partition
|
381
|
+
# the KV heads across multiple tensor parallel GPUs.
|
382
|
+
assert self.total_num_kv_heads % tp_size == 0
|
383
|
+
else:
|
384
|
+
# Number of KV heads is less than TP size, so we replicate
|
385
|
+
# the KV heads across multiple tensor parallel GPUs.
|
386
|
+
assert tp_size % self.total_num_kv_heads == 0
|
387
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
388
|
+
self.head_dim = 128
|
389
|
+
self.q_size = self.num_heads * self.head_dim
|
390
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
391
|
+
self.scaling = self.head_dim**-0.5
|
392
|
+
self.rope_theta = rope_theta
|
393
|
+
|
394
|
+
self.qkv_proj = QKVParallelLinear(
|
395
|
+
hidden_size,
|
396
|
+
self.head_dim,
|
397
|
+
self.total_num_heads,
|
398
|
+
self.total_num_kv_heads,
|
399
|
+
bias=False,
|
400
|
+
quant_config=quant_config,
|
401
|
+
)
|
402
|
+
|
403
|
+
self.o_proj = RowParallelLinear(
|
404
|
+
self.total_num_heads * self.head_dim,
|
405
|
+
hidden_size,
|
406
|
+
bias=False,
|
407
|
+
quant_config=quant_config,
|
408
|
+
)
|
409
|
+
self.rotary_emb = get_rope(
|
410
|
+
self.head_dim,
|
411
|
+
rotary_dim=self.head_dim,
|
412
|
+
max_position=max_position,
|
413
|
+
base=int(self.rope_theta),
|
414
|
+
is_neox_style=True,
|
415
|
+
)
|
416
|
+
self.attn = RadixAttention(
|
417
|
+
self.num_heads,
|
418
|
+
self.head_dim,
|
419
|
+
self.scaling,
|
420
|
+
num_kv_heads=self.num_kv_heads,
|
421
|
+
layer_id=layer_id,
|
422
|
+
logit_cap=logit_cap,
|
423
|
+
)
|
424
|
+
|
425
|
+
def forward(
|
426
|
+
self,
|
427
|
+
positions: torch.Tensor,
|
428
|
+
hidden_states: torch.Tensor,
|
429
|
+
input_metadata: InputMetadata,
|
430
|
+
) -> torch.Tensor:
|
431
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
432
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
433
|
+
q, k = self.rotary_emb(positions, q, k)
|
434
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
435
|
+
output, _ = self.o_proj(attn_output)
|
436
|
+
return output
|
437
|
+
|
438
|
+
|
439
|
+
class Grok1DecoderLayer(nn.Module):
|
440
|
+
def __init__(
|
441
|
+
self,
|
442
|
+
config: PretrainedConfig,
|
443
|
+
layer_id: int = 0,
|
444
|
+
quant_config: Optional[QuantizationConfig] = None,
|
445
|
+
) -> None:
|
446
|
+
super().__init__()
|
447
|
+
self.hidden_size = config.hidden_size
|
448
|
+
# Requires transformers > 4.32.0
|
449
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
450
|
+
self.self_attn = Grok1Attention(
|
451
|
+
hidden_size=self.hidden_size,
|
452
|
+
num_heads=config.num_attention_heads,
|
453
|
+
max_position=config.max_position_embeddings,
|
454
|
+
num_kv_heads=config.num_key_value_heads,
|
455
|
+
layer_id=layer_id,
|
456
|
+
rope_theta=rope_theta,
|
457
|
+
quant_config=quant_config,
|
458
|
+
)
|
459
|
+
if use_fused:
|
460
|
+
self.block_sparse_moe = Grok1MoE(
|
461
|
+
num_experts=config.num_local_experts,
|
462
|
+
top_k=config.num_experts_per_tok,
|
463
|
+
hidden_size=config.hidden_size,
|
464
|
+
intermediate_size=config.intermediate_size,
|
465
|
+
quant_config=quant_config)
|
466
|
+
else:
|
467
|
+
self.block_sparse_moe = Grok1MoEUnfused(
|
468
|
+
config=config, quant_config=quant_config)
|
469
|
+
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
470
|
+
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
471
|
+
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
472
|
+
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
473
|
+
|
474
|
+
def forward(
|
475
|
+
self,
|
476
|
+
positions: torch.Tensor,
|
477
|
+
hidden_states: torch.Tensor,
|
478
|
+
input_metadata: InputMetadata,
|
479
|
+
) -> torch.Tensor:
|
480
|
+
|
481
|
+
hidden_states = self.post_attn_norm(self.self_attn(
|
482
|
+
positions=positions, hidden_states=self.pre_attn_norm(hidden_states),
|
483
|
+
input_metadata=input_metadata,
|
484
|
+
)) + hidden_states
|
485
|
+
|
486
|
+
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
|
487
|
+
|
488
|
+
return hidden_states
|
489
|
+
|
490
|
+
|
491
|
+
class Grok1Model(nn.Module):
|
492
|
+
def __init__(
|
493
|
+
self,
|
494
|
+
config: PretrainedConfig,
|
495
|
+
quant_config: Optional[QuantizationConfig] = None,
|
496
|
+
) -> None:
|
497
|
+
super().__init__()
|
498
|
+
self.config = config
|
499
|
+
self.padding_idx = config.pad_token_id
|
500
|
+
self.vocab_size = config.vocab_size
|
501
|
+
|
502
|
+
self.embed_tokens = VocabParallelEmbedding(
|
503
|
+
config.vocab_size,
|
504
|
+
config.hidden_size,
|
505
|
+
)
|
506
|
+
self.layers = nn.ModuleList(
|
507
|
+
[
|
508
|
+
Grok1DecoderLayer(config, i, quant_config=quant_config)
|
509
|
+
for i in range(config.num_hidden_layers)
|
510
|
+
]
|
511
|
+
)
|
512
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
513
|
+
|
514
|
+
def forward(
|
515
|
+
self,
|
516
|
+
input_ids: torch.Tensor,
|
517
|
+
positions: torch.Tensor,
|
518
|
+
input_metadata: InputMetadata,
|
519
|
+
input_embeds: torch.Tensor = None,
|
520
|
+
) -> torch.Tensor:
|
521
|
+
if input_embeds is None:
|
522
|
+
hidden_states = self.embed_tokens(input_ids)
|
523
|
+
else:
|
524
|
+
hidden_states = input_embeds
|
525
|
+
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
526
|
+
|
527
|
+
for i in range(len(self.layers)):
|
528
|
+
hidden_states = self.layers[i](
|
529
|
+
positions, hidden_states, input_metadata
|
530
|
+
)
|
531
|
+
|
532
|
+
hidden_states = self.norm(hidden_states)
|
533
|
+
hidden_states.mul_(self.config.output_multiplier_scale)
|
534
|
+
return hidden_states
|
535
|
+
|
536
|
+
|
537
|
+
class Grok1ModelForCausalLM(nn.Module):
|
538
|
+
def __init__(
|
539
|
+
self,
|
540
|
+
config: PretrainedConfig,
|
541
|
+
quant_config: Optional[QuantizationConfig] = None,
|
542
|
+
cache_config: Optional[CacheConfig] = None,
|
543
|
+
) -> None:
|
544
|
+
super().__init__()
|
545
|
+
self.config = config
|
546
|
+
self.quant_config = quant_config
|
547
|
+
self.model = Grok1Model(config, quant_config=quant_config)
|
548
|
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
549
|
+
self.logits_processor = LogitsProcessor(config)
|
550
|
+
|
551
|
+
# Monkey patch _prepare_weights to load pre-sharded weights
|
552
|
+
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
553
|
+
|
554
|
+
def forward(
|
555
|
+
self,
|
556
|
+
input_ids: torch.Tensor,
|
557
|
+
positions: torch.Tensor,
|
558
|
+
input_metadata: InputMetadata,
|
559
|
+
input_embeds: torch.Tensor = None,
|
560
|
+
) -> torch.Tensor:
|
561
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
562
|
+
return self.logits_processor(
|
563
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
564
|
+
)
|
565
|
+
|
566
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
567
|
+
stacked_params_mapping = [
|
568
|
+
# (param_name, shard_name, shard_id)
|
569
|
+
("qkv_proj", "q_proj", "q"),
|
570
|
+
("qkv_proj", "k_proj", "k"),
|
571
|
+
("qkv_proj", "v_proj", "v"),
|
572
|
+
]
|
573
|
+
|
574
|
+
if use_fused:
|
575
|
+
expert_params_mapping = [
|
576
|
+
# These are the weight scales for the experts
|
577
|
+
# (param_name, weight_name, expert_id)
|
578
|
+
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
579
|
+
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
580
|
+
for expert_id in range(self.config.num_local_experts)
|
581
|
+
for weight_name in ["w1", "w2", "w3"]
|
582
|
+
] + [
|
583
|
+
# These are the weights for the experts
|
584
|
+
# (param_name, weight_name, expert_id)
|
585
|
+
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
586
|
+
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
587
|
+
for expert_id in range(self.config.num_local_experts)
|
588
|
+
for weight_name in ["w1", "w2", "w3"]
|
589
|
+
] + [
|
590
|
+
# These are the activation scales for the experts
|
591
|
+
# (param_name, weight_name, expert_id)
|
592
|
+
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
593
|
+
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
594
|
+
for expert_id in range(self.config.num_local_experts)
|
595
|
+
for weight_name in ["w1", "w2", "w3"]
|
596
|
+
]
|
597
|
+
else:
|
598
|
+
expert_params_mapping = []
|
599
|
+
|
600
|
+
params_dict = dict(self.named_parameters())
|
601
|
+
if get_tensor_model_parallel_rank() == 0:
|
602
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
603
|
+
for name, loaded_weight in weights:
|
604
|
+
#print(get_tensor_model_parallel_rank(), name)
|
605
|
+
if "rotary_emb.inv_freq" in name:
|
606
|
+
continue
|
607
|
+
|
608
|
+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
609
|
+
if weight_name not in name:
|
610
|
+
continue
|
611
|
+
name = name.replace(weight_name, param_name)
|
612
|
+
# Skip loading extra bias for GPTQ models.
|
613
|
+
if name.endswith(".bias") and name not in params_dict:
|
614
|
+
continue
|
615
|
+
param = params_dict[name]
|
616
|
+
weight_loader = param.weight_loader
|
617
|
+
weight_loader(param, loaded_weight, shard_id)
|
618
|
+
break
|
619
|
+
else:
|
620
|
+
for param_name, weight_name, expert_id in expert_params_mapping:
|
621
|
+
if weight_name not in name:
|
622
|
+
continue
|
623
|
+
name = name.replace(weight_name, param_name)
|
624
|
+
param = params_dict[name]
|
625
|
+
weight_loader = param.weight_loader
|
626
|
+
weight_loader(param,
|
627
|
+
loaded_weight,
|
628
|
+
weight_name,
|
629
|
+
expert_id=expert_id,
|
630
|
+
pre_sharded=get_tensor_model_parallel_world_size() > 1)
|
631
|
+
break
|
632
|
+
else:
|
633
|
+
# Skip loading extra bias for GPTQ models.
|
634
|
+
if name.endswith(".bias") and name not in params_dict:
|
635
|
+
continue
|
636
|
+
param = params_dict[name]
|
637
|
+
weight_loader = getattr(param, "weight_loader",
|
638
|
+
default_weight_loader)
|
639
|
+
weight_loader(param, loaded_weight)
|
640
|
+
|
641
|
+
|
642
|
+
def all_close_1d(x: torch.Tensor) -> bool:
|
643
|
+
assert len(x.shape) == 1
|
644
|
+
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
645
|
+
|
646
|
+
|
647
|
+
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
648
|
+
def _prepare_presharded_weights(self,
|
649
|
+
model_name_or_path: str,
|
650
|
+
revision: Optional[str],
|
651
|
+
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
652
|
+
import glob
|
653
|
+
import os
|
654
|
+
|
655
|
+
if get_tensor_model_parallel_world_size() == 1:
|
656
|
+
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
657
|
+
|
658
|
+
tp_rank = get_tensor_model_parallel_rank()
|
659
|
+
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
660
|
+
|
661
|
+
hf_folder = model_name_or_path
|
662
|
+
|
663
|
+
hf_weights_files: List[str] = []
|
664
|
+
for pattern in allow_patterns:
|
665
|
+
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
666
|
+
use_safetensors = False
|
667
|
+
|
668
|
+
return hf_folder, hf_weights_files, use_safetensors
|
669
|
+
|
670
|
+
|
671
|
+
EntryClass = Grok1ModelForCausalLM
|