sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -16,29 +16,24 @@ limitations under the License.
|
|
16
16
|
# Adapted from
|
17
17
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
18
18
|
"""Inference-only Grok1 model."""
|
19
|
+
import warnings
|
19
20
|
from typing import Iterable, List, Optional, Tuple
|
20
21
|
|
21
|
-
import numpy as np
|
22
22
|
import torch
|
23
23
|
import torch.nn.functional as F
|
24
|
-
import tqdm
|
25
24
|
from torch import nn
|
26
25
|
from transformers import PretrainedConfig
|
27
|
-
from vllm import _custom_ops as ops
|
28
26
|
from vllm.config import CacheConfig
|
29
27
|
from vllm.distributed import (
|
30
28
|
get_tensor_model_parallel_rank,
|
31
29
|
get_tensor_model_parallel_world_size,
|
32
|
-
tensor_model_parallel_all_reduce,
|
33
30
|
)
|
34
|
-
from vllm.model_executor.layers.layernorm import RMSNorm
|
35
31
|
from vllm.model_executor.layers.linear import (
|
36
32
|
QKVParallelLinear,
|
37
33
|
ReplicatedLinear,
|
38
34
|
RowParallelLinear,
|
39
35
|
)
|
40
36
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
41
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
42
37
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
43
38
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
44
39
|
ParallelLMHead,
|
@@ -46,140 +41,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
46
41
|
)
|
47
42
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
48
43
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
49
|
-
from vllm.model_executor.utils import set_weight_attrs
|
50
|
-
from vllm.utils import print_warning_once
|
51
44
|
|
52
|
-
from sglang.srt.layers.fused_moe import
|
45
|
+
from sglang.srt.layers.fused_moe import FusedMoE
|
46
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
53
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
+
from sglang.srt.layers.sampler import Sampler
|
55
50
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
56
51
|
|
57
|
-
use_fused = True
|
58
|
-
|
59
|
-
|
60
|
-
class Grok1MLP(nn.Module):
|
61
|
-
def __init__(
|
62
|
-
self,
|
63
|
-
num_experts: int,
|
64
|
-
hidden_size: int,
|
65
|
-
intermediate_size: int,
|
66
|
-
quant_config: Optional[QuantizationConfig] = None,
|
67
|
-
) -> None:
|
68
|
-
super().__init__()
|
69
|
-
self.num_experts = num_experts
|
70
|
-
self.ffn_dim = intermediate_size
|
71
|
-
self.hidden_dim = hidden_size
|
72
|
-
|
73
|
-
self.w1 = ReplicatedLinear(
|
74
|
-
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
75
|
-
)
|
76
|
-
self.w2 = ReplicatedLinear(
|
77
|
-
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
|
78
|
-
)
|
79
|
-
self.w3 = ReplicatedLinear(
|
80
|
-
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
|
81
|
-
)
|
82
|
-
|
83
|
-
self.act_fn = nn.GELU()
|
84
|
-
|
85
|
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
86
|
-
w1_out, _ = self.w1(hidden_states)
|
87
|
-
w1_out = self.act_fn(w1_out)
|
88
|
-
w3_out, _ = self.w3(hidden_states)
|
89
|
-
current_hidden_states = w1_out * w3_out
|
90
|
-
current_hidden_states, _ = self.w2(current_hidden_states)
|
91
|
-
return current_hidden_states
|
92
|
-
|
93
|
-
|
94
|
-
class Grok1MoEUnfused(nn.Module):
|
95
|
-
def __init__(
|
96
|
-
self,
|
97
|
-
config: PretrainedConfig,
|
98
|
-
quant_config: Optional[QuantizationConfig] = None,
|
99
|
-
):
|
100
|
-
super().__init__()
|
101
|
-
self.config = config
|
102
|
-
self.rank = get_tensor_model_parallel_rank()
|
103
|
-
self.tp_size = get_tensor_model_parallel_world_size()
|
104
|
-
self.num_total_experts = config.num_local_experts
|
105
|
-
self.top_k = config.num_experts_per_tok
|
106
|
-
if self.tp_size > self.num_total_experts:
|
107
|
-
raise ValueError(
|
108
|
-
f"Tensor parallel size {self.tp_size} is greater than "
|
109
|
-
f"the number of experts {self.num_total_experts}."
|
110
|
-
)
|
111
|
-
# Split experts equally between ranks
|
112
|
-
self.expert_indicies = np.array_split(
|
113
|
-
range(self.num_total_experts), self.tp_size
|
114
|
-
)[self.rank].tolist()
|
115
|
-
if not self.expert_indicies:
|
116
|
-
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
|
117
|
-
|
118
|
-
self.experts = nn.ModuleList(
|
119
|
-
[
|
120
|
-
(
|
121
|
-
Grok1MLP(
|
122
|
-
self.num_total_experts,
|
123
|
-
config.hidden_size,
|
124
|
-
config.intermediate_size,
|
125
|
-
quant_config=quant_config,
|
126
|
-
)
|
127
|
-
if idx in self.expert_indicies
|
128
|
-
else None
|
129
|
-
)
|
130
|
-
for idx in range(self.num_total_experts)
|
131
|
-
]
|
132
|
-
)
|
133
|
-
self.gate = ReplicatedLinear(
|
134
|
-
config.hidden_size, self.num_total_experts, bias=False, quant_config=None
|
135
|
-
)
|
136
|
-
|
137
|
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
138
|
-
router_logits, _ = self.gate(hidden_states)
|
139
|
-
router_logits = 30 * F.tanh(router_logits / 30)
|
140
|
-
|
141
|
-
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
142
|
-
routing_weights, selected_experts = torch.topk(
|
143
|
-
routing_weights, self.top_k, dim=-1
|
144
|
-
)
|
145
|
-
routing_weights = routing_weights.to(hidden_states.dtype)
|
146
|
-
hidden_dim = hidden_states.shape[1]
|
147
|
-
|
148
|
-
final_hidden_states = torch.zeros(
|
149
|
-
(hidden_states.shape[0], hidden_dim),
|
150
|
-
dtype=hidden_states.dtype,
|
151
|
-
device=hidden_states.device,
|
152
|
-
)
|
153
|
-
expert_mask = torch.nn.functional.one_hot(
|
154
|
-
selected_experts, num_classes=self.num_total_experts
|
155
|
-
).permute(2, 1, 0)
|
156
|
-
|
157
|
-
for expert_idx in self.expert_indicies:
|
158
|
-
expert_layer = self.experts[expert_idx]
|
159
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
160
|
-
|
161
|
-
if top_x.shape[0] == 0:
|
162
|
-
continue
|
163
|
-
|
164
|
-
# in torch it is faster to index using lists than torch tensors
|
165
|
-
top_x_list = top_x.tolist()
|
166
|
-
idx_list = idx.tolist()
|
167
|
-
|
168
|
-
# Index the correct hidden states and compute the expert hidden state for
|
169
|
-
# the current expert. We need to make sure to multiply the output hidden
|
170
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
171
|
-
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
172
|
-
current_hidden_states = (
|
173
|
-
expert_layer(current_state)
|
174
|
-
* routing_weights[top_x_list, idx_list, None]
|
175
|
-
)
|
176
|
-
|
177
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
178
|
-
# the `top_x` tensor here.
|
179
|
-
final_hidden_states.index_add_(0, top_x, current_hidden_states)
|
180
|
-
|
181
|
-
return tensor_model_parallel_all_reduce(final_hidden_states)
|
182
|
-
|
183
52
|
|
184
53
|
class Grok1MoE(nn.Module):
|
185
54
|
"""A tensor-parallel MoE implementation for Grok1 that shards each expert
|
@@ -197,221 +66,42 @@ class Grok1MoE(nn.Module):
|
|
197
66
|
hidden_size: int,
|
198
67
|
intermediate_size: int,
|
199
68
|
params_dtype: Optional[torch.dtype] = None,
|
200
|
-
tp_size: Optional[int] = None,
|
201
69
|
quant_config: Optional[QuantizationConfig] = None,
|
70
|
+
tp_size: Optional[int] = None,
|
202
71
|
):
|
203
72
|
super().__init__()
|
204
|
-
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
205
|
-
self.num_total_experts = num_experts
|
206
|
-
self.top_k = top_k
|
207
73
|
self.hidden_size = hidden_size
|
208
|
-
self.intermediate_size = intermediate_size // self.tp_size
|
209
|
-
self.quant_config = quant_config
|
210
|
-
|
211
|
-
# FIXME(pcmoritz): Make this more general to support different
|
212
|
-
# quantization schemes
|
213
|
-
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
214
|
-
|
215
|
-
if params_dtype is None:
|
216
|
-
params_dtype = torch.get_default_dtype()
|
217
|
-
self.params_dtype = params_dtype
|
218
74
|
|
219
75
|
# Gate always runs at half / full precision for now.
|
220
76
|
self.gate = ReplicatedLinear(
|
221
|
-
|
222
|
-
|
77
|
+
hidden_size,
|
78
|
+
num_experts,
|
223
79
|
bias=False,
|
224
|
-
params_dtype=
|
80
|
+
params_dtype=params_dtype,
|
225
81
|
quant_config=None,
|
226
82
|
)
|
227
83
|
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
)
|
239
|
-
self.w2_weight = nn.Parameter(
|
240
|
-
torch.empty(
|
241
|
-
self.num_total_experts,
|
242
|
-
self.hidden_size,
|
243
|
-
self.intermediate_size,
|
244
|
-
dtype=params_dtype,
|
245
|
-
)
|
246
|
-
)
|
247
|
-
|
248
|
-
set_weight_attrs(
|
249
|
-
self.w13_weight,
|
250
|
-
{
|
251
|
-
"weight_loader": self.weight_loader,
|
252
|
-
},
|
253
|
-
)
|
254
|
-
set_weight_attrs(
|
255
|
-
self.w2_weight,
|
256
|
-
{
|
257
|
-
"weight_loader": self.weight_loader,
|
258
|
-
},
|
84
|
+
self.experts = FusedMoE(
|
85
|
+
num_experts=num_experts,
|
86
|
+
top_k=top_k,
|
87
|
+
hidden_size=hidden_size,
|
88
|
+
intermediate_size=intermediate_size,
|
89
|
+
params_dtype=params_dtype,
|
90
|
+
reduce_results=True,
|
91
|
+
renormalize=False,
|
92
|
+
quant_config=quant_config,
|
93
|
+
tp_size=tp_size,
|
259
94
|
)
|
260
95
|
|
261
|
-
# Used for fp8.
|
262
|
-
self.w13_scale = None
|
263
|
-
self.w2_scale = None
|
264
|
-
self.a13_scale = None
|
265
|
-
self.a2_scale = None
|
266
|
-
|
267
|
-
if self.use_fp8:
|
268
|
-
# WEIGHT_SCALE (for fp8)
|
269
|
-
self.w13_scale = nn.Parameter(
|
270
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
271
|
-
requires_grad=False,
|
272
|
-
)
|
273
|
-
self.w2_scale = nn.Parameter(
|
274
|
-
torch.ones(self.num_total_experts, dtype=torch.float32),
|
275
|
-
requires_grad=False,
|
276
|
-
)
|
277
|
-
|
278
|
-
# If loading fp8 checkpoint, pass the weight loaders.
|
279
|
-
# If loading an fp16 checkpoint, do not (we will quantize in
|
280
|
-
# process_weights_after_loading()
|
281
|
-
if quant_config.is_checkpoint_fp8_serialized:
|
282
|
-
set_weight_attrs(
|
283
|
-
self.w13_scale,
|
284
|
-
{
|
285
|
-
"weight_loader": self.weight_loader,
|
286
|
-
},
|
287
|
-
)
|
288
|
-
set_weight_attrs(
|
289
|
-
self.w2_scale,
|
290
|
-
{
|
291
|
-
"weight_loader": self.weight_loader,
|
292
|
-
},
|
293
|
-
)
|
294
|
-
|
295
|
-
# ACT_SCALE (for fp8)
|
296
|
-
if quant_config.activation_scheme == "static":
|
297
|
-
if not quant_config.is_checkpoint_fp8_serialized:
|
298
|
-
raise ValueError(
|
299
|
-
"Found static activation scheme for checkpoint that "
|
300
|
-
"was not serialized fp8."
|
301
|
-
)
|
302
|
-
self.a13_scale = nn.Parameter(
|
303
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
304
|
-
requires_grad=False,
|
305
|
-
)
|
306
|
-
self.a2_scale = nn.Parameter(
|
307
|
-
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
308
|
-
requires_grad=False,
|
309
|
-
)
|
310
|
-
|
311
|
-
set_weight_attrs(
|
312
|
-
self.a13_scale,
|
313
|
-
{
|
314
|
-
"weight_loader": self.weight_loader,
|
315
|
-
},
|
316
|
-
)
|
317
|
-
set_weight_attrs(
|
318
|
-
self.a2_scale,
|
319
|
-
{
|
320
|
-
"weight_loader": self.weight_loader,
|
321
|
-
},
|
322
|
-
)
|
323
|
-
|
324
|
-
def weight_loader(
|
325
|
-
self,
|
326
|
-
param: nn.Parameter,
|
327
|
-
loaded_weight: torch.Tensor,
|
328
|
-
weight_name: str,
|
329
|
-
expert_id: int,
|
330
|
-
pre_sharded: bool,
|
331
|
-
):
|
332
|
-
param_data = param.data
|
333
|
-
shard_size = self.intermediate_size
|
334
|
-
if pre_sharded:
|
335
|
-
# The weight is already sharded. Readl the full shard
|
336
|
-
shard = slice(None)
|
337
|
-
else:
|
338
|
-
tp_rank = get_tensor_model_parallel_rank()
|
339
|
-
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
340
|
-
if weight_name.endswith("w1.weight"):
|
341
|
-
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
342
|
-
if weight_name.endswith("w3.weight"):
|
343
|
-
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
344
|
-
shard, :
|
345
|
-
]
|
346
|
-
if weight_name.endswith("w2.weight"):
|
347
|
-
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
348
|
-
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
349
|
-
param_data[expert_id] = loaded_weight
|
350
|
-
|
351
|
-
def process_weights_after_loading(self):
|
352
|
-
# Fp8 is the only case where we need to process after loading.
|
353
|
-
if not self.use_fp8:
|
354
|
-
return
|
355
|
-
|
356
|
-
# If checkpoint is fp16, quantize here.
|
357
|
-
if not self.quant_config.is_checkpoint_fp8_serialized:
|
358
|
-
w13_weight = torch.empty_like(
|
359
|
-
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
360
|
-
)
|
361
|
-
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
362
|
-
for expert in range(self.num_total_experts):
|
363
|
-
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
364
|
-
self.w13_weight.data[expert, :, :]
|
365
|
-
)
|
366
|
-
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
367
|
-
self.w2_weight.data[expert, :, :]
|
368
|
-
)
|
369
|
-
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
370
|
-
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
371
|
-
|
372
|
-
# If checkpoint is fp8 + static, cleanup act_scales.
|
373
|
-
# Since state_dict has an act_scale per expert but our kernels
|
374
|
-
# are passed one act_scale shared across all experts.
|
375
|
-
elif self.quant_config.activation_scheme == "static":
|
376
|
-
if self.a13_scale is None or self.a2_scale is None:
|
377
|
-
raise ValueError(
|
378
|
-
"QuantConfig has static quantization, but found "
|
379
|
-
"activation scales are None."
|
380
|
-
)
|
381
|
-
|
382
|
-
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
383
|
-
print_warning_once(
|
384
|
-
"Found act_scales that are not equal for fp8 MoE layer. "
|
385
|
-
"Using the maximum across experts for each layer. "
|
386
|
-
)
|
387
|
-
|
388
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
389
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
390
|
-
|
391
96
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
392
|
-
|
97
|
+
# NOTE: hidden_states can have either 1D or 2D shape.
|
98
|
+
orig_shape = hidden_states.shape
|
393
99
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
394
100
|
# router_logits: (num_tokens, n_experts)
|
395
101
|
router_logits, _ = self.gate(hidden_states)
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
self.w2_weight,
|
400
|
-
router_logits,
|
401
|
-
self.top_k,
|
402
|
-
renormalize=False,
|
403
|
-
inplace=True,
|
404
|
-
use_fp8=self.use_fp8,
|
405
|
-
w1_scale=self.w13_scale,
|
406
|
-
w2_scale=self.w2_scale,
|
407
|
-
a1_scale=self.a13_scale,
|
408
|
-
a2_scale=self.a2_scale,
|
409
|
-
)
|
410
|
-
|
411
|
-
if self.tp_size > 1:
|
412
|
-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
413
|
-
|
414
|
-
return final_hidden_states.view(num_tokens, hidden_size)
|
102
|
+
router_logits = 30.0 * F.tanh(router_logits / 30.0)
|
103
|
+
final_hidden_states = self.experts(hidden_states, router_logits)
|
104
|
+
return final_hidden_states.view(orig_shape)
|
415
105
|
|
416
106
|
|
417
107
|
class Grok1Attention(nn.Module):
|
@@ -478,6 +168,7 @@ class Grok1Attention(nn.Module):
|
|
478
168
|
layer_id=layer_id,
|
479
169
|
logit_cap=logit_cap,
|
480
170
|
)
|
171
|
+
# TODO(lianmin): load logit cap from config
|
481
172
|
|
482
173
|
def forward(
|
483
174
|
self,
|
@@ -502,7 +193,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
502
193
|
) -> None:
|
503
194
|
super().__init__()
|
504
195
|
self.hidden_size = config.hidden_size
|
505
|
-
|
196
|
+
|
506
197
|
rope_theta = getattr(config, "rope_theta", 10000)
|
507
198
|
self.self_attn = Grok1Attention(
|
508
199
|
hidden_size=self.hidden_size,
|
@@ -513,18 +204,13 @@ class Grok1DecoderLayer(nn.Module):
|
|
513
204
|
rope_theta=rope_theta,
|
514
205
|
quant_config=quant_config,
|
515
206
|
)
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
)
|
524
|
-
else:
|
525
|
-
self.block_sparse_moe = Grok1MoEUnfused(
|
526
|
-
config=config, quant_config=quant_config
|
527
|
-
)
|
207
|
+
self.block_sparse_moe = Grok1MoE(
|
208
|
+
num_experts=config.num_local_experts,
|
209
|
+
top_k=config.num_experts_per_tok,
|
210
|
+
hidden_size=config.hidden_size,
|
211
|
+
intermediate_size=config.intermediate_size,
|
212
|
+
quant_config=quant_config,
|
213
|
+
)
|
528
214
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
529
215
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
530
216
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -536,6 +222,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
536
222
|
hidden_states: torch.Tensor,
|
537
223
|
input_metadata: InputMetadata,
|
538
224
|
) -> torch.Tensor:
|
225
|
+
# Self Attention
|
539
226
|
hidden_states = (
|
540
227
|
self.post_attn_norm(
|
541
228
|
self.self_attn(
|
@@ -547,11 +234,11 @@ class Grok1DecoderLayer(nn.Module):
|
|
547
234
|
+ hidden_states
|
548
235
|
)
|
549
236
|
|
237
|
+
# Fully Connected
|
550
238
|
hidden_states = (
|
551
239
|
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
552
240
|
+ hidden_states
|
553
241
|
)
|
554
|
-
|
555
242
|
return hidden_states
|
556
243
|
|
557
244
|
|
@@ -593,7 +280,6 @@ class Grok1Model(nn.Module):
|
|
593
280
|
|
594
281
|
for i in range(len(self.layers)):
|
595
282
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
596
|
-
|
597
283
|
hidden_states = self.norm(hidden_states)
|
598
284
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
599
285
|
return hidden_states
|
@@ -612,11 +298,15 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
612
298
|
self.model = Grok1Model(config, quant_config=quant_config)
|
613
299
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
614
300
|
self.logits_processor = LogitsProcessor(config)
|
301
|
+
self.sampler = Sampler()
|
615
302
|
|
616
303
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
617
304
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
618
305
|
|
619
|
-
|
306
|
+
self.use_presharded_weights = True
|
307
|
+
|
308
|
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
309
|
+
|
620
310
|
def forward(
|
621
311
|
self,
|
622
312
|
input_ids: torch.Tensor,
|
@@ -625,9 +315,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
625
315
|
input_embeds: torch.Tensor = None,
|
626
316
|
) -> torch.Tensor:
|
627
317
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
628
|
-
|
318
|
+
logits_output = self.logits_processor(
|
629
319
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
630
320
|
)
|
321
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
322
|
+
return sample_output, logits_output
|
631
323
|
|
632
324
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
633
325
|
stacked_params_mapping = [
|
@@ -637,50 +329,17 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
637
329
|
("qkv_proj", "v_proj", "v"),
|
638
330
|
]
|
639
331
|
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
expert_id,
|
649
|
-
)
|
650
|
-
for expert_id in range(self.config.num_local_experts)
|
651
|
-
for weight_name in ["w1", "w2", "w3"]
|
652
|
-
]
|
653
|
-
+ [
|
654
|
-
# These are the weights for the experts
|
655
|
-
# (param_name, weight_name, expert_id)
|
656
|
-
(
|
657
|
-
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
658
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
659
|
-
expert_id,
|
660
|
-
)
|
661
|
-
for expert_id in range(self.config.num_local_experts)
|
662
|
-
for weight_name in ["w1", "w2", "w3"]
|
663
|
-
]
|
664
|
-
+ [
|
665
|
-
# These are the activation scales for the experts
|
666
|
-
# (param_name, weight_name, expert_id)
|
667
|
-
(
|
668
|
-
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
669
|
-
f"experts.{expert_id}.{weight_name}.act_scale",
|
670
|
-
expert_id,
|
671
|
-
)
|
672
|
-
for expert_id in range(self.config.num_local_experts)
|
673
|
-
for weight_name in ["w1", "w2", "w3"]
|
674
|
-
]
|
675
|
-
)
|
676
|
-
else:
|
677
|
-
expert_params_mapping = []
|
332
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
333
|
+
# (param_name, weight_name, expert_id, shard_id)
|
334
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
335
|
+
ckpt_gate_proj_name="w1",
|
336
|
+
ckpt_down_proj_name="w2",
|
337
|
+
ckpt_up_proj_name="w3",
|
338
|
+
num_experts=self.config.num_local_experts,
|
339
|
+
)
|
678
340
|
|
679
341
|
params_dict = dict(self.named_parameters())
|
680
|
-
if get_tensor_model_parallel_rank() == 0:
|
681
|
-
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
682
342
|
for name, loaded_weight in weights:
|
683
|
-
# print(get_tensor_model_parallel_rank(), name)
|
684
343
|
if "rotary_emb.inv_freq" in name:
|
685
344
|
continue
|
686
345
|
|
@@ -691,29 +350,43 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
691
350
|
# Skip loading extra bias for GPTQ models.
|
692
351
|
if name.endswith(".bias") and name not in params_dict:
|
693
352
|
continue
|
353
|
+
|
694
354
|
param = params_dict[name]
|
695
355
|
weight_loader = param.weight_loader
|
696
356
|
weight_loader(param, loaded_weight, shard_id)
|
697
357
|
break
|
698
358
|
else:
|
699
|
-
for
|
359
|
+
for mapping in expert_params_mapping:
|
360
|
+
param_name, weight_name, expert_id, shard_id = mapping
|
700
361
|
if weight_name not in name:
|
701
362
|
continue
|
702
363
|
name = name.replace(weight_name, param_name)
|
364
|
+
|
365
|
+
if self.use_presharded_weights:
|
366
|
+
extra_kwargs = {
|
367
|
+
"use_presharded_weights": self.use_presharded_weights
|
368
|
+
}
|
369
|
+
else:
|
370
|
+
extra_kwargs = {}
|
371
|
+
|
703
372
|
param = params_dict[name]
|
704
373
|
weight_loader = param.weight_loader
|
705
374
|
weight_loader(
|
706
375
|
param,
|
707
376
|
loaded_weight,
|
708
377
|
weight_name,
|
378
|
+
shard_id=shard_id,
|
709
379
|
expert_id=expert_id,
|
710
|
-
|
380
|
+
**extra_kwargs,
|
711
381
|
)
|
712
382
|
break
|
713
383
|
else:
|
714
384
|
# Skip loading extra bias for GPTQ models.
|
715
385
|
if name.endswith(".bias") and name not in params_dict:
|
716
386
|
continue
|
387
|
+
if name is None:
|
388
|
+
continue
|
389
|
+
|
717
390
|
param = params_dict[name]
|
718
391
|
weight_loader = getattr(
|
719
392
|
param, "weight_loader", default_weight_loader
|
@@ -721,11 +394,6 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
721
394
|
weight_loader(param, loaded_weight)
|
722
395
|
|
723
396
|
|
724
|
-
def all_close_1d(x: torch.Tensor) -> bool:
|
725
|
-
assert len(x.shape) == 1
|
726
|
-
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
727
|
-
|
728
|
-
|
729
397
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
730
398
|
|
731
399
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
40
40
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.sampler import Sampler
|
43
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
45
|
|
45
46
|
|
@@ -262,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|
262
263
|
self.model = InternLM2Model(config, quant_config)
|
263
264
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
264
265
|
self.logits_processor = LogitsProcessor(config)
|
266
|
+
self.sampler = Sampler()
|
265
267
|
|
266
268
|
@torch.no_grad()
|
267
269
|
def forward(
|
@@ -272,9 +274,11 @@ class InternLM2ForCausalLM(nn.Module):
|
|
272
274
|
input_embeds: torch.Tensor = None,
|
273
275
|
) -> torch.Tensor:
|
274
276
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
275
|
-
|
277
|
+
logits_output = self.logits_processor(
|
276
278
|
input_ids, hidden_states, self.output.weight, input_metadata
|
277
279
|
)
|
280
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
281
|
+
return sample_output, logits_output
|
278
282
|
|
279
283
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
280
284
|
stacked_params_mapping = [
|
sglang/srt/models/llama2.py
CHANGED
@@ -39,8 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
-
from sglang.srt.layers.logits_processor import
|
42
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
+
from sglang.srt.layers.sampler import Sampler
|
44
45
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
46
|
|
46
47
|
|
@@ -302,6 +303,7 @@ class LlamaForCausalLM(nn.Module):
|
|
302
303
|
self.model = LlamaModel(config, quant_config=quant_config)
|
303
304
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
304
305
|
self.logits_processor = LogitsProcessor(config)
|
306
|
+
self.sampler = Sampler()
|
305
307
|
|
306
308
|
@torch.no_grad()
|
307
309
|
def forward(
|
@@ -310,11 +312,13 @@ class LlamaForCausalLM(nn.Module):
|
|
310
312
|
positions: torch.Tensor,
|
311
313
|
input_metadata: InputMetadata,
|
312
314
|
input_embeds: torch.Tensor = None,
|
313
|
-
) ->
|
315
|
+
) -> LogitsProcessorOutput:
|
314
316
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
315
|
-
|
317
|
+
logits_output = self.logits_processor(
|
316
318
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
317
319
|
)
|
320
|
+
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
321
|
+
return sample_output, logits_output
|
318
322
|
|
319
323
|
def get_module_name(self, name):
|
320
324
|
stacked_params_mapping = [
|