sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/managers/cache_controller.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +52 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +9 -1
- sglang/srt/mem_cache/memory_pool.py +4 -1
- sglang/srt/model_executor/cuda_graph_runner.py +59 -16
- sglang/srt/model_executor/forward_batch_info.py +13 -4
- sglang/srt/models/deepseek_v2.py +180 -177
- sglang/srt/models/grok.py +374 -119
- sglang/srt/openai_api/adapter.py +22 -20
- sglang/srt/server_args.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +24 -22
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -15,28 +15,36 @@
|
|
15
15
|
# Adapted from
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
|
-
|
19
|
-
|
20
|
-
|
18
|
+
import functools
|
19
|
+
import json
|
20
|
+
import logging
|
21
|
+
import math
|
22
|
+
import os
|
23
|
+
import warnings
|
24
|
+
from typing import Iterable, Optional, Tuple
|
25
|
+
|
26
|
+
import numpy as np
|
21
27
|
import torch
|
22
|
-
import torch.nn.functional as F
|
23
28
|
from torch import nn
|
24
29
|
from transformers import PretrainedConfig
|
25
30
|
|
26
31
|
from sglang.srt.distributed import (
|
27
32
|
get_tensor_model_parallel_rank,
|
28
33
|
get_tensor_model_parallel_world_size,
|
34
|
+
tensor_model_parallel_all_gather,
|
35
|
+
tensor_model_parallel_all_reduce,
|
29
36
|
)
|
30
|
-
from sglang.srt.layers.
|
37
|
+
from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
|
31
38
|
from sglang.srt.layers.layernorm import RMSNorm
|
32
39
|
from sglang.srt.layers.linear import (
|
33
|
-
MergedColumnParallelLinear,
|
34
40
|
QKVParallelLinear,
|
35
41
|
ReplicatedLinear,
|
36
42
|
RowParallelLinear,
|
37
43
|
)
|
38
44
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
39
46
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
47
|
+
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
40
48
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
49
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
50
|
from sglang.srt.layers.rotary_embedding import get_rope
|
@@ -44,47 +52,17 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
44
52
|
ParallelLMHead,
|
45
53
|
VocabParallelEmbedding,
|
46
54
|
)
|
55
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
47
56
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
57
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
49
58
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
50
|
-
from sglang.srt.utils import
|
59
|
+
from sglang.srt.utils import dump_to_file
|
51
60
|
|
61
|
+
logger = logging.getLogger(__name__)
|
52
62
|
|
53
|
-
class Grok1MLP(nn.Module):
|
54
|
-
def __init__(
|
55
|
-
self,
|
56
|
-
hidden_size: int,
|
57
|
-
intermediate_size: int,
|
58
|
-
quant_config: Optional[QuantizationConfig] = None,
|
59
|
-
prefix: str = "",
|
60
|
-
reduce_results=True,
|
61
|
-
use_presharded_weights: bool = False,
|
62
|
-
) -> None:
|
63
|
-
super().__init__()
|
64
|
-
self.gate_up_proj = MergedColumnParallelLinear(
|
65
|
-
hidden_size,
|
66
|
-
[intermediate_size] * 2,
|
67
|
-
bias=False,
|
68
|
-
quant_config=quant_config,
|
69
|
-
prefix=add_prefix("gate_up_proj", prefix),
|
70
|
-
use_presharded_weights=use_presharded_weights,
|
71
|
-
)
|
72
|
-
self.down_proj = RowParallelLinear(
|
73
|
-
intermediate_size,
|
74
|
-
hidden_size,
|
75
|
-
bias=False,
|
76
|
-
quant_config=quant_config,
|
77
|
-
prefix=add_prefix("down_proj", prefix),
|
78
|
-
reduce_results=reduce_results,
|
79
|
-
use_presharded_weights=use_presharded_weights,
|
80
|
-
)
|
81
|
-
self.act_fn = GeluAndMul(approximate="tanh")
|
82
63
|
|
83
|
-
|
84
|
-
|
85
|
-
x = self.act_fn(gate_up)
|
86
|
-
x, _ = self.down_proj(x)
|
87
|
-
return x
|
64
|
+
debug_tensor_dump_output_folder = None
|
65
|
+
debug_tensor_dump_inject = False
|
88
66
|
|
89
67
|
|
90
68
|
class Grok1MoE(nn.Module):
|
@@ -108,51 +86,55 @@ class Grok1MoE(nn.Module):
|
|
108
86
|
tp_size: Optional[int] = None,
|
109
87
|
reduce_results=True,
|
110
88
|
use_presharded_weights: bool = False,
|
111
|
-
|
89
|
+
inplace: bool = True,
|
90
|
+
no_combine: bool = False,
|
112
91
|
):
|
113
92
|
super().__init__()
|
114
93
|
self.hidden_size = hidden_size
|
115
94
|
|
116
|
-
# Gate always runs at
|
95
|
+
# Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
|
117
96
|
self.gate = ReplicatedLinear(
|
118
97
|
hidden_size,
|
119
98
|
num_experts,
|
120
99
|
bias=False,
|
121
|
-
params_dtype=
|
100
|
+
params_dtype=torch.float32,
|
122
101
|
quant_config=None,
|
123
|
-
prefix=add_prefix("gate", prefix),
|
124
102
|
)
|
125
103
|
|
126
104
|
self.router_logit_softcapping = getattr(
|
127
105
|
config, "router_logit_softcapping", 30.0
|
128
106
|
)
|
129
|
-
|
107
|
+
custom_routing_function = functools.partial(
|
108
|
+
fused_moe_router_shim, self.router_logit_softcapping
|
109
|
+
)
|
110
|
+
|
111
|
+
kwargs = {}
|
112
|
+
if global_server_args_dict["enable_ep_moe"]:
|
113
|
+
MoEImpl = EPMoE
|
114
|
+
else:
|
115
|
+
MoEImpl = FusedMoE
|
116
|
+
kwargs["reduce_results"] = reduce_results
|
117
|
+
kwargs["use_presharded_weights"] = use_presharded_weights
|
118
|
+
kwargs["inplace"] = inplace
|
119
|
+
kwargs["no_combine"] = no_combine
|
120
|
+
|
121
|
+
self.experts = MoEImpl(
|
130
122
|
num_experts=num_experts,
|
131
123
|
top_k=top_k,
|
132
124
|
hidden_size=hidden_size,
|
133
125
|
intermediate_size=intermediate_size,
|
134
126
|
params_dtype=params_dtype,
|
135
|
-
reduce_results=reduce_results,
|
136
127
|
renormalize=False,
|
137
128
|
quant_config=quant_config,
|
138
129
|
tp_size=tp_size,
|
130
|
+
custom_routing_function=custom_routing_function,
|
139
131
|
activation="gelu",
|
140
|
-
|
141
|
-
prefix=add_prefix("experts", prefix),
|
132
|
+
**kwargs,
|
142
133
|
)
|
143
134
|
|
144
135
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
145
|
-
# NOTE: hidden_states can have either 1D or 2D shape.
|
146
|
-
orig_shape = hidden_states.shape
|
147
|
-
hidden_states = hidden_states.view(-1, self.hidden_size)
|
148
|
-
|
149
|
-
# router_logits: (num_tokens, n_experts)
|
150
|
-
router_logits, _ = self.gate(hidden_states)
|
151
|
-
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
152
|
-
|
153
136
|
# need to assert self.gate.quant_method is unquantized
|
154
|
-
|
155
|
-
return final_hidden_states.view(orig_shape)
|
137
|
+
return self.experts(hidden_states, self.gate.weight)
|
156
138
|
|
157
139
|
|
158
140
|
class Grok1Attention(nn.Module):
|
@@ -167,31 +149,33 @@ class Grok1Attention(nn.Module):
|
|
167
149
|
rope_theta: float = 10000,
|
168
150
|
quant_config: Optional[QuantizationConfig] = None,
|
169
151
|
reduce_results: bool = True,
|
170
|
-
|
152
|
+
load_presharded_attn: bool = False,
|
171
153
|
) -> None:
|
172
154
|
super().__init__()
|
173
155
|
self.config = config
|
174
156
|
self.layer_id = layer_id
|
175
157
|
self.hidden_size = hidden_size
|
176
|
-
|
158
|
+
attn_tp_rank = get_tensor_model_parallel_rank()
|
159
|
+
attn_tp_size = get_tensor_model_parallel_world_size()
|
177
160
|
self.total_num_heads = num_heads
|
178
|
-
assert self.total_num_heads %
|
179
|
-
self.num_heads = self.total_num_heads //
|
161
|
+
assert self.total_num_heads % attn_tp_size == 0
|
162
|
+
self.num_heads = self.total_num_heads // attn_tp_size
|
180
163
|
self.total_num_kv_heads = num_kv_heads
|
181
|
-
if self.total_num_kv_heads >=
|
164
|
+
if self.total_num_kv_heads >= attn_tp_size:
|
182
165
|
# Number of KV heads is greater than TP size, so we partition
|
183
166
|
# the KV heads across multiple tensor parallel GPUs.
|
184
|
-
assert self.total_num_kv_heads %
|
167
|
+
assert self.total_num_kv_heads % attn_tp_size == 0
|
185
168
|
else:
|
186
169
|
# Number of KV heads is less than TP size, so we replicate
|
187
170
|
# the KV heads across multiple tensor parallel GPUs.
|
188
|
-
assert
|
189
|
-
self.num_kv_heads = max(1, self.total_num_kv_heads //
|
171
|
+
assert attn_tp_size % self.total_num_kv_heads == 0
|
172
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
190
173
|
self.head_dim = getattr(config, "head_dim", 128)
|
191
174
|
self.q_size = self.num_heads * self.head_dim
|
192
175
|
self.kv_size = self.num_kv_heads * self.head_dim
|
193
176
|
self.scaling = self.head_dim**-0.5
|
194
177
|
self.rope_theta = rope_theta
|
178
|
+
self.load_presharded_attn = load_presharded_attn
|
195
179
|
|
196
180
|
self.qkv_proj = QKVParallelLinear(
|
197
181
|
hidden_size,
|
@@ -200,7 +184,9 @@ class Grok1Attention(nn.Module):
|
|
200
184
|
self.total_num_kv_heads,
|
201
185
|
bias=False,
|
202
186
|
quant_config=quant_config,
|
203
|
-
|
187
|
+
tp_rank=attn_tp_rank,
|
188
|
+
tp_size=attn_tp_size,
|
189
|
+
load_presharded_attn=self.load_presharded_attn,
|
204
190
|
)
|
205
191
|
self.o_proj = RowParallelLinear(
|
206
192
|
self.total_num_heads * self.head_dim,
|
@@ -208,7 +194,9 @@ class Grok1Attention(nn.Module):
|
|
208
194
|
bias=False,
|
209
195
|
quant_config=quant_config,
|
210
196
|
reduce_results=reduce_results,
|
211
|
-
|
197
|
+
tp_rank=attn_tp_rank,
|
198
|
+
tp_size=attn_tp_size,
|
199
|
+
use_presharded_weights=self.load_presharded_attn,
|
212
200
|
)
|
213
201
|
self.rotary_emb = get_rope(
|
214
202
|
self.head_dim,
|
@@ -227,7 +215,6 @@ class Grok1Attention(nn.Module):
|
|
227
215
|
num_kv_heads=self.num_kv_heads,
|
228
216
|
layer_id=layer_id,
|
229
217
|
logit_cap=logit_cap,
|
230
|
-
prefix=add_prefix("attn", prefix),
|
231
218
|
)
|
232
219
|
|
233
220
|
def forward(
|
@@ -236,10 +223,73 @@ class Grok1Attention(nn.Module):
|
|
236
223
|
hidden_states: torch.Tensor,
|
237
224
|
forward_batch: ForwardBatch,
|
238
225
|
) -> torch.Tensor:
|
226
|
+
if hidden_states.shape[0] == 0:
|
227
|
+
assert (
|
228
|
+
not self.o_proj.reduce_results
|
229
|
+
), "short-circuiting allreduce will lead to hangs"
|
230
|
+
return hidden_states
|
231
|
+
if debug_tensor_dump_output_folder:
|
232
|
+
dump_to_file(
|
233
|
+
debug_tensor_dump_output_folder,
|
234
|
+
f"attn_input_{self.layer_id}",
|
235
|
+
hidden_states,
|
236
|
+
)
|
237
|
+
|
238
|
+
if debug_tensor_dump_inject:
|
239
|
+
name = os.path.join(
|
240
|
+
debug_tensor_dump_output_folder,
|
241
|
+
f"jax_dump_attn_input_{self.layer_id}.npy",
|
242
|
+
)
|
243
|
+
logger.info(f"Load {name} from jax.")
|
244
|
+
x = np.load(name)
|
245
|
+
hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
|
246
|
+
hidden_states
|
247
|
+
)
|
248
|
+
|
239
249
|
qkv, _ = self.qkv_proj(hidden_states)
|
240
250
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
241
251
|
q, k = self.rotary_emb(positions, q, k)
|
252
|
+
|
253
|
+
if debug_tensor_dump_output_folder:
|
254
|
+
num_tokens = q.shape[0]
|
255
|
+
num_heads_q = self.num_heads
|
256
|
+
head_dim = self.head_dim
|
257
|
+
num_heads_kv = k.numel() // (num_tokens * head_dim)
|
258
|
+
|
259
|
+
dump_to_file(
|
260
|
+
debug_tensor_dump_output_folder,
|
261
|
+
f"q_{self.layer_id}",
|
262
|
+
tensor_model_parallel_all_gather(
|
263
|
+
q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
|
264
|
+
).contiguous(),
|
265
|
+
)
|
266
|
+
dump_to_file(
|
267
|
+
debug_tensor_dump_output_folder,
|
268
|
+
f"k_{self.layer_id}",
|
269
|
+
tensor_model_parallel_all_gather(
|
270
|
+
k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
|
271
|
+
).contiguous(),
|
272
|
+
)
|
273
|
+
dump_to_file(
|
274
|
+
debug_tensor_dump_output_folder,
|
275
|
+
f"v_{self.layer_id}",
|
276
|
+
tensor_model_parallel_all_gather(
|
277
|
+
v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
|
278
|
+
).contiguous(),
|
279
|
+
)
|
280
|
+
|
242
281
|
attn_output = self.attn(q, k, v, forward_batch)
|
282
|
+
|
283
|
+
if debug_tensor_dump_output_folder:
|
284
|
+
dump_to_file(
|
285
|
+
debug_tensor_dump_output_folder,
|
286
|
+
f"attn_output_{self.layer_id}",
|
287
|
+
tensor_model_parallel_all_gather(
|
288
|
+
attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
|
289
|
+
dim=1,
|
290
|
+
).contiguous(),
|
291
|
+
)
|
292
|
+
|
243
293
|
output, _ = self.o_proj(attn_output)
|
244
294
|
return output
|
245
295
|
|
@@ -250,8 +300,9 @@ class Grok1DecoderLayer(nn.Module):
|
|
250
300
|
config: PretrainedConfig,
|
251
301
|
layer_id: int = 0,
|
252
302
|
quant_config: Optional[QuantizationConfig] = None,
|
253
|
-
|
254
|
-
|
303
|
+
load_presharded_moe: bool = False,
|
304
|
+
load_presharded_attn: bool = False,
|
305
|
+
load_presharded_mlp: bool = False,
|
255
306
|
) -> None:
|
256
307
|
super().__init__()
|
257
308
|
self.num_experts = config.num_local_experts
|
@@ -268,7 +319,8 @@ class Grok1DecoderLayer(nn.Module):
|
|
268
319
|
layer_id=layer_id,
|
269
320
|
rope_theta=rope_theta,
|
270
321
|
quant_config=quant_config,
|
271
|
-
|
322
|
+
reduce_results=False,
|
323
|
+
load_presharded_attn=load_presharded_attn,
|
272
324
|
)
|
273
325
|
self.block_sparse_moe = Grok1MoE(
|
274
326
|
config=config,
|
@@ -282,38 +334,68 @@ class Grok1DecoderLayer(nn.Module):
|
|
282
334
|
),
|
283
335
|
quant_config=quant_config,
|
284
336
|
reduce_results=True,
|
285
|
-
use_presharded_weights=
|
286
|
-
|
337
|
+
use_presharded_weights=load_presharded_moe,
|
338
|
+
inplace=True,
|
339
|
+
no_combine=False, # just a suggestion to not combine topk
|
287
340
|
)
|
341
|
+
|
288
342
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
289
343
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
290
344
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
291
345
|
self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
292
346
|
|
347
|
+
self.ffn = self.block_sparse_moe
|
348
|
+
|
293
349
|
def forward(
|
294
350
|
self,
|
295
351
|
positions: torch.Tensor,
|
296
352
|
hidden_states: torch.Tensor,
|
297
353
|
forward_batch: ForwardBatch,
|
298
|
-
|
354
|
+
residual: Optional[torch.Tensor] = None,
|
355
|
+
deferred_norm: Optional[RMSNorm] = None,
|
356
|
+
) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
|
299
357
|
# Self Attention
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
358
|
+
if deferred_norm is not None:
|
359
|
+
assert residual is not None
|
360
|
+
# here hidden_states is output of ffn, residual is residual from after previous attn layer
|
361
|
+
hidden_states, residual = fused_dual_residual_rmsnorm(
|
362
|
+
hidden_states,
|
363
|
+
residual,
|
364
|
+
deferred_norm.weight,
|
365
|
+
self.pre_attn_norm.weight,
|
366
|
+
deferred_norm.variance_epsilon,
|
307
367
|
)
|
308
|
-
|
368
|
+
else:
|
369
|
+
# here hidden_states is the residual
|
370
|
+
hidden_states, residual = (
|
371
|
+
fused_rmsnorm(
|
372
|
+
hidden_states,
|
373
|
+
self.pre_attn_norm.weight,
|
374
|
+
self.pre_attn_norm.variance_epsilon,
|
375
|
+
),
|
376
|
+
hidden_states,
|
377
|
+
)
|
378
|
+
|
379
|
+
hidden_states = self.self_attn(
|
380
|
+
positions=positions,
|
381
|
+
hidden_states=hidden_states,
|
382
|
+
forward_batch=forward_batch,
|
309
383
|
)
|
310
384
|
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
385
|
+
if get_tensor_model_parallel_world_size() > 1:
|
386
|
+
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
387
|
+
|
388
|
+
hidden_states, residual = fused_dual_residual_rmsnorm(
|
389
|
+
hidden_states,
|
390
|
+
residual,
|
391
|
+
self.post_attn_norm.weight,
|
392
|
+
self.pre_moe_norm.weight,
|
393
|
+
self.post_attn_norm.variance_epsilon,
|
315
394
|
)
|
316
|
-
|
395
|
+
|
396
|
+
# Fully Connected
|
397
|
+
hidden_states = self.ffn(hidden_states)
|
398
|
+
return hidden_states, residual, self.post_moe_norm # defer layernorm
|
317
399
|
|
318
400
|
|
319
401
|
class Grok1Model(nn.Module):
|
@@ -321,8 +403,10 @@ class Grok1Model(nn.Module):
|
|
321
403
|
self,
|
322
404
|
config: PretrainedConfig,
|
323
405
|
quant_config: Optional[QuantizationConfig] = None,
|
324
|
-
|
325
|
-
|
406
|
+
load_presharded_moe: bool = False,
|
407
|
+
load_presharded_embedding: bool = False,
|
408
|
+
load_presharded_attn: bool = False,
|
409
|
+
load_presharded_mlp: bool = False,
|
326
410
|
) -> None:
|
327
411
|
super().__init__()
|
328
412
|
self.config = config
|
@@ -332,7 +416,7 @@ class Grok1Model(nn.Module):
|
|
332
416
|
self.embed_tokens = VocabParallelEmbedding(
|
333
417
|
config.vocab_size,
|
334
418
|
config.hidden_size,
|
335
|
-
|
419
|
+
use_presharded_weights=load_presharded_embedding,
|
336
420
|
)
|
337
421
|
self.layers = nn.ModuleList(
|
338
422
|
[
|
@@ -340,8 +424,9 @@ class Grok1Model(nn.Module):
|
|
340
424
|
config,
|
341
425
|
i,
|
342
426
|
quant_config=quant_config,
|
343
|
-
|
344
|
-
|
427
|
+
load_presharded_moe=load_presharded_moe,
|
428
|
+
load_presharded_attn=load_presharded_attn,
|
429
|
+
load_presharded_mlp=load_presharded_mlp,
|
345
430
|
)
|
346
431
|
for i in range(config.num_hidden_layers)
|
347
432
|
]
|
@@ -361,10 +446,48 @@ class Grok1Model(nn.Module):
|
|
361
446
|
else:
|
362
447
|
hidden_states = input_embeds
|
363
448
|
|
449
|
+
residual, deferred_norm = None, None
|
364
450
|
for i in range(len(self.layers)):
|
365
|
-
hidden_states = self.layers[i](
|
366
|
-
|
367
|
-
|
451
|
+
hidden_states, residual, deferred_norm = self.layers[i](
|
452
|
+
positions, hidden_states, forward_batch, residual, deferred_norm
|
453
|
+
)
|
454
|
+
|
455
|
+
if debug_tensor_dump_output_folder:
|
456
|
+
hidden_states = (
|
457
|
+
fused_rmsnorm(
|
458
|
+
hidden_states,
|
459
|
+
deferred_norm.weight,
|
460
|
+
deferred_norm.variance_epsilon,
|
461
|
+
)
|
462
|
+
+ residual
|
463
|
+
)
|
464
|
+
|
465
|
+
dump_to_file(
|
466
|
+
debug_tensor_dump_output_folder,
|
467
|
+
"last_hidden_before_norm",
|
468
|
+
hidden_states,
|
469
|
+
)
|
470
|
+
|
471
|
+
hidden_states = fused_rmsnorm(
|
472
|
+
hidden_states,
|
473
|
+
self.norm.weight,
|
474
|
+
self.norm.variance_epsilon,
|
475
|
+
)
|
476
|
+
|
477
|
+
dump_to_file(
|
478
|
+
debug_tensor_dump_output_folder,
|
479
|
+
"last_hidden_after_norm",
|
480
|
+
hidden_states,
|
481
|
+
)
|
482
|
+
else:
|
483
|
+
hidden_states, _ = fused_dual_residual_rmsnorm(
|
484
|
+
hidden_states,
|
485
|
+
residual,
|
486
|
+
deferred_norm.weight,
|
487
|
+
self.norm.weight,
|
488
|
+
deferred_norm.variance_epsilon,
|
489
|
+
)
|
490
|
+
|
368
491
|
return hidden_states
|
369
492
|
|
370
493
|
|
@@ -373,31 +496,77 @@ class Grok1ForCausalLM(nn.Module):
|
|
373
496
|
self,
|
374
497
|
config: PretrainedConfig,
|
375
498
|
quant_config: Optional[QuantizationConfig] = None,
|
376
|
-
prefix: str = "",
|
377
499
|
) -> None:
|
378
500
|
super().__init__()
|
379
501
|
self.config = config
|
380
502
|
self.quant_config = quant_config
|
381
503
|
|
382
|
-
|
504
|
+
# Get presharded weights.
|
505
|
+
self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
|
506
|
+
self.load_presharded_moe = (
|
383
507
|
self.config.num_local_experts > 0
|
384
508
|
and get_tensor_model_parallel_world_size() > 1
|
385
|
-
)
|
386
|
-
|
509
|
+
)
|
510
|
+
self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
|
511
|
+
self.load_presharded_embedding = getattr(
|
512
|
+
config, "load_presharded_embedding", False
|
513
|
+
)
|
514
|
+
|
515
|
+
self.is_weights_presharded = (
|
516
|
+
self.load_presharded_mlp
|
517
|
+
or self.load_presharded_moe
|
518
|
+
or self.load_presharded_attn
|
519
|
+
or self.load_presharded_embedding
|
520
|
+
)
|
521
|
+
|
522
|
+
if self.is_weights_presharded:
|
387
523
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
388
|
-
|
389
|
-
|
524
|
+
|
525
|
+
default_replicate_lm_head = False
|
526
|
+
self.replicate_lm_head = getattr(
|
527
|
+
config, "replicate_lm_head", default_replicate_lm_head
|
528
|
+
)
|
390
529
|
|
391
530
|
self.model = Grok1Model(
|
392
531
|
config,
|
393
532
|
quant_config=quant_config,
|
394
|
-
|
395
|
-
|
533
|
+
load_presharded_moe=self.load_presharded_moe,
|
534
|
+
load_presharded_embedding=self.load_presharded_embedding,
|
535
|
+
load_presharded_attn=self.load_presharded_attn,
|
536
|
+
load_presharded_mlp=self.load_presharded_mlp,
|
396
537
|
)
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
538
|
+
|
539
|
+
lm_head_params_dtype = None
|
540
|
+
if self.replicate_lm_head:
|
541
|
+
self.lm_head = ReplicatedLinear(
|
542
|
+
config.hidden_size,
|
543
|
+
config.vocab_size,
|
544
|
+
bias=False,
|
545
|
+
params_dtype=lm_head_params_dtype,
|
546
|
+
)
|
547
|
+
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
548
|
+
else:
|
549
|
+
self.lm_head = ParallelLMHead(
|
550
|
+
config.vocab_size,
|
551
|
+
config.hidden_size,
|
552
|
+
use_presharded_weights=self.load_presharded_embedding,
|
553
|
+
params_dtype=lm_head_params_dtype,
|
554
|
+
)
|
555
|
+
self.logits_processor = LogitsProcessor(config)
|
556
|
+
|
557
|
+
# Dump tensors for debugging
|
558
|
+
global debug_tensor_dump_output_folder, debug_tensor_dump_inject
|
559
|
+
debug_tensor_dump_output_folder = global_server_args_dict[
|
560
|
+
"debug_tensor_dump_output_folder"
|
561
|
+
]
|
562
|
+
debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
|
563
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
564
|
+
|
565
|
+
if get_tensor_model_parallel_rank() == 0:
|
566
|
+
logger.info(
|
567
|
+
f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
|
568
|
+
f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
|
569
|
+
)
|
401
570
|
|
402
571
|
def forward(
|
403
572
|
self,
|
@@ -406,6 +575,9 @@ class Grok1ForCausalLM(nn.Module):
|
|
406
575
|
forward_batch: ForwardBatch,
|
407
576
|
input_embeds: torch.Tensor = None,
|
408
577
|
) -> torch.Tensor:
|
578
|
+
if debug_tensor_dump_output_folder:
|
579
|
+
dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)
|
580
|
+
|
409
581
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
410
582
|
return self.logits_processor(
|
411
583
|
input_ids, hidden_states, self.lm_head, forward_batch
|
@@ -414,21 +586,28 @@ class Grok1ForCausalLM(nn.Module):
|
|
414
586
|
def load_weights(
|
415
587
|
self,
|
416
588
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
589
|
+
num_experts: Optional[int] = None,
|
590
|
+
ignore_parent_name: bool = False,
|
591
|
+
) -> dict[str, torch.Tensor]:
|
592
|
+
if num_experts is None:
|
593
|
+
num_experts = self.config.num_local_experts
|
594
|
+
stacked_params_mapping = []
|
595
|
+
stacked_params_mapping += [
|
421
596
|
# (param_name, shard_name, shard_id)
|
422
597
|
("qkv_proj", "q_proj", "q"),
|
423
598
|
("qkv_proj", "k_proj", "k"),
|
424
599
|
("qkv_proj", "v_proj", "v"),
|
600
|
+
]
|
601
|
+
stacked_params_mapping += [
|
602
|
+
# (param_name, shard_name, shard_id)
|
425
603
|
("gate_up_proj", "gate_proj", 0),
|
426
604
|
("gate_up_proj", "up_proj", 1),
|
427
605
|
]
|
428
606
|
|
429
607
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
430
608
|
# (param_name, weight_name, expert_id, shard_id)
|
431
|
-
|
609
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
610
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
432
611
|
ckpt_gate_proj_name="w1",
|
433
612
|
ckpt_down_proj_name="w2",
|
434
613
|
ckpt_up_proj_name="w3",
|
@@ -439,14 +618,25 @@ class Grok1ForCausalLM(nn.Module):
|
|
439
618
|
all_names = set(params_dict.keys())
|
440
619
|
hit_names = set()
|
441
620
|
|
442
|
-
def load_weight_wrapper(
|
621
|
+
def load_weight_wrapper(
|
622
|
+
name: str, loaded_weight: torch.Tensor, *args, **kwargs
|
623
|
+
):
|
624
|
+
if ignore_parent_name:
|
625
|
+
name = name.split(".")[-1]
|
626
|
+
|
443
627
|
if name not in params_dict:
|
444
628
|
return
|
445
629
|
|
630
|
+
# Fuse constant multipliers into the weights
|
631
|
+
if "lm_head" in name:
|
632
|
+
loaded_weight = (
|
633
|
+
loaded_weight.to(torch.float32)
|
634
|
+
* self.config.output_multiplier_scale
|
635
|
+
)
|
636
|
+
|
446
637
|
param = params_dict[name]
|
447
638
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
448
639
|
weight_loader(param, loaded_weight, *args, **kwargs)
|
449
|
-
|
450
640
|
hit_names.add(name)
|
451
641
|
|
452
642
|
for name, loaded_weight in weights:
|
@@ -460,7 +650,6 @@ class Grok1ForCausalLM(nn.Module):
|
|
460
650
|
# Skip loading extra bias for GPTQ models.
|
461
651
|
if name.endswith(".bias") and name not in params_dict:
|
462
652
|
continue
|
463
|
-
|
464
653
|
load_weight_wrapper(name, loaded_weight, shard_id)
|
465
654
|
break
|
466
655
|
else:
|
@@ -487,13 +676,79 @@ class Grok1ForCausalLM(nn.Module):
|
|
487
676
|
|
488
677
|
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
489
678
|
|
679
|
+
if len(hit_names) > 5:
|
680
|
+
missing = all_names - hit_names
|
681
|
+
missing_exclude_scales = {x for x in missing if "scale" not in x}
|
682
|
+
logger.info(
|
683
|
+
f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
|
684
|
+
)
|
685
|
+
if len(missing_exclude_scales) > 0:
|
686
|
+
raise ValueError(
|
687
|
+
f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
|
688
|
+
)
|
689
|
+
|
690
|
+
elif len(hit_names) == 0:
|
691
|
+
raise ValueError("load_weights failed because it did not hit any names.")
|
692
|
+
|
693
|
+
return hit_names
|
694
|
+
|
695
|
+
def get_num_params_analytical(self):
|
696
|
+
cfg = self.config
|
697
|
+
moe_intermediate_size = getattr(
|
698
|
+
cfg,
|
699
|
+
"moe_intermediate_size",
|
700
|
+
getattr(cfg, "intermediate_size", None),
|
701
|
+
)
|
702
|
+
num_experts = cfg.num_local_experts
|
703
|
+
|
704
|
+
wq = (
|
705
|
+
cfg.num_hidden_layers
|
706
|
+
* cfg.hidden_size
|
707
|
+
* cfg.num_attention_heads
|
708
|
+
* cfg.head_dim
|
709
|
+
)
|
710
|
+
wkv = (
|
711
|
+
cfg.num_hidden_layers
|
712
|
+
* cfg.hidden_size
|
713
|
+
* cfg.num_key_value_heads
|
714
|
+
* cfg.head_dim
|
715
|
+
* 2
|
716
|
+
)
|
717
|
+
out = (
|
718
|
+
cfg.num_hidden_layers
|
719
|
+
* cfg.hidden_size
|
720
|
+
* cfg.num_attention_heads
|
721
|
+
* cfg.head_dim
|
722
|
+
)
|
723
|
+
ffn1 = (
|
724
|
+
cfg.num_hidden_layers
|
725
|
+
* num_experts
|
726
|
+
* cfg.hidden_size
|
727
|
+
* moe_intermediate_size
|
728
|
+
* 2
|
729
|
+
)
|
730
|
+
ffn2 = (
|
731
|
+
cfg.num_hidden_layers
|
732
|
+
* num_experts
|
733
|
+
* cfg.hidden_size
|
734
|
+
* moe_intermediate_size
|
735
|
+
)
|
736
|
+
embed = cfg.hidden_size * cfg.vocab_size * 2
|
737
|
+
return wq + wkv + out + ffn1 + ffn2 + embed
|
738
|
+
|
739
|
+
def get_num_params_torch(self):
|
740
|
+
return (
|
741
|
+
sum(p.numel() for p in self.parameters())
|
742
|
+
* get_tensor_model_parallel_world_size()
|
743
|
+
)
|
744
|
+
|
490
745
|
|
491
746
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
492
747
|
|
493
748
|
|
494
749
|
def _prepare_presharded_weights(
|
495
750
|
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
496
|
-
) -> Tuple[str,
|
751
|
+
) -> Tuple[str, list[str], bool]:
|
497
752
|
import glob
|
498
753
|
import os
|
499
754
|
|
@@ -522,7 +777,7 @@ def _prepare_presharded_weights(
|
|
522
777
|
# The new format
|
523
778
|
allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
|
524
779
|
|
525
|
-
hf_weights_files
|
780
|
+
hf_weights_files = []
|
526
781
|
for pattern in allow_patterns:
|
527
782
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
528
783
|
|