sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
16
16
  ep_scatter,
17
17
  gelu_and_mul_triton_kernel,
18
18
  grouped_gemm_triton,
19
+ moe_ep_deepgemm_preprocess,
19
20
  post_reorder_triton_kernel,
20
21
  pre_reorder_triton_kernel,
21
22
  run_moe_ep_preproess,
@@ -33,10 +34,12 @@ from sglang.srt.layers.quantization.base_config import (
33
34
  )
34
35
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
36
  from sglang.srt.layers.quantization.fp8_kernel import (
37
+ is_fp8_fnuz,
36
38
  scaled_fp8_quant,
37
39
  sglang_per_token_group_quant_fp8,
38
40
  sglang_per_token_quant_fp8,
39
41
  )
42
+ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
40
43
  from sglang.srt.managers.expert_location import get_global_expert_location_metadata
41
44
  from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
42
45
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -50,6 +53,7 @@ from sglang.srt.utils import (
50
53
  )
51
54
 
52
55
  _is_hip = is_hip()
56
+ _is_fp8_fnuz = is_fp8_fnuz()
53
57
 
54
58
  if _is_hip:
55
59
  from vllm._custom_ops import scaled_fp8_quant
@@ -175,6 +179,7 @@ class EPMoE(torch.nn.Module):
175
179
  assert (
176
180
  num_fused_shared_experts == 0
177
181
  ), "num_fused_shared_experts is not supported in EP"
182
+ self.num_fused_shared_experts = num_fused_shared_experts
178
183
  self.num_experts_per_partition = self.num_experts // self.tp_size
179
184
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
180
185
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
@@ -224,13 +229,182 @@ class EPMoE(torch.nn.Module):
224
229
 
225
230
  self.grouped_gemm_runner = None
226
231
 
232
+ self.w13_weight_fp8 = (
233
+ self.w13_weight,
234
+ (
235
+ self.w13_weight_scale_inv
236
+ if self.use_block_quant
237
+ else self.w13_weight_scale
238
+ ),
239
+ )
240
+ self.w2_weight_fp8 = (
241
+ self.w2_weight,
242
+ self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
243
+ )
244
+
227
245
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
246
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
247
+ return self.forward_deepgemm(hidden_states, router_logits)
248
+ else:
249
+ return self.forward_normal(hidden_states, router_logits)
250
+
251
+ def forward_deepgemm(
252
+ self, hidden_states: torch.Tensor, router_logits: torch.Tensor
253
+ ):
254
+ assert self.quant_method is not None
255
+ assert self.activation == "silu"
228
256
  hidden_states_shape = hidden_states.shape
229
257
  hidden_states_dtype = hidden_states.dtype
230
258
  hidden_states_device = hidden_states.device
259
+ topk_weights, topk_ids = select_experts(
260
+ hidden_states=hidden_states,
261
+ router_logits=router_logits,
262
+ top_k=self.top_k,
263
+ use_grouped_topk=self.use_grouped_topk,
264
+ renormalize=self.renormalize,
265
+ topk_group=self.topk_group,
266
+ num_expert_group=self.num_expert_group,
267
+ num_fused_shared_experts=self.num_fused_shared_experts,
268
+ correction_bias=self.correction_bias,
269
+ custom_routing_function=self.custom_routing_function,
270
+ routed_scaling_factor=self.routed_scaling_factor,
271
+ )
231
272
 
232
- assert self.quant_method is not None
273
+ if not self.use_block_quant:
274
+ # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
275
+ scale_block_size = 128
276
+ w13_weight_scale_n = 2 * (
277
+ (self.intermediate_size + scale_block_size - 1) // scale_block_size
278
+ )
279
+ w13_weight_scale_k = (
280
+ hidden_states_shape[-1] + scale_block_size - 1
281
+ ) // scale_block_size
282
+ w13_weight_scale = (
283
+ self.w13_weight_scale.unsqueeze(1)
284
+ .repeat_interleave(w13_weight_scale_n, dim=1)
285
+ .unsqueeze(2)
286
+ .repeat_interleave(w13_weight_scale_k, dim=2)
287
+ )
288
+ self.w13_weight_fp8 = (
289
+ self.w13_weight,
290
+ w13_weight_scale,
291
+ )
292
+ w2_weight_scale_n = (
293
+ hidden_states_shape[-1] + scale_block_size - 1
294
+ ) // scale_block_size
295
+ w2_weight_scale_k = (
296
+ self.intermediate_size + scale_block_size - 1
297
+ ) // scale_block_size
298
+ w2_weight_scale = (
299
+ self.w2_weight_scale.unsqueeze(1)
300
+ .repeat_interleave(w2_weight_scale_n, dim=1)
301
+ .unsqueeze(2)
302
+ .repeat_interleave(w2_weight_scale_k, dim=2)
303
+ )
304
+ self.w2_weight_fp8 = (
305
+ self.w2_weight,
306
+ w2_weight_scale,
307
+ )
308
+
309
+ # PreReorder
310
+ m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
311
+ moe_ep_deepgemm_preprocess(
312
+ topk_ids,
313
+ self.num_experts,
314
+ hidden_states,
315
+ self.top_k,
316
+ self.start_expert_id,
317
+ self.end_expert_id,
318
+ self.block_shape,
319
+ )
320
+ )
321
+
322
+ dispose_tensor(hidden_states)
323
+
324
+ # GroupGemm-0
325
+ gateup_input_fp8 = (
326
+ gateup_input,
327
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
328
+ )
329
+ num_groups, m, k = gateup_input_fp8[0].size()
330
+ n = self.w13_weight.size(1)
331
+ gateup_output = torch.empty(
332
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
333
+ )
334
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
335
+ gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
336
+ )
337
+ del gateup_input
338
+ del gateup_input_fp8
339
+
340
+ # Act
341
+ down_input = torch.empty(
342
+ (
343
+ gateup_output.shape[0],
344
+ gateup_output.shape[1],
345
+ gateup_output.shape[2] // 2,
346
+ ),
347
+ device=hidden_states_device,
348
+ dtype=self.fp8_dtype,
349
+ )
350
+ scale_block_size = 128
351
+ down_input_scale = torch.empty(
352
+ (
353
+ gateup_output.shape[0],
354
+ gateup_output.shape[1],
355
+ gateup_output.shape[2] // 2 // scale_block_size,
356
+ ),
357
+ device=hidden_states_device,
358
+ dtype=torch.float32,
359
+ )
360
+ silu_and_mul_masked_post_quant_fwd(
361
+ gateup_output,
362
+ down_input,
363
+ down_input_scale,
364
+ scale_block_size,
365
+ masked_m,
366
+ )
367
+ del gateup_output
233
368
 
369
+ # GroupGemm-1
370
+ n = self.w2_weight.size(1)
371
+ down_input_fp8 = (
372
+ down_input,
373
+ deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
374
+ )
375
+ down_output = torch.empty(
376
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
377
+ )
378
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
379
+ down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
380
+ )
381
+ del down_input
382
+ del down_input_fp8
383
+
384
+ # PostReorder
385
+ output = torch.empty(
386
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
387
+ )
388
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
389
+ down_output,
390
+ output,
391
+ src2dst,
392
+ topk_ids,
393
+ topk_weights,
394
+ self.start_expert_id,
395
+ self.end_expert_id,
396
+ self.top_k,
397
+ hidden_states_shape[1],
398
+ m_max * self.start_expert_id,
399
+ BLOCK_SIZE=512,
400
+ )
401
+ return output
402
+
403
+ def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
404
+ assert self.quant_method is not None
405
+ hidden_states_shape = hidden_states.shape
406
+ hidden_states_dtype = hidden_states.dtype
407
+ hidden_states_device = hidden_states.device
234
408
  if self.grouped_gemm_runner is None:
235
409
  self.grouped_gemm_runner = GroupedGemmRunner(
236
410
  hidden_states.device,
@@ -246,6 +420,7 @@ class EPMoE(torch.nn.Module):
246
420
  renormalize=self.renormalize,
247
421
  topk_group=self.topk_group,
248
422
  num_expert_group=self.num_expert_group,
423
+ num_fused_shared_experts=self.num_fused_shared_experts,
249
424
  correction_bias=self.correction_bias,
250
425
  custom_routing_function=self.custom_routing_function,
251
426
  routed_scaling_factor=self.routed_scaling_factor,
@@ -437,6 +612,7 @@ class EPMoE(torch.nn.Module):
437
612
  self.end_expert_id,
438
613
  self.top_k,
439
614
  hidden_states_shape[1],
615
+ 0,
440
616
  BLOCK_SIZE=512,
441
617
  )
442
618
  return output
@@ -843,6 +1019,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
843
1019
  torch.max(layer.w13_weight_scale, dim=1).values,
844
1020
  requires_grad=False,
845
1021
  )
1022
+ if self.block_quant:
1023
+ # If ROCm, normalize the weights and scales to e4m3fnuz
1024
+ if _is_fp8_fnuz:
1025
+ # activation_scheme: dynamic
1026
+ w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1027
+ weight=layer.w13_weight,
1028
+ weight_scale=layer.w13_weight_scale_inv,
1029
+ input_scale=None,
1030
+ )
1031
+ w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1032
+ weight=layer.w2_weight,
1033
+ weight_scale=layer.w2_weight_scale_inv,
1034
+ input_scale=None,
1035
+ )
1036
+ # Reset the parameter
1037
+ layer.w13_weight = torch.nn.Parameter(
1038
+ w13_weight, requires_grad=False
1039
+ )
1040
+ layer.w13_weight_scale_inv = torch.nn.Parameter(
1041
+ w13_weight_scale, requires_grad=False
1042
+ )
1043
+ layer.w13_input_scale = None
1044
+ layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1045
+ layer.w2_weight_scale_inv = torch.nn.Parameter(
1046
+ w2_weight_scale, requires_grad=False
1047
+ )
1048
+ layer.w2_input_scale = None
846
1049
  return
847
1050
 
848
1051
  def apply(
@@ -1265,6 +1468,9 @@ class DeepEPMoE(EPMoE):
1265
1468
  def get_moe_impl_class():
1266
1469
  if global_server_args_dict["enable_deepep_moe"]:
1267
1470
  return DeepEPMoE
1471
+ if global_server_args_dict["enable_flashinfer_moe"]:
1472
+ # Must come before EPMoE because FusedMoE also supports enable_ep_moe
1473
+ return FusedMoE
1268
1474
  if global_server_args_dict["enable_ep_moe"]:
1269
1475
  return EPMoE
1270
1476
  return FusedMoE
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
25
25
  sglang_per_token_group_quant_int8,
26
26
  )
27
27
  from sglang.srt.utils import (
28
+ cpu_has_amx_support,
28
29
  direct_register_custom_op,
29
30
  get_bool_env_var,
30
31
  get_device_name,
32
+ is_cpu,
31
33
  is_cuda,
32
34
  is_hip,
33
35
  log_info_on_rank0,
@@ -36,9 +38,13 @@ from sglang.srt.utils import (
36
38
 
37
39
  _is_hip = is_hip()
38
40
  _is_cuda = is_cuda()
41
+ _is_cpu_amx_available = cpu_has_amx_support()
42
+ _is_cpu = is_cpu()
39
43
 
40
44
  if _is_cuda:
41
45
  from sgl_kernel import gelu_and_mul, silu_and_mul
46
+ elif _is_cpu and _is_cpu_amx_available:
47
+ pass
42
48
  else:
43
49
  from vllm import _custom_ops as vllm_ops
44
50
  from vllm._custom_ops import scaled_fp8_quant
@@ -32,6 +32,7 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
32
32
 
33
33
  if _use_aiter:
34
34
  from aiter import ActivationType
35
+ from aiter.fused_moe import fused_moe
35
36
  from aiter.fused_moe_bf16_asm import ck_moe_2stages
36
37
  from aiter.ops.shuffle import shuffle_weight
37
38
 
@@ -204,7 +205,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
204
205
  topk_weights, dtype=torch.float32
205
206
  ) # topk_weights must be FP32 (float32)
206
207
 
207
- return ck_moe_2stages(
208
+ return fused_moe(
208
209
  x,
209
210
  layer.w13_weight,
210
211
  layer.w2_weight,
@@ -241,7 +242,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
241
242
  num_fused_shared_experts: int = 0,
242
243
  custom_routing_function: Optional[Callable] = None,
243
244
  correction_bias: Optional[torch.Tensor] = None,
245
+ activation: str = "silu",
246
+ apply_router_weight_on_input: bool = False,
244
247
  inplace: bool = True,
248
+ no_combine: bool = False,
249
+ routed_scaling_factor: Optional[float] = None,
245
250
  ) -> torch.Tensor:
246
251
  return moe_forward_native(
247
252
  layer,
@@ -260,7 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
260
265
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
261
266
  raise NotImplementedError("The TPU backend currently does not support MoE.")
262
267
 
263
- forward_native = forward_cuda
268
+ forward_native = forward_cpu
264
269
 
265
270
 
266
271
  class FusedMoE(torch.nn.Module):
@@ -310,6 +315,8 @@ class FusedMoE(torch.nn.Module):
310
315
  inplace: bool = True,
311
316
  no_combine: bool = False,
312
317
  routed_scaling_factor: Optional[float] = None,
318
+ enable_flashinfer_moe: Optional[bool] = False,
319
+ enable_ep_moe: Optional[bool] = False,
313
320
  ):
314
321
  super().__init__()
315
322
 
@@ -320,9 +327,40 @@ class FusedMoE(torch.nn.Module):
320
327
  self.tp_size = (
321
328
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
322
329
  )
330
+ self.tp_rank = get_tensor_model_parallel_rank()
331
+ self.num_experts = num_experts
332
+ self.expert_map = None
333
+
334
+ if enable_flashinfer_moe and quant_config is None:
335
+ logger.warning("Disable flashinfer MoE when quantization config is None.")
336
+ enable_flashinfer_moe = False
337
+ enable_ep_moe = False
338
+
339
+ self.enable_flashinfer_moe = enable_flashinfer_moe
340
+ if enable_ep_moe:
341
+ assert (
342
+ self.enable_flashinfer_moe
343
+ ), "FusedMoE only supports EP with --enable-flashinfer-moe"
344
+ self.ep_size = self.tp_size
345
+ self.ep_rank = self.tp_rank
346
+ self.tp_size = 1
347
+ self.tp_rank = 0
348
+ # Create a tensor of size num_experts filled with -1
349
+ self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
350
+ # Create a expert map for the local experts
351
+ assert num_experts % self.ep_size == 0
352
+ self.local_num_experts = num_experts // self.ep_size
353
+ self.expert_map[
354
+ self.ep_rank
355
+ * self.local_num_experts : (self.ep_rank + 1)
356
+ * self.local_num_experts
357
+ ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
358
+ else:
359
+ self.ep_size = 1
360
+ self.ep_rank = 0
361
+ self.local_num_experts = num_experts
323
362
  self.routed_scaling_factor = routed_scaling_factor
324
363
  self.top_k = top_k
325
- self.num_experts = num_experts
326
364
  assert intermediate_size % self.tp_size == 0
327
365
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
328
366
  self.reduce_results = reduce_results
@@ -340,7 +378,6 @@ class FusedMoE(torch.nn.Module):
340
378
  self.use_presharded_weights = use_presharded_weights
341
379
  self.inplace = inplace
342
380
  self.no_combine = no_combine
343
- self.local_num_experts = num_experts
344
381
 
345
382
  if quant_config is None:
346
383
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -348,11 +385,13 @@ class FusedMoE(torch.nn.Module):
348
385
  )
349
386
  else:
350
387
  self.quant_method = quant_config.get_quant_method(self, prefix)
388
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
389
+ self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
351
390
  assert self.quant_method is not None
352
391
 
353
392
  self.quant_method.create_weights(
354
393
  layer=self,
355
- num_experts=num_experts,
394
+ num_experts=self.local_num_experts,
356
395
  hidden_size=hidden_size,
357
396
  # FIXME: figure out which intermediate_size to use
358
397
  intermediate_size=self.intermediate_size_per_partition,
@@ -446,12 +485,15 @@ class FusedMoE(torch.nn.Module):
446
485
 
447
486
  # Narrow parameter and load.
448
487
  # w1, gate_proj: Load into first logical weight of w13.
449
- if shard_id == "w1":
450
- expert_data = expert_data.narrow(shard_dim, 0, shard_size)
451
488
  # w3, up_proj: Load into second logical weight of w13.
489
+ # trtllm cutlass kernel assumes differently
490
+ assert shard_id in ("w1", "w3")
491
+ switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
492
+ if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
493
+ start = shard_size
452
494
  else:
453
- assert shard_id == "w3"
454
- expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
495
+ start = 0
496
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
455
497
  expert_data.copy_(loaded_weight)
456
498
 
457
499
  def _load_w2(
@@ -505,6 +547,11 @@ class FusedMoE(torch.nn.Module):
505
547
  assert shard_id in ("w1", "w3")
506
548
  expert_data.copy_(loaded_weight)
507
549
 
550
+ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
551
+ if self.expert_map is None:
552
+ return expert_id
553
+ return self.expert_map[expert_id].item()
554
+
508
555
  def weight_loader(
509
556
  self,
510
557
  param: torch.nn.Parameter,
@@ -513,6 +560,13 @@ class FusedMoE(torch.nn.Module):
513
560
  shard_id: str,
514
561
  expert_id: int,
515
562
  ) -> None:
563
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
564
+ if expert_id == -1:
565
+ return
566
+
567
+ # TP rank is set to 0 if EP is enabled
568
+ tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
569
+
516
570
  # compressed-tensors checkpoints with packed weights are stored flipped
517
571
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
518
572
  # against known CompressionFormat enum values that have this quality
@@ -537,7 +591,6 @@ class FusedMoE(torch.nn.Module):
537
591
  SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
538
592
 
539
593
  expert_data = param.data[expert_id]
540
- tp_rank = get_tensor_model_parallel_rank()
541
594
 
542
595
  # is_transposed: if the dim to shard the weight
543
596
  # should be flipped. Required by GPTQ, compressed-tensors
@@ -545,7 +598,7 @@ class FusedMoE(torch.nn.Module):
545
598
  is_transposed = getattr(param, "is_transposed", False)
546
599
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
547
600
  if is_transposed:
548
- shard_dim = ~shard_dim
601
+ shard_dim = int(not shard_dim)
549
602
 
550
603
  # Case input scale: input_scale loading is only supported for fp8
551
604
  if "input_scale" in weight_name:
@@ -686,9 +739,19 @@ class FusedMoE(torch.nn.Module):
686
739
  activation=self.activation,
687
740
  apply_router_weight_on_input=self.apply_router_weight_on_input,
688
741
  routed_scaling_factor=self.routed_scaling_factor,
742
+ **(
743
+ dict(
744
+ tp_rank=self.tp_rank,
745
+ tp_size=self.tp_size,
746
+ ep_rank=self.ep_rank,
747
+ ep_size=self.ep_size,
748
+ )
749
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
750
+ else {}
751
+ ),
689
752
  )
690
753
 
691
- if self.reduce_results and self.tp_size > 1:
754
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
692
755
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
693
756
 
694
757
  return final_hidden_states