sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import gc
|
|
17
17
|
import json
|
18
18
|
import logging
|
19
19
|
import time
|
20
|
-
from typing import Optional
|
20
|
+
from typing import List, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.distributed as dist
|
@@ -48,8 +48,8 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
48
48
|
)
|
49
49
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
from sglang.srt.model_loader import get_model
|
51
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
52
51
|
from sglang.srt.server_args import ServerArgs
|
52
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
53
53
|
from sglang.srt.utils import (
|
54
54
|
enable_show_time_cost,
|
55
55
|
get_available_gpu_memory,
|
@@ -75,6 +75,7 @@ class ModelRunner:
|
|
75
75
|
tp_size: int,
|
76
76
|
nccl_port: int,
|
77
77
|
server_args: ServerArgs,
|
78
|
+
is_draft_worker: bool = False,
|
78
79
|
):
|
79
80
|
# Parse args
|
80
81
|
self.model_config = model_config
|
@@ -85,8 +86,13 @@ class ModelRunner:
|
|
85
86
|
self.tp_size = tp_size
|
86
87
|
self.dist_port = nccl_port
|
87
88
|
self.server_args = server_args
|
89
|
+
self.is_draft_worker = is_draft_worker
|
88
90
|
self.is_generation = model_config.is_generation
|
89
91
|
self.is_multimodal = model_config.is_multimodal
|
92
|
+
self.should_log = tp_rank == 0
|
93
|
+
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
94
|
+
server_args.speculative_algorithm
|
95
|
+
)
|
90
96
|
|
91
97
|
# Model-specific adjustment
|
92
98
|
if (
|
@@ -112,15 +118,21 @@ class ModelRunner:
|
|
112
118
|
|
113
119
|
if self.is_multimodal:
|
114
120
|
self.mem_fraction_static *= 0.95
|
121
|
+
logger.info(
|
122
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
123
|
+
f"because this is a multimodal model."
|
124
|
+
)
|
125
|
+
|
115
126
|
if self.model_config.hf_config.architectures == [
|
116
127
|
"MllamaForConditionalGeneration"
|
117
128
|
]:
|
118
129
|
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
119
130
|
server_args.chunked_prefill_size = -1
|
120
|
-
|
131
|
+
|
121
132
|
if self.model_config.hf_config.architectures == [
|
122
133
|
"Qwen2VLForConditionalGeneration"
|
123
134
|
]:
|
135
|
+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
124
136
|
logger.info(
|
125
137
|
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
126
138
|
)
|
@@ -192,9 +204,9 @@ class ModelRunner:
|
|
192
204
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
193
205
|
if self.device == "cuda":
|
194
206
|
backend = "nccl"
|
195
|
-
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
196
|
-
# Need to use xccl for xpu backend in the future
|
197
207
|
elif self.device == "xpu":
|
208
|
+
# TODO(liangan1): Just use gloo to bypass the initilization fail
|
209
|
+
# Need to use xccl for xpu backend in the future
|
198
210
|
backend = "gloo"
|
199
211
|
elif self.device == "hpu":
|
200
212
|
backend = "hccl"
|
@@ -206,14 +218,18 @@ class ModelRunner:
|
|
206
218
|
else:
|
207
219
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
208
220
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
221
|
+
|
222
|
+
if not self.is_draft_worker:
|
223
|
+
# Only initilzie the distributed environment on the target model worker.
|
224
|
+
init_distributed_environment(
|
225
|
+
backend=backend,
|
226
|
+
world_size=self.tp_size,
|
227
|
+
rank=self.tp_rank,
|
228
|
+
local_rank=self.gpu_id,
|
229
|
+
distributed_init_method=dist_init_method,
|
230
|
+
)
|
231
|
+
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
232
|
+
|
217
233
|
min_per_gpu_memory = get_available_gpu_memory(
|
218
234
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
219
235
|
)
|
@@ -408,7 +424,6 @@ class ModelRunner:
|
|
408
424
|
target_dtype = (
|
409
425
|
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
410
426
|
)
|
411
|
-
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
|
412
427
|
|
413
428
|
assert (
|
414
429
|
self._model_update_group is not None
|
@@ -429,9 +444,9 @@ class ModelRunner:
|
|
429
444
|
logger.error(error_msg)
|
430
445
|
return False, error_msg
|
431
446
|
|
432
|
-
def update_weights_from_tensor(self,
|
433
|
-
self.model.load_weights(
|
434
|
-
return True, "Success"
|
447
|
+
def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
|
448
|
+
self.model.load_weights(named_tensors)
|
449
|
+
return True, "Success"
|
435
450
|
|
436
451
|
def get_weights_by_name(
|
437
452
|
self, name: str, truncate_size: int = 100
|
@@ -507,6 +522,28 @@ class ModelRunner:
|
|
507
522
|
)
|
508
523
|
|
509
524
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
525
|
+
|
526
|
+
if max_num_reqs is None:
|
527
|
+
max_num_reqs = min(
|
528
|
+
max(
|
529
|
+
int(
|
530
|
+
self.max_total_num_tokens / self.model_config.context_len * 512
|
531
|
+
),
|
532
|
+
2048,
|
533
|
+
),
|
534
|
+
4096,
|
535
|
+
)
|
536
|
+
|
537
|
+
if not self.spec_algorithm.is_none():
|
538
|
+
if self.is_draft_worker:
|
539
|
+
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
540
|
+
else:
|
541
|
+
self.server_args.draft_runner_cache_size = (
|
542
|
+
self.max_total_num_tokens
|
543
|
+
+ max_num_reqs * self.server_args.speculative_num_steps
|
544
|
+
+ 100
|
545
|
+
)
|
546
|
+
|
510
547
|
if max_total_tokens is not None:
|
511
548
|
if max_total_tokens > self.max_total_num_tokens:
|
512
549
|
logging.warning(
|
@@ -521,17 +558,6 @@ class ModelRunner:
|
|
521
558
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
522
559
|
)
|
523
560
|
|
524
|
-
if max_num_reqs is None:
|
525
|
-
max_num_reqs = min(
|
526
|
-
max(
|
527
|
-
int(
|
528
|
-
self.max_total_num_tokens / self.model_config.context_len * 512
|
529
|
-
),
|
530
|
-
2048,
|
531
|
-
),
|
532
|
-
4096,
|
533
|
-
)
|
534
|
-
|
535
561
|
self.req_to_token_pool = ReqToTokenPool(
|
536
562
|
size=max_num_reqs + 1,
|
537
563
|
max_context_len=self.model_config.context_len + 4,
|
@@ -608,7 +634,6 @@ class ModelRunner:
|
|
608
634
|
)
|
609
635
|
|
610
636
|
def init_double_sparsity_channel_config(self, selected_channel):
|
611
|
-
|
612
637
|
selected_channel = "." + selected_channel + "_proj"
|
613
638
|
self.sorted_channels = []
|
614
639
|
# load channel config
|
@@ -651,10 +676,6 @@ class ModelRunner:
|
|
651
676
|
tensor_parallel(self.model, device_mesh)
|
652
677
|
|
653
678
|
def forward_decode(self, forward_batch: ForwardBatch):
|
654
|
-
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
655
|
-
return self.cuda_graph_runner.replay(forward_batch)
|
656
|
-
|
657
|
-
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
658
679
|
self.attn_backend.init_forward_metadata(forward_batch)
|
659
680
|
return self.model.forward(
|
660
681
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
@@ -684,14 +705,18 @@ class ModelRunner:
|
|
684
705
|
)
|
685
706
|
|
686
707
|
def forward_idle(self, forward_batch: ForwardBatch):
|
687
|
-
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
688
|
-
return self.cuda_graph_runner.replay(forward_batch)
|
689
|
-
|
690
708
|
return self.model.forward(
|
691
709
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
692
710
|
)
|
693
711
|
|
694
712
|
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
|
713
|
+
if (
|
714
|
+
forward_batch.forward_mode.is_cuda_graph()
|
715
|
+
and self.cuda_graph_runner
|
716
|
+
and self.cuda_graph_runner.can_run(forward_batch)
|
717
|
+
):
|
718
|
+
return self.cuda_graph_runner.replay(forward_batch)
|
719
|
+
|
695
720
|
if forward_batch.forward_mode.is_decode():
|
696
721
|
return self.forward_decode(forward_batch)
|
697
722
|
elif forward_batch.forward_mode.is_extend():
|
@@ -699,11 +724,12 @@ class ModelRunner:
|
|
699
724
|
elif forward_batch.forward_mode.is_idle():
|
700
725
|
return self.forward_idle(forward_batch)
|
701
726
|
else:
|
702
|
-
raise ValueError(f"
|
727
|
+
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
703
728
|
|
704
729
|
def sample(
|
705
730
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
706
731
|
) -> torch.Tensor:
|
732
|
+
# Apply logit bias
|
707
733
|
sampling_info = forward_batch.sampling_info
|
708
734
|
if sampling_info.sampling_info_done:
|
709
735
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
@@ -714,35 +740,17 @@ class ModelRunner:
|
|
714
740
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
715
741
|
sampling_info.update_regex_vocab_mask()
|
716
742
|
sampling_info.update_penalties()
|
717
|
-
|
718
|
-
|
719
|
-
# Sample the next tokens
|
720
|
-
next_token_ids = self.sampler(
|
743
|
+
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
744
|
+
|
745
|
+
# Sample the next tokens
|
746
|
+
next_token_ids = self.sampler(
|
747
|
+
logits_output,
|
748
|
+
sampling_info,
|
749
|
+
forward_batch.return_logprob,
|
750
|
+
forward_batch.top_logprobs_nums,
|
751
|
+
)
|
721
752
|
return next_token_ids
|
722
753
|
|
723
|
-
def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
|
724
|
-
# Apply logit_bias
|
725
|
-
if sampling_info.logit_bias is not None:
|
726
|
-
logits.add_(sampling_info.logit_bias)
|
727
|
-
|
728
|
-
# min-token, presence, frequency
|
729
|
-
if sampling_info.linear_penalties is not None:
|
730
|
-
logits.add_(sampling_info.linear_penalties)
|
731
|
-
|
732
|
-
# repetition
|
733
|
-
if sampling_info.scaling_penalties is not None:
|
734
|
-
logits = torch.where(
|
735
|
-
logits > 0,
|
736
|
-
logits / sampling_info.scaling_penalties,
|
737
|
-
logits * sampling_info.scaling_penalties,
|
738
|
-
)
|
739
|
-
|
740
|
-
# Apply regex vocab_mask
|
741
|
-
if sampling_info.vocab_mask is not None:
|
742
|
-
sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
|
743
|
-
|
744
|
-
return logits
|
745
|
-
|
746
754
|
@property
|
747
755
|
def model_is_mrope(self) -> bool:
|
748
756
|
"""Detect if the model has "mrope" rope_scaling type.
|
sglang/srt/models/chatglm.py
CHANGED
@@ -23,8 +23,8 @@ from torch import nn
|
|
23
23
|
from torch.nn import LayerNorm
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.transformers_utils.configs import ChatGLMConfig
|
27
26
|
|
27
|
+
from sglang.srt.configs import ChatGLMConfig
|
28
28
|
from sglang.srt.layers.activation import SiluAndMul
|
29
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
30
|
from sglang.srt.layers.linear import (
|
sglang/srt/models/dbrx.py
CHANGED
@@ -25,8 +25,8 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_reduce,
|
26
26
|
)
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
29
28
|
|
29
|
+
from sglang.srt.configs import DbrxConfig
|
30
30
|
from sglang.srt.layers.linear import (
|
31
31
|
QKVParallelLinear,
|
32
32
|
ReplicatedLinear,
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
46
46
|
from sglang.srt.layers.quantization.fp8_utils import (
|
47
47
|
block_quant_to_tensor_quant,
|
48
48
|
input_to_float8,
|
49
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
49
50
|
)
|
50
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
52
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
@@ -55,7 +56,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
55
56
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
56
57
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
58
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
|
-
from sglang.srt.utils import is_flashinfer_available
|
59
|
+
from sglang.srt.utils import is_flashinfer_available, is_hip
|
60
|
+
|
61
|
+
is_hip_ = is_hip()
|
59
62
|
|
60
63
|
if is_flashinfer_available():
|
61
64
|
from flashinfer import bmm_fp8
|
@@ -573,7 +576,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
573
576
|
)
|
574
577
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
575
578
|
|
576
|
-
if self.w_kc.dtype == torch.
|
579
|
+
if self.w_kc.dtype == torch.float8_e4m3fnuz:
|
580
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
581
|
+
q_nope_out = torch.bmm(
|
582
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
583
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
584
|
+
)
|
585
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
577
586
|
q_nope_val, q_nope_scale = input_to_float8(
|
578
587
|
q_nope.transpose(0, 1), torch.float8_e4m3fn
|
579
588
|
)
|
@@ -598,7 +607,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
598
607
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
599
608
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
600
609
|
|
601
|
-
if self.w_vc.dtype == torch.
|
610
|
+
if self.w_vc.dtype == torch.float8_e4m3fnuz:
|
611
|
+
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
|
612
|
+
attn_bmm_output = torch.bmm(
|
613
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
614
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
615
|
+
)
|
616
|
+
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
602
617
|
attn_output_val, attn_output_scale = input_to_float8(
|
603
618
|
attn_output.transpose(0, 1), torch.float8_e4m3fn
|
604
619
|
)
|
@@ -940,15 +955,25 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
940
955
|
w = self_attn.kv_b_proj.weight
|
941
956
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
942
957
|
# This may affect the accuracy of fp8 model.
|
943
|
-
if (
|
944
|
-
|
945
|
-
|
958
|
+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
|
959
|
+
torch.float8_e4m3fn,
|
960
|
+
torch.float8_e4m3fnuz,
|
946
961
|
):
|
947
962
|
weight_block_size = self.quant_config.weight_block_size
|
948
963
|
if weight_block_size is not None:
|
949
964
|
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
965
|
+
if is_hip_:
|
966
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
967
|
+
weight=w,
|
968
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
969
|
+
input_scale=None,
|
970
|
+
)
|
971
|
+
else:
|
972
|
+
weight = w
|
973
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
974
|
+
|
950
975
|
w, scale = block_quant_to_tensor_quant(
|
951
|
-
|
976
|
+
weight, weight_scale, weight_block_size
|
952
977
|
)
|
953
978
|
self_attn.w_scale = scale
|
954
979
|
w_kc, w_vc = w.unflatten(
|
@@ -961,6 +986,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
961
986
|
and self_attn.w_scale is None
|
962
987
|
):
|
963
988
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
989
|
+
if is_hip_:
|
990
|
+
self_attn.w_scale *= 2.0
|
964
991
|
|
965
992
|
|
966
993
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
sglang/srt/models/grok.py
CHANGED
@@ -16,13 +16,16 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
18
|
|
19
|
-
from typing import Iterable, Optional, Tuple
|
19
|
+
from typing import Iterable, List, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
import torch.nn.functional as F
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
|
-
from vllm.distributed import
|
25
|
+
from vllm.distributed import (
|
26
|
+
get_tensor_model_parallel_rank,
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
)
|
26
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
30
|
|
28
31
|
from sglang.srt.layers.activation import GeluAndMul
|
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
42
45
|
VocabParallelEmbedding,
|
43
46
|
)
|
44
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
45
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
50
|
|
47
51
|
|
@@ -53,6 +57,7 @@ class Grok1MLP(nn.Module):
|
|
53
57
|
quant_config: Optional[QuantizationConfig] = None,
|
54
58
|
prefix: str = "",
|
55
59
|
reduce_results=True,
|
60
|
+
use_presharded_weights: bool = False,
|
56
61
|
) -> None:
|
57
62
|
super().__init__()
|
58
63
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -61,6 +66,7 @@ class Grok1MLP(nn.Module):
|
|
61
66
|
bias=False,
|
62
67
|
quant_config=quant_config,
|
63
68
|
prefix=f"{prefix}.gate_up_proj",
|
69
|
+
use_presharded_weights=use_presharded_weights,
|
64
70
|
)
|
65
71
|
self.down_proj = RowParallelLinear(
|
66
72
|
intermediate_size,
|
@@ -69,6 +75,7 @@ class Grok1MLP(nn.Module):
|
|
69
75
|
quant_config=quant_config,
|
70
76
|
prefix=f"{prefix}.down_proj",
|
71
77
|
reduce_results=reduce_results,
|
78
|
+
use_presharded_weights=use_presharded_weights,
|
72
79
|
)
|
73
80
|
self.act_fn = GeluAndMul(approximate="tanh")
|
74
81
|
|
@@ -99,6 +106,7 @@ class Grok1MoE(nn.Module):
|
|
99
106
|
quant_config: Optional[QuantizationConfig] = None,
|
100
107
|
tp_size: Optional[int] = None,
|
101
108
|
reduce_results=True,
|
109
|
+
use_presharded_weights: bool = False,
|
102
110
|
):
|
103
111
|
super().__init__()
|
104
112
|
self.hidden_size = hidden_size
|
@@ -125,6 +133,7 @@ class Grok1MoE(nn.Module):
|
|
125
133
|
renormalize=False,
|
126
134
|
quant_config=quant_config,
|
127
135
|
tp_size=tp_size,
|
136
|
+
use_presharded_weights=use_presharded_weights,
|
128
137
|
)
|
129
138
|
|
130
139
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -152,6 +161,7 @@ class Grok1Attention(nn.Module):
|
|
152
161
|
max_position: int = 4096 * 32,
|
153
162
|
rope_theta: float = 10000,
|
154
163
|
quant_config: Optional[QuantizationConfig] = None,
|
164
|
+
reduce_results: bool = True,
|
155
165
|
) -> None:
|
156
166
|
super().__init__()
|
157
167
|
self.config = config
|
@@ -190,6 +200,7 @@ class Grok1Attention(nn.Module):
|
|
190
200
|
hidden_size,
|
191
201
|
bias=False,
|
192
202
|
quant_config=quant_config,
|
203
|
+
reduce_results=reduce_results,
|
193
204
|
)
|
194
205
|
self.rotary_emb = get_rope(
|
195
206
|
self.head_dim,
|
@@ -230,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
|
|
230
241
|
config: PretrainedConfig,
|
231
242
|
layer_id: int = 0,
|
232
243
|
quant_config: Optional[QuantizationConfig] = None,
|
244
|
+
use_presharded_weights: bool = False,
|
233
245
|
) -> None:
|
234
246
|
super().__init__()
|
235
247
|
self.num_experts = config.num_local_experts
|
236
248
|
self.hidden_size = config.hidden_size
|
249
|
+
self.layer_id = layer_id
|
237
250
|
|
238
251
|
rope_theta = getattr(config, "rope_theta", 10000)
|
239
252
|
self.self_attn = Grok1Attention(
|
@@ -258,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
258
271
|
),
|
259
272
|
quant_config=quant_config,
|
260
273
|
reduce_results=True,
|
274
|
+
use_presharded_weights=use_presharded_weights,
|
261
275
|
)
|
262
276
|
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
263
277
|
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
@@ -295,6 +309,7 @@ class Grok1Model(nn.Module):
|
|
295
309
|
self,
|
296
310
|
config: PretrainedConfig,
|
297
311
|
quant_config: Optional[QuantizationConfig] = None,
|
312
|
+
use_presharded_weights: bool = False,
|
298
313
|
) -> None:
|
299
314
|
super().__init__()
|
300
315
|
self.config = config
|
@@ -307,7 +322,12 @@ class Grok1Model(nn.Module):
|
|
307
322
|
)
|
308
323
|
self.layers = nn.ModuleList(
|
309
324
|
[
|
310
|
-
Grok1DecoderLayer(
|
325
|
+
Grok1DecoderLayer(
|
326
|
+
config,
|
327
|
+
i,
|
328
|
+
quant_config=quant_config,
|
329
|
+
use_presharded_weights=use_presharded_weights,
|
330
|
+
)
|
311
331
|
for i in range(config.num_hidden_layers)
|
312
332
|
]
|
313
333
|
)
|
@@ -343,7 +363,21 @@ class Grok1ForCausalLM(nn.Module):
|
|
343
363
|
super().__init__()
|
344
364
|
self.config = config
|
345
365
|
self.quant_config = quant_config
|
346
|
-
|
366
|
+
|
367
|
+
if (
|
368
|
+
self.config.num_local_experts > 0
|
369
|
+
and get_tensor_model_parallel_world_size() > 1
|
370
|
+
):
|
371
|
+
self.use_presharded_weights = True
|
372
|
+
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
373
|
+
else:
|
374
|
+
self.use_presharded_weights = False
|
375
|
+
|
376
|
+
self.model = Grok1Model(
|
377
|
+
config,
|
378
|
+
quant_config=quant_config,
|
379
|
+
use_presharded_weights=self.use_presharded_weights,
|
380
|
+
)
|
347
381
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
348
382
|
self.logits_processor = LogitsProcessor(config)
|
349
383
|
|
@@ -359,7 +393,12 @@ class Grok1ForCausalLM(nn.Module):
|
|
359
393
|
input_ids, hidden_states, self.lm_head, forward_batch
|
360
394
|
)
|
361
395
|
|
362
|
-
def load_weights(
|
396
|
+
def load_weights(
|
397
|
+
self,
|
398
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
399
|
+
):
|
400
|
+
num_experts = self.config.num_local_experts
|
401
|
+
|
363
402
|
stacked_params_mapping = [
|
364
403
|
# (param_name, shard_name, shard_id)
|
365
404
|
("qkv_proj", "q_proj", "q"),
|
@@ -375,10 +414,23 @@ class Grok1ForCausalLM(nn.Module):
|
|
375
414
|
ckpt_gate_proj_name="w1",
|
376
415
|
ckpt_down_proj_name="w2",
|
377
416
|
ckpt_up_proj_name="w3",
|
378
|
-
num_experts=
|
417
|
+
num_experts=num_experts,
|
379
418
|
)
|
380
419
|
|
381
420
|
params_dict = dict(self.named_parameters())
|
421
|
+
all_names = set(params_dict.keys())
|
422
|
+
hit_names = set()
|
423
|
+
|
424
|
+
def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
|
425
|
+
if name not in params_dict:
|
426
|
+
return
|
427
|
+
|
428
|
+
param = params_dict[name]
|
429
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
430
|
+
weight_loader(param, loaded_weight, *args, **kwargs)
|
431
|
+
|
432
|
+
hit_names.add(name)
|
433
|
+
|
382
434
|
for name, loaded_weight in weights:
|
383
435
|
if "rotary_emb.inv_freq" in name:
|
384
436
|
continue
|
@@ -391,9 +443,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
391
443
|
if name.endswith(".bias") and name not in params_dict:
|
392
444
|
continue
|
393
445
|
|
394
|
-
|
395
|
-
weight_loader = param.weight_loader
|
396
|
-
weight_loader(param, loaded_weight, shard_id)
|
446
|
+
load_weight_wrapper(name, loaded_weight, shard_id)
|
397
447
|
break
|
398
448
|
else:
|
399
449
|
for mapping in expert_params_mapping:
|
@@ -402,15 +452,8 @@ class Grok1ForCausalLM(nn.Module):
|
|
402
452
|
continue
|
403
453
|
name = name.replace(weight_name, param_name)
|
404
454
|
|
405
|
-
|
406
|
-
name
|
407
|
-
) and name not in params_dict:
|
408
|
-
continue
|
409
|
-
|
410
|
-
param = params_dict[name]
|
411
|
-
weight_loader = param.weight_loader
|
412
|
-
weight_loader(
|
413
|
-
param,
|
455
|
+
load_weight_wrapper(
|
456
|
+
name,
|
414
457
|
loaded_weight,
|
415
458
|
name,
|
416
459
|
shard_id=shard_id,
|
@@ -419,21 +462,58 @@ class Grok1ForCausalLM(nn.Module):
|
|
419
462
|
break
|
420
463
|
else:
|
421
464
|
# Skip loading extra bias for GPTQ models.
|
422
|
-
if (
|
423
|
-
name.endswith(".bias") or name.endswith("_bias")
|
424
|
-
) and name not in params_dict:
|
425
|
-
continue
|
426
|
-
# Skip loading kv_scale from ckpts towards new design.
|
427
|
-
if name.endswith(".kv_scale") and name not in params_dict:
|
465
|
+
if name.endswith(".bias") and name not in params_dict:
|
428
466
|
continue
|
429
467
|
if name is None:
|
430
468
|
continue
|
431
469
|
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
470
|
+
load_weight_wrapper(name=name, loaded_weight=loaded_weight)
|
471
|
+
|
472
|
+
|
473
|
+
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
474
|
+
|
475
|
+
|
476
|
+
def _prepare_presharded_weights(
|
477
|
+
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
478
|
+
) -> Tuple[str, List[str], bool]:
|
479
|
+
import glob
|
480
|
+
import os
|
481
|
+
|
482
|
+
if get_tensor_model_parallel_world_size() == 1:
|
483
|
+
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
484
|
+
|
485
|
+
if not os.path.isdir(model_name_or_path):
|
486
|
+
from sglang.srt.model_loader.weight_utils import download_weights_from_hf
|
487
|
+
|
488
|
+
allow_patterns = ["*.safetensors", "*.bin"]
|
489
|
+
hf_folder = download_weights_from_hf(
|
490
|
+
model_name_or_path,
|
491
|
+
self.load_config.download_dir,
|
492
|
+
allow_patterns,
|
493
|
+
revision,
|
494
|
+
ignore_patterns=self.load_config.ignore_patterns,
|
495
|
+
)
|
496
|
+
else:
|
497
|
+
hf_folder = model_name_or_path
|
498
|
+
|
499
|
+
tp_rank = get_tensor_model_parallel_rank()
|
500
|
+
|
501
|
+
# The old format
|
502
|
+
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
503
|
+
|
504
|
+
# The new format
|
505
|
+
allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
|
506
|
+
|
507
|
+
hf_weights_files: List[str] = []
|
508
|
+
for pattern in allow_patterns:
|
509
|
+
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
510
|
+
|
511
|
+
if hf_weights_files[0].endswith("safetensors"):
|
512
|
+
use_safetensors = True
|
513
|
+
else:
|
514
|
+
use_safetensors = False
|
515
|
+
|
516
|
+
return hf_folder, hf_weights_files, use_safetensors
|
437
517
|
|
438
518
|
|
439
519
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|