sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
|
|
1
|
+
# Copyright 2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
from typing import Callable, Optional
|
16
|
+
|
17
|
+
import torch
|
18
|
+
import torch.nn.functional as F
|
19
|
+
|
20
|
+
|
21
|
+
def fused_topk_native(
|
22
|
+
hidden_states: torch.Tensor,
|
23
|
+
gating_output: torch.Tensor,
|
24
|
+
topk: int,
|
25
|
+
renormalize: bool,
|
26
|
+
):
|
27
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
28
|
+
M, _ = hidden_states.shape
|
29
|
+
topk_weights = torch.empty(
|
30
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
31
|
+
)
|
32
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
33
|
+
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
34
|
+
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
35
|
+
if renormalize:
|
36
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
37
|
+
return topk_weights, topk_ids
|
38
|
+
|
39
|
+
|
40
|
+
def fused_topk(
|
41
|
+
hidden_states: torch.Tensor,
|
42
|
+
gating_output: torch.Tensor,
|
43
|
+
topk: int,
|
44
|
+
renormalize: bool,
|
45
|
+
):
|
46
|
+
from vllm import _custom_ops as ops
|
47
|
+
|
48
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
49
|
+
|
50
|
+
M, _ = hidden_states.shape
|
51
|
+
|
52
|
+
topk_weights = torch.empty(
|
53
|
+
M, topk, dtype=torch.float32, device=hidden_states.device
|
54
|
+
)
|
55
|
+
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
56
|
+
token_expert_indicies = torch.empty(
|
57
|
+
M, topk, dtype=torch.int32, device=hidden_states.device
|
58
|
+
)
|
59
|
+
|
60
|
+
ops.topk_softmax(
|
61
|
+
topk_weights,
|
62
|
+
topk_ids,
|
63
|
+
token_expert_indicies,
|
64
|
+
gating_output.float(),
|
65
|
+
)
|
66
|
+
del token_expert_indicies
|
67
|
+
|
68
|
+
if renormalize:
|
69
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
70
|
+
|
71
|
+
return topk_weights, topk_ids
|
72
|
+
|
73
|
+
|
74
|
+
# This is used by the Deepseek-V2 model
|
75
|
+
def grouped_topk(
|
76
|
+
hidden_states: torch.Tensor,
|
77
|
+
gating_output: torch.Tensor,
|
78
|
+
topk: int,
|
79
|
+
renormalize: bool,
|
80
|
+
num_expert_group: int = 0,
|
81
|
+
topk_group: int = 0,
|
82
|
+
):
|
83
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
84
|
+
|
85
|
+
scores = torch.softmax(gating_output, dim=-1)
|
86
|
+
num_token = scores.shape[0]
|
87
|
+
group_scores = (
|
88
|
+
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
89
|
+
) # [n, n_group]
|
90
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
91
|
+
1
|
92
|
+
] # [n, top_k_group]
|
93
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
94
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
95
|
+
score_mask = (
|
96
|
+
group_mask.unsqueeze(-1)
|
97
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
98
|
+
.reshape(num_token, -1)
|
99
|
+
) # [n, e]
|
100
|
+
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
101
|
+
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
102
|
+
|
103
|
+
if renormalize:
|
104
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
105
|
+
|
106
|
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
107
|
+
|
108
|
+
|
109
|
+
def biased_grouped_topk(
|
110
|
+
hidden_states: torch.Tensor,
|
111
|
+
gating_output: torch.Tensor,
|
112
|
+
correction_bias: torch.Tensor,
|
113
|
+
topk: int,
|
114
|
+
renormalize: bool,
|
115
|
+
num_expert_group: int = 0,
|
116
|
+
topk_group: int = 0,
|
117
|
+
):
|
118
|
+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
119
|
+
|
120
|
+
scores = gating_output.sigmoid()
|
121
|
+
num_token = scores.shape[0]
|
122
|
+
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
123
|
+
group_scores = (
|
124
|
+
scores_for_choice.view(num_token, num_expert_group, -1)
|
125
|
+
.topk(2, dim=-1)[0]
|
126
|
+
.sum(dim=-1)
|
127
|
+
) # [n, n_group]
|
128
|
+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
129
|
+
1
|
130
|
+
] # [n, top_k_group]
|
131
|
+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
132
|
+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
133
|
+
score_mask = (
|
134
|
+
group_mask.unsqueeze(-1)
|
135
|
+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
136
|
+
.reshape(num_token, -1)
|
137
|
+
) # [n, e]
|
138
|
+
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
139
|
+
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
140
|
+
topk_weights = scores.gather(1, topk_ids)
|
141
|
+
|
142
|
+
if renormalize:
|
143
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
144
|
+
|
145
|
+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
146
|
+
|
147
|
+
|
148
|
+
def select_experts(
|
149
|
+
hidden_states: torch.Tensor,
|
150
|
+
router_logits: torch.Tensor,
|
151
|
+
top_k: int,
|
152
|
+
use_grouped_topk: bool,
|
153
|
+
renormalize: bool,
|
154
|
+
topk_group: Optional[int] = None,
|
155
|
+
num_expert_group: Optional[int] = None,
|
156
|
+
custom_routing_function: Optional[Callable] = None,
|
157
|
+
correction_bias: Optional[torch.Tensor] = None,
|
158
|
+
torch_native: bool = False,
|
159
|
+
):
|
160
|
+
# DeekSeekv2 uses grouped_top_k
|
161
|
+
if use_grouped_topk:
|
162
|
+
assert topk_group is not None
|
163
|
+
assert num_expert_group is not None
|
164
|
+
if correction_bias is None:
|
165
|
+
topk_weights, topk_ids = grouped_topk(
|
166
|
+
hidden_states=hidden_states,
|
167
|
+
gating_output=router_logits,
|
168
|
+
topk=top_k,
|
169
|
+
renormalize=renormalize,
|
170
|
+
num_expert_group=num_expert_group,
|
171
|
+
topk_group=topk_group,
|
172
|
+
)
|
173
|
+
else:
|
174
|
+
topk_weights, topk_ids = biased_grouped_topk(
|
175
|
+
hidden_states=hidden_states,
|
176
|
+
gating_output=router_logits,
|
177
|
+
correction_bias=correction_bias,
|
178
|
+
topk=top_k,
|
179
|
+
renormalize=renormalize,
|
180
|
+
num_expert_group=num_expert_group,
|
181
|
+
topk_group=topk_group,
|
182
|
+
)
|
183
|
+
elif torch_native:
|
184
|
+
topk_weights, topk_ids = fused_topk_native(
|
185
|
+
hidden_states=hidden_states,
|
186
|
+
gating_output=router_logits,
|
187
|
+
topk=top_k,
|
188
|
+
renormalize=renormalize,
|
189
|
+
)
|
190
|
+
elif custom_routing_function is None:
|
191
|
+
topk_weights, topk_ids = fused_topk(
|
192
|
+
hidden_states=hidden_states,
|
193
|
+
gating_output=router_logits,
|
194
|
+
topk=top_k,
|
195
|
+
renormalize=renormalize,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
topk_weights, topk_ids = custom_routing_function(
|
199
|
+
hidden_states=hidden_states,
|
200
|
+
gating_output=router_logits,
|
201
|
+
topk=top_k,
|
202
|
+
renormalize=renormalize,
|
203
|
+
)
|
204
|
+
|
205
|
+
return topk_weights, topk_ids
|
@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
|
|
60
60
|
is_layer_skipped,
|
61
61
|
)
|
62
62
|
|
63
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
64
63
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
64
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
65
65
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
66
66
|
|
67
67
|
if isinstance(layer, LinearBase):
|
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
|
|
80
80
|
GPTQMarlinMoEMethod,
|
81
81
|
)
|
82
82
|
|
83
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
83
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
84
84
|
|
85
85
|
if isinstance(layer, LinearBase):
|
86
86
|
return GPTQMarlinLinearMethod(self)
|
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
|
|
96
96
|
AWQMoEMethod,
|
97
97
|
)
|
98
98
|
|
99
|
-
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
99
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
100
100
|
|
101
101
|
if isinstance(layer, LinearBase):
|
102
102
|
return AWQMarlinLinearMethod(self)
|
@@ -9,6 +9,7 @@ import torch.nn.functional as F
|
|
9
9
|
from torch.nn import Module
|
10
10
|
from torch.nn.parameter import Parameter
|
11
11
|
from vllm import _custom_ops as ops
|
12
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
12
13
|
from vllm.model_executor.layers.linear import LinearBase
|
13
14
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
14
15
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
@@ -26,13 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|
26
27
|
)
|
27
28
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
28
29
|
|
29
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
30
30
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
31
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
31
32
|
from sglang.srt.layers.quantization.base_config import (
|
32
33
|
QuantizationConfig,
|
33
34
|
QuantizeMethodBase,
|
34
35
|
)
|
35
|
-
from sglang.srt.layers.quantization.fp8_utils import
|
36
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
37
|
+
BlockQuantScaleParameter,
|
38
|
+
apply_w8a8_block_fp8_linear,
|
39
|
+
normalize_e4m3fn_to_e4m3fnuz,
|
40
|
+
)
|
36
41
|
from sglang.srt.utils import (
|
37
42
|
get_bool_env_var,
|
38
43
|
is_hip,
|
@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig):
|
|
53
58
|
is_checkpoint_fp8_serialized: bool = False,
|
54
59
|
activation_scheme: str = "dynamic",
|
55
60
|
ignored_layers: Optional[List[str]] = None,
|
61
|
+
weight_block_size: List[int] = None,
|
56
62
|
) -> None:
|
57
63
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
58
64
|
if is_checkpoint_fp8_serialized:
|
@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig):
|
|
64
70
|
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
65
71
|
self.activation_scheme = activation_scheme
|
66
72
|
self.ignored_layers = ignored_layers or []
|
73
|
+
if weight_block_size is not None:
|
74
|
+
if not is_checkpoint_fp8_serialized:
|
75
|
+
raise ValueError(
|
76
|
+
f"The block-wise quantization only supports fp8-serialized checkpoint for now."
|
77
|
+
)
|
78
|
+
if len(weight_block_size) != 2:
|
79
|
+
raise ValueError(
|
80
|
+
f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
|
81
|
+
)
|
82
|
+
if activation_scheme != "dynamic":
|
83
|
+
raise ValueError(
|
84
|
+
f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
|
85
|
+
)
|
86
|
+
self.weight_block_size = weight_block_size
|
67
87
|
|
68
88
|
@classmethod
|
69
89
|
def get_name(cls) -> str:
|
@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig):
|
|
87
107
|
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
88
108
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
89
109
|
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
110
|
+
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
90
111
|
return cls(
|
91
112
|
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
92
113
|
activation_scheme=activation_scheme,
|
93
114
|
ignored_layers=ignored_layers,
|
115
|
+
weight_block_size=weight_block_size,
|
94
116
|
)
|
95
117
|
|
96
118
|
def get_quant_method(
|
@@ -98,7 +120,7 @@ class Fp8Config(QuantizationConfig):
|
|
98
120
|
) -> Optional["QuantizeMethodBase"]:
|
99
121
|
from vllm.attention.layer import Attention # Avoid circular import
|
100
122
|
|
101
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
123
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
102
124
|
|
103
125
|
if isinstance(layer, LinearBase):
|
104
126
|
if is_layer_skipped(prefix, self.ignored_layers):
|
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
143
165
|
if is_hip():
|
144
166
|
self.use_marlin = False
|
145
167
|
|
168
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
169
|
+
if self.block_quant:
|
170
|
+
# Marlin doesn't support block-wise fp8
|
171
|
+
self.use_marlin = False
|
172
|
+
|
146
173
|
def create_weights(
|
147
174
|
self,
|
148
175
|
layer: torch.nn.Module,
|
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
153
180
|
params_dtype: torch.dtype,
|
154
181
|
**extra_weight_attrs,
|
155
182
|
):
|
156
|
-
del input_size, output_size
|
157
183
|
output_size_per_partition = sum(output_partition_sizes)
|
158
184
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
159
185
|
|
186
|
+
tp_size = get_tensor_model_parallel_world_size()
|
187
|
+
if self.block_quant:
|
188
|
+
block_n, block_k = (
|
189
|
+
self.quant_config.weight_block_size[0],
|
190
|
+
self.quant_config.weight_block_size[1],
|
191
|
+
)
|
192
|
+
# Required by row parallel
|
193
|
+
if tp_size > 1 and input_size // input_size_per_partition == tp_size:
|
194
|
+
if input_size_per_partition % block_k != 0:
|
195
|
+
raise ValueError(
|
196
|
+
f"Weight input_size_per_partition = "
|
197
|
+
f"{input_size_per_partition} is not divisible by "
|
198
|
+
f"weight quantization block_k = {block_k}."
|
199
|
+
)
|
200
|
+
# Required by collum parallel or enabling merged weights
|
201
|
+
if (
|
202
|
+
tp_size > 1 and output_size // output_size_per_partition == tp_size
|
203
|
+
) or len(output_partition_sizes) > 1:
|
204
|
+
for output_partition_size in output_partition_sizes:
|
205
|
+
if output_partition_size % block_n != 0:
|
206
|
+
raise ValueError(
|
207
|
+
f"Weight output_partition_size = "
|
208
|
+
f"{output_partition_size} is not divisible by "
|
209
|
+
f"weight quantization block_n = {block_n}."
|
210
|
+
)
|
211
|
+
|
160
212
|
layer.logical_widths = output_partition_sizes
|
161
213
|
|
162
214
|
layer.input_size_per_partition = input_size_per_partition
|
@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
184
236
|
# Otherwise, wait until process_weights_after_loading.
|
185
237
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
186
238
|
# WEIGHT SCALE
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
239
|
+
if self.block_quant:
|
240
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
241
|
+
scale = BlockQuantScaleParameter(
|
242
|
+
data=torch.empty(
|
243
|
+
(output_size_per_partition + block_n - 1) // block_n,
|
244
|
+
(input_size_per_partition + block_k - 1) // block_k,
|
245
|
+
dtype=torch.float32,
|
246
|
+
),
|
247
|
+
input_dim=1,
|
248
|
+
output_dim=0,
|
249
|
+
weight_loader=weight_loader,
|
250
|
+
)
|
251
|
+
scale[:] = torch.finfo(torch.float32).min
|
252
|
+
layer.register_parameter("weight_scale_inv", scale)
|
253
|
+
else:
|
254
|
+
scale = PerTensorScaleParameter(
|
255
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
256
|
+
weight_loader=weight_loader,
|
257
|
+
)
|
258
|
+
scale[:] = torch.finfo(torch.float32).min
|
259
|
+
layer.register_parameter("weight_scale", scale)
|
194
260
|
|
195
261
|
# INPUT ACTIVATION SCALE
|
196
262
|
if self.quant_config.activation_scheme == "static":
|
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
205
271
|
layer.register_parameter("input_scale", None)
|
206
272
|
|
207
273
|
def process_weights_after_loading(self, layer: Module) -> None:
|
274
|
+
# Block quant doesn't need to process weights after loading
|
275
|
+
if self.block_quant:
|
276
|
+
return
|
208
277
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
209
278
|
# If checkpoint not serialized fp8, quantize the weights.
|
210
279
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
295
364
|
bias=bias,
|
296
365
|
)
|
297
366
|
|
367
|
+
if self.block_quant:
|
368
|
+
return apply_w8a8_block_fp8_linear(
|
369
|
+
input=x,
|
370
|
+
weight=layer.weight,
|
371
|
+
block_size=self.quant_config.weight_block_size,
|
372
|
+
weight_scale=layer.weight_scale_inv,
|
373
|
+
input_scale=layer.input_scale,
|
374
|
+
bias=bias,
|
375
|
+
)
|
376
|
+
|
298
377
|
return apply_fp8_linear(
|
299
378
|
input=x,
|
300
379
|
weight=layer.weight,
|
@@ -320,7 +399,7 @@ class Fp8MoEMethod:
|
|
320
399
|
"""
|
321
400
|
|
322
401
|
def __new__(cls, *args, **kwargs):
|
323
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
|
402
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
324
403
|
|
325
404
|
if not hasattr(cls, "_initialized"):
|
326
405
|
original_init = cls.__init__
|
@@ -339,6 +418,7 @@ class Fp8MoEMethod:
|
|
339
418
|
|
340
419
|
def __init__(self, quant_config):
|
341
420
|
self.quant_config = quant_config
|
421
|
+
self.block_quant = self.quant_config.weight_block_size is not None
|
342
422
|
|
343
423
|
def create_weights(
|
344
424
|
self,
|
@@ -349,10 +429,32 @@ class Fp8MoEMethod:
|
|
349
429
|
params_dtype: torch.dtype,
|
350
430
|
**extra_weight_attrs,
|
351
431
|
):
|
352
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
|
432
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
353
433
|
|
354
434
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
355
435
|
params_dtype = torch.float8_e4m3fn
|
436
|
+
tp_size = get_tensor_model_parallel_world_size()
|
437
|
+
if self.block_quant:
|
438
|
+
block_n, block_k = (
|
439
|
+
self.quant_config.weight_block_size[0],
|
440
|
+
self.quant_config.weight_block_size[1],
|
441
|
+
)
|
442
|
+
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
443
|
+
# Required by collum parallel or enabling merged weights
|
444
|
+
if intermediate_size % block_n != 0:
|
445
|
+
raise ValueError(
|
446
|
+
f"The output_size of gate's and up's weight = "
|
447
|
+
f"{intermediate_size} is not divisible by "
|
448
|
+
f"weight quantization block_n = {block_n}."
|
449
|
+
)
|
450
|
+
if tp_size > 1:
|
451
|
+
# Required by row parallel
|
452
|
+
if intermediate_size % block_k != 0:
|
453
|
+
raise ValueError(
|
454
|
+
f"The input_size of down's weight = "
|
455
|
+
f"{intermediate_size} is not divisible by "
|
456
|
+
f"weight quantization block_k = {block_k}."
|
457
|
+
)
|
356
458
|
|
357
459
|
# WEIGHTS
|
358
460
|
w13_weight = torch.nn.Parameter(
|
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
|
|
374
476
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
375
477
|
|
376
478
|
# WEIGHT_SCALES
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
479
|
+
if self.block_quant:
|
480
|
+
w13_weight_scale = torch.nn.Parameter(
|
481
|
+
torch.ones(
|
482
|
+
num_experts,
|
483
|
+
2 * ((intermediate_size + block_n - 1) // block_n),
|
484
|
+
(hidden_size + block_k - 1) // block_k,
|
485
|
+
dtype=torch.float32,
|
486
|
+
),
|
487
|
+
requires_grad=False,
|
488
|
+
)
|
489
|
+
w2_weight_scale = torch.nn.Parameter(
|
490
|
+
torch.ones(
|
491
|
+
num_experts,
|
492
|
+
(hidden_size + block_n - 1) // block_n,
|
493
|
+
(intermediate_size + block_k - 1) // block_k,
|
494
|
+
dtype=torch.float32,
|
495
|
+
),
|
496
|
+
requires_grad=False,
|
497
|
+
)
|
498
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
499
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
500
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
501
|
+
else:
|
502
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
503
|
+
# They will be combined to a single scale after weight loading.
|
504
|
+
w13_weight_scale = torch.nn.Parameter(
|
505
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
506
|
+
)
|
507
|
+
w2_weight_scale = torch.nn.Parameter(
|
508
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
509
|
+
)
|
510
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
511
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
388
512
|
# Add the quantization method used (per tensor/grouped/channel)
|
389
513
|
# to ensure the weight scales are loaded in properly
|
390
514
|
extra_weight_attrs.update(
|
391
|
-
{"quant_method": FusedMoeWeightScaleSupported.
|
515
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
516
|
+
if self.block_quant
|
517
|
+
else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
392
518
|
)
|
393
519
|
# If loading fp8 checkpoint, pass the weight loaders.
|
394
520
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
@@ -422,7 +548,9 @@ class Fp8MoEMethod:
|
|
422
548
|
layer.w2_input_scale = None
|
423
549
|
|
424
550
|
def process_weights_after_loading(self, layer: Module) -> None:
|
425
|
-
|
551
|
+
# Block quant doesn't need to process weights after loading
|
552
|
+
if self.block_quant:
|
553
|
+
return
|
426
554
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
427
555
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
428
556
|
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
|
@@ -519,7 +647,6 @@ class Fp8MoEMethod:
|
|
519
647
|
layer.w2_input_scale = torch.nn.Parameter(
|
520
648
|
w2_input_scale, requires_grad=False
|
521
649
|
)
|
522
|
-
|
523
650
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
524
651
|
# We take the max then dequant and requant each expert.
|
525
652
|
assert layer.w13_weight_scale is not None
|
@@ -566,12 +693,14 @@ class Fp8MoEMethod:
|
|
566
693
|
topk_group: Optional[int] = None,
|
567
694
|
num_expert_group: Optional[int] = None,
|
568
695
|
custom_routing_function: Optional[Callable] = None,
|
696
|
+
correction_bias: Optional[torch.Tensor] = None,
|
569
697
|
) -> torch.Tensor:
|
570
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
571
|
-
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
698
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
699
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
700
|
+
from sglang.srt.layers.moe.topk import select_experts
|
572
701
|
|
573
702
|
# Expert selection
|
574
|
-
topk_weights, topk_ids =
|
703
|
+
topk_weights, topk_ids = select_experts(
|
575
704
|
hidden_states=x,
|
576
705
|
router_logits=router_logits,
|
577
706
|
use_grouped_topk=use_grouped_topk,
|
@@ -580,6 +709,7 @@ class Fp8MoEMethod:
|
|
580
709
|
topk_group=topk_group,
|
581
710
|
num_expert_group=num_expert_group,
|
582
711
|
custom_routing_function=custom_routing_function,
|
712
|
+
correction_bias=correction_bias,
|
583
713
|
)
|
584
714
|
|
585
715
|
# Expert fusion with FP8 quantization
|
@@ -591,10 +721,17 @@ class Fp8MoEMethod:
|
|
591
721
|
topk_ids=topk_ids,
|
592
722
|
inplace=True,
|
593
723
|
use_fp8_w8a8=True,
|
594
|
-
w1_scale=
|
595
|
-
|
724
|
+
w1_scale=(
|
725
|
+
layer.w13_weight_scale_inv
|
726
|
+
if self.block_quant
|
727
|
+
else layer.w13_weight_scale
|
728
|
+
),
|
729
|
+
w2_scale=(
|
730
|
+
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
|
731
|
+
),
|
596
732
|
a1_scale=layer.w13_input_scale,
|
597
733
|
a2_scale=layer.w2_input_scale,
|
734
|
+
block_shape=self.quant_config.weight_block_size,
|
598
735
|
)
|
599
736
|
|
600
737
|
|