sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,215 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Cutlass W4A8 MoE kernel."""
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sgl_kernel import (
7
+ cutlass_w4a8_moe_mm,
8
+ get_cutlass_w4a8_moe_mm_data,
9
+ sgl_per_tensor_quant_fp8,
10
+ silu_and_mul,
11
+ )
12
+
13
+ from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ post_reorder_triton_kernel,
15
+ pre_reorder_triton_kernel_for_cutlass_moe,
16
+ run_cutlass_moe_ep_preproess,
17
+ )
18
+
19
+
20
+ def cutlass_w4a8_moe(
21
+ start_expert_id: int,
22
+ end_expert_id: int,
23
+ total_num_experts: int,
24
+ a: torch.Tensor,
25
+ w1_q: torch.Tensor,
26
+ w2_q: torch.Tensor,
27
+ w1_scale: torch.Tensor,
28
+ w2_scale: torch.Tensor,
29
+ topk_weights: torch.Tensor,
30
+ topk_ids_: torch.Tensor,
31
+ local_topk_ids: torch.Tensor,
32
+ a_strides1: torch.Tensor,
33
+ b_strides1: torch.Tensor,
34
+ c_strides1: torch.Tensor,
35
+ a_strides2: torch.Tensor,
36
+ b_strides2: torch.Tensor,
37
+ c_strides2: torch.Tensor,
38
+ s_strides13: torch.Tensor,
39
+ s_strides2: torch.Tensor,
40
+ expert_offsets: torch.Tensor,
41
+ problem_sizes1: torch.Tensor,
42
+ problem_sizes2: torch.Tensor,
43
+ a1_scale: Optional[torch.Tensor] = None,
44
+ a2_scale: Optional[torch.Tensor] = None,
45
+ apply_router_weight_on_input: bool = False,
46
+ ) -> torch.Tensor:
47
+ """
48
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
49
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
50
+ mechanism. The matrix multiplications are implemented with CUTLASS
51
+ grouped gemm.
52
+
53
+ Parameters:
54
+ - a (torch.Tensor): The input tensor to the MoE layer.
55
+ Shape: [M, K]
56
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
57
+ Shape: [num_experts, N * 2, K // 2]
58
+ (the weights are passed transposed and int4-packed)
59
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
60
+ Shape: [num_experts, K, N // 2]
61
+ (the weights are passed transposed and int4-packed)
62
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
63
+ Shape: [num_experts, K // 512, N * 8]
64
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
+ Shape: [num_experts, N // 512, K * 4]
66
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
67
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
70
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
71
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
72
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
73
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
74
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
75
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
76
+ Shape: scalar or [1, K]
77
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
78
+ quantize the intermediate result between the gemms.
79
+ Shape: scalar or [1, N]
80
+ - apply_router_weight_on_input (bool): When true, the topk weights are
81
+ applied directly on the inputs. This is only applicable when topk is 1.
82
+
83
+ Returns:
84
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
+ """
86
+ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
87
+ assert w1_q.dtype == torch.int8
88
+ assert w2_q.dtype == torch.int8
89
+ assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
90
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
91
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
+ assert (
95
+ w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
+ and w1_scale.shape[2] == w1_q.shape[1] * 4
97
+ ), "W1 scale shape mismatch"
98
+ assert (
99
+ w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
+ and w2_scale.shape[2] == w2_q.shape[1] * 4
101
+ ), "W2 scale shape mismatch"
102
+
103
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
+ num_experts = w1_q.size(0)
108
+ m = a.size(0)
109
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
110
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
111
+ topk = topk_ids_.size(1)
112
+
113
+ if apply_router_weight_on_input:
114
+ assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
115
+
116
+ device = a.device
117
+
118
+ _, src2dst, _ = run_cutlass_moe_ep_preproess(
119
+ local_topk_ids,
120
+ num_experts,
121
+ )
122
+
123
+ gateup_input = torch.empty(
124
+ (m * topk, k),
125
+ device=device,
126
+ dtype=torch.float8_e4m3fn,
127
+ )
128
+
129
+ pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
130
+ a,
131
+ gateup_input,
132
+ src2dst,
133
+ local_topk_ids,
134
+ a1_scale,
135
+ total_num_experts,
136
+ topk,
137
+ k,
138
+ BLOCK_SIZE=512,
139
+ )
140
+
141
+ # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
142
+ # they are kept to allow for a quick switch of the permutation logic
143
+ # from the current triton kernel implementation to the cutlass-based one if needed.
144
+ a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
145
+ c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
146
+ get_cutlass_w4a8_moe_mm_data(
147
+ local_topk_ids,
148
+ expert_offsets,
149
+ problem_sizes1,
150
+ problem_sizes2,
151
+ a_map,
152
+ c_map,
153
+ num_experts,
154
+ n,
155
+ k,
156
+ )
157
+
158
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
160
+
161
+ cutlass_w4a8_moe_mm(
162
+ c1,
163
+ gateup_input,
164
+ w1_q,
165
+ a1_scale.float(),
166
+ w1_scale,
167
+ expert_offsets[:-1],
168
+ problem_sizes1,
169
+ a_strides1,
170
+ b_strides1,
171
+ c_strides1,
172
+ s_strides13,
173
+ 128,
174
+ topk,
175
+ )
176
+
177
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
178
+ silu_and_mul(c1, intermediate)
179
+
180
+ intermediate_q = torch.empty(
181
+ intermediate.shape, dtype=torch.float8_e4m3fn, device=device
182
+ )
183
+ sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
184
+
185
+ cutlass_w4a8_moe_mm(
186
+ c2,
187
+ intermediate_q,
188
+ w2_q,
189
+ a2_scale.float(),
190
+ w2_scale,
191
+ expert_offsets[:-1],
192
+ problem_sizes2,
193
+ a_strides2,
194
+ b_strides2,
195
+ c_strides2,
196
+ s_strides2,
197
+ 128,
198
+ topk,
199
+ )
200
+
201
+ output = torch.empty_like(a)
202
+ post_reorder_triton_kernel[(m,)](
203
+ c2,
204
+ output,
205
+ src2dst,
206
+ topk_ids_,
207
+ topk_weights,
208
+ start_expert_id,
209
+ end_expert_id,
210
+ topk,
211
+ k,
212
+ 0,
213
+ BLOCK_SIZE=512,
214
+ )
215
+ return output
@@ -6,6 +6,7 @@ import triton
6
6
 
7
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
8
8
  from sglang.srt.utils import ceil_div, dispose_tensor, is_cuda
9
+ from sglang.utils import is_in_ci
9
10
 
10
11
  logger = logging.getLogger(__name__)
11
12
 
@@ -146,6 +147,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
146
147
 
147
148
  def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
148
149
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
150
+
149
151
  seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
150
152
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
151
153
 
@@ -158,9 +160,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
158
160
  compute_src2dst_triton_kernel[grid](
159
161
  reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
160
162
  )
163
+
161
164
  return reorder_topk_ids, src2dst, seg_indptr
162
165
 
163
166
 
167
+ def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
168
+ reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
169
+
170
+ seg_indptr = torch.zeros(
171
+ local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
172
+ )
173
+ src2dst = torch.empty(
174
+ local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
175
+ )
176
+
177
+ BLOCK_SIZE = 512
178
+ grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
179
+ compute_src2dst_triton_kernel[grid](
180
+ reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
181
+ )
182
+
183
+ return reorder_topk_ids, src2dst, seg_indptr
184
+
185
+
186
+ @triton.jit
187
+ def pre_reorder_triton_kernel_for_cutlass_moe(
188
+ input_ptr,
189
+ gateup_input_ptr,
190
+ src2dst_ptr,
191
+ topk_ids_ptr,
192
+ a1_scales_ptr,
193
+ num_experts,
194
+ topk,
195
+ hidden_size,
196
+ BLOCK_SIZE: tl.constexpr,
197
+ ):
198
+ OutDtype = gateup_input_ptr.dtype.element_ty
199
+
200
+ src_idx = tl.program_id(0)
201
+ src2dst_ptr = src2dst_ptr + src_idx * topk
202
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
203
+
204
+ src_ptr = input_ptr + src_idx * hidden_size
205
+ for idx in range(topk):
206
+ expert_id = tl.load(topk_ids_ptr + idx)
207
+ if expert_id != num_experts:
208
+ if a1_scales_ptr is not None:
209
+ scale = 1.0 / tl.load(a1_scales_ptr)
210
+ else:
211
+ scale = 1.0
212
+
213
+ dst_idx = tl.load(src2dst_ptr + idx)
214
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
215
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
216
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
217
+ mask = offset < hidden_size
218
+ in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
219
+ out_data = (in_data * scale).to(OutDtype)
220
+ tl.store(dst_ptr + offset, out_data, mask=mask)
221
+
222
+
164
223
  @triton.jit
165
224
  def pre_reorder_triton_kernel(
166
225
  input_ptr,
@@ -1000,7 +1059,7 @@ def ep_gather(
1000
1059
  input_index: torch.Tensor,
1001
1060
  output_tensor: torch.Tensor,
1002
1061
  ):
1003
- BLOCK_D = 1024 # block size of quantization
1062
+ BLOCK_D = 1024 if not is_in_ci() else 128 # block size of quantization
1004
1063
  num_warps = 2
1005
1064
  num_tokens = output_tensor.shape[0]
1006
1065
  hidden_size = input_tensor.shape[1]
@@ -20,6 +20,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
20
20
  moe_ep_deepgemm_preprocess,
21
21
  post_reorder_triton_kernel,
22
22
  pre_reorder_triton_kernel,
23
+ pre_reorder_triton_kernel_for_cutlass_moe,
24
+ run_cutlass_moe_ep_preproess,
23
25
  run_moe_ep_preproess,
24
26
  silu_and_mul_masked_post_quant_fwd,
25
27
  silu_and_mul_triton_kernel,
@@ -41,6 +43,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
41
43
  sglang_per_token_quant_fp8,
42
44
  )
43
45
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
46
+ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
44
47
  from sglang.srt.managers.schedule_batch import global_server_args_dict
45
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
49
  from sglang.srt.utils import (
@@ -61,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
61
64
  if not _is_npu:
62
65
  from sgl_kernel import silu_and_mul
63
66
 
67
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
68
+
64
69
  if _is_hip:
65
70
  from vllm._custom_ops import scaled_fp8_quant
66
71
 
@@ -191,7 +196,7 @@ class EPMoE(torch.nn.Module):
191
196
  num_fused_shared_experts == 0
192
197
  ), "num_fused_shared_experts is not supported in EP"
193
198
  self.num_fused_shared_experts = num_fused_shared_experts
194
- self.num_experts_per_partition = self.num_experts // self.tp_size
199
+ self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
195
200
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
196
201
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
197
202
 
@@ -215,6 +220,18 @@ class EPMoE(torch.nn.Module):
215
220
  self.use_block_quant = False
216
221
  self.block_shape = None
217
222
  self.activation_scheme = None
223
+ self.use_w4afp8 = False
224
+ elif isinstance(quant_config, W4AFp8Config):
225
+ self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
226
+ quant_config
227
+ )
228
+ self.use_w4afp8 = True
229
+ self.use_fp8_w8a8 = False
230
+ self.use_block_quant = False
231
+ self.fp8_dtype = torch.float8_e4m3fn
232
+ self.w13_weight_scale = None
233
+ self.w2_weight_scale = None
234
+ self.activation_scheme = quant_config.moe_activation_scheme
218
235
  else:
219
236
  self.quant_method: Optional[QuantizeMethodBase] = Fp8EPMoEMethod(
220
237
  quant_config
@@ -228,6 +245,7 @@ class EPMoE(torch.nn.Module):
228
245
  )
229
246
  self.fp8_dtype = torch.float8_e4m3fn
230
247
  self.activation_scheme = quant_config.activation_scheme
248
+ self.use_w4afp8 = False
231
249
 
232
250
  self.quant_method.create_weights(
233
251
  layer=self,
@@ -253,6 +271,49 @@ class EPMoE(torch.nn.Module):
253
271
  self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
254
272
  )
255
273
 
274
+ # Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
275
+ # Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
276
+ def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
277
+ """
278
+ Calculates how many experts should be assigned to each rank for EP and
279
+ creates a mapping from global to local expert index. Experts are
280
+ distributed evenly across ranks. Any remaining are assigned to the
281
+ last rank.
282
+
283
+ Returns:
284
+ Tuple[int, Optional[torch.Tensor]]: A tuple containing:
285
+ - local_num_experts (int): The number of experts assigned
286
+ to the current rank.
287
+ - expert_map (Optional[torch.Tensor]): A tensor of shape
288
+ (global_num_experts,) mapping from global to local index.
289
+ Contains global_num_experts for experts not assigned to the current rank.
290
+ Returns None if ep_size is 1.
291
+ """
292
+ ep_size = self.tp_size
293
+ ep_rank = self.tp_rank
294
+ global_num_experts = self.num_experts
295
+
296
+ assert ep_size > 0
297
+ if ep_size == 1:
298
+ return (global_num_experts, None)
299
+
300
+ local_num_experts = global_num_experts // ep_size
301
+
302
+ expert_map = torch.full(
303
+ (global_num_experts,), self.num_experts, dtype=torch.int32
304
+ )
305
+ if ep_rank < (ep_size - 1):
306
+ expert_map[
307
+ ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
308
+ ] = torch.arange(0, local_num_experts, dtype=torch.int32)
309
+ else:
310
+ local_num_experts = global_num_experts - ep_rank * local_num_experts
311
+
312
+ expert_map[-local_num_experts:] = torch.arange(
313
+ 0, local_num_experts, dtype=torch.int32
314
+ )
315
+ return (local_num_experts, expert_map)
316
+
256
317
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
257
318
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
258
319
  return self.forward_deepgemm(hidden_states, router_logits)
@@ -440,6 +501,51 @@ class EPMoE(torch.nn.Module):
440
501
  ),
441
502
  )
442
503
 
504
+ if self.use_w4afp8:
505
+ local_topk_ids = topk_ids
506
+ if self.expert_map is not None:
507
+ "Translate info from expert_map to topk_ids"
508
+ local_topk_ids = torch.where(
509
+ self.expert_map[topk_ids] != self.num_experts,
510
+ self.expert_map[topk_ids],
511
+ self.num_experts,
512
+ )
513
+
514
+ output = cutlass_w4a8_moe(
515
+ self.start_expert_id,
516
+ self.end_expert_id,
517
+ self.num_experts,
518
+ hidden_states,
519
+ self.w13_weight,
520
+ self.w2_weight,
521
+ self.w13_weight_scale_inv,
522
+ self.w2_weight_scale_inv,
523
+ topk_weights,
524
+ topk_ids,
525
+ local_topk_ids,
526
+ self.quant_method.a_strides1,
527
+ self.quant_method.b_strides1,
528
+ self.quant_method.c_strides1,
529
+ self.quant_method.a_strides2,
530
+ self.quant_method.b_strides2,
531
+ self.quant_method.c_strides2,
532
+ self.quant_method.s_strides13,
533
+ self.quant_method.s_strides2,
534
+ self.quant_method.expert_offsets,
535
+ self.quant_method.problem_sizes1,
536
+ self.quant_method.problem_sizes2,
537
+ self.w13_input_scale,
538
+ self.w2_input_scale,
539
+ )
540
+ return output
541
+
542
+ if self.grouped_gemm_runner is None:
543
+ self.grouped_gemm_runner = GroupedGemmRunner(
544
+ hidden_states.device,
545
+ use_flashinfer=False, # TODO: use flashinfer
546
+ use_per_token_if_dynamic=self.use_per_token_if_dynamic,
547
+ )
548
+
443
549
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
444
550
  topk_ids, self.num_experts
445
551
  )
@@ -449,7 +555,7 @@ class EPMoE(torch.nn.Module):
449
555
  device=hidden_states.device,
450
556
  dtype=(
451
557
  self.fp8_dtype
452
- if (self.use_fp8_w8a8 and not self.use_block_quant)
558
+ if ((self.use_fp8_w8a8 or self.use_w4afp8) and not self.use_block_quant)
453
559
  else hidden_states.dtype
454
560
  ),
455
561
  )
@@ -656,6 +762,23 @@ class EPMoE(torch.nn.Module):
656
762
  ]
657
763
  ]
658
764
 
765
+ @classmethod
766
+ def make_expert_input_scale_params_mapping(
767
+ cls,
768
+ num_experts: int,
769
+ ) -> List[Tuple[str, str, int, str]]:
770
+ # (param_name, weight_name, expert_id, shard_id)
771
+ return [
772
+ (
773
+ "experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
774
+ f"experts.{expert_id}.{shard_id}.",
775
+ expert_id,
776
+ shard_id,
777
+ )
778
+ for expert_id in range(num_experts)
779
+ for shard_id in ["w1", "w2", "w3"]
780
+ ]
781
+
659
782
  def weight_loader(
660
783
  self,
661
784
  param: torch.nn.Parameter,
@@ -727,6 +850,15 @@ class EPMoE(torch.nn.Module):
727
850
 
728
851
  # Input scales can be loaded directly and should be equal.
729
852
  if "input_scale" in weight_name:
853
+ if self.use_w4afp8:
854
+ if shard_id == "w1":
855
+ param_data[expert_id][0] = loaded_weight
856
+ elif shard_id == "w3":
857
+ param_data[expert_id][1] = loaded_weight
858
+ else:
859
+ param_data[expert_id] = loaded_weight
860
+ return
861
+
730
862
  if (
731
863
  (shard_id == "w1" or shard_id == "w3")
732
864
  and param_data[expert_id] != 1
@@ -752,6 +884,13 @@ class EPMoE(torch.nn.Module):
752
884
  ] = loaded_weight
753
885
  else: # w2
754
886
  param_data[expert_id] = loaded_weight
887
+ elif self.use_w4afp8:
888
+ if shard_id == "w1":
889
+ param_data[expert_id][: self.intermediate_size, :] = loaded_weight
890
+ elif shard_id == "w3":
891
+ param_data[expert_id][self.intermediate_size :, :] = loaded_weight
892
+ else:
893
+ param_data[expert_id] = loaded_weight
755
894
  # If we are in merged column case (gate_up_proj)
756
895
  else:
757
896
  if shard_id in ("w1", "w3"):
@@ -1737,6 +1737,7 @@ def fused_moe(
1737
1737
  renormalize: bool,
1738
1738
  inplace: bool = False,
1739
1739
  activation: str = "silu",
1740
+ apply_router_weight_on_input: bool = False,
1740
1741
  use_grouped_topk: bool = False,
1741
1742
  num_expert_group: Optional[int] = None,
1742
1743
  num_fused_shared_experts: int = 0,
@@ -1822,6 +1823,7 @@ def fused_moe(
1822
1823
  topk_ids,
1823
1824
  inplace=inplace,
1824
1825
  activation=activation,
1826
+ apply_router_weight_on_input=apply_router_weight_on_input,
1825
1827
  use_fp8_w8a8=use_fp8_w8a8,
1826
1828
  use_int8_w8a8=use_int8_w8a8,
1827
1829
  use_int8_w8a16=use_int8_w8a16,