sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 64,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 256,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -6,13 +6,13 @@ import functools
6
6
  import json
7
7
  import logging
8
8
  import os
9
- from typing import Any, Callable, Dict, List, Optional, Tuple
9
+ from typing import Any, Dict, List, Optional, Tuple
10
10
 
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
- from sglang.srt.layers.moe.topk import select_experts
15
+ from sglang.srt.layers.moe.topk import TopKOutput
16
16
  from sglang.srt.layers.quantization.fp8_kernel import (
17
17
  per_token_group_quant_fp8,
18
18
  scaled_fp8_quant,
@@ -39,11 +39,20 @@ _is_hip = is_hip()
39
39
  _is_cuda = is_cuda()
40
40
  _is_cpu_amx_available = cpu_has_amx_support()
41
41
  _is_cpu = is_cpu()
42
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
42
43
 
43
44
  if _is_cuda:
44
45
  from sgl_kernel import gelu_and_mul, silu_and_mul
45
46
  elif _is_cpu and _is_cpu_amx_available:
46
47
  pass
48
+ elif _is_hip:
49
+ from vllm import _custom_ops as vllm_ops # gelu_and_mul, silu_and_mul
50
+
51
+ if _use_aiter:
52
+ try:
53
+ from aiter import moe_sum
54
+ except ImportError:
55
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
47
56
  else:
48
57
  from vllm import _custom_ops as vllm_ops
49
58
  from vllm._custom_ops import scaled_fp8_quant
@@ -752,14 +761,13 @@ def moe_align_block_size(
752
761
  sorted_ids = torch.empty(
753
762
  (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
754
763
  )
755
- sorted_ids.fill_(topk_ids.numel())
756
-
757
764
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
758
765
  expert_ids = torch.empty(
759
766
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
760
767
  )
761
768
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
762
769
  if enable_moe_align_block_size_triton:
770
+ sorted_ids.fill_(topk_ids.numel())
763
771
  moe_align_block_size_triton(
764
772
  topk_ids,
765
773
  num_experts,
@@ -778,6 +786,11 @@ def moe_align_block_size(
778
786
  device=topk_ids.device,
779
787
  )
780
788
 
789
+ # Threshold based on benchmark results
790
+ fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
791
+ if not fuse_sorted_ids_padding:
792
+ sorted_ids.fill_(topk_ids.numel())
793
+
781
794
  sgl_moe_align_block_size(
782
795
  topk_ids,
783
796
  num_experts,
@@ -787,6 +800,7 @@ def moe_align_block_size(
787
800
  num_tokens_post_pad,
788
801
  token_cnts_buffer,
789
802
  cumsum_buffer,
803
+ fuse_sorted_ids_padding,
790
804
  )
791
805
  return sorted_ids, expert_ids, num_tokens_post_pad
792
806
 
@@ -1328,8 +1342,7 @@ def fused_experts(
1328
1342
  hidden_states: torch.Tensor,
1329
1343
  w1: torch.Tensor,
1330
1344
  w2: torch.Tensor,
1331
- topk_weights: torch.Tensor,
1332
- topk_ids: torch.Tensor,
1345
+ topk_output: TopKOutput,
1333
1346
  inplace: bool = False,
1334
1347
  activation: str = "silu",
1335
1348
  apply_router_weight_on_input: bool = False,
@@ -1348,7 +1361,7 @@ def fused_experts(
1348
1361
  no_combine: bool = False,
1349
1362
  routed_scaling_factor: Optional[float] = None,
1350
1363
  ):
1351
-
1364
+ topk_weights, topk_ids, _ = topk_output
1352
1365
  if inplace:
1353
1366
  assert not no_combine, "no combine + inplace makes no sense"
1354
1367
  torch.ops.sglang.inplace_fused_experts(
@@ -1517,11 +1530,7 @@ def fused_experts_impl(
1517
1530
  routed_scaling_factor: Optional[float] = None,
1518
1531
  ):
1519
1532
  padded_size = padding_size
1520
- if (
1521
- not (use_fp8_w8a8 or use_int8_w8a8)
1522
- or block_shape is not None
1523
- or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
1524
- ):
1533
+ if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
1525
1534
  padded_size = 0
1526
1535
 
1527
1536
  # Check constraints.
@@ -1719,6 +1728,17 @@ def fused_experts_impl(
1719
1728
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1720
1729
  routed_scaling_factor,
1721
1730
  )
1731
+ elif _is_hip:
1732
+ if _use_aiter:
1733
+ moe_sum(
1734
+ intermediate_cache3.view(*intermediate_cache3.shape),
1735
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1736
+ )
1737
+ else:
1738
+ vllm_ops.moe_sum(
1739
+ intermediate_cache3.view(*intermediate_cache3.shape),
1740
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1741
+ )
1722
1742
  else:
1723
1743
  vllm_ops.moe_sum(
1724
1744
  intermediate_cache3.view(*intermediate_cache3.shape),
@@ -1732,17 +1752,10 @@ def fused_moe(
1732
1752
  hidden_states: torch.Tensor,
1733
1753
  w1: torch.Tensor,
1734
1754
  w2: torch.Tensor,
1735
- gating_output: torch.Tensor,
1736
- topk: int,
1737
- renormalize: bool,
1755
+ topk_output: TopKOutput,
1738
1756
  inplace: bool = False,
1739
1757
  activation: str = "silu",
1740
1758
  apply_router_weight_on_input: bool = False,
1741
- use_grouped_topk: bool = False,
1742
- num_expert_group: Optional[int] = None,
1743
- num_fused_shared_experts: int = 0,
1744
- topk_group: Optional[int] = None,
1745
- custom_routing_function: Optional[Callable] = None,
1746
1759
  use_fp8_w8a8: bool = False,
1747
1760
  use_int8_w8a8: bool = False,
1748
1761
  use_int8_w8a16: bool = False,
@@ -1766,16 +1779,9 @@ def fused_moe(
1766
1779
  - hidden_states (torch.Tensor): The input tensor to the MoE layer.
1767
1780
  - w1 (torch.Tensor): The first set of expert weights.
1768
1781
  - w2 (torch.Tensor): The second set of expert weights.
1769
- - gating_output (torch.Tensor): The output of the gating operation
1770
- (before softmax).
1771
- - topk (int): The number of top-k experts to select.
1772
- - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1782
+ - topk_output (TopKOutput): The top-k output of the experts.
1773
1783
  - inplace (bool): If True, perform the operation in-place.
1774
1784
  Defaults to False.
1775
- - num_expert_group: Optional[int]: additional parameter for grouped_topk
1776
- - topk_group: Optional[int]: additional parameter for grouped_topk
1777
- - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1778
- note: Deepseek V2/V3/R1 series models use grouped_topk
1779
1785
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1780
1786
  products for w1 and w2. Defaults to False.
1781
1787
  - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
@@ -1799,28 +1805,12 @@ def fused_moe(
1799
1805
  Returns:
1800
1806
  - torch.Tensor: The output tensor after applying the MoE layer.
1801
1807
  """
1802
- # Check constraints.
1803
- assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
1804
-
1805
- topk_weights, topk_ids = select_experts(
1806
- hidden_states=hidden_states,
1807
- router_logits=gating_output,
1808
- use_grouped_topk=use_grouped_topk,
1809
- top_k=topk,
1810
- renormalize=renormalize,
1811
- topk_group=topk_group,
1812
- num_expert_group=num_expert_group,
1813
- num_fused_shared_experts=num_fused_shared_experts,
1814
- custom_routing_function=custom_routing_function,
1815
- routed_scaling_factor=routed_scaling_factor,
1816
- )
1817
1808
 
1818
1809
  return fused_experts(
1819
1810
  hidden_states,
1820
1811
  w1,
1821
1812
  w2,
1822
- topk_weights,
1823
- topk_ids,
1813
+ topk_output,
1824
1814
  inplace=inplace,
1825
1815
  activation=activation,
1826
1816
  apply_router_weight_on_input=apply_router_weight_on_input,