sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -16,29 +16,24 @@ limitations under the License.
16
16
  # Adapted from
17
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
18
18
  """Inference-only Grok1 model."""
19
+ import warnings
19
20
  from typing import Iterable, List, Optional, Tuple
20
21
 
21
- import numpy as np
22
22
  import torch
23
23
  import torch.nn.functional as F
24
- import tqdm
25
24
  from torch import nn
26
25
  from transformers import PretrainedConfig
27
- from vllm import _custom_ops as ops
28
26
  from vllm.config import CacheConfig
29
27
  from vllm.distributed import (
30
28
  get_tensor_model_parallel_rank,
31
29
  get_tensor_model_parallel_world_size,
32
- tensor_model_parallel_all_reduce,
33
30
  )
34
- from vllm.model_executor.layers.layernorm import RMSNorm
35
31
  from vllm.model_executor.layers.linear import (
36
32
  QKVParallelLinear,
37
33
  ReplicatedLinear,
38
34
  RowParallelLinear,
39
35
  )
40
36
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
41
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
42
37
  from vllm.model_executor.layers.rotary_embedding import get_rope
43
38
  from vllm.model_executor.layers.vocab_parallel_embedding import (
44
39
  ParallelLMHead,
@@ -46,140 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
46
41
  )
47
42
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
48
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
- from vllm.model_executor.utils import set_weight_attrs
50
- from vllm.utils import print_warning_once
51
44
 
52
- from sglang.srt.layers.fused_moe import fused_moe
45
+ from sglang.srt.layers.fused_moe import FusedMoE
46
+ from sglang.srt.layers.layernorm import RMSNorm
53
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
48
  from sglang.srt.layers.radix_attention import RadixAttention
55
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
56
50
 
57
- use_fused = True
58
-
59
-
60
- class Grok1MLP(nn.Module):
61
- def __init__(
62
- self,
63
- num_experts: int,
64
- hidden_size: int,
65
- intermediate_size: int,
66
- quant_config: Optional[QuantizationConfig] = None,
67
- ) -> None:
68
- super().__init__()
69
- self.num_experts = num_experts
70
- self.ffn_dim = intermediate_size
71
- self.hidden_dim = hidden_size
72
-
73
- self.w1 = ReplicatedLinear(
74
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
75
- )
76
- self.w2 = ReplicatedLinear(
77
- self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
78
- )
79
- self.w3 = ReplicatedLinear(
80
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
81
- )
82
-
83
- self.act_fn = nn.GELU()
84
-
85
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
86
- w1_out, _ = self.w1(hidden_states)
87
- w1_out = self.act_fn(w1_out)
88
- w3_out, _ = self.w3(hidden_states)
89
- current_hidden_states = w1_out * w3_out
90
- current_hidden_states, _ = self.w2(current_hidden_states)
91
- return current_hidden_states
92
-
93
-
94
- class Grok1MoEUnfused(nn.Module):
95
- def __init__(
96
- self,
97
- config: PretrainedConfig,
98
- quant_config: Optional[QuantizationConfig] = None,
99
- ):
100
- super().__init__()
101
- self.config = config
102
- self.rank = get_tensor_model_parallel_rank()
103
- self.tp_size = get_tensor_model_parallel_world_size()
104
- self.num_total_experts = config.num_local_experts
105
- self.top_k = config.num_experts_per_tok
106
- if self.tp_size > self.num_total_experts:
107
- raise ValueError(
108
- f"Tensor parallel size {self.tp_size} is greater than "
109
- f"the number of experts {self.num_total_experts}."
110
- )
111
- # Split experts equally between ranks
112
- self.expert_indicies = np.array_split(
113
- range(self.num_total_experts), self.tp_size
114
- )[self.rank].tolist()
115
- if not self.expert_indicies:
116
- raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
117
-
118
- self.experts = nn.ModuleList(
119
- [
120
- (
121
- Grok1MLP(
122
- self.num_total_experts,
123
- config.hidden_size,
124
- config.intermediate_size,
125
- quant_config=quant_config,
126
- )
127
- if idx in self.expert_indicies
128
- else None
129
- )
130
- for idx in range(self.num_total_experts)
131
- ]
132
- )
133
- self.gate = ReplicatedLinear(
134
- config.hidden_size, self.num_total_experts, bias=False, quant_config=None
135
- )
136
-
137
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
138
- router_logits, _ = self.gate(hidden_states)
139
- router_logits = 30 * F.tanh(router_logits / 30)
140
-
141
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
142
- routing_weights, selected_experts = torch.topk(
143
- routing_weights, self.top_k, dim=-1
144
- )
145
- routing_weights = routing_weights.to(hidden_states.dtype)
146
- hidden_dim = hidden_states.shape[1]
147
-
148
- final_hidden_states = torch.zeros(
149
- (hidden_states.shape[0], hidden_dim),
150
- dtype=hidden_states.dtype,
151
- device=hidden_states.device,
152
- )
153
- expert_mask = torch.nn.functional.one_hot(
154
- selected_experts, num_classes=self.num_total_experts
155
- ).permute(2, 1, 0)
156
-
157
- for expert_idx in self.expert_indicies:
158
- expert_layer = self.experts[expert_idx]
159
- idx, top_x = torch.where(expert_mask[expert_idx])
160
-
161
- if top_x.shape[0] == 0:
162
- continue
163
-
164
- # in torch it is faster to index using lists than torch tensors
165
- top_x_list = top_x.tolist()
166
- idx_list = idx.tolist()
167
-
168
- # Index the correct hidden states and compute the expert hidden state for
169
- # the current expert. We need to make sure to multiply the output hidden
170
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
171
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
172
- current_hidden_states = (
173
- expert_layer(current_state)
174
- * routing_weights[top_x_list, idx_list, None]
175
- )
176
-
177
- # However `index_add_` only support torch tensors for indexing so we'll use
178
- # the `top_x` tensor here.
179
- final_hidden_states.index_add_(0, top_x, current_hidden_states)
180
-
181
- return tensor_model_parallel_all_reduce(final_hidden_states)
182
-
183
51
 
184
52
  class Grok1MoE(nn.Module):
185
53
  """A tensor-parallel MoE implementation for Grok1 that shards each expert
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
197
65
  hidden_size: int,
198
66
  intermediate_size: int,
199
67
  params_dtype: Optional[torch.dtype] = None,
200
- tp_size: Optional[int] = None,
201
68
  quant_config: Optional[QuantizationConfig] = None,
69
+ tp_size: Optional[int] = None,
202
70
  ):
203
71
  super().__init__()
204
- self.tp_size = tp_size or get_tensor_model_parallel_world_size()
205
- self.num_total_experts = num_experts
206
- self.top_k = top_k
207
72
  self.hidden_size = hidden_size
208
- self.intermediate_size = intermediate_size // self.tp_size
209
- self.quant_config = quant_config
210
-
211
- # FIXME(pcmoritz): Make this more general to support different
212
- # quantization schemes
213
- self.use_fp8 = isinstance(quant_config, Fp8Config)
214
-
215
- if params_dtype is None:
216
- params_dtype = torch.get_default_dtype()
217
- self.params_dtype = params_dtype
218
73
 
219
74
  # Gate always runs at half / full precision for now.
220
75
  self.gate = ReplicatedLinear(
221
- self.hidden_size,
222
- self.num_total_experts,
76
+ hidden_size,
77
+ num_experts,
223
78
  bias=False,
224
- params_dtype=self.params_dtype,
79
+ params_dtype=params_dtype,
225
80
  quant_config=None,
226
81
  )
227
82
 
228
- if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
229
- params_dtype = torch.float8_e4m3fn
230
-
231
- self.w13_weight = nn.Parameter(
232
- torch.empty(
233
- self.num_total_experts,
234
- 2 * self.intermediate_size,
235
- self.hidden_size,
236
- dtype=params_dtype,
237
- )
238
- )
239
- self.w2_weight = nn.Parameter(
240
- torch.empty(
241
- self.num_total_experts,
242
- self.hidden_size,
243
- self.intermediate_size,
244
- dtype=params_dtype,
245
- )
246
- )
247
-
248
- set_weight_attrs(
249
- self.w13_weight,
250
- {
251
- "weight_loader": self.weight_loader,
252
- },
253
- )
254
- set_weight_attrs(
255
- self.w2_weight,
256
- {
257
- "weight_loader": self.weight_loader,
258
- },
83
+ self.experts = FusedMoE(
84
+ num_experts=num_experts,
85
+ top_k=top_k,
86
+ hidden_size=hidden_size,
87
+ intermediate_size=intermediate_size,
88
+ params_dtype=params_dtype,
89
+ reduce_results=True,
90
+ renormalize=False,
91
+ quant_config=quant_config,
92
+ tp_size=tp_size,
259
93
  )
260
94
 
261
- # Used for fp8.
262
- self.w13_scale = None
263
- self.w2_scale = None
264
- self.a13_scale = None
265
- self.a2_scale = None
266
-
267
- if self.use_fp8:
268
- # WEIGHT_SCALE (for fp8)
269
- self.w13_scale = nn.Parameter(
270
- torch.ones(self.num_total_experts, dtype=torch.float32),
271
- requires_grad=False,
272
- )
273
- self.w2_scale = nn.Parameter(
274
- torch.ones(self.num_total_experts, dtype=torch.float32),
275
- requires_grad=False,
276
- )
277
-
278
- # If loading fp8 checkpoint, pass the weight loaders.
279
- # If loading an fp16 checkpoint, do not (we will quantize in
280
- # process_weights_after_loading()
281
- if quant_config.is_checkpoint_fp8_serialized:
282
- set_weight_attrs(
283
- self.w13_scale,
284
- {
285
- "weight_loader": self.weight_loader,
286
- },
287
- )
288
- set_weight_attrs(
289
- self.w2_scale,
290
- {
291
- "weight_loader": self.weight_loader,
292
- },
293
- )
294
-
295
- # ACT_SCALE (for fp8)
296
- if quant_config.activation_scheme == "static":
297
- if not quant_config.is_checkpoint_fp8_serialized:
298
- raise ValueError(
299
- "Found static activation scheme for checkpoint that "
300
- "was not serialized fp8."
301
- )
302
- self.a13_scale = nn.Parameter(
303
- torch.zeros(self.num_total_experts, dtype=torch.float32),
304
- requires_grad=False,
305
- )
306
- self.a2_scale = nn.Parameter(
307
- torch.zeros(self.num_total_experts, dtype=torch.float32),
308
- requires_grad=False,
309
- )
310
-
311
- set_weight_attrs(
312
- self.a13_scale,
313
- {
314
- "weight_loader": self.weight_loader,
315
- },
316
- )
317
- set_weight_attrs(
318
- self.a2_scale,
319
- {
320
- "weight_loader": self.weight_loader,
321
- },
322
- )
323
-
324
- def weight_loader(
325
- self,
326
- param: nn.Parameter,
327
- loaded_weight: torch.Tensor,
328
- weight_name: str,
329
- expert_id: int,
330
- pre_sharded: bool,
331
- ):
332
- param_data = param.data
333
- shard_size = self.intermediate_size
334
- if pre_sharded:
335
- # The weight is already sharded. Readl the full shard
336
- shard = slice(None)
337
- else:
338
- tp_rank = get_tensor_model_parallel_rank()
339
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
340
- if weight_name.endswith("w1.weight"):
341
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
342
- if weight_name.endswith("w3.weight"):
343
- param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
344
- shard, :
345
- ]
346
- if weight_name.endswith("w2.weight"):
347
- param_data[expert_id, :, :] = loaded_weight[:, shard]
348
- if "act_scale" in weight_name or "weight_scale" in weight_name:
349
- param_data[expert_id] = loaded_weight
350
-
351
- def process_weights_after_loading(self):
352
- # Fp8 is the only case where we need to process after loading.
353
- if not self.use_fp8:
354
- return
355
-
356
- # If checkpoint is fp16, quantize here.
357
- if not self.quant_config.is_checkpoint_fp8_serialized:
358
- w13_weight = torch.empty_like(
359
- self.w13_weight.data, dtype=torch.float8_e4m3fn
360
- )
361
- w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
362
- for expert in range(self.num_total_experts):
363
- w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
364
- self.w13_weight.data[expert, :, :]
365
- )
366
- w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
367
- self.w2_weight.data[expert, :, :]
368
- )
369
- self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
370
- self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
371
-
372
- # If checkpoint is fp8 + static, cleanup act_scales.
373
- # Since state_dict has an act_scale per expert but our kernels
374
- # are passed one act_scale shared across all experts.
375
- elif self.quant_config.activation_scheme == "static":
376
- if self.a13_scale is None or self.a2_scale is None:
377
- raise ValueError(
378
- "QuantConfig has static quantization, but found "
379
- "activation scales are None."
380
- )
381
-
382
- if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
383
- print_warning_once(
384
- "Found act_scales that are not equal for fp8 MoE layer. "
385
- "Using the maximum across experts for each layer. "
386
- )
387
-
388
- self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
389
- self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
390
-
391
95
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
392
- num_tokens, hidden_size = hidden_states.shape
96
+ # NOTE: hidden_states can have either 1D or 2D shape.
97
+ orig_shape = hidden_states.shape
393
98
  hidden_states = hidden_states.view(-1, self.hidden_size)
394
99
  # router_logits: (num_tokens, n_experts)
395
100
  router_logits, _ = self.gate(hidden_states)
396
- final_hidden_states = fused_moe(
397
- hidden_states,
398
- self.w13_weight,
399
- self.w2_weight,
400
- router_logits,
401
- self.top_k,
402
- renormalize=False,
403
- inplace=True,
404
- use_fp8=self.use_fp8,
405
- w1_scale=self.w13_scale,
406
- w2_scale=self.w2_scale,
407
- a1_scale=self.a13_scale,
408
- a2_scale=self.a2_scale,
409
- )
410
-
411
- if self.tp_size > 1:
412
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
413
-
414
- return final_hidden_states.view(num_tokens, hidden_size)
101
+ router_logits = 30.0 * F.tanh(router_logits / 30.0)
102
+ final_hidden_states = self.experts(hidden_states, router_logits)
103
+ return final_hidden_states.view(orig_shape)
415
104
 
416
105
 
417
106
  class Grok1Attention(nn.Module):
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
478
167
  layer_id=layer_id,
479
168
  logit_cap=logit_cap,
480
169
  )
170
+ # TODO(lianmin): load logit cap from config
481
171
 
482
172
  def forward(
483
173
  self,
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
502
192
  ) -> None:
503
193
  super().__init__()
504
194
  self.hidden_size = config.hidden_size
505
- # Requires transformers > 4.32.0
195
+
506
196
  rope_theta = getattr(config, "rope_theta", 10000)
507
197
  self.self_attn = Grok1Attention(
508
198
  hidden_size=self.hidden_size,
@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
513
203
  rope_theta=rope_theta,
514
204
  quant_config=quant_config,
515
205
  )
516
- if use_fused:
517
- self.block_sparse_moe = Grok1MoE(
518
- num_experts=config.num_local_experts,
519
- top_k=config.num_experts_per_tok,
520
- hidden_size=config.hidden_size,
521
- intermediate_size=config.intermediate_size,
522
- quant_config=quant_config,
523
- )
524
- else:
525
- self.block_sparse_moe = Grok1MoEUnfused(
526
- config=config, quant_config=quant_config
527
- )
206
+ self.block_sparse_moe = Grok1MoE(
207
+ num_experts=config.num_local_experts,
208
+ top_k=config.num_experts_per_tok,
209
+ hidden_size=config.hidden_size,
210
+ intermediate_size=config.intermediate_size,
211
+ quant_config=quant_config,
212
+ )
528
213
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
529
214
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
530
215
  self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
536
221
  hidden_states: torch.Tensor,
537
222
  input_metadata: InputMetadata,
538
223
  ) -> torch.Tensor:
224
+ # Self Attention
539
225
  hidden_states = (
540
226
  self.post_attn_norm(
541
227
  self.self_attn(
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
547
233
  + hidden_states
548
234
  )
549
235
 
236
+ # Fully Connected
550
237
  hidden_states = (
551
238
  self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
552
239
  + hidden_states
553
240
  )
554
-
555
241
  return hidden_states
556
242
 
557
243
 
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
593
279
 
594
280
  for i in range(len(self.layers)):
595
281
  hidden_states = self.layers[i](positions, hidden_states, input_metadata)
596
-
597
282
  hidden_states = self.norm(hidden_states)
598
283
  hidden_states.mul_(self.config.output_multiplier_scale)
599
284
  return hidden_states
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
615
300
 
616
301
  # Monkey patch _prepare_weights to load pre-sharded weights
617
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
303
+ warnings.filterwarnings("ignore", category=FutureWarning)
618
304
 
619
- @torch.no_grad()
620
305
  def forward(
621
306
  self,
622
307
  input_ids: torch.Tensor,
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
637
322
  ("qkv_proj", "v_proj", "v"),
638
323
  ]
639
324
 
640
- if use_fused:
641
- expert_params_mapping = (
642
- [
643
- # These are the weight scales for the experts
644
- # (param_name, weight_name, expert_id)
645
- (
646
- "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
647
- f"experts.{expert_id}.{weight_name}.weight_scale",
648
- expert_id,
649
- )
650
- for expert_id in range(self.config.num_local_experts)
651
- for weight_name in ["w1", "w2", "w3"]
652
- ]
653
- + [
654
- # These are the weights for the experts
655
- # (param_name, weight_name, expert_id)
656
- (
657
- "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
658
- f"experts.{expert_id}.{weight_name}.weight",
659
- expert_id,
660
- )
661
- for expert_id in range(self.config.num_local_experts)
662
- for weight_name in ["w1", "w2", "w3"]
663
- ]
664
- + [
665
- # These are the activation scales for the experts
666
- # (param_name, weight_name, expert_id)
667
- (
668
- "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
669
- f"experts.{expert_id}.{weight_name}.act_scale",
670
- expert_id,
671
- )
672
- for expert_id in range(self.config.num_local_experts)
673
- for weight_name in ["w1", "w2", "w3"]
674
- ]
675
- )
676
- else:
677
- expert_params_mapping = []
325
+ # Params for weights, fp8 weight scales, fp8 activation scales
326
+ # (param_name, weight_name, expert_id, shard_id)
327
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
328
+ ckpt_gate_proj_name="w1",
329
+ ckpt_down_proj_name="w2",
330
+ ckpt_up_proj_name="w3",
331
+ num_experts=self.config.num_local_experts,
332
+ )
678
333
 
679
334
  params_dict = dict(self.named_parameters())
680
- if get_tensor_model_parallel_rank() == 0:
681
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
682
335
  for name, loaded_weight in weights:
683
- # print(get_tensor_model_parallel_rank(), name)
684
336
  if "rotary_emb.inv_freq" in name:
685
337
  continue
686
338
 
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
691
343
  # Skip loading extra bias for GPTQ models.
692
344
  if name.endswith(".bias") and name not in params_dict:
693
345
  continue
346
+
694
347
  param = params_dict[name]
695
348
  weight_loader = param.weight_loader
696
349
  weight_loader(param, loaded_weight, shard_id)
697
350
  break
698
351
  else:
699
- for param_name, weight_name, expert_id in expert_params_mapping:
352
+ for mapping in expert_params_mapping:
353
+ param_name, weight_name, expert_id, shard_id = mapping
700
354
  if weight_name not in name:
701
355
  continue
702
356
  name = name.replace(weight_name, param_name)
357
+
703
358
  param = params_dict[name]
704
359
  weight_loader = param.weight_loader
705
360
  weight_loader(
706
361
  param,
707
362
  loaded_weight,
708
363
  weight_name,
364
+ shard_id=shard_id,
709
365
  expert_id=expert_id,
710
366
  pre_sharded=get_tensor_model_parallel_world_size() > 1,
711
367
  )
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
714
370
  # Skip loading extra bias for GPTQ models.
715
371
  if name.endswith(".bias") and name not in params_dict:
716
372
  continue
373
+ if name is None:
374
+ continue
375
+
717
376
  param = params_dict[name]
718
377
  weight_loader = getattr(
719
378
  param, "weight_loader", default_weight_loader
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
721
380
  weight_loader(param, loaded_weight)
722
381
 
723
382
 
724
- def all_close_1d(x: torch.Tensor) -> bool:
725
- assert len(x.shape) == 1
726
- return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
727
-
728
-
729
383
  old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
730
384
 
731
385
 
@@ -22,8 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.config import CacheConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
25
  from vllm.model_executor.layers.linear import (
28
26
  MergedColumnParallelLinear,
29
27
  QKVParallelLinear,
@@ -37,6 +35,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
37
35
  )
38
36
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
38
+ from sglang.srt.layers.activation import SiluAndMul
39
+ from sglang.srt.layers.layernorm import RMSNorm
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
42
  from sglang.srt.model_executor.forward_batch_info import InputMetadata