sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -16,29 +16,24 @@ limitations under the License.
|
|
16
16
|
# Adapted from
|
17
17
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
18
18
|
"""Inference-only Grok1 model."""
|
19
|
+
import warnings
|
19
20
|
from typing import Iterable, List, Optional, Tuple
|
20
21
|
|
21
|
-
import numpy as np
|
22
22
|
import torch
|
23
23
|
import torch.nn.functional as F
|
24
|
-
import tqdm
|
25
24
|
from torch import nn
|
26
25
|
from transformers import PretrainedConfig
|
27
|
-
from vllm import _custom_ops as ops
|
28
26
|
from vllm.config import CacheConfig
|
29
27
|
from vllm.distributed import (
|
30
28
|
get_tensor_model_parallel_rank,
|
31
29
|
get_tensor_model_parallel_world_size,
|
32
|
-
tensor_model_parallel_all_reduce,
|
33
30
|
)
|
34
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
35
31
|
from vllm.model_executor.layers.linear import (
|
36
32
|
QKVParallelLinear,
|
37
33
|
ReplicatedLinear,
|
38
34
|
RowParallelLinear,
|
39
35
|
)
|
40
36
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
41
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
42
37
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
43
38
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
44
39
|
ParallelLMHead,
|
@@ -46,140 +41,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
46
41
|
)
|
47
42
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
48
43
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
|
-
from vllm.model_executor.utils import set_weight_attrs
|
50
|
-
from vllm.utils import print_warning_once
|
51
44
|
|
52
|
-
from sglang.srt.layers.fused_moe import
|
45
|
+
from sglang.srt.layers.fused_moe import FusedMoE
|
46
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
53
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
55
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
56
50
|
|
57
|
-
use_fused = True
|
58
|
-
|
59
|
-
|
60
|
-
class Grok1MLP(nn.Module):
|
61
|
-
def __init__(
|
62
|
-
self,
|
63
|
-
num_experts: int,
|
64
|
-
hidden_size: int,
|
65
|
-
intermediate_size: int,
|
66
|
-
quant_config: Optional[QuantizationConfig] = None,
|
67
|
-
) -> None:
|
68
|
-
super().__init__()
|
69
|
-
self.num_experts = num_experts
|
70
|
-
self.ffn_dim = intermediate_size
|
71
|
-
self.hidden_dim = hidden_size
|
72
|
-
|
73
|
-
self.w1 = ReplicatedLinear(
|
74
|
-
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
75
|
-
)
|
76
|
-
self.w2 = ReplicatedLinear(
|
77
|
-
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
78
|
-
)
|
79
|
-
self.w3 = ReplicatedLinear(
|
80
|
-
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
81
|
-
)
|
82
|
-
|
83
|
-
self.act_fn = nn.GELU()
|
84
|
-
|
85
|
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
86
|
-
w1_out, _ = self.w1(hidden_states)
|
87
|
-
w1_out = self.act_fn(w1_out)
|
88
|
-
w3_out, _ = self.w3(hidden_states)
|
89
|
-
current_hidden_states = w1_out * w3_out
|
90
|
-
current_hidden_states, _ = self.w2(current_hidden_states)
|
91
|
-
return current_hidden_states
|
92
|
-
|
93
|
-
|
94
|
-
class Grok1MoEUnfused(nn.Module):
|
95
|
-
def __init__(
|
96
|
-
self,
|
97
|
-
config: PretrainedConfig,
|
98
|
-
quant_config: Optional[QuantizationConfig] = None,
|
99
|
-
):
|
100
|
-
super().__init__()
|
101
|
-
self.config = config
|
102
|
-
self.rank = get_tensor_model_parallel_rank()
|
103
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
104
|
-
self.num_total_experts = config.num_local_experts
|
105
|
-
self.top_k = config.num_experts_per_tok
|
106
|
-
if self.tp_size > self.num_total_experts:
|
107
|
-
raise ValueError(
|
108
|
-
f"Tensor parallel size {self.tp_size} is greater than "
|
109
|
-
f"the number of experts {self.num_total_experts}."
|
110
|
-
)
|
111
|
-
# Split experts equally between ranks
|
112
|
-
self.expert_indicies = np.array_split(
|
113
|
-
range(self.num_total_experts), self.tp_size
|
114
|
-
)[self.rank].tolist()
|
115
|
-
if not self.expert_indicies:
|
116
|
-
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
117
|
-
|
118
|
-
self.experts = nn.ModuleList(
|
119
|
-
[
|
120
|
-
(
|
121
|
-
Grok1MLP(
|
122
|
-
self.num_total_experts,
|
123
|
-
config.hidden_size,
|
124
|
-
config.intermediate_size,
|
125
|
-
quant_config=quant_config,
|
126
|
-
)
|
127
|
-
if idx in self.expert_indicies
|
128
|
-
else None
|
129
|
-
)
|
130
|
-
for idx in range(self.num_total_experts)
|
131
|
-
]
|
132
|
-
)
|
133
|
-
self.gate = ReplicatedLinear(
|
134
|
-
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
135
|
-
)
|
136
|
-
|
137
|
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
138
|
-
router_logits, _ = self.gate(hidden_states)
|
139
|
-
router_logits = 30 * F.tanh(router_logits / 30)
|
140
|
-
|
141
|
-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
142
|
-
routing_weights, selected_experts = torch.topk(
|
143
|
-
routing_weights, self.top_k, dim=-1
|
144
|
-
)
|
145
|
-
routing_weights = routing_weights.to(hidden_states.dtype)
|
146
|
-
hidden_dim = hidden_states.shape[1]
|
147
|
-
|
148
|
-
final_hidden_states = torch.zeros(
|
149
|
-
(hidden_states.shape[0], hidden_dim),
|
150
|
-
dtype=hidden_states.dtype,
|
151
|
-
device=hidden_states.device,
|
152
|
-
)
|
153
|
-
expert_mask = torch.nn.functional.one_hot(
|
154
|
-
selected_experts, num_classes=self.num_total_experts
|
155
|
-
).permute(2, 1, 0)
|
156
|
-
|
157
|
-
for expert_idx in self.expert_indicies:
|
158
|
-
expert_layer = self.experts[expert_idx]
|
159
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
160
|
-
|
161
|
-
if top_x.shape[0] == 0:
|
162
|
-
continue
|
163
|
-
|
164
|
-
# in torch it is faster to index using lists than torch tensors
|
165
|
-
top_x_list = top_x.tolist()
|
166
|
-
idx_list = idx.tolist()
|
167
|
-
|
168
|
-
# Index the correct hidden states and compute the expert hidden state for
|
169
|
-
# the current expert. We need to make sure to multiply the output hidden
|
170
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
171
|
-
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
172
|
-
current_hidden_states = (
|
173
|
-
expert_layer(current_state)
|
174
|
-
* routing_weights[top_x_list, idx_list, None]
|
175
|
-
)
|
176
|
-
|
177
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
178
|
-
# the `top_x` tensor here.
|
179
|
-
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
180
|
-
|
181
|
-
return tensor_model_parallel_all_reduce(final_hidden_states)
|
182
|
-
|
183
51
|
|
184
52
|
class Grok1MoE(nn.Module):
|
185
53
|
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
@@ -197,221 +65,42 @@ class Grok1MoE(nn.Module):
|
|
197
65
|
hidden_size: int,
|
198
66
|
intermediate_size: int,
|
199
67
|
params_dtype: Optional[torch.dtype] = None,
|
200
|
-
tp_size: Optional[int] = None,
|
201
68
|
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
tp_size: Optional[int] = None,
|
202
70
|
):
|
203
71
|
super().__init__()
|
204
|
-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
205
|
-
self.num_total_experts = num_experts
|
206
|
-
self.top_k = top_k
|
207
72
|
self.hidden_size = hidden_size
|
208
|
-
self.intermediate_size = intermediate_size // self.tp_size
|
209
|
-
self.quant_config = quant_config
|
210
|
-
|
211
|
-
# FIXME(pcmoritz): Make this more general to support different
|
212
|
-
# quantization schemes
|
213
|
-
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
214
|
-
|
215
|
-
if params_dtype is None:
|
216
|
-
params_dtype = torch.get_default_dtype()
|
217
|
-
self.params_dtype = params_dtype
|
218
73
|
|
219
74
|
# Gate always runs at half / full precision for now.
|
220
75
|
self.gate = ReplicatedLinear(
|
221
|
-
|
222
|
-
|
76
|
+
hidden_size,
|
77
|
+
num_experts,
|
223
78
|
bias=False,
|
224
|
-
params_dtype=
|
79
|
+
params_dtype=params_dtype,
|
225
80
|
quant_config=None,
|
226
81
|
)
|
227
82
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
)
|
239
|
-
self.w2_weight = nn.Parameter(
|
240
|
-
torch.empty(
|
241
|
-
self.num_total_experts,
|
242
|
-
self.hidden_size,
|
243
|
-
self.intermediate_size,
|
244
|
-
dtype=params_dtype,
|
245
|
-
)
|
246
|
-
)
|
247
|
-
|
248
|
-
set_weight_attrs(
|
249
|
-
self.w13_weight,
|
250
|
-
{
|
251
|
-
"weight_loader": self.weight_loader,
|
252
|
-
},
|
253
|
-
)
|
254
|
-
set_weight_attrs(
|
255
|
-
self.w2_weight,
|
256
|
-
{
|
257
|
-
"weight_loader": self.weight_loader,
|
258
|
-
},
|
83
|
+
self.experts = FusedMoE(
|
84
|
+
num_experts=num_experts,
|
85
|
+
top_k=top_k,
|
86
|
+
hidden_size=hidden_size,
|
87
|
+
intermediate_size=intermediate_size,
|
88
|
+
params_dtype=params_dtype,
|
89
|
+
reduce_results=True,
|
90
|
+
renormalize=False,
|
91
|
+
quant_config=quant_config,
|
92
|
+
tp_size=tp_size,
|
259
93
|
)
|
260
94
|
|
261
|
-
# Used for fp8.
|
262
|
-
self.w13_scale = None
|
263
|
-
self.w2_scale = None
|
264
|
-
self.a13_scale = None
|
265
|
-
self.a2_scale = None
|
266
|
-
|
267
|
-
if self.use_fp8:
|
268
|
-
# WEIGHT_SCALE (for fp8)
|
269
|
-
self.w13_scale = nn.Parameter(
|
270
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
271
|
-
requires_grad=False,
|
272
|
-
)
|
273
|
-
self.w2_scale = nn.Parameter(
|
274
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
275
|
-
requires_grad=False,
|
276
|
-
)
|
277
|
-
|
278
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
279
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
280
|
-
# process_weights_after_loading()
|
281
|
-
if quant_config.is_checkpoint_fp8_serialized:
|
282
|
-
set_weight_attrs(
|
283
|
-
self.w13_scale,
|
284
|
-
{
|
285
|
-
"weight_loader": self.weight_loader,
|
286
|
-
},
|
287
|
-
)
|
288
|
-
set_weight_attrs(
|
289
|
-
self.w2_scale,
|
290
|
-
{
|
291
|
-
"weight_loader": self.weight_loader,
|
292
|
-
},
|
293
|
-
)
|
294
|
-
|
295
|
-
# ACT_SCALE (for fp8)
|
296
|
-
if quant_config.activation_scheme == "static":
|
297
|
-
if not quant_config.is_checkpoint_fp8_serialized:
|
298
|
-
raise ValueError(
|
299
|
-
"Found static activation scheme for checkpoint that "
|
300
|
-
"was not serialized fp8."
|
301
|
-
)
|
302
|
-
self.a13_scale = nn.Parameter(
|
303
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
304
|
-
requires_grad=False,
|
305
|
-
)
|
306
|
-
self.a2_scale = nn.Parameter(
|
307
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
308
|
-
requires_grad=False,
|
309
|
-
)
|
310
|
-
|
311
|
-
set_weight_attrs(
|
312
|
-
self.a13_scale,
|
313
|
-
{
|
314
|
-
"weight_loader": self.weight_loader,
|
315
|
-
},
|
316
|
-
)
|
317
|
-
set_weight_attrs(
|
318
|
-
self.a2_scale,
|
319
|
-
{
|
320
|
-
"weight_loader": self.weight_loader,
|
321
|
-
},
|
322
|
-
)
|
323
|
-
|
324
|
-
def weight_loader(
|
325
|
-
self,
|
326
|
-
param: nn.Parameter,
|
327
|
-
loaded_weight: torch.Tensor,
|
328
|
-
weight_name: str,
|
329
|
-
expert_id: int,
|
330
|
-
pre_sharded: bool,
|
331
|
-
):
|
332
|
-
param_data = param.data
|
333
|
-
shard_size = self.intermediate_size
|
334
|
-
if pre_sharded:
|
335
|
-
# The weight is already sharded. Readl the full shard
|
336
|
-
shard = slice(None)
|
337
|
-
else:
|
338
|
-
tp_rank = get_tensor_model_parallel_rank()
|
339
|
-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
340
|
-
if weight_name.endswith("w1.weight"):
|
341
|
-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
342
|
-
if weight_name.endswith("w3.weight"):
|
343
|
-
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
344
|
-
shard, :
|
345
|
-
]
|
346
|
-
if weight_name.endswith("w2.weight"):
|
347
|
-
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
348
|
-
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
349
|
-
param_data[expert_id] = loaded_weight
|
350
|
-
|
351
|
-
def process_weights_after_loading(self):
|
352
|
-
# Fp8 is the only case where we need to process after loading.
|
353
|
-
if not self.use_fp8:
|
354
|
-
return
|
355
|
-
|
356
|
-
# If checkpoint is fp16, quantize here.
|
357
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
358
|
-
w13_weight = torch.empty_like(
|
359
|
-
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
360
|
-
)
|
361
|
-
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
362
|
-
for expert in range(self.num_total_experts):
|
363
|
-
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
364
|
-
self.w13_weight.data[expert, :, :]
|
365
|
-
)
|
366
|
-
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
367
|
-
self.w2_weight.data[expert, :, :]
|
368
|
-
)
|
369
|
-
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
370
|
-
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
371
|
-
|
372
|
-
# If checkpoint is fp8 + static, cleanup act_scales.
|
373
|
-
# Since state_dict has an act_scale per expert but our kernels
|
374
|
-
# are passed one act_scale shared across all experts.
|
375
|
-
elif self.quant_config.activation_scheme == "static":
|
376
|
-
if self.a13_scale is None or self.a2_scale is None:
|
377
|
-
raise ValueError(
|
378
|
-
"QuantConfig has static quantization, but found "
|
379
|
-
"activation scales are None."
|
380
|
-
)
|
381
|
-
|
382
|
-
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
383
|
-
print_warning_once(
|
384
|
-
"Found act_scales that are not equal for fp8 MoE layer. "
|
385
|
-
"Using the maximum across experts for each layer. "
|
386
|
-
)
|
387
|
-
|
388
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
389
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
390
|
-
|
391
95
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
392
|
-
|
96
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
97
|
+
orig_shape = hidden_states.shape
|
393
98
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
394
99
|
# router_logits: (num_tokens, n_experts)
|
395
100
|
router_logits, _ = self.gate(hidden_states)
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
self.w2_weight,
|
400
|
-
router_logits,
|
401
|
-
self.top_k,
|
402
|
-
renormalize=False,
|
403
|
-
inplace=True,
|
404
|
-
use_fp8=self.use_fp8,
|
405
|
-
w1_scale=self.w13_scale,
|
406
|
-
w2_scale=self.w2_scale,
|
407
|
-
a1_scale=self.a13_scale,
|
408
|
-
a2_scale=self.a2_scale,
|
409
|
-
)
|
410
|
-
|
411
|
-
if self.tp_size > 1:
|
412
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
413
|
-
|
414
|
-
return final_hidden_states.view(num_tokens, hidden_size)
|
101
|
+
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
102
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
103
|
+
return final_hidden_states.view(orig_shape)
|
415
104
|
|
416
105
|
|
417
106
|
class Grok1Attention(nn.Module):
|
@@ -478,6 +167,7 @@ class Grok1Attention(nn.Module):
|
|
478
167
|
layer_id=layer_id,
|
479
168
|
logit_cap=logit_cap,
|
480
169
|
)
|
170
|
+
# TODO(lianmin): load logit cap from config
|
481
171
|
|
482
172
|
def forward(
|
483
173
|
self,
|
@@ -502,7 +192,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
502
192
|
) -> None:
|
503
193
|
super().__init__()
|
504
194
|
self.hidden_size = config.hidden_size
|
505
|
-
|
195
|
+
|
506
196
|
rope_theta = getattr(config, "rope_theta", 10000)
|
507
197
|
self.self_attn = Grok1Attention(
|
508
198
|
hidden_size=self.hidden_size,
|
@@ -513,18 +203,13 @@ class Grok1DecoderLayer(nn.Module):
|
|
513
203
|
rope_theta=rope_theta,
|
514
204
|
quant_config=quant_config,
|
515
205
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
)
|
524
|
-
else:
|
525
|
-
self.block_sparse_moe = Grok1MoEUnfused(
|
526
|
-
config=config, quant_config=quant_config
|
527
|
-
)
|
206
|
+
self.block_sparse_moe = Grok1MoE(
|
207
|
+
num_experts=config.num_local_experts,
|
208
|
+
top_k=config.num_experts_per_tok,
|
209
|
+
hidden_size=config.hidden_size,
|
210
|
+
intermediate_size=config.intermediate_size,
|
211
|
+
quant_config=quant_config,
|
212
|
+
)
|
528
213
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
529
214
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
530
215
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -536,6 +221,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
536
221
|
hidden_states: torch.Tensor,
|
537
222
|
input_metadata: InputMetadata,
|
538
223
|
) -> torch.Tensor:
|
224
|
+
# Self Attention
|
539
225
|
hidden_states = (
|
540
226
|
self.post_attn_norm(
|
541
227
|
self.self_attn(
|
@@ -547,11 +233,11 @@ class Grok1DecoderLayer(nn.Module):
|
|
547
233
|
+ hidden_states
|
548
234
|
)
|
549
235
|
|
236
|
+
# Fully Connected
|
550
237
|
hidden_states = (
|
551
238
|
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
552
239
|
+ hidden_states
|
553
240
|
)
|
554
|
-
|
555
241
|
return hidden_states
|
556
242
|
|
557
243
|
|
@@ -593,7 +279,6 @@ class Grok1Model(nn.Module):
|
|
593
279
|
|
594
280
|
for i in range(len(self.layers)):
|
595
281
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
596
|
-
|
597
282
|
hidden_states = self.norm(hidden_states)
|
598
283
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
599
284
|
return hidden_states
|
@@ -615,8 +300,8 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
615
300
|
|
616
301
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
617
302
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
303
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
618
304
|
|
619
|
-
@torch.no_grad()
|
620
305
|
def forward(
|
621
306
|
self,
|
622
307
|
input_ids: torch.Tensor,
|
@@ -637,50 +322,17 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
637
322
|
("qkv_proj", "v_proj", "v"),
|
638
323
|
]
|
639
324
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
expert_id,
|
649
|
-
)
|
650
|
-
for expert_id in range(self.config.num_local_experts)
|
651
|
-
for weight_name in ["w1", "w2", "w3"]
|
652
|
-
]
|
653
|
-
+ [
|
654
|
-
# These are the weights for the experts
|
655
|
-
# (param_name, weight_name, expert_id)
|
656
|
-
(
|
657
|
-
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
658
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
659
|
-
expert_id,
|
660
|
-
)
|
661
|
-
for expert_id in range(self.config.num_local_experts)
|
662
|
-
for weight_name in ["w1", "w2", "w3"]
|
663
|
-
]
|
664
|
-
+ [
|
665
|
-
# These are the activation scales for the experts
|
666
|
-
# (param_name, weight_name, expert_id)
|
667
|
-
(
|
668
|
-
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
669
|
-
f"experts.{expert_id}.{weight_name}.act_scale",
|
670
|
-
expert_id,
|
671
|
-
)
|
672
|
-
for expert_id in range(self.config.num_local_experts)
|
673
|
-
for weight_name in ["w1", "w2", "w3"]
|
674
|
-
]
|
675
|
-
)
|
676
|
-
else:
|
677
|
-
expert_params_mapping = []
|
325
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
326
|
+
# (param_name, weight_name, expert_id, shard_id)
|
327
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
328
|
+
ckpt_gate_proj_name="w1",
|
329
|
+
ckpt_down_proj_name="w2",
|
330
|
+
ckpt_up_proj_name="w3",
|
331
|
+
num_experts=self.config.num_local_experts,
|
332
|
+
)
|
678
333
|
|
679
334
|
params_dict = dict(self.named_parameters())
|
680
|
-
if get_tensor_model_parallel_rank() == 0:
|
681
|
-
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
682
335
|
for name, loaded_weight in weights:
|
683
|
-
# print(get_tensor_model_parallel_rank(), name)
|
684
336
|
if "rotary_emb.inv_freq" in name:
|
685
337
|
continue
|
686
338
|
|
@@ -691,21 +343,25 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
691
343
|
# Skip loading extra bias for GPTQ models.
|
692
344
|
if name.endswith(".bias") and name not in params_dict:
|
693
345
|
continue
|
346
|
+
|
694
347
|
param = params_dict[name]
|
695
348
|
weight_loader = param.weight_loader
|
696
349
|
weight_loader(param, loaded_weight, shard_id)
|
697
350
|
break
|
698
351
|
else:
|
699
|
-
for
|
352
|
+
for mapping in expert_params_mapping:
|
353
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
700
354
|
if weight_name not in name:
|
701
355
|
continue
|
702
356
|
name = name.replace(weight_name, param_name)
|
357
|
+
|
703
358
|
param = params_dict[name]
|
704
359
|
weight_loader = param.weight_loader
|
705
360
|
weight_loader(
|
706
361
|
param,
|
707
362
|
loaded_weight,
|
708
363
|
weight_name,
|
364
|
+
shard_id=shard_id,
|
709
365
|
expert_id=expert_id,
|
710
366
|
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
711
367
|
)
|
@@ -714,6 +370,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
714
370
|
# Skip loading extra bias for GPTQ models.
|
715
371
|
if name.endswith(".bias") and name not in params_dict:
|
716
372
|
continue
|
373
|
+
if name is None:
|
374
|
+
continue
|
375
|
+
|
717
376
|
param = params_dict[name]
|
718
377
|
weight_loader = getattr(
|
719
378
|
param, "weight_loader", default_weight_loader
|
@@ -721,11 +380,6 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
721
380
|
weight_loader(param, loaded_weight)
|
722
381
|
|
723
382
|
|
724
|
-
def all_close_1d(x: torch.Tensor) -> bool:
|
725
|
-
assert len(x.shape) == 1
|
726
|
-
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
727
|
-
|
728
|
-
|
729
383
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
730
384
|
|
731
385
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -23,8 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.config import CacheConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
27
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
28
26
|
from vllm.model_executor.layers.linear import (
|
29
27
|
MergedColumnParallelLinear,
|
30
28
|
QKVParallelLinear,
|
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
38
36
|
)
|
39
37
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
38
|
|
39
|
+
from sglang.srt.layers.activation import SiluAndMul
|
40
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
47
|
-
|
48
47
|
def __init__(
|
49
48
|
self,
|
50
49
|
hidden_size: int,
|
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
|
|
74
73
|
|
75
74
|
|
76
75
|
class InternLM2Attention(nn.Module):
|
77
|
-
|
78
76
|
def __init__(
|
79
77
|
self,
|
80
78
|
hidden_size: int,
|
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
|
|
150
148
|
|
151
149
|
|
152
150
|
class InternLMDecoderLayer(nn.Module):
|
153
|
-
|
154
151
|
def __init__(
|
155
152
|
self,
|
156
153
|
config: PretrainedConfig,
|
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
|
|
207
204
|
|
208
205
|
|
209
206
|
class InternLM2Model(nn.Module):
|
210
|
-
|
211
207
|
def __init__(
|
212
208
|
self,
|
213
209
|
config: PretrainedConfig,
|
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
|
|
254
250
|
|
255
251
|
|
256
252
|
class InternLM2ForCausalLM(nn.Module):
|
257
|
-
|
258
253
|
def __init__(
|
259
254
|
self,
|
260
255
|
config: PretrainedConfig,
|
sglang/srt/models/llama2.py
CHANGED
@@ -24,8 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import LlamaConfig
|
25
25
|
from vllm.config import CacheConfig
|
26
26
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
|
-
from vllm.model_executor.layers.activation import SiluAndMul
|
28
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
29
27
|
from vllm.model_executor.layers.linear import (
|
30
28
|
MergedColumnParallelLinear,
|
31
29
|
QKVParallelLinear,
|
@@ -39,7 +37,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
39
37
|
)
|
40
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
41
39
|
|
42
|
-
from sglang.srt.layers.
|
40
|
+
from sglang.srt.layers.activation import SiluAndMul
|
41
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
45
|
|
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
|
|
310
310
|
positions: torch.Tensor,
|
311
311
|
input_metadata: InputMetadata,
|
312
312
|
input_embeds: torch.Tensor = None,
|
313
|
-
) ->
|
313
|
+
) -> LogitProcessorOutput:
|
314
314
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
315
315
|
return self.logits_processor(
|
316
316
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|