sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/srt/models/grok.py CHANGED
@@ -15,28 +15,36 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
-
19
- from typing import Iterable, List, Optional, Tuple
20
-
18
+ import functools
19
+ import json
20
+ import logging
21
+ import math
22
+ import os
23
+ import warnings
24
+ from typing import Iterable, Optional, Tuple
25
+
26
+ import numpy as np
21
27
  import torch
22
- import torch.nn.functional as F
23
28
  from torch import nn
24
29
  from transformers import PretrainedConfig
25
30
 
26
31
  from sglang.srt.distributed import (
27
32
  get_tensor_model_parallel_rank,
28
33
  get_tensor_model_parallel_world_size,
34
+ tensor_model_parallel_all_gather,
35
+ tensor_model_parallel_all_reduce,
29
36
  )
30
- from sglang.srt.layers.activation import GeluAndMul
37
+ from sglang.srt.layers.elementwise import fused_dual_residual_rmsnorm, fused_rmsnorm
31
38
  from sglang.srt.layers.layernorm import RMSNorm
32
39
  from sglang.srt.layers.linear import (
33
- MergedColumnParallelLinear,
34
40
  QKVParallelLinear,
35
41
  ReplicatedLinear,
36
42
  RowParallelLinear,
37
43
  )
38
44
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
39
46
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
47
+ from sglang.srt.layers.moe.router import fused_moe_router_shim
40
48
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
49
  from sglang.srt.layers.radix_attention import RadixAttention
42
50
  from sglang.srt.layers.rotary_embedding import get_rope
@@ -44,47 +52,17 @@ from sglang.srt.layers.vocab_parallel_embedding import (
44
52
  ParallelLMHead,
45
53
  VocabParallelEmbedding,
46
54
  )
55
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
47
56
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
57
  from sglang.srt.model_loader.loader import DefaultModelLoader
49
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
- from sglang.srt.utils import add_prefix
59
+ from sglang.srt.utils import dump_to_file
51
60
 
61
+ logger = logging.getLogger(__name__)
52
62
 
53
- class Grok1MLP(nn.Module):
54
- def __init__(
55
- self,
56
- hidden_size: int,
57
- intermediate_size: int,
58
- quant_config: Optional[QuantizationConfig] = None,
59
- prefix: str = "",
60
- reduce_results=True,
61
- use_presharded_weights: bool = False,
62
- ) -> None:
63
- super().__init__()
64
- self.gate_up_proj = MergedColumnParallelLinear(
65
- hidden_size,
66
- [intermediate_size] * 2,
67
- bias=False,
68
- quant_config=quant_config,
69
- prefix=add_prefix("gate_up_proj", prefix),
70
- use_presharded_weights=use_presharded_weights,
71
- )
72
- self.down_proj = RowParallelLinear(
73
- intermediate_size,
74
- hidden_size,
75
- bias=False,
76
- quant_config=quant_config,
77
- prefix=add_prefix("down_proj", prefix),
78
- reduce_results=reduce_results,
79
- use_presharded_weights=use_presharded_weights,
80
- )
81
- self.act_fn = GeluAndMul(approximate="tanh")
82
63
 
83
- def forward(self, x):
84
- gate_up, _ = self.gate_up_proj(x)
85
- x = self.act_fn(gate_up)
86
- x, _ = self.down_proj(x)
87
- return x
64
+ debug_tensor_dump_output_folder = None
65
+ debug_tensor_dump_inject = False
88
66
 
89
67
 
90
68
  class Grok1MoE(nn.Module):
@@ -108,51 +86,55 @@ class Grok1MoE(nn.Module):
108
86
  tp_size: Optional[int] = None,
109
87
  reduce_results=True,
110
88
  use_presharded_weights: bool = False,
111
- prefix: str = "",
89
+ inplace: bool = True,
90
+ no_combine: bool = False,
112
91
  ):
113
92
  super().__init__()
114
93
  self.hidden_size = hidden_size
115
94
 
116
- # Gate always runs at half / full precision for now.
95
+ # Gate always runs at full precision for stability (see https://arxiv.org/pdf/2101.03961)
117
96
  self.gate = ReplicatedLinear(
118
97
  hidden_size,
119
98
  num_experts,
120
99
  bias=False,
121
- params_dtype=params_dtype,
100
+ params_dtype=torch.float32,
122
101
  quant_config=None,
123
- prefix=add_prefix("gate", prefix),
124
102
  )
125
103
 
126
104
  self.router_logit_softcapping = getattr(
127
105
  config, "router_logit_softcapping", 30.0
128
106
  )
129
- self.experts = FusedMoE(
107
+ custom_routing_function = functools.partial(
108
+ fused_moe_router_shim, self.router_logit_softcapping
109
+ )
110
+
111
+ kwargs = {}
112
+ if global_server_args_dict["enable_ep_moe"]:
113
+ MoEImpl = EPMoE
114
+ else:
115
+ MoEImpl = FusedMoE
116
+ kwargs["reduce_results"] = reduce_results
117
+ kwargs["use_presharded_weights"] = use_presharded_weights
118
+ kwargs["inplace"] = inplace
119
+ kwargs["no_combine"] = no_combine
120
+
121
+ self.experts = MoEImpl(
130
122
  num_experts=num_experts,
131
123
  top_k=top_k,
132
124
  hidden_size=hidden_size,
133
125
  intermediate_size=intermediate_size,
134
126
  params_dtype=params_dtype,
135
- reduce_results=reduce_results,
136
127
  renormalize=False,
137
128
  quant_config=quant_config,
138
129
  tp_size=tp_size,
130
+ custom_routing_function=custom_routing_function,
139
131
  activation="gelu",
140
- use_presharded_weights=use_presharded_weights,
141
- prefix=add_prefix("experts", prefix),
132
+ **kwargs,
142
133
  )
143
134
 
144
135
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
- # NOTE: hidden_states can have either 1D or 2D shape.
146
- orig_shape = hidden_states.shape
147
- hidden_states = hidden_states.view(-1, self.hidden_size)
148
-
149
- # router_logits: (num_tokens, n_experts)
150
- router_logits, _ = self.gate(hidden_states)
151
- router_logits = 30.0 * F.tanh(router_logits / 30.0)
152
-
153
136
  # need to assert self.gate.quant_method is unquantized
154
- final_hidden_states = self.experts(hidden_states, router_logits)
155
- return final_hidden_states.view(orig_shape)
137
+ return self.experts(hidden_states, self.gate.weight)
156
138
 
157
139
 
158
140
  class Grok1Attention(nn.Module):
@@ -167,31 +149,33 @@ class Grok1Attention(nn.Module):
167
149
  rope_theta: float = 10000,
168
150
  quant_config: Optional[QuantizationConfig] = None,
169
151
  reduce_results: bool = True,
170
- prefix: str = "",
152
+ load_presharded_attn: bool = False,
171
153
  ) -> None:
172
154
  super().__init__()
173
155
  self.config = config
174
156
  self.layer_id = layer_id
175
157
  self.hidden_size = hidden_size
176
- tp_size = get_tensor_model_parallel_world_size()
158
+ attn_tp_rank = get_tensor_model_parallel_rank()
159
+ attn_tp_size = get_tensor_model_parallel_world_size()
177
160
  self.total_num_heads = num_heads
178
- assert self.total_num_heads % tp_size == 0
179
- self.num_heads = self.total_num_heads // tp_size
161
+ assert self.total_num_heads % attn_tp_size == 0
162
+ self.num_heads = self.total_num_heads // attn_tp_size
180
163
  self.total_num_kv_heads = num_kv_heads
181
- if self.total_num_kv_heads >= tp_size:
164
+ if self.total_num_kv_heads >= attn_tp_size:
182
165
  # Number of KV heads is greater than TP size, so we partition
183
166
  # the KV heads across multiple tensor parallel GPUs.
184
- assert self.total_num_kv_heads % tp_size == 0
167
+ assert self.total_num_kv_heads % attn_tp_size == 0
185
168
  else:
186
169
  # Number of KV heads is less than TP size, so we replicate
187
170
  # the KV heads across multiple tensor parallel GPUs.
188
- assert tp_size % self.total_num_kv_heads == 0
189
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
171
+ assert attn_tp_size % self.total_num_kv_heads == 0
172
+ self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
190
173
  self.head_dim = getattr(config, "head_dim", 128)
191
174
  self.q_size = self.num_heads * self.head_dim
192
175
  self.kv_size = self.num_kv_heads * self.head_dim
193
176
  self.scaling = self.head_dim**-0.5
194
177
  self.rope_theta = rope_theta
178
+ self.load_presharded_attn = load_presharded_attn
195
179
 
196
180
  self.qkv_proj = QKVParallelLinear(
197
181
  hidden_size,
@@ -200,7 +184,9 @@ class Grok1Attention(nn.Module):
200
184
  self.total_num_kv_heads,
201
185
  bias=False,
202
186
  quant_config=quant_config,
203
- prefix=add_prefix("qkv_proj", prefix),
187
+ tp_rank=attn_tp_rank,
188
+ tp_size=attn_tp_size,
189
+ load_presharded_attn=self.load_presharded_attn,
204
190
  )
205
191
  self.o_proj = RowParallelLinear(
206
192
  self.total_num_heads * self.head_dim,
@@ -208,7 +194,9 @@ class Grok1Attention(nn.Module):
208
194
  bias=False,
209
195
  quant_config=quant_config,
210
196
  reduce_results=reduce_results,
211
- prefix=add_prefix("o_proj", prefix),
197
+ tp_rank=attn_tp_rank,
198
+ tp_size=attn_tp_size,
199
+ use_presharded_weights=self.load_presharded_attn,
212
200
  )
213
201
  self.rotary_emb = get_rope(
214
202
  self.head_dim,
@@ -227,7 +215,6 @@ class Grok1Attention(nn.Module):
227
215
  num_kv_heads=self.num_kv_heads,
228
216
  layer_id=layer_id,
229
217
  logit_cap=logit_cap,
230
- prefix=add_prefix("attn", prefix),
231
218
  )
232
219
 
233
220
  def forward(
@@ -236,10 +223,73 @@ class Grok1Attention(nn.Module):
236
223
  hidden_states: torch.Tensor,
237
224
  forward_batch: ForwardBatch,
238
225
  ) -> torch.Tensor:
226
+ if hidden_states.shape[0] == 0:
227
+ assert (
228
+ not self.o_proj.reduce_results
229
+ ), "short-circuiting allreduce will lead to hangs"
230
+ return hidden_states
231
+ if debug_tensor_dump_output_folder:
232
+ dump_to_file(
233
+ debug_tensor_dump_output_folder,
234
+ f"attn_input_{self.layer_id}",
235
+ hidden_states,
236
+ )
237
+
238
+ if debug_tensor_dump_inject:
239
+ name = os.path.join(
240
+ debug_tensor_dump_output_folder,
241
+ f"jax_dump_attn_input_{self.layer_id}.npy",
242
+ )
243
+ logger.info(f"Load {name} from jax.")
244
+ x = np.load(name)
245
+ hidden_states = torch.tensor(x[0, : hidden_states.shape[0]]).to(
246
+ hidden_states
247
+ )
248
+
239
249
  qkv, _ = self.qkv_proj(hidden_states)
240
250
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
241
251
  q, k = self.rotary_emb(positions, q, k)
252
+
253
+ if debug_tensor_dump_output_folder:
254
+ num_tokens = q.shape[0]
255
+ num_heads_q = self.num_heads
256
+ head_dim = self.head_dim
257
+ num_heads_kv = k.numel() // (num_tokens * head_dim)
258
+
259
+ dump_to_file(
260
+ debug_tensor_dump_output_folder,
261
+ f"q_{self.layer_id}",
262
+ tensor_model_parallel_all_gather(
263
+ q.reshape(num_tokens, num_heads_q, head_dim).contiguous(), dim=1
264
+ ).contiguous(),
265
+ )
266
+ dump_to_file(
267
+ debug_tensor_dump_output_folder,
268
+ f"k_{self.layer_id}",
269
+ tensor_model_parallel_all_gather(
270
+ k.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
271
+ ).contiguous(),
272
+ )
273
+ dump_to_file(
274
+ debug_tensor_dump_output_folder,
275
+ f"v_{self.layer_id}",
276
+ tensor_model_parallel_all_gather(
277
+ v.reshape(num_tokens, num_heads_kv, head_dim).contiguous(), dim=1
278
+ ).contiguous(),
279
+ )
280
+
242
281
  attn_output = self.attn(q, k, v, forward_batch)
282
+
283
+ if debug_tensor_dump_output_folder:
284
+ dump_to_file(
285
+ debug_tensor_dump_output_folder,
286
+ f"attn_output_{self.layer_id}",
287
+ tensor_model_parallel_all_gather(
288
+ attn_output.reshape(num_tokens, num_heads_q, head_dim).contiguous(),
289
+ dim=1,
290
+ ).contiguous(),
291
+ )
292
+
243
293
  output, _ = self.o_proj(attn_output)
244
294
  return output
245
295
 
@@ -250,8 +300,9 @@ class Grok1DecoderLayer(nn.Module):
250
300
  config: PretrainedConfig,
251
301
  layer_id: int = 0,
252
302
  quant_config: Optional[QuantizationConfig] = None,
253
- use_presharded_weights: bool = False,
254
- prefix: str = "",
303
+ load_presharded_moe: bool = False,
304
+ load_presharded_attn: bool = False,
305
+ load_presharded_mlp: bool = False,
255
306
  ) -> None:
256
307
  super().__init__()
257
308
  self.num_experts = config.num_local_experts
@@ -268,7 +319,8 @@ class Grok1DecoderLayer(nn.Module):
268
319
  layer_id=layer_id,
269
320
  rope_theta=rope_theta,
270
321
  quant_config=quant_config,
271
- prefix=add_prefix("attn", prefix),
322
+ reduce_results=False,
323
+ load_presharded_attn=load_presharded_attn,
272
324
  )
273
325
  self.block_sparse_moe = Grok1MoE(
274
326
  config=config,
@@ -282,38 +334,68 @@ class Grok1DecoderLayer(nn.Module):
282
334
  ),
283
335
  quant_config=quant_config,
284
336
  reduce_results=True,
285
- use_presharded_weights=use_presharded_weights,
286
- prefix=add_prefix("block_sparse_moe", prefix),
337
+ use_presharded_weights=load_presharded_moe,
338
+ inplace=True,
339
+ no_combine=False, # just a suggestion to not combine topk
287
340
  )
341
+
288
342
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
343
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
344
  self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
291
345
  self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
292
346
 
347
+ self.ffn = self.block_sparse_moe
348
+
293
349
  def forward(
294
350
  self,
295
351
  positions: torch.Tensor,
296
352
  hidden_states: torch.Tensor,
297
353
  forward_batch: ForwardBatch,
298
- ) -> torch.Tensor:
354
+ residual: Optional[torch.Tensor] = None,
355
+ deferred_norm: Optional[RMSNorm] = None,
356
+ ) -> Tuple[torch.Tensor, torch.Tensor, RMSNorm]:
299
357
  # Self Attention
300
- hidden_states = (
301
- self.post_attn_norm(
302
- self.self_attn(
303
- positions=positions,
304
- hidden_states=self.pre_attn_norm(hidden_states),
305
- forward_batch=forward_batch,
306
- )
358
+ if deferred_norm is not None:
359
+ assert residual is not None
360
+ # here hidden_states is output of ffn, residual is residual from after previous attn layer
361
+ hidden_states, residual = fused_dual_residual_rmsnorm(
362
+ hidden_states,
363
+ residual,
364
+ deferred_norm.weight,
365
+ self.pre_attn_norm.weight,
366
+ deferred_norm.variance_epsilon,
307
367
  )
308
- + hidden_states
368
+ else:
369
+ # here hidden_states is the residual
370
+ hidden_states, residual = (
371
+ fused_rmsnorm(
372
+ hidden_states,
373
+ self.pre_attn_norm.weight,
374
+ self.pre_attn_norm.variance_epsilon,
375
+ ),
376
+ hidden_states,
377
+ )
378
+
379
+ hidden_states = self.self_attn(
380
+ positions=positions,
381
+ hidden_states=hidden_states,
382
+ forward_batch=forward_batch,
309
383
  )
310
384
 
311
- # Fully Connected
312
- hidden_states = (
313
- self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
314
- + hidden_states
385
+ if get_tensor_model_parallel_world_size() > 1:
386
+ hidden_states = tensor_model_parallel_all_reduce(hidden_states)
387
+
388
+ hidden_states, residual = fused_dual_residual_rmsnorm(
389
+ hidden_states,
390
+ residual,
391
+ self.post_attn_norm.weight,
392
+ self.pre_moe_norm.weight,
393
+ self.post_attn_norm.variance_epsilon,
315
394
  )
316
- return hidden_states
395
+
396
+ # Fully Connected
397
+ hidden_states = self.ffn(hidden_states)
398
+ return hidden_states, residual, self.post_moe_norm # defer layernorm
317
399
 
318
400
 
319
401
  class Grok1Model(nn.Module):
@@ -321,8 +403,10 @@ class Grok1Model(nn.Module):
321
403
  self,
322
404
  config: PretrainedConfig,
323
405
  quant_config: Optional[QuantizationConfig] = None,
324
- use_presharded_weights: bool = False,
325
- prefix: str = "",
406
+ load_presharded_moe: bool = False,
407
+ load_presharded_embedding: bool = False,
408
+ load_presharded_attn: bool = False,
409
+ load_presharded_mlp: bool = False,
326
410
  ) -> None:
327
411
  super().__init__()
328
412
  self.config = config
@@ -332,7 +416,7 @@ class Grok1Model(nn.Module):
332
416
  self.embed_tokens = VocabParallelEmbedding(
333
417
  config.vocab_size,
334
418
  config.hidden_size,
335
- prefix=add_prefix("embed_tokens", prefix),
419
+ use_presharded_weights=load_presharded_embedding,
336
420
  )
337
421
  self.layers = nn.ModuleList(
338
422
  [
@@ -340,8 +424,9 @@ class Grok1Model(nn.Module):
340
424
  config,
341
425
  i,
342
426
  quant_config=quant_config,
343
- use_presharded_weights=use_presharded_weights,
344
- prefix=add_prefix(f"layers.{i}", prefix),
427
+ load_presharded_moe=load_presharded_moe,
428
+ load_presharded_attn=load_presharded_attn,
429
+ load_presharded_mlp=load_presharded_mlp,
345
430
  )
346
431
  for i in range(config.num_hidden_layers)
347
432
  ]
@@ -361,10 +446,48 @@ class Grok1Model(nn.Module):
361
446
  else:
362
447
  hidden_states = input_embeds
363
448
 
449
+ residual, deferred_norm = None, None
364
450
  for i in range(len(self.layers)):
365
- hidden_states = self.layers[i](positions, hidden_states, forward_batch)
366
- hidden_states = self.norm(hidden_states)
367
- hidden_states.mul_(self.config.output_multiplier_scale)
451
+ hidden_states, residual, deferred_norm = self.layers[i](
452
+ positions, hidden_states, forward_batch, residual, deferred_norm
453
+ )
454
+
455
+ if debug_tensor_dump_output_folder:
456
+ hidden_states = (
457
+ fused_rmsnorm(
458
+ hidden_states,
459
+ deferred_norm.weight,
460
+ deferred_norm.variance_epsilon,
461
+ )
462
+ + residual
463
+ )
464
+
465
+ dump_to_file(
466
+ debug_tensor_dump_output_folder,
467
+ "last_hidden_before_norm",
468
+ hidden_states,
469
+ )
470
+
471
+ hidden_states = fused_rmsnorm(
472
+ hidden_states,
473
+ self.norm.weight,
474
+ self.norm.variance_epsilon,
475
+ )
476
+
477
+ dump_to_file(
478
+ debug_tensor_dump_output_folder,
479
+ "last_hidden_after_norm",
480
+ hidden_states,
481
+ )
482
+ else:
483
+ hidden_states, _ = fused_dual_residual_rmsnorm(
484
+ hidden_states,
485
+ residual,
486
+ deferred_norm.weight,
487
+ self.norm.weight,
488
+ deferred_norm.variance_epsilon,
489
+ )
490
+
368
491
  return hidden_states
369
492
 
370
493
 
@@ -373,31 +496,77 @@ class Grok1ForCausalLM(nn.Module):
373
496
  self,
374
497
  config: PretrainedConfig,
375
498
  quant_config: Optional[QuantizationConfig] = None,
376
- prefix: str = "",
377
499
  ) -> None:
378
500
  super().__init__()
379
501
  self.config = config
380
502
  self.quant_config = quant_config
381
503
 
382
- if (
504
+ # Get presharded weights.
505
+ self.load_presharded_mlp = getattr(config, "load_presharded_mlp", False)
506
+ self.load_presharded_moe = (
383
507
  self.config.num_local_experts > 0
384
508
  and get_tensor_model_parallel_world_size() > 1
385
- ):
386
- self.use_presharded_weights = True
509
+ )
510
+ self.load_presharded_attn = getattr(config, "load_presharded_attn", False)
511
+ self.load_presharded_embedding = getattr(
512
+ config, "load_presharded_embedding", False
513
+ )
514
+
515
+ self.is_weights_presharded = (
516
+ self.load_presharded_mlp
517
+ or self.load_presharded_moe
518
+ or self.load_presharded_attn
519
+ or self.load_presharded_embedding
520
+ )
521
+
522
+ if self.is_weights_presharded:
387
523
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
388
- else:
389
- self.use_presharded_weights = False
524
+
525
+ default_replicate_lm_head = False
526
+ self.replicate_lm_head = getattr(
527
+ config, "replicate_lm_head", default_replicate_lm_head
528
+ )
390
529
 
391
530
  self.model = Grok1Model(
392
531
  config,
393
532
  quant_config=quant_config,
394
- use_presharded_weights=self.use_presharded_weights,
395
- prefix=add_prefix("model", prefix),
533
+ load_presharded_moe=self.load_presharded_moe,
534
+ load_presharded_embedding=self.load_presharded_embedding,
535
+ load_presharded_attn=self.load_presharded_attn,
536
+ load_presharded_mlp=self.load_presharded_mlp,
396
537
  )
397
- self.lm_head = ParallelLMHead(
398
- config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
399
- )
400
- self.logits_processor = LogitsProcessor(config)
538
+
539
+ lm_head_params_dtype = None
540
+ if self.replicate_lm_head:
541
+ self.lm_head = ReplicatedLinear(
542
+ config.hidden_size,
543
+ config.vocab_size,
544
+ bias=False,
545
+ params_dtype=lm_head_params_dtype,
546
+ )
547
+ self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
548
+ else:
549
+ self.lm_head = ParallelLMHead(
550
+ config.vocab_size,
551
+ config.hidden_size,
552
+ use_presharded_weights=self.load_presharded_embedding,
553
+ params_dtype=lm_head_params_dtype,
554
+ )
555
+ self.logits_processor = LogitsProcessor(config)
556
+
557
+ # Dump tensors for debugging
558
+ global debug_tensor_dump_output_folder, debug_tensor_dump_inject
559
+ debug_tensor_dump_output_folder = global_server_args_dict[
560
+ "debug_tensor_dump_output_folder"
561
+ ]
562
+ debug_tensor_dump_inject = global_server_args_dict["debug_tensor_dump_inject"]
563
+ warnings.filterwarnings("ignore", category=FutureWarning)
564
+
565
+ if get_tensor_model_parallel_rank() == 0:
566
+ logger.info(
567
+ f"#parameters (analytical): {self.get_num_params_analytical() / 1e9:.2f} B, "
568
+ f"#parameters (actual): {self.get_num_params_torch() / 1e9:.2f} B"
569
+ )
401
570
 
402
571
  def forward(
403
572
  self,
@@ -406,6 +575,9 @@ class Grok1ForCausalLM(nn.Module):
406
575
  forward_batch: ForwardBatch,
407
576
  input_embeds: torch.Tensor = None,
408
577
  ) -> torch.Tensor:
578
+ if debug_tensor_dump_output_folder:
579
+ dump_to_file(debug_tensor_dump_output_folder, "input_ids", input_ids)
580
+
409
581
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
410
582
  return self.logits_processor(
411
583
  input_ids, hidden_states, self.lm_head, forward_batch
@@ -414,21 +586,28 @@ class Grok1ForCausalLM(nn.Module):
414
586
  def load_weights(
415
587
  self,
416
588
  weights: Iterable[Tuple[str, torch.Tensor]],
417
- ):
418
- num_experts = self.config.num_local_experts
419
-
420
- stacked_params_mapping = [
589
+ num_experts: Optional[int] = None,
590
+ ignore_parent_name: bool = False,
591
+ ) -> dict[str, torch.Tensor]:
592
+ if num_experts is None:
593
+ num_experts = self.config.num_local_experts
594
+ stacked_params_mapping = []
595
+ stacked_params_mapping += [
421
596
  # (param_name, shard_name, shard_id)
422
597
  ("qkv_proj", "q_proj", "q"),
423
598
  ("qkv_proj", "k_proj", "k"),
424
599
  ("qkv_proj", "v_proj", "v"),
600
+ ]
601
+ stacked_params_mapping += [
602
+ # (param_name, shard_name, shard_id)
425
603
  ("gate_up_proj", "gate_proj", 0),
426
604
  ("gate_up_proj", "up_proj", 1),
427
605
  ]
428
606
 
429
607
  # Params for weights, fp8 weight scales, fp8 activation scales
430
608
  # (param_name, weight_name, expert_id, shard_id)
431
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
609
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
610
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
432
611
  ckpt_gate_proj_name="w1",
433
612
  ckpt_down_proj_name="w2",
434
613
  ckpt_up_proj_name="w3",
@@ -439,14 +618,25 @@ class Grok1ForCausalLM(nn.Module):
439
618
  all_names = set(params_dict.keys())
440
619
  hit_names = set()
441
620
 
442
- def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
621
+ def load_weight_wrapper(
622
+ name: str, loaded_weight: torch.Tensor, *args, **kwargs
623
+ ):
624
+ if ignore_parent_name:
625
+ name = name.split(".")[-1]
626
+
443
627
  if name not in params_dict:
444
628
  return
445
629
 
630
+ # Fuse constant multipliers into the weights
631
+ if "lm_head" in name:
632
+ loaded_weight = (
633
+ loaded_weight.to(torch.float32)
634
+ * self.config.output_multiplier_scale
635
+ )
636
+
446
637
  param = params_dict[name]
447
638
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
448
639
  weight_loader(param, loaded_weight, *args, **kwargs)
449
-
450
640
  hit_names.add(name)
451
641
 
452
642
  for name, loaded_weight in weights:
@@ -460,7 +650,6 @@ class Grok1ForCausalLM(nn.Module):
460
650
  # Skip loading extra bias for GPTQ models.
461
651
  if name.endswith(".bias") and name not in params_dict:
462
652
  continue
463
-
464
653
  load_weight_wrapper(name, loaded_weight, shard_id)
465
654
  break
466
655
  else:
@@ -487,13 +676,79 @@ class Grok1ForCausalLM(nn.Module):
487
676
 
488
677
  load_weight_wrapper(name=name, loaded_weight=loaded_weight)
489
678
 
679
+ if len(hit_names) > 5:
680
+ missing = all_names - hit_names
681
+ missing_exclude_scales = {x for x in missing if "scale" not in x}
682
+ logger.info(
683
+ f"#all_names: {len(all_names)}, #hit_names: {len(hit_names)}, #missing_exclude_scales: {len(missing_exclude_scales)}",
684
+ )
685
+ if len(missing_exclude_scales) > 0:
686
+ raise ValueError(
687
+ f"load_weights failed because some weights are missing: {missing_exclude_scales=}."
688
+ )
689
+
690
+ elif len(hit_names) == 0:
691
+ raise ValueError("load_weights failed because it did not hit any names.")
692
+
693
+ return hit_names
694
+
695
+ def get_num_params_analytical(self):
696
+ cfg = self.config
697
+ moe_intermediate_size = getattr(
698
+ cfg,
699
+ "moe_intermediate_size",
700
+ getattr(cfg, "intermediate_size", None),
701
+ )
702
+ num_experts = cfg.num_local_experts
703
+
704
+ wq = (
705
+ cfg.num_hidden_layers
706
+ * cfg.hidden_size
707
+ * cfg.num_attention_heads
708
+ * cfg.head_dim
709
+ )
710
+ wkv = (
711
+ cfg.num_hidden_layers
712
+ * cfg.hidden_size
713
+ * cfg.num_key_value_heads
714
+ * cfg.head_dim
715
+ * 2
716
+ )
717
+ out = (
718
+ cfg.num_hidden_layers
719
+ * cfg.hidden_size
720
+ * cfg.num_attention_heads
721
+ * cfg.head_dim
722
+ )
723
+ ffn1 = (
724
+ cfg.num_hidden_layers
725
+ * num_experts
726
+ * cfg.hidden_size
727
+ * moe_intermediate_size
728
+ * 2
729
+ )
730
+ ffn2 = (
731
+ cfg.num_hidden_layers
732
+ * num_experts
733
+ * cfg.hidden_size
734
+ * moe_intermediate_size
735
+ )
736
+ embed = cfg.hidden_size * cfg.vocab_size * 2
737
+ return wq + wkv + out + ffn1 + ffn2 + embed
738
+
739
+ def get_num_params_torch(self):
740
+ return (
741
+ sum(p.numel() for p in self.parameters())
742
+ * get_tensor_model_parallel_world_size()
743
+ )
744
+
490
745
 
491
746
  old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
492
747
 
493
748
 
494
749
  def _prepare_presharded_weights(
495
750
  self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
496
- ) -> Tuple[str, List[str], bool]:
751
+ ) -> Tuple[str, list[str], bool]:
497
752
  import glob
498
753
  import os
499
754
 
@@ -522,7 +777,7 @@ def _prepare_presharded_weights(
522
777
  # The new format
523
778
  allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
524
779
 
525
- hf_weights_files: List[str] = []
780
+ hf_weights_files = []
526
781
  for pattern in allow_patterns:
527
782
  hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
528
783