sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +4 -4
- sglang/backend/litellm.py +2 -2
- sglang/backend/openai.py +26 -15
- sglang/bench_latency.py +299 -0
- sglang/global_config.py +4 -1
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +1 -1
- sglang/lang/ir.py +15 -5
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +2 -1
- sglang/srt/constrained/__init__.py +13 -6
- sglang/srt/constrained/fsm_cache.py +6 -3
- sglang/srt/constrained/jump_forward.py +113 -25
- sglang/srt/conversation.py +2 -0
- sglang/srt/flush_cache.py +2 -0
- sglang/srt/hf_transformers_utils.py +64 -9
- sglang/srt/layers/fused_moe.py +186 -89
- sglang/srt/layers/logits_processor.py +53 -25
- sglang/srt/layers/radix_attention.py +34 -7
- sglang/srt/managers/controller/dp_worker.py +6 -3
- sglang/srt/managers/controller/infer_batch.py +142 -67
- sglang/srt/managers/controller/manager_multi.py +5 -5
- sglang/srt/managers/controller/manager_single.py +8 -3
- sglang/srt/managers/controller/model_runner.py +154 -54
- sglang/srt/managers/controller/radix_cache.py +4 -0
- sglang/srt/managers/controller/schedule_heuristic.py +2 -0
- sglang/srt/managers/controller/tp_worker.py +140 -135
- sglang/srt/managers/detokenizer_manager.py +15 -19
- sglang/srt/managers/io_struct.py +10 -4
- sglang/srt/managers/tokenizer_manager.py +14 -13
- sglang/srt/model_config.py +83 -4
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/gemma.py +5 -1
- sglang/srt/models/grok.py +204 -137
- sglang/srt/models/llama2.py +11 -4
- sglang/srt/models/llama_classification.py +104 -0
- sglang/srt/models/llava.py +11 -8
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/mixtral.py +164 -115
- sglang/srt/models/mixtral_quant.py +0 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/models/yivl.py +2 -2
- sglang/srt/openai_api_adapter.py +33 -23
- sglang/srt/openai_protocol.py +1 -1
- sglang/srt/server.py +60 -19
- sglang/srt/server_args.py +79 -44
- sglang/srt/utils.py +146 -37
- sglang/test/test_programs.py +28 -10
- sglang/utils.py +4 -3
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
- sglang-0.1.18.dist-info/RECORD +78 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
- sglang/srt/managers/router/infer_batch.py +0 -596
- sglang/srt/managers/router/manager.py +0 -82
- sglang/srt/managers/router/model_rpc.py +0 -818
- sglang/srt/managers/router/model_runner.py +0 -445
- sglang/srt/managers/router/radix_cache.py +0 -267
- sglang/srt/managers/router/scheduler.py +0 -59
- sglang-0.1.17.dist-info/RECORD +0 -81
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
- {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
3
3
|
"""Inference-only Grok1 model."""
|
4
|
-
from typing import Iterable, Optional, Tuple
|
4
|
+
from typing import Iterable, List, Optional, Tuple
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
@@ -9,7 +9,6 @@ import torch.nn.functional as F
|
|
9
9
|
import tqdm
|
10
10
|
from torch import nn
|
11
11
|
from transformers import PretrainedConfig
|
12
|
-
|
13
12
|
from vllm import _custom_ops as ops
|
14
13
|
from vllm.config import CacheConfig
|
15
14
|
from vllm.distributed import (
|
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
34
|
from vllm.model_executor.utils import set_weight_attrs
|
36
35
|
from vllm.utils import print_warning_once
|
37
36
|
|
38
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
37
|
from sglang.srt.layers.fused_moe import fused_moe
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
40
|
from sglang.srt.managers.controller.model_runner import InputMetadata
|
42
41
|
|
43
|
-
|
44
42
|
use_fused = True
|
45
43
|
|
46
44
|
|
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
|
|
134
132
|
|
135
133
|
final_hidden_states = torch.zeros(
|
136
134
|
(hidden_states.shape[0], hidden_dim),
|
137
|
-
dtype=hidden_states.dtype,
|
135
|
+
dtype=hidden_states.dtype,
|
136
|
+
device=hidden_states.device,
|
138
137
|
)
|
139
|
-
expert_mask = torch.nn.functional.one_hot(
|
138
|
+
expert_mask = torch.nn.functional.one_hot(
|
139
|
+
selected_experts, num_classes=self.num_total_experts
|
140
|
+
).permute(2, 1, 0)
|
140
141
|
|
141
142
|
for expert_idx in self.expert_indicies:
|
142
143
|
expert_layer = self.experts[expert_idx]
|
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
|
|
153
154
|
# the current expert. We need to make sure to multiply the output hidden
|
154
155
|
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
155
156
|
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
156
|
-
current_hidden_states =
|
157
|
+
current_hidden_states = (
|
158
|
+
expert_layer(current_state)
|
159
|
+
* routing_weights[top_x_list, idx_list, None]
|
160
|
+
)
|
157
161
|
|
158
162
|
# However `index_add_` only support torch tensors for indexing so we'll use
|
159
163
|
# the `top_x` tensor here.
|
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
|
|
198
202
|
self.params_dtype = params_dtype
|
199
203
|
|
200
204
|
# Gate always runs at half / full precision for now.
|
201
|
-
self.gate = ReplicatedLinear(
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
205
|
+
self.gate = ReplicatedLinear(
|
206
|
+
self.hidden_size,
|
207
|
+
self.num_total_experts,
|
208
|
+
bias=False,
|
209
|
+
params_dtype=self.params_dtype,
|
210
|
+
quant_config=None,
|
211
|
+
)
|
206
212
|
|
207
213
|
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
|
208
214
|
params_dtype = torch.float8_e4m3fn
|
209
215
|
|
210
216
|
self.w13_weight = nn.Parameter(
|
211
|
-
torch.empty(
|
212
|
-
|
213
|
-
|
214
|
-
|
217
|
+
torch.empty(
|
218
|
+
self.num_total_experts,
|
219
|
+
2 * self.intermediate_size,
|
220
|
+
self.hidden_size,
|
221
|
+
dtype=params_dtype,
|
222
|
+
)
|
223
|
+
)
|
215
224
|
self.w2_weight = nn.Parameter(
|
216
|
-
torch.empty(
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
set_weight_attrs(
|
225
|
-
|
226
|
-
|
225
|
+
torch.empty(
|
226
|
+
self.num_total_experts,
|
227
|
+
self.hidden_size,
|
228
|
+
self.intermediate_size,
|
229
|
+
dtype=params_dtype,
|
230
|
+
)
|
231
|
+
)
|
232
|
+
|
233
|
+
set_weight_attrs(
|
234
|
+
self.w13_weight,
|
235
|
+
{
|
236
|
+
"weight_loader": self.weight_loader,
|
237
|
+
},
|
238
|
+
)
|
239
|
+
set_weight_attrs(
|
240
|
+
self.w2_weight,
|
241
|
+
{
|
242
|
+
"weight_loader": self.weight_loader,
|
243
|
+
},
|
244
|
+
)
|
227
245
|
|
228
246
|
# Used for fp8.
|
229
247
|
self.w13_scale = None
|
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
|
|
233
251
|
|
234
252
|
if self.use_fp8:
|
235
253
|
# WEIGHT_SCALE (for fp8)
|
236
|
-
self.w13_scale = nn.Parameter(
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
254
|
+
self.w13_scale = nn.Parameter(
|
255
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
256
|
+
requires_grad=False,
|
257
|
+
)
|
258
|
+
self.w2_scale = nn.Parameter(
|
259
|
+
torch.ones(self.num_total_experts, dtype=torch.float32),
|
260
|
+
requires_grad=False,
|
261
|
+
)
|
242
262
|
|
243
263
|
# If loading fp8 checkpoint, pass the weight loaders.
|
244
264
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
245
265
|
# process_weights_after_loading()
|
246
266
|
if quant_config.is_checkpoint_fp8_serialized:
|
247
|
-
set_weight_attrs(
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
267
|
+
set_weight_attrs(
|
268
|
+
self.w13_scale,
|
269
|
+
{
|
270
|
+
"weight_loader": self.weight_loader,
|
271
|
+
},
|
272
|
+
)
|
273
|
+
set_weight_attrs(
|
274
|
+
self.w2_scale,
|
275
|
+
{
|
276
|
+
"weight_loader": self.weight_loader,
|
277
|
+
},
|
278
|
+
)
|
253
279
|
|
254
280
|
# ACT_SCALE (for fp8)
|
255
281
|
if quant_config.activation_scheme == "static":
|
256
282
|
if not quant_config.is_checkpoint_fp8_serialized:
|
257
283
|
raise ValueError(
|
258
284
|
"Found static activation scheme for checkpoint that "
|
259
|
-
"was not serialized fp8."
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
set_weight_attrs(
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
285
|
+
"was not serialized fp8."
|
286
|
+
)
|
287
|
+
self.a13_scale = nn.Parameter(
|
288
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
289
|
+
requires_grad=False,
|
290
|
+
)
|
291
|
+
self.a2_scale = nn.Parameter(
|
292
|
+
torch.zeros(self.num_total_experts, dtype=torch.float32),
|
293
|
+
requires_grad=False,
|
294
|
+
)
|
295
|
+
|
296
|
+
set_weight_attrs(
|
297
|
+
self.a13_scale,
|
298
|
+
{
|
299
|
+
"weight_loader": self.weight_loader,
|
300
|
+
},
|
301
|
+
)
|
302
|
+
set_weight_attrs(
|
303
|
+
self.a2_scale,
|
304
|
+
{
|
305
|
+
"weight_loader": self.weight_loader,
|
306
|
+
},
|
307
|
+
)
|
308
|
+
|
309
|
+
def weight_loader(
|
310
|
+
self,
|
311
|
+
param: nn.Parameter,
|
312
|
+
loaded_weight: torch.Tensor,
|
313
|
+
weight_name: str,
|
314
|
+
expert_id: int,
|
315
|
+
pre_sharded: bool,
|
316
|
+
):
|
276
317
|
param_data = param.data
|
277
318
|
shard_size = self.intermediate_size
|
278
319
|
if pre_sharded:
|
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
|
|
284
325
|
if weight_name.endswith("w1.weight"):
|
285
326
|
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
286
327
|
if weight_name.endswith("w3.weight"):
|
287
|
-
param_data[expert_id,
|
288
|
-
|
328
|
+
param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
|
329
|
+
shard, :
|
330
|
+
]
|
289
331
|
if weight_name.endswith("w2.weight"):
|
290
332
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
291
333
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
|
|
298
340
|
|
299
341
|
# If checkpoint is fp16, quantize here.
|
300
342
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
301
|
-
w13_weight = torch.empty_like(
|
302
|
-
|
303
|
-
|
304
|
-
|
343
|
+
w13_weight = torch.empty_like(
|
344
|
+
self.w13_weight.data, dtype=torch.float8_e4m3fn
|
345
|
+
)
|
346
|
+
w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
|
305
347
|
for expert in range(self.num_total_experts):
|
306
|
-
w13_weight[expert, :, :], self.w13_scale[
|
307
|
-
expert
|
308
|
-
|
309
|
-
w2_weight[expert, :, :], self.w2_scale[
|
310
|
-
expert
|
311
|
-
|
348
|
+
w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
|
349
|
+
self.w13_weight.data[expert, :, :]
|
350
|
+
)
|
351
|
+
w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
|
352
|
+
self.w2_weight.data[expert, :, :]
|
353
|
+
)
|
312
354
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
313
355
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
314
356
|
|
@@ -319,40 +361,40 @@ class Grok1MoE(nn.Module):
|
|
319
361
|
if self.a13_scale is None or self.a2_scale is None:
|
320
362
|
raise ValueError(
|
321
363
|
"QuantConfig has static quantization, but found "
|
322
|
-
"activation scales are None."
|
364
|
+
"activation scales are None."
|
365
|
+
)
|
323
366
|
|
324
|
-
if
|
325
|
-
or not all_close_1d(self.a2_scale)):
|
367
|
+
if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
|
326
368
|
print_warning_once(
|
327
369
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
328
|
-
"Using the maximum across experts for each layer. "
|
370
|
+
"Using the maximum across experts for each layer. "
|
371
|
+
)
|
329
372
|
|
330
|
-
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
331
|
-
|
332
|
-
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
333
|
-
requires_grad=False)
|
373
|
+
self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
|
374
|
+
self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
|
334
375
|
|
335
376
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
336
377
|
num_tokens, hidden_size = hidden_states.shape
|
337
378
|
hidden_states = hidden_states.view(-1, self.hidden_size)
|
338
379
|
# router_logits: (num_tokens, n_experts)
|
339
380
|
router_logits, _ = self.gate(hidden_states)
|
340
|
-
final_hidden_states = fused_moe(
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
381
|
+
final_hidden_states = fused_moe(
|
382
|
+
hidden_states,
|
383
|
+
self.w13_weight,
|
384
|
+
self.w2_weight,
|
385
|
+
router_logits,
|
386
|
+
self.top_k,
|
387
|
+
renormalize=False,
|
388
|
+
inplace=True,
|
389
|
+
use_fp8=self.use_fp8,
|
390
|
+
w1_scale=self.w13_scale,
|
391
|
+
w2_scale=self.w2_scale,
|
392
|
+
a1_scale=self.a13_scale,
|
393
|
+
a2_scale=self.a2_scale,
|
394
|
+
)
|
352
395
|
|
353
396
|
if self.tp_size > 1:
|
354
|
-
final_hidden_states = tensor_model_parallel_all_reduce(
|
355
|
-
final_hidden_states)
|
397
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
356
398
|
|
357
399
|
return final_hidden_states.view(num_tokens, hidden_size)
|
358
400
|
|
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
|
|
462
504
|
top_k=config.num_experts_per_tok,
|
463
505
|
hidden_size=config.hidden_size,
|
464
506
|
intermediate_size=config.intermediate_size,
|
465
|
-
quant_config=quant_config
|
507
|
+
quant_config=quant_config,
|
508
|
+
)
|
466
509
|
else:
|
467
510
|
self.block_sparse_moe = Grok1MoEUnfused(
|
468
|
-
config=config, quant_config=quant_config
|
511
|
+
config=config, quant_config=quant_config
|
512
|
+
)
|
469
513
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
470
514
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
471
515
|
self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -477,13 +521,21 @@ class Grok1DecoderLayer(nn.Module):
|
|
477
521
|
hidden_states: torch.Tensor,
|
478
522
|
input_metadata: InputMetadata,
|
479
523
|
) -> torch.Tensor:
|
524
|
+
hidden_states = (
|
525
|
+
self.post_attn_norm(
|
526
|
+
self.self_attn(
|
527
|
+
positions=positions,
|
528
|
+
hidden_states=self.pre_attn_norm(hidden_states),
|
529
|
+
input_metadata=input_metadata,
|
530
|
+
)
|
531
|
+
)
|
532
|
+
+ hidden_states
|
533
|
+
)
|
480
534
|
|
481
|
-
hidden_states =
|
482
|
-
|
483
|
-
|
484
|
-
)
|
485
|
-
|
486
|
-
hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
|
535
|
+
hidden_states = (
|
536
|
+
self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
|
537
|
+
+ hidden_states
|
538
|
+
)
|
487
539
|
|
488
540
|
return hidden_states
|
489
541
|
|
@@ -525,9 +577,7 @@ class Grok1Model(nn.Module):
|
|
525
577
|
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
526
578
|
|
527
579
|
for i in range(len(self.layers)):
|
528
|
-
hidden_states = self.layers[i](
|
529
|
-
positions, hidden_states, input_metadata
|
530
|
-
)
|
580
|
+
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
531
581
|
|
532
582
|
hidden_states = self.norm(hidden_states)
|
533
583
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
@@ -572,28 +622,41 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
572
622
|
]
|
573
623
|
|
574
624
|
if use_fused:
|
575
|
-
expert_params_mapping =
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
625
|
+
expert_params_mapping = (
|
626
|
+
[
|
627
|
+
# These are the weight scales for the experts
|
628
|
+
# (param_name, weight_name, expert_id)
|
629
|
+
(
|
630
|
+
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
631
|
+
f"experts.{expert_id}.{weight_name}.weight_scale",
|
632
|
+
expert_id,
|
633
|
+
)
|
634
|
+
for expert_id in range(self.config.num_local_experts)
|
635
|
+
for weight_name in ["w1", "w2", "w3"]
|
636
|
+
]
|
637
|
+
+ [
|
638
|
+
# These are the weights for the experts
|
639
|
+
# (param_name, weight_name, expert_id)
|
640
|
+
(
|
641
|
+
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
642
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
643
|
+
expert_id,
|
644
|
+
)
|
645
|
+
for expert_id in range(self.config.num_local_experts)
|
646
|
+
for weight_name in ["w1", "w2", "w3"]
|
647
|
+
]
|
648
|
+
+ [
|
649
|
+
# These are the activation scales for the experts
|
650
|
+
# (param_name, weight_name, expert_id)
|
651
|
+
(
|
652
|
+
"a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
653
|
+
f"experts.{expert_id}.{weight_name}.act_scale",
|
654
|
+
expert_id,
|
655
|
+
)
|
656
|
+
for expert_id in range(self.config.num_local_experts)
|
657
|
+
for weight_name in ["w1", "w2", "w3"]
|
658
|
+
]
|
659
|
+
)
|
597
660
|
else:
|
598
661
|
expert_params_mapping = []
|
599
662
|
|
@@ -601,11 +664,11 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
601
664
|
if get_tensor_model_parallel_rank() == 0:
|
602
665
|
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
|
603
666
|
for name, loaded_weight in weights:
|
604
|
-
#print(get_tensor_model_parallel_rank(), name)
|
667
|
+
# print(get_tensor_model_parallel_rank(), name)
|
605
668
|
if "rotary_emb.inv_freq" in name:
|
606
669
|
continue
|
607
670
|
|
608
|
-
for
|
671
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
609
672
|
if weight_name not in name:
|
610
673
|
continue
|
611
674
|
name = name.replace(weight_name, param_name)
|
@@ -623,19 +686,22 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
623
686
|
name = name.replace(weight_name, param_name)
|
624
687
|
param = params_dict[name]
|
625
688
|
weight_loader = param.weight_loader
|
626
|
-
weight_loader(
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
689
|
+
weight_loader(
|
690
|
+
param,
|
691
|
+
loaded_weight,
|
692
|
+
weight_name,
|
693
|
+
expert_id=expert_id,
|
694
|
+
pre_sharded=get_tensor_model_parallel_world_size() > 1,
|
695
|
+
)
|
631
696
|
break
|
632
697
|
else:
|
633
698
|
# Skip loading extra bias for GPTQ models.
|
634
699
|
if name.endswith(".bias") and name not in params_dict:
|
635
700
|
continue
|
636
701
|
param = params_dict[name]
|
637
|
-
weight_loader = getattr(
|
638
|
-
|
702
|
+
weight_loader = getattr(
|
703
|
+
param, "weight_loader", default_weight_loader
|
704
|
+
)
|
639
705
|
weight_loader(param, loaded_weight)
|
640
706
|
|
641
707
|
|
@@ -645,10 +711,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
|
|
645
711
|
|
646
712
|
|
647
713
|
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
714
|
+
|
715
|
+
|
716
|
+
def _prepare_presharded_weights(
|
717
|
+
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
718
|
+
) -> Tuple[str, List[str], bool]:
|
652
719
|
import glob
|
653
720
|
import os
|
654
721
|
|
@@ -668,4 +735,4 @@ def _prepare_presharded_weights(self,
|
|
668
735
|
return hf_folder, hf_weights_files, use_safetensors
|
669
736
|
|
670
737
|
|
671
|
-
EntryClass = Grok1ModelForCausalLM
|
738
|
+
EntryClass = Grok1ModelForCausalLM
|
sglang/srt/models/llama2.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
# Adapted from
|
2
2
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
3
3
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
4
|
-
|
4
|
+
|
5
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
6
|
|
6
7
|
import torch
|
7
8
|
import tqdm
|
@@ -10,7 +11,7 @@ from transformers import LlamaConfig
|
|
10
11
|
from vllm.config import CacheConfig
|
11
12
|
from vllm.distributed import (
|
12
13
|
get_tensor_model_parallel_rank,
|
13
|
-
get_tensor_model_parallel_world_size
|
14
|
+
get_tensor_model_parallel_world_size,
|
14
15
|
)
|
15
16
|
from vllm.model_executor.layers.activation import SiluAndMul
|
16
17
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
@@ -76,6 +77,7 @@ class LlamaAttention(nn.Module):
|
|
76
77
|
layer_id: int = 0,
|
77
78
|
rope_theta: float = 10000,
|
78
79
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
80
|
+
rope_is_neox_style: bool = True,
|
79
81
|
max_position_embeddings: int = 8192,
|
80
82
|
quant_config: Optional[QuantizationConfig] = None,
|
81
83
|
) -> None:
|
@@ -123,6 +125,7 @@ class LlamaAttention(nn.Module):
|
|
123
125
|
max_position=max_position_embeddings,
|
124
126
|
base=rope_theta,
|
125
127
|
rope_scaling=rope_scaling,
|
128
|
+
is_neox_style=rope_is_neox_style,
|
126
129
|
)
|
127
130
|
self.attn = RadixAttention(
|
128
131
|
self.num_heads,
|
@@ -158,9 +161,12 @@ class LlamaDecoderLayer(nn.Module):
|
|
158
161
|
rope_theta = getattr(config, "rope_theta", 10000)
|
159
162
|
rope_scaling = getattr(config, "rope_scaling", None)
|
160
163
|
if rope_scaling is not None and getattr(
|
161
|
-
|
164
|
+
config, "original_max_position_embeddings", None
|
165
|
+
):
|
162
166
|
rope_scaling["original_max_position_embeddings"] = (
|
163
|
-
config.original_max_position_embeddings
|
167
|
+
config.original_max_position_embeddings
|
168
|
+
)
|
169
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
164
170
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
165
171
|
self.self_attn = LlamaAttention(
|
166
172
|
hidden_size=self.hidden_size,
|
@@ -169,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
|
|
169
175
|
layer_id=layer_id,
|
170
176
|
rope_theta=rope_theta,
|
171
177
|
rope_scaling=rope_scaling,
|
178
|
+
rope_is_neox_style=rope_is_neox_style,
|
172
179
|
max_position_embeddings=max_position_embeddings,
|
173
180
|
quant_config=quant_config,
|
174
181
|
)
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from typing import Iterable, Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import tqdm
|
5
|
+
from torch import nn
|
6
|
+
from transformers import LlamaConfig
|
7
|
+
from vllm.config import CacheConfig
|
8
|
+
from vllm.distributed import (
|
9
|
+
get_tensor_model_parallel_rank,
|
10
|
+
)
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
13
|
+
|
14
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
15
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
16
|
+
from sglang.srt.models.llama2 import LlamaModel
|
17
|
+
|
18
|
+
|
19
|
+
class LlamaForClassification(nn.Module):
|
20
|
+
def __init__(
|
21
|
+
self,
|
22
|
+
config: LlamaConfig,
|
23
|
+
quant_config: Optional[QuantizationConfig] = None,
|
24
|
+
cache_config: Optional[CacheConfig] = None,
|
25
|
+
) -> None:
|
26
|
+
super().__init__()
|
27
|
+
self.config = config
|
28
|
+
self.quant_config = quant_config
|
29
|
+
self.model = LlamaModel(config, quant_config=quant_config)
|
30
|
+
|
31
|
+
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
|
32
|
+
self.eos_token_id = config.eos_token_id
|
33
|
+
|
34
|
+
def forward(
|
35
|
+
self,
|
36
|
+
input_ids: torch.Tensor,
|
37
|
+
positions: torch.Tensor,
|
38
|
+
input_metadata: InputMetadata,
|
39
|
+
input_embeds: torch.Tensor = None,
|
40
|
+
) -> torch.Tensor:
|
41
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
42
|
+
is_eos_token = input_ids == self.eos_token_id
|
43
|
+
hidden_states = hidden_states[is_eos_token]
|
44
|
+
scores = self.classification_head(hidden_states)
|
45
|
+
|
46
|
+
if scores.shape[0] != input_metadata.batch_size:
|
47
|
+
print("Warning: the EOS tokens are missing in some sentences.")
|
48
|
+
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
|
49
|
+
|
50
|
+
return LogitProcessorOutput(
|
51
|
+
next_token_logits=scores,
|
52
|
+
next_token_logprobs=scores,
|
53
|
+
normalized_prompt_logprobs=scores,
|
54
|
+
prefill_token_logprobs=torch.ones_like(input_ids),
|
55
|
+
prefill_top_logprobs=None,
|
56
|
+
decode_top_logprobs=None,
|
57
|
+
)
|
58
|
+
|
59
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
60
|
+
stacked_params_mapping = [
|
61
|
+
# (param_name, shard_name, shard_id)
|
62
|
+
("qkv_proj", "q_proj", "q"),
|
63
|
+
("qkv_proj", "k_proj", "k"),
|
64
|
+
("qkv_proj", "v_proj", "v"),
|
65
|
+
("gate_up_proj", "gate_proj", 0),
|
66
|
+
("gate_up_proj", "up_proj", 1),
|
67
|
+
]
|
68
|
+
params_dict = dict(self.named_parameters())
|
69
|
+
if get_tensor_model_parallel_rank() == 0:
|
70
|
+
weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
|
71
|
+
for name, loaded_weight in weights:
|
72
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
73
|
+
continue
|
74
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
75
|
+
# Models trained using ColossalAI may include these tensors in
|
76
|
+
# the checkpoint. Skip them.
|
77
|
+
continue
|
78
|
+
if "lm_head" in name:
|
79
|
+
continue
|
80
|
+
|
81
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
82
|
+
if weight_name not in name:
|
83
|
+
continue
|
84
|
+
name = name.replace(weight_name, param_name)
|
85
|
+
# Skip loading extra bias for GPTQ models.
|
86
|
+
if name.endswith(".bias") and name not in params_dict:
|
87
|
+
continue
|
88
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
89
|
+
continue
|
90
|
+
param = params_dict[name]
|
91
|
+
weight_loader = param.weight_loader
|
92
|
+
weight_loader(param, loaded_weight, shard_id)
|
93
|
+
break
|
94
|
+
else:
|
95
|
+
# Skip loading extra bias for GPTQ models.
|
96
|
+
if name.endswith(".bias") and name not in params_dict:
|
97
|
+
continue
|
98
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
99
|
+
continue
|
100
|
+
param = params_dict[name]
|
101
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
102
|
+
weight_loader(param, loaded_weight)
|
103
|
+
|
104
|
+
EntryClass = LlamaForClassification
|