sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,205 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import Callable, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+
20
+
21
+ def fused_topk_native(
22
+ hidden_states: torch.Tensor,
23
+ gating_output: torch.Tensor,
24
+ topk: int,
25
+ renormalize: bool,
26
+ ):
27
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
28
+ M, _ = hidden_states.shape
29
+ topk_weights = torch.empty(
30
+ M, topk, dtype=torch.float32, device=hidden_states.device
31
+ )
32
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
33
+ topk_weights = F.softmax(gating_output.float(), dim=-1)
34
+ topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
35
+ if renormalize:
36
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
37
+ return topk_weights, topk_ids
38
+
39
+
40
+ def fused_topk(
41
+ hidden_states: torch.Tensor,
42
+ gating_output: torch.Tensor,
43
+ topk: int,
44
+ renormalize: bool,
45
+ ):
46
+ from vllm import _custom_ops as ops
47
+
48
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
49
+
50
+ M, _ = hidden_states.shape
51
+
52
+ topk_weights = torch.empty(
53
+ M, topk, dtype=torch.float32, device=hidden_states.device
54
+ )
55
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
56
+ token_expert_indicies = torch.empty(
57
+ M, topk, dtype=torch.int32, device=hidden_states.device
58
+ )
59
+
60
+ ops.topk_softmax(
61
+ topk_weights,
62
+ topk_ids,
63
+ token_expert_indicies,
64
+ gating_output.float(),
65
+ )
66
+ del token_expert_indicies
67
+
68
+ if renormalize:
69
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
70
+
71
+ return topk_weights, topk_ids
72
+
73
+
74
+ # This is used by the Deepseek-V2 model
75
+ def grouped_topk(
76
+ hidden_states: torch.Tensor,
77
+ gating_output: torch.Tensor,
78
+ topk: int,
79
+ renormalize: bool,
80
+ num_expert_group: int = 0,
81
+ topk_group: int = 0,
82
+ ):
83
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
84
+
85
+ scores = torch.softmax(gating_output, dim=-1)
86
+ num_token = scores.shape[0]
87
+ group_scores = (
88
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
89
+ ) # [n, n_group]
90
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
91
+ 1
92
+ ] # [n, top_k_group]
93
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
94
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
95
+ score_mask = (
96
+ group_mask.unsqueeze(-1)
97
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
98
+ .reshape(num_token, -1)
99
+ ) # [n, e]
100
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
101
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
102
+
103
+ if renormalize:
104
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
105
+
106
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
107
+
108
+
109
+ def biased_grouped_topk(
110
+ hidden_states: torch.Tensor,
111
+ gating_output: torch.Tensor,
112
+ correction_bias: torch.Tensor,
113
+ topk: int,
114
+ renormalize: bool,
115
+ num_expert_group: int = 0,
116
+ topk_group: int = 0,
117
+ ):
118
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
119
+
120
+ scores = gating_output.sigmoid()
121
+ num_token = scores.shape[0]
122
+ scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
123
+ group_scores = (
124
+ scores_for_choice.view(num_token, num_expert_group, -1)
125
+ .topk(2, dim=-1)[0]
126
+ .sum(dim=-1)
127
+ ) # [n, n_group]
128
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
129
+ 1
130
+ ] # [n, top_k_group]
131
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
132
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
133
+ score_mask = (
134
+ group_mask.unsqueeze(-1)
135
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
136
+ .reshape(num_token, -1)
137
+ ) # [n, e]
138
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
139
+ _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
140
+ topk_weights = scores.gather(1, topk_ids)
141
+
142
+ if renormalize:
143
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
144
+
145
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
146
+
147
+
148
+ def select_experts(
149
+ hidden_states: torch.Tensor,
150
+ router_logits: torch.Tensor,
151
+ top_k: int,
152
+ use_grouped_topk: bool,
153
+ renormalize: bool,
154
+ topk_group: Optional[int] = None,
155
+ num_expert_group: Optional[int] = None,
156
+ custom_routing_function: Optional[Callable] = None,
157
+ correction_bias: Optional[torch.Tensor] = None,
158
+ torch_native: bool = False,
159
+ ):
160
+ # DeekSeekv2 uses grouped_top_k
161
+ if use_grouped_topk:
162
+ assert topk_group is not None
163
+ assert num_expert_group is not None
164
+ if correction_bias is None:
165
+ topk_weights, topk_ids = grouped_topk(
166
+ hidden_states=hidden_states,
167
+ gating_output=router_logits,
168
+ topk=top_k,
169
+ renormalize=renormalize,
170
+ num_expert_group=num_expert_group,
171
+ topk_group=topk_group,
172
+ )
173
+ else:
174
+ topk_weights, topk_ids = biased_grouped_topk(
175
+ hidden_states=hidden_states,
176
+ gating_output=router_logits,
177
+ correction_bias=correction_bias,
178
+ topk=top_k,
179
+ renormalize=renormalize,
180
+ num_expert_group=num_expert_group,
181
+ topk_group=topk_group,
182
+ )
183
+ elif torch_native:
184
+ topk_weights, topk_ids = fused_topk_native(
185
+ hidden_states=hidden_states,
186
+ gating_output=router_logits,
187
+ topk=top_k,
188
+ renormalize=renormalize,
189
+ )
190
+ elif custom_routing_function is None:
191
+ topk_weights, topk_ids = fused_topk(
192
+ hidden_states=hidden_states,
193
+ gating_output=router_logits,
194
+ topk=top_k,
195
+ renormalize=renormalize,
196
+ )
197
+ else:
198
+ topk_weights, topk_ids = custom_routing_function(
199
+ hidden_states=hidden_states,
200
+ gating_output=router_logits,
201
+ topk=top_k,
202
+ renormalize=renormalize,
203
+ )
204
+
205
+ return topk_weights, topk_ids
@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
60
60
  is_layer_skipped,
61
61
  )
62
62
 
63
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
64
63
  from sglang.srt.layers.linear import UnquantizedLinearMethod
64
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
65
65
  from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
66
66
 
67
67
  if isinstance(layer, LinearBase):
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
80
80
  GPTQMarlinMoEMethod,
81
81
  )
82
82
 
83
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
83
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
84
84
 
85
85
  if isinstance(layer, LinearBase):
86
86
  return GPTQMarlinLinearMethod(self)
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
96
96
  AWQMoEMethod,
97
97
  )
98
98
 
99
- from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
99
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
100
100
 
101
101
  if isinstance(layer, LinearBase):
102
102
  return AWQMarlinLinearMethod(self)
@@ -9,6 +9,7 @@ import torch.nn.functional as F
9
9
  from torch.nn import Module
10
10
  from torch.nn.parameter import Parameter
11
11
  from vllm import _custom_ops as ops
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
12
13
  from vllm.model_executor.layers.linear import LinearBase
13
14
  from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
14
15
  from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -26,13 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
26
27
  )
27
28
  from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
28
29
 
29
- from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
30
30
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
31
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
31
32
  from sglang.srt.layers.quantization.base_config import (
32
33
  QuantizationConfig,
33
34
  QuantizeMethodBase,
34
35
  )
35
- from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
36
+ from sglang.srt.layers.quantization.fp8_utils import (
37
+ BlockQuantScaleParameter,
38
+ apply_w8a8_block_fp8_linear,
39
+ normalize_e4m3fn_to_e4m3fnuz,
40
+ )
36
41
  from sglang.srt.utils import (
37
42
  get_bool_env_var,
38
43
  is_hip,
@@ -53,6 +58,7 @@ class Fp8Config(QuantizationConfig):
53
58
  is_checkpoint_fp8_serialized: bool = False,
54
59
  activation_scheme: str = "dynamic",
55
60
  ignored_layers: Optional[List[str]] = None,
61
+ weight_block_size: List[int] = None,
56
62
  ) -> None:
57
63
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
58
64
  if is_checkpoint_fp8_serialized:
@@ -64,6 +70,20 @@ class Fp8Config(QuantizationConfig):
64
70
  raise ValueError(f"Unsupported activation scheme {activation_scheme}")
65
71
  self.activation_scheme = activation_scheme
66
72
  self.ignored_layers = ignored_layers or []
73
+ if weight_block_size is not None:
74
+ if not is_checkpoint_fp8_serialized:
75
+ raise ValueError(
76
+ f"The block-wise quantization only supports fp8-serialized checkpoint for now."
77
+ )
78
+ if len(weight_block_size) != 2:
79
+ raise ValueError(
80
+ f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
81
+ )
82
+ if activation_scheme != "dynamic":
83
+ raise ValueError(
84
+ f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
85
+ )
86
+ self.weight_block_size = weight_block_size
67
87
 
68
88
  @classmethod
69
89
  def get_name(cls) -> str:
@@ -87,10 +107,12 @@ class Fp8Config(QuantizationConfig):
87
107
  is_checkpoint_fp8_serialized = "fp8" in quant_method
88
108
  activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
89
109
  ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
110
+ weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
90
111
  return cls(
91
112
  is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
92
113
  activation_scheme=activation_scheme,
93
114
  ignored_layers=ignored_layers,
115
+ weight_block_size=weight_block_size,
94
116
  )
95
117
 
96
118
  def get_quant_method(
@@ -98,7 +120,7 @@ class Fp8Config(QuantizationConfig):
98
120
  ) -> Optional["QuantizeMethodBase"]:
99
121
  from vllm.attention.layer import Attention # Avoid circular import
100
122
 
101
- from sglang.srt.layers.fused_moe_triton import FusedMoE
123
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
102
124
 
103
125
  if isinstance(layer, LinearBase):
104
126
  if is_layer_skipped(prefix, self.ignored_layers):
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
143
165
  if is_hip():
144
166
  self.use_marlin = False
145
167
 
168
+ self.block_quant = self.quant_config.weight_block_size is not None
169
+ if self.block_quant:
170
+ # Marlin doesn't support block-wise fp8
171
+ self.use_marlin = False
172
+
146
173
  def create_weights(
147
174
  self,
148
175
  layer: torch.nn.Module,
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
153
180
  params_dtype: torch.dtype,
154
181
  **extra_weight_attrs,
155
182
  ):
156
- del input_size, output_size
157
183
  output_size_per_partition = sum(output_partition_sizes)
158
184
  weight_loader = extra_weight_attrs.get("weight_loader")
159
185
 
186
+ tp_size = get_tensor_model_parallel_world_size()
187
+ if self.block_quant:
188
+ block_n, block_k = (
189
+ self.quant_config.weight_block_size[0],
190
+ self.quant_config.weight_block_size[1],
191
+ )
192
+ # Required by row parallel
193
+ if tp_size > 1 and input_size // input_size_per_partition == tp_size:
194
+ if input_size_per_partition % block_k != 0:
195
+ raise ValueError(
196
+ f"Weight input_size_per_partition = "
197
+ f"{input_size_per_partition} is not divisible by "
198
+ f"weight quantization block_k = {block_k}."
199
+ )
200
+ # Required by collum parallel or enabling merged weights
201
+ if (
202
+ tp_size > 1 and output_size // output_size_per_partition == tp_size
203
+ ) or len(output_partition_sizes) > 1:
204
+ for output_partition_size in output_partition_sizes:
205
+ if output_partition_size % block_n != 0:
206
+ raise ValueError(
207
+ f"Weight output_partition_size = "
208
+ f"{output_partition_size} is not divisible by "
209
+ f"weight quantization block_n = {block_n}."
210
+ )
211
+
160
212
  layer.logical_widths = output_partition_sizes
161
213
 
162
214
  layer.input_size_per_partition = input_size_per_partition
@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase):
184
236
  # Otherwise, wait until process_weights_after_loading.
185
237
  if self.quant_config.is_checkpoint_fp8_serialized:
186
238
  # WEIGHT SCALE
187
- scale = PerTensorScaleParameter(
188
- data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
189
- weight_loader=weight_loader,
190
- )
191
-
192
- scale[:] = torch.finfo(torch.float32).min
193
- layer.register_parameter("weight_scale", scale)
239
+ if self.block_quant:
240
+ assert self.quant_config.activation_scheme == "dynamic"
241
+ scale = BlockQuantScaleParameter(
242
+ data=torch.empty(
243
+ (output_size_per_partition + block_n - 1) // block_n,
244
+ (input_size_per_partition + block_k - 1) // block_k,
245
+ dtype=torch.float32,
246
+ ),
247
+ input_dim=1,
248
+ output_dim=0,
249
+ weight_loader=weight_loader,
250
+ )
251
+ scale[:] = torch.finfo(torch.float32).min
252
+ layer.register_parameter("weight_scale_inv", scale)
253
+ else:
254
+ scale = PerTensorScaleParameter(
255
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
256
+ weight_loader=weight_loader,
257
+ )
258
+ scale[:] = torch.finfo(torch.float32).min
259
+ layer.register_parameter("weight_scale", scale)
194
260
 
195
261
  # INPUT ACTIVATION SCALE
196
262
  if self.quant_config.activation_scheme == "static":
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
205
271
  layer.register_parameter("input_scale", None)
206
272
 
207
273
  def process_weights_after_loading(self, layer: Module) -> None:
274
+ # Block quant doesn't need to process weights after loading
275
+ if self.block_quant:
276
+ return
208
277
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
209
278
  # If checkpoint not serialized fp8, quantize the weights.
210
279
  if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
295
364
  bias=bias,
296
365
  )
297
366
 
367
+ if self.block_quant:
368
+ return apply_w8a8_block_fp8_linear(
369
+ input=x,
370
+ weight=layer.weight,
371
+ block_size=self.quant_config.weight_block_size,
372
+ weight_scale=layer.weight_scale_inv,
373
+ input_scale=layer.input_scale,
374
+ bias=bias,
375
+ )
376
+
298
377
  return apply_fp8_linear(
299
378
  input=x,
300
379
  weight=layer.weight,
@@ -320,7 +399,7 @@ class Fp8MoEMethod:
320
399
  """
321
400
 
322
401
  def __new__(cls, *args, **kwargs):
323
- from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
402
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
324
403
 
325
404
  if not hasattr(cls, "_initialized"):
326
405
  original_init = cls.__init__
@@ -339,6 +418,7 @@ class Fp8MoEMethod:
339
418
 
340
419
  def __init__(self, quant_config):
341
420
  self.quant_config = quant_config
421
+ self.block_quant = self.quant_config.weight_block_size is not None
342
422
 
343
423
  def create_weights(
344
424
  self,
@@ -349,10 +429,32 @@ class Fp8MoEMethod:
349
429
  params_dtype: torch.dtype,
350
430
  **extra_weight_attrs,
351
431
  ):
352
- from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
432
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
353
433
 
354
434
  if self.quant_config.is_checkpoint_fp8_serialized:
355
435
  params_dtype = torch.float8_e4m3fn
436
+ tp_size = get_tensor_model_parallel_world_size()
437
+ if self.block_quant:
438
+ block_n, block_k = (
439
+ self.quant_config.weight_block_size[0],
440
+ self.quant_config.weight_block_size[1],
441
+ )
442
+ # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
443
+ # Required by collum parallel or enabling merged weights
444
+ if intermediate_size % block_n != 0:
445
+ raise ValueError(
446
+ f"The output_size of gate's and up's weight = "
447
+ f"{intermediate_size} is not divisible by "
448
+ f"weight quantization block_n = {block_n}."
449
+ )
450
+ if tp_size > 1:
451
+ # Required by row parallel
452
+ if intermediate_size % block_k != 0:
453
+ raise ValueError(
454
+ f"The input_size of down's weight = "
455
+ f"{intermediate_size} is not divisible by "
456
+ f"weight quantization block_k = {block_k}."
457
+ )
356
458
 
357
459
  # WEIGHTS
358
460
  w13_weight = torch.nn.Parameter(
@@ -374,21 +476,45 @@ class Fp8MoEMethod:
374
476
  set_weight_attrs(w2_weight, extra_weight_attrs)
375
477
 
376
478
  # WEIGHT_SCALES
377
- # Allocate 2 scales for w1 and w3 respectively.
378
- # They will be combined to a single scale after weight loading.
379
- w13_weight_scale = torch.nn.Parameter(
380
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
381
- )
382
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
383
-
384
- w2_weight_scale = torch.nn.Parameter(
385
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
386
- )
387
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
479
+ if self.block_quant:
480
+ w13_weight_scale = torch.nn.Parameter(
481
+ torch.ones(
482
+ num_experts,
483
+ 2 * ((intermediate_size + block_n - 1) // block_n),
484
+ (hidden_size + block_k - 1) // block_k,
485
+ dtype=torch.float32,
486
+ ),
487
+ requires_grad=False,
488
+ )
489
+ w2_weight_scale = torch.nn.Parameter(
490
+ torch.ones(
491
+ num_experts,
492
+ (hidden_size + block_n - 1) // block_n,
493
+ (intermediate_size + block_k - 1) // block_k,
494
+ dtype=torch.float32,
495
+ ),
496
+ requires_grad=False,
497
+ )
498
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
499
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
500
+ assert self.quant_config.activation_scheme == "dynamic"
501
+ else:
502
+ # Allocate 2 scales for w1 and w3 respectively.
503
+ # They will be combined to a single scale after weight loading.
504
+ w13_weight_scale = torch.nn.Parameter(
505
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
506
+ )
507
+ w2_weight_scale = torch.nn.Parameter(
508
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
509
+ )
510
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
511
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
388
512
  # Add the quantization method used (per tensor/grouped/channel)
389
513
  # to ensure the weight scales are loaded in properly
390
514
  extra_weight_attrs.update(
391
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
515
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
516
+ if self.block_quant
517
+ else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
392
518
  )
393
519
  # If loading fp8 checkpoint, pass the weight loaders.
394
520
  # If loading an fp16 checkpoint, do not (we will quantize in
@@ -422,7 +548,9 @@ class Fp8MoEMethod:
422
548
  layer.w2_input_scale = None
423
549
 
424
550
  def process_weights_after_loading(self, layer: Module) -> None:
425
-
551
+ # Block quant doesn't need to process weights after loading
552
+ if self.block_quant:
553
+ return
426
554
  # If checkpoint is fp16 or bfloat16, quantize in place.
427
555
  if not self.quant_config.is_checkpoint_fp8_serialized:
428
556
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
@@ -519,7 +647,6 @@ class Fp8MoEMethod:
519
647
  layer.w2_input_scale = torch.nn.Parameter(
520
648
  w2_input_scale, requires_grad=False
521
649
  )
522
-
523
650
  # Fp8 moe kernel needs single weight scale for w13 per expert.
524
651
  # We take the max then dequant and requant each expert.
525
652
  assert layer.w13_weight_scale is not None
@@ -566,12 +693,14 @@ class Fp8MoEMethod:
566
693
  topk_group: Optional[int] = None,
567
694
  num_expert_group: Optional[int] = None,
568
695
  custom_routing_function: Optional[Callable] = None,
696
+ correction_bias: Optional[torch.Tensor] = None,
569
697
  ) -> torch.Tensor:
570
- from sglang.srt.layers.fused_moe_triton import FusedMoE
571
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
698
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
699
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
700
+ from sglang.srt.layers.moe.topk import select_experts
572
701
 
573
702
  # Expert selection
574
- topk_weights, topk_ids = FusedMoE.select_experts(
703
+ topk_weights, topk_ids = select_experts(
575
704
  hidden_states=x,
576
705
  router_logits=router_logits,
577
706
  use_grouped_topk=use_grouped_topk,
@@ -580,6 +709,7 @@ class Fp8MoEMethod:
580
709
  topk_group=topk_group,
581
710
  num_expert_group=num_expert_group,
582
711
  custom_routing_function=custom_routing_function,
712
+ correction_bias=correction_bias,
583
713
  )
584
714
 
585
715
  # Expert fusion with FP8 quantization
@@ -591,10 +721,17 @@ class Fp8MoEMethod:
591
721
  topk_ids=topk_ids,
592
722
  inplace=True,
593
723
  use_fp8_w8a8=True,
594
- w1_scale=layer.w13_weight_scale,
595
- w2_scale=layer.w2_weight_scale,
724
+ w1_scale=(
725
+ layer.w13_weight_scale_inv
726
+ if self.block_quant
727
+ else layer.w13_weight_scale
728
+ ),
729
+ w2_scale=(
730
+ layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
731
+ ),
596
732
  a1_scale=layer.w13_input_scale,
597
733
  a2_scale=layer.w2_input_scale,
734
+ block_shape=self.quant_config.weight_block_size,
598
735
  )
599
736
 
600
737