sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,422 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ from typing import TYPE_CHECKING, Callable, List, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch.nn.parameter import Parameter
9
+
10
+ from sglang.srt.custom_op import CustomOp
11
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
12
+ from sglang.srt.layers.quantization.base_config import (
13
+ FusedMoEMethodBase,
14
+ LinearMethodBase,
15
+ QuantizeMethodBase,
16
+ )
17
+ from sglang.srt.utils import (
18
+ cpu_has_amx_support,
19
+ get_bool_env_var,
20
+ is_cpu,
21
+ is_hip,
22
+ set_weight_attrs,
23
+ use_intel_amx_backend,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ from sglang.srt.layers.moe.topk import TopKOutput
28
+
29
+ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
30
+
31
+
32
+ _is_cpu_amx_available = cpu_has_amx_support()
33
+ _is_hip = is_hip()
34
+ _is_cpu = is_cpu()
35
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
36
+
37
+ if _use_aiter:
38
+ from aiter import ActivationType
39
+ from aiter.fused_moe import fused_moe
40
+ from aiter.ops.shuffle import shuffle_weight
41
+
42
+
43
+ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
44
+ """Unquantized method for embeddings."""
45
+
46
+ def create_weights(
47
+ self,
48
+ layer: torch.nn.Module,
49
+ input_size_per_partition: int,
50
+ output_partition_sizes: List[int],
51
+ input_size: int,
52
+ output_size: int,
53
+ params_dtype: torch.dtype,
54
+ **extra_weight_attrs,
55
+ ):
56
+ """Create weights for embedding layer."""
57
+ weight = Parameter(
58
+ torch.empty(
59
+ sum(output_partition_sizes),
60
+ input_size_per_partition,
61
+ dtype=params_dtype,
62
+ ),
63
+ requires_grad=False,
64
+ )
65
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
66
+ layer.register_parameter("weight", weight)
67
+ set_weight_attrs(weight, extra_weight_attrs)
68
+
69
+ def apply(
70
+ self,
71
+ layer: torch.nn.Module,
72
+ x: torch.Tensor,
73
+ bias: Optional[torch.Tensor] = None,
74
+ ) -> torch.Tensor:
75
+ return F.linear(x, layer.weight, bias)
76
+
77
+ def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
78
+ return F.embedding(input_, layer.weight)
79
+
80
+
81
+ class UnquantizedLinearMethod(LinearMethodBase):
82
+ """Linear method without quantization."""
83
+
84
+ def create_weights(
85
+ self,
86
+ layer: torch.nn.Module,
87
+ input_size_per_partition: int,
88
+ output_partition_sizes: List[int],
89
+ input_size: int,
90
+ output_size: int,
91
+ params_dtype: torch.dtype,
92
+ **extra_weight_attrs,
93
+ ):
94
+ weight = Parameter(
95
+ torch.empty(
96
+ sum(output_partition_sizes),
97
+ input_size_per_partition,
98
+ dtype=params_dtype,
99
+ ),
100
+ requires_grad=False,
101
+ )
102
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
103
+ layer.register_parameter("weight", weight)
104
+ set_weight_attrs(weight, extra_weight_attrs)
105
+
106
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
107
+ if _is_cpu and _is_cpu_amx_available:
108
+ _amx_process_weight_after_loading(layer, ["weight"])
109
+
110
+ def apply(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ x: torch.Tensor,
114
+ bias: Optional[torch.Tensor] = None,
115
+ ) -> torch.Tensor:
116
+
117
+ if use_intel_amx_backend(layer):
118
+ return torch.ops.sgl_kernel.weight_packed_linear(
119
+ x, layer.weight, bias, True # is_vnni
120
+ )
121
+
122
+ return F.linear(x, layer.weight, bias)
123
+
124
+
125
+ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
126
+ """MoE method without quantization."""
127
+
128
+ def __init__(self, use_triton_kernels: bool = False):
129
+ super().__init__()
130
+ self.use_triton_kernels = use_triton_kernels
131
+
132
+ def create_weights(
133
+ self,
134
+ layer: torch.nn.Module,
135
+ num_experts: int,
136
+ hidden_size: int,
137
+ intermediate_size: int,
138
+ params_dtype: torch.dtype,
139
+ **extra_weight_attrs,
140
+ ):
141
+ # Fused gate_up_proj (column parallel)
142
+ w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
143
+ if self.use_triton_kernels:
144
+ w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
145
+ w13_weight = torch.nn.Parameter(
146
+ torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
147
+ requires_grad=False,
148
+ )
149
+ layer.register_parameter("w13_weight", w13_weight)
150
+ set_weight_attrs(w13_weight, extra_weight_attrs)
151
+
152
+ # down_proj (row parallel)
153
+ w2_weight_n, w2_weight_k = (
154
+ hidden_size,
155
+ intermediate_size,
156
+ )
157
+ if self.use_triton_kernels:
158
+ w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
159
+ w2_weight = torch.nn.Parameter(
160
+ torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
161
+ requires_grad=False,
162
+ )
163
+ layer.register_parameter("w2_weight", w2_weight)
164
+ set_weight_attrs(w2_weight, extra_weight_attrs)
165
+
166
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
167
+ if _use_aiter:
168
+ layer.w13_weight = torch.nn.Parameter(
169
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
170
+ requires_grad=False,
171
+ )
172
+ torch.cuda.empty_cache()
173
+ layer.w2_weight = torch.nn.Parameter(
174
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
175
+ requires_grad=False,
176
+ )
177
+ torch.cuda.empty_cache()
178
+
179
+ # Pack weight for get better performance on CPU
180
+ if _is_cpu and _is_cpu_amx_available:
181
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
182
+
183
+ return
184
+
185
+ def apply(
186
+ self,
187
+ layer: torch.nn.Module,
188
+ x: torch.Tensor,
189
+ topk_output: TopKOutput,
190
+ *,
191
+ activation: str = "silu",
192
+ apply_router_weight_on_input: bool = False,
193
+ inplace: bool = True,
194
+ no_combine: bool = False,
195
+ routed_scaling_factor: Optional[float] = None,
196
+ ) -> torch.Tensor:
197
+ return self.forward(
198
+ x=x,
199
+ layer=layer,
200
+ topk_output=topk_output,
201
+ activation=activation,
202
+ apply_router_weight_on_input=apply_router_weight_on_input,
203
+ inplace=inplace,
204
+ no_combine=no_combine,
205
+ routed_scaling_factor=routed_scaling_factor,
206
+ )
207
+
208
+ def forward_cuda(
209
+ self,
210
+ layer: torch.nn.Module,
211
+ x: torch.Tensor,
212
+ topk_output: TopKOutput,
213
+ *,
214
+ activation: str = "silu",
215
+ apply_router_weight_on_input: bool = False,
216
+ inplace: bool = True,
217
+ no_combine: bool = False,
218
+ routed_scaling_factor: Optional[float] = None,
219
+ ) -> torch.Tensor:
220
+
221
+ if self.use_triton_kernels:
222
+ # TODO(ch-wan): re-enable the Triton kernel
223
+ raise NotImplementedError("The Triton kernel is temporarily disabled.")
224
+ # return triton_kernel_moe_forward(
225
+ # hidden_states=x,
226
+ # w1=layer.w13_weight,
227
+ # w2=layer.w2_weight,
228
+ # gating_output=router_logits,
229
+ # topk=top_k,
230
+ # renormalize=renormalize,
231
+ # )
232
+ else:
233
+ if _use_aiter:
234
+ assert not no_combine, "unsupported"
235
+ topk_weights, topk_ids, _ = topk_output
236
+ if apply_router_weight_on_input:
237
+ assert (
238
+ topk_weights.dim() == 2
239
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
240
+ _, topk = topk_weights.shape
241
+ assert (
242
+ topk == 1
243
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
244
+ x = x * topk_weights.to(x.dtype)
245
+ topk_weights = torch.ones_like(
246
+ topk_weights, dtype=torch.float32
247
+ ) # topk_weights must be FP32 (float32)
248
+ return fused_moe(
249
+ x,
250
+ layer.w13_weight,
251
+ layer.w2_weight,
252
+ topk_weights,
253
+ topk_ids,
254
+ activation=(
255
+ ActivationType.Silu
256
+ if activation == "silu"
257
+ else ActivationType.Gelu
258
+ ),
259
+ )
260
+ else:
261
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
262
+ fused_experts,
263
+ )
264
+
265
+ return fused_experts(
266
+ hidden_states=x,
267
+ w1=layer.w13_weight,
268
+ w2=layer.w2_weight,
269
+ topk_output=topk_output,
270
+ inplace=inplace and not no_combine,
271
+ activation=activation,
272
+ apply_router_weight_on_input=apply_router_weight_on_input,
273
+ no_combine=no_combine,
274
+ routed_scaling_factor=routed_scaling_factor,
275
+ )
276
+
277
+ def forward_cpu(
278
+ self,
279
+ layer: torch.nn.Module,
280
+ x: torch.Tensor,
281
+ topk_output: TopKOutput,
282
+ *,
283
+ activation: str = "silu",
284
+ apply_router_weight_on_input: bool = False,
285
+ inplace: bool = True,
286
+ no_combine: bool = False,
287
+ routed_scaling_factor: Optional[float] = None,
288
+ ) -> torch.Tensor:
289
+ assert activation == "silu", f"activation = {activation} is not supported."
290
+
291
+ if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
292
+ from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
293
+
294
+ topk_weights, topk_ids, _ = topk_output
295
+ x, topk_weights = apply_topk_weights_cpu(
296
+ apply_router_weight_on_input, topk_weights, x
297
+ )
298
+ return torch.ops.sgl_kernel.fused_experts_cpu(
299
+ x,
300
+ layer.w13_weight,
301
+ layer.w2_weight,
302
+ topk_weights,
303
+ topk_ids,
304
+ False, # inplace # See [Note] inplace should be False in fused_experts.
305
+ False, # use_int8_w8a8
306
+ False, # use_fp8_w8a16
307
+ None, # w1_scale
308
+ None, # w2_scale
309
+ None, # block_size
310
+ None, # a1_scale
311
+ None, # a2_scale
312
+ True, # is_vnni
313
+ )
314
+ else:
315
+ from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
316
+
317
+ return moe_forward_native(
318
+ layer,
319
+ x,
320
+ topk_output,
321
+ activation=activation,
322
+ apply_router_weight_on_input=apply_router_weight_on_input,
323
+ inplace=inplace,
324
+ no_combine=no_combine,
325
+ routed_scaling_factor=routed_scaling_factor,
326
+ )
327
+
328
+ def forward_npu(
329
+ self,
330
+ layer: torch.nn.Module,
331
+ x: torch.Tensor,
332
+ topk_output: TopKOutput,
333
+ *,
334
+ activation: str = "silu",
335
+ apply_router_weight_on_input: bool = False,
336
+ inplace: bool = True,
337
+ no_combine: bool = False,
338
+ routed_scaling_factor: Optional[float] = None,
339
+ ) -> torch.Tensor:
340
+ from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
341
+
342
+ return moe_forward_native(
343
+ layer,
344
+ x,
345
+ topk_output,
346
+ activation=activation,
347
+ apply_router_weight_on_input=apply_router_weight_on_input,
348
+ inplace=inplace,
349
+ no_combine=no_combine,
350
+ routed_scaling_factor=routed_scaling_factor,
351
+ )
352
+
353
+ def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
354
+ raise NotImplementedError("The TPU backend currently does not support MoE.")
355
+
356
+ forward_native = forward_cpu
357
+
358
+
359
+ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
360
+
361
+ def create_weights(
362
+ self,
363
+ layer: torch.nn.Module,
364
+ num_experts_per_partition: int,
365
+ hidden_size: int,
366
+ intermediate_size: int,
367
+ params_dtype: torch.dtype,
368
+ **extra_weight_attrs,
369
+ ):
370
+ # Fused gate_up_proj (column parallel)
371
+ w13_weight = torch.nn.Parameter(
372
+ torch.empty(
373
+ num_experts_per_partition,
374
+ 2 * intermediate_size,
375
+ hidden_size,
376
+ dtype=params_dtype,
377
+ ),
378
+ requires_grad=False,
379
+ )
380
+ layer.register_parameter("w13_weight", w13_weight)
381
+ set_weight_attrs(w13_weight, extra_weight_attrs)
382
+
383
+ # down_proj (row parallel)
384
+ w2_weight = torch.nn.Parameter(
385
+ torch.empty(
386
+ num_experts_per_partition,
387
+ hidden_size,
388
+ intermediate_size,
389
+ dtype=params_dtype,
390
+ ),
391
+ requires_grad=False,
392
+ )
393
+ layer.register_parameter("w2_weight", w2_weight)
394
+ set_weight_attrs(w2_weight, extra_weight_attrs)
395
+
396
+ # scale
397
+ layer.register_parameter("w13_input_scale", None)
398
+ layer.register_parameter("w13_weight_scale", None)
399
+
400
+ ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
401
+
402
+ w2_input_scale = torch.nn.Parameter(
403
+ ones_tensor,
404
+ requires_grad=False,
405
+ )
406
+ layer.register_parameter("w2_input_scale", w2_input_scale)
407
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
408
+
409
+ w2_weight_scale = torch.nn.Parameter(
410
+ ones_tensor,
411
+ requires_grad=False,
412
+ )
413
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
414
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
415
+
416
+ def apply(
417
+ self,
418
+ layer: torch.nn.Module,
419
+ hidden_states: torch.Tensor,
420
+ topk_output: TopKOutput,
421
+ ) -> torch.Tensor:
422
+ raise NotImplementedError