sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. sglang/__init__.py +8 -3
  2. sglang/bench_one_batch.py +119 -17
  3. sglang/lang/chat_template.py +18 -0
  4. sglang/srt/bench_utils.py +137 -0
  5. sglang/srt/configs/model_config.py +42 -7
  6. sglang/srt/conversation.py +9 -5
  7. sglang/srt/disaggregation/base/conn.py +5 -2
  8. sglang/srt/disaggregation/decode.py +14 -4
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  10. sglang/srt/disaggregation/mooncake/conn.py +286 -160
  11. sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
  12. sglang/srt/disaggregation/prefill.py +2 -0
  13. sglang/srt/distributed/parallel_state.py +15 -11
  14. sglang/srt/entrypoints/context.py +227 -0
  15. sglang/srt/entrypoints/engine.py +15 -9
  16. sglang/srt/entrypoints/harmony_utils.py +372 -0
  17. sglang/srt/entrypoints/http_server.py +74 -4
  18. sglang/srt/entrypoints/openai/protocol.py +218 -1
  19. sglang/srt/entrypoints/openai/serving_chat.py +41 -11
  20. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  21. sglang/srt/entrypoints/openai/tool_server.py +175 -0
  22. sglang/srt/entrypoints/tool.py +87 -0
  23. sglang/srt/eplb/expert_location.py +5 -1
  24. sglang/srt/function_call/ebnf_composer.py +1 -0
  25. sglang/srt/function_call/function_call_parser.py +2 -0
  26. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  27. sglang/srt/function_call/gpt_oss_detector.py +331 -0
  28. sglang/srt/function_call/kimik2_detector.py +3 -3
  29. sglang/srt/function_call/qwen3_coder_detector.py +219 -9
  30. sglang/srt/hf_transformers_utils.py +30 -3
  31. sglang/srt/jinja_template_utils.py +14 -1
  32. sglang/srt/layers/attention/aiter_backend.py +375 -115
  33. sglang/srt/layers/attention/ascend_backend.py +3 -0
  34. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +18 -0
  36. sglang/srt/layers/attention/flashinfer_backend.py +52 -13
  37. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  38. sglang/srt/layers/attention/triton_backend.py +85 -14
  39. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  41. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  42. sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
  43. sglang/srt/layers/attention/vision.py +22 -6
  44. sglang/srt/layers/attention/wave_backend.py +627 -0
  45. sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
  46. sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
  47. sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
  48. sglang/srt/layers/communicator.py +29 -14
  49. sglang/srt/layers/dp_attention.py +12 -0
  50. sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
  51. sglang/srt/layers/linear.py +3 -7
  52. sglang/srt/layers/moe/cutlass_moe.py +12 -3
  53. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
  54. sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
  55. sglang/srt/layers/moe/ep_moe/layer.py +135 -73
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  59. sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
  60. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  61. sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
  62. sglang/srt/layers/moe/topk.py +16 -4
  63. sglang/srt/layers/moe/utils.py +16 -0
  64. sglang/srt/layers/quantization/__init__.py +27 -3
  65. sglang/srt/layers/quantization/fp4.py +557 -0
  66. sglang/srt/layers/quantization/fp8.py +3 -6
  67. sglang/srt/layers/quantization/fp8_kernel.py +277 -0
  68. sglang/srt/layers/quantization/fp8_utils.py +51 -10
  69. sglang/srt/layers/quantization/modelopt_quant.py +258 -68
  70. sglang/srt/layers/quantization/mxfp4.py +654 -0
  71. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  72. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  73. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  74. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  75. sglang/srt/layers/quantization/quark/utils.py +107 -0
  76. sglang/srt/layers/quantization/unquant.py +60 -6
  77. sglang/srt/layers/quantization/w4afp8.py +21 -12
  78. sglang/srt/layers/quantization/w8a8_int8.py +48 -34
  79. sglang/srt/layers/rotary_embedding.py +506 -3
  80. sglang/srt/layers/utils.py +9 -0
  81. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  82. sglang/srt/lora/backend/base_backend.py +3 -23
  83. sglang/srt/lora/layers.py +60 -114
  84. sglang/srt/lora/lora.py +17 -62
  85. sglang/srt/lora/lora_manager.py +82 -62
  86. sglang/srt/lora/lora_registry.py +23 -11
  87. sglang/srt/lora/mem_pool.py +63 -68
  88. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  89. sglang/srt/lora/utils.py +25 -58
  90. sglang/srt/managers/cache_controller.py +75 -58
  91. sglang/srt/managers/detokenizer_manager.py +1 -1
  92. sglang/srt/managers/io_struct.py +20 -8
  93. sglang/srt/managers/mm_utils.py +6 -13
  94. sglang/srt/managers/multimodal_processor.py +1 -1
  95. sglang/srt/managers/schedule_batch.py +61 -25
  96. sglang/srt/managers/schedule_policy.py +6 -6
  97. sglang/srt/managers/scheduler.py +41 -19
  98. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  99. sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
  100. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  101. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  102. sglang/srt/managers/template_manager.py +35 -1
  103. sglang/srt/managers/tokenizer_manager.py +47 -30
  104. sglang/srt/managers/tp_worker.py +3 -0
  105. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  106. sglang/srt/mem_cache/allocator.py +61 -87
  107. sglang/srt/mem_cache/hicache_storage.py +1 -1
  108. sglang/srt/mem_cache/hiradix_cache.py +80 -22
  109. sglang/srt/mem_cache/lora_radix_cache.py +421 -0
  110. sglang/srt/mem_cache/memory_pool_host.py +34 -36
  111. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  112. sglang/srt/mem_cache/radix_cache.py +2 -5
  113. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  114. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
  115. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
  116. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
  117. sglang/srt/model_executor/cuda_graph_runner.py +29 -9
  118. sglang/srt/model_executor/forward_batch_info.py +61 -19
  119. sglang/srt/model_executor/model_runner.py +148 -37
  120. sglang/srt/model_loader/loader.py +18 -6
  121. sglang/srt/model_loader/weight_utils.py +10 -0
  122. sglang/srt/models/bailing_moe.py +425 -0
  123. sglang/srt/models/deepseek_v2.py +137 -59
  124. sglang/srt/models/ernie4.py +426 -0
  125. sglang/srt/models/ernie4_eagle.py +203 -0
  126. sglang/srt/models/gemma2.py +0 -34
  127. sglang/srt/models/gemma3n_mm.py +38 -0
  128. sglang/srt/models/glm4.py +6 -0
  129. sglang/srt/models/glm4_moe.py +28 -16
  130. sglang/srt/models/glm4v.py +589 -0
  131. sglang/srt/models/glm4v_moe.py +400 -0
  132. sglang/srt/models/gpt_oss.py +1251 -0
  133. sglang/srt/models/granite.py +0 -25
  134. sglang/srt/models/llama.py +0 -25
  135. sglang/srt/models/llama4.py +1 -1
  136. sglang/srt/models/qwen2.py +6 -0
  137. sglang/srt/models/qwen2_5_vl.py +7 -3
  138. sglang/srt/models/qwen2_audio.py +10 -9
  139. sglang/srt/models/qwen2_moe.py +6 -0
  140. sglang/srt/models/qwen3.py +0 -24
  141. sglang/srt/models/qwen3_moe.py +32 -6
  142. sglang/srt/models/registry.py +1 -1
  143. sglang/srt/models/step3_vl.py +9 -0
  144. sglang/srt/models/torch_native_llama.py +0 -24
  145. sglang/srt/models/transformers.py +2 -5
  146. sglang/srt/multimodal/processors/base_processor.py +23 -13
  147. sglang/srt/multimodal/processors/glm4v.py +132 -0
  148. sglang/srt/multimodal/processors/qwen_audio.py +4 -2
  149. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  150. sglang/srt/reasoning_parser.py +332 -37
  151. sglang/srt/server_args.py +186 -75
  152. sglang/srt/speculative/eagle_worker.py +16 -0
  153. sglang/srt/two_batch_overlap.py +169 -9
  154. sglang/srt/utils.py +41 -5
  155. sglang/srt/weight_sync/tensor_bucket.py +106 -0
  156. sglang/test/attention/test_trtllm_mla_backend.py +186 -36
  157. sglang/test/doc_patch.py +59 -0
  158. sglang/test/few_shot_gsm8k.py +1 -1
  159. sglang/test/few_shot_gsm8k_engine.py +1 -1
  160. sglang/test/run_eval.py +4 -1
  161. sglang/test/runners.py +2 -2
  162. sglang/test/simple_eval_common.py +6 -0
  163. sglang/test/simple_eval_gpqa.py +2 -0
  164. sglang/test/test_fp4_moe.py +118 -36
  165. sglang/test/test_utils.py +1 -1
  166. sglang/utils.py +1 -1
  167. sglang/version.py +1 -1
  168. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
  169. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
  170. sglang/srt/lora/backend/flashinfer_backend.py +0 -131
  171. /sglang/{api.py → lang/api.py} +0 -0
  172. /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
  173. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
  174. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  175. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -319,6 +319,7 @@ def fused_moe_kernel(
319
319
  # Pointers to matrices
320
320
  a_ptr,
321
321
  b_ptr,
322
+ bias_ptr,
322
323
  c_ptr,
323
324
  a_scale_ptr,
324
325
  b_scale_ptr,
@@ -340,6 +341,8 @@ def fused_moe_kernel(
340
341
  stride_be,
341
342
  stride_bk,
342
343
  stride_bn,
344
+ stride_bias_e,
345
+ stride_bias_n,
343
346
  stride_cm,
344
347
  stride_cn,
345
348
  stride_asm,
@@ -449,6 +452,10 @@ def fused_moe_kernel(
449
452
  + off_experts * stride_be
450
453
  + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
451
454
  )
455
+ if bias_ptr is not None:
456
+ bias = tl.load(
457
+ bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
458
+ )
452
459
  if use_int8_w8a16:
453
460
  b_scale_ptrs = (
454
461
  b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
@@ -526,18 +533,20 @@ def fused_moe_kernel(
526
533
  a_ptrs += BLOCK_SIZE_K * stride_ak
527
534
  b_ptrs += BLOCK_SIZE_K * stride_bk
528
535
 
529
- if MUL_ROUTED_WEIGHT:
530
- moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
531
- accumulator = accumulator * moe_weight[:, None]
532
536
  if use_int8_w8a16:
533
- accumulator = (accumulator * b_scale).to(compute_type)
537
+ accumulator *= b_scale
534
538
  elif use_fp8_w8a8 or use_int8_w8a8:
535
- if group_k > 0 and group_n > 0:
536
- accumulator = accumulator.to(compute_type)
537
- else:
538
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
539
- else:
540
- accumulator = accumulator.to(compute_type)
539
+ if group_k == 0 or group_n == 0:
540
+ accumulator *= a_scale * b_scale
541
+
542
+ if bias_ptr is not None:
543
+ accumulator += bias
544
+
545
+ if MUL_ROUTED_WEIGHT:
546
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
547
+ accumulator *= moe_weight[:, None]
548
+
549
+ accumulator = accumulator.to(compute_type)
541
550
  # -----------------------------------------------------------
542
551
  # Write back the block of the output
543
552
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -622,6 +631,7 @@ def moe_align_block_size(
622
631
  def invoke_fused_moe_kernel(
623
632
  A: torch.Tensor,
624
633
  B: torch.Tensor,
634
+ bias: Optional[torch.Tensor],
625
635
  C: torch.Tensor,
626
636
  A_scale: Optional[torch.Tensor],
627
637
  B_scale: Optional[torch.Tensor],
@@ -711,6 +721,7 @@ def invoke_fused_moe_kernel(
711
721
  ):
712
722
  assert B_scale is not None and B_scale.ndim == 3
713
723
  assert B_zp is None or B_zp.ndim == 3
724
+ assert bias is None
714
725
  fused_moe_kernel_gptq_awq[grid](
715
726
  A,
716
727
  B,
@@ -754,6 +765,7 @@ def invoke_fused_moe_kernel(
754
765
  fused_moe_kernel[grid](
755
766
  A,
756
767
  B,
768
+ bias,
757
769
  C,
758
770
  A_scale,
759
771
  B_scale,
@@ -770,6 +782,8 @@ def invoke_fused_moe_kernel(
770
782
  B.stride(0),
771
783
  B.stride(2),
772
784
  B.stride(1),
785
+ bias.stride(0) if bias is not None else 0,
786
+ bias.stride(1) if bias is not None else 0,
773
787
  C.stride(1),
774
788
  C.stride(2),
775
789
  A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
@@ -994,6 +1008,8 @@ def inplace_fused_experts(
994
1008
  w2: torch.Tensor,
995
1009
  topk_weights: torch.Tensor,
996
1010
  topk_ids: torch.Tensor,
1011
+ b1: Optional[torch.Tensor] = None,
1012
+ b2: Optional[torch.Tensor] = None,
997
1013
  activation: str = "silu",
998
1014
  apply_router_weight_on_input: bool = False,
999
1015
  use_fp8_w8a8: bool = False,
@@ -1009,6 +1025,8 @@ def inplace_fused_experts(
1009
1025
  a2_scale: Optional[torch.Tensor] = None,
1010
1026
  block_shape: Optional[List[int]] = None,
1011
1027
  routed_scaling_factor: Optional[float] = None,
1028
+ activation_alpha: Optional[float] = None,
1029
+ swiglu_limit: Optional[float] = None,
1012
1030
  ) -> None:
1013
1031
  fused_experts_impl(
1014
1032
  hidden_states,
@@ -1016,6 +1034,8 @@ def inplace_fused_experts(
1016
1034
  w2,
1017
1035
  topk_weights,
1018
1036
  topk_ids,
1037
+ b1,
1038
+ b2,
1019
1039
  True,
1020
1040
  activation,
1021
1041
  apply_router_weight_on_input,
@@ -1033,6 +1053,8 @@ def inplace_fused_experts(
1033
1053
  block_shape,
1034
1054
  False,
1035
1055
  routed_scaling_factor,
1056
+ activation_alpha,
1057
+ swiglu_limit,
1036
1058
  )
1037
1059
 
1038
1060
 
@@ -1042,6 +1064,8 @@ def inplace_fused_experts_fake(
1042
1064
  w2: torch.Tensor,
1043
1065
  topk_weights: torch.Tensor,
1044
1066
  topk_ids: torch.Tensor,
1067
+ b1: Optional[torch.Tensor] = None,
1068
+ b2: Optional[torch.Tensor] = None,
1045
1069
  activation: str = "silu",
1046
1070
  apply_router_weight_on_input: bool = False,
1047
1071
  use_fp8_w8a8: bool = False,
@@ -1057,6 +1081,8 @@ def inplace_fused_experts_fake(
1057
1081
  a2_scale: Optional[torch.Tensor] = None,
1058
1082
  block_shape: Optional[List[int]] = None,
1059
1083
  routed_scaling_factor: Optional[float] = None,
1084
+ activation_alpha: Optional[float] = None,
1085
+ swiglu_limit: Optional[float] = None,
1060
1086
  ) -> None:
1061
1087
  pass
1062
1088
 
@@ -1075,6 +1101,8 @@ def outplace_fused_experts(
1075
1101
  w2: torch.Tensor,
1076
1102
  topk_weights: torch.Tensor,
1077
1103
  topk_ids: torch.Tensor,
1104
+ b1: Optional[torch.Tensor] = None,
1105
+ b2: Optional[torch.Tensor] = None,
1078
1106
  activation: str = "silu",
1079
1107
  apply_router_weight_on_input: bool = False,
1080
1108
  use_fp8_w8a8: bool = False,
@@ -1091,6 +1119,8 @@ def outplace_fused_experts(
1091
1119
  block_shape: Optional[List[int]] = None,
1092
1120
  no_combine: bool = False,
1093
1121
  routed_scaling_factor: Optional[float] = None,
1122
+ activation_alpha: Optional[float] = None,
1123
+ swiglu_limit: Optional[float] = None,
1094
1124
  ) -> torch.Tensor:
1095
1125
  return fused_experts_impl(
1096
1126
  hidden_states,
@@ -1098,6 +1128,8 @@ def outplace_fused_experts(
1098
1128
  w2,
1099
1129
  topk_weights,
1100
1130
  topk_ids,
1131
+ b1,
1132
+ b2,
1101
1133
  False,
1102
1134
  activation,
1103
1135
  apply_router_weight_on_input,
@@ -1115,6 +1147,8 @@ def outplace_fused_experts(
1115
1147
  block_shape,
1116
1148
  no_combine=no_combine,
1117
1149
  routed_scaling_factor=routed_scaling_factor,
1150
+ activation_alpha=activation_alpha,
1151
+ swiglu_limit=swiglu_limit,
1118
1152
  )
1119
1153
 
1120
1154
 
@@ -1124,6 +1158,8 @@ def outplace_fused_experts_fake(
1124
1158
  w2: torch.Tensor,
1125
1159
  topk_weights: torch.Tensor,
1126
1160
  topk_ids: torch.Tensor,
1161
+ b1: Optional[torch.Tensor] = None,
1162
+ b2: Optional[torch.Tensor] = None,
1127
1163
  activation: str = "silu",
1128
1164
  apply_router_weight_on_input: bool = False,
1129
1165
  use_fp8_w8a8: bool = False,
@@ -1140,6 +1176,8 @@ def outplace_fused_experts_fake(
1140
1176
  block_shape: Optional[List[int]] = None,
1141
1177
  no_combine: bool = False,
1142
1178
  routed_scaling_factor: Optional[float] = None,
1179
+ activation_alpha: Optional[float] = None,
1180
+ swiglu_limit: Optional[float] = None,
1143
1181
  ) -> torch.Tensor:
1144
1182
  return torch.empty_like(hidden_states)
1145
1183
 
@@ -1157,6 +1195,8 @@ def fused_experts(
1157
1195
  w1: torch.Tensor,
1158
1196
  w2: torch.Tensor,
1159
1197
  topk_output: TopKOutput,
1198
+ b1: Optional[torch.Tensor] = None,
1199
+ b2: Optional[torch.Tensor] = None,
1160
1200
  inplace: bool = False,
1161
1201
  activation: str = "silu",
1162
1202
  apply_router_weight_on_input: bool = False,
@@ -1174,6 +1214,8 @@ def fused_experts(
1174
1214
  block_shape: Optional[List[int]] = None,
1175
1215
  no_combine: bool = False,
1176
1216
  routed_scaling_factor: Optional[float] = None,
1217
+ activation_alpha: Optional[float] = None,
1218
+ swiglu_limit: Optional[float] = None,
1177
1219
  ):
1178
1220
  topk_weights, topk_ids, _ = topk_output
1179
1221
  if inplace:
@@ -1184,6 +1226,8 @@ def fused_experts(
1184
1226
  w2,
1185
1227
  topk_weights,
1186
1228
  topk_ids,
1229
+ b1,
1230
+ b2,
1187
1231
  activation,
1188
1232
  apply_router_weight_on_input,
1189
1233
  use_fp8_w8a8,
@@ -1199,6 +1243,8 @@ def fused_experts(
1199
1243
  a2_scale,
1200
1244
  block_shape,
1201
1245
  routed_scaling_factor,
1246
+ activation_alpha,
1247
+ swiglu_limit,
1202
1248
  )
1203
1249
  return hidden_states
1204
1250
  else:
@@ -1208,6 +1254,8 @@ def fused_experts(
1208
1254
  w2,
1209
1255
  topk_weights,
1210
1256
  topk_ids,
1257
+ b1,
1258
+ b2,
1211
1259
  activation,
1212
1260
  apply_router_weight_on_input,
1213
1261
  use_fp8_w8a8,
@@ -1224,6 +1272,8 @@ def fused_experts(
1224
1272
  block_shape,
1225
1273
  no_combine=no_combine,
1226
1274
  routed_scaling_factor=routed_scaling_factor,
1275
+ activation_alpha=activation_alpha,
1276
+ swiglu_limit=swiglu_limit,
1227
1277
  )
1228
1278
 
1229
1279
 
@@ -1319,12 +1369,22 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
1319
1369
  out.mul_(routed_scaling_factor)
1320
1370
 
1321
1371
 
1372
+ @torch.compile
1373
+ def swiglu_with_alpha_and_limit(x, alpha, limit):
1374
+ gate, up = x[..., ::2], x[..., 1::2]
1375
+ gate = gate.clamp(min=None, max=limit)
1376
+ up = up.clamp(min=-limit, max=limit)
1377
+ return gate * torch.sigmoid(gate * alpha) * (up + 1)
1378
+
1379
+
1322
1380
  def fused_experts_impl(
1323
1381
  hidden_states: torch.Tensor,
1324
1382
  w1: torch.Tensor,
1325
1383
  w2: torch.Tensor,
1326
1384
  topk_weights: torch.Tensor,
1327
1385
  topk_ids: torch.Tensor,
1386
+ b1: Optional[torch.Tensor] = None,
1387
+ b2: Optional[torch.Tensor] = None,
1328
1388
  inplace: bool = False,
1329
1389
  activation: str = "silu",
1330
1390
  apply_router_weight_on_input: bool = False,
@@ -1342,6 +1402,8 @@ def fused_experts_impl(
1342
1402
  block_shape: Optional[List[int]] = None,
1343
1403
  no_combine: bool = False,
1344
1404
  routed_scaling_factor: Optional[float] = None,
1405
+ activation_alpha: Optional[float] = None,
1406
+ swiglu_limit: Optional[float] = None,
1345
1407
  ):
1346
1408
  padded_size = padding_size
1347
1409
  if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -1353,7 +1415,7 @@ def fused_experts_impl(
1353
1415
  else:
1354
1416
  assert (
1355
1417
  hidden_states.shape[1] == w1.shape[2] - padded_size
1356
- ), "Hidden size mismatch"
1418
+ ), f"Hidden size mismatch"
1357
1419
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1358
1420
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1359
1421
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -1449,6 +1511,7 @@ def fused_experts_impl(
1449
1511
  invoke_fused_moe_kernel(
1450
1512
  curr_hidden_states,
1451
1513
  w1,
1514
+ b1,
1452
1515
  intermediate_cache1,
1453
1516
  a1_scale,
1454
1517
  w1_scale,
@@ -1470,13 +1533,24 @@ def fused_experts_impl(
1470
1533
  block_shape=block_shape,
1471
1534
  )
1472
1535
  if activation == "silu":
1473
- if _is_cuda:
1536
+ if activation_alpha is not None:
1537
+ assert swiglu_limit is not None
1538
+ intermediate_cache2 = swiglu_with_alpha_and_limit(
1539
+ intermediate_cache1.view(-1, N),
1540
+ activation_alpha,
1541
+ swiglu_limit,
1542
+ )
1543
+ elif _is_cuda:
1474
1544
  silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1475
1545
  else:
1476
1546
  vllm_ops.silu_and_mul(
1477
1547
  intermediate_cache2, intermediate_cache1.view(-1, N)
1478
1548
  )
1479
1549
  elif activation == "gelu":
1550
+ assert (
1551
+ activation_alpha is None
1552
+ ), "activation_alpha is not supported for gelu"
1553
+ assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
1480
1554
  if _is_cuda:
1481
1555
  gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1482
1556
  else:
@@ -1489,6 +1563,7 @@ def fused_experts_impl(
1489
1563
  invoke_fused_moe_kernel(
1490
1564
  intermediate_cache2,
1491
1565
  w2,
1566
+ b2,
1492
1567
  (
1493
1568
  intermediate_cache3
1494
1569
  if not no_combine and topk_ids.shape[1] != 1
@@ -1567,6 +1642,8 @@ def fused_moe(
1567
1642
  w1: torch.Tensor,
1568
1643
  w2: torch.Tensor,
1569
1644
  topk_output: TopKOutput,
1645
+ b1: Optional[torch.Tensor] = None,
1646
+ b2: Optional[torch.Tensor] = None,
1570
1647
  inplace: bool = False,
1571
1648
  activation: str = "silu",
1572
1649
  apply_router_weight_on_input: bool = False,
@@ -1584,6 +1661,8 @@ def fused_moe(
1584
1661
  block_shape: Optional[List[int]] = None,
1585
1662
  no_combine: bool = False,
1586
1663
  routed_scaling_factor: Optional[float] = None,
1664
+ activation_alpha: Optional[float] = None,
1665
+ swiglu_limit: Optional[float] = None,
1587
1666
  ) -> torch.Tensor:
1588
1667
  """
1589
1668
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1594,6 +1673,8 @@ def fused_moe(
1594
1673
  - w1 (torch.Tensor): The first set of expert weights.
1595
1674
  - w2 (torch.Tensor): The second set of expert weights.
1596
1675
  - topk_output (TopKOutput): The top-k output of the experts.
1676
+ - b1 (Optional[torch.Tensor]): Optional bias for w1.
1677
+ - b2 (Optional[torch.Tensor]): Optional bias for w2.
1597
1678
  - inplace (bool): If True, perform the operation in-place.
1598
1679
  Defaults to False.
1599
1680
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
@@ -1615,6 +1696,10 @@ def fused_moe(
1615
1696
  a2.
1616
1697
  - block_shape: (Optional[List[int]]): Optional block size for block-wise
1617
1698
  quantization.
1699
+ - activation_alpha (Optional[float]): Optional alpha for the activation
1700
+ function.
1701
+ - swiglu_limit (Optional[float]): Optional limit for the swiglu activation
1702
+ function.
1618
1703
 
1619
1704
  Returns:
1620
1705
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -1625,6 +1710,8 @@ def fused_moe(
1625
1710
  w1,
1626
1711
  w2,
1627
1712
  topk_output,
1713
+ b1=b1,
1714
+ b2=b2,
1628
1715
  inplace=inplace,
1629
1716
  activation=activation,
1630
1717
  apply_router_weight_on_input=apply_router_weight_on_input,
@@ -1642,4 +1729,6 @@ def fused_moe(
1642
1729
  block_shape=block_shape,
1643
1730
  no_combine=no_combine,
1644
1731
  routed_scaling_factor=routed_scaling_factor,
1732
+ activation_alpha=activation_alpha,
1733
+ swiglu_limit=swiglu_limit,
1645
1734
  )