sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
2
2
 
3
+ import importlib
3
4
  from abc import abstractmethod
4
5
  from enum import Enum
5
6
  from typing import Callable, List, Optional, Tuple
@@ -12,23 +13,33 @@ from sglang.srt.distributed import (
12
13
  get_tensor_model_parallel_world_size,
13
14
  tensor_model_parallel_all_reduce,
14
15
  )
16
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
15
17
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
16
18
  from sglang.srt.layers.moe.topk import select_experts
17
19
  from sglang.srt.layers.quantization.base_config import (
18
20
  QuantizationConfig,
19
21
  QuantizeMethodBase,
20
22
  )
23
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
24
+ from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
21
25
  from sglang.srt.utils import (
22
- _process_weight_after_loading,
23
26
  cpu_has_amx_support,
24
27
  get_bool_env_var,
25
28
  is_cpu,
26
29
  is_hip,
27
30
  set_weight_attrs,
31
+ use_intel_amx_backend,
28
32
  )
29
33
 
34
+ has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
35
+
30
36
  if torch.cuda.is_available():
31
37
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
38
+
39
+ if has_triton_kernels:
40
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
41
+ triton_kernel_moe_forward,
42
+ )
32
43
  else:
33
44
  fused_experts = None # type: ignore
34
45
 
@@ -85,6 +96,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
85
96
  class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
86
97
  """MoE method without quantization."""
87
98
 
99
+ def __init__(self, use_triton_kernels: bool = False):
100
+ super().__init__()
101
+ self.use_triton_kernels = use_triton_kernels
102
+
88
103
  def create_weights(
89
104
  self,
90
105
  layer: torch.nn.Module,
@@ -95,20 +110,25 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
95
110
  **extra_weight_attrs,
96
111
  ):
97
112
  # Fused gate_up_proj (column parallel)
113
+ w13_weight_n, w13_weight_k = 2 * intermediate_size, hidden_size
114
+ if self.use_triton_kernels:
115
+ w13_weight_n, w13_weight_k = w13_weight_k, w13_weight_n
98
116
  w13_weight = torch.nn.Parameter(
99
- torch.empty(
100
- num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
101
- ),
117
+ torch.empty(num_experts, w13_weight_n, w13_weight_k, dtype=params_dtype),
102
118
  requires_grad=False,
103
119
  )
104
120
  layer.register_parameter("w13_weight", w13_weight)
105
121
  set_weight_attrs(w13_weight, extra_weight_attrs)
106
122
 
107
123
  # down_proj (row parallel)
124
+ w2_weight_n, w2_weight_k = (
125
+ hidden_size,
126
+ intermediate_size,
127
+ )
128
+ if self.use_triton_kernels:
129
+ w2_weight_n, w2_weight_k = w2_weight_k, w2_weight_n
108
130
  w2_weight = torch.nn.Parameter(
109
- torch.empty(
110
- num_experts, hidden_size, intermediate_size, dtype=params_dtype
111
- ),
131
+ torch.empty(num_experts, w2_weight_n, w2_weight_k, dtype=params_dtype),
112
132
  requires_grad=False,
113
133
  )
114
134
  layer.register_parameter("w2_weight", w2_weight)
@@ -129,7 +149,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
129
149
 
130
150
  # Pack weight for get better performance on CPU
131
151
  if _is_cpu and _is_cpu_amx_available:
132
- _process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
152
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133
153
 
134
154
  return
135
155
 
@@ -190,59 +210,72 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
190
210
  no_combine: bool = False,
191
211
  routed_scaling_factor: Optional[float] = None,
192
212
  ) -> torch.Tensor:
193
- topk_weights, topk_ids = select_experts(
194
- hidden_states=x,
195
- router_logits=router_logits,
196
- use_grouped_topk=use_grouped_topk,
197
- top_k=top_k,
198
- renormalize=renormalize,
199
- topk_group=topk_group,
200
- num_expert_group=num_expert_group,
201
- num_fused_shared_experts=num_fused_shared_experts,
202
- custom_routing_function=custom_routing_function,
203
- correction_bias=correction_bias,
204
- routed_scaling_factor=routed_scaling_factor,
205
- )
206
213
 
207
- if _use_aiter:
208
- assert not no_combine, "unsupported"
209
- if apply_router_weight_on_input:
210
- assert (
211
- topk_weights.dim() == 2
212
- ), "`topk_weights` should be in shape (num_tokens, topk)"
213
- _, topk = topk_weights.shape
214
- assert (
215
- topk == 1
216
- ), "Only support topk=1 when `apply_router_weight_on_input` is True"
217
- x = x * topk_weights.to(x.dtype)
218
- topk_weights = torch.ones_like(
219
- topk_weights, dtype=torch.float32
220
- ) # topk_weights must be FP32 (float32)
221
-
222
- return fused_moe(
223
- x,
224
- layer.w13_weight,
225
- layer.w2_weight,
226
- topk_weights,
227
- topk_ids,
228
- activation=(
229
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
230
- ),
231
- )
232
- else:
233
- return fused_experts(
214
+ if self.use_triton_kernels:
215
+ return triton_kernel_moe_forward(
234
216
  hidden_states=x,
235
217
  w1=layer.w13_weight,
236
218
  w2=layer.w2_weight,
237
- topk_weights=topk_weights,
238
- topk_ids=topk_ids,
239
- inplace=inplace and not no_combine,
240
- activation=activation,
241
- apply_router_weight_on_input=apply_router_weight_on_input,
242
- no_combine=no_combine,
219
+ gating_output=router_logits,
220
+ topk=top_k,
221
+ renormalize=renormalize,
222
+ )
223
+ else:
224
+ topk_weights, topk_ids = select_experts(
225
+ hidden_states=x,
226
+ router_logits=router_logits,
227
+ use_grouped_topk=use_grouped_topk,
228
+ top_k=top_k,
229
+ renormalize=renormalize,
230
+ topk_group=topk_group,
231
+ num_expert_group=num_expert_group,
232
+ num_fused_shared_experts=num_fused_shared_experts,
233
+ custom_routing_function=custom_routing_function,
234
+ correction_bias=correction_bias,
243
235
  routed_scaling_factor=routed_scaling_factor,
244
236
  )
245
237
 
238
+ if _use_aiter:
239
+ assert not no_combine, "unsupported"
240
+ if apply_router_weight_on_input:
241
+ assert (
242
+ topk_weights.dim() == 2
243
+ ), "`topk_weights` should be in shape (num_tokens, topk)"
244
+ _, topk = topk_weights.shape
245
+ assert (
246
+ topk == 1
247
+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
248
+ x = x * topk_weights.to(x.dtype)
249
+ topk_weights = torch.ones_like(
250
+ topk_weights, dtype=torch.float32
251
+ ) # topk_weights must be FP32 (float32)
252
+
253
+ return fused_moe(
254
+ x,
255
+ layer.w13_weight,
256
+ layer.w2_weight,
257
+ topk_weights,
258
+ topk_ids,
259
+ activation=(
260
+ ActivationType.Silu
261
+ if activation == "silu"
262
+ else ActivationType.Gelu
263
+ ),
264
+ )
265
+ else:
266
+ return fused_experts(
267
+ hidden_states=x,
268
+ w1=layer.w13_weight,
269
+ w2=layer.w2_weight,
270
+ topk_weights=topk_weights,
271
+ topk_ids=topk_ids,
272
+ inplace=inplace and not no_combine,
273
+ activation=activation,
274
+ apply_router_weight_on_input=apply_router_weight_on_input,
275
+ no_combine=no_combine,
276
+ routed_scaling_factor=routed_scaling_factor,
277
+ )
278
+
246
279
  def forward_cpu(
247
280
  self,
248
281
  layer: torch.nn.Module,
@@ -264,10 +297,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
264
297
  ) -> torch.Tensor:
265
298
  assert activation == "silu", f"activation = {activation} is not supported."
266
299
 
267
- if (
268
- getattr(layer, "use_intel_amx_backend", False)
269
- and not apply_router_weight_on_input
270
- ):
300
+ if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
271
301
  topk_weights, topk_ids = select_experts(
272
302
  hidden_states=x,
273
303
  router_logits=router_logits,
@@ -287,11 +317,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
287
317
  x,
288
318
  layer.w13_weight,
289
319
  layer.w2_weight,
290
- topk_weights.to(
291
- torch.float
292
- ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
320
+ topk_weights,
293
321
  topk_ids,
294
- True, # inplace
322
+ False, # inplace # See [Note] inplace should be False in fused_experts.
295
323
  False, # use_int8_w8a8
296
324
  False, # use_fp8_w8a16
297
325
  None, # w1_scale
@@ -321,6 +349,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
321
349
  routed_scaling_factor,
322
350
  )
323
351
 
352
+ def forward_npu(
353
+ self,
354
+ layer: torch.nn.Module,
355
+ x: torch.Tensor,
356
+ use_grouped_topk: bool,
357
+ top_k: int,
358
+ router_logits: torch.Tensor,
359
+ renormalize: bool,
360
+ topk_group: Optional[int] = None,
361
+ num_expert_group: Optional[int] = None,
362
+ num_fused_shared_experts: int = 0,
363
+ custom_routing_function: Optional[Callable] = None,
364
+ correction_bias: Optional[torch.Tensor] = None,
365
+ activation: str = "silu",
366
+ apply_router_weight_on_input: bool = False,
367
+ inplace: bool = True,
368
+ no_combine: bool = False,
369
+ routed_scaling_factor: Optional[float] = None,
370
+ ) -> torch.Tensor:
371
+ return moe_forward_native(
372
+ layer,
373
+ x,
374
+ use_grouped_topk,
375
+ top_k,
376
+ router_logits,
377
+ renormalize,
378
+ topk_group,
379
+ num_expert_group,
380
+ num_fused_shared_experts,
381
+ custom_routing_function,
382
+ correction_bias,
383
+ activation,
384
+ apply_router_weight_on_input,
385
+ inplace,
386
+ no_combine,
387
+ routed_scaling_factor,
388
+ )
389
+
324
390
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
325
391
  raise NotImplementedError("The TPU backend currently does not support MoE.")
326
392
 
@@ -438,9 +504,13 @@ class FusedMoE(torch.nn.Module):
438
504
  self.inplace = inplace
439
505
  self.no_combine = no_combine
440
506
 
507
+ self.use_triton_kernels = (
508
+ not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
509
+ )
510
+
441
511
  if quant_config is None:
442
- self.quant_method: Optional[QuantizeMethodBase] = (
443
- UnquantizedFusedMoEMethod()
512
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
513
+ self.use_triton_kernels
444
514
  )
445
515
  else:
446
516
  self.quant_method = quant_config.get_quant_method(self, prefix)
@@ -537,11 +607,6 @@ class FusedMoE(torch.nn.Module):
537
607
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
538
608
  shard_size = expert_data.shape[shard_dim] // 2
539
609
 
540
- if not self.use_presharded_weights:
541
- loaded_weight = loaded_weight.narrow(
542
- shard_dim, shard_size * tp_rank, shard_size
543
- )
544
-
545
610
  # Narrow parameter and load.
546
611
  # w1, gate_proj: Load into first logical weight of w13.
547
612
  # w3, up_proj: Load into second logical weight of w13.
@@ -552,7 +617,26 @@ class FusedMoE(torch.nn.Module):
552
617
  start = shard_size
553
618
  else:
554
619
  start = 0
555
- expert_data = expert_data.narrow(shard_dim, start, shard_size)
620
+
621
+ if _is_cpu:
622
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
623
+ expert_data,
624
+ loaded_weight,
625
+ start,
626
+ shard_size * tp_rank,
627
+ shard_dim,
628
+ shard_size,
629
+ not self.use_presharded_weights,
630
+ )
631
+ else:
632
+ if not self.use_presharded_weights:
633
+ if self.use_triton_kernels:
634
+ loaded_weight = loaded_weight.transpose(-2, -1)
635
+ loaded_weight = loaded_weight.narrow(
636
+ shard_dim, shard_size * tp_rank, shard_size
637
+ )
638
+
639
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
556
640
  expert_data.copy_(loaded_weight)
557
641
 
558
642
  def _load_w2(
@@ -563,16 +647,54 @@ class FusedMoE(torch.nn.Module):
563
647
  loaded_weight: torch.tensor,
564
648
  tp_rank: int,
565
649
  ):
650
+ """Load w2 weights for down projection.
651
+
652
+ Args:
653
+ expert_data: The expert data tensor to load into
654
+ shard_dim: The dimension to shard along
655
+ shard_id: The shard ID (must be "w2")
656
+ loaded_weight: The weight tensor to load from
657
+ tp_rank: The tensor parallel rank
658
+ """
659
+ if not isinstance(expert_data, torch.Tensor) or not isinstance(
660
+ loaded_weight, torch.Tensor
661
+ ):
662
+ raise ValueError("expert_data and loaded_weight must be torch.Tensor")
663
+
664
+ if expert_data.dim() != 2 or loaded_weight.dim() != 2:
665
+ raise ValueError(
666
+ f"Expected 2D tensors, got expert_data shape {expert_data.shape} and loaded_weight shape {loaded_weight.shape}"
667
+ )
668
+
669
+ if shard_id != "w2":
670
+ raise ValueError(f"shard_id must be 'w2', got {shard_id}")
566
671
 
567
672
  # Index the loaded weight for tp sharding.
568
673
  # down_proj: "RowParallel" so tp sharding on input_dim
569
674
  # Narrow parameter and load.
570
675
  shard_size = expert_data.shape[shard_dim]
571
676
 
572
- if not self.use_presharded_weights:
573
- loaded_weight = loaded_weight.narrow(
574
- shard_dim, shard_size * tp_rank, shard_size
677
+ if _is_cpu:
678
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
679
+ expert_data,
680
+ loaded_weight,
681
+ 0, # param_data_start
682
+ shard_size * tp_rank,
683
+ shard_dim,
684
+ shard_size,
685
+ not self.use_presharded_weights,
575
686
  )
687
+ else:
688
+ if not self.use_presharded_weights:
689
+ if self.use_triton_kernels:
690
+ loaded_weight = loaded_weight.transpose(-2, -1)
691
+ if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
692
+ raise ValueError(
693
+ f"Shard size {shard_size} at rank {tp_rank} exceeds loaded_weight dimension {loaded_weight.shape[shard_dim]}"
694
+ )
695
+ loaded_weight = loaded_weight.narrow(
696
+ shard_dim, shard_size * tp_rank, shard_size
697
+ )
576
698
 
577
699
  # w2, down_proj: Load into only logical weight of w2.
578
700
  expert_data.copy_(loaded_weight)
@@ -656,6 +778,8 @@ class FusedMoE(torch.nn.Module):
656
778
  # should be whatever dimension intermediate_size is
657
779
  is_transposed = getattr(param, "is_transposed", False)
658
780
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
781
+ if self.use_triton_kernels:
782
+ is_transposed = True
659
783
  if is_transposed:
660
784
  shard_dim = int(not shard_dim)
661
785
 
@@ -694,8 +818,21 @@ class FusedMoE(torch.nn.Module):
694
818
  tp_rank=tp_rank,
695
819
  )
696
820
  return
821
+
697
822
  if "ModelOpt" in self.quant_method.__class__.__name__:
698
- if "weight_scale_2" in weight_name or "input_scale" in weight_name:
823
+ # Determine per-tensor weight scale patterns based on variant
824
+ is_fp4_variant = (
825
+ "ModelOptNvFp4FusedMoEMethod" in self.quant_method.__class__.__name__
826
+ )
827
+
828
+ # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
829
+ per_tensor_conditions = (
830
+ "weight_scale_2" in weight_name
831
+ if is_fp4_variant
832
+ else "weight_scale" in weight_name
833
+ ) or "input_scale" in weight_name
834
+
835
+ if per_tensor_conditions:
699
836
  self._load_per_tensor_weight_scale(
700
837
  shard_id=shard_id,
701
838
  param=param,
@@ -0,0 +1,176 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from sgl_kernel import gelu_and_mul, silu_and_mul
6
+ from triton_kernels.matmul_ogs import matmul_ogs
7
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
8
+
9
+ from sglang.srt.utils import direct_register_custom_op
10
+
11
+
12
+ def triton_kernel_moe_forward(
13
+ hidden_states: torch.Tensor,
14
+ w1: torch.Tensor,
15
+ w2: torch.Tensor,
16
+ gating_output: torch.Tensor,
17
+ topk: int,
18
+ renormalize: bool,
19
+ inplace: bool = False,
20
+ activation: str = "silu",
21
+ apply_router_weight_on_input: bool = False,
22
+ use_fp8_w8a8: bool = False,
23
+ per_channel_quant: bool = False,
24
+ global_num_experts: int = -1,
25
+ expert_map: Optional[torch.Tensor] = None,
26
+ w1_scale: Optional[torch.Tensor] = None,
27
+ w2_scale: Optional[torch.Tensor] = None,
28
+ a1_scale: Optional[torch.Tensor] = None,
29
+ a2_scale: Optional[torch.Tensor] = None,
30
+ block_shape: Optional[list[int]] = None,
31
+ ) -> torch.Tensor:
32
+
33
+ if not renormalize:
34
+ gating_output = torch.softmax(gating_output, dim=-1)
35
+ routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
36
+
37
+ return triton_kernel_fused_experts(
38
+ hidden_states,
39
+ w1,
40
+ w2,
41
+ routing_data,
42
+ gather_idx,
43
+ scatter_idx,
44
+ inplace=inplace,
45
+ activation=activation,
46
+ apply_router_weight_on_input=apply_router_weight_on_input,
47
+ use_fp8_w8a8=use_fp8_w8a8,
48
+ per_channel_quant=per_channel_quant,
49
+ global_num_experts=global_num_experts,
50
+ expert_map=expert_map,
51
+ w1_scale=w1_scale,
52
+ w2_scale=w2_scale,
53
+ a1_scale=a1_scale,
54
+ a2_scale=a2_scale,
55
+ block_shape=block_shape,
56
+ )
57
+
58
+
59
+ # This is a triton implementation of the fused_experts function
60
+ def triton_kernel_fused_experts(
61
+ hidden_states: torch.Tensor,
62
+ w1: torch.Tensor,
63
+ w2: torch.Tensor,
64
+ routing_data: RoutingData,
65
+ gather_indx: GatherIndx,
66
+ scatter_indx: ScatterIndx,
67
+ inplace: bool = False,
68
+ activation: str = "silu",
69
+ apply_router_weight_on_input: bool = False,
70
+ use_fp8_w8a8: bool = False,
71
+ per_channel_quant: bool = False,
72
+ global_num_experts: int = -1,
73
+ expert_map: Optional[torch.Tensor] = None,
74
+ w1_scale: Optional[torch.Tensor] = None,
75
+ w2_scale: Optional[torch.Tensor] = None,
76
+ a1_scale: Optional[torch.Tensor] = None,
77
+ a2_scale: Optional[torch.Tensor] = None,
78
+ block_shape: Optional[list[int]] = None,
79
+ ) -> torch.Tensor:
80
+
81
+ assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
82
+ assert per_channel_quant == False, "per_channel_quant is not supported"
83
+ assert expert_map == None, "expert_map is not supported"
84
+ assert w1_scale == None, "w1_scale is not supported"
85
+ assert w2_scale == None, "w2_scale is not supported"
86
+ assert a1_scale == None, "a1_scale is not supported"
87
+ assert a2_scale == None, "a2_scale is not supported"
88
+ assert block_shape == None, "block_shape is not supported"
89
+
90
+ # type check
91
+ assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
92
+ assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
93
+ assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
94
+
95
+ # Shape check
96
+ assert hidden_states.ndim == 2, "hidden_states must be 2D"
97
+ assert (
98
+ hidden_states.shape[-1] == w1.shape[-2]
99
+ ), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
100
+ assert (
101
+ w2.shape[-1] == w1.shape[1]
102
+ ), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
103
+
104
+ # feature check
105
+ assert inplace == False, "Inplace is not supported in new triton MoE kernel"
106
+
107
+ M, K = hidden_states.shape
108
+ E, _, N = w1.shape
109
+ n_expts_act = routing_data.n_expts_act
110
+ dtype = hidden_states.dtype
111
+
112
+ if global_num_experts == -1:
113
+ global_num_experts = E
114
+
115
+ # consistent with default implementation
116
+ intermediate_cache2 = torch.empty(
117
+ (M * n_expts_act, N // 2), device="cuda", dtype=dtype
118
+ )
119
+
120
+ intermediate_cache1 = matmul_ogs(
121
+ hidden_states,
122
+ w1,
123
+ None,
124
+ routing_data,
125
+ gather_indx=gather_indx,
126
+ gammas=routing_data.gate_scal if apply_router_weight_on_input else None,
127
+ )
128
+
129
+ if activation == "silu":
130
+ silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
131
+ elif activation == "gelu":
132
+ gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
133
+ else:
134
+ raise ValueError(f"Unsupported FusedMoe activation: {activation}")
135
+
136
+ intermediate_cache3 = matmul_ogs(
137
+ intermediate_cache2,
138
+ w2,
139
+ None,
140
+ routing_data,
141
+ scatter_indx=scatter_indx,
142
+ gammas=None if apply_router_weight_on_input else routing_data.gate_scal,
143
+ )
144
+
145
+ return intermediate_cache3
146
+
147
+
148
+ def triton_kernel_moe_forward_fake(
149
+ hidden_states: torch.Tensor,
150
+ w1: torch.Tensor,
151
+ w2: torch.Tensor,
152
+ gating_output: torch.Tensor,
153
+ topk: int,
154
+ renormalize: bool,
155
+ inplace: bool = False,
156
+ activation: str = "silu",
157
+ apply_router_weight_on_input: bool = False,
158
+ use_fp8_w8a8: bool = False,
159
+ per_channel_quant: bool = False,
160
+ global_num_experts: int = -1,
161
+ expert_map: Optional[torch.Tensor] = None,
162
+ w1_scale: Optional[torch.Tensor] = None,
163
+ w2_scale: Optional[torch.Tensor] = None,
164
+ a1_scale: Optional[torch.Tensor] = None,
165
+ a2_scale: Optional[torch.Tensor] = None,
166
+ block_shape: Optional[list[int]] = None,
167
+ ) -> torch.Tensor:
168
+ return torch.empty_like(hidden_states)
169
+
170
+
171
+ direct_register_custom_op(
172
+ op_name="forward_cuda_triton",
173
+ op_func=triton_kernel_moe_forward,
174
+ mutates_args=[],
175
+ fake_impl=triton_kernel_moe_forward_fake,
176
+ )