sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
|
1
|
+
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
|
2
2
|
|
3
3
|
import logging
|
4
4
|
from abc import abstractmethod
|
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
|
|
7
7
|
import torch
|
8
8
|
import torch.nn.functional as F
|
9
9
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
10
|
-
|
10
|
+
|
11
|
+
from sglang.srt.distributed import (
|
11
12
|
divide,
|
12
13
|
get_tensor_model_parallel_rank,
|
13
14
|
get_tensor_model_parallel_world_size,
|
@@ -15,17 +16,12 @@ from vllm.distributed import (
|
|
15
16
|
tensor_model_parallel_all_gather,
|
16
17
|
tensor_model_parallel_all_reduce,
|
17
18
|
)
|
18
|
-
|
19
|
-
# workaround
|
20
|
-
from vllm.model_executor.layers.linear import LinearBase
|
21
|
-
|
22
19
|
from sglang.srt.layers.parameter import (
|
23
20
|
BasevLLMParameter,
|
24
21
|
PackedColumnParameter,
|
25
22
|
PackedvLLMParameter,
|
26
23
|
PerTensorScaleParameter,
|
27
24
|
RowvLLMParameter,
|
28
|
-
_ColumnvLLMParameter,
|
29
25
|
)
|
30
26
|
from sglang.srt.layers.quantization.base_config import (
|
31
27
|
QuantizationConfig,
|
@@ -43,9 +39,13 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|
43
39
|
"GPTQMarlinLinearMethod",
|
44
40
|
"Fp8LinearMethod",
|
45
41
|
"MarlinLinearMethod",
|
46
|
-
"GPTQLinearMethod",
|
47
42
|
"QQQLinearMethod",
|
43
|
+
"GPTQMarlin24LinearMethod",
|
44
|
+
"TPUInt8LinearMethod",
|
45
|
+
"GPTQLinearMethod",
|
46
|
+
"FBGEMMFp8LinearMethod",
|
48
47
|
"ModelOptFp8LinearMethod",
|
48
|
+
"IPEXAWQLinearMethod",
|
49
49
|
]
|
50
50
|
|
51
51
|
|
@@ -95,62 +95,6 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
95
95
|
return param[shard_id], loaded_weight
|
96
96
|
|
97
97
|
|
98
|
-
def load_column_qkv_weight(
|
99
|
-
self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank
|
100
|
-
):
|
101
|
-
if (
|
102
|
-
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
103
|
-
and self.output_dim == self.packed_dim
|
104
|
-
):
|
105
|
-
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
106
|
-
shard_offset=shard_offset, shard_size=shard_size
|
107
|
-
)
|
108
|
-
|
109
|
-
param_data = self.data
|
110
|
-
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
111
|
-
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
112
|
-
loaded_weight = loaded_weight.narrow(
|
113
|
-
self.output_dim, shard_id * shard_size, shard_size
|
114
|
-
)
|
115
|
-
|
116
|
-
assert param_data.shape == loaded_weight.shape
|
117
|
-
param_data.copy_(loaded_weight)
|
118
|
-
|
119
|
-
|
120
|
-
def load_column_parallel_weight(
|
121
|
-
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
122
|
-
):
|
123
|
-
if isinstance(self, _ColumnvLLMParameter):
|
124
|
-
if not use_presharded_weights:
|
125
|
-
shard_size = self.data.shape[self.output_dim]
|
126
|
-
loaded_weight = loaded_weight.narrow(
|
127
|
-
self.output_dim, tp_rank * shard_size, shard_size
|
128
|
-
)
|
129
|
-
assert self.data.shape == loaded_weight.shape
|
130
|
-
self.data.copy_(loaded_weight)
|
131
|
-
else:
|
132
|
-
self.data.copy_(loaded_weight)
|
133
|
-
|
134
|
-
|
135
|
-
def load_row_parallel_weight(
|
136
|
-
self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False
|
137
|
-
):
|
138
|
-
if isinstance(self, RowvLLMParameter):
|
139
|
-
if not use_presharded_weights:
|
140
|
-
shard_size = self.data.shape[self.input_dim]
|
141
|
-
loaded_weight = loaded_weight.narrow(
|
142
|
-
self.input_dim, tp_rank * shard_size, shard_size
|
143
|
-
)
|
144
|
-
|
145
|
-
if len(loaded_weight.shape) == 0:
|
146
|
-
loaded_weight = loaded_weight.reshape(1)
|
147
|
-
|
148
|
-
assert self.data.shape == loaded_weight.shape
|
149
|
-
self.data.copy_(loaded_weight)
|
150
|
-
else:
|
151
|
-
self.data.copy_(loaded_weight)
|
152
|
-
|
153
|
-
|
154
98
|
class LinearMethodBase(QuantizeMethodBase):
|
155
99
|
"""Base class for different (maybe quantized) linear methods."""
|
156
100
|
|
@@ -227,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|
227
171
|
return F.linear(x, layer.weight, bias)
|
228
172
|
|
229
173
|
|
174
|
+
class LinearBase(torch.nn.Module):
|
175
|
+
"""Base linear layer.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
input_size: input dimension of the linear layer.
|
179
|
+
output_size: output dimension of the linear layer.
|
180
|
+
bias: If true, add bias.
|
181
|
+
skip_bias_add: If true, skip adding bias but instead return it.
|
182
|
+
params_dtype: Data type for the parameters.
|
183
|
+
quant_config: Quantization configure.
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __init__(
|
187
|
+
self,
|
188
|
+
input_size: int,
|
189
|
+
output_size: int,
|
190
|
+
skip_bias_add: bool = False,
|
191
|
+
params_dtype: Optional[torch.dtype] = None,
|
192
|
+
quant_config: Optional[QuantizationConfig] = None,
|
193
|
+
prefix: str = "",
|
194
|
+
):
|
195
|
+
super().__init__()
|
196
|
+
|
197
|
+
# Keep input parameters
|
198
|
+
self.input_size = input_size
|
199
|
+
self.output_size = output_size
|
200
|
+
self.skip_bias_add = skip_bias_add
|
201
|
+
if params_dtype is None:
|
202
|
+
params_dtype = torch.get_default_dtype()
|
203
|
+
self.params_dtype = params_dtype
|
204
|
+
if quant_config is None:
|
205
|
+
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
206
|
+
else:
|
207
|
+
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
208
|
+
|
209
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
210
|
+
raise NotImplementedError
|
211
|
+
|
212
|
+
|
230
213
|
class ReplicatedLinear(LinearBase):
|
231
214
|
"""Replicated linear layer.
|
232
215
|
|
@@ -426,9 +409,7 @@ class ColumnParallelLinear(LinearBase):
|
|
426
409
|
if len(loaded_weight.shape) == 0:
|
427
410
|
loaded_weight = loaded_weight.reshape(1)
|
428
411
|
|
429
|
-
assert
|
430
|
-
param_data.shape == loaded_weight.shape
|
431
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
412
|
+
assert param_data.shape == loaded_weight.shape
|
432
413
|
param_data.copy_(loaded_weight)
|
433
414
|
|
434
415
|
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
@@ -437,7 +418,7 @@ class ColumnParallelLinear(LinearBase):
|
|
437
418
|
if len(loaded_weight.shape) == 0:
|
438
419
|
assert loaded_weight.numel() == 1
|
439
420
|
loaded_weight = loaded_weight.reshape(1)
|
440
|
-
param.load_column_parallel_weight(loaded_weight=
|
421
|
+
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
|
441
422
|
|
442
423
|
def forward(self, input_):
|
443
424
|
bias = self.bias if not self.skip_bias_add else None
|
@@ -565,9 +546,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
565
546
|
param_data, loaded_weight, 0
|
566
547
|
)
|
567
548
|
|
568
|
-
assert
|
569
|
-
param_data.shape == loaded_weight.shape
|
570
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
549
|
+
assert param_data.shape == loaded_weight.shape
|
571
550
|
param_data.copy_(loaded_weight)
|
572
551
|
return
|
573
552
|
current_shard_offset = 0
|
@@ -643,9 +622,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
643
622
|
"the same for all partitions."
|
644
623
|
)
|
645
624
|
|
646
|
-
assert
|
647
|
-
param_data.shape == loaded_weight.shape
|
648
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
625
|
+
assert param_data.shape == loaded_weight.shape
|
649
626
|
param_data.copy_(loaded_weight)
|
650
627
|
|
651
628
|
def _load_fused_module_from_checkpoint(
|
@@ -697,6 +674,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
697
674
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
698
675
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
699
676
|
return
|
677
|
+
# TODO: @dsikka - move to parameter.py
|
700
678
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
701
679
|
return
|
702
680
|
|
@@ -882,6 +860,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
882
860
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
883
861
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
884
862
|
return
|
863
|
+
# TODO: @dsikka - move to parameter.py
|
885
864
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
886
865
|
return
|
887
866
|
|
@@ -896,24 +875,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
896
875
|
shard_offset = (shard_offset + block_n - 1) // block_n
|
897
876
|
shard_size = (shard_size + block_n - 1) // block_n
|
898
877
|
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
tp_rank=self.tp_rank,
|
908
|
-
)
|
909
|
-
else:
|
910
|
-
param.load_qkv_weight(
|
911
|
-
loaded_weight=loaded_weight,
|
912
|
-
num_heads=self.num_kv_head_replicas,
|
913
|
-
shard_id=loaded_shard_id,
|
914
|
-
shard_offset=shard_offset,
|
915
|
-
shard_size=shard_size,
|
916
|
-
)
|
878
|
+
param.load_qkv_weight(
|
879
|
+
loaded_weight=loaded_weight,
|
880
|
+
num_heads=self.num_kv_head_replicas,
|
881
|
+
shard_id=loaded_shard_id,
|
882
|
+
shard_offset=shard_offset,
|
883
|
+
shard_size=shard_size,
|
884
|
+
tp_rank=self.tp_rank,
|
885
|
+
)
|
917
886
|
|
918
887
|
def weight_loader(
|
919
888
|
self,
|
@@ -962,9 +931,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
962
931
|
param_data, loaded_weight, 0
|
963
932
|
)
|
964
933
|
|
965
|
-
assert
|
966
|
-
param_data.shape == loaded_weight.shape
|
967
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
934
|
+
assert param_data.shape == loaded_weight.shape
|
968
935
|
param_data.copy_(loaded_weight)
|
969
936
|
return
|
970
937
|
shard_offsets = [
|
@@ -1105,9 +1072,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
1105
1072
|
"for all partitions."
|
1106
1073
|
)
|
1107
1074
|
|
1108
|
-
assert
|
1109
|
-
param_data.shape == loaded_weight.shape
|
1110
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
1075
|
+
assert param_data.shape == loaded_weight.shape
|
1111
1076
|
param_data.copy_(loaded_weight)
|
1112
1077
|
|
1113
1078
|
|
@@ -1234,9 +1199,7 @@ class RowParallelLinear(LinearBase):
|
|
1234
1199
|
if len(loaded_weight.shape) == 0:
|
1235
1200
|
loaded_weight = loaded_weight.reshape(1)
|
1236
1201
|
|
1237
|
-
assert
|
1238
|
-
param_data.shape == loaded_weight.shape
|
1239
|
-
), f"{param_data.shape=}, {loaded_weight.shape=}"
|
1202
|
+
assert param_data.shape == loaded_weight.shape
|
1240
1203
|
param_data.copy_(loaded_weight)
|
1241
1204
|
|
1242
1205
|
def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor):
|
@@ -1247,7 +1210,18 @@ class RowParallelLinear(LinearBase):
|
|
1247
1210
|
assert loaded_weight.numel() == 1
|
1248
1211
|
loaded_weight = loaded_weight.reshape(1)
|
1249
1212
|
|
1250
|
-
param
|
1213
|
+
if isinstance(param, BasevLLMParameter):
|
1214
|
+
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
|
1215
|
+
# It supports additional parameters like tp_rank and use_presharded_weights.
|
1216
|
+
param.load_row_parallel_weight(
|
1217
|
+
loaded_weight,
|
1218
|
+
tp_rank=self.tp_rank,
|
1219
|
+
use_presharded_weights=self.use_presharded_weights,
|
1220
|
+
)
|
1221
|
+
else:
|
1222
|
+
# `params` is defined in `vllm/model_executor/parameter.py`,
|
1223
|
+
# It does not support additional parameters.
|
1224
|
+
param.load_row_parallel_weight(loaded_weight)
|
1251
1225
|
|
1252
1226
|
def forward(self, input_):
|
1253
1227
|
if self.input_is_parallel:
|
@@ -14,17 +14,18 @@
|
|
14
14
|
"""Logits processing."""
|
15
15
|
|
16
16
|
import dataclasses
|
17
|
+
import logging
|
17
18
|
from typing import List, Optional, Union
|
18
19
|
|
19
20
|
import torch
|
20
21
|
import triton
|
21
22
|
import triton.language as tl
|
22
23
|
from torch import nn
|
23
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_gather,
|
26
28
|
)
|
27
|
-
|
28
29
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
29
30
|
from sglang.srt.model_executor.forward_batch_info import (
|
30
31
|
CaptureHiddenMode,
|
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
32
33
|
ForwardMode,
|
33
34
|
)
|
34
35
|
|
36
|
+
logger = logging.getLogger(__name__)
|
37
|
+
|
35
38
|
|
36
39
|
@dataclasses.dataclass
|
37
40
|
class LogitsProcessorOutput:
|
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
|
|
50
53
|
next_token_top_logprobs_idx: Optional[List] = None
|
51
54
|
|
52
55
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
53
|
-
# The normlaized logprobs of prompts. shape: [#seq]
|
54
|
-
normalized_prompt_logprobs: torch.Tensor = None
|
55
56
|
# The logprobs of input tokens. shape: [#token]
|
56
57
|
input_token_logprobs: torch.Tensor = None
|
57
58
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
|
|
129
130
|
hidden_states,
|
130
131
|
lm_head: VocabParallelEmbedding,
|
131
132
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
132
|
-
):
|
133
|
+
) -> LogitsProcessorOutput:
|
133
134
|
if isinstance(logits_metadata, ForwardBatch):
|
134
135
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
135
136
|
|
136
137
|
# Get the last hidden states and last logits for the next token prediction
|
137
138
|
if (
|
138
|
-
logits_metadata.forward_mode.
|
139
|
+
logits_metadata.forward_mode.is_decode_or_idle()
|
139
140
|
or logits_metadata.forward_mode.is_target_verify()
|
140
141
|
):
|
141
|
-
|
142
|
-
|
143
|
-
|
142
|
+
pruned_states = hidden_states
|
143
|
+
sample_indices = None
|
144
|
+
elif (
|
145
|
+
logits_metadata.forward_mode.is_extend()
|
146
|
+
and not logits_metadata.extend_return_logprob
|
147
|
+
):
|
148
|
+
# Prefill without input logprobs.
|
144
149
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
145
|
-
|
150
|
+
pruned_states = hidden_states[last_index]
|
151
|
+
sample_indices = None
|
152
|
+
else:
|
153
|
+
# Slice the requested tokens to compute logprob
|
154
|
+
sample_index_pt = -1
|
155
|
+
sample_indices = []
|
156
|
+
pt, pruned_states, pruned_input_ids = 0, [], []
|
157
|
+
for start_len, extend_len in zip(
|
158
|
+
logits_metadata.extend_logprob_start_lens_cpu,
|
159
|
+
logits_metadata.extend_seq_lens_cpu,
|
160
|
+
):
|
161
|
+
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
162
|
+
sample_index_pt += extend_len - start_len
|
163
|
+
sample_indices.append(sample_index_pt)
|
164
|
+
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
165
|
+
pt += extend_len
|
166
|
+
|
167
|
+
pruned_states = torch.cat(pruned_states)
|
168
|
+
|
169
|
+
# Compute logits for both input and sampled tokens.
|
170
|
+
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
171
|
+
sampled_logits = (
|
172
|
+
logits[sample_indices] if sample_indices is not None else logits
|
173
|
+
)
|
146
174
|
|
147
|
-
# Compute logits
|
148
|
-
last_logits = self._get_logits(last_hidden, lm_head)
|
149
175
|
if (
|
150
176
|
not logits_metadata.extend_return_logprob
|
151
177
|
or logits_metadata.capture_hidden_mode.need_capture()
|
152
178
|
):
|
153
179
|
# Decode mode or extend mode without return_logprob.
|
154
180
|
return LogitsProcessorOutput(
|
155
|
-
next_token_logits=
|
181
|
+
next_token_logits=sampled_logits,
|
156
182
|
hidden_states=(
|
157
183
|
hidden_states
|
158
184
|
if logits_metadata.capture_hidden_mode.is_full()
|
159
185
|
else (
|
160
|
-
|
186
|
+
pruned_states
|
161
187
|
if logits_metadata.capture_hidden_mode.is_last()
|
162
188
|
else None
|
163
189
|
)
|
164
190
|
),
|
165
191
|
)
|
166
192
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
for start_len, extend_len in zip(
|
170
|
-
logits_metadata.extend_logprob_start_lens_cpu,
|
171
|
-
logits_metadata.extend_seq_lens_cpu,
|
172
|
-
):
|
173
|
-
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
174
|
-
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
175
|
-
pt += extend_len
|
176
|
-
|
177
|
-
# Compute the logits of all required tokens
|
178
|
-
pruned_states = torch.cat(pruned_states)
|
179
|
-
del hidden_states
|
180
|
-
input_token_logits = self._get_logits(pruned_states, lm_head)
|
181
|
-
del pruned_states
|
193
|
+
input_logprobs = logits
|
194
|
+
del hidden_states, logits
|
182
195
|
|
183
196
|
# Normalize the logprob w/o temperature, top-p
|
184
|
-
input_logprobs = input_token_logits
|
185
197
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
186
198
|
input_logprobs, logits_metadata
|
187
199
|
)
|
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
|
|
195
207
|
else:
|
196
208
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
197
209
|
|
198
|
-
# Compute the normalized logprobs for the requested tokens.
|
199
|
-
# Note that we pad a zero at the end for easy batching.
|
200
210
|
input_token_logprobs = input_logprobs[
|
201
|
-
torch.arange(input_logprobs.shape[0], device=
|
211
|
+
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
202
212
|
torch.cat(
|
203
213
|
[
|
204
214
|
torch.cat(pruned_input_ids)[1:],
|
205
|
-
torch.tensor([0], device=
|
215
|
+
torch.tensor([0], device=input_logprobs.device),
|
206
216
|
]
|
207
217
|
),
|
208
218
|
]
|
209
|
-
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
210
|
-
input_token_logprobs,
|
211
|
-
logits_metadata,
|
212
|
-
)
|
213
219
|
|
214
220
|
return LogitsProcessorOutput(
|
215
|
-
next_token_logits=
|
216
|
-
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
221
|
+
next_token_logits=sampled_logits,
|
217
222
|
input_token_logprobs=input_token_logprobs,
|
218
223
|
input_top_logprobs_val=input_top_logprobs_val,
|
219
224
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
|
|
223
228
|
self,
|
224
229
|
hidden_states: torch.Tensor,
|
225
230
|
lm_head: VocabParallelEmbedding,
|
231
|
+
logits_metadata: LogitsMetadata,
|
226
232
|
embedding_bias: Optional[torch.Tensor] = None,
|
227
233
|
) -> torch.Tensor:
|
234
|
+
"""Get logits from hidden_states."""
|
235
|
+
|
228
236
|
if hasattr(lm_head, "weight"):
|
229
237
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
230
238
|
else:
|
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
|
|
237
245
|
if self.do_tensor_parallel_all_gather:
|
238
246
|
logits = tensor_model_parallel_all_gather(logits)
|
239
247
|
|
240
|
-
# Compute the normalized logprobs for the requested tokens.
|
241
|
-
# Note that we pad a zero at the end for easy batching.
|
242
248
|
logits = logits[:, : self.config.vocab_size].float()
|
243
249
|
|
244
250
|
if self.final_logit_softcapping:
|
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
|
|
246
252
|
|
247
253
|
return logits
|
248
254
|
|
249
|
-
@staticmethod
|
250
|
-
def _get_normalized_prompt_logprobs(
|
251
|
-
input_token_logprobs: torch.Tensor,
|
252
|
-
logits_metadata: LogitsMetadata,
|
253
|
-
):
|
254
|
-
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
255
|
-
pruned_lens = torch.tensor(
|
256
|
-
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
257
|
-
)
|
258
|
-
|
259
|
-
start = torch.zeros_like(pruned_lens)
|
260
|
-
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
261
|
-
end = torch.clamp(
|
262
|
-
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
263
|
-
)
|
264
|
-
sum_logp = (
|
265
|
-
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
266
|
-
)
|
267
|
-
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
268
|
-
return normalized_prompt_logprobs
|
269
|
-
|
270
255
|
@staticmethod
|
271
256
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
272
257
|
max_k = max(logits_metadata.top_logprobs_nums)
|
@@ -4,13 +4,12 @@ from typing import Callable, List, Optional, Tuple
|
|
4
4
|
import torch
|
5
5
|
from torch.nn import Module
|
6
6
|
from vllm import _custom_ops as ops
|
7
|
-
from vllm.
|
7
|
+
from vllm.model_executor.custom_op import CustomOp
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
8
10
|
get_tensor_model_parallel_rank,
|
9
11
|
get_tensor_model_parallel_world_size,
|
10
12
|
)
|
11
|
-
from vllm.model_executor.custom_op import CustomOp
|
12
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
13
|
-
|
14
13
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
15
14
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
16
15
|
grouped_gemm_triton,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
25
24
|
QuantizationConfig,
|
26
25
|
QuantizeMethodBase,
|
27
26
|
)
|
27
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
28
28
|
from sglang.srt.utils import is_hip, set_weight_attrs
|
29
29
|
|
30
30
|
logger = logging.getLogger(__name__)
|
@@ -8,6 +8,7 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
+
from sglang.srt.layers.activation import SiluAndMul
|
11
12
|
from sglang.srt.layers.moe.topk import select_experts
|
12
13
|
|
13
14
|
|
@@ -44,3 +45,71 @@ def fused_moe_forward_native(
|
|
44
45
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
45
46
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
46
47
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
48
|
+
|
49
|
+
|
50
|
+
def moe_forward_native(
|
51
|
+
layer: torch.nn.Module,
|
52
|
+
x: torch.Tensor,
|
53
|
+
use_grouped_topk: bool,
|
54
|
+
top_k: int,
|
55
|
+
router_logits: torch.Tensor,
|
56
|
+
renormalize: bool,
|
57
|
+
topk_group: Optional[int] = None,
|
58
|
+
num_expert_group: Optional[int] = None,
|
59
|
+
custom_routing_function: Optional[Callable] = None,
|
60
|
+
correction_bias: Optional[torch.Tensor] = None,
|
61
|
+
) -> torch.Tensor:
|
62
|
+
|
63
|
+
topk_weights, topk_ids = select_experts(
|
64
|
+
hidden_states=x,
|
65
|
+
router_logits=router_logits,
|
66
|
+
use_grouped_topk=use_grouped_topk,
|
67
|
+
top_k=top_k,
|
68
|
+
renormalize=renormalize,
|
69
|
+
topk_group=topk_group,
|
70
|
+
num_expert_group=num_expert_group,
|
71
|
+
custom_routing_function=custom_routing_function,
|
72
|
+
correction_bias=correction_bias,
|
73
|
+
torch_native=True,
|
74
|
+
)
|
75
|
+
|
76
|
+
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
77
|
+
len_experts = layer.num_experts
|
78
|
+
|
79
|
+
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
80
|
+
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
81
|
+
tokens_per_expert = cnts.sum(dim=0)
|
82
|
+
idxs = topk_ids.view(-1).argsort()
|
83
|
+
|
84
|
+
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
85
|
+
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
86
|
+
|
87
|
+
outputs = []
|
88
|
+
start_idx = 0
|
89
|
+
for i, num_tokens in enumerate(tokens_per_expert):
|
90
|
+
end_idx = start_idx + num_tokens
|
91
|
+
if num_tokens == 0:
|
92
|
+
continue
|
93
|
+
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
94
|
+
|
95
|
+
layer_w13_weight = layer.w13_weight[i]
|
96
|
+
layer_w2_weight = layer.w2_weight[i]
|
97
|
+
|
98
|
+
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
99
|
+
gate_up = SiluAndMul()(gate_up)
|
100
|
+
expert_out = F.linear(gate_up, layer_w2_weight)
|
101
|
+
outputs.append(expert_out)
|
102
|
+
start_idx = end_idx
|
103
|
+
|
104
|
+
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
105
|
+
new_x = torch.empty_like(outs)
|
106
|
+
|
107
|
+
new_x[idxs] = outs
|
108
|
+
final_out = (
|
109
|
+
new_x.view(*topk_ids.shape, -1)
|
110
|
+
.type(topk_weights.dtype)
|
111
|
+
.mul_(topk_weights.unsqueeze(dim=-1))
|
112
|
+
.sum(dim=1)
|
113
|
+
.type(new_x.dtype)
|
114
|
+
)
|
115
|
+
return final_out
|
@@ -15,15 +15,18 @@ from vllm import _custom_ops as ops
|
|
15
15
|
|
16
16
|
from sglang.srt.layers.moe.topk import select_experts
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
-
from sglang.srt.utils import
|
18
|
+
from sglang.srt.utils import (
|
19
|
+
direct_register_custom_op,
|
20
|
+
get_device_name,
|
21
|
+
is_cuda_available,
|
22
|
+
is_hip,
|
23
|
+
)
|
19
24
|
|
20
|
-
|
21
|
-
|
25
|
+
is_cuda = is_cuda_available()
|
26
|
+
is_hip_flag = is_hip()
|
27
|
+
if is_cuda:
|
22
28
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
23
29
|
|
24
|
-
is_hip_flag = False
|
25
|
-
else:
|
26
|
-
is_hip_flag = True
|
27
30
|
|
28
31
|
logger = logging.getLogger(__name__)
|
29
32
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|