sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 5
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 32,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 32,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 5
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 5
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 5
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 5
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 64,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 64,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 32,
103
+ "num_warps": 4,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -25,9 +25,11 @@ from sglang.srt.layers.quantization.int8_kernel import (
25
25
  sglang_per_token_group_quant_int8,
26
26
  )
27
27
  from sglang.srt.utils import (
28
+ cpu_has_amx_support,
28
29
  direct_register_custom_op,
29
30
  get_bool_env_var,
30
31
  get_device_name,
32
+ is_cpu,
31
33
  is_cuda,
32
34
  is_hip,
33
35
  log_info_on_rank0,
@@ -36,9 +38,13 @@ from sglang.srt.utils import (
36
38
 
37
39
  _is_hip = is_hip()
38
40
  _is_cuda = is_cuda()
41
+ _is_cpu_amx_available = cpu_has_amx_support()
42
+ _is_cpu = is_cpu()
39
43
 
40
44
  if _is_cuda:
41
45
  from sgl_kernel import gelu_and_mul, silu_and_mul
46
+ elif _is_cpu and _is_cpu_amx_available:
47
+ pass
42
48
  else:
43
49
  from vllm import _custom_ops as vllm_ops
44
50
  from vllm._custom_ops import scaled_fp8_quant
@@ -744,9 +750,11 @@ def moe_align_block_size(
744
750
  by block_size for proper block matrix operations.
745
751
  """
746
752
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
747
- sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
748
- max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
753
+ sorted_ids = torch.empty(
754
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
749
755
  )
756
+ sorted_ids.fill_(topk_ids.numel())
757
+
750
758
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
751
759
  expert_ids = torch.empty(
752
760
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -762,6 +770,9 @@ def moe_align_block_size(
762
770
  num_tokens_post_pad,
763
771
  )
764
772
  else:
773
+ cumsum_buffer = torch.empty(
774
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
775
+ )
765
776
  token_cnts_buffer = torch.empty(
766
777
  (num_experts + 1) * num_experts,
767
778
  dtype=torch.int32,
@@ -18,7 +18,14 @@ from sglang.srt.layers.quantization.base_config import (
18
18
  QuantizationConfig,
19
19
  QuantizeMethodBase,
20
20
  )
21
- from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
21
+ from sglang.srt.utils import (
22
+ _process_weight_after_loading,
23
+ cpu_has_amx_support,
24
+ get_bool_env_var,
25
+ is_cpu,
26
+ is_hip,
27
+ set_weight_attrs,
28
+ )
22
29
 
23
30
  if torch.cuda.is_available():
24
31
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -28,10 +35,13 @@ else:
28
35
  import logging
29
36
 
30
37
  _is_hip = is_hip()
38
+ _is_cpu_amx_available = cpu_has_amx_support()
39
+ _is_cpu = is_cpu()
31
40
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
32
41
 
33
42
  if _use_aiter:
34
43
  from aiter import ActivationType
44
+ from aiter.fused_moe import fused_moe
35
45
  from aiter.fused_moe_bf16_asm import ck_moe_2stages
36
46
  from aiter.ops.shuffle import shuffle_weight
37
47
 
@@ -116,6 +126,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
116
126
  requires_grad=False,
117
127
  )
118
128
  torch.cuda.empty_cache()
129
+
130
+ # Pack weight for get better performance on CPU
131
+ if _is_cpu and _is_cpu_amx_available:
132
+ _process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
133
+
119
134
  return
120
135
 
121
136
  def apply(
@@ -204,7 +219,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
204
219
  topk_weights, dtype=torch.float32
205
220
  ) # topk_weights must be FP32 (float32)
206
221
 
207
- return ck_moe_2stages(
222
+ return fused_moe(
208
223
  x,
209
224
  layer.w13_weight,
210
225
  layer.w2_weight,
@@ -241,26 +256,75 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
241
256
  num_fused_shared_experts: int = 0,
242
257
  custom_routing_function: Optional[Callable] = None,
243
258
  correction_bias: Optional[torch.Tensor] = None,
259
+ activation: str = "silu",
260
+ apply_router_weight_on_input: bool = False,
244
261
  inplace: bool = True,
262
+ no_combine: bool = False,
263
+ routed_scaling_factor: Optional[float] = None,
245
264
  ) -> torch.Tensor:
246
- return moe_forward_native(
247
- layer,
248
- x,
249
- use_grouped_topk,
250
- top_k,
251
- router_logits,
252
- renormalize,
253
- topk_group,
254
- num_expert_group,
255
- num_fused_shared_experts,
256
- custom_routing_function,
257
- correction_bias,
258
- )
265
+ assert activation == "silu", f"activation = {activation} is not supported."
266
+
267
+ if (
268
+ getattr(layer, "use_intel_amx_backend", False)
269
+ and not apply_router_weight_on_input
270
+ ):
271
+ topk_weights, topk_ids = select_experts(
272
+ hidden_states=x,
273
+ router_logits=router_logits,
274
+ use_grouped_topk=use_grouped_topk,
275
+ top_k=top_k,
276
+ renormalize=renormalize,
277
+ topk_group=topk_group,
278
+ num_expert_group=num_expert_group,
279
+ num_fused_shared_experts=num_fused_shared_experts,
280
+ custom_routing_function=custom_routing_function,
281
+ correction_bias=correction_bias,
282
+ routed_scaling_factor=routed_scaling_factor,
283
+ )
284
+
285
+ # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
286
+ return torch.ops.sgl_kernel.fused_experts_cpu(
287
+ x,
288
+ layer.w13_weight,
289
+ layer.w2_weight,
290
+ topk_weights.to(
291
+ torch.float
292
+ ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
293
+ topk_ids,
294
+ True, # inplace
295
+ False, # use_int8_w8a8
296
+ False, # use_fp8_w8a16
297
+ None, # w1_scale
298
+ None, # w2_scale
299
+ None, # block_size
300
+ None, # a1_scale
301
+ None, # a2_scale
302
+ True, # is_vnni
303
+ )
304
+ else:
305
+ return moe_forward_native(
306
+ layer,
307
+ x,
308
+ use_grouped_topk,
309
+ top_k,
310
+ router_logits,
311
+ renormalize,
312
+ topk_group,
313
+ num_expert_group,
314
+ num_fused_shared_experts,
315
+ custom_routing_function,
316
+ correction_bias,
317
+ activation,
318
+ apply_router_weight_on_input,
319
+ inplace,
320
+ no_combine,
321
+ routed_scaling_factor,
322
+ )
259
323
 
260
324
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
261
325
  raise NotImplementedError("The TPU backend currently does not support MoE.")
262
326
 
263
- forward_native = forward_cuda
327
+ forward_native = forward_cpu
264
328
 
265
329
 
266
330
  class FusedMoE(torch.nn.Module):
@@ -310,6 +374,8 @@ class FusedMoE(torch.nn.Module):
310
374
  inplace: bool = True,
311
375
  no_combine: bool = False,
312
376
  routed_scaling_factor: Optional[float] = None,
377
+ enable_flashinfer_moe: Optional[bool] = False,
378
+ enable_ep_moe: Optional[bool] = False,
313
379
  ):
314
380
  super().__init__()
315
381
 
@@ -320,9 +386,40 @@ class FusedMoE(torch.nn.Module):
320
386
  self.tp_size = (
321
387
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
322
388
  )
389
+ self.tp_rank = get_tensor_model_parallel_rank()
390
+ self.num_experts = num_experts
391
+ self.expert_map = None
392
+
393
+ if enable_flashinfer_moe and quant_config is None:
394
+ logger.warning("Disable flashinfer MoE when quantization config is None.")
395
+ enable_flashinfer_moe = False
396
+ enable_ep_moe = False
397
+
398
+ self.enable_flashinfer_moe = enable_flashinfer_moe
399
+ if enable_ep_moe:
400
+ assert (
401
+ self.enable_flashinfer_moe
402
+ ), "FusedMoE only supports EP with --enable-flashinfer-moe"
403
+ self.ep_size = self.tp_size
404
+ self.ep_rank = self.tp_rank
405
+ self.tp_size = 1
406
+ self.tp_rank = 0
407
+ # Create a tensor of size num_experts filled with -1
408
+ self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
409
+ # Create a expert map for the local experts
410
+ assert num_experts % self.ep_size == 0
411
+ self.local_num_experts = num_experts // self.ep_size
412
+ self.expert_map[
413
+ self.ep_rank
414
+ * self.local_num_experts : (self.ep_rank + 1)
415
+ * self.local_num_experts
416
+ ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
417
+ else:
418
+ self.ep_size = 1
419
+ self.ep_rank = 0
420
+ self.local_num_experts = num_experts
323
421
  self.routed_scaling_factor = routed_scaling_factor
324
422
  self.top_k = top_k
325
- self.num_experts = num_experts
326
423
  assert intermediate_size % self.tp_size == 0
327
424
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
328
425
  self.reduce_results = reduce_results
@@ -340,7 +437,6 @@ class FusedMoE(torch.nn.Module):
340
437
  self.use_presharded_weights = use_presharded_weights
341
438
  self.inplace = inplace
342
439
  self.no_combine = no_combine
343
- self.local_num_experts = num_experts
344
440
 
345
441
  if quant_config is None:
346
442
  self.quant_method: Optional[QuantizeMethodBase] = (
@@ -348,11 +444,13 @@ class FusedMoE(torch.nn.Module):
348
444
  )
349
445
  else:
350
446
  self.quant_method = quant_config.get_quant_method(self, prefix)
447
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
448
+ self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
351
449
  assert self.quant_method is not None
352
450
 
353
451
  self.quant_method.create_weights(
354
452
  layer=self,
355
- num_experts=num_experts,
453
+ num_experts=self.local_num_experts,
356
454
  hidden_size=hidden_size,
357
455
  # FIXME: figure out which intermediate_size to use
358
456
  intermediate_size=self.intermediate_size_per_partition,
@@ -446,12 +544,15 @@ class FusedMoE(torch.nn.Module):
446
544
 
447
545
  # Narrow parameter and load.
448
546
  # w1, gate_proj: Load into first logical weight of w13.
449
- if shard_id == "w1":
450
- expert_data = expert_data.narrow(shard_dim, 0, shard_size)
451
547
  # w3, up_proj: Load into second logical weight of w13.
548
+ # trtllm cutlass kernel assumes differently
549
+ assert shard_id in ("w1", "w3")
550
+ switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
551
+ if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
552
+ start = shard_size
452
553
  else:
453
- assert shard_id == "w3"
454
- expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
554
+ start = 0
555
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
455
556
  expert_data.copy_(loaded_weight)
456
557
 
457
558
  def _load_w2(
@@ -505,6 +606,11 @@ class FusedMoE(torch.nn.Module):
505
606
  assert shard_id in ("w1", "w3")
506
607
  expert_data.copy_(loaded_weight)
507
608
 
609
+ def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
610
+ if self.expert_map is None:
611
+ return expert_id
612
+ return self.expert_map[expert_id].item()
613
+
508
614
  def weight_loader(
509
615
  self,
510
616
  param: torch.nn.Parameter,
@@ -513,6 +619,13 @@ class FusedMoE(torch.nn.Module):
513
619
  shard_id: str,
514
620
  expert_id: int,
515
621
  ) -> None:
622
+ expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
623
+ if expert_id == -1:
624
+ return
625
+
626
+ # TP rank is set to 0 if EP is enabled
627
+ tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
628
+
516
629
  # compressed-tensors checkpoints with packed weights are stored flipped
517
630
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
518
631
  # against known CompressionFormat enum values that have this quality
@@ -537,7 +650,6 @@ class FusedMoE(torch.nn.Module):
537
650
  SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
538
651
 
539
652
  expert_data = param.data[expert_id]
540
- tp_rank = get_tensor_model_parallel_rank()
541
653
 
542
654
  # is_transposed: if the dim to shard the weight
543
655
  # should be flipped. Required by GPTQ, compressed-tensors
@@ -545,7 +657,7 @@ class FusedMoE(torch.nn.Module):
545
657
  is_transposed = getattr(param, "is_transposed", False)
546
658
  shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
547
659
  if is_transposed:
548
- shard_dim = ~shard_dim
660
+ shard_dim = int(not shard_dim)
549
661
 
550
662
  # Case input scale: input_scale loading is only supported for fp8
551
663
  if "input_scale" in weight_name:
@@ -686,9 +798,19 @@ class FusedMoE(torch.nn.Module):
686
798
  activation=self.activation,
687
799
  apply_router_weight_on_input=self.apply_router_weight_on_input,
688
800
  routed_scaling_factor=self.routed_scaling_factor,
801
+ **(
802
+ dict(
803
+ tp_rank=self.tp_rank,
804
+ tp_size=self.tp_size,
805
+ ep_rank=self.ep_rank,
806
+ ep_size=self.ep_size,
807
+ )
808
+ if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
809
+ else {}
810
+ ),
689
811
  )
690
812
 
691
- if self.reduce_results and self.tp_size > 1:
813
+ if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
692
814
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
693
815
 
694
816
  return final_hidden_states
@@ -28,19 +28,34 @@ from sglang.srt.managers.expert_location_dispatch import (
28
28
  topk_ids_logical_to_physical,
29
29
  )
30
30
  from sglang.srt.managers.schedule_batch import global_server_args_dict
31
- from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
31
+ from sglang.srt.utils import (
32
+ cpu_has_amx_support,
33
+ get_bool_env_var,
34
+ get_compiler_backend,
35
+ is_cpu,
36
+ is_cuda,
37
+ is_hip,
38
+ )
32
39
 
33
40
  _is_cuda = is_cuda()
34
41
  _is_hip = is_hip()
42
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
43
+ _is_cpu_amx_available = cpu_has_amx_support()
44
+ _is_cpu = is_cpu()
35
45
 
36
46
  if _is_cuda:
37
47
  from sgl_kernel import moe_fused_gate
38
48
 
39
49
  if _is_cuda or _is_hip:
40
50
  from sgl_kernel import topk_softmax
51
+ if _use_aiter:
52
+ try:
53
+ from aiter import biased_grouped_topk as aiter_biased_grouped_topk
54
+ except ImportError:
55
+ raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
41
56
 
42
57
 
43
- def fused_topk_native(
58
+ def fused_topk_torch_native(
44
59
  hidden_states: torch.Tensor,
45
60
  gating_output: torch.Tensor,
46
61
  topk: int,
@@ -61,6 +76,20 @@ def fused_topk_native(
61
76
  return topk_weights, topk_ids
62
77
 
63
78
 
79
+ def fused_topk_cpu(
80
+ hidden_states: torch.Tensor,
81
+ gating_output: torch.Tensor,
82
+ topk: int,
83
+ renormalize: bool,
84
+ ):
85
+ return torch.ops.sgl_kernel.topk_softmax_cpu(
86
+ hidden_states=hidden_states,
87
+ gating_output=gating_output,
88
+ topk=topk,
89
+ renormalize=renormalize,
90
+ )
91
+
92
+
64
93
  def fused_topk(
65
94
  hidden_states: torch.Tensor,
66
95
  gating_output: torch.Tensor,
@@ -115,7 +144,7 @@ def _fused_topk_postprocess(
115
144
 
116
145
  # This is used by the Deepseek V2/V3/R1 series models
117
146
  @torch.compile(dynamic=True, backend=get_compiler_backend())
118
- def grouped_topk(
147
+ def grouped_topk_gpu(
119
148
  hidden_states: torch.Tensor,
120
149
  gating_output: torch.Tensor,
121
150
  topk: int,
@@ -171,6 +200,32 @@ def grouped_topk(
171
200
  return topk_weights, topk_ids
172
201
 
173
202
 
203
+ def grouped_topk_cpu(
204
+ hidden_states: torch.Tensor,
205
+ gating_output: torch.Tensor,
206
+ topk: int,
207
+ renormalize: bool,
208
+ num_expert_group: int = 0,
209
+ topk_group: int = 0,
210
+ num_fused_shared_experts: int = 0,
211
+ routed_scaling_factor: Optional[float] = None,
212
+ num_token_non_padded: Optional[torch.Tensor] = None,
213
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
214
+ ):
215
+ assert expert_location_dispatch_info is None
216
+ return torch.ops.sgl_kernel.grouped_topk_cpu(
217
+ hidden_states,
218
+ gating_output,
219
+ topk,
220
+ renormalize,
221
+ num_expert_group,
222
+ topk_group,
223
+ num_fused_shared_experts,
224
+ routed_scaling_factor,
225
+ num_token_non_padded,
226
+ )
227
+
228
+
174
229
  def biased_grouped_topk_impl(
175
230
  hidden_states: torch.Tensor,
176
231
  gating_output: torch.Tensor,
@@ -258,7 +313,7 @@ def _biased_grouped_topk_postprocess(
258
313
  return topk_ids
259
314
 
260
315
 
261
- def biased_grouped_topk(
316
+ def biased_grouped_topk_gpu(
262
317
  hidden_states: torch.Tensor,
263
318
  gating_output: torch.Tensor,
264
319
  correction_bias: torch.Tensor,
@@ -299,6 +354,25 @@ def biased_grouped_topk(
299
354
  topk_ids, expert_location_dispatch_info, num_token_non_padded
300
355
  )
301
356
  return topk_weights, topk_ids
357
+ elif _use_aiter:
358
+ token = gating_output.shape[0]
359
+ device = gating_output.device
360
+ assert (
361
+ hidden_states.shape[0] == gating_output.shape[0]
362
+ ), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
363
+ topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
364
+ topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
365
+ aiter_biased_grouped_topk(
366
+ gating_output,
367
+ correction_bias,
368
+ topk_weights,
369
+ topk_ids,
370
+ num_expert_group,
371
+ topk_group,
372
+ renormalize,
373
+ routed_scaling_factor,
374
+ )
375
+ return topk_weights, topk_ids
302
376
  else:
303
377
  biased_grouped_topk_fn = (
304
378
  torch.compile(
@@ -322,6 +396,45 @@ def biased_grouped_topk(
322
396
  )
323
397
 
324
398
 
399
+ def biased_grouped_topk_cpu(
400
+ hidden_states: torch.Tensor,
401
+ gating_output: torch.Tensor,
402
+ correction_bias: torch.Tensor,
403
+ topk: int,
404
+ renormalize: bool,
405
+ num_expert_group: int = 0,
406
+ topk_group: int = 0,
407
+ compiled: bool = True,
408
+ num_fused_shared_experts: int = 0,
409
+ routed_scaling_factor: Optional[float] = None,
410
+ num_token_non_padded: Optional[torch.Tensor] = None,
411
+ expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
412
+ ):
413
+ assert expert_location_dispatch_info is None
414
+ return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
415
+ hidden_states,
416
+ gating_output,
417
+ correction_bias,
418
+ topk,
419
+ renormalize,
420
+ num_expert_group,
421
+ topk_group,
422
+ num_fused_shared_experts,
423
+ routed_scaling_factor,
424
+ num_token_non_padded,
425
+ )
426
+
427
+
428
+ if _is_cpu and _is_cpu_amx_available:
429
+ biased_grouped_topk = biased_grouped_topk_cpu
430
+ grouped_topk = grouped_topk_cpu
431
+ fused_topk_native = fused_topk_cpu
432
+ else:
433
+ biased_grouped_topk = biased_grouped_topk_gpu
434
+ grouped_topk = grouped_topk_gpu
435
+ fused_topk_native = fused_topk_torch_native
436
+
437
+
325
438
  def select_experts(
326
439
  hidden_states: torch.Tensor,
327
440
  router_logits: torch.Tensor,
@@ -14,14 +14,18 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
14
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
15
  from sglang.srt.layers.quantization.utils import (
16
16
  all_close_1d,
17
+ cpu_has_amx_support,
17
18
  per_tensor_dequantize,
18
19
  replace_parameter,
19
20
  )
20
- from sglang.srt.utils import is_cuda, set_weight_attrs
21
+ from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
21
22
 
22
23
  _is_cuda = is_cuda()
24
+ _is_npu = is_npu()
25
+ _is_cpu_amx_available = cpu_has_amx_support()
26
+ _is_cpu = is_cpu()
23
27
 
24
- if not _is_cuda:
28
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
25
29
  from vllm import _custom_ops as vllm_ops
26
30
  from vllm._custom_ops import scaled_fp8_quant
27
31