sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -16,29 +16,24 @@ limitations under the License.
16
16
  # Adapted from
17
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
18
18
  """Inference-only Grok1 model."""
19
+ import warnings
19
20
  from typing import Iterable, List, Optional, Tuple
20
21
 
21
- import numpy as np
22
22
  import torch
23
23
  import torch.nn.functional as F
24
- import tqdm
25
24
  from torch import nn
26
25
  from transformers import PretrainedConfig
27
- from vllm import _custom_ops as ops
28
26
  from vllm.config import CacheConfig
29
27
  from vllm.distributed import (
30
28
  get_tensor_model_parallel_rank,
31
29
  get_tensor_model_parallel_world_size,
32
- tensor_model_parallel_all_reduce,
33
30
  )
34
- from vllm.model_executor.layers.layernorm import RMSNorm
35
31
  from vllm.model_executor.layers.linear import (
36
32
  QKVParallelLinear,
37
33
  ReplicatedLinear,
38
34
  RowParallelLinear,
39
35
  )
40
36
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
41
- from vllm.model_executor.layers.quantization.fp8 import Fp8Config
42
37
  from vllm.model_executor.layers.rotary_embedding import get_rope
43
38
  from vllm.model_executor.layers.vocab_parallel_embedding import (
44
39
  ParallelLMHead,
@@ -46,140 +41,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
46
41
  )
47
42
  from vllm.model_executor.model_loader.loader import DefaultModelLoader
48
43
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
- from vllm.model_executor.utils import set_weight_attrs
50
- from vllm.utils import print_warning_once
51
44
 
52
- from sglang.srt.layers.fused_moe import fused_moe
45
+ from sglang.srt.layers.fused_moe import FusedMoE
46
+ from sglang.srt.layers.layernorm import RMSNorm
53
47
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
48
  from sglang.srt.layers.radix_attention import RadixAttention
49
+ from sglang.srt.layers.sampler import Sampler
55
50
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
56
51
 
57
- use_fused = True
58
-
59
-
60
- class Grok1MLP(nn.Module):
61
- def __init__(
62
- self,
63
- num_experts: int,
64
- hidden_size: int,
65
- intermediate_size: int,
66
- quant_config: Optional[QuantizationConfig] = None,
67
- ) -> None:
68
- super().__init__()
69
- self.num_experts = num_experts
70
- self.ffn_dim = intermediate_size
71
- self.hidden_dim = hidden_size
72
-
73
- self.w1 = ReplicatedLinear(
74
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
75
- )
76
- self.w2 = ReplicatedLinear(
77
- self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
78
- )
79
- self.w3 = ReplicatedLinear(
80
- self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
81
- )
82
-
83
- self.act_fn = nn.GELU()
84
-
85
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
86
- w1_out, _ = self.w1(hidden_states)
87
- w1_out = self.act_fn(w1_out)
88
- w3_out, _ = self.w3(hidden_states)
89
- current_hidden_states = w1_out * w3_out
90
- current_hidden_states, _ = self.w2(current_hidden_states)
91
- return current_hidden_states
92
-
93
-
94
- class Grok1MoEUnfused(nn.Module):
95
- def __init__(
96
- self,
97
- config: PretrainedConfig,
98
- quant_config: Optional[QuantizationConfig] = None,
99
- ):
100
- super().__init__()
101
- self.config = config
102
- self.rank = get_tensor_model_parallel_rank()
103
- self.tp_size = get_tensor_model_parallel_world_size()
104
- self.num_total_experts = config.num_local_experts
105
- self.top_k = config.num_experts_per_tok
106
- if self.tp_size > self.num_total_experts:
107
- raise ValueError(
108
- f"Tensor parallel size {self.tp_size} is greater than "
109
- f"the number of experts {self.num_total_experts}."
110
- )
111
- # Split experts equally between ranks
112
- self.expert_indicies = np.array_split(
113
- range(self.num_total_experts), self.tp_size
114
- )[self.rank].tolist()
115
- if not self.expert_indicies:
116
- raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
117
-
118
- self.experts = nn.ModuleList(
119
- [
120
- (
121
- Grok1MLP(
122
- self.num_total_experts,
123
- config.hidden_size,
124
- config.intermediate_size,
125
- quant_config=quant_config,
126
- )
127
- if idx in self.expert_indicies
128
- else None
129
- )
130
- for idx in range(self.num_total_experts)
131
- ]
132
- )
133
- self.gate = ReplicatedLinear(
134
- config.hidden_size, self.num_total_experts, bias=False, quant_config=None
135
- )
136
-
137
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
138
- router_logits, _ = self.gate(hidden_states)
139
- router_logits = 30 * F.tanh(router_logits / 30)
140
-
141
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
142
- routing_weights, selected_experts = torch.topk(
143
- routing_weights, self.top_k, dim=-1
144
- )
145
- routing_weights = routing_weights.to(hidden_states.dtype)
146
- hidden_dim = hidden_states.shape[1]
147
-
148
- final_hidden_states = torch.zeros(
149
- (hidden_states.shape[0], hidden_dim),
150
- dtype=hidden_states.dtype,
151
- device=hidden_states.device,
152
- )
153
- expert_mask = torch.nn.functional.one_hot(
154
- selected_experts, num_classes=self.num_total_experts
155
- ).permute(2, 1, 0)
156
-
157
- for expert_idx in self.expert_indicies:
158
- expert_layer = self.experts[expert_idx]
159
- idx, top_x = torch.where(expert_mask[expert_idx])
160
-
161
- if top_x.shape[0] == 0:
162
- continue
163
-
164
- # in torch it is faster to index using lists than torch tensors
165
- top_x_list = top_x.tolist()
166
- idx_list = idx.tolist()
167
-
168
- # Index the correct hidden states and compute the expert hidden state for
169
- # the current expert. We need to make sure to multiply the output hidden
170
- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
171
- current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
172
- current_hidden_states = (
173
- expert_layer(current_state)
174
- * routing_weights[top_x_list, idx_list, None]
175
- )
176
-
177
- # However `index_add_` only support torch tensors for indexing so we'll use
178
- # the `top_x` tensor here.
179
- final_hidden_states.index_add_(0, top_x, current_hidden_states)
180
-
181
- return tensor_model_parallel_all_reduce(final_hidden_states)
182
-
183
52
 
184
53
  class Grok1MoE(nn.Module):
185
54
  """A tensor-parallel MoE implementation for Grok1 that shards each expert
@@ -197,221 +66,42 @@ class Grok1MoE(nn.Module):
197
66
  hidden_size: int,
198
67
  intermediate_size: int,
199
68
  params_dtype: Optional[torch.dtype] = None,
200
- tp_size: Optional[int] = None,
201
69
  quant_config: Optional[QuantizationConfig] = None,
70
+ tp_size: Optional[int] = None,
202
71
  ):
203
72
  super().__init__()
204
- self.tp_size = tp_size or get_tensor_model_parallel_world_size()
205
- self.num_total_experts = num_experts
206
- self.top_k = top_k
207
73
  self.hidden_size = hidden_size
208
- self.intermediate_size = intermediate_size // self.tp_size
209
- self.quant_config = quant_config
210
-
211
- # FIXME(pcmoritz): Make this more general to support different
212
- # quantization schemes
213
- self.use_fp8 = isinstance(quant_config, Fp8Config)
214
-
215
- if params_dtype is None:
216
- params_dtype = torch.get_default_dtype()
217
- self.params_dtype = params_dtype
218
74
 
219
75
  # Gate always runs at half / full precision for now.
220
76
  self.gate = ReplicatedLinear(
221
- self.hidden_size,
222
- self.num_total_experts,
77
+ hidden_size,
78
+ num_experts,
223
79
  bias=False,
224
- params_dtype=self.params_dtype,
80
+ params_dtype=params_dtype,
225
81
  quant_config=None,
226
82
  )
227
83
 
228
- if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
229
- params_dtype = torch.float8_e4m3fn
230
-
231
- self.w13_weight = nn.Parameter(
232
- torch.empty(
233
- self.num_total_experts,
234
- 2 * self.intermediate_size,
235
- self.hidden_size,
236
- dtype=params_dtype,
237
- )
238
- )
239
- self.w2_weight = nn.Parameter(
240
- torch.empty(
241
- self.num_total_experts,
242
- self.hidden_size,
243
- self.intermediate_size,
244
- dtype=params_dtype,
245
- )
246
- )
247
-
248
- set_weight_attrs(
249
- self.w13_weight,
250
- {
251
- "weight_loader": self.weight_loader,
252
- },
253
- )
254
- set_weight_attrs(
255
- self.w2_weight,
256
- {
257
- "weight_loader": self.weight_loader,
258
- },
84
+ self.experts = FusedMoE(
85
+ num_experts=num_experts,
86
+ top_k=top_k,
87
+ hidden_size=hidden_size,
88
+ intermediate_size=intermediate_size,
89
+ params_dtype=params_dtype,
90
+ reduce_results=True,
91
+ renormalize=False,
92
+ quant_config=quant_config,
93
+ tp_size=tp_size,
259
94
  )
260
95
 
261
- # Used for fp8.
262
- self.w13_scale = None
263
- self.w2_scale = None
264
- self.a13_scale = None
265
- self.a2_scale = None
266
-
267
- if self.use_fp8:
268
- # WEIGHT_SCALE (for fp8)
269
- self.w13_scale = nn.Parameter(
270
- torch.ones(self.num_total_experts, dtype=torch.float32),
271
- requires_grad=False,
272
- )
273
- self.w2_scale = nn.Parameter(
274
- torch.ones(self.num_total_experts, dtype=torch.float32),
275
- requires_grad=False,
276
- )
277
-
278
- # If loading fp8 checkpoint, pass the weight loaders.
279
- # If loading an fp16 checkpoint, do not (we will quantize in
280
- # process_weights_after_loading()
281
- if quant_config.is_checkpoint_fp8_serialized:
282
- set_weight_attrs(
283
- self.w13_scale,
284
- {
285
- "weight_loader": self.weight_loader,
286
- },
287
- )
288
- set_weight_attrs(
289
- self.w2_scale,
290
- {
291
- "weight_loader": self.weight_loader,
292
- },
293
- )
294
-
295
- # ACT_SCALE (for fp8)
296
- if quant_config.activation_scheme == "static":
297
- if not quant_config.is_checkpoint_fp8_serialized:
298
- raise ValueError(
299
- "Found static activation scheme for checkpoint that "
300
- "was not serialized fp8."
301
- )
302
- self.a13_scale = nn.Parameter(
303
- torch.zeros(self.num_total_experts, dtype=torch.float32),
304
- requires_grad=False,
305
- )
306
- self.a2_scale = nn.Parameter(
307
- torch.zeros(self.num_total_experts, dtype=torch.float32),
308
- requires_grad=False,
309
- )
310
-
311
- set_weight_attrs(
312
- self.a13_scale,
313
- {
314
- "weight_loader": self.weight_loader,
315
- },
316
- )
317
- set_weight_attrs(
318
- self.a2_scale,
319
- {
320
- "weight_loader": self.weight_loader,
321
- },
322
- )
323
-
324
- def weight_loader(
325
- self,
326
- param: nn.Parameter,
327
- loaded_weight: torch.Tensor,
328
- weight_name: str,
329
- expert_id: int,
330
- pre_sharded: bool,
331
- ):
332
- param_data = param.data
333
- shard_size = self.intermediate_size
334
- if pre_sharded:
335
- # The weight is already sharded. Readl the full shard
336
- shard = slice(None)
337
- else:
338
- tp_rank = get_tensor_model_parallel_rank()
339
- shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
340
- if weight_name.endswith("w1.weight"):
341
- param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
342
- if weight_name.endswith("w3.weight"):
343
- param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
344
- shard, :
345
- ]
346
- if weight_name.endswith("w2.weight"):
347
- param_data[expert_id, :, :] = loaded_weight[:, shard]
348
- if "act_scale" in weight_name or "weight_scale" in weight_name:
349
- param_data[expert_id] = loaded_weight
350
-
351
- def process_weights_after_loading(self):
352
- # Fp8 is the only case where we need to process after loading.
353
- if not self.use_fp8:
354
- return
355
-
356
- # If checkpoint is fp16, quantize here.
357
- if not self.quant_config.is_checkpoint_fp8_serialized:
358
- w13_weight = torch.empty_like(
359
- self.w13_weight.data, dtype=torch.float8_e4m3fn
360
- )
361
- w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
362
- for expert in range(self.num_total_experts):
363
- w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
364
- self.w13_weight.data[expert, :, :]
365
- )
366
- w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
367
- self.w2_weight.data[expert, :, :]
368
- )
369
- self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
370
- self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
371
-
372
- # If checkpoint is fp8 + static, cleanup act_scales.
373
- # Since state_dict has an act_scale per expert but our kernels
374
- # are passed one act_scale shared across all experts.
375
- elif self.quant_config.activation_scheme == "static":
376
- if self.a13_scale is None or self.a2_scale is None:
377
- raise ValueError(
378
- "QuantConfig has static quantization, but found "
379
- "activation scales are None."
380
- )
381
-
382
- if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
383
- print_warning_once(
384
- "Found act_scales that are not equal for fp8 MoE layer. "
385
- "Using the maximum across experts for each layer. "
386
- )
387
-
388
- self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
389
- self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
390
-
391
96
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
392
- num_tokens, hidden_size = hidden_states.shape
97
+ # NOTE: hidden_states can have either 1D or 2D shape.
98
+ orig_shape = hidden_states.shape
393
99
  hidden_states = hidden_states.view(-1, self.hidden_size)
394
100
  # router_logits: (num_tokens, n_experts)
395
101
  router_logits, _ = self.gate(hidden_states)
396
- final_hidden_states = fused_moe(
397
- hidden_states,
398
- self.w13_weight,
399
- self.w2_weight,
400
- router_logits,
401
- self.top_k,
402
- renormalize=False,
403
- inplace=True,
404
- use_fp8=self.use_fp8,
405
- w1_scale=self.w13_scale,
406
- w2_scale=self.w2_scale,
407
- a1_scale=self.a13_scale,
408
- a2_scale=self.a2_scale,
409
- )
410
-
411
- if self.tp_size > 1:
412
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
413
-
414
- return final_hidden_states.view(num_tokens, hidden_size)
102
+ router_logits = 30.0 * F.tanh(router_logits / 30.0)
103
+ final_hidden_states = self.experts(hidden_states, router_logits)
104
+ return final_hidden_states.view(orig_shape)
415
105
 
416
106
 
417
107
  class Grok1Attention(nn.Module):
@@ -478,6 +168,7 @@ class Grok1Attention(nn.Module):
478
168
  layer_id=layer_id,
479
169
  logit_cap=logit_cap,
480
170
  )
171
+ # TODO(lianmin): load logit cap from config
481
172
 
482
173
  def forward(
483
174
  self,
@@ -502,7 +193,7 @@ class Grok1DecoderLayer(nn.Module):
502
193
  ) -> None:
503
194
  super().__init__()
504
195
  self.hidden_size = config.hidden_size
505
- # Requires transformers > 4.32.0
196
+
506
197
  rope_theta = getattr(config, "rope_theta", 10000)
507
198
  self.self_attn = Grok1Attention(
508
199
  hidden_size=self.hidden_size,
@@ -513,18 +204,13 @@ class Grok1DecoderLayer(nn.Module):
513
204
  rope_theta=rope_theta,
514
205
  quant_config=quant_config,
515
206
  )
516
- if use_fused:
517
- self.block_sparse_moe = Grok1MoE(
518
- num_experts=config.num_local_experts,
519
- top_k=config.num_experts_per_tok,
520
- hidden_size=config.hidden_size,
521
- intermediate_size=config.intermediate_size,
522
- quant_config=quant_config,
523
- )
524
- else:
525
- self.block_sparse_moe = Grok1MoEUnfused(
526
- config=config, quant_config=quant_config
527
- )
207
+ self.block_sparse_moe = Grok1MoE(
208
+ num_experts=config.num_local_experts,
209
+ top_k=config.num_experts_per_tok,
210
+ hidden_size=config.hidden_size,
211
+ intermediate_size=config.intermediate_size,
212
+ quant_config=quant_config,
213
+ )
528
214
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
529
215
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
530
216
  self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -536,6 +222,7 @@ class Grok1DecoderLayer(nn.Module):
536
222
  hidden_states: torch.Tensor,
537
223
  input_metadata: InputMetadata,
538
224
  ) -> torch.Tensor:
225
+ # Self Attention
539
226
  hidden_states = (
540
227
  self.post_attn_norm(
541
228
  self.self_attn(
@@ -547,11 +234,11 @@ class Grok1DecoderLayer(nn.Module):
547
234
  + hidden_states
548
235
  )
549
236
 
237
+ # Fully Connected
550
238
  hidden_states = (
551
239
  self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
552
240
  + hidden_states
553
241
  )
554
-
555
242
  return hidden_states
556
243
 
557
244
 
@@ -593,7 +280,6 @@ class Grok1Model(nn.Module):
593
280
 
594
281
  for i in range(len(self.layers)):
595
282
  hidden_states = self.layers[i](positions, hidden_states, input_metadata)
596
-
597
283
  hidden_states = self.norm(hidden_states)
598
284
  hidden_states.mul_(self.config.output_multiplier_scale)
599
285
  return hidden_states
@@ -612,11 +298,15 @@ class Grok1ModelForCausalLM(nn.Module):
612
298
  self.model = Grok1Model(config, quant_config=quant_config)
613
299
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
614
300
  self.logits_processor = LogitsProcessor(config)
301
+ self.sampler = Sampler()
615
302
 
616
303
  # Monkey patch _prepare_weights to load pre-sharded weights
617
304
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
618
305
 
619
- @torch.no_grad()
306
+ self.use_presharded_weights = True
307
+
308
+ warnings.filterwarnings("ignore", category=FutureWarning)
309
+
620
310
  def forward(
621
311
  self,
622
312
  input_ids: torch.Tensor,
@@ -625,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
625
315
  input_embeds: torch.Tensor = None,
626
316
  ) -> torch.Tensor:
627
317
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
628
- return self.logits_processor(
318
+ logits_output = self.logits_processor(
629
319
  input_ids, hidden_states, self.lm_head.weight, input_metadata
630
320
  )
321
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
322
+ return sample_output, logits_output
631
323
 
632
324
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
633
325
  stacked_params_mapping = [
@@ -637,50 +329,17 @@ class Grok1ModelForCausalLM(nn.Module):
637
329
  ("qkv_proj", "v_proj", "v"),
638
330
  ]
639
331
 
640
- if use_fused:
641
- expert_params_mapping = (
642
- [
643
- # These are the weight scales for the experts
644
- # (param_name, weight_name, expert_id)
645
- (
646
- "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
647
- f"experts.{expert_id}.{weight_name}.weight_scale",
648
- expert_id,
649
- )
650
- for expert_id in range(self.config.num_local_experts)
651
- for weight_name in ["w1", "w2", "w3"]
652
- ]
653
- + [
654
- # These are the weights for the experts
655
- # (param_name, weight_name, expert_id)
656
- (
657
- "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
658
- f"experts.{expert_id}.{weight_name}.weight",
659
- expert_id,
660
- )
661
- for expert_id in range(self.config.num_local_experts)
662
- for weight_name in ["w1", "w2", "w3"]
663
- ]
664
- + [
665
- # These are the activation scales for the experts
666
- # (param_name, weight_name, expert_id)
667
- (
668
- "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
669
- f"experts.{expert_id}.{weight_name}.act_scale",
670
- expert_id,
671
- )
672
- for expert_id in range(self.config.num_local_experts)
673
- for weight_name in ["w1", "w2", "w3"]
674
- ]
675
- )
676
- else:
677
- expert_params_mapping = []
332
+ # Params for weights, fp8 weight scales, fp8 activation scales
333
+ # (param_name, weight_name, expert_id, shard_id)
334
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
335
+ ckpt_gate_proj_name="w1",
336
+ ckpt_down_proj_name="w2",
337
+ ckpt_up_proj_name="w3",
338
+ num_experts=self.config.num_local_experts,
339
+ )
678
340
 
679
341
  params_dict = dict(self.named_parameters())
680
- if get_tensor_model_parallel_rank() == 0:
681
- weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
682
342
  for name, loaded_weight in weights:
683
- # print(get_tensor_model_parallel_rank(), name)
684
343
  if "rotary_emb.inv_freq" in name:
685
344
  continue
686
345
 
@@ -691,29 +350,43 @@ class Grok1ModelForCausalLM(nn.Module):
691
350
  # Skip loading extra bias for GPTQ models.
692
351
  if name.endswith(".bias") and name not in params_dict:
693
352
  continue
353
+
694
354
  param = params_dict[name]
695
355
  weight_loader = param.weight_loader
696
356
  weight_loader(param, loaded_weight, shard_id)
697
357
  break
698
358
  else:
699
- for param_name, weight_name, expert_id in expert_params_mapping:
359
+ for mapping in expert_params_mapping:
360
+ param_name, weight_name, expert_id, shard_id = mapping
700
361
  if weight_name not in name:
701
362
  continue
702
363
  name = name.replace(weight_name, param_name)
364
+
365
+ if self.use_presharded_weights:
366
+ extra_kwargs = {
367
+ "use_presharded_weights": self.use_presharded_weights
368
+ }
369
+ else:
370
+ extra_kwargs = {}
371
+
703
372
  param = params_dict[name]
704
373
  weight_loader = param.weight_loader
705
374
  weight_loader(
706
375
  param,
707
376
  loaded_weight,
708
377
  weight_name,
378
+ shard_id=shard_id,
709
379
  expert_id=expert_id,
710
- pre_sharded=get_tensor_model_parallel_world_size() > 1,
380
+ **extra_kwargs,
711
381
  )
712
382
  break
713
383
  else:
714
384
  # Skip loading extra bias for GPTQ models.
715
385
  if name.endswith(".bias") and name not in params_dict:
716
386
  continue
387
+ if name is None:
388
+ continue
389
+
717
390
  param = params_dict[name]
718
391
  weight_loader = getattr(
719
392
  param, "weight_loader", default_weight_loader
@@ -721,11 +394,6 @@ class Grok1ModelForCausalLM(nn.Module):
721
394
  weight_loader(param, loaded_weight)
722
395
 
723
396
 
724
- def all_close_1d(x: torch.Tensor) -> bool:
725
- assert len(x.shape) == 1
726
- return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
727
-
728
-
729
397
  old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
730
398
 
731
399
 
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
40
40
  from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.sampler import Sampler
43
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
45
 
45
46
 
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
262
263
  self.model = InternLM2Model(config, quant_config)
263
264
  self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
264
265
  self.logits_processor = LogitsProcessor(config)
266
+ self.sampler = Sampler()
265
267
 
266
268
  @torch.no_grad()
267
269
  def forward(
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
272
274
  input_embeds: torch.Tensor = None,
273
275
  ) -> torch.Tensor:
274
276
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
275
- return self.logits_processor(
277
+ logits_output = self.logits_processor(
276
278
  input_ids, hidden_states, self.output.weight, input_metadata
277
279
  )
280
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
281
+ return sample_output, logits_output
278
282
 
279
283
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
280
284
  stacked_params_mapping = [
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.activation import SiluAndMul
41
41
  from sglang.srt.layers.layernorm import RMSNorm
42
- from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.sampler import Sampler
44
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
46
 
46
47
 
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
302
303
  self.model = LlamaModel(config, quant_config=quant_config)
303
304
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
304
305
  self.logits_processor = LogitsProcessor(config)
306
+ self.sampler = Sampler()
305
307
 
306
308
  @torch.no_grad()
307
309
  def forward(
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
310
312
  positions: torch.Tensor,
311
313
  input_metadata: InputMetadata,
312
314
  input_embeds: torch.Tensor = None,
313
- ) -> LogitProcessorOutput:
315
+ ) -> LogitsProcessorOutput:
314
316
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
- return self.logits_processor(
317
+ logits_output = self.logits_processor(
316
318
  input_ids, hidden_states, self.lm_head.weight, input_metadata
317
319
  )
320
+ sample_output = self.sampler(logits_output, input_metadata.sampling_info)
321
+ return sample_output, logits_output
318
322
 
319
323
  def get_module_name(self, name):
320
324
  stacked_params_mapping = [