sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import einops
5
5
  import torch
6
- from sgl_kernel import silu_and_mul
7
6
  from torch.nn import Module
8
7
 
9
8
  from sglang.srt.custom_op import CustomOp
@@ -11,6 +10,8 @@ from sglang.srt.distributed import (
11
10
  get_tensor_model_parallel_rank,
12
11
  get_tensor_model_parallel_world_size,
13
12
  )
13
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
14
15
  from sglang.srt.layers.moe.ep_moe.kernels import (
15
16
  ep_gather,
16
17
  ep_scatter,
@@ -40,24 +41,34 @@ from sglang.srt.layers.quantization.fp8_kernel import (
40
41
  sglang_per_token_quant_fp8,
41
42
  )
42
43
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
43
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
44
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
45
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
45
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
46
  from sglang.srt.utils import (
48
47
  DeepEPMode,
48
+ ceil_div,
49
49
  dispose_tensor,
50
50
  get_bool_env_var,
51
51
  is_hip,
52
+ is_npu,
52
53
  set_weight_attrs,
53
54
  )
54
55
 
55
56
  _is_hip = is_hip()
57
+ _is_npu = is_npu()
56
58
  _is_fp8_fnuz = is_fp8_fnuz()
59
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
60
+
61
+ if not _is_npu:
62
+ from sgl_kernel import silu_and_mul
57
63
 
58
64
  if _is_hip:
59
65
  from vllm._custom_ops import scaled_fp8_quant
60
66
 
67
+ if _use_aiter:
68
+ from aiter import ActivationType, QuantType
69
+ from aiter.fused_moe import fused_moe
70
+ from aiter.ops.shuffle import shuffle_weight
71
+
61
72
  logger = logging.getLogger(__name__)
62
73
 
63
74
 
@@ -1046,6 +1057,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
1046
1057
  w2_weight_scale, requires_grad=False
1047
1058
  )
1048
1059
  layer.w2_input_scale = None
1060
+ if _use_aiter:
1061
+ layer.w13_weight = torch.nn.Parameter(
1062
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
1063
+ requires_grad=False,
1064
+ )
1065
+ layer.w2_weight = torch.nn.Parameter(
1066
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
1067
+ requires_grad=False,
1068
+ )
1049
1069
  return
1050
1070
 
1051
1071
  def apply(
@@ -1117,18 +1137,36 @@ class DeepEPMoE(EPMoE):
1117
1137
  assert (
1118
1138
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1119
1139
  ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1120
- self.w13_weight_fp8 = (
1121
- self.w13_weight,
1122
- (
1123
- self.w13_weight_scale_inv
1124
- if self.use_block_quant
1125
- else self.w13_weight_scale
1126
- ),
1127
- )
1128
- self.w2_weight_fp8 = (
1129
- self.w2_weight,
1130
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
1131
- )
1140
+ if _use_aiter:
1141
+ # expert_mask is of size (self.num_experts_per_partition + 1),
1142
+ # the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
1143
+ # for instance, if we have 4 experts on this rank, we would have a expert_mask like:
1144
+ # self.expert_mask = [1, 1, 1, 1, 0]
1145
+ # idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
1146
+ self.expert_mask = torch.zeros(
1147
+ (self.num_experts_per_partition + 1),
1148
+ device=torch.cuda.current_device(),
1149
+ dtype=torch.int,
1150
+ )
1151
+ # the last one is invalid rank_id
1152
+ self.expert_mask[:-1] = 1
1153
+ else:
1154
+ self.w13_weight_fp8 = (
1155
+ self.w13_weight,
1156
+ (
1157
+ self.w13_weight_scale_inv
1158
+ if self.use_block_quant
1159
+ else self.w13_weight_scale
1160
+ ),
1161
+ )
1162
+ self.w2_weight_fp8 = (
1163
+ self.w2_weight,
1164
+ (
1165
+ self.w2_weight_scale_inv
1166
+ if self.use_block_quant
1167
+ else self.w2_weight_scale
1168
+ ),
1169
+ )
1132
1170
 
1133
1171
  def forward(
1134
1172
  self,
@@ -1140,9 +1178,14 @@ class DeepEPMoE(EPMoE):
1140
1178
  masked_m: torch.Tensor,
1141
1179
  expected_m: int,
1142
1180
  num_recv_tokens_per_expert: List[int],
1143
- forward_mode: ForwardMode,
1181
+ forward_batch: ForwardBatch,
1144
1182
  ):
1145
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
1183
+ if _use_aiter:
1184
+ # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1185
+ return self.forward_aiter(hidden_states, topk_idx, topk_weights)
1186
+ resolved_deepep_mode = self.deepep_mode.resolve(
1187
+ forward_batch.is_extend_in_batch
1188
+ )
1146
1189
  if resolved_deepep_mode == DeepEPMode.normal:
1147
1190
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
1148
1191
  return self.forward_deepgemm_contiguous(
@@ -1274,6 +1317,37 @@ class DeepEPMoE(EPMoE):
1274
1317
  )
1275
1318
  return down_output
1276
1319
 
1320
+ def forward_aiter(
1321
+ self,
1322
+ hidden_states: torch.Tensor,
1323
+ topk_idx: torch.Tensor,
1324
+ topk_weights: torch.Tensor,
1325
+ ):
1326
+ if hidden_states.shape[0] == 0:
1327
+ return hidden_states
1328
+ # in original deepep, idx == -1 meaning invalid and will not be processed.
1329
+ # aiter does not accept -1, we use a expert mask to make these idx invalid
1330
+ # (idx == num_experts_per_partition) meaning not used in aiter fused_moe
1331
+ topk_idx_copy = topk_idx.to(torch.int32)
1332
+ topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
1333
+
1334
+ return fused_moe(
1335
+ hidden_states,
1336
+ self.w13_weight,
1337
+ self.w2_weight,
1338
+ topk_weights,
1339
+ topk_idx_copy,
1340
+ w1_scale=self.w13_weight_scale_inv,
1341
+ w2_scale=self.w2_weight_scale_inv,
1342
+ quant_type=QuantType.per_128x128,
1343
+ activation=(
1344
+ ActivationType.Silu
1345
+ if self.activation == "silu"
1346
+ else ActivationType.Gelu
1347
+ ),
1348
+ expert_mask=self.expert_mask,
1349
+ )
1350
+
1277
1351
  def forward_deepgemm_contiguous(
1278
1352
  self,
1279
1353
  hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
@@ -1303,10 +1377,19 @@ class DeepEPMoE(EPMoE):
1303
1377
  device=hidden_states_fp8.device,
1304
1378
  dtype=hidden_states_fp8.dtype,
1305
1379
  ),
1306
- torch.empty(
1307
- (all_tokens, K // 128),
1308
- device=hidden_states_fp8.device,
1309
- dtype=torch.float32,
1380
+ (
1381
+ # TODO check whether need `zeros`
1382
+ torch.zeros(
1383
+ (ceil_div(K // 128, 4), all_tokens),
1384
+ device=hidden_states_fp8.device,
1385
+ dtype=torch.int,
1386
+ ).transpose(0, 1)
1387
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1388
+ else torch.empty(
1389
+ (all_tokens, K // 128),
1390
+ device=hidden_states_fp8.device,
1391
+ dtype=torch.float32,
1392
+ )
1310
1393
  ),
1311
1394
  ]
1312
1395
  m_indices = torch.empty(
@@ -1332,6 +1415,7 @@ class DeepEPMoE(EPMoE):
1332
1415
  input_tensor[1],
1333
1416
  m_indices,
1334
1417
  output_index,
1418
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1335
1419
  )
1336
1420
  dispose_tensor(hidden_states_fp8)
1337
1421
 
@@ -1340,7 +1424,8 @@ class DeepEPMoE(EPMoE):
1340
1424
  device=hidden_states_fp8_device,
1341
1425
  dtype=torch.bfloat16,
1342
1426
  )
1343
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
1427
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1428
+ input_tensor[1] = tma_align_input_scale(input_tensor[1])
1344
1429
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1345
1430
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1346
1431
  )
@@ -1361,10 +1446,15 @@ class DeepEPMoE(EPMoE):
1361
1446
  dtype=torch.bfloat16,
1362
1447
  )
1363
1448
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1364
- down_input, scale_block_size
1449
+ down_input,
1450
+ scale_block_size,
1451
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1452
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1453
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1365
1454
  )
1366
1455
  del down_input
1367
- down_input_scale = tma_align_input_scale(down_input_scale)
1456
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1457
+ down_input_scale = tma_align_input_scale(down_input_scale)
1368
1458
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1369
1459
  (down_input_fp8, down_input_scale),
1370
1460
  self.w2_weight_fp8,
@@ -1,12 +1,16 @@
1
1
  import logging
2
2
  from dataclasses import dataclass
3
3
 
4
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
4
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
5
- from sglang.srt.managers.expert_distribution import (
6
- get_global_expert_distribution_recorder,
7
- )
8
6
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
- from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config
7
+ from sglang.srt.utils import (
8
+ DeepEPMode,
9
+ get_bool_env_var,
10
+ get_int_env_var,
11
+ is_hip,
12
+ load_json_config,
13
+ )
10
14
 
11
15
  try:
12
16
  from deep_ep import Buffer, Config
@@ -30,7 +34,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
30
34
  deepep_post_reorder_triton_kernel,
31
35
  deepep_run_moe_deep_preprocess,
32
36
  )
33
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+
39
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
34
40
 
35
41
  logger = logging.getLogger(__name__)
36
42
 
@@ -238,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
238
244
  topk_idx = topk_idx.to(torch.int64)
239
245
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
240
246
  # TODO hard code 128 block quant,use fp8 communication
241
- hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
247
+ hidden_states = sglang_per_token_group_quant_fp8(
248
+ hidden_states,
249
+ 128,
250
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
251
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
252
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
253
+ )
242
254
  previous_event = Buffer.capture() if self.async_finish else None
243
255
  return hidden_states, topk_idx, topk_weights, previous_event
244
256
 
@@ -376,6 +388,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
376
388
  Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
377
389
  https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
378
390
  """
391
+ if _use_aiter:
392
+ # skip permutation here as aiter fused_moe has fused inside
393
+ reorder_topk_ids = torch.empty(
394
+ (0,), device=hidden_states.device, dtype=torch.int64
395
+ )
396
+ seg_indptr = torch.zeros(
397
+ (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
398
+ )
399
+ return reorder_topk_ids, seg_indptr, hidden_states
379
400
 
380
401
  reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
381
402
  topk_idx, self.num_experts
@@ -409,7 +430,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
409
430
  topk_idx: torch.Tensor,
410
431
  topk_weights: torch.Tensor,
411
432
  ):
412
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
433
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
413
434
  output = hidden_states
414
435
  else:
415
436
  if hidden_states.shape[0] > 0:
@@ -665,21 +686,21 @@ class DeepEPDispatcher:
665
686
  hidden_states: torch.Tensor,
666
687
  topk_idx: torch.Tensor,
667
688
  topk_weights: torch.Tensor,
668
- forward_mode: ForwardMode = None,
689
+ forward_batch: ForwardBatch,
669
690
  ):
670
691
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
671
- inner_state = self._get_impl(forward_mode).dispatch_a(
692
+ inner_state = self._get_impl(forward_batch).dispatch_a(
672
693
  hidden_states=hidden_states,
673
694
  topk_idx=topk_idx,
674
695
  topk_weights=topk_weights,
675
696
  )
676
- self._dispatch_intermediate_state = forward_mode, inner_state
697
+ self._dispatch_intermediate_state = forward_batch, inner_state
677
698
 
678
699
  def dispatch_b(self):
679
700
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
680
- forward_mode, inner_state = self._dispatch_intermediate_state
701
+ forward_batch, inner_state = self._dispatch_intermediate_state
681
702
  del self._dispatch_intermediate_state
682
- return self._get_impl(forward_mode).dispatch_b(*inner_state)
703
+ return self._get_impl(forward_batch).dispatch_b(*inner_state)
683
704
 
684
705
  def combine(self, *args, **kwargs) -> Tuple:
685
706
  self.combine_a(*args, **kwargs)
@@ -691,24 +712,26 @@ class DeepEPDispatcher:
691
712
  hidden_states: torch.Tensor,
692
713
  topk_idx: torch.Tensor,
693
714
  topk_weights: torch.Tensor,
694
- forward_mode: ForwardMode,
715
+ forward_batch: ForwardBatch,
695
716
  ):
696
717
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
697
- inner_state = self._get_impl(forward_mode).combine_a(
718
+ inner_state = self._get_impl(forward_batch).combine_a(
698
719
  hidden_states=hidden_states,
699
720
  topk_idx=topk_idx,
700
721
  topk_weights=topk_weights,
701
722
  )
702
- self._combine_intermediate_state = forward_mode, inner_state
723
+ self._combine_intermediate_state = forward_batch, inner_state
703
724
 
704
725
  def combine_b(self):
705
726
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
706
- forward_mode, inner_state = self._combine_intermediate_state
727
+ forward_batch, inner_state = self._combine_intermediate_state
707
728
  del self._combine_intermediate_state
708
- return self._get_impl(forward_mode).combine_b(*inner_state)
729
+ return self._get_impl(forward_batch).combine_b(*inner_state)
709
730
 
710
- def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
711
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
731
+ def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
732
+ resolved_deepep_mode = self.deepep_mode.resolve(
733
+ forward_batch.is_extend_in_batch
734
+ )
712
735
  if resolved_deepep_mode == DeepEPMode.normal:
713
736
  return self._normal_dispatcher
714
737
  elif resolved_deepep_mode == DeepEPMode.low_latency:
@@ -77,8 +77,15 @@ def moe_forward_native(
77
77
  custom_routing_function: Optional[Callable] = None,
78
78
  correction_bias: Optional[torch.Tensor] = None,
79
79
  activation: str = "silu",
80
+ apply_router_weight_on_input: bool = False,
81
+ inplace: bool = True,
82
+ no_combine: bool = False,
80
83
  routed_scaling_factor: Optional[float] = None,
81
84
  ) -> torch.Tensor:
85
+
86
+ if apply_router_weight_on_input:
87
+ raise NotImplementedError()
88
+
82
89
  topk_weights, topk_ids = select_experts(
83
90
  hidden_states=x,
84
91
  router_logits=router_logits,
@@ -12,7 +12,6 @@ import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
- from sglang.math_utils import ceil_div
16
15
  from sglang.srt.layers.moe.topk import select_experts
17
16
  from sglang.srt.layers.quantization.fp8_kernel import (
18
17
  per_token_group_quant_fp8,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
25
24
  sglang_per_token_group_quant_int8,
26
25
  )
27
26
  from sglang.srt.utils import (
27
+ ceil_div,
28
28
  cpu_has_amx_support,
29
29
  direct_register_custom_op,
30
30
  get_bool_env_var,
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
32
32
  is_cpu,
33
33
  is_cuda,
34
34
  is_hip,
35
- log_info_on_rank0,
36
35
  next_power_of_2,
37
36
  )
38
37
 
@@ -750,9 +749,11 @@ def moe_align_block_size(
750
749
  by block_size for proper block matrix operations.
751
750
  """
752
751
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
753
- sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
754
- max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
752
+ sorted_ids = torch.empty(
753
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
755
754
  )
755
+ sorted_ids.fill_(topk_ids.numel())
756
+
756
757
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
757
758
  expert_ids = torch.empty(
758
759
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
@@ -768,6 +769,9 @@ def moe_align_block_size(
768
769
  num_tokens_post_pad,
769
770
  )
770
771
  else:
772
+ cumsum_buffer = torch.empty(
773
+ (num_experts + 1,), dtype=torch.int32, device=topk_ids.device
774
+ )
771
775
  token_cnts_buffer = torch.empty(
772
776
  (num_experts + 1) * num_experts,
773
777
  dtype=torch.int32,
@@ -12,13 +12,22 @@ from sglang.srt.distributed import (
12
12
  get_tensor_model_parallel_world_size,
13
13
  tensor_model_parallel_all_reduce,
14
14
  )
15
+ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
15
16
  from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
16
17
  from sglang.srt.layers.moe.topk import select_experts
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
19
20
  QuantizeMethodBase,
20
21
  )
21
- from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
22
+ from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
23
+ from sglang.srt.utils import (
24
+ cpu_has_amx_support,
25
+ get_bool_env_var,
26
+ is_cpu,
27
+ is_hip,
28
+ set_weight_attrs,
29
+ use_intel_amx_backend,
30
+ )
22
31
 
23
32
  if torch.cuda.is_available():
24
33
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -28,6 +37,8 @@ else:
28
37
  import logging
29
38
 
30
39
  _is_hip = is_hip()
40
+ _is_cpu_amx_available = cpu_has_amx_support()
41
+ _is_cpu = is_cpu()
31
42
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
32
43
 
33
44
  if _use_aiter:
@@ -117,6 +128,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
117
128
  requires_grad=False,
118
129
  )
119
130
  torch.cuda.empty_cache()
131
+
132
+ # Pack weight for get better performance on CPU
133
+ if _is_cpu and _is_cpu_amx_available:
134
+ _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
135
+
120
136
  return
121
137
 
122
138
  def apply(
@@ -247,6 +263,81 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
247
263
  inplace: bool = True,
248
264
  no_combine: bool = False,
249
265
  routed_scaling_factor: Optional[float] = None,
266
+ ) -> torch.Tensor:
267
+ assert activation == "silu", f"activation = {activation} is not supported."
268
+
269
+ if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
270
+ topk_weights, topk_ids = select_experts(
271
+ hidden_states=x,
272
+ router_logits=router_logits,
273
+ use_grouped_topk=use_grouped_topk,
274
+ top_k=top_k,
275
+ renormalize=renormalize,
276
+ topk_group=topk_group,
277
+ num_expert_group=num_expert_group,
278
+ num_fused_shared_experts=num_fused_shared_experts,
279
+ custom_routing_function=custom_routing_function,
280
+ correction_bias=correction_bias,
281
+ routed_scaling_factor=routed_scaling_factor,
282
+ )
283
+
284
+ # TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
285
+ return torch.ops.sgl_kernel.fused_experts_cpu(
286
+ x,
287
+ layer.w13_weight,
288
+ layer.w2_weight,
289
+ topk_weights.to(
290
+ torch.float
291
+ ), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
292
+ topk_ids,
293
+ False, # inplace # See [Note] inplace should be False in fused_experts.
294
+ False, # use_int8_w8a8
295
+ False, # use_fp8_w8a16
296
+ None, # w1_scale
297
+ None, # w2_scale
298
+ None, # block_size
299
+ None, # a1_scale
300
+ None, # a2_scale
301
+ True, # is_vnni
302
+ )
303
+ else:
304
+ return moe_forward_native(
305
+ layer,
306
+ x,
307
+ use_grouped_topk,
308
+ top_k,
309
+ router_logits,
310
+ renormalize,
311
+ topk_group,
312
+ num_expert_group,
313
+ num_fused_shared_experts,
314
+ custom_routing_function,
315
+ correction_bias,
316
+ activation,
317
+ apply_router_weight_on_input,
318
+ inplace,
319
+ no_combine,
320
+ routed_scaling_factor,
321
+ )
322
+
323
+ def forward_npu(
324
+ self,
325
+ layer: torch.nn.Module,
326
+ x: torch.Tensor,
327
+ use_grouped_topk: bool,
328
+ top_k: int,
329
+ router_logits: torch.Tensor,
330
+ renormalize: bool,
331
+ topk_group: Optional[int] = None,
332
+ num_expert_group: Optional[int] = None,
333
+ num_fused_shared_experts: int = 0,
334
+ custom_routing_function: Optional[Callable] = None,
335
+ correction_bias: Optional[torch.Tensor] = None,
336
+ activation: str = "silu",
337
+ apply_router_weight_on_input: bool = False,
338
+ inplace: bool = True,
339
+ no_combine: bool = False,
340
+ routed_scaling_factor: Optional[float] = None,
250
341
  ) -> torch.Tensor:
251
342
  return moe_forward_native(
252
343
  layer,
@@ -260,6 +351,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
260
351
  num_fused_shared_experts,
261
352
  custom_routing_function,
262
353
  correction_bias,
354
+ activation,
355
+ apply_router_weight_on_input,
356
+ inplace,
357
+ no_combine,
358
+ routed_scaling_factor,
263
359
  )
264
360
 
265
361
  def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
@@ -478,11 +574,6 @@ class FusedMoE(torch.nn.Module):
478
574
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
479
575
  shard_size = expert_data.shape[shard_dim] // 2
480
576
 
481
- if not self.use_presharded_weights:
482
- loaded_weight = loaded_weight.narrow(
483
- shard_dim, shard_size * tp_rank, shard_size
484
- )
485
-
486
577
  # Narrow parameter and load.
487
578
  # w1, gate_proj: Load into first logical weight of w13.
488
579
  # w3, up_proj: Load into second logical weight of w13.
@@ -493,7 +584,24 @@ class FusedMoE(torch.nn.Module):
493
584
  start = shard_size
494
585
  else:
495
586
  start = 0
496
- expert_data = expert_data.narrow(shard_dim, start, shard_size)
587
+
588
+ if _is_cpu:
589
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
590
+ expert_data,
591
+ loaded_weight,
592
+ start,
593
+ shard_size * tp_rank,
594
+ shard_dim,
595
+ shard_size,
596
+ not self.use_presharded_weights,
597
+ )
598
+ else:
599
+ if not self.use_presharded_weights:
600
+ loaded_weight = loaded_weight.narrow(
601
+ shard_dim, shard_size * tp_rank, shard_size
602
+ )
603
+
604
+ expert_data = expert_data.narrow(shard_dim, start, shard_size)
497
605
  expert_data.copy_(loaded_weight)
498
606
 
499
607
  def _load_w2(
@@ -510,10 +618,21 @@ class FusedMoE(torch.nn.Module):
510
618
  # Narrow parameter and load.
511
619
  shard_size = expert_data.shape[shard_dim]
512
620
 
513
- if not self.use_presharded_weights:
514
- loaded_weight = loaded_weight.narrow(
515
- shard_dim, shard_size * tp_rank, shard_size
621
+ if _is_cpu:
622
+ expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
623
+ expert_data,
624
+ loaded_weight,
625
+ 0, # param_data_start
626
+ shard_size * tp_rank,
627
+ shard_dim,
628
+ shard_size,
629
+ not self.use_presharded_weights,
516
630
  )
631
+ else:
632
+ if not self.use_presharded_weights:
633
+ loaded_weight = loaded_weight.narrow(
634
+ shard_dim, shard_size * tp_rank, shard_size
635
+ )
517
636
 
518
637
  # w2, down_proj: Load into only logical weight of w2.
519
638
  expert_data.copy_(loaded_weight)