sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
14
14
  silu_and_mul_masked_post_quant_fwd,
15
15
  tma_align_input_scale,
16
16
  )
17
- from sglang.srt.layers.moe.fused_moe_triton.layer import (
18
- FlashInferFusedMoE,
19
- FusedMoE,
20
- should_use_flashinfer_trtllm_moe,
21
- )
17
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
22
18
  from sglang.srt.layers.moe.topk import TopKOutput
23
- from sglang.srt.layers.moe.utils import DeepEPMode
19
+ from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
24
20
  from sglang.srt.layers.quantization import deep_gemm_wrapper
25
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
22
  from sglang.srt.layers.quantization.fp8 import (
@@ -48,7 +44,6 @@ _is_npu = is_npu()
48
44
  _is_fp8_fnuz = is_fp8_fnuz()
49
45
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
50
46
 
51
-
52
47
  if not (_is_npu or _is_hip):
53
48
  from sgl_kernel import silu_and_mul
54
49
 
@@ -60,6 +55,22 @@ if _use_aiter:
60
55
  logger = logging.getLogger(__name__)
61
56
 
62
57
 
58
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
59
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
60
+ @torch.compile
61
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
62
+ temp = x.to(torch.float32).view(torch.int32)
63
+ exp = torch.bitwise_right_shift(temp, 23)
64
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
65
+ is_ru = torch.logical_and(
66
+ torch.logical_and((mant > 0), (exp != 0xFE)),
67
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
68
+ )
69
+ exp = torch.where(is_ru, exp + 1, exp)
70
+ new_x = exp.to(torch.uint8).view(torch.int)
71
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
72
+
73
+
63
74
  class EPMoE(FusedMoE):
64
75
  """
65
76
  MoE Expert Parallel Impl
@@ -81,6 +92,9 @@ class EPMoE(FusedMoE):
81
92
  prefix: str = "",
82
93
  activation: str = "silu",
83
94
  routed_scaling_factor: Optional[float] = None,
95
+ activation_alpha: Optional[float] = None,
96
+ swiglu_limit: Optional[float] = None,
97
+ with_bias: bool = False,
84
98
  ):
85
99
  super().__init__(
86
100
  num_experts=num_experts,
@@ -96,6 +110,9 @@ class EPMoE(FusedMoE):
96
110
  activation=activation,
97
111
  # apply_router_weight_on_input=apply_router_weight_on_input,
98
112
  routed_scaling_factor=routed_scaling_factor,
113
+ activation_alpha=activation_alpha,
114
+ swiglu_limit=swiglu_limit,
115
+ with_bias=with_bias,
99
116
  )
100
117
 
101
118
  self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -203,10 +220,22 @@ class EPMoE(FusedMoE):
203
220
 
204
221
  dispose_tensor(hidden_states)
205
222
 
223
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
224
+ b, s_mn, s_k = gateup_input_scale.shape
225
+ assert (
226
+ s_mn % 4 == 0 and s_k % 4 == 0
227
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
228
+
206
229
  # GroupGemm-0
207
230
  gateup_input_fp8 = (
208
231
  gateup_input,
209
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
232
+ (
233
+ _cast_to_e8m0_with_rounding_up(gateup_input_scale)
234
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
235
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
236
+ gateup_input_scale
237
+ )
238
+ ),
210
239
  )
211
240
  num_groups, m, k = gateup_input_fp8[0].size()
212
241
  n = self.w13_weight.size(1)
@@ -214,7 +243,12 @@ class EPMoE(FusedMoE):
214
243
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
215
244
  )
216
245
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
217
- gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
246
+ gateup_input_fp8,
247
+ self.w13_weight_fp8,
248
+ gateup_output,
249
+ masked_m,
250
+ expected_m,
251
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
218
252
  )
219
253
  del gateup_input
220
254
  del gateup_input_fp8
@@ -245,6 +279,7 @@ class EPMoE(FusedMoE):
245
279
  down_input_scale,
246
280
  scale_block_size,
247
281
  masked_m,
282
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
248
283
  )
249
284
  del gateup_output
250
285
 
@@ -252,13 +287,24 @@ class EPMoE(FusedMoE):
252
287
  n = self.w2_weight.size(1)
253
288
  down_input_fp8 = (
254
289
  down_input,
255
- deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
290
+ (
291
+ down_input_scale
292
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
293
+ else deep_gemm_wrapper.get_col_major_tma_aligned_tensor(
294
+ down_input_scale
295
+ )
296
+ ),
256
297
  )
257
298
  down_output = torch.empty(
258
299
  (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
259
300
  )
260
301
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
261
- down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
302
+ down_input_fp8,
303
+ self.w2_weight_fp8,
304
+ down_output,
305
+ masked_m,
306
+ expected_m,
307
+ recipe=(1, 128, 128) if deep_gemm_wrapper.DEEPGEMM_BLACKWELL else None,
262
308
  )
263
309
  del down_input
264
310
  del down_input_fp8
@@ -678,71 +724,29 @@ class DeepEPMoE(EPMoE):
678
724
  return down_output
679
725
 
680
726
 
681
- class FlashInferEPMoE(EPMoE):
682
- def __init__(self, *args, **kwargs):
683
- renormalize = kwargs.pop("renormalize", True)
684
- num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
685
- use_grouped_topk = kwargs.pop("use_grouped_topk", False)
686
- num_expert_group = kwargs.pop("num_expert_group", None)
687
- topk_group = kwargs.pop("topk_group", None)
688
- correction_bias = kwargs.pop("correction_bias", None)
689
- super().__init__(*args, **kwargs)
690
- self.renormalize = renormalize
691
- self.num_fused_shared_experts = num_fused_shared_experts
692
- self.use_grouped_topk = use_grouped_topk
693
- if self.use_grouped_topk:
694
- assert num_expert_group is not None and topk_group is not None
695
- self.num_expert_group = num_expert_group
696
- self.topk_group = topk_group
697
- self.correction_bias = correction_bias
698
- self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
699
-
700
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
701
- assert self.use_flashinfer_trtllm_moe
702
- assert (
703
- self.activation == "silu"
704
- ), "Only silu is supported for flashinfer blockscale fp8 moe"
705
- assert (
706
- self.renormalize
707
- ), "Renormalize is required for flashinfer blockscale fp8 moe"
708
- assert (
709
- self.num_fused_shared_experts == 0
710
- ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
711
- a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
712
- # NOTE: scales of hidden states have to be transposed!
713
- a_sf_t = a_sf.t().contiguous()
714
- from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
715
-
716
- return trtllm_fp8_block_scale_moe(
717
- routing_logits=router_logits.to(torch.float32),
718
- routing_bias=self.correction_bias.to(hidden_states.dtype),
719
- hidden_states=a_q,
720
- hidden_states_scale=a_sf_t,
721
- gemm1_weights=self.w13_weight,
722
- gemm1_weights_scale=self.w13_weight_scale_inv,
723
- gemm2_weights=self.w2_weight,
724
- gemm2_weights_scale=self.w2_weight_scale_inv,
725
- num_experts=self.num_experts,
726
- top_k=self.top_k,
727
- n_group=self.num_expert_group,
728
- topk_group=self.topk_group,
729
- intermediate_size=self.w2_weight.shape[2],
730
- local_expert_offset=self.start_expert_id,
731
- local_num_experts=self.num_local_experts,
732
- routed_scaling_factor=self.routed_scaling_factor,
733
- tile_tokens_dim=get_tile_tokens_dim(
734
- hidden_states.shape[0], self.top_k, self.num_experts
735
- ),
736
- routing_method_type=2, # DeepSeek-styled routing method
737
- use_shuffled_weight=False,
738
- )
739
-
740
-
741
727
  def get_moe_impl_class():
742
728
  if global_server_args_dict["moe_a2a_backend"].is_deepep():
743
729
  return DeepEPMoE
730
+
731
+ # NEW: Direct FP4 detection (bypasses EP requirements)
732
+ # Check for FP4 quantization with TRTLLM flag, regardless of EP
733
+ if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
734
+ try:
735
+ # Check the quantization argument directly
736
+ quantization = global_server_args_dict.get("quantization")
737
+ if quantization == "modelopt_fp4":
738
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
739
+ FlashInferFP4MoE,
740
+ )
741
+
742
+ return FlashInferFP4MoE
743
+ except:
744
+ pass
745
+
746
+ if should_use_flashinfer_trtllm_moe():
747
+ return FlashInferFusedMoE
744
748
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
745
749
  return FusedMoE
746
750
  if get_moe_expert_parallel_world_size() > 1:
747
- return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
748
- return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
751
+ return EPMoE
752
+ return FusedMoE
@@ -319,6 +319,7 @@ def fused_moe_kernel(
319
319
  # Pointers to matrices
320
320
  a_ptr,
321
321
  b_ptr,
322
+ bias_ptr,
322
323
  c_ptr,
323
324
  a_scale_ptr,
324
325
  b_scale_ptr,
@@ -340,6 +341,8 @@ def fused_moe_kernel(
340
341
  stride_be,
341
342
  stride_bk,
342
343
  stride_bn,
344
+ stride_bias_e,
345
+ stride_bias_n,
343
346
  stride_cm,
344
347
  stride_cn,
345
348
  stride_asm,
@@ -449,6 +452,10 @@ def fused_moe_kernel(
449
452
  + off_experts * stride_be
450
453
  + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
451
454
  )
455
+ if bias_ptr is not None:
456
+ bias = tl.load(
457
+ bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
458
+ )
452
459
  if use_int8_w8a16:
453
460
  b_scale_ptrs = (
454
461
  b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
@@ -526,18 +533,20 @@ def fused_moe_kernel(
526
533
  a_ptrs += BLOCK_SIZE_K * stride_ak
527
534
  b_ptrs += BLOCK_SIZE_K * stride_bk
528
535
 
529
- if MUL_ROUTED_WEIGHT:
530
- moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
531
- accumulator = accumulator * moe_weight[:, None]
532
536
  if use_int8_w8a16:
533
- accumulator = (accumulator * b_scale).to(compute_type)
537
+ accumulator *= b_scale
534
538
  elif use_fp8_w8a8 or use_int8_w8a8:
535
- if group_k > 0 and group_n > 0:
536
- accumulator = accumulator.to(compute_type)
537
- else:
538
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
539
- else:
540
- accumulator = accumulator.to(compute_type)
539
+ if group_k == 0 or group_n == 0:
540
+ accumulator *= a_scale * b_scale
541
+
542
+ if bias_ptr is not None:
543
+ accumulator += bias
544
+
545
+ if MUL_ROUTED_WEIGHT:
546
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
547
+ accumulator *= moe_weight[:, None]
548
+
549
+ accumulator = accumulator.to(compute_type)
541
550
  # -----------------------------------------------------------
542
551
  # Write back the block of the output
543
552
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -622,6 +631,7 @@ def moe_align_block_size(
622
631
  def invoke_fused_moe_kernel(
623
632
  A: torch.Tensor,
624
633
  B: torch.Tensor,
634
+ bias: Optional[torch.Tensor],
625
635
  C: torch.Tensor,
626
636
  A_scale: Optional[torch.Tensor],
627
637
  B_scale: Optional[torch.Tensor],
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
711
721
  ):
712
722
  assert B_scale is not None and B_scale.ndim == 3
713
723
  assert B_zp is None or B_zp.ndim == 3
724
+ assert bias is None
714
725
  fused_moe_kernel_gptq_awq[grid](
715
726
  A,
716
727
  B,
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
754
765
  fused_moe_kernel[grid](
755
766
  A,
756
767
  B,
768
+ bias,
757
769
  C,
758
770
  A_scale,
759
771
  B_scale,
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
770
782
  B.stride(0),
771
783
  B.stride(2),
772
784
  B.stride(1),
785
+ bias.stride(0) if bias is not None else 0,
786
+ bias.stride(1) if bias is not None else 0,
773
787
  C.stride(1),
774
788
  C.stride(2),
775
789
  A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
994
1008
  w2: torch.Tensor,
995
1009
  topk_weights: torch.Tensor,
996
1010
  topk_ids: torch.Tensor,
1011
+ b1: Optional[torch.Tensor] = None,
1012
+ b2: Optional[torch.Tensor] = None,
997
1013
  activation: str = "silu",
998
1014
  apply_router_weight_on_input: bool = False,
999
1015
  use_fp8_w8a8: bool = False,
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
1009
1025
  a2_scale: Optional[torch.Tensor] = None,
1010
1026
  block_shape: Optional[List[int]] = None,
1011
1027
  routed_scaling_factor: Optional[float] = None,
1028
+ activation_alpha: Optional[float] = None,
1029
+ swiglu_limit: Optional[float] = None,
1012
1030
  ) -> None:
1013
1031
  fused_experts_impl(
1014
1032
  hidden_states,
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
1016
1034
  w2,
1017
1035
  topk_weights,
1018
1036
  topk_ids,
1037
+ b1,
1038
+ b2,
1019
1039
  True,
1020
1040
  activation,
1021
1041
  apply_router_weight_on_input,
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
1033
1053
  block_shape,
1034
1054
  False,
1035
1055
  routed_scaling_factor,
1056
+ activation_alpha,
1057
+ swiglu_limit,
1036
1058
  )
1037
1059
 
1038
1060
 
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
1042
1064
  w2: torch.Tensor,
1043
1065
  topk_weights: torch.Tensor,
1044
1066
  topk_ids: torch.Tensor,
1067
+ b1: Optional[torch.Tensor] = None,
1068
+ b2: Optional[torch.Tensor] = None,
1045
1069
  activation: str = "silu",
1046
1070
  apply_router_weight_on_input: bool = False,
1047
1071
  use_fp8_w8a8: bool = False,
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
1057
1081
  a2_scale: Optional[torch.Tensor] = None,
1058
1082
  block_shape: Optional[List[int]] = None,
1059
1083
  routed_scaling_factor: Optional[float] = None,
1084
+ activation_alpha: Optional[float] = None,
1085
+ swiglu_limit: Optional[float] = None,
1060
1086
  ) -> None:
1061
1087
  pass
1062
1088
 
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
1075
1101
  w2: torch.Tensor,
1076
1102
  topk_weights: torch.Tensor,
1077
1103
  topk_ids: torch.Tensor,
1104
+ b1: Optional[torch.Tensor] = None,
1105
+ b2: Optional[torch.Tensor] = None,
1078
1106
  activation: str = "silu",
1079
1107
  apply_router_weight_on_input: bool = False,
1080
1108
  use_fp8_w8a8: bool = False,
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
1091
1119
  block_shape: Optional[List[int]] = None,
1092
1120
  no_combine: bool = False,
1093
1121
  routed_scaling_factor: Optional[float] = None,
1122
+ activation_alpha: Optional[float] = None,
1123
+ swiglu_limit: Optional[float] = None,
1094
1124
  ) -> torch.Tensor:
1095
1125
  return fused_experts_impl(
1096
1126
  hidden_states,
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
1098
1128
  w2,
1099
1129
  topk_weights,
1100
1130
  topk_ids,
1131
+ b1,
1132
+ b2,
1101
1133
  False,
1102
1134
  activation,
1103
1135
  apply_router_weight_on_input,
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
1115
1147
  block_shape,
1116
1148
  no_combine=no_combine,
1117
1149
  routed_scaling_factor=routed_scaling_factor,
1150
+ activation_alpha=activation_alpha,
1151
+ swiglu_limit=swiglu_limit,
1118
1152
  )
1119
1153
 
1120
1154
 
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
1124
1158
  w2: torch.Tensor,
1125
1159
  topk_weights: torch.Tensor,
1126
1160
  topk_ids: torch.Tensor,
1161
+ b1: Optional[torch.Tensor] = None,
1162
+ b2: Optional[torch.Tensor] = None,
1127
1163
  activation: str = "silu",
1128
1164
  apply_router_weight_on_input: bool = False,
1129
1165
  use_fp8_w8a8: bool = False,
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
1140
1176
  block_shape: Optional[List[int]] = None,
1141
1177
  no_combine: bool = False,
1142
1178
  routed_scaling_factor: Optional[float] = None,
1179
+ activation_alpha: Optional[float] = None,
1180
+ swiglu_limit: Optional[float] = None,
1143
1181
  ) -> torch.Tensor:
1144
1182
  return torch.empty_like(hidden_states)
1145
1183
 
@@ -1157,6 +1195,8 @@ def fused_experts(
1157
1195
  w1: torch.Tensor,
1158
1196
  w2: torch.Tensor,
1159
1197
  topk_output: TopKOutput,
1198
+ b1: Optional[torch.Tensor] = None,
1199
+ b2: Optional[torch.Tensor] = None,
1160
1200
  inplace: bool = False,
1161
1201
  activation: str = "silu",
1162
1202
  apply_router_weight_on_input: bool = False,
@@ -1174,6 +1214,8 @@ def fused_experts(
1174
1214
  block_shape: Optional[List[int]] = None,
1175
1215
  no_combine: bool = False,
1176
1216
  routed_scaling_factor: Optional[float] = None,
1217
+ activation_alpha: Optional[float] = None,
1218
+ swiglu_limit: Optional[float] = None,
1177
1219
  ):
1178
1220
  topk_weights, topk_ids, _ = topk_output
1179
1221
  if inplace:
@@ -1184,6 +1226,8 @@ def fused_experts(
1184
1226
  w2,
1185
1227
  topk_weights,
1186
1228
  topk_ids,
1229
+ b1,
1230
+ b2,
1187
1231
  activation,
1188
1232
  apply_router_weight_on_input,
1189
1233
  use_fp8_w8a8,
@@ -1199,6 +1243,8 @@ def fused_experts(
1199
1243
  a2_scale,
1200
1244
  block_shape,
1201
1245
  routed_scaling_factor,
1246
+ activation_alpha,
1247
+ swiglu_limit,
1202
1248
  )
1203
1249
  return hidden_states
1204
1250
  else:
@@ -1208,6 +1254,8 @@ def fused_experts(
1208
1254
  w2,
1209
1255
  topk_weights,
1210
1256
  topk_ids,
1257
+ b1,
1258
+ b2,
1211
1259
  activation,
1212
1260
  apply_router_weight_on_input,
1213
1261
  use_fp8_w8a8,
@@ -1224,6 +1272,8 @@ def fused_experts(
1224
1272
  block_shape,
1225
1273
  no_combine=no_combine,
1226
1274
  routed_scaling_factor=routed_scaling_factor,
1275
+ activation_alpha=activation_alpha,
1276
+ swiglu_limit=swiglu_limit,
1227
1277
  )
1228
1278
 
1229
1279
 
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
1319
1369
  out.mul_(routed_scaling_factor)
1320
1370
 
1321
1371
 
1372
+ @torch.compile
1373
+ def swiglu_with_alpha_and_limit(x, alpha, limit):
1374
+ gate, up = x[..., ::2], x[..., 1::2]
1375
+ gate = gate.clamp(min=None, max=limit)
1376
+ up = up.clamp(min=-limit, max=limit)
1377
+ return gate * torch.sigmoid(gate * alpha) * (up + 1)
1378
+
1379
+
1322
1380
  def fused_experts_impl(
1323
1381
  hidden_states: torch.Tensor,
1324
1382
  w1: torch.Tensor,
1325
1383
  w2: torch.Tensor,
1326
1384
  topk_weights: torch.Tensor,
1327
1385
  topk_ids: torch.Tensor,
1386
+ b1: Optional[torch.Tensor] = None,
1387
+ b2: Optional[torch.Tensor] = None,
1328
1388
  inplace: bool = False,
1329
1389
  activation: str = "silu",
1330
1390
  apply_router_weight_on_input: bool = False,
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
1342
1402
  block_shape: Optional[List[int]] = None,
1343
1403
  no_combine: bool = False,
1344
1404
  routed_scaling_factor: Optional[float] = None,
1405
+ activation_alpha: Optional[float] = None,
1406
+ swiglu_limit: Optional[float] = None,
1345
1407
  ):
1346
1408
  padded_size = padding_size
1347
1409
  if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
1353
1415
  else:
1354
1416
  assert (
1355
1417
  hidden_states.shape[1] == w1.shape[2] - padded_size
1356
- ), "Hidden size mismatch"
1418
+ ), f"Hidden size mismatch"
1357
1419
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1358
1420
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1359
1421
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
1449
1511
  invoke_fused_moe_kernel(
1450
1512
  curr_hidden_states,
1451
1513
  w1,
1514
+ b1,
1452
1515
  intermediate_cache1,
1453
1516
  a1_scale,
1454
1517
  w1_scale,
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
1470
1533
  block_shape=block_shape,
1471
1534
  )
1472
1535
  if activation == "silu":
1473
- if _is_cuda:
1536
+ if activation_alpha is not None:
1537
+ assert swiglu_limit is not None
1538
+ intermediate_cache2 = swiglu_with_alpha_and_limit(
1539
+ intermediate_cache1.view(-1, N),
1540
+ activation_alpha,
1541
+ swiglu_limit,
1542
+ )
1543
+ elif _is_cuda:
1474
1544
  silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1475
1545
  else:
1476
1546
  vllm_ops.silu_and_mul(
1477
1547
  intermediate_cache2, intermediate_cache1.view(-1, N)
1478
1548
  )
1479
1549
  elif activation == "gelu":
1550
+ assert (
1551
+ activation_alpha is None
1552
+ ), "activation_alpha is not supported for gelu"
1553
+ assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
1480
1554
  if _is_cuda:
1481
1555
  gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1482
1556
  else:
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
1489
1563
  invoke_fused_moe_kernel(
1490
1564
  intermediate_cache2,
1491
1565
  w2,
1566
+ b2,
1492
1567
  (
1493
1568
  intermediate_cache3
1494
1569
  if not no_combine and topk_ids.shape[1] != 1
@@ -1567,6 +1642,8 @@ def fused_moe(
1567
1642
  w1: torch.Tensor,
1568
1643
  w2: torch.Tensor,
1569
1644
  topk_output: TopKOutput,
1645
+ b1: Optional[torch.Tensor] = None,
1646
+ b2: Optional[torch.Tensor] = None,
1570
1647
  inplace: bool = False,
1571
1648
  activation: str = "silu",
1572
1649
  apply_router_weight_on_input: bool = False,
@@ -1584,6 +1661,8 @@ def fused_moe(
1584
1661
  block_shape: Optional[List[int]] = None,
1585
1662
  no_combine: bool = False,
1586
1663
  routed_scaling_factor: Optional[float] = None,
1664
+ activation_alpha: Optional[float] = None,
1665
+ swiglu_limit: Optional[float] = None,
1587
1666
  ) -> torch.Tensor:
1588
1667
  """
1589
1668
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1594,6 +1673,8 @@ def fused_moe(
1594
1673
  - w1 (torch.Tensor): The first set of expert weights.
1595
1674
  - w2 (torch.Tensor): The second set of expert weights.
1596
1675
  - topk_output (TopKOutput): The top-k output of the experts.
1676
+ - b1 (Optional[torch.Tensor]): Optional bias for w1.
1677
+ - b2 (Optional[torch.Tensor]): Optional bias for w2.
1597
1678
  - inplace (bool): If True, perform the operation in-place.
1598
1679
  Defaults to False.
1599
1680
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
@@ -1615,6 +1696,10 @@ def fused_moe(
1615
1696
  a2.
1616
1697
  - block_shape: (Optional[List[int]]): Optional block size for block-wise
1617
1698
  quantization.
1699
+ - activation_alpha (Optional[float]): Optional alpha for the activation
1700
+ function.
1701
+ - swiglu_limit (Optional[float]): Optional limit for the swiglu activation
1702
+ function.
1618
1703
 
1619
1704
  Returns:
1620
1705
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -1625,6 +1710,8 @@ def fused_moe(
1625
1710
  w1,
1626
1711
  w2,
1627
1712
  topk_output,
1713
+ b1=b1,
1714
+ b2=b2,
1628
1715
  inplace=inplace,
1629
1716
  activation=activation,
1630
1717
  apply_router_weight_on_input=apply_router_weight_on_input,
@@ -1642,4 +1729,6 @@ def fused_moe(
1642
1729
  block_shape=block_shape,
1643
1730
  no_combine=no_combine,
1644
1731
  routed_scaling_factor=routed_scaling_factor,
1732
+ activation_alpha=activation_alpha,
1733
+ swiglu_limit=swiglu_limit,
1645
1734
  )