sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -5,20 +5,21 @@ from enum import Enum
|
|
5
5
|
from typing import Callable, List, Optional, Tuple
|
6
6
|
|
7
7
|
import torch
|
8
|
-
from vllm.
|
8
|
+
from vllm.model_executor.custom_op import CustomOp
|
9
|
+
|
10
|
+
from sglang.srt.distributed import (
|
9
11
|
get_tensor_model_parallel_rank,
|
10
12
|
get_tensor_model_parallel_world_size,
|
11
13
|
tensor_model_parallel_all_reduce,
|
12
14
|
)
|
13
|
-
from vllm.model_executor.custom_op import CustomOp
|
14
|
-
|
15
15
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
16
|
+
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
16
17
|
from sglang.srt.layers.moe.topk import select_experts
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
19
20
|
QuantizeMethodBase,
|
20
21
|
)
|
21
|
-
from sglang.srt.utils import set_weight_attrs
|
22
|
+
from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
|
22
23
|
|
23
24
|
if torch.cuda.is_available():
|
24
25
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -27,6 +28,8 @@ else:
|
|
27
28
|
|
28
29
|
import logging
|
29
30
|
|
31
|
+
is_hip_ = is_hip()
|
32
|
+
|
30
33
|
logger = logging.getLogger(__name__)
|
31
34
|
|
32
35
|
|
@@ -97,6 +100,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
97
100
|
layer.register_parameter("w2_weight", w2_weight)
|
98
101
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
99
102
|
|
103
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
104
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
105
|
+
layer.w13_weight = torch.nn.Parameter(
|
106
|
+
permute_weight(layer.w13_weight.data),
|
107
|
+
requires_grad=False,
|
108
|
+
)
|
109
|
+
torch.cuda.empty_cache()
|
110
|
+
layer.w2_weight = torch.nn.Parameter(
|
111
|
+
permute_weight(layer.w2_weight.data),
|
112
|
+
requires_grad=False,
|
113
|
+
)
|
114
|
+
torch.cuda.empty_cache()
|
115
|
+
return
|
116
|
+
|
100
117
|
def apply(
|
101
118
|
self,
|
102
119
|
layer: torch.nn.Module,
|
@@ -148,17 +165,52 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
148
165
|
correction_bias=correction_bias,
|
149
166
|
)
|
150
167
|
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
168
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
169
|
+
import ater
|
170
|
+
from ater.fused_moe import fused_experts_ck
|
171
|
+
|
172
|
+
return fused_experts_ck(
|
173
|
+
hidden_states=x,
|
174
|
+
w1=layer.w13_weight,
|
175
|
+
w2=layer.w2_weight,
|
176
|
+
topk_weights=topk_weights,
|
177
|
+
topk_ids=topk_ids,
|
178
|
+
)
|
179
|
+
else:
|
180
|
+
return fused_experts(
|
181
|
+
hidden_states=x,
|
182
|
+
w1=layer.w13_weight,
|
183
|
+
w2=layer.w2_weight,
|
184
|
+
topk_weights=topk_weights,
|
185
|
+
topk_ids=topk_ids,
|
186
|
+
inplace=True,
|
187
|
+
)
|
159
188
|
|
160
|
-
def forward_cpu(
|
161
|
-
|
189
|
+
def forward_cpu(
|
190
|
+
self,
|
191
|
+
layer: torch.nn.Module,
|
192
|
+
x: torch.Tensor,
|
193
|
+
use_grouped_topk: bool,
|
194
|
+
top_k: int,
|
195
|
+
router_logits: torch.Tensor,
|
196
|
+
renormalize: bool,
|
197
|
+
topk_group: Optional[int] = None,
|
198
|
+
num_expert_group: Optional[int] = None,
|
199
|
+
custom_routing_function: Optional[Callable] = None,
|
200
|
+
correction_bias: Optional[torch.Tensor] = None,
|
201
|
+
) -> torch.Tensor:
|
202
|
+
return moe_forward_native(
|
203
|
+
layer,
|
204
|
+
x,
|
205
|
+
use_grouped_topk,
|
206
|
+
top_k,
|
207
|
+
router_logits,
|
208
|
+
renormalize,
|
209
|
+
topk_group,
|
210
|
+
num_expert_group,
|
211
|
+
custom_routing_function,
|
212
|
+
correction_bias,
|
213
|
+
)
|
162
214
|
|
163
215
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
164
216
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -24,7 +24,9 @@ def fused_topk_native(
|
|
24
24
|
topk: int,
|
25
25
|
renormalize: bool,
|
26
26
|
):
|
27
|
-
assert
|
27
|
+
assert (
|
28
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
29
|
+
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
28
30
|
M, _ = hidden_states.shape
|
29
31
|
topk_weights = torch.empty(
|
30
32
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
@@ -180,7 +182,7 @@ def select_experts(
|
|
180
182
|
num_expert_group=num_expert_group,
|
181
183
|
topk_group=topk_group,
|
182
184
|
)
|
183
|
-
elif torch_native:
|
185
|
+
elif torch_native and custom_routing_function is None:
|
184
186
|
topk_weights, topk_ids = fused_topk_native(
|
185
187
|
hidden_states=hidden_states,
|
186
188
|
gating_output=router_logits,
|
sglang/srt/layers/parameter.py
CHANGED
@@ -1,7 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Adapted from vLLM (0.6.4.post1).
|
3
|
-
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
|
4
|
-
"""
|
1
|
+
"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py"""
|
5
2
|
|
6
3
|
import logging
|
7
4
|
from fractions import Fraction
|
@@ -9,7 +6,8 @@ from typing import Callable, Optional, Union
|
|
9
6
|
|
10
7
|
import torch
|
11
8
|
from torch.nn import Parameter
|
12
|
-
|
9
|
+
|
10
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
13
11
|
|
14
12
|
__all__ = [
|
15
13
|
"BasevLLMParameter",
|
@@ -88,12 +86,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
88
86
|
def output_dim(self):
|
89
87
|
return self._output_dim
|
90
88
|
|
91
|
-
def load_column_parallel_weight(
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
89
|
+
def load_column_parallel_weight(
|
90
|
+
self,
|
91
|
+
loaded_weight: torch.Tensor,
|
92
|
+
tp_rank: int,
|
93
|
+
use_presharded_weights: bool = False,
|
94
|
+
):
|
95
|
+
if not use_presharded_weights:
|
96
|
+
shard_size = self.data.shape[self.output_dim]
|
97
|
+
loaded_weight = loaded_weight.narrow(
|
98
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
99
|
+
)
|
97
100
|
assert self.data.shape == loaded_weight.shape
|
98
101
|
self.data.copy_(loaded_weight)
|
99
102
|
|
@@ -121,7 +124,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
121
124
|
assert param_data.shape == loaded_weight.shape
|
122
125
|
param_data.copy_(loaded_weight)
|
123
126
|
|
124
|
-
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
127
|
+
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
|
125
128
|
|
126
129
|
shard_offset = kwargs.get("shard_offset")
|
127
130
|
shard_size = kwargs.get("shard_size")
|
@@ -137,7 +140,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
137
140
|
)
|
138
141
|
|
139
142
|
param_data = self.data
|
140
|
-
tp_rank = get_tensor_model_parallel_rank()
|
141
143
|
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
142
144
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
143
145
|
loaded_weight = loaded_weight.narrow(
|
@@ -164,11 +166,14 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
164
166
|
def input_dim(self):
|
165
167
|
return self._input_dim
|
166
168
|
|
167
|
-
def load_row_parallel_weight(
|
168
|
-
|
169
|
-
|
170
|
-
|
169
|
+
def load_row_parallel_weight(
|
170
|
+
self,
|
171
|
+
loaded_weight: torch.Tensor,
|
172
|
+
tp_rank: int,
|
173
|
+
use_presharded_weights: bool = False,
|
174
|
+
):
|
171
175
|
if not use_presharded_weights:
|
176
|
+
shard_size = self.data.shape[self.input_dim]
|
172
177
|
loaded_weight = loaded_weight.narrow(
|
173
178
|
self.input_dim, tp_rank * shard_size, shard_size
|
174
179
|
)
|
@@ -238,6 +243,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
|
238
243
|
# For row parallel layers, no sharding needed
|
239
244
|
# load weight into parameter as is
|
240
245
|
def load_row_parallel_weight(self, *args, **kwargs):
|
246
|
+
kwargs.pop("tp_rank", None)
|
247
|
+
kwargs.pop("use_presharded_weights", None)
|
241
248
|
super().load_row_parallel_weight(*args, **kwargs)
|
242
249
|
|
243
250
|
def load_merged_column_weight(self, *args, **kwargs):
|
@@ -247,6 +254,8 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
|
247
254
|
self._load_into_shard_id(*args, **kwargs)
|
248
255
|
|
249
256
|
def load_column_parallel_weight(self, *args, **kwargs):
|
257
|
+
kwargs.pop("tp_rank", None)
|
258
|
+
kwargs.pop("use_presharded_weights", None)
|
250
259
|
super().load_row_parallel_weight(*args, **kwargs)
|
251
260
|
|
252
261
|
def _load_into_shard_id(
|
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
|
23
23
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
24
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
25
25
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
26
|
+
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
26
27
|
|
27
28
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
28
29
|
"aqlm": AQLMConfig,
|
@@ -42,6 +43,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
42
43
|
"bitsandbytes": BitsAndBytesConfig,
|
43
44
|
"qqq": QQQConfig,
|
44
45
|
"experts_int8": ExpertsInt8Config,
|
46
|
+
"w8a8_int8": W8A8Int8Config,
|
45
47
|
}
|
46
48
|
|
47
49
|
|
@@ -54,33 +56,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
54
56
|
return QUANTIZATION_METHODS[quantization]
|
55
57
|
|
56
58
|
|
57
|
-
def fp8_get_quant_method(self, layer, prefix):
|
58
|
-
"""Enhanced get_quant_method for FP8 config."""
|
59
|
-
from vllm.model_executor.layers.linear import LinearBase
|
60
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
61
|
-
is_layer_skipped,
|
62
|
-
)
|
63
|
-
|
64
|
-
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
65
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
66
|
-
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
67
|
-
|
68
|
-
if isinstance(layer, LinearBase):
|
69
|
-
if is_layer_skipped(prefix, self.ignored_layers):
|
70
|
-
return UnquantizedLinearMethod()
|
71
|
-
return Fp8LinearMethod(self)
|
72
|
-
elif isinstance(layer, FusedMoE):
|
73
|
-
return Fp8MoEMethod(self)
|
74
|
-
return None
|
75
|
-
|
76
|
-
|
77
59
|
def gptq_get_quant_method(self, layer, prefix):
|
78
|
-
from vllm.model_executor.layers.linear import LinearBase
|
79
60
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
80
61
|
GPTQMarlinLinearMethod,
|
81
62
|
GPTQMarlinMoEMethod,
|
82
63
|
)
|
83
64
|
|
65
|
+
from sglang.srt.layers.linear import LinearBase
|
84
66
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
85
67
|
|
86
68
|
if isinstance(layer, LinearBase):
|
@@ -91,12 +73,12 @@ def gptq_get_quant_method(self, layer, prefix):
|
|
91
73
|
|
92
74
|
|
93
75
|
def awq_get_quant_method(self, layer, prefix):
|
94
|
-
from vllm.model_executor.layers.linear import LinearBase
|
95
76
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
96
77
|
AWQMarlinLinearMethod,
|
97
78
|
AWQMoEMethod,
|
98
79
|
)
|
99
80
|
|
81
|
+
from sglang.srt.layers.linear import LinearBase
|
100
82
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
101
83
|
|
102
84
|
if isinstance(layer, LinearBase):
|
@@ -106,13 +88,30 @@ def awq_get_quant_method(self, layer, prefix):
|
|
106
88
|
return None
|
107
89
|
|
108
90
|
|
91
|
+
def patch_vllm_linear_base_isinstance():
|
92
|
+
import builtins
|
93
|
+
|
94
|
+
from vllm.model_executor.layers.linear import LinearBase
|
95
|
+
|
96
|
+
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
97
|
+
|
98
|
+
original_isinstance = builtins.isinstance
|
99
|
+
|
100
|
+
def patched_isinstance(obj, classinfo):
|
101
|
+
if classinfo is LinearBase:
|
102
|
+
return original_isinstance(obj, PatchedLinearBase)
|
103
|
+
return original_isinstance(obj, classinfo)
|
104
|
+
|
105
|
+
builtins.isinstance = patched_isinstance
|
106
|
+
|
107
|
+
|
109
108
|
def apply_monkey_patches():
|
110
109
|
"""Apply all monkey patches in one place."""
|
111
|
-
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
112
110
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
113
111
|
setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method)
|
114
112
|
|
115
113
|
|
114
|
+
patch_vllm_linear_base_isinstance()
|
116
115
|
# Apply patches when module is imported
|
117
116
|
apply_monkey_patches()
|
118
117
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import os
|
5
4
|
from typing import Any, Callable, Dict, List, Optional
|
6
5
|
|
7
6
|
import torch
|
@@ -9,8 +8,6 @@ import torch.nn.functional as F
|
|
9
8
|
from torch.nn import Module
|
10
9
|
from torch.nn.parameter import Parameter
|
11
10
|
from vllm import _custom_ops as ops
|
12
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
13
|
-
from vllm.model_executor.layers.linear import LinearBase
|
14
11
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
15
12
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
16
13
|
apply_fp8_marlin_linear,
|
@@ -26,7 +23,12 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
26
23
|
requantize_with_max_scale,
|
27
24
|
)
|
28
25
|
|
29
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.linear import (
|
28
|
+
LinearBase,
|
29
|
+
LinearMethodBase,
|
30
|
+
UnquantizedLinearMethod,
|
31
|
+
)
|
30
32
|
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
31
33
|
from sglang.srt.layers.quantization.base_config import (
|
32
34
|
QuantizationConfig,
|
@@ -40,12 +42,15 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
40
42
|
from sglang.srt.utils import (
|
41
43
|
get_bool_env_var,
|
42
44
|
is_hip,
|
45
|
+
permute_weight,
|
43
46
|
print_warning_once,
|
44
47
|
set_weight_attrs,
|
45
48
|
)
|
46
49
|
|
47
50
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
48
51
|
|
52
|
+
is_hip_ = is_hip()
|
53
|
+
|
49
54
|
logger = logging.getLogger(__name__)
|
50
55
|
|
51
56
|
|
@@ -161,7 +166,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
161
166
|
# kernel for fast weight-only FP8 quantization
|
162
167
|
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
|
163
168
|
# Disable marlin for ROCm
|
164
|
-
if
|
169
|
+
if is_hip_:
|
165
170
|
self.use_marlin = False
|
166
171
|
|
167
172
|
self.block_quant = self.quant_config.weight_block_size is not None
|
@@ -273,7 +278,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
273
278
|
# Block quant doesn't need to process weights after loading
|
274
279
|
if self.block_quant:
|
275
280
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
276
|
-
if
|
281
|
+
if is_hip_:
|
277
282
|
# activation_scheme: dynamic
|
278
283
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
279
284
|
weight=layer.weight,
|
@@ -330,7 +335,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
335
|
weight_scale = layer.weight_scale
|
331
336
|
|
332
337
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
333
|
-
if
|
338
|
+
if is_hip_:
|
334
339
|
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
335
340
|
weight=weight,
|
336
341
|
weight_scale=weight_scale,
|
@@ -567,7 +572,7 @@ class Fp8MoEMethod:
|
|
567
572
|
# Block quant doesn't need to process weights after loading
|
568
573
|
if self.block_quant:
|
569
574
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
570
|
-
if
|
575
|
+
if is_hip_:
|
571
576
|
# activation_scheme: dynamic
|
572
577
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
573
578
|
weight=layer.w13_weight,
|
@@ -594,7 +599,7 @@ class Fp8MoEMethod:
|
|
594
599
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
595
600
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
596
601
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
597
|
-
fp8_dtype = torch.float8_e4m3fnuz if
|
602
|
+
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
|
598
603
|
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
|
599
604
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
600
605
|
|
@@ -616,18 +621,30 @@ class Fp8MoEMethod:
|
|
616
621
|
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
617
622
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
618
623
|
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
624
|
+
if is_hip_:
|
625
|
+
if get_bool_env_var("CK_MOE"):
|
626
|
+
layer.w13_weight = torch.nn.Parameter(
|
627
|
+
permute_weight(layer.w13_weight.data),
|
628
|
+
requires_grad=False,
|
629
|
+
)
|
630
|
+
torch.cuda.empty_cache()
|
631
|
+
layer.w2_weight = torch.nn.Parameter(
|
632
|
+
permute_weight(layer.w2_weight.data),
|
633
|
+
requires_grad=False,
|
634
|
+
)
|
635
|
+
torch.cuda.empty_cache()
|
636
|
+
elif get_bool_env_var("MOE_PADDING"):
|
637
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
638
|
+
layer.w13_weight = torch.nn.Parameter(
|
639
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
640
|
+
requires_grad=False,
|
641
|
+
)
|
642
|
+
torch.cuda.empty_cache()
|
643
|
+
layer.w2_weight = torch.nn.Parameter(
|
644
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
645
|
+
requires_grad=False,
|
646
|
+
)
|
647
|
+
torch.cuda.empty_cache()
|
631
648
|
return
|
632
649
|
|
633
650
|
# If checkpoint is fp8, we need to handle that the
|
@@ -658,7 +675,7 @@ class Fp8MoEMethod:
|
|
658
675
|
)
|
659
676
|
|
660
677
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
661
|
-
if
|
678
|
+
if is_hip_:
|
662
679
|
# Normalize the weights and scales
|
663
680
|
w13_weight, w13_weight_scale, w13_input_scale = (
|
664
681
|
normalize_e4m3fn_to_e4m3fnuz(
|
@@ -708,18 +725,30 @@ class Fp8MoEMethod:
|
|
708
725
|
max_w13_scales, requires_grad=False
|
709
726
|
)
|
710
727
|
|
711
|
-
|
712
|
-
|
713
|
-
|
714
|
-
|
715
|
-
|
716
|
-
|
717
|
-
|
718
|
-
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
728
|
+
if is_hip_:
|
729
|
+
if get_bool_env_var("CK_MOE"):
|
730
|
+
layer.w13_weight = torch.nn.Parameter(
|
731
|
+
permute_weight(layer.w13_weight.data),
|
732
|
+
requires_grad=False,
|
733
|
+
)
|
734
|
+
torch.cuda.empty_cache()
|
735
|
+
layer.w2_weight = torch.nn.Parameter(
|
736
|
+
permute_weight(layer.w2_weight.data),
|
737
|
+
requires_grad=False,
|
738
|
+
)
|
739
|
+
torch.cuda.empty_cache()
|
740
|
+
elif get_bool_env_var("MOE_PADDING"):
|
741
|
+
# If ROCm, apply weight padding (min. Mem channel contention) only if set
|
742
|
+
layer.w13_weight = torch.nn.Parameter(
|
743
|
+
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
|
744
|
+
requires_grad=False,
|
745
|
+
)
|
746
|
+
torch.cuda.empty_cache()
|
747
|
+
layer.w2_weight = torch.nn.Parameter(
|
748
|
+
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
|
749
|
+
requires_grad=False,
|
750
|
+
)
|
751
|
+
torch.cuda.empty_cache()
|
723
752
|
return
|
724
753
|
|
725
754
|
def apply(
|
@@ -752,27 +781,55 @@ class Fp8MoEMethod:
|
|
752
781
|
correction_bias=correction_bias,
|
753
782
|
)
|
754
783
|
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
760
|
-
|
761
|
-
|
762
|
-
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
|
771
|
-
|
772
|
-
|
773
|
-
|
774
|
-
|
775
|
-
|
784
|
+
if is_hip_ and get_bool_env_var("CK_MOE"):
|
785
|
+
import ater
|
786
|
+
from ater.fused_moe import fused_experts_ck
|
787
|
+
|
788
|
+
return fused_experts_ck(
|
789
|
+
x,
|
790
|
+
layer.w13_weight,
|
791
|
+
layer.w2_weight,
|
792
|
+
topk_weights=topk_weights,
|
793
|
+
topk_ids=topk_ids,
|
794
|
+
use_fp8_w8a8=True,
|
795
|
+
w1_scale=(
|
796
|
+
layer.w13_weight_scale_inv
|
797
|
+
if self.block_quant
|
798
|
+
else layer.w13_weight_scale
|
799
|
+
),
|
800
|
+
w2_scale=(
|
801
|
+
layer.w2_weight_scale_inv
|
802
|
+
if self.block_quant
|
803
|
+
else layer.w2_weight_scale
|
804
|
+
),
|
805
|
+
a1_scale=layer.w13_input_scale,
|
806
|
+
a2_scale=layer.w2_input_scale,
|
807
|
+
)
|
808
|
+
|
809
|
+
else:
|
810
|
+
# Expert fusion with FP8 quantization
|
811
|
+
return fused_experts(
|
812
|
+
x,
|
813
|
+
layer.w13_weight,
|
814
|
+
layer.w2_weight,
|
815
|
+
topk_weights=topk_weights,
|
816
|
+
topk_ids=topk_ids,
|
817
|
+
inplace=True,
|
818
|
+
use_fp8_w8a8=True,
|
819
|
+
w1_scale=(
|
820
|
+
layer.w13_weight_scale_inv
|
821
|
+
if self.block_quant
|
822
|
+
else layer.w13_weight_scale
|
823
|
+
),
|
824
|
+
w2_scale=(
|
825
|
+
layer.w2_weight_scale_inv
|
826
|
+
if self.block_quant
|
827
|
+
else layer.w2_weight_scale
|
828
|
+
),
|
829
|
+
a1_scale=layer.w13_input_scale,
|
830
|
+
a2_scale=layer.w2_input_scale,
|
831
|
+
block_shape=self.quant_config.weight_block_size,
|
832
|
+
)
|
776
833
|
|
777
834
|
|
778
835
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
@@ -1,8 +1,8 @@
|
|
1
1
|
from typing import List, Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
|
-
from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
5
4
|
|
5
|
+
from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter
|
6
6
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
7
7
|
per_token_group_quant_fp8,
|
8
8
|
w8a8_block_fp8_matmul,
|
@@ -0,0 +1,54 @@
|
|
1
|
+
import torch
|
2
|
+
import triton
|
3
|
+
import triton.language as tl
|
4
|
+
|
5
|
+
|
6
|
+
@triton.jit
|
7
|
+
def _per_token_quant_int8(
|
8
|
+
x_ptr,
|
9
|
+
xq_ptr,
|
10
|
+
scale_ptr,
|
11
|
+
stride_x,
|
12
|
+
stride_xq,
|
13
|
+
N,
|
14
|
+
BLOCK: tl.constexpr,
|
15
|
+
):
|
16
|
+
# Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
17
|
+
row_id = tl.program_id(0)
|
18
|
+
|
19
|
+
cols = tl.arange(0, BLOCK)
|
20
|
+
mask = cols < N
|
21
|
+
|
22
|
+
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
|
23
|
+
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
|
24
|
+
scale_x = absmax / 127
|
25
|
+
x_q = x * (127 / absmax)
|
26
|
+
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
|
27
|
+
|
28
|
+
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
|
29
|
+
tl.store(scale_ptr + row_id, scale_x)
|
30
|
+
|
31
|
+
|
32
|
+
def per_token_quant_int8(x):
|
33
|
+
M = x.numel() // x.shape[-1]
|
34
|
+
N = x.shape[-1]
|
35
|
+
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
|
36
|
+
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
|
37
|
+
BLOCK = triton.next_power_of_2(N)
|
38
|
+
# heuristics for number of warps
|
39
|
+
num_warps = min(max(BLOCK // 256, 1), 8)
|
40
|
+
|
41
|
+
assert x.is_contiguous()
|
42
|
+
_per_token_quant_int8[(M,)](
|
43
|
+
x,
|
44
|
+
x_q,
|
45
|
+
scales,
|
46
|
+
stride_x=x.stride(-2),
|
47
|
+
stride_xq=x_q.stride(-2),
|
48
|
+
N=N,
|
49
|
+
BLOCK=BLOCK,
|
50
|
+
num_warps=num_warps,
|
51
|
+
num_stages=1,
|
52
|
+
)
|
53
|
+
|
54
|
+
return x_q, scales
|
@@ -5,15 +5,14 @@ from typing import Any, Dict, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch.nn.parameter import Parameter
|
8
|
-
from vllm.model_executor.layers.linear import LinearBase
|
9
8
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
10
9
|
apply_fp8_linear,
|
11
10
|
cutlass_fp8_supported,
|
12
11
|
requantize_with_max_scale,
|
13
12
|
)
|
14
|
-
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
15
13
|
|
16
|
-
from sglang.srt.layers.linear import LinearMethodBase
|
14
|
+
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
15
|
+
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
|
17
16
|
from sglang.srt.layers.quantization.base_config import (
|
18
17
|
QuantizationConfig,
|
19
18
|
QuantizeMethodBase,
|