sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -6,14 +6,22 @@ import functools
6
6
  import json
7
7
  import logging
8
8
  import os
9
- from typing import Any, Callable, Dict, Optional, Tuple
9
+ from typing import Any, Callable, Dict, List, Optional, Tuple
10
10
 
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
  from vllm import _custom_ops as ops
15
15
 
16
- from sglang.srt.utils import direct_register_custom_op, get_device_name
16
+ from sglang.srt.layers.moe.topk import select_experts
17
+ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
+ from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
+
20
+ not_hip = False
21
+ if not is_hip():
22
+ from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
+
24
+ not_hip = True
17
25
 
18
26
  logger = logging.getLogger(__name__)
19
27
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@@ -47,8 +55,14 @@ def fused_moe_kernel(
47
55
  stride_bn,
48
56
  stride_cm,
49
57
  stride_cn,
58
+ stride_asm,
59
+ stride_ask,
50
60
  stride_bse,
61
+ stride_bsk,
51
62
  stride_bsn,
63
+ # Block size for block-wise quantization
64
+ group_n: tl.constexpr,
65
+ group_k: tl.constexpr,
52
66
  # Meta-parameters
53
67
  BLOCK_SIZE_M: tl.constexpr,
54
68
  BLOCK_SIZE_N: tl.constexpr,
@@ -132,8 +146,15 @@ def fused_moe_kernel(
132
146
  b_scale = tl.load(b_scale_ptrs)
133
147
 
134
148
  if use_fp8_w8a8:
135
- a_scale = tl.load(a_scale_ptr)
136
- b_scale = tl.load(b_scale_ptr + off_experts)
149
+ if group_k > 0 and group_n > 0:
150
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
151
+ offs_bsn = offs_bn // group_n
152
+ b_scale_ptrs = (
153
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
154
+ )
155
+ else:
156
+ a_scale = tl.load(a_scale_ptr)
157
+ b_scale = tl.load(b_scale_ptr + off_experts)
137
158
 
138
159
  # -----------------------------------------------------------
139
160
  # Iterate to compute a block of the C matrix.
@@ -164,7 +185,17 @@ def fused_moe_kernel(
164
185
  if use_int8_w8a16:
165
186
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
166
187
  elif use_fp8_w8a8:
167
- accumulator = tl.dot(a, b, acc=accumulator)
188
+ if group_k > 0 and group_n > 0:
189
+ k_start = k * BLOCK_SIZE_K
190
+ offs_ks = k_start // group_k
191
+ a_scale = tl.load(
192
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
193
+ )
194
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
195
+
196
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
197
+ else:
198
+ accumulator = tl.dot(a, b, acc=accumulator)
168
199
  else:
169
200
  accumulator += tl.dot(a, b)
170
201
  # Advance the ptrs to the next K block.
@@ -177,7 +208,10 @@ def fused_moe_kernel(
177
208
  if use_int8_w8a16:
178
209
  accumulator = (accumulator * b_scale).to(compute_type)
179
210
  elif use_fp8_w8a8:
180
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
211
+ if group_k > 0 and group_n > 0:
212
+ accumulator = accumulator.to(compute_type)
213
+ else:
214
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
181
215
  else:
182
216
  accumulator = accumulator.to(compute_type)
183
217
  # -----------------------------------------------------------
@@ -238,9 +272,33 @@ def moe_align_block_size(
238
272
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
239
273
  )
240
274
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
241
- ops.moe_align_block_size(
242
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
243
- )
275
+ if not_hip and num_experts >= 224:
276
+ token_cnts_buffer = torch.empty(
277
+ (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
278
+ )
279
+ cumsum_buffer = torch.empty(
280
+ num_experts + 1, dtype=torch.int32, device=topk_ids.device
281
+ )
282
+
283
+ sgl_moe_align_block_size(
284
+ topk_ids,
285
+ num_experts,
286
+ block_size,
287
+ sorted_ids,
288
+ expert_ids,
289
+ num_tokens_post_pad,
290
+ token_cnts_buffer,
291
+ cumsum_buffer,
292
+ )
293
+ else:
294
+ ops.moe_align_block_size(
295
+ topk_ids,
296
+ num_experts,
297
+ block_size,
298
+ sorted_ids,
299
+ expert_ids,
300
+ num_tokens_post_pad,
301
+ )
244
302
  return sorted_ids, expert_ids, num_tokens_post_pad
245
303
 
246
304
 
@@ -261,6 +319,7 @@ def invoke_fused_moe_kernel(
261
319
  compute_type: tl.dtype,
262
320
  use_fp8_w8a8: bool,
263
321
  use_int8_w8a16: bool,
322
+ block_shape: Optional[List[int]] = None,
264
323
  ) -> None:
265
324
  assert topk_weights.stride(1) == 1
266
325
  assert sorted_token_ids.stride(0) == 1
@@ -268,8 +327,16 @@ def invoke_fused_moe_kernel(
268
327
  padded_size = 0
269
328
  if use_fp8_w8a8:
270
329
  padded_size = padding_size
271
- A, A_scale = ops.scaled_fp8_quant(A, A_scale)
272
330
  assert B_scale is not None
331
+ if block_shape is None:
332
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
333
+ else:
334
+ assert len(block_shape) == 2
335
+ block_n, block_k = block_shape[0], block_shape[1]
336
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
337
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
338
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
339
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
273
340
  elif use_int8_w8a16:
274
341
  assert B_scale is not None
275
342
  else:
@@ -308,8 +375,13 @@ def invoke_fused_moe_kernel(
308
375
  B.stride(1),
309
376
  C.stride(1),
310
377
  C.stride(2),
311
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
312
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
378
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
379
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
380
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
381
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
382
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
383
+ 0 if block_shape is None else block_shape[0],
384
+ 0 if block_shape is None else block_shape[1],
313
385
  MUL_ROUTED_WEIGHT=mul_routed_weight,
314
386
  top_k=top_k,
315
387
  compute_type=compute_type,
@@ -414,8 +486,9 @@ def try_get_optimal_moe_config(
414
486
  dtype: Optional[str],
415
487
  M: int,
416
488
  is_marlin: bool = False,
489
+ block_shape: Optional[List[int]] = None,
417
490
  ):
418
- from sglang.srt.layers.fused_moe_triton import get_config
491
+ from sglang.srt.layers.moe.fused_moe_triton import get_config
419
492
 
420
493
  override_config = get_config()
421
494
  if override_config:
@@ -432,77 +505,16 @@ def try_get_optimal_moe_config(
432
505
  else:
433
506
  # Else use the default config
434
507
  config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
508
+ # TODO(HandH1998): Optimize the configs of block-wise quant.
509
+ # NOTE(HandH1998): For block-wise quant,
510
+ # BLOCK_K must be divisable by block_shape[1]
511
+ # BLOCK_N and BLOCK_M has no requirements
512
+ if block_shape is not None:
513
+ config["BLOCK_SIZE_N"] = block_shape[0]
514
+ config["BLOCK_SIZE_K"] = block_shape[1]
435
515
  return config
436
516
 
437
517
 
438
- def fused_topk(
439
- hidden_states: torch.Tensor,
440
- gating_output: torch.Tensor,
441
- topk: int,
442
- renormalize: bool,
443
- ):
444
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
445
-
446
- M, _ = hidden_states.shape
447
-
448
- topk_weights = torch.empty(
449
- M, topk, dtype=torch.float32, device=hidden_states.device
450
- )
451
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
452
- token_expert_indicies = torch.empty(
453
- M, topk, dtype=torch.int32, device=hidden_states.device
454
- )
455
-
456
- ops.topk_softmax(
457
- topk_weights,
458
- topk_ids,
459
- token_expert_indicies,
460
- gating_output.float(), # TODO(woosuk): Optimize this.
461
- )
462
- del token_expert_indicies # Not used. Will be used in the future.
463
-
464
- if renormalize:
465
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
466
-
467
- return topk_weights, topk_ids
468
-
469
-
470
- # This is used by the Deepseek-V2 model
471
- def grouped_topk(
472
- hidden_states: torch.Tensor,
473
- gating_output: torch.Tensor,
474
- topk: int,
475
- renormalize: bool,
476
- num_expert_group: int = 0,
477
- topk_group: int = 0,
478
- ):
479
-
480
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
481
-
482
- scores = torch.softmax(gating_output, dim=-1)
483
- num_token = scores.shape[0]
484
- group_scores = (
485
- scores.view(num_token, num_expert_group, -1).max(dim=-1).values
486
- ) # [n, n_group]
487
- group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
488
- 1
489
- ] # [n, top_k_group]
490
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
491
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
492
- score_mask = (
493
- group_mask.unsqueeze(-1)
494
- .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
495
- .reshape(num_token, -1)
496
- ) # [n, e]
497
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
498
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
499
-
500
- if renormalize:
501
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
502
-
503
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
504
-
505
-
506
518
  def get_config_dtype_str(
507
519
  dtype: torch.dtype,
508
520
  use_int8_w8a16: Optional[bool] = False,
@@ -531,6 +543,7 @@ def inplace_fused_experts(
531
543
  w2_scale: Optional[torch.Tensor] = None,
532
544
  a1_scale: Optional[torch.Tensor] = None,
533
545
  a2_scale: Optional[torch.Tensor] = None,
546
+ block_shape: Optional[List[int]] = None,
534
547
  ) -> None:
535
548
  fused_experts_impl(
536
549
  hidden_states,
@@ -545,6 +558,7 @@ def inplace_fused_experts(
545
558
  w2_scale,
546
559
  a1_scale,
547
560
  a2_scale,
561
+ block_shape,
548
562
  )
549
563
 
550
564
 
@@ -560,6 +574,7 @@ def inplace_fused_experts_fake(
560
574
  w2_scale: Optional[torch.Tensor] = None,
561
575
  a1_scale: Optional[torch.Tensor] = None,
562
576
  a2_scale: Optional[torch.Tensor] = None,
577
+ block_shape: Optional[List[int]] = None,
563
578
  ) -> None:
564
579
  pass
565
580
 
@@ -584,6 +599,7 @@ def outplace_fused_experts(
584
599
  w2_scale: Optional[torch.Tensor] = None,
585
600
  a1_scale: Optional[torch.Tensor] = None,
586
601
  a2_scale: Optional[torch.Tensor] = None,
602
+ block_shape: Optional[List[int]] = None,
587
603
  ) -> torch.Tensor:
588
604
  return fused_experts_impl(
589
605
  hidden_states,
@@ -598,6 +614,7 @@ def outplace_fused_experts(
598
614
  w2_scale,
599
615
  a1_scale,
600
616
  a2_scale,
617
+ block_shape,
601
618
  )
602
619
 
603
620
 
@@ -613,6 +630,7 @@ def outplace_fused_experts_fake(
613
630
  w2_scale: Optional[torch.Tensor] = None,
614
631
  a1_scale: Optional[torch.Tensor] = None,
615
632
  a2_scale: Optional[torch.Tensor] = None,
633
+ block_shape: Optional[List[int]] = None,
616
634
  ) -> torch.Tensor:
617
635
  return torch.empty_like(hidden_states)
618
636
 
@@ -638,6 +656,7 @@ def fused_experts(
638
656
  w2_scale: Optional[torch.Tensor] = None,
639
657
  a1_scale: Optional[torch.Tensor] = None,
640
658
  a2_scale: Optional[torch.Tensor] = None,
659
+ block_shape: Optional[List[int]] = None,
641
660
  ):
642
661
  if inplace:
643
662
  torch.ops.sglang.inplace_fused_experts(
@@ -652,6 +671,7 @@ def fused_experts(
652
671
  w2_scale,
653
672
  a1_scale,
654
673
  a2_scale,
674
+ block_shape,
655
675
  )
656
676
  return hidden_states
657
677
  else:
@@ -667,6 +687,7 @@ def fused_experts(
667
687
  w2_scale,
668
688
  a1_scale,
669
689
  a2_scale,
690
+ block_shape,
670
691
  )
671
692
 
672
693
 
@@ -683,6 +704,7 @@ def fused_experts_impl(
683
704
  w2_scale: Optional[torch.Tensor] = None,
684
705
  a1_scale: Optional[torch.Tensor] = None,
685
706
  a2_scale: Optional[torch.Tensor] = None,
707
+ block_shape: Optional[List[int]] = None,
686
708
  ):
687
709
  padded_size = padding_size
688
710
  if not use_fp8_w8a8:
@@ -714,6 +736,7 @@ def fused_experts_impl(
714
736
  (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
715
737
  topk_ids.shape[1],
716
738
  config_dtype,
739
+ block_shape=block_shape,
717
740
  )
718
741
 
719
742
  config = get_config_func(M)
@@ -786,6 +809,7 @@ def fused_experts_impl(
786
809
  compute_type=compute_type,
787
810
  use_fp8_w8a8=use_fp8_w8a8,
788
811
  use_int8_w8a16=use_int8_w8a16,
812
+ block_shape=block_shape,
789
813
  )
790
814
 
791
815
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -807,6 +831,7 @@ def fused_experts_impl(
807
831
  compute_type=compute_type,
808
832
  use_fp8_w8a8=use_fp8_w8a8,
809
833
  use_int8_w8a16=use_int8_w8a16,
834
+ block_shape=block_shape,
810
835
  )
811
836
 
812
837
  torch.sum(
@@ -835,6 +860,7 @@ def fused_moe(
835
860
  w2_scale: Optional[torch.Tensor] = None,
836
861
  a1_scale: Optional[torch.Tensor] = None,
837
862
  a2_scale: Optional[torch.Tensor] = None,
863
+ block_shape: Optional[List[int]] = None,
838
864
  ) -> torch.Tensor:
839
865
  """
840
866
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -862,6 +888,12 @@ def fused_moe(
862
888
  w1.
863
889
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
864
890
  w2.
891
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
892
+ a1.
893
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
894
+ a2.
895
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
896
+ quantization.
865
897
 
866
898
  Returns:
867
899
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -869,24 +901,16 @@ def fused_moe(
869
901
  # Check constraints.
870
902
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
871
903
 
872
- if use_grouped_topk:
873
- assert num_expert_group is not None and topk_group is not None
874
- topk_weights, topk_ids = grouped_topk(
875
- hidden_states,
876
- gating_output,
877
- topk,
878
- renormalize,
879
- num_expert_group,
880
- topk_group,
881
- )
882
- elif custom_routing_function is None:
883
- topk_weights, topk_ids = fused_topk(
884
- hidden_states, gating_output, topk, renormalize
885
- )
886
- else:
887
- topk_weights, topk_ids = custom_routing_function(
888
- hidden_states, gating_output, topk, renormalize
889
- )
904
+ topk_weights, topk_ids = select_experts(
905
+ hidden_states=hidden_states,
906
+ router_logits=gating_output,
907
+ use_grouped_topk=use_grouped_topk,
908
+ top_k=topk,
909
+ renormalize=renormalize,
910
+ topk_group=topk_group,
911
+ num_expert_group=num_expert_group,
912
+ custom_routing_function=custom_routing_function,
913
+ )
890
914
 
891
915
  return fused_experts(
892
916
  hidden_states,
@@ -901,4 +925,5 @@ def fused_moe(
901
925
  w2_scale=w2_scale,
902
926
  a1_scale=a1_scale,
903
927
  a2_scale=a2_scale,
928
+ block_shape=block_shape,
904
929
  )
@@ -13,6 +13,7 @@ from vllm.distributed import (
13
13
  from vllm.model_executor.custom_op import CustomOp
14
14
 
15
15
  from sglang.srt.layers.custom_op_util import register_custom_op
16
+ from sglang.srt.layers.moe.topk import select_experts
16
17
  from sglang.srt.layers.quantization.base_config import (
17
18
  QuantizationConfig,
18
19
  QuantizeMethodBase,
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
20
21
  from sglang.srt.utils import set_weight_attrs
21
22
 
22
23
  if torch.cuda.is_available():
23
- from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
24
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
24
25
  else:
25
26
  fused_experts = None # type: ignore
26
27
 
@@ -33,6 +34,7 @@ class FusedMoeWeightScaleSupported(Enum):
33
34
  TENSOR = "tensor"
34
35
  CHANNEL = "channel"
35
36
  GROUP = "group"
37
+ BLOCK = "block"
36
38
 
37
39
 
38
40
  class FusedMoEMethodBase(QuantizeMethodBase):
@@ -106,6 +108,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
106
108
  topk_group: Optional[int] = None,
107
109
  num_expert_group: Optional[int] = None,
108
110
  custom_routing_function: Optional[Callable] = None,
111
+ correction_bias: Optional[torch.Tensor] = None,
109
112
  ) -> torch.Tensor:
110
113
  return self.forward(
111
114
  x=x,
@@ -117,6 +120,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
117
120
  topk_group=topk_group,
118
121
  num_expert_group=num_expert_group,
119
122
  custom_routing_function=custom_routing_function,
123
+ correction_bias=correction_bias,
120
124
  )
121
125
 
122
126
  def forward_cuda(
@@ -130,8 +134,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
130
134
  topk_group: Optional[int] = None,
131
135
  num_expert_group: Optional[int] = None,
132
136
  custom_routing_function: Optional[Callable] = None,
137
+ correction_bias: Optional[torch.Tensor] = None,
133
138
  ) -> torch.Tensor:
134
- topk_weights, topk_ids = FusedMoE.select_experts(
139
+ topk_weights, topk_ids = select_experts(
135
140
  hidden_states=x,
136
141
  router_logits=router_logits,
137
142
  use_grouped_topk=use_grouped_topk,
@@ -140,6 +145,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
140
145
  topk_group=topk_group,
141
146
  num_expert_group=num_expert_group,
142
147
  custom_routing_function=custom_routing_function,
148
+ correction_bias=correction_bias,
143
149
  )
144
150
 
145
151
  return fused_experts(
@@ -197,6 +203,7 @@ class FusedMoE(torch.nn.Module):
197
203
  tp_size: Optional[int] = None,
198
204
  prefix: str = "",
199
205
  custom_routing_function: Optional[Callable] = None,
206
+ correction_bias: Optional[torch.Tensor] = None,
200
207
  ):
201
208
  super().__init__()
202
209
 
@@ -208,6 +215,7 @@ class FusedMoE(torch.nn.Module):
208
215
  )
209
216
  self.top_k = top_k
210
217
  self.num_experts = num_experts
218
+ assert intermediate_size % self.tp_size == 0
211
219
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
212
220
  self.reduce_results = reduce_results
213
221
  self.renormalize = renormalize
@@ -217,6 +225,7 @@ class FusedMoE(torch.nn.Module):
217
225
  self.num_expert_group = num_expert_group
218
226
  self.topk_group = topk_group
219
227
  self.custom_routing_function = custom_routing_function
228
+ self.correction_bias = correction_bias
220
229
 
221
230
  if quant_config is None:
222
231
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -463,7 +472,10 @@ class FusedMoE(torch.nn.Module):
463
472
  expert_data=expert_data,
464
473
  tp_rank=tp_rank,
465
474
  )
466
- elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
475
+ elif quant_method in [
476
+ FusedMoeWeightScaleSupported.GROUP.value,
477
+ FusedMoeWeightScaleSupported.BLOCK.value,
478
+ ]:
467
479
  self._load_model_weight_or_group_weight_scale(
468
480
  shard_id=shard_id,
469
481
  shard_dim=shard_dim,
@@ -503,51 +515,6 @@ class FusedMoE(torch.nn.Module):
503
515
  )
504
516
  return
505
517
 
506
- @staticmethod
507
- def select_experts(
508
- hidden_states: torch.Tensor,
509
- router_logits: torch.Tensor,
510
- top_k: int,
511
- use_grouped_topk: bool,
512
- renormalize: bool,
513
- topk_group: Optional[int] = None,
514
- num_expert_group: Optional[int] = None,
515
- custom_routing_function: Optional[Callable] = None,
516
- ):
517
- from sglang.srt.layers.fused_moe_triton.fused_moe import (
518
- fused_topk,
519
- grouped_topk,
520
- )
521
-
522
- # DeekSeekv2 uses grouped_top_k
523
- if use_grouped_topk:
524
- assert topk_group is not None
525
- assert num_expert_group is not None
526
- topk_weights, topk_ids = grouped_topk(
527
- hidden_states=hidden_states,
528
- gating_output=router_logits,
529
- topk=top_k,
530
- renormalize=renormalize,
531
- num_expert_group=num_expert_group,
532
- topk_group=topk_group,
533
- )
534
- elif custom_routing_function is None:
535
- topk_weights, topk_ids = fused_topk(
536
- hidden_states=hidden_states,
537
- gating_output=router_logits,
538
- topk=top_k,
539
- renormalize=renormalize,
540
- )
541
- else:
542
- topk_weights, topk_ids = custom_routing_function(
543
- hidden_states=hidden_states,
544
- gating_output=router_logits,
545
- topk=top_k,
546
- renormalize=renormalize,
547
- )
548
-
549
- return topk_weights, topk_ids
550
-
551
518
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
552
519
  assert self.quant_method is not None
553
520
 
@@ -562,6 +529,7 @@ class FusedMoE(torch.nn.Module):
562
529
  topk_group=self.topk_group,
563
530
  num_expert_group=self.num_expert_group,
564
531
  custom_routing_function=self.custom_routing_function,
532
+ correction_bias=self.correction_bias,
565
533
  )
566
534
 
567
535
  if self.reduce_results and self.tp_size > 1: