sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -4,9 +4,8 @@ from typing import List, Optional
4
4
  import torch
5
5
  import triton
6
6
 
7
- from sglang.math_utils import ceil_div
8
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
- from sglang.srt.utils import dispose_tensor, is_cuda
8
+ from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -147,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
147
146
 
148
147
  def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
149
148
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
149
+
150
150
  seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
151
151
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
152
152
 
@@ -159,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
159
159
  compute_src2dst_triton_kernel[grid](
160
160
  reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
161
161
  )
162
+
163
+ return reorder_topk_ids, src2dst, seg_indptr
164
+
165
+
166
+ def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
167
+ reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
168
+
169
+ seg_indptr = torch.zeros(
170
+ local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
171
+ )
172
+ src2dst = torch.empty(
173
+ local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
174
+ )
175
+
176
+ BLOCK_SIZE = 512
177
+ grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
178
+ compute_src2dst_triton_kernel[grid](
179
+ reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
180
+ )
181
+
162
182
  return reorder_topk_ids, src2dst, seg_indptr
163
183
 
164
184
 
185
+ @triton.jit
186
+ def pre_reorder_triton_kernel_for_cutlass_moe(
187
+ input_ptr,
188
+ gateup_input_ptr,
189
+ src2dst_ptr,
190
+ topk_ids_ptr,
191
+ a1_scales_ptr,
192
+ num_experts,
193
+ topk,
194
+ hidden_size,
195
+ BLOCK_SIZE: tl.constexpr,
196
+ ):
197
+ OutDtype = gateup_input_ptr.dtype.element_ty
198
+
199
+ src_idx = tl.program_id(0)
200
+ src2dst_ptr = src2dst_ptr + src_idx * topk
201
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
202
+
203
+ src_ptr = input_ptr + src_idx * hidden_size
204
+ for idx in range(topk):
205
+ expert_id = tl.load(topk_ids_ptr + idx)
206
+ if expert_id != num_experts:
207
+ if a1_scales_ptr is not None:
208
+ scale = 1.0 / tl.load(a1_scales_ptr)
209
+ else:
210
+ scale = 1.0
211
+
212
+ dst_idx = tl.load(src2dst_ptr + idx)
213
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
214
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
215
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
216
+ mask = offset < hidden_size
217
+ in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
218
+ out_data = (in_data * scale).to(OutDtype)
219
+ tl.store(dst_ptr + offset, out_data, mask=mask)
220
+
221
+
165
222
  @triton.jit
166
223
  def pre_reorder_triton_kernel(
167
224
  input_ptr,
@@ -814,14 +871,17 @@ def _fwd_kernel_ep_scatter_2(
814
871
  offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
815
872
  mask = offset_in < HIDDEN_SIZE
816
873
 
817
- offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
818
- mask_s = offset_in_s < SCALE_HIDDEN_SIZE
874
+ index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
875
+ mask_s = index_in_s < SCALE_HIDDEN_SIZE
819
876
 
820
877
  for token_id_int32 in range(start_token_id, total_token_num, grid_num):
821
878
  token_id = token_id_int32.to(tl.int64)
822
879
  to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
823
880
  to_copy_s = tl.load(
824
- recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
881
+ recv_x_scale
882
+ + token_id * recv_x_scale_stride0
883
+ + index_in_s * recv_x_scale_stride1,
884
+ mask=mask_s,
825
885
  )
826
886
 
827
887
  for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
@@ -842,7 +902,11 @@ def _fwd_kernel_ep_scatter_2(
842
902
  output_tensor_scale + dest_token_index * output_tensor_scale_stride0
843
903
  )
844
904
  tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
845
- tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s)
905
+ tl.store(
906
+ output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
907
+ to_copy_s,
908
+ mask=mask_s,
909
+ )
846
910
 
847
911
 
848
912
  # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py
@@ -857,6 +921,7 @@ def ep_scatter(
857
921
  output_tensor_scale: torch.Tensor,
858
922
  m_indices: torch.Tensor,
859
923
  output_index: torch.Tensor,
924
+ scale_ue8m0: bool = False,
860
925
  ):
861
926
  BLOCK_E = 128 # token num of per expert is aligned to 128
862
927
  BLOCK_D = 128 # block size of quantization
@@ -866,7 +931,15 @@ def ep_scatter(
866
931
  # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
867
932
  grid = num_experts
868
933
 
934
+ scale_hidden_size = hidden_size // BLOCK_D
935
+ if scale_ue8m0:
936
+ # ue8m0 scales are packed here (4 scales per int32),
937
+ # hence the effective size of this dimension is divided by 4.
938
+ scale_hidden_size = ceil_div(scale_hidden_size, 4)
939
+
869
940
  assert m_indices.shape[0] % BLOCK_E == 0
941
+ assert recv_x_scale.dtype == output_tensor_scale.dtype
942
+ assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size
870
943
 
871
944
  _fwd_kernel_ep_scatter_1[(grid,)](
872
945
  num_recv_tokens_per_expert,
@@ -905,8 +978,8 @@ def ep_scatter(
905
978
  num_warps=num_warps,
906
979
  HIDDEN_SIZE=hidden_size,
907
980
  HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
908
- SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
909
- SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
981
+ SCALE_HIDDEN_SIZE=scale_hidden_size,
982
+ SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
910
983
  )
911
984
  return
912
985
 
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import einops
5
5
  import torch
6
- from sgl_kernel import silu_and_mul
7
6
  from torch.nn import Module
8
7
 
9
8
  from sglang.srt.custom_op import CustomOp
@@ -11,6 +10,9 @@ from sglang.srt.distributed import (
11
10
  get_tensor_model_parallel_rank,
12
11
  get_tensor_model_parallel_world_size,
13
12
  )
13
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
15
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
14
16
  from sglang.srt.layers.moe.ep_moe.kernels import (
15
17
  ep_gather,
16
18
  ep_scatter,
@@ -19,6 +21,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
19
21
  moe_ep_deepgemm_preprocess,
20
22
  post_reorder_triton_kernel,
21
23
  pre_reorder_triton_kernel,
24
+ pre_reorder_triton_kernel_for_cutlass_moe,
25
+ run_cutlass_moe_ep_preproess,
22
26
  run_moe_ep_preproess,
23
27
  silu_and_mul_masked_post_quant_fwd,
24
28
  silu_and_mul_triton_kernel,
@@ -40,22 +44,27 @@ from sglang.srt.layers.quantization.fp8_kernel import (
40
44
  sglang_per_token_quant_fp8,
41
45
  )
42
46
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
43
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
44
- from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
47
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
45
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
50
  from sglang.srt.utils import (
48
51
  DeepEPMode,
52
+ ceil_div,
49
53
  dispose_tensor,
50
54
  get_bool_env_var,
51
55
  is_hip,
56
+ is_npu,
52
57
  set_weight_attrs,
53
58
  )
54
59
 
55
60
  _is_hip = is_hip()
61
+ _is_npu = is_npu()
56
62
  _is_fp8_fnuz = is_fp8_fnuz()
57
63
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
58
64
 
65
+ if not _is_npu:
66
+ from sgl_kernel import silu_and_mul
67
+
59
68
  if _is_hip:
60
69
  from vllm._custom_ops import scaled_fp8_quant
61
70
 
@@ -186,7 +195,7 @@ class EPMoE(torch.nn.Module):
186
195
  num_fused_shared_experts == 0
187
196
  ), "num_fused_shared_experts is not supported in EP"
188
197
  self.num_fused_shared_experts = num_fused_shared_experts
189
- self.num_experts_per_partition = self.num_experts // self.tp_size
198
+ self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
190
199
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
191
200
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
192
201
 
@@ -210,6 +219,18 @@ class EPMoE(torch.nn.Module):
210
219
  self.use_block_quant = False
211
220
  self.block_shape = None
212
221
  self.activation_scheme = None
222
+ self.use_w4afp8 = False
223
+ elif isinstance(quant_config, W4AFp8Config):
224
+ self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
225
+ quant_config
226
+ )
227
+ self.use_w4afp8 = True
228
+ self.use_fp8_w8a8 = False
229
+ self.use_block_quant = False
230
+ self.fp8_dtype = torch.float8_e4m3fn
231
+ self.w13_weight_scale = None
232
+ self.w2_weight_scale = None
233
+ self.activation_scheme = quant_config.moe_activation_scheme
213
234
  else:
214
235
  self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
215
236
  quant_config
@@ -223,6 +244,7 @@ class EPMoE(torch.nn.Module):
223
244
  )
224
245
  self.fp8_dtype = torch.float8_e4m3fn
225
246
  self.activation_scheme = quant_config.activation_scheme
247
+ self.use_w4afp8 = False
226
248
 
227
249
  self.quant_method.create_weights(
228
250
  layer=self,
@@ -248,6 +270,49 @@ class EPMoE(torch.nn.Module):
248
270
  self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
249
271
  )
250
272
 
273
+ # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
274
+ # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
275
+ def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
276
+ """
277
+ Calculates how many experts should be assigned to each rank for EP and
278
+ creates a mapping from global to local expert index. Experts are
279
+ distributed evenly across ranks. Any remaining are assigned to the
280
+ last rank.
281
+
282
+ Returns:
283
+ Tuple[int, Optional[torch.Tensor]]: A tuple containing:
284
+ - local_num_experts (int): The number of experts assigned
285
+ to the current rank.
286
+ - expert_map (Optional[torch.Tensor]): A tensor of shape
287
+ (global_num_experts,) mapping from global to local index.
288
+ Contains global_num_experts for experts not assigned to the current rank.
289
+ Returns None if ep_size is 1.
290
+ """
291
+ ep_size = self.tp_size
292
+ ep_rank = self.tp_rank
293
+ global_num_experts = self.num_experts
294
+
295
+ assert ep_size > 0
296
+ if ep_size == 1:
297
+ return (global_num_experts, None)
298
+
299
+ local_num_experts = global_num_experts // ep_size
300
+
301
+ expert_map = torch.full(
302
+ (global_num_experts,), self.num_experts, dtype=torch.int32
303
+ )
304
+ if ep_rank < (ep_size - 1):
305
+ expert_map[
306
+ ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
307
+ ] = torch.arange(0, local_num_experts, dtype=torch.int32)
308
+ else:
309
+ local_num_experts = global_num_experts - ep_rank * local_num_experts
310
+
311
+ expert_map[-local_num_experts:] = torch.arange(
312
+ 0, local_num_experts, dtype=torch.int32
313
+ )
314
+ return (local_num_experts, expert_map)
315
+
251
316
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
252
317
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
253
318
  return self.forward_deepgemm(hidden_states, router_logits)
@@ -435,6 +500,51 @@ class EPMoE(torch.nn.Module):
435
500
  ),
436
501
  )
437
502
 
503
+ if self.use_w4afp8:
504
+ local_topk_ids = topk_ids
505
+ if self.expert_map is not None:
506
+ "Translate info from expert_map to topk_ids"
507
+ local_topk_ids = torch.where(
508
+ self.expert_map[topk_ids] != self.num_experts,
509
+ self.expert_map[topk_ids],
510
+ self.num_experts,
511
+ )
512
+
513
+ output = cutlass_w4a8_moe(
514
+ self.start_expert_id,
515
+ self.end_expert_id,
516
+ self.num_experts,
517
+ hidden_states,
518
+ self.w13_weight,
519
+ self.w2_weight,
520
+ self.w13_weight_scale_inv,
521
+ self.w2_weight_scale_inv,
522
+ topk_weights,
523
+ topk_ids,
524
+ local_topk_ids,
525
+ self.quant_method.a_strides1,
526
+ self.quant_method.b_strides1,
527
+ self.quant_method.c_strides1,
528
+ self.quant_method.a_strides2,
529
+ self.quant_method.b_strides2,
530
+ self.quant_method.c_strides2,
531
+ self.quant_method.s_strides13,
532
+ self.quant_method.s_strides2,
533
+ self.quant_method.expert_offsets,
534
+ self.quant_method.problem_sizes1,
535
+ self.quant_method.problem_sizes2,
536
+ self.w13_input_scale,
537
+ self.w2_input_scale,
538
+ )
539
+ return output
540
+
541
+ if self.grouped_gemm_runner is None:
542
+ self.grouped_gemm_runner = GroupedGemmRunner(
543
+ hidden_states.device,
544
+ use_flashinfer=False, # TODO: use flashinfer
545
+ use_per_token_if_dynamic=self.use_per_token_if_dynamic,
546
+ )
547
+
438
548
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
439
549
  topk_ids, self.num_experts
440
550
  )
@@ -444,7 +554,7 @@ class EPMoE(torch.nn.Module):
444
554
  device=hidden_states.device,
445
555
  dtype=(
446
556
  self.fp8_dtype
447
- if (self.use_fp8_w8a8 and not self.use_block_quant)
557
+ if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
448
558
  else hidden_states.dtype
449
559
  ),
450
560
  )
@@ -651,6 +761,23 @@ class EPMoE(torch.nn.Module):
651
761
  ]
652
762
  ]
653
763
 
764
+ @classmethod
765
+ def make_expert_input_scale_params_mapping(
766
+ cls,
767
+ num_experts: int,
768
+ ) -> List[Tuple[str, str, int, str]]:
769
+ # (param_name, weight_name, expert_id, shard_id)
770
+ return [
771
+ (
772
+ "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
773
+ f"experts.{expert_id}.{shard_id}.",
774
+ expert_id,
775
+ shard_id,
776
+ )
777
+ for expert_id in range(num_experts)
778
+ for shard_id in ["w1", "w2", "w3"]
779
+ ]
780
+
654
781
  def weight_loader(
655
782
  self,
656
783
  param: torch.nn.Parameter,
@@ -722,6 +849,15 @@ class EPMoE(torch.nn.Module):
722
849
 
723
850
  # Input scales can be loaded directly and should be equal.
724
851
  if "input_scale" in weight_name:
852
+ if self.use_w4afp8:
853
+ if shard_id == "w1":
854
+ param_data[expert_id][0] = loaded_weight
855
+ elif shard_id == "w3":
856
+ param_data[expert_id][1] = loaded_weight
857
+ else:
858
+ param_data[expert_id] = loaded_weight
859
+ return
860
+
725
861
  if (
726
862
  (shard_id == "w1" or shard_id == "w3")
727
863
  and param_data[expert_id] != 1
@@ -747,6 +883,13 @@ class EPMoE(torch.nn.Module):
747
883
  ] = loaded_weight
748
884
  else: # w2
749
885
  param_data[expert_id] = loaded_weight
886
+ elif self.use_w4afp8:
887
+ if shard_id == "w1":
888
+ param_data[expert_id][: self.intermediate_size, :] = loaded_weight
889
+ elif shard_id == "w3":
890
+ param_data[expert_id][self.intermediate_size :, :] = loaded_weight
891
+ else:
892
+ param_data[expert_id] = loaded_weight
750
893
  # If we are in merged column case (gate_up_proj)
751
894
  else:
752
895
  if shard_id in ("w1", "w3"):
@@ -1173,12 +1316,14 @@ class DeepEPMoE(EPMoE):
1173
1316
  masked_m: torch.Tensor,
1174
1317
  expected_m: int,
1175
1318
  num_recv_tokens_per_expert: List[int],
1176
- forward_mode: ForwardMode,
1319
+ forward_batch: ForwardBatch,
1177
1320
  ):
1178
1321
  if _use_aiter:
1179
1322
  # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1180
1323
  return self.forward_aiter(hidden_states, topk_idx, topk_weights)
1181
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
1324
+ resolved_deepep_mode = self.deepep_mode.resolve(
1325
+ forward_batch.is_extend_in_batch
1326
+ )
1182
1327
  if resolved_deepep_mode == DeepEPMode.normal:
1183
1328
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
1184
1329
  return self.forward_deepgemm_contiguous(
@@ -1370,10 +1515,19 @@ class DeepEPMoE(EPMoE):
1370
1515
  device=hidden_states_fp8.device,
1371
1516
  dtype=hidden_states_fp8.dtype,
1372
1517
  ),
1373
- torch.empty(
1374
- (all_tokens, K // 128),
1375
- device=hidden_states_fp8.device,
1376
- dtype=torch.float32,
1518
+ (
1519
+ # TODO check whether need `zeros`
1520
+ torch.zeros(
1521
+ (ceil_div(K // 128, 4), all_tokens),
1522
+ device=hidden_states_fp8.device,
1523
+ dtype=torch.int,
1524
+ ).transpose(0, 1)
1525
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1526
+ else torch.empty(
1527
+ (all_tokens, K // 128),
1528
+ device=hidden_states_fp8.device,
1529
+ dtype=torch.float32,
1530
+ )
1377
1531
  ),
1378
1532
  ]
1379
1533
  m_indices = torch.empty(
@@ -1399,6 +1553,7 @@ class DeepEPMoE(EPMoE):
1399
1553
  input_tensor[1],
1400
1554
  m_indices,
1401
1555
  output_index,
1556
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1402
1557
  )
1403
1558
  dispose_tensor(hidden_states_fp8)
1404
1559
 
@@ -1407,7 +1562,8 @@ class DeepEPMoE(EPMoE):
1407
1562
  device=hidden_states_fp8_device,
1408
1563
  dtype=torch.bfloat16,
1409
1564
  )
1410
- input_tensor[1] = tma_align_input_scale(input_tensor[1])
1565
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1566
+ input_tensor[1] = tma_align_input_scale(input_tensor[1])
1411
1567
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1412
1568
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1413
1569
  )
@@ -1428,10 +1584,15 @@ class DeepEPMoE(EPMoE):
1428
1584
  dtype=torch.bfloat16,
1429
1585
  )
1430
1586
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1431
- down_input, scale_block_size
1587
+ down_input,
1588
+ scale_block_size,
1589
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1590
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1591
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
1432
1592
  )
1433
1593
  del down_input
1434
- down_input_scale = tma_align_input_scale(down_input_scale)
1594
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
1595
+ down_input_scale = tma_align_input_scale(down_input_scale)
1435
1596
  deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
1436
1597
  (down_input_fp8, down_input_scale),
1437
1598
  self.w2_weight_fp8,
@@ -1,10 +1,8 @@
1
1
  import logging
2
2
  from dataclasses import dataclass
3
3
 
4
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
4
5
  from sglang.srt.layers.quantization import deep_gemm_wrapper
5
- from sglang.srt.managers.expert_distribution import (
6
- get_global_expert_distribution_recorder,
7
- )
8
6
  from sglang.srt.managers.schedule_batch import global_server_args_dict
9
7
  from sglang.srt.utils import (
10
8
  DeepEPMode,
@@ -36,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
36
34
  deepep_post_reorder_triton_kernel,
37
35
  deepep_run_moe_deep_preprocess,
38
36
  )
39
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
38
 
41
39
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
42
40
 
@@ -246,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
246
244
  topk_idx = topk_idx.to(torch.int64)
247
245
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
248
246
  # TODO hard code 128 block quant,use fp8 communication
249
- hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
247
+ hidden_states = sglang_per_token_group_quant_fp8(
248
+ hidden_states,
249
+ 128,
250
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
251
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
252
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
253
+ )
250
254
  previous_event = Buffer.capture() if self.async_finish else None
251
255
  return hidden_states, topk_idx, topk_weights, previous_event
252
256
 
@@ -682,21 +686,21 @@ class DeepEPDispatcher:
682
686
  hidden_states: torch.Tensor,
683
687
  topk_idx: torch.Tensor,
684
688
  topk_weights: torch.Tensor,
685
- forward_mode: ForwardMode = None,
689
+ forward_batch: ForwardBatch,
686
690
  ):
687
691
  self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
688
- inner_state = self._get_impl(forward_mode).dispatch_a(
692
+ inner_state = self._get_impl(forward_batch).dispatch_a(
689
693
  hidden_states=hidden_states,
690
694
  topk_idx=topk_idx,
691
695
  topk_weights=topk_weights,
692
696
  )
693
- self._dispatch_intermediate_state = forward_mode, inner_state
697
+ self._dispatch_intermediate_state = forward_batch, inner_state
694
698
 
695
699
  def dispatch_b(self):
696
700
  self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
697
- forward_mode, inner_state = self._dispatch_intermediate_state
701
+ forward_batch, inner_state = self._dispatch_intermediate_state
698
702
  del self._dispatch_intermediate_state
699
- return self._get_impl(forward_mode).dispatch_b(*inner_state)
703
+ return self._get_impl(forward_batch).dispatch_b(*inner_state)
700
704
 
701
705
  def combine(self, *args, **kwargs) -> Tuple:
702
706
  self.combine_a(*args, **kwargs)
@@ -708,24 +712,26 @@ class DeepEPDispatcher:
708
712
  hidden_states: torch.Tensor,
709
713
  topk_idx: torch.Tensor,
710
714
  topk_weights: torch.Tensor,
711
- forward_mode: ForwardMode,
715
+ forward_batch: ForwardBatch,
712
716
  ):
713
717
  self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
714
- inner_state = self._get_impl(forward_mode).combine_a(
718
+ inner_state = self._get_impl(forward_batch).combine_a(
715
719
  hidden_states=hidden_states,
716
720
  topk_idx=topk_idx,
717
721
  topk_weights=topk_weights,
718
722
  )
719
- self._combine_intermediate_state = forward_mode, inner_state
723
+ self._combine_intermediate_state = forward_batch, inner_state
720
724
 
721
725
  def combine_b(self):
722
726
  self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
723
- forward_mode, inner_state = self._combine_intermediate_state
727
+ forward_batch, inner_state = self._combine_intermediate_state
724
728
  del self._combine_intermediate_state
725
- return self._get_impl(forward_mode).combine_b(*inner_state)
729
+ return self._get_impl(forward_batch).combine_b(*inner_state)
726
730
 
727
- def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
728
- resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
731
+ def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
732
+ resolved_deepep_mode = self.deepep_mode.resolve(
733
+ forward_batch.is_extend_in_batch
734
+ )
729
735
  if resolved_deepep_mode == DeepEPMode.normal:
730
736
  return self._normal_dispatcher
731
737
  elif resolved_deepep_mode == DeepEPMode.low_latency:
@@ -12,7 +12,6 @@ import torch
12
12
  import triton
13
13
  import triton.language as tl
14
14
 
15
- from sglang.math_utils import ceil_div
16
15
  from sglang.srt.layers.moe.topk import select_experts
17
16
  from sglang.srt.layers.quantization.fp8_kernel import (
18
17
  per_token_group_quant_fp8,
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
25
24
  sglang_per_token_group_quant_int8,
26
25
  )
27
26
  from sglang.srt.utils import (
27
+ ceil_div,
28
28
  cpu_has_amx_support,
29
29
  direct_register_custom_op,
30
30
  get_bool_env_var,
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
32
32
  is_cpu,
33
33
  is_cuda,
34
34
  is_hip,
35
- log_info_on_rank0,
36
35
  next_power_of_2,
37
36
  )
38
37
 
@@ -1738,6 +1737,7 @@ def fused_moe(
1738
1737
  renormalize: bool,
1739
1738
  inplace: bool = False,
1740
1739
  activation: str = "silu",
1740
+ apply_router_weight_on_input: bool = False,
1741
1741
  use_grouped_topk: bool = False,
1742
1742
  num_expert_group: Optional[int] = None,
1743
1743
  num_fused_shared_experts: int = 0,
@@ -1823,6 +1823,7 @@ def fused_moe(
1823
1823
  topk_ids,
1824
1824
  inplace=inplace,
1825
1825
  activation=activation,
1826
+ apply_router_weight_on_input=apply_router_weight_on_input,
1826
1827
  use_fp8_w8a8=use_fp8_w8a8,
1827
1828
  use_int8_w8a8=use_int8_w8a8,
1828
1829
  use_int8_w8a16=use_int8_w8a16,