sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +1 -0
  4. sglang/srt/aio_rwlock.py +100 -0
  5. sglang/srt/configs/model_config.py +8 -1
  6. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  7. sglang/srt/layers/linear.py +20 -2
  8. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  9. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  10. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  11. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
  12. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  13. sglang/srt/layers/moe/topk.py +191 -0
  14. sglang/srt/layers/quantization/__init__.py +3 -3
  15. sglang/srt/layers/quantization/fp8.py +169 -32
  16. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  17. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  18. sglang/srt/layers/torchao_utils.py +11 -15
  19. sglang/srt/managers/schedule_batch.py +16 -10
  20. sglang/srt/managers/scheduler.py +2 -2
  21. sglang/srt/managers/tokenizer_manager.py +86 -76
  22. sglang/srt/mem_cache/memory_pool.py +15 -8
  23. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  24. sglang/srt/model_executor/model_runner.py +6 -0
  25. sglang/srt/models/dbrx.py +1 -1
  26. sglang/srt/models/deepseek.py +1 -1
  27. sglang/srt/models/deepseek_v2.py +67 -18
  28. sglang/srt/models/grok.py +1 -1
  29. sglang/srt/models/mixtral.py +2 -2
  30. sglang/srt/models/olmoe.py +1 -1
  31. sglang/srt/models/qwen2_moe.py +1 -1
  32. sglang/srt/models/xverse_moe.py +1 -1
  33. sglang/srt/openai_api/adapter.py +4 -0
  34. sglang/srt/server.py +1 -0
  35. sglang/srt/utils.py +33 -44
  36. sglang/test/test_block_fp8.py +341 -0
  37. sglang/version.py +1 -1
  38. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
  39. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
  40. sglang/srt/layers/fused_moe_patch.py +0 -133
  41. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  42. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  43. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  44. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  45. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,16 @@ import functools
6
6
  import json
7
7
  import logging
8
8
  import os
9
- from typing import Any, Callable, Dict, Optional, Tuple
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple
10
10
 
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
14
15
  from vllm import _custom_ops as ops
15
16
 
17
+ from sglang.srt.layers.moe.topk import select_experts
18
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
16
19
  from sglang.srt.utils import direct_register_custom_op, get_device_name
17
20
 
18
21
  logger = logging.getLogger(__name__)
@@ -47,8 +50,14 @@ def fused_moe_kernel(
47
50
  stride_bn,
48
51
  stride_cm,
49
52
  stride_cn,
53
+ stride_asm,
54
+ stride_ask,
50
55
  stride_bse,
56
+ stride_bsk,
51
57
  stride_bsn,
58
+ # Block size for block-wise quantization
59
+ group_n: tl.constexpr,
60
+ group_k: tl.constexpr,
52
61
  # Meta-parameters
53
62
  BLOCK_SIZE_M: tl.constexpr,
54
63
  BLOCK_SIZE_N: tl.constexpr,
@@ -132,8 +141,15 @@ def fused_moe_kernel(
132
141
  b_scale = tl.load(b_scale_ptrs)
133
142
 
134
143
  if use_fp8_w8a8:
135
- a_scale = tl.load(a_scale_ptr)
136
- b_scale = tl.load(b_scale_ptr + off_experts)
144
+ if group_k > 0 and group_n > 0:
145
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
146
+ offs_bsn = offs_bn // group_n
147
+ b_scale_ptrs = (
148
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
149
+ )
150
+ else:
151
+ a_scale = tl.load(a_scale_ptr)
152
+ b_scale = tl.load(b_scale_ptr + off_experts)
137
153
 
138
154
  # -----------------------------------------------------------
139
155
  # Iterate to compute a block of the C matrix.
@@ -164,7 +180,17 @@ def fused_moe_kernel(
164
180
  if use_int8_w8a16:
165
181
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
166
182
  elif use_fp8_w8a8:
167
- accumulator = tl.dot(a, b, acc=accumulator)
183
+ if group_k > 0 and group_n > 0:
184
+ k_start = k * BLOCK_SIZE_K
185
+ offs_ks = k_start // group_k
186
+ a_scale = tl.load(
187
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
188
+ )
189
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
190
+
191
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
192
+ else:
193
+ accumulator = tl.dot(a, b, acc=accumulator)
168
194
  else:
169
195
  accumulator += tl.dot(a, b)
170
196
  # Advance the ptrs to the next K block.
@@ -177,7 +203,10 @@ def fused_moe_kernel(
177
203
  if use_int8_w8a16:
178
204
  accumulator = (accumulator * b_scale).to(compute_type)
179
205
  elif use_fp8_w8a8:
180
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
206
+ if group_k > 0 and group_n > 0:
207
+ accumulator = accumulator.to(compute_type)
208
+ else:
209
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
181
210
  else:
182
211
  accumulator = accumulator.to(compute_type)
183
212
  # -----------------------------------------------------------
@@ -238,9 +267,25 @@ def moe_align_block_size(
238
267
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
239
268
  )
240
269
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
241
- ops.moe_align_block_size(
242
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
243
- )
270
+ # FIXME(zhyncs)
271
+ if num_experts >= 256:
272
+ sgl_moe_align_block_size(
273
+ topk_ids,
274
+ num_experts,
275
+ block_size,
276
+ sorted_ids,
277
+ expert_ids,
278
+ num_tokens_post_pad,
279
+ )
280
+ else:
281
+ ops.moe_align_block_size(
282
+ topk_ids,
283
+ num_experts,
284
+ block_size,
285
+ sorted_ids,
286
+ expert_ids,
287
+ num_tokens_post_pad,
288
+ )
244
289
  return sorted_ids, expert_ids, num_tokens_post_pad
245
290
 
246
291
 
@@ -261,6 +306,7 @@ def invoke_fused_moe_kernel(
261
306
  compute_type: tl.dtype,
262
307
  use_fp8_w8a8: bool,
263
308
  use_int8_w8a16: bool,
309
+ block_shape: Optional[List[int]] = None,
264
310
  ) -> None:
265
311
  assert topk_weights.stride(1) == 1
266
312
  assert sorted_token_ids.stride(0) == 1
@@ -268,8 +314,16 @@ def invoke_fused_moe_kernel(
268
314
  padded_size = 0
269
315
  if use_fp8_w8a8:
270
316
  padded_size = padding_size
271
- A, A_scale = ops.scaled_fp8_quant(A, A_scale)
272
317
  assert B_scale is not None
318
+ if block_shape is None:
319
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
320
+ else:
321
+ assert len(block_shape) == 2
322
+ block_n, block_k = block_shape[0], block_shape[1]
323
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
324
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
325
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
326
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
273
327
  elif use_int8_w8a16:
274
328
  assert B_scale is not None
275
329
  else:
@@ -308,8 +362,13 @@ def invoke_fused_moe_kernel(
308
362
  B.stride(1),
309
363
  C.stride(1),
310
364
  C.stride(2),
311
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
312
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
365
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
366
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
367
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
368
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
369
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
370
+ 0 if block_shape is None else block_shape[0],
371
+ 0 if block_shape is None else block_shape[1],
313
372
  MUL_ROUTED_WEIGHT=mul_routed_weight,
314
373
  top_k=top_k,
315
374
  compute_type=compute_type,
@@ -414,8 +473,9 @@ def try_get_optimal_moe_config(
414
473
  dtype: Optional[str],
415
474
  M: int,
416
475
  is_marlin: bool = False,
476
+ block_shape: Optional[List[int]] = None,
417
477
  ):
418
- from sglang.srt.layers.fused_moe_triton import get_config
478
+ from sglang.srt.layers.moe.fused_moe_triton import get_config
419
479
 
420
480
  override_config = get_config()
421
481
  if override_config:
@@ -432,77 +492,16 @@ def try_get_optimal_moe_config(
432
492
  else:
433
493
  # Else use the default config
434
494
  config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
495
+ # TODO(HandH1998): Optimize the configs of block-wise quant.
496
+ # NOTE(HandH1998): For block-wise quant,
497
+ # BLOCK_K must be divisable by block_shape[1]
498
+ # BLOCK_N and BLOCK_M has no requirements
499
+ if block_shape is not None:
500
+ config["BLOCK_SIZE_N"] = block_shape[0]
501
+ config["BLOCK_SIZE_K"] = block_shape[1]
435
502
  return config
436
503
 
437
504
 
438
- def fused_topk(
439
- hidden_states: torch.Tensor,
440
- gating_output: torch.Tensor,
441
- topk: int,
442
- renormalize: bool,
443
- ):
444
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
445
-
446
- M, _ = hidden_states.shape
447
-
448
- topk_weights = torch.empty(
449
- M, topk, dtype=torch.float32, device=hidden_states.device
450
- )
451
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
452
- token_expert_indicies = torch.empty(
453
- M, topk, dtype=torch.int32, device=hidden_states.device
454
- )
455
-
456
- ops.topk_softmax(
457
- topk_weights,
458
- topk_ids,
459
- token_expert_indicies,
460
- gating_output.float(), # TODO(woosuk): Optimize this.
461
- )
462
- del token_expert_indicies # Not used. Will be used in the future.
463
-
464
- if renormalize:
465
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
466
-
467
- return topk_weights, topk_ids
468
-
469
-
470
- # This is used by the Deepseek-V2 model
471
- def grouped_topk(
472
- hidden_states: torch.Tensor,
473
- gating_output: torch.Tensor,
474
- topk: int,
475
- renormalize: bool,
476
- num_expert_group: int = 0,
477
- topk_group: int = 0,
478
- ):
479
-
480
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
481
-
482
- scores = torch.softmax(gating_output, dim=-1)
483
- num_token = scores.shape[0]
484
- group_scores = (
485
- scores.view(num_token, num_expert_group, -1).max(dim=-1).values
486
- ) # [n, n_group]
487
- group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
488
- 1
489
- ] # [n, top_k_group]
490
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
491
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
492
- score_mask = (
493
- group_mask.unsqueeze(-1)
494
- .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
495
- .reshape(num_token, -1)
496
- ) # [n, e]
497
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
498
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
499
-
500
- if renormalize:
501
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
502
-
503
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
504
-
505
-
506
505
  def get_config_dtype_str(
507
506
  dtype: torch.dtype,
508
507
  use_int8_w8a16: Optional[bool] = False,
@@ -531,6 +530,7 @@ def inplace_fused_experts(
531
530
  w2_scale: Optional[torch.Tensor] = None,
532
531
  a1_scale: Optional[torch.Tensor] = None,
533
532
  a2_scale: Optional[torch.Tensor] = None,
533
+ block_shape: Optional[List[int]] = None,
534
534
  ) -> None:
535
535
  fused_experts_impl(
536
536
  hidden_states,
@@ -545,6 +545,7 @@ def inplace_fused_experts(
545
545
  w2_scale,
546
546
  a1_scale,
547
547
  a2_scale,
548
+ block_shape,
548
549
  )
549
550
 
550
551
 
@@ -560,6 +561,7 @@ def inplace_fused_experts_fake(
560
561
  w2_scale: Optional[torch.Tensor] = None,
561
562
  a1_scale: Optional[torch.Tensor] = None,
562
563
  a2_scale: Optional[torch.Tensor] = None,
564
+ block_shape: Optional[List[int]] = None,
563
565
  ) -> None:
564
566
  pass
565
567
 
@@ -584,6 +586,7 @@ def outplace_fused_experts(
584
586
  w2_scale: Optional[torch.Tensor] = None,
585
587
  a1_scale: Optional[torch.Tensor] = None,
586
588
  a2_scale: Optional[torch.Tensor] = None,
589
+ block_shape: Optional[List[int]] = None,
587
590
  ) -> torch.Tensor:
588
591
  return fused_experts_impl(
589
592
  hidden_states,
@@ -598,6 +601,7 @@ def outplace_fused_experts(
598
601
  w2_scale,
599
602
  a1_scale,
600
603
  a2_scale,
604
+ block_shape,
601
605
  )
602
606
 
603
607
 
@@ -613,6 +617,7 @@ def outplace_fused_experts_fake(
613
617
  w2_scale: Optional[torch.Tensor] = None,
614
618
  a1_scale: Optional[torch.Tensor] = None,
615
619
  a2_scale: Optional[torch.Tensor] = None,
620
+ block_shape: Optional[List[int]] = None,
616
621
  ) -> torch.Tensor:
617
622
  return torch.empty_like(hidden_states)
618
623
 
@@ -638,6 +643,7 @@ def fused_experts(
638
643
  w2_scale: Optional[torch.Tensor] = None,
639
644
  a1_scale: Optional[torch.Tensor] = None,
640
645
  a2_scale: Optional[torch.Tensor] = None,
646
+ block_shape: Optional[List[int]] = None,
641
647
  ):
642
648
  if inplace:
643
649
  torch.ops.sglang.inplace_fused_experts(
@@ -652,6 +658,7 @@ def fused_experts(
652
658
  w2_scale,
653
659
  a1_scale,
654
660
  a2_scale,
661
+ block_shape,
655
662
  )
656
663
  return hidden_states
657
664
  else:
@@ -667,6 +674,7 @@ def fused_experts(
667
674
  w2_scale,
668
675
  a1_scale,
669
676
  a2_scale,
677
+ block_shape,
670
678
  )
671
679
 
672
680
 
@@ -683,6 +691,7 @@ def fused_experts_impl(
683
691
  w2_scale: Optional[torch.Tensor] = None,
684
692
  a1_scale: Optional[torch.Tensor] = None,
685
693
  a2_scale: Optional[torch.Tensor] = None,
694
+ block_shape: Optional[List[int]] = None,
686
695
  ):
687
696
  padded_size = padding_size
688
697
  if not use_fp8_w8a8:
@@ -714,6 +723,7 @@ def fused_experts_impl(
714
723
  (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
715
724
  topk_ids.shape[1],
716
725
  config_dtype,
726
+ block_shape=block_shape,
717
727
  )
718
728
 
719
729
  config = get_config_func(M)
@@ -786,6 +796,7 @@ def fused_experts_impl(
786
796
  compute_type=compute_type,
787
797
  use_fp8_w8a8=use_fp8_w8a8,
788
798
  use_int8_w8a16=use_int8_w8a16,
799
+ block_shape=block_shape,
789
800
  )
790
801
 
791
802
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -807,6 +818,7 @@ def fused_experts_impl(
807
818
  compute_type=compute_type,
808
819
  use_fp8_w8a8=use_fp8_w8a8,
809
820
  use_int8_w8a16=use_int8_w8a16,
821
+ block_shape=block_shape,
810
822
  )
811
823
 
812
824
  torch.sum(
@@ -835,6 +847,7 @@ def fused_moe(
835
847
  w2_scale: Optional[torch.Tensor] = None,
836
848
  a1_scale: Optional[torch.Tensor] = None,
837
849
  a2_scale: Optional[torch.Tensor] = None,
850
+ block_shape: Optional[List[int]] = None,
838
851
  ) -> torch.Tensor:
839
852
  """
840
853
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -862,6 +875,12 @@ def fused_moe(
862
875
  w1.
863
876
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
864
877
  w2.
878
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
879
+ a1.
880
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
881
+ a2.
882
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
883
+ quantization.
865
884
 
866
885
  Returns:
867
886
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -869,24 +888,16 @@ def fused_moe(
869
888
  # Check constraints.
870
889
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
871
890
 
872
- if use_grouped_topk:
873
- assert num_expert_group is not None and topk_group is not None
874
- topk_weights, topk_ids = grouped_topk(
875
- hidden_states,
876
- gating_output,
877
- topk,
878
- renormalize,
879
- num_expert_group,
880
- topk_group,
881
- )
882
- elif custom_routing_function is None:
883
- topk_weights, topk_ids = fused_topk(
884
- hidden_states, gating_output, topk, renormalize
885
- )
886
- else:
887
- topk_weights, topk_ids = custom_routing_function(
888
- hidden_states, gating_output, topk, renormalize
889
- )
891
+ topk_weights, topk_ids = select_experts(
892
+ hidden_states=hidden_states,
893
+ router_logits=gating_output,
894
+ use_grouped_topk=use_grouped_topk,
895
+ top_k=topk,
896
+ renormalize=renormalize,
897
+ topk_group=topk_group,
898
+ num_expert_group=num_expert_group,
899
+ custom_routing_function=custom_routing_function,
900
+ )
890
901
 
891
902
  return fused_experts(
892
903
  hidden_states,
@@ -901,4 +912,5 @@ def fused_moe(
901
912
  w2_scale=w2_scale,
902
913
  a1_scale=a1_scale,
903
914
  a2_scale=a2_scale,
915
+ block_shape=block_shape,
904
916
  )
@@ -13,6 +13,7 @@ from vllm.distributed import (
13
13
  from vllm.model_executor.custom_op import CustomOp
14
14
 
15
15
  from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.moe.topk import select_experts
16
17
  from sglang.srt.layers.quantization.base_config import (
17
18
  QuantizationConfig,
18
19
  QuantizeMethodBase,
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
20
21
  from sglang.srt.utils import set_weight_attrs
21
22
 
22
23
  if torch.cuda.is_available():
23
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
24
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
24
25
  else:
25
26
  fused_experts = None # type: ignore
26
27
 
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
33
34
  TENSOR = "tensor"
34
35
  CHANNEL = "channel"
35
36
  GROUP = "group"
37
+ BLOCK = "block"
36
38
 
37
39
 
38
40
  class FusedMoEMethodBase(QuantizeMethodBase):
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
106
108
  topk_group: Optional[int] = None,
107
109
  num_expert_group: Optional[int] = None,
108
110
  custom_routing_function: Optional[Callable] = None,
111
+ correction_bias: Optional[torch.Tensor] = None,
109
112
  ) -> torch.Tensor:
110
113
  return self.forward(
111
114
  x=x,
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
117
120
  topk_group=topk_group,
118
121
  num_expert_group=num_expert_group,
119
122
  custom_routing_function=custom_routing_function,
123
+ correction_bias=correction_bias,
120
124
  )
121
125
 
122
126
  def forward_cuda(
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
130
134
  topk_group: Optional[int] = None,
131
135
  num_expert_group: Optional[int] = None,
132
136
  custom_routing_function: Optional[Callable] = None,
137
+ correction_bias: Optional[torch.Tensor] = None,
133
138
  ) -> torch.Tensor:
134
- topk_weights, topk_ids = FusedMoE.select_experts(
139
+ topk_weights, topk_ids = select_experts(
135
140
  hidden_states=x,
136
141
  router_logits=router_logits,
137
142
  use_grouped_topk=use_grouped_topk,
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
140
145
  topk_group=topk_group,
141
146
  num_expert_group=num_expert_group,
142
147
  custom_routing_function=custom_routing_function,
148
+ correction_bias=correction_bias,
143
149
  )
144
150
 
145
151
  return fused_experts(
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
197
203
  tp_size: Optional[int] = None,
198
204
  prefix: str = "",
199
205
  custom_routing_function: Optional[Callable] = None,
206
+ correction_bias: Optional[torch.Tensor] = None,
200
207
  ):
201
208
  super().__init__()
202
209
 
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
208
215
  )
209
216
  self.top_k = top_k
210
217
  self.num_experts = num_experts
218
+ assert intermediate_size % self.tp_size == 0
211
219
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
212
220
  self.reduce_results = reduce_results
213
221
  self.renormalize = renormalize
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
217
225
  self.num_expert_group = num_expert_group
218
226
  self.topk_group = topk_group
219
227
  self.custom_routing_function = custom_routing_function
228
+ self.correction_bias = correction_bias
220
229
 
221
230
  if quant_config is None:
222
231
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
463
472
  expert_data=expert_data,
464
473
  tp_rank=tp_rank,
465
474
  )
466
- elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
475
+ elif quant_method in [
476
+ FusedMoeWeightScaleSupported.GROUP.value,
477
+ FusedMoeWeightScaleSupported.BLOCK.value,
478
+ ]:
467
479
  self._load_model_weight_or_group_weight_scale(
468
480
  shard_id=shard_id,
469
481
  shard_dim=shard_dim,
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
503
515
  )
504
516
  return
505
517
 
506
- @staticmethod
507
- def select_experts(
508
- hidden_states: torch.Tensor,
509
- router_logits: torch.Tensor,
510
- top_k: int,
511
- use_grouped_topk: bool,
512
- renormalize: bool,
513
- topk_group: Optional[int] = None,
514
- num_expert_group: Optional[int] = None,
515
- custom_routing_function: Optional[Callable] = None,
516
- ):
517
- from sglang.srt.layers.fused_moe_triton.fused_moe import (
518
- fused_topk,
519
- grouped_topk,
520
- )
521
-
522
- # DeekSeekv2 uses grouped_top_k
523
- if use_grouped_topk:
524
- assert topk_group is not None
525
- assert num_expert_group is not None
526
- topk_weights, topk_ids = grouped_topk(
527
- hidden_states=hidden_states,
528
- gating_output=router_logits,
529
- topk=top_k,
530
- renormalize=renormalize,
531
- num_expert_group=num_expert_group,
532
- topk_group=topk_group,
533
- )
534
- elif custom_routing_function is None:
535
- topk_weights, topk_ids = fused_topk(
536
- hidden_states=hidden_states,
537
- gating_output=router_logits,
538
- topk=top_k,
539
- renormalize=renormalize,
540
- )
541
- else:
542
- topk_weights, topk_ids = custom_routing_function(
543
- hidden_states=hidden_states,
544
- gating_output=router_logits,
545
- topk=top_k,
546
- renormalize=renormalize,
547
- )
548
-
549
- return topk_weights, topk_ids
550
-
551
518
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
552
519
  assert self.quant_method is not None
553
520
 
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
562
529
  topk_group=self.topk_group,
563
530
  num_expert_group=self.num_expert_group,
564
531
  custom_routing_function=self.custom_routing_function,
532
+ correction_bias=self.correction_bias,
565
533
  )
566
534
 
567
535
  if self.reduce_results and self.tp_size > 1: