sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,60 +1,29 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
- import importlib
4
- from abc import abstractmethod
3
+ import logging
5
4
  from enum import Enum
6
- from typing import Callable, List, Optional, Tuple
5
+ from typing import List, Optional, Tuple
7
6
 
8
7
  import torch
9
8
 
10
- from sglang.srt.custom_op import CustomOp
11
9
  from sglang.srt.distributed import (
12
10
  get_tensor_model_parallel_rank,
13
11
  get_tensor_model_parallel_world_size,
14
12
  tensor_model_parallel_all_reduce,
15
13
  )
16
- from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
17
- from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
18
- from sglang.srt.layers.moe.topk import select_experts
14
+ from sglang.srt.layers.moe.topk import TopKOutput
19
15
  from sglang.srt.layers.quantization.base_config import (
20
16
  QuantizationConfig,
21
17
  QuantizeMethodBase,
22
18
  )
19
+ from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
23
20
  from sglang.srt.managers.schedule_batch import global_server_args_dict
24
21
  from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
25
- from sglang.srt.utils import (
26
- cpu_has_amx_support,
27
- get_bool_env_var,
28
- is_cpu,
29
- is_hip,
30
- set_weight_attrs,
31
- use_intel_amx_backend,
32
- )
33
-
34
- has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
-
36
- if torch.cuda.is_available():
37
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
38
-
39
- if has_triton_kernels:
40
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
41
- triton_kernel_moe_forward,
42
- )
43
- else:
44
- fused_experts = None # type: ignore
45
-
46
- import logging
22
+ from sglang.srt.utils import cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip
47
23
 
48
24
  _is_hip = is_hip()
49
25
  _is_cpu_amx_available = cpu_has_amx_support()
50
26
  _is_cpu = is_cpu()
51
- _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
52
-
53
- if _use_aiter:
54
- from aiter import ActivationType
55
- from aiter.fused_moe import fused_moe
56
- from aiter.fused_moe_bf16_asm import ck_moe_2stages
57
- from aiter.ops.shuffle import shuffle_weight
58
27
 
59
28
  logger = logging.getLogger(__name__)
60
29
 
@@ -66,333 +35,6 @@ class FusedMoeWeightScaleSupported(Enum):
66
35
  BLOCK = "block"
67
36
 
68
37
 
69
- class FusedMoEMethodBase(QuantizeMethodBase):
70
-
71
- @abstractmethod
72
- def create_weights(
73
- self,
74
- layer: torch.nn.Module,
75
- num_experts: int,
76
- hidden_size: int,
77
- intermediate_size: int,
78
- params_dtype: torch.dtype,
79
- **extra_weight_attrs,
80
- ):
81
- raise NotImplementedError
82
-
83
- @abstractmethod
84
- def apply(
85
- self,
86
- layer: torch.nn.Module,
87
- x: torch.Tensor,
88
- router_logits: torch.Tensor,
89
- top_k: int,
90
- renormalize: bool,
91
- use_grouped_topk: bool,
92
- ) -> torch.Tensor:
93
- raise NotImplementedError
94
-
95
-
96
- class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
97
- """MoE method without quantization."""
98
-
99
- def __init__(self, use_triton_kernels: bool = False):
100
- super().__init__()
101
- self.use_triton_kernels = use_triton_kernels
102
-
103
- def create_weights(
104
- self,
105
- layer: torch.nn.Module,
106
- num_experts: int,
107
- hidden_size: int,
108
- intermediate_size: int,
109
- params_dtype: torch.dtype,
110
- **extra_weight_attrs,
111
- ):
112
- # Fused gate_up_proj (column parallel)
113
- w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
114
- if self.use_triton_kernels:
115
- w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
116
- w13_weight = torch.nn.Parameter(
117
- torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
118
- requires_grad=False,
119
- )
120
- layer.register_parameter("w13_weight", w13_weight)
121
- set_weight_attrs(w13_weight, extra_weight_attrs)
122
-
123
- # down_proj (row parallel)
124
- w2_weight_n, w2_weight_k = (
125
- hidden_size,
126
- intermediate_size,
127
- )
128
- if self.use_triton_kernels:
129
- w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
130
- w2_weight = torch.nn.Parameter(
131
- torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
132
- requires_grad=False,
133
- )
134
- layer.register_parameter("w2_weight", w2_weight)
135
- set_weight_attrs(w2_weight, extra_weight_attrs)
136
-
137
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
138
- if _use_aiter:
139
- layer.w13_weight = torch.nn.Parameter(
140
- shuffle_weight(layer.w13_weight.data, (16, 16)),
141
- requires_grad=False,
142
- )
143
- torch.cuda.empty_cache()
144
- layer.w2_weight = torch.nn.Parameter(
145
- shuffle_weight(layer.w2_weight.data, (16, 16)),
146
- requires_grad=False,
147
- )
148
- torch.cuda.empty_cache()
149
-
150
- # Pack weight for get better performance on CPU
151
- if _is_cpu and _is_cpu_amx_available:
152
- _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
153
-
154
- return
155
-
156
- def apply(
157
- self,
158
- layer: torch.nn.Module,
159
- x: torch.Tensor,
160
- router_logits: torch.Tensor,
161
- top_k: int,
162
- renormalize: bool,
163
- use_grouped_topk: bool,
164
- topk_group: Optional[int] = None,
165
- num_expert_group: Optional[int] = None,
166
- num_fused_shared_experts: int = 0,
167
- custom_routing_function: Optional[Callable] = None,
168
- correction_bias: Optional[torch.Tensor] = None,
169
- activation: str = "silu",
170
- apply_router_weight_on_input: bool = False,
171
- inplace: bool = True,
172
- no_combine: bool = False,
173
- routed_scaling_factor: Optional[float] = None,
174
- ) -> torch.Tensor:
175
- return self.forward(
176
- x=x,
177
- layer=layer,
178
- router_logits=router_logits,
179
- top_k=top_k,
180
- renormalize=renormalize,
181
- use_grouped_topk=use_grouped_topk,
182
- topk_group=topk_group,
183
- num_expert_group=num_expert_group,
184
- num_fused_shared_experts=num_fused_shared_experts,
185
- custom_routing_function=custom_routing_function,
186
- correction_bias=correction_bias,
187
- activation=activation,
188
- apply_router_weight_on_input=apply_router_weight_on_input,
189
- inplace=inplace,
190
- no_combine=no_combine,
191
- routed_scaling_factor=routed_scaling_factor,
192
- )
193
-
194
- def forward_cuda(
195
- self,
196
- layer: torch.nn.Module,
197
- x: torch.Tensor,
198
- use_grouped_topk: bool,
199
- top_k: int,
200
- router_logits: torch.Tensor,
201
- renormalize: bool,
202
- topk_group: Optional[int] = None,
203
- num_expert_group: Optional[int] = None,
204
- num_fused_shared_experts: int = 0,
205
- custom_routing_function: Optional[Callable] = None,
206
- correction_bias: Optional[torch.Tensor] = None,
207
- activation: str = "silu",
208
- apply_router_weight_on_input: bool = False,
209
- inplace: bool = True,
210
- no_combine: bool = False,
211
- routed_scaling_factor: Optional[float] = None,
212
- ) -> torch.Tensor:
213
-
214
- if self.use_triton_kernels:
215
- return triton_kernel_moe_forward(
216
- hidden_states=x,
217
- w1=layer.w13_weight,
218
- w2=layer.w2_weight,
219
- gating_output=router_logits,
220
- topk=top_k,
221
- renormalize=renormalize,
222
- )
223
- else:
224
- topk_weights, topk_ids = select_experts(
225
- hidden_states=x,
226
- router_logits=router_logits,
227
- use_grouped_topk=use_grouped_topk,
228
- top_k=top_k,
229
- renormalize=renormalize,
230
- topk_group=topk_group,
231
- num_expert_group=num_expert_group,
232
- num_fused_shared_experts=num_fused_shared_experts,
233
- custom_routing_function=custom_routing_function,
234
- correction_bias=correction_bias,
235
- routed_scaling_factor=routed_scaling_factor,
236
- )
237
-
238
- if _use_aiter:
239
- assert not no_combine, "unsupported"
240
- if apply_router_weight_on_input:
241
- assert (
242
- topk_weights.dim() == 2
243
- ), "`topk_weights` should be in shape (num_tokens, topk)"
244
- _, topk = topk_weights.shape
245
- assert (
246
- topk == 1
247
- ), "Only support topk=1 when `apply_router_weight_on_input` is True"
248
- x = x * topk_weights.to(x.dtype)
249
- topk_weights = torch.ones_like(
250
- topk_weights, dtype=torch.float32
251
- ) # topk_weights must be FP32 (float32)
252
-
253
- return fused_moe(
254
- x,
255
- layer.w13_weight,
256
- layer.w2_weight,
257
- topk_weights,
258
- topk_ids,
259
- activation=(
260
- ActivationType.Silu
261
- if activation == "silu"
262
- else ActivationType.Gelu
263
- ),
264
- )
265
- else:
266
- return fused_experts(
267
- hidden_states=x,
268
- w1=layer.w13_weight,
269
- w2=layer.w2_weight,
270
- topk_weights=topk_weights,
271
- topk_ids=topk_ids,
272
- inplace=inplace and not no_combine,
273
- activation=activation,
274
- apply_router_weight_on_input=apply_router_weight_on_input,
275
- no_combine=no_combine,
276
- routed_scaling_factor=routed_scaling_factor,
277
- )
278
-
279
- def forward_cpu(
280
- self,
281
- layer: torch.nn.Module,
282
- x: torch.Tensor,
283
- use_grouped_topk: bool,
284
- top_k: int,
285
- router_logits: torch.Tensor,
286
- renormalize: bool,
287
- topk_group: Optional[int] = None,
288
- num_expert_group: Optional[int] = None,
289
- num_fused_shared_experts: int = 0,
290
- custom_routing_function: Optional[Callable] = None,
291
- correction_bias: Optional[torch.Tensor] = None,
292
- activation: str = "silu",
293
- apply_router_weight_on_input: bool = False,
294
- inplace: bool = True,
295
- no_combine: bool = False,
296
- routed_scaling_factor: Optional[float] = None,
297
- ) -> torch.Tensor:
298
- assert activation == "silu", f"activation = {activation} is not supported."
299
-
300
- if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
301
- topk_weights, topk_ids = select_experts(
302
- hidden_states=x,
303
- router_logits=router_logits,
304
- use_grouped_topk=use_grouped_topk,
305
- top_k=top_k,
306
- renormalize=renormalize,
307
- topk_group=topk_group,
308
- num_expert_group=num_expert_group,
309
- num_fused_shared_experts=num_fused_shared_experts,
310
- custom_routing_function=custom_routing_function,
311
- correction_bias=correction_bias,
312
- routed_scaling_factor=routed_scaling_factor,
313
- )
314
-
315
- # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
316
- return torch.ops.sgl_kernel.fused_experts_cpu(
317
- x,
318
- layer.w13_weight,
319
- layer.w2_weight,
320
- topk_weights,
321
- topk_ids,
322
- False, # inplace # See [Note] inplace should be False in fused_experts.
323
- False, # use_int8_w8a8
324
- False, # use_fp8_w8a16
325
- None, # w1_scale
326
- None, # w2_scale
327
- None, # block_size
328
- None, # a1_scale
329
- None, # a2_scale
330
- True, # is_vnni
331
- )
332
- else:
333
- return moe_forward_native(
334
- layer,
335
- x,
336
- use_grouped_topk,
337
- top_k,
338
- router_logits,
339
- renormalize,
340
- topk_group,
341
- num_expert_group,
342
- num_fused_shared_experts,
343
- custom_routing_function,
344
- correction_bias,
345
- activation,
346
- apply_router_weight_on_input,
347
- inplace,
348
- no_combine,
349
- routed_scaling_factor,
350
- )
351
-
352
- def forward_npu(
353
- self,
354
- layer: torch.nn.Module,
355
- x: torch.Tensor,
356
- use_grouped_topk: bool,
357
- top_k: int,
358
- router_logits: torch.Tensor,
359
- renormalize: bool,
360
- topk_group: Optional[int] = None,
361
- num_expert_group: Optional[int] = None,
362
- num_fused_shared_experts: int = 0,
363
- custom_routing_function: Optional[Callable] = None,
364
- correction_bias: Optional[torch.Tensor] = None,
365
- activation: str = "silu",
366
- apply_router_weight_on_input: bool = False,
367
- inplace: bool = True,
368
- no_combine: bool = False,
369
- routed_scaling_factor: Optional[float] = None,
370
- ) -> torch.Tensor:
371
- return moe_forward_native(
372
- layer,
373
- x,
374
- use_grouped_topk,
375
- top_k,
376
- router_logits,
377
- renormalize,
378
- topk_group,
379
- num_expert_group,
380
- num_fused_shared_experts,
381
- custom_routing_function,
382
- correction_bias,
383
- activation,
384
- apply_router_weight_on_input,
385
- inplace,
386
- no_combine,
387
- routed_scaling_factor,
388
- )
389
-
390
- def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
391
- raise NotImplementedError("The TPU backend currently does not support MoE.")
392
-
393
- forward_native = forward_cpu
394
-
395
-
396
38
  class FusedMoE(torch.nn.Module):
397
39
  """FusedMoE layer for MoE models.
398
40
 
@@ -418,22 +60,15 @@ class FusedMoE(torch.nn.Module):
418
60
  def __init__(
419
61
  self,
420
62
  num_experts: int,
421
- top_k: int,
422
63
  hidden_size: int,
423
64
  intermediate_size: int,
65
+ top_k: Optional[int] = None,
424
66
  layer_id: Optional[int] = None,
425
67
  params_dtype: Optional[torch.dtype] = None,
426
68
  reduce_results: bool = False,
427
- renormalize: bool = True,
428
- use_grouped_topk: bool = False,
429
- num_expert_group: Optional[int] = None,
430
- num_fused_shared_experts: int = 0,
431
- topk_group: Optional[int] = None,
432
69
  quant_config: Optional[QuantizationConfig] = None,
433
70
  tp_size: Optional[int] = None,
434
71
  prefix: str = "",
435
- custom_routing_function: Optional[Callable] = None,
436
- correction_bias: Optional[torch.Tensor] = None,
437
72
  activation: str = "silu",
438
73
  apply_router_weight_on_input: bool = False,
439
74
  use_presharded_weights: bool = False,
@@ -448,6 +83,7 @@ class FusedMoE(torch.nn.Module):
448
83
  if params_dtype is None:
449
84
  params_dtype = torch.get_default_dtype()
450
85
 
86
+ self.top_k = top_k
451
87
  self.hidden_size = hidden_size
452
88
  self.tp_size = (
453
89
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
@@ -485,19 +121,9 @@ class FusedMoE(torch.nn.Module):
485
121
  self.ep_rank = 0
486
122
  self.local_num_experts = num_experts
487
123
  self.routed_scaling_factor = routed_scaling_factor
488
- self.top_k = top_k
489
124
  assert intermediate_size % self.tp_size == 0
490
125
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
491
126
  self.reduce_results = reduce_results
492
- self.renormalize = renormalize
493
- self.use_grouped_topk = use_grouped_topk
494
- if self.use_grouped_topk:
495
- assert num_expert_group is not None and topk_group is not None
496
- self.num_expert_group = num_expert_group
497
- self.num_fused_shared_experts = num_fused_shared_experts
498
- self.topk_group = topk_group
499
- self.custom_routing_function = custom_routing_function
500
- self.correction_bias = correction_bias
501
127
  self.activation = activation
502
128
  self.apply_router_weight_on_input = apply_router_weight_on_input
503
129
  self.use_presharded_weights = use_presharded_weights
@@ -553,7 +179,7 @@ class FusedMoE(torch.nn.Module):
553
179
  shard_dim: int,
554
180
  expert_data: torch.Tensor,
555
181
  shard_id: str,
556
- loaded_weight: torch.tensor,
182
+ loaded_weight: torch.Tensor,
557
183
  tp_rank: int,
558
184
  ):
559
185
  # Load grouped weight scales for group quantization
@@ -580,7 +206,7 @@ class FusedMoE(torch.nn.Module):
580
206
  expert_data: torch.Tensor,
581
207
  shard_dim: int,
582
208
  shard_id: str,
583
- loaded_weight: torch.tensor,
209
+ loaded_weight: torch.Tensor,
584
210
  tp_rank: int,
585
211
  ):
586
212
  # for per channel weight quantization
@@ -600,7 +226,7 @@ class FusedMoE(torch.nn.Module):
600
226
  expert_data: torch.Tensor,
601
227
  shard_dim: int,
602
228
  shard_id: str,
603
- loaded_weight: torch.tensor,
229
+ loaded_weight: torch.Tensor,
604
230
  tp_rank: int,
605
231
  ):
606
232
 
@@ -645,7 +271,7 @@ class FusedMoE(torch.nn.Module):
645
271
  expert_data: torch.Tensor,
646
272
  shard_dim: int,
647
273
  shard_id: str,
648
- loaded_weight: torch.tensor,
274
+ loaded_weight: torch.Tensor,
649
275
  tp_rank: int,
650
276
  ):
651
277
  """Load w2 weights for down projection.
@@ -717,7 +343,7 @@ class FusedMoE(torch.nn.Module):
717
343
  shard_id: str,
718
344
  expert_data: torch.Tensor,
719
345
  shard_dim: int,
720
- loaded_weight: torch.tensor,
346
+ loaded_weight: torch.Tensor,
721
347
  tp_rank: int,
722
348
  ):
723
349
 
@@ -921,22 +547,14 @@ class FusedMoE(torch.nn.Module):
921
547
  )
922
548
  return
923
549
 
924
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
550
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
925
551
  assert self.quant_method is not None
926
552
 
927
553
  # Matrix multiply.
928
554
  final_hidden_states = self.quant_method.apply(
929
555
  layer=self,
930
556
  x=hidden_states,
931
- router_logits=router_logits,
932
- top_k=self.top_k,
933
- renormalize=self.renormalize,
934
- use_grouped_topk=self.use_grouped_topk,
935
- topk_group=self.topk_group,
936
- num_expert_group=self.num_expert_group,
937
- num_fused_shared_experts=self.num_fused_shared_experts,
938
- custom_routing_function=self.custom_routing_function,
939
- correction_bias=self.correction_bias,
557
+ topk_output=topk_output,
940
558
  activation=self.activation,
941
559
  apply_router_weight_on_input=self.apply_router_weight_on_input,
942
560
  routed_scaling_factor=self.routed_scaling_factor,