sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,15 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
+ import os
4
5
  from typing import Any, Callable, Dict, List, Optional
5
6
 
6
7
  import torch
8
+ import torch.nn.functional as F
7
9
  from torch.nn import Module
8
10
  from torch.nn.parameter import Parameter
9
11
  from vllm import _custom_ops as ops
12
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
13
  from vllm.model_executor.layers.linear import LinearBase
11
14
  from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
12
15
  from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@@ -24,17 +27,17 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
24
27
  )
25
28
  from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
26
29
 
27
- from sglang.srt.layers.fused_moe_triton import (
28
- FusedMoE,
29
- FusedMoEMethodBase,
30
- FusedMoeWeightScaleSupported,
31
- )
32
30
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
31
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
33
32
  from sglang.srt.layers.quantization.base_config import (
34
33
  QuantizationConfig,
35
34
  QuantizeMethodBase,
36
35
  )
37
- from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
36
+ from sglang.srt.layers.quantization.fp8_utils import (
37
+ BlockQuantScaleParameter,
38
+ apply_w8a8_block_fp8_linear,
39
+ normalize_e4m3fn_to_e4m3fnuz,
40
+ )
38
41
  from sglang.srt.utils import (
39
42
  get_bool_env_var,
40
43
  is_hip,
@@ -55,6 +58,7 @@ class Fp8Config(QuantizationConfig):
55
58
  is_checkpoint_fp8_serialized: bool = False,
56
59
  activation_scheme: str = "dynamic",
57
60
  ignored_layers: Optional[List[str]] = None,
61
+ weight_block_size: List[int] = None,
58
62
  ) -> None:
59
63
  self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
60
64
  if is_checkpoint_fp8_serialized:
@@ -66,6 +70,20 @@ class Fp8Config(QuantizationConfig):
66
70
  raise ValueError(f"Unsupported activation scheme {activation_scheme}")
67
71
  self.activation_scheme = activation_scheme
68
72
  self.ignored_layers = ignored_layers or []
73
+ if weight_block_size is not None:
74
+ if not is_checkpoint_fp8_serialized:
75
+ raise ValueError(
76
+ f"The block-wise quantization only supports fp8-serialized checkpoint for now."
77
+ )
78
+ if len(weight_block_size) != 2:
79
+ raise ValueError(
80
+ f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions."
81
+ )
82
+ if activation_scheme != "dynamic":
83
+ raise ValueError(
84
+ f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme."
85
+ )
86
+ self.weight_block_size = weight_block_size
69
87
 
70
88
  @classmethod
71
89
  def get_name(cls) -> str:
@@ -89,10 +107,12 @@ class Fp8Config(QuantizationConfig):
89
107
  is_checkpoint_fp8_serialized = "fp8" in quant_method
90
108
  activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
91
109
  ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
110
+ weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
92
111
  return cls(
93
112
  is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
94
113
  activation_scheme=activation_scheme,
95
114
  ignored_layers=ignored_layers,
115
+ weight_block_size=weight_block_size,
96
116
  )
97
117
 
98
118
  def get_quant_method(
@@ -100,6 +120,8 @@ class Fp8Config(QuantizationConfig):
100
120
  ) -> Optional["QuantizeMethodBase"]:
101
121
  from vllm.attention.layer import Attention # Avoid circular import
102
122
 
123
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
124
+
103
125
  if isinstance(layer, LinearBase):
104
126
  if is_layer_skipped(prefix, self.ignored_layers):
105
127
  return UnquantizedLinearMethod()
@@ -143,6 +165,11 @@ class Fp8LinearMethod(LinearMethodBase):
143
165
  if is_hip():
144
166
  self.use_marlin = False
145
167
 
168
+ self.block_quant = self.quant_config.weight_block_size is not None
169
+ if self.block_quant:
170
+ # Marlin doesn't support block-wise fp8
171
+ self.use_marlin = False
172
+
146
173
  def create_weights(
147
174
  self,
148
175
  layer: torch.nn.Module,
@@ -153,10 +180,35 @@ class Fp8LinearMethod(LinearMethodBase):
153
180
  params_dtype: torch.dtype,
154
181
  **extra_weight_attrs,
155
182
  ):
156
- del input_size, output_size
157
183
  output_size_per_partition = sum(output_partition_sizes)
158
184
  weight_loader = extra_weight_attrs.get("weight_loader")
159
185
 
186
+ tp_size = get_tensor_model_parallel_world_size()
187
+ if self.block_quant:
188
+ block_n, block_k = (
189
+ self.quant_config.weight_block_size[0],
190
+ self.quant_config.weight_block_size[1],
191
+ )
192
+ # Required by row parallel
193
+ if tp_size > 1 and input_size // input_size_per_partition == tp_size:
194
+ if input_size_per_partition % block_k != 0:
195
+ raise ValueError(
196
+ f"Weight input_size_per_partition = "
197
+ f"{input_size_per_partition} is not divisible by "
198
+ f"weight quantization block_k = {block_k}."
199
+ )
200
+ # Required by collum parallel or enabling merged weights
201
+ if (
202
+ tp_size > 1 and output_size // output_size_per_partition == tp_size
203
+ ) or len(output_partition_sizes) > 1:
204
+ for output_partition_size in output_partition_sizes:
205
+ if output_partition_size % block_n != 0:
206
+ raise ValueError(
207
+ f"Weight output_partition_size = "
208
+ f"{output_partition_size} is not divisible by "
209
+ f"weight quantization block_n = {block_n}."
210
+ )
211
+
160
212
  layer.logical_widths = output_partition_sizes
161
213
 
162
214
  layer.input_size_per_partition = input_size_per_partition
@@ -184,13 +236,27 @@ class Fp8LinearMethod(LinearMethodBase):
184
236
  # Otherwise, wait until process_weights_after_loading.
185
237
  if self.quant_config.is_checkpoint_fp8_serialized:
186
238
  # WEIGHT SCALE
187
- scale = PerTensorScaleParameter(
188
- data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
189
- weight_loader=weight_loader,
190
- )
191
-
192
- scale[:] = torch.finfo(torch.float32).min
193
- layer.register_parameter("weight_scale", scale)
239
+ if self.block_quant:
240
+ assert self.quant_config.activation_scheme == "dynamic"
241
+ scale = BlockQuantScaleParameter(
242
+ data=torch.empty(
243
+ (output_size_per_partition + block_n - 1) // block_n,
244
+ (input_size_per_partition + block_k - 1) // block_k,
245
+ dtype=torch.float32,
246
+ ),
247
+ input_dim=1,
248
+ output_dim=0,
249
+ weight_loader=weight_loader,
250
+ )
251
+ scale[:] = torch.finfo(torch.float32).min
252
+ layer.register_parameter("weight_scale_inv", scale)
253
+ else:
254
+ scale = PerTensorScaleParameter(
255
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
256
+ weight_loader=weight_loader,
257
+ )
258
+ scale[:] = torch.finfo(torch.float32).min
259
+ layer.register_parameter("weight_scale", scale)
194
260
 
195
261
  # INPUT ACTIVATION SCALE
196
262
  if self.quant_config.activation_scheme == "static":
@@ -205,6 +271,9 @@ class Fp8LinearMethod(LinearMethodBase):
205
271
  layer.register_parameter("input_scale", None)
206
272
 
207
273
  def process_weights_after_loading(self, layer: Module) -> None:
274
+ # Block quant doesn't need to process weights after loading
275
+ if self.block_quant:
276
+ return
208
277
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
209
278
  # If checkpoint not serialized fp8, quantize the weights.
210
279
  if not self.quant_config.is_checkpoint_fp8_serialized:
@@ -295,6 +364,16 @@ class Fp8LinearMethod(LinearMethodBase):
295
364
  bias=bias,
296
365
  )
297
366
 
367
+ if self.block_quant:
368
+ return apply_w8a8_block_fp8_linear(
369
+ input=x,
370
+ weight=layer.weight,
371
+ block_size=self.quant_config.weight_block_size,
372
+ weight_scale=layer.weight_scale_inv,
373
+ input_scale=layer.input_scale,
374
+ bias=bias,
375
+ )
376
+
298
377
  return apply_fp8_linear(
299
378
  input=x,
300
379
  weight=layer.weight,
@@ -306,7 +385,7 @@ class Fp8LinearMethod(LinearMethodBase):
306
385
  )
307
386
 
308
387
 
309
- class Fp8MoEMethod(FusedMoEMethodBase):
388
+ class Fp8MoEMethod:
310
389
  """MoE method for FP8.
311
390
  Supports loading FP8 checkpoints with static weight scale and
312
391
  dynamic/static activation scale.
@@ -319,8 +398,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
319
398
  quant_config: The quantization config.
320
399
  """
321
400
 
322
- def __init__(self, quant_config: Fp8Config):
401
+ def __new__(cls, *args, **kwargs):
402
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
403
+
404
+ if not hasattr(cls, "_initialized"):
405
+ original_init = cls.__init__
406
+ new_cls = type(
407
+ cls.__name__,
408
+ (FusedMoEMethodBase,),
409
+ {
410
+ "__init__": original_init,
411
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
412
+ },
413
+ )
414
+ obj = super(new_cls, new_cls).__new__(new_cls)
415
+ obj.__init__(*args, **kwargs)
416
+ return obj
417
+ return super().__new__(cls)
418
+
419
+ def __init__(self, quant_config):
323
420
  self.quant_config = quant_config
421
+ self.block_quant = self.quant_config.weight_block_size is not None
324
422
 
325
423
  def create_weights(
326
424
  self,
@@ -331,9 +429,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
331
429
  params_dtype: torch.dtype,
332
430
  **extra_weight_attrs,
333
431
  ):
432
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
334
433
 
335
434
  if self.quant_config.is_checkpoint_fp8_serialized:
336
435
  params_dtype = torch.float8_e4m3fn
436
+ tp_size = get_tensor_model_parallel_world_size()
437
+ if self.block_quant:
438
+ block_n, block_k = (
439
+ self.quant_config.weight_block_size[0],
440
+ self.quant_config.weight_block_size[1],
441
+ )
442
+ # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
443
+ # Required by collum parallel or enabling merged weights
444
+ if intermediate_size % block_n != 0:
445
+ raise ValueError(
446
+ f"The output_size of gate's and up's weight = "
447
+ f"{intermediate_size} is not divisible by "
448
+ f"weight quantization block_n = {block_n}."
449
+ )
450
+ if tp_size > 1:
451
+ # Required by row parallel
452
+ if intermediate_size % block_k != 0:
453
+ raise ValueError(
454
+ f"The input_size of down's weight = "
455
+ f"{intermediate_size} is not divisible by "
456
+ f"weight quantization block_k = {block_k}."
457
+ )
337
458
 
338
459
  # WEIGHTS
339
460
  w13_weight = torch.nn.Parameter(
@@ -355,21 +476,45 @@ class Fp8MoEMethod(FusedMoEMethodBase):
355
476
  set_weight_attrs(w2_weight, extra_weight_attrs)
356
477
 
357
478
  # WEIGHT_SCALES
358
- # Allocate 2 scales for w1 and w3 respectively.
359
- # They will be combined to a single scale after weight loading.
360
- w13_weight_scale = torch.nn.Parameter(
361
- torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
362
- )
363
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
364
-
365
- w2_weight_scale = torch.nn.Parameter(
366
- torch.ones(num_experts, dtype=torch.float32), requires_grad=False
367
- )
368
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
479
+ if self.block_quant:
480
+ w13_weight_scale = torch.nn.Parameter(
481
+ torch.ones(
482
+ num_experts,
483
+ 2 * ((intermediate_size + block_n - 1) // block_n),
484
+ (hidden_size + block_k - 1) // block_k,
485
+ dtype=torch.float32,
486
+ ),
487
+ requires_grad=False,
488
+ )
489
+ w2_weight_scale = torch.nn.Parameter(
490
+ torch.ones(
491
+ num_experts,
492
+ (hidden_size + block_n - 1) // block_n,
493
+ (intermediate_size + block_k - 1) // block_k,
494
+ dtype=torch.float32,
495
+ ),
496
+ requires_grad=False,
497
+ )
498
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
499
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
500
+ assert self.quant_config.activation_scheme == "dynamic"
501
+ else:
502
+ # Allocate 2 scales for w1 and w3 respectively.
503
+ # They will be combined to a single scale after weight loading.
504
+ w13_weight_scale = torch.nn.Parameter(
505
+ torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
506
+ )
507
+ w2_weight_scale = torch.nn.Parameter(
508
+ torch.ones(num_experts, dtype=torch.float32), requires_grad=False
509
+ )
510
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
511
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
369
512
  # Add the quantization method used (per tensor/grouped/channel)
370
513
  # to ensure the weight scales are loaded in properly
371
514
  extra_weight_attrs.update(
372
- {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
515
+ {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
516
+ if self.block_quant
517
+ else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
373
518
  )
374
519
  # If loading fp8 checkpoint, pass the weight loaders.
375
520
  # If loading an fp16 checkpoint, do not (we will quantize in
@@ -403,8 +548,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
403
548
  layer.w2_input_scale = None
404
549
 
405
550
  def process_weights_after_loading(self, layer: Module) -> None:
406
-
407
- # If checkpoint is fp16, quantize in place.
551
+ # Block quant doesn't need to process weights after loading
552
+ if self.block_quant:
553
+ return
554
+ # If checkpoint is fp16 or bfloat16, quantize in place.
408
555
  if not self.quant_config.is_checkpoint_fp8_serialized:
409
556
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
410
557
  fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
@@ -428,6 +575,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
428
575
  )
429
576
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
430
577
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
578
+
579
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
580
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
581
+ layer.w13_weight = torch.nn.Parameter(
582
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
583
+ requires_grad=False,
584
+ )
585
+ torch.cuda.empty_cache()
586
+ layer.w2_weight = torch.nn.Parameter(
587
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
588
+ requires_grad=False,
589
+ )
590
+ torch.cuda.empty_cache()
431
591
  return
432
592
 
433
593
  # If checkpoint is fp8, we need to handle that the
@@ -456,6 +616,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
456
616
  layer.w2_input_scale = torch.nn.Parameter(
457
617
  layer.w2_input_scale.max(), requires_grad=False
458
618
  )
619
+
459
620
  # If ROCm, normalize the weights and scales to e4m3fnuz
460
621
  if is_hip():
461
622
  # Normalize the weights and scales
@@ -486,7 +647,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
486
647
  layer.w2_input_scale = torch.nn.Parameter(
487
648
  w2_input_scale, requires_grad=False
488
649
  )
489
-
490
650
  # Fp8 moe kernel needs single weight scale for w13 per expert.
491
651
  # We take the max then dequant and requant each expert.
492
652
  assert layer.w13_weight_scale is not None
@@ -507,6 +667,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
507
667
  layer.w13_weight_scale = torch.nn.Parameter(
508
668
  max_w13_scales, requires_grad=False
509
669
  )
670
+
671
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
672
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
673
+ layer.w13_weight = torch.nn.Parameter(
674
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
675
+ requires_grad=False,
676
+ )
677
+ torch.cuda.empty_cache()
678
+ layer.w2_weight = torch.nn.Parameter(
679
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
680
+ requires_grad=False,
681
+ )
682
+ torch.cuda.empty_cache()
510
683
  return
511
684
 
512
685
  def apply(
@@ -520,11 +693,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
520
693
  topk_group: Optional[int] = None,
521
694
  num_expert_group: Optional[int] = None,
522
695
  custom_routing_function: Optional[Callable] = None,
696
+ correction_bias: Optional[torch.Tensor] = None,
523
697
  ) -> torch.Tensor:
698
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
699
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
700
+ from sglang.srt.layers.moe.topk import select_experts
524
701
 
525
- from vllm.model_executor.layers.fused_moe import fused_experts
526
-
527
- topk_weights, topk_ids = FusedMoE.select_experts(
702
+ # Expert selection
703
+ topk_weights, topk_ids = select_experts(
528
704
  hidden_states=x,
529
705
  router_logits=router_logits,
530
706
  use_grouped_topk=use_grouped_topk,
@@ -533,8 +709,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
533
709
  topk_group=topk_group,
534
710
  num_expert_group=num_expert_group,
535
711
  custom_routing_function=custom_routing_function,
712
+ correction_bias=correction_bias,
536
713
  )
537
714
 
715
+ # Expert fusion with FP8 quantization
538
716
  return fused_experts(
539
717
  x,
540
718
  layer.w13_weight,
@@ -543,10 +721,17 @@ class Fp8MoEMethod(FusedMoEMethodBase):
543
721
  topk_ids=topk_ids,
544
722
  inplace=True,
545
723
  use_fp8_w8a8=True,
546
- w1_scale=layer.w13_weight_scale,
547
- w2_scale=layer.w2_weight_scale,
724
+ w1_scale=(
725
+ layer.w13_weight_scale_inv
726
+ if self.block_quant
727
+ else layer.w13_weight_scale
728
+ ),
729
+ w2_scale=(
730
+ layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
731
+ ),
548
732
  a1_scale=layer.w13_input_scale,
549
733
  a2_scale=layer.w2_input_scale,
734
+ block_shape=self.quant_config.weight_block_size,
550
735
  )
551
736
 
552
737