sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
478
478
  end_expert_id,
479
479
  topk,
480
480
  hidden_size,
481
+ dst_start,
481
482
  BLOCK_SIZE: tl.constexpr,
482
483
  ):
483
484
  InDtype = down_output_ptr.dtype.element_ty
484
485
 
485
- src_idx = tl.program_id(0)
486
+ src_idx_int32 = tl.program_id(0)
487
+ src_idx = src_idx_int32.to(tl.int64)
486
488
  src2dst_ptr = src2dst_ptr + src_idx * topk
487
489
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
488
490
  topk_weights_ptr = topk_weights_ptr + src_idx * topk
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
501
503
  expert_id = tl.load(topk_ids_ptr + idx)
502
504
  if expert_id >= start_expert_id and expert_id <= end_expert_id:
503
505
  computed = True
504
- dst_idx = tl.load(src2dst_ptr + idx)
506
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
507
+ dst_idx = dst_idx_int32.to(tl.int64)
508
+ dst_idx = dst_idx - dst_start
505
509
  weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
506
510
  load_ptr = down_output_ptr + dst_idx * hidden_size
507
511
  in_data = tl.load(load_ptr + offset, mask=mask)
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
1086
1090
  BLOCK_SIZE_K=BLOCK_SIZE_K,
1087
1091
  )
1088
1092
  return output.t()[:m]
1093
+
1094
+
1095
+ @triton.jit
1096
+ def compute_masked_m_triton_kernel(seg_indptr, masked_m):
1097
+ expert_id = tl.program_id(0)
1098
+ start = tl.load(seg_indptr + expert_id)
1099
+ end = tl.load(seg_indptr + expert_id + 1)
1100
+ tl.store(masked_m + expert_id, (end - start))
1101
+
1102
+
1103
+ @triton.jit
1104
+ def deepgemm_compute_src2dst_triton_kernel(
1105
+ topk_ids,
1106
+ reorder_ids,
1107
+ seg_indptr,
1108
+ src2dst,
1109
+ m_max,
1110
+ num_toks,
1111
+ BLOCK_SIZE: tl.constexpr,
1112
+ ):
1113
+ pid = tl.program_id(axis=0)
1114
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1115
+ mask = dst_id < num_toks
1116
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
1117
+ expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
1118
+ expert_dst_start = tl.load(seg_indptr + expert_id)
1119
+ expert_dst_offset = dst_id - expert_dst_start
1120
+ dst_id = expert_id * m_max + expert_dst_offset
1121
+ tl.store(src2dst + src_id, dst_id, mask=mask)
1122
+
1123
+
1124
+ @triton.jit
1125
+ def fill_gateup_input_triton_kernel(
1126
+ input_ptr,
1127
+ scale_ptr,
1128
+ gateup_input_ptr,
1129
+ gateup_input_scale_ptr,
1130
+ src2dst_ptr,
1131
+ topk_ids_ptr,
1132
+ start_expert_id,
1133
+ end_expert_id,
1134
+ topk,
1135
+ m_max,
1136
+ hidden_size,
1137
+ scale_size,
1138
+ BLOCK_SIZE: tl.constexpr,
1139
+ ):
1140
+
1141
+ src_idx_int32 = tl.program_id(0)
1142
+ src_idx = src_idx_int32.to(tl.int64)
1143
+ src2dst_ptr = src2dst_ptr + src_idx * topk
1144
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
1145
+ src_ptr = input_ptr + src_idx * hidden_size
1146
+ scale_src_ptr = scale_ptr + src_idx * scale_size
1147
+
1148
+ vec = tl.arange(0, BLOCK_SIZE)
1149
+ for idx in range(topk):
1150
+ expert_id = tl.load(topk_ids_ptr + idx)
1151
+ if expert_id >= start_expert_id and expert_id <= end_expert_id:
1152
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
1153
+ dst_idx = dst_idx_int32.to(tl.int64)
1154
+ dst_idx = dst_idx - start_expert_id * m_max
1155
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
1156
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
1157
+ offset = start_offset + vec
1158
+ mask = offset < hidden_size
1159
+ in_data = tl.load(src_ptr + offset, mask=mask)
1160
+ tl.store(dst_ptr + offset, in_data, mask=mask)
1161
+ scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
1162
+ for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
1163
+ offset = start_offset + vec
1164
+ mask = offset < scale_size
1165
+ in_scale = tl.load(scale_src_ptr + offset, mask=mask)
1166
+ tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
1167
+
1168
+
1169
+ def moe_ep_deepgemm_preprocess(
1170
+ topk_ids: torch.Tensor,
1171
+ num_experts: int,
1172
+ hidden_states: torch.Tensor,
1173
+ top_k: int,
1174
+ start_expert_id,
1175
+ end_expert_id,
1176
+ block_shape,
1177
+ output_dtype: torch.dtype = torch.float8_e4m3fn,
1178
+ ):
1179
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
1180
+ seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
1181
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
1182
+ masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
1183
+
1184
+ compute_seg_indptr_triton_kernel[(num_experts,)](
1185
+ reorder_topk_ids, seg_indptr, topk_ids.numel()
1186
+ )
1187
+
1188
+ grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
1189
+ compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
1190
+
1191
+ # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
1192
+ m_max = (hidden_states.size(0) + 255) // 256 * 256
1193
+ expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
1194
+ gateup_input = torch.empty(
1195
+ (int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
1196
+ device=hidden_states.device,
1197
+ dtype=output_dtype,
1198
+ )
1199
+
1200
+ deepgemm_compute_src2dst_triton_kernel[grid](
1201
+ topk_ids,
1202
+ reorder_ids,
1203
+ seg_indptr,
1204
+ src2dst,
1205
+ m_max,
1206
+ topk_ids.numel(),
1207
+ BLOCK_SIZE=256,
1208
+ )
1209
+
1210
+ if block_shape is None:
1211
+ block_shape = [128, 128]
1212
+ assert len(block_shape) == 2
1213
+ block_n, block_k = block_shape[0], block_shape[1]
1214
+ hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
1215
+
1216
+ gateup_input_scale = torch.empty(
1217
+ (gateup_input.size(0), gateup_input.size(1), scale.size(1)),
1218
+ device=hidden_states.device,
1219
+ dtype=scale.dtype,
1220
+ )
1221
+
1222
+ fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
1223
+ hidden_states,
1224
+ scale,
1225
+ gateup_input,
1226
+ gateup_input_scale,
1227
+ src2dst,
1228
+ topk_ids,
1229
+ start_expert_id,
1230
+ end_expert_id,
1231
+ top_k,
1232
+ m_max,
1233
+ hidden_states.size(1),
1234
+ scale.size(1),
1235
+ BLOCK_SIZE=1024,
1236
+ )
1237
+
1238
+ return (
1239
+ m_max,
1240
+ masked_m[start_expert_id : (end_expert_id + 1)],
1241
+ expected_m,
1242
+ src2dst,
1243
+ gateup_input,
1244
+ gateup_input_scale,
1245
+ )
@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
16
16
  ep_scatter,
17
17
  gelu_and_mul_triton_kernel,
18
18
  grouped_gemm_triton,
19
+ moe_ep_deepgemm_preprocess,
19
20
  post_reorder_triton_kernel,
20
21
  pre_reorder_triton_kernel,
21
22
  run_moe_ep_preproess,
@@ -33,10 +34,12 @@ from sglang.srt.layers.quantization.base_config import (
33
34
  )
34
35
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
36
  from sglang.srt.layers.quantization.fp8_kernel import (
37
+ is_fp8_fnuz,
36
38
  scaled_fp8_quant,
37
39
  sglang_per_token_group_quant_fp8,
38
40
  sglang_per_token_quant_fp8,
39
41
  )
42
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
40
43
  from sglang.srt.managers.expert_location import get_global_expert_location_metadata
41
44
  from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
42
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -50,10 +53,17 @@ from sglang.srt.utils import (
50
53
  )
51
54
 
52
55
  _is_hip = is_hip()
56
+ _is_fp8_fnuz = is_fp8_fnuz()
57
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
53
58
 
54
59
  if _is_hip:
55
60
  from vllm._custom_ops import scaled_fp8_quant
56
61
 
62
+ if _use_aiter:
63
+ from aiter import ActivationType, QuantType
64
+ from aiter.fused_moe import fused_moe
65
+ from aiter.ops.shuffle import shuffle_weight
66
+
57
67
  logger = logging.getLogger(__name__)
58
68
 
59
69
 
@@ -175,6 +185,7 @@ class EPMoE(torch.nn.Module):
175
185
  assert (
176
186
  num_fused_shared_experts == 0
177
187
  ), "num_fused_shared_experts is not supported in EP"
188
+ self.num_fused_shared_experts = num_fused_shared_experts
178
189
  self.num_experts_per_partition = self.num_experts // self.tp_size
179
190
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
180
191
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
@@ -224,13 +235,182 @@ class EPMoE(torch.nn.Module):
224
235
 
225
236
  self.grouped_gemm_runner = None
226
237
 
238
+ self.w13_weight_fp8 = (
239
+ self.w13_weight,
240
+ (
241
+ self.w13_weight_scale_inv
242
+ if self.use_block_quant
243
+ else self.w13_weight_scale
244
+ ),
245
+ )
246
+ self.w2_weight_fp8 = (
247
+ self.w2_weight,
248
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
249
+ )
250
+
227
251
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
252
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
253
+ return self.forward_deepgemm(hidden_states, router_logits)
254
+ else:
255
+ return self.forward_normal(hidden_states, router_logits)
256
+
257
+ def forward_deepgemm(
258
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor
259
+ ):
260
+ assert self.quant_method is not None
261
+ assert self.activation == "silu"
228
262
  hidden_states_shape = hidden_states.shape
229
263
  hidden_states_dtype = hidden_states.dtype
230
264
  hidden_states_device = hidden_states.device
265
+ topk_weights, topk_ids = select_experts(
266
+ hidden_states=hidden_states,
267
+ router_logits=router_logits,
268
+ top_k=self.top_k,
269
+ use_grouped_topk=self.use_grouped_topk,
270
+ renormalize=self.renormalize,
271
+ topk_group=self.topk_group,
272
+ num_expert_group=self.num_expert_group,
273
+ num_fused_shared_experts=self.num_fused_shared_experts,
274
+ correction_bias=self.correction_bias,
275
+ custom_routing_function=self.custom_routing_function,
276
+ routed_scaling_factor=self.routed_scaling_factor,
277
+ )
231
278
 
232
- assert self.quant_method is not None
279
+ if not self.use_block_quant:
280
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
281
+ scale_block_size = 128
282
+ w13_weight_scale_n = 2 * (
283
+ (self.intermediate_size + scale_block_size - 1) // scale_block_size
284
+ )
285
+ w13_weight_scale_k = (
286
+ hidden_states_shape[-1] + scale_block_size - 1
287
+ ) // scale_block_size
288
+ w13_weight_scale = (
289
+ self.w13_weight_scale.unsqueeze(1)
290
+ .repeat_interleave(w13_weight_scale_n, dim=1)
291
+ .unsqueeze(2)
292
+ .repeat_interleave(w13_weight_scale_k, dim=2)
293
+ )
294
+ self.w13_weight_fp8 = (
295
+ self.w13_weight,
296
+ w13_weight_scale,
297
+ )
298
+ w2_weight_scale_n = (
299
+ hidden_states_shape[-1] + scale_block_size - 1
300
+ ) // scale_block_size
301
+ w2_weight_scale_k = (
302
+ self.intermediate_size + scale_block_size - 1
303
+ ) // scale_block_size
304
+ w2_weight_scale = (
305
+ self.w2_weight_scale.unsqueeze(1)
306
+ .repeat_interleave(w2_weight_scale_n, dim=1)
307
+ .unsqueeze(2)
308
+ .repeat_interleave(w2_weight_scale_k, dim=2)
309
+ )
310
+ self.w2_weight_fp8 = (
311
+ self.w2_weight,
312
+ w2_weight_scale,
313
+ )
314
+
315
+ # PreReorder
316
+ m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
317
+ moe_ep_deepgemm_preprocess(
318
+ topk_ids,
319
+ self.num_experts,
320
+ hidden_states,
321
+ self.top_k,
322
+ self.start_expert_id,
323
+ self.end_expert_id,
324
+ self.block_shape,
325
+ )
326
+ )
327
+
328
+ dispose_tensor(hidden_states)
233
329
 
330
+ # GroupGemm-0
331
+ gateup_input_fp8 = (
332
+ gateup_input,
333
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
334
+ )
335
+ num_groups, m, k = gateup_input_fp8[0].size()
336
+ n = self.w13_weight.size(1)
337
+ gateup_output = torch.empty(
338
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
339
+ )
340
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
341
+ gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
342
+ )
343
+ del gateup_input
344
+ del gateup_input_fp8
345
+
346
+ # Act
347
+ down_input = torch.empty(
348
+ (
349
+ gateup_output.shape[0],
350
+ gateup_output.shape[1],
351
+ gateup_output.shape[2] // 2,
352
+ ),
353
+ device=hidden_states_device,
354
+ dtype=self.fp8_dtype,
355
+ )
356
+ scale_block_size = 128
357
+ down_input_scale = torch.empty(
358
+ (
359
+ gateup_output.shape[0],
360
+ gateup_output.shape[1],
361
+ gateup_output.shape[2] // 2 // scale_block_size,
362
+ ),
363
+ device=hidden_states_device,
364
+ dtype=torch.float32,
365
+ )
366
+ silu_and_mul_masked_post_quant_fwd(
367
+ gateup_output,
368
+ down_input,
369
+ down_input_scale,
370
+ scale_block_size,
371
+ masked_m,
372
+ )
373
+ del gateup_output
374
+
375
+ # GroupGemm-1
376
+ n = self.w2_weight.size(1)
377
+ down_input_fp8 = (
378
+ down_input,
379
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
380
+ )
381
+ down_output = torch.empty(
382
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
383
+ )
384
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
385
+ down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
386
+ )
387
+ del down_input
388
+ del down_input_fp8
389
+
390
+ # PostReorder
391
+ output = torch.empty(
392
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
393
+ )
394
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
395
+ down_output,
396
+ output,
397
+ src2dst,
398
+ topk_ids,
399
+ topk_weights,
400
+ self.start_expert_id,
401
+ self.end_expert_id,
402
+ self.top_k,
403
+ hidden_states_shape[1],
404
+ m_max * self.start_expert_id,
405
+ BLOCK_SIZE=512,
406
+ )
407
+ return output
408
+
409
+ def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
410
+ assert self.quant_method is not None
411
+ hidden_states_shape = hidden_states.shape
412
+ hidden_states_dtype = hidden_states.dtype
413
+ hidden_states_device = hidden_states.device
234
414
  if self.grouped_gemm_runner is None:
235
415
  self.grouped_gemm_runner = GroupedGemmRunner(
236
416
  hidden_states.device,
@@ -246,6 +426,7 @@ class EPMoE(torch.nn.Module):
246
426
  renormalize=self.renormalize,
247
427
  topk_group=self.topk_group,
248
428
  num_expert_group=self.num_expert_group,
429
+ num_fused_shared_experts=self.num_fused_shared_experts,
249
430
  correction_bias=self.correction_bias,
250
431
  custom_routing_function=self.custom_routing_function,
251
432
  routed_scaling_factor=self.routed_scaling_factor,
@@ -437,6 +618,7 @@ class EPMoE(torch.nn.Module):
437
618
  self.end_expert_id,
438
619
  self.top_k,
439
620
  hidden_states_shape[1],
621
+ 0,
440
622
  BLOCK_SIZE=512,
441
623
  )
442
624
  return output
@@ -843,6 +1025,42 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
843
1025
  torch.max(layer.w13_weight_scale, dim=1).values,
844
1026
  requires_grad=False,
845
1027
  )
1028
+ if self.block_quant:
1029
+ # If ROCm, normalize the weights and scales to e4m3fnuz
1030
+ if _is_fp8_fnuz:
1031
+ # activation_scheme: dynamic
1032
+ w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1033
+ weight=layer.w13_weight,
1034
+ weight_scale=layer.w13_weight_scale_inv,
1035
+ input_scale=None,
1036
+ )
1037
+ w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1038
+ weight=layer.w2_weight,
1039
+ weight_scale=layer.w2_weight_scale_inv,
1040
+ input_scale=None,
1041
+ )
1042
+ # Reset the parameter
1043
+ layer.w13_weight = torch.nn.Parameter(
1044
+ w13_weight, requires_grad=False
1045
+ )
1046
+ layer.w13_weight_scale_inv = torch.nn.Parameter(
1047
+ w13_weight_scale, requires_grad=False
1048
+ )
1049
+ layer.w13_input_scale = None
1050
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1051
+ layer.w2_weight_scale_inv = torch.nn.Parameter(
1052
+ w2_weight_scale, requires_grad=False
1053
+ )
1054
+ layer.w2_input_scale = None
1055
+ if _use_aiter:
1056
+ layer.w13_weight = torch.nn.Parameter(
1057
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
1058
+ requires_grad=False,
1059
+ )
1060
+ layer.w2_weight = torch.nn.Parameter(
1061
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
1062
+ requires_grad=False,
1063
+ )
846
1064
  return
847
1065
 
848
1066
  def apply(
@@ -914,18 +1132,36 @@ class DeepEPMoE(EPMoE):
914
1132
  assert (
915
1133
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
916
1134
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
917
- self.w13_weight_fp8 = (
918
- self.w13_weight,
919
- (
920
- self.w13_weight_scale_inv
921
- if self.use_block_quant
922
- else self.w13_weight_scale
923
- ),
924
- )
925
- self.w2_weight_fp8 = (
926
- self.w2_weight,
927
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
928
- )
1135
+ if _use_aiter:
1136
+ # expert_mask is of size (self.num_experts_per_partition + 1),
1137
+ # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
1138
+ # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
1139
+ # self.expert_mask = [1, 1, 1, 1, 0]
1140
+ # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
1141
+ self.expert_mask = torch.zeros(
1142
+ (self.num_experts_per_partition + 1),
1143
+ device=torch.cuda.current_device(),
1144
+ dtype=torch.int,
1145
+ )
1146
+ # the last one is invalid rank_id
1147
+ self.expert_mask[:-1] = 1
1148
+ else:
1149
+ self.w13_weight_fp8 = (
1150
+ self.w13_weight,
1151
+ (
1152
+ self.w13_weight_scale_inv
1153
+ if self.use_block_quant
1154
+ else self.w13_weight_scale
1155
+ ),
1156
+ )
1157
+ self.w2_weight_fp8 = (
1158
+ self.w2_weight,
1159
+ (
1160
+ self.w2_weight_scale_inv
1161
+ if self.use_block_quant
1162
+ else self.w2_weight_scale
1163
+ ),
1164
+ )
929
1165
 
930
1166
  def forward(
931
1167
  self,
@@ -939,6 +1175,9 @@ class DeepEPMoE(EPMoE):
939
1175
  num_recv_tokens_per_expert: List[int],
940
1176
  forward_mode: ForwardMode,
941
1177
  ):
1178
+ if _use_aiter:
1179
+ # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1180
+ return self.forward_aiter(hidden_states, topk_idx, topk_weights)
942
1181
  resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
943
1182
  if resolved_deepep_mode == DeepEPMode.normal:
944
1183
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
@@ -1071,6 +1310,37 @@ class DeepEPMoE(EPMoE):
1071
1310
  )
1072
1311
  return down_output
1073
1312
 
1313
+ def forward_aiter(
1314
+ self,
1315
+ hidden_states: torch.Tensor,
1316
+ topk_idx: torch.Tensor,
1317
+ topk_weights: torch.Tensor,
1318
+ ):
1319
+ if hidden_states.shape[0] == 0:
1320
+ return hidden_states
1321
+ # in original deepep, idx == -1 meaning invalid and will not be processed.
1322
+ # aiter does not accept -1, we use a expert mask to make these idx invalid
1323
+ # (idx == num_experts_per_partition) meaning not used in aiter fused_moe
1324
+ topk_idx_copy = topk_idx.to(torch.int32)
1325
+ topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
1326
+
1327
+ return fused_moe(
1328
+ hidden_states,
1329
+ self.w13_weight,
1330
+ self.w2_weight,
1331
+ topk_weights,
1332
+ topk_idx_copy,
1333
+ w1_scale=self.w13_weight_scale_inv,
1334
+ w2_scale=self.w2_weight_scale_inv,
1335
+ quant_type=QuantType.per_128x128,
1336
+ activation=(
1337
+ ActivationType.Silu
1338
+ if self.activation == "silu"
1339
+ else ActivationType.Gelu
1340
+ ),
1341
+ expert_mask=self.expert_mask,
1342
+ )
1343
+
1074
1344
  def forward_deepgemm_contiguous(
1075
1345
  self,
1076
1346
  hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
@@ -1265,6 +1535,9 @@ class DeepEPMoE(EPMoE):
1265
1535
  def get_moe_impl_class():
1266
1536
  if global_server_args_dict["enable_deepep_moe"]:
1267
1537
  return DeepEPMoE
1538
+ if global_server_args_dict["enable_flashinfer_moe"]:
1539
+ # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1540
+ return FusedMoE
1268
1541
  if global_server_args_dict["enable_ep_moe"]:
1269
1542
  return EPMoE
1270
1543
  return FusedMoE
@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
6
6
  get_global_expert_distribution_recorder,
7
7
  )
8
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
- from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config
9
+ from sglang.srt.utils import (
10
+ DeepEPMode,
11
+ get_bool_env_var,
12
+ get_int_env_var,
13
+ is_hip,
14
+ load_json_config,
15
+ )
10
16
 
11
17
  try:
12
18
  from deep_ep import Buffer, Config
@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
32
38
  )
33
39
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
34
40
 
41
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
42
+
35
43
  logger = logging.getLogger(__name__)
36
44
 
37
45
 
@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
376
384
  Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
377
385
  https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
378
386
  """
387
+ if _use_aiter:
388
+ # skip permutation here as aiter fused_moe has fused inside
389
+ reorder_topk_ids = torch.empty(
390
+ (0,), device=hidden_states.device, dtype=torch.int64
391
+ )
392
+ seg_indptr = torch.zeros(
393
+ (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
394
+ )
395
+ return reorder_topk_ids, seg_indptr, hidden_states
379
396
 
380
397
  reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
381
398
  topk_idx, self.num_experts
@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
409
426
  topk_idx: torch.Tensor,
410
427
  topk_weights: torch.Tensor,
411
428
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
429
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
413
430
  output = hidden_states
414
431
  else:
415
432
  if hidden_states.shape[0] > 0:
@@ -77,8 +77,15 @@ def moe_forward_native(
77
77
  custom_routing_function: Optional[Callable] = None,
78
78
  correction_bias: Optional[torch.Tensor] = None,
79
79
  activation: str = "silu",
80
+ apply_router_weight_on_input: bool = False,
81
+ inplace: bool = True,
82
+ no_combine: bool = False,
80
83
  routed_scaling_factor: Optional[float] = None,
81
84
  ) -> torch.Tensor:
85
+
86
+ if apply_router_weight_on_input:
87
+ raise NotImplementedError()
88
+
82
89
  topk_weights, topk_ids = select_experts(
83
90
  hidden_states=x,
84
91
  router_logits=router_logits,