sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +12 -1
  3. sglang/srt/conversation.py +35 -1
  4. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  5. sglang/srt/entrypoints/http_server_engine.py +1 -1
  6. sglang/srt/layers/communicator.py +3 -1
  7. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  8. sglang/srt/layers/layernorm.py +2 -2
  9. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  10. sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
  11. sglang/srt/layers/moe/ep_moe/layer.py +140 -2
  12. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  13. sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
  14. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  15. sglang/srt/layers/quantization/__init__.py +2 -0
  16. sglang/srt/layers/quantization/fp8.py +28 -7
  17. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  18. sglang/srt/layers/quantization/w4afp8.py +264 -0
  19. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  20. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  21. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  22. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  23. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  24. sglang/srt/managers/cache_controller.py +41 -195
  25. sglang/srt/managers/io_struct.py +8 -1
  26. sglang/srt/managers/mm_utils.py +4 -2
  27. sglang/srt/managers/schedule_batch.py +1 -1
  28. sglang/srt/managers/scheduler.py +17 -5
  29. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  30. sglang/srt/mem_cache/memory_pool.py +113 -63
  31. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  32. sglang/srt/mem_cache/radix_cache.py +8 -4
  33. sglang/srt/models/deepseek_v2.py +16 -2
  34. sglang/srt/models/mllama4.py +360 -79
  35. sglang/srt/multimodal/mm_utils.py +2 -2
  36. sglang/srt/multimodal/processors/mllama4.py +62 -60
  37. sglang/srt/server_args.py +15 -0
  38. sglang/srt/two_batch_overlap.py +3 -0
  39. sglang/srt/utils.py +37 -17
  40. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  41. sglang/utils.py +5 -5
  42. sglang/version.py +1 -1
  43. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
  44. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
  45. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  46. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  47. {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -814,9 +814,9 @@ def sample_mmmu_requests(
814
814
  List of tuples (prompt, prompt_token_len, output_token_len).
815
815
  """
816
816
  try:
817
- import base64
818
817
  import io
819
818
 
819
+ import pybase64
820
820
  from datasets import load_dataset
821
821
  except ImportError:
822
822
  raise ImportError("Please install datasets: pip install datasets")
@@ -867,7 +867,7 @@ def sample_mmmu_requests(
867
867
  # Encode image to base64
868
868
  buffered = io.BytesIO()
869
869
  image.save(buffered, format="JPEG")
870
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
870
+ img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
871
871
  image_data = f"data:image/jpeg;base64,{img_str}"
872
872
  else:
873
873
  continue
@@ -359,7 +359,17 @@ class ModelConfig:
359
359
  if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
360
360
  quant_cfg = modelopt_quant_config
361
361
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
362
- quant_cfg = modelopt_quant_config
362
+ quant_config_file = os.path.join(
363
+ self.model_path, "hf_quant_config.json"
364
+ )
365
+ with open(quant_config_file) as f:
366
+ quant_config_dict = json.load(f)
367
+ json_quant_configs = quant_config_dict["quantization"]
368
+ quant_algo = json_quant_configs.get("quant_algo", None)
369
+ if quant_algo == "MIXED_PRECISION":
370
+ quant_cfg = {"quant_method": "w4afp8"}
371
+ else:
372
+ quant_cfg = modelopt_quant_config
363
373
  return quant_cfg
364
374
 
365
375
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -389,6 +399,7 @@ class ModelConfig:
389
399
  "w8a8_fp8",
390
400
  "moe_wna16",
391
401
  "qoq",
402
+ "w4afp8",
392
403
  ]
393
404
  compatible_quantization_methods = {
394
405
  "modelopt_fp4": ["modelopt"],
@@ -921,6 +921,19 @@ register_conv_template(
921
921
  )
922
922
  )
923
923
 
924
+ register_conv_template(
925
+ Conversation(
926
+ name="mimo-vl",
927
+ system_message="You are MiMo, an AI assistant developed by Xiaomi.",
928
+ system_template="<|im_start|>system\n{system_message}",
929
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
930
+ sep="<|im_end|>\n",
931
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
932
+ stop_str=["<|im_end|>"],
933
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
934
+ )
935
+ )
936
+
924
937
 
925
938
  register_conv_template(
926
939
  Conversation(
@@ -935,6 +948,19 @@ register_conv_template(
935
948
  )
936
949
  )
937
950
 
951
+ register_conv_template(
952
+ Conversation(
953
+ name="llama_4_vision",
954
+ system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
955
+ system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
956
+ roles=("user", "assistant"),
957
+ sep_style=SeparatorStyle.LLAMA4,
958
+ sep="",
959
+ stop_str="<|eot|>",
960
+ image_token="<|image|>",
961
+ )
962
+ )
963
+
938
964
 
939
965
  @register_conv_template_matching_function
940
966
  def match_internvl(model_path: str):
@@ -943,9 +969,11 @@ def match_internvl(model_path: str):
943
969
 
944
970
 
945
971
  @register_conv_template_matching_function
946
- def match_llama_3_vision(model_path: str):
972
+ def match_llama_vision(model_path: str):
947
973
  if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
948
974
  return "llama_3_vision"
975
+ if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
976
+ return "llama_4_vision"
949
977
 
950
978
 
951
979
  @register_conv_template_matching_function
@@ -1034,3 +1062,9 @@ def match_phi_4_mm(model_path: str):
1034
1062
  def match_vila(model_path: str):
1035
1063
  if re.search(r"vila", model_path, re.IGNORECASE):
1036
1064
  return "chatml"
1065
+
1066
+
1067
+ @register_conv_template_matching_function
1068
+ def match_mimo_vl(model_path: str):
1069
+ if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
1070
+ return "mimo-vl"
@@ -185,9 +185,11 @@ class MooncakeKVManager(BaseKVManager):
185
185
  threading.Thread(
186
186
  target=self.transfer_worker, args=(queue, executor), daemon=True
187
187
  ).start()
188
-
189
- self.bootstrap_time_out = get_int_env_var(
190
- "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
188
+ # If a timeout happens on the prefill side, it means prefill instances
189
+ # fail to receive the KV indices from the decode instance of this request.
190
+ # These timeout requests should be aborted to release the tree cache.
191
+ self.bootstrap_timeout = get_int_env_var(
192
+ "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
191
193
  )
192
194
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
193
195
  self.heartbeat_failures = {}
@@ -209,6 +211,12 @@ class MooncakeKVManager(BaseKVManager):
209
211
  self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
210
212
  self.prefill_tp_size_table: Dict[str, int] = {}
211
213
  self.prefill_dp_size_table: Dict[str, int] = {}
214
+ # If a timeout happens on the decode side, it means decode instances
215
+ # fail to receive the KV Cache transfer done signal after bootstrapping.
216
+ # These timeout requests should be aborted to release the tree cache.
217
+ self.waiting_timeout = get_int_env_var(
218
+ "SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
219
+ )
212
220
  else:
213
221
  raise ValueError(
214
222
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
@@ -938,7 +946,12 @@ class MooncakeKVSender(BaseKVSender):
938
946
  if self.init_time is not None:
939
947
  now = time.time()
940
948
  elapsed = now - self.init_time
941
- if elapsed >= self.kv_mgr.bootstrap_time_out:
949
+ if elapsed >= self.kv_mgr.bootstrap_timeout:
950
+ logger.warning_once(
951
+ "Some requests timed out when bootstrapping, "
952
+ "which means prefill instances fail to receive the KV indices from the decode instance of this request. "
953
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
954
+ )
942
955
  self.kv_mgr.record_failure(
943
956
  self.bootstrap_room,
944
957
  f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
@@ -987,6 +1000,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
987
1000
  self.session_id = self.kv_mgr.get_session_id()
988
1001
  self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
989
1002
  self.conclude_state = None
1003
+ self.init_time = None
990
1004
  self.data_parallel_rank = data_parallel_rank
991
1005
 
992
1006
  if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
@@ -1222,14 +1236,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
1222
1236
  str(self.required_dst_info_num).encode("ascii"),
1223
1237
  ]
1224
1238
  )
1239
+ self.init_time = time.time()
1225
1240
 
1226
1241
  def poll(self) -> KVPoll:
1227
1242
  if self.conclude_state is None:
1228
1243
  status = self.kv_mgr.check_status(self.bootstrap_room)
1229
1244
  if status in (KVPoll.Success, KVPoll.Failed):
1230
1245
  self.conclude_state = status
1246
+ elif status == KVPoll.WaitingForInput:
1247
+ if self.init_time is not None:
1248
+ now = time.time()
1249
+ elapsed = now - self.init_time
1250
+ if elapsed >= self.kv_mgr.waiting_timeout:
1251
+ logger.warning_once(
1252
+ "Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
1253
+ "If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
1254
+ )
1255
+ self.kv_mgr.record_failure(
1256
+ self.bootstrap_room,
1257
+ f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
1258
+ )
1259
+ self.conclude_state = KVPoll.Failed
1260
+ return KVPoll.Failed
1231
1261
 
1232
1262
  return status
1263
+
1233
1264
  else:
1234
1265
  return self.conclude_state
1235
1266
 
@@ -1,4 +1,3 @@
1
- import base64
2
1
  import copy
3
2
  import dataclasses
4
3
  import multiprocessing
@@ -7,6 +6,7 @@ import threading
7
6
  import time
8
7
  from typing import Any, Dict, List, Optional, Tuple, Union
9
8
 
9
+ import pybase64
10
10
  import requests
11
11
  import torch
12
12
  import torch.distributed as dist
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
402
402
  if hidden_states.shape[0] != 0:
403
403
  hidden_states = layernorm(hidden_states)
404
404
  else:
405
+ # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
406
+ # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
405
407
  if (
406
408
  _is_sm100_supported
407
409
  and _is_flashinfer_available
408
410
  and hasattr(layernorm, "forward_with_allreduce_fusion")
409
411
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
410
- and hidden_states.shape[0] <= 1024
412
+ and hidden_states.shape[0] <= 128
411
413
  ):
412
414
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
413
415
  hidden_states, residual
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
92
92
 
93
93
 
94
94
  def ensure_workspace_initialized(
95
- max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
95
+ max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
96
96
  ):
97
97
  """Ensure workspace is initialized"""
98
98
  if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
119
119
  return _workspace_manager.initialized
120
120
 
121
121
 
122
- def flashinfer_allreduce_add_rmsnorm(
122
+ def flashinfer_allreduce_residual_rmsnorm(
123
123
  input_tensor: torch.Tensor,
124
124
  residual: torch.Tensor,
125
125
  weight: torch.Tensor,
126
126
  eps: float = 1e-6,
127
- max_token_num: int = 1024,
127
+ max_token_num: int = 128,
128
128
  use_oneshot: bool = True,
129
129
  trigger_completion_at_end: bool = False,
130
130
  fp32_acc: bool = False,
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
174
174
  if residual is not None:
175
175
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
176
176
  from sglang.srt.layers.flashinfer_comm_fusion import (
177
- flashinfer_allreduce_add_rmsnorm,
177
+ flashinfer_allreduce_residual_rmsnorm,
178
178
  )
179
179
 
180
180
  if get_tensor_model_parallel_world_size() > 1:
181
- fused_result = flashinfer_allreduce_add_rmsnorm(
181
+ fused_result = flashinfer_allreduce_residual_rmsnorm(
182
182
  input_tensor=x,
183
183
  residual=residual,
184
184
  weight=self.weight,
@@ -0,0 +1,215 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """Cutlass W4A8 MoE kernel."""
3
+ from typing import Optional
4
+
5
+ import torch
6
+ from sgl_kernel import (
7
+ cutlass_w4a8_moe_mm,
8
+ get_cutlass_w4a8_moe_mm_data,
9
+ sgl_per_tensor_quant_fp8,
10
+ silu_and_mul,
11
+ )
12
+
13
+ from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ post_reorder_triton_kernel,
15
+ pre_reorder_triton_kernel_for_cutlass_moe,
16
+ run_cutlass_moe_ep_preproess,
17
+ )
18
+
19
+
20
+ def cutlass_w4a8_moe(
21
+ start_expert_id: int,
22
+ end_expert_id: int,
23
+ total_num_experts: int,
24
+ a: torch.Tensor,
25
+ w1_q: torch.Tensor,
26
+ w2_q: torch.Tensor,
27
+ w1_scale: torch.Tensor,
28
+ w2_scale: torch.Tensor,
29
+ topk_weights: torch.Tensor,
30
+ topk_ids_: torch.Tensor,
31
+ local_topk_ids: torch.Tensor,
32
+ a_strides1: torch.Tensor,
33
+ b_strides1: torch.Tensor,
34
+ c_strides1: torch.Tensor,
35
+ a_strides2: torch.Tensor,
36
+ b_strides2: torch.Tensor,
37
+ c_strides2: torch.Tensor,
38
+ s_strides13: torch.Tensor,
39
+ s_strides2: torch.Tensor,
40
+ expert_offsets: torch.Tensor,
41
+ problem_sizes1: torch.Tensor,
42
+ problem_sizes2: torch.Tensor,
43
+ a1_scale: Optional[torch.Tensor] = None,
44
+ a2_scale: Optional[torch.Tensor] = None,
45
+ apply_router_weight_on_input: bool = False,
46
+ ) -> torch.Tensor:
47
+ """
48
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
49
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
50
+ mechanism. The matrix multiplications are implemented with CUTLASS
51
+ grouped gemm.
52
+
53
+ Parameters:
54
+ - a (torch.Tensor): The input tensor to the MoE layer.
55
+ Shape: [M, K]
56
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
57
+ Shape: [num_experts, N * 2, K // 2]
58
+ (the weights are passed transposed and int4-packed)
59
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
60
+ Shape: [num_experts, K, N // 2]
61
+ (the weights are passed transposed and int4-packed)
62
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
63
+ Shape: [num_experts, K // 512, N * 8]
64
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
+ Shape: [num_experts, N // 512, K * 4]
66
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
67
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
70
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
71
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
72
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
73
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
74
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
75
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
76
+ Shape: scalar or [1, K]
77
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
78
+ quantize the intermediate result between the gemms.
79
+ Shape: scalar or [1, N]
80
+ - apply_router_weight_on_input (bool): When true, the topk weights are
81
+ applied directly on the inputs. This is only applicable when topk is 1.
82
+
83
+ Returns:
84
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
+ """
86
+ assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
87
+ assert w1_q.dtype == torch.int8
88
+ assert w2_q.dtype == torch.int8
89
+ assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
90
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
91
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
+ assert (
95
+ w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
+ and w1_scale.shape[2] == w1_q.shape[1] * 4
97
+ ), "W1 scale shape mismatch"
98
+ assert (
99
+ w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
+ and w2_scale.shape[2] == w2_q.shape[1] * 4
101
+ ), "W2 scale shape mismatch"
102
+
103
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
+ num_experts = w1_q.size(0)
108
+ m = a.size(0)
109
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
110
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
111
+ topk = topk_ids_.size(1)
112
+
113
+ if apply_router_weight_on_input:
114
+ assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
115
+
116
+ device = a.device
117
+
118
+ _, src2dst, _ = run_cutlass_moe_ep_preproess(
119
+ local_topk_ids,
120
+ num_experts,
121
+ )
122
+
123
+ gateup_input = torch.empty(
124
+ (m * topk, k),
125
+ device=device,
126
+ dtype=torch.float8_e4m3fn,
127
+ )
128
+
129
+ pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
130
+ a,
131
+ gateup_input,
132
+ src2dst,
133
+ local_topk_ids,
134
+ a1_scale,
135
+ total_num_experts,
136
+ topk,
137
+ k,
138
+ BLOCK_SIZE=512,
139
+ )
140
+
141
+ # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
142
+ # they are kept to allow for a quick switch of the permutation logic
143
+ # from the current triton kernel implementation to the cutlass-based one if needed.
144
+ a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
145
+ c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
146
+ get_cutlass_w4a8_moe_mm_data(
147
+ local_topk_ids,
148
+ expert_offsets,
149
+ problem_sizes1,
150
+ problem_sizes2,
151
+ a_map,
152
+ c_map,
153
+ num_experts,
154
+ n,
155
+ k,
156
+ )
157
+
158
+ c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
159
+ c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
160
+
161
+ cutlass_w4a8_moe_mm(
162
+ c1,
163
+ gateup_input,
164
+ w1_q,
165
+ a1_scale.float(),
166
+ w1_scale,
167
+ expert_offsets[:-1],
168
+ problem_sizes1,
169
+ a_strides1,
170
+ b_strides1,
171
+ c_strides1,
172
+ s_strides13,
173
+ 128,
174
+ topk,
175
+ )
176
+
177
+ intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
178
+ silu_and_mul(c1, intermediate)
179
+
180
+ intermediate_q = torch.empty(
181
+ intermediate.shape, dtype=torch.float8_e4m3fn, device=device
182
+ )
183
+ sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
184
+
185
+ cutlass_w4a8_moe_mm(
186
+ c2,
187
+ intermediate_q,
188
+ w2_q,
189
+ a2_scale.float(),
190
+ w2_scale,
191
+ expert_offsets[:-1],
192
+ problem_sizes2,
193
+ a_strides2,
194
+ b_strides2,
195
+ c_strides2,
196
+ s_strides2,
197
+ 128,
198
+ topk,
199
+ )
200
+
201
+ output = torch.empty_like(a)
202
+ post_reorder_triton_kernel[(m,)](
203
+ c2,
204
+ output,
205
+ src2dst,
206
+ topk_ids_,
207
+ topk_weights,
208
+ start_expert_id,
209
+ end_expert_id,
210
+ topk,
211
+ k,
212
+ 0,
213
+ BLOCK_SIZE=512,
214
+ )
215
+ return output
@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
146
146
 
147
147
  def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
148
148
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
149
+
149
150
  seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
150
151
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
151
152
 
@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
158
159
  compute_src2dst_triton_kernel[grid](
159
160
  reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
160
161
  )
162
+
161
163
  return reorder_topk_ids, src2dst, seg_indptr
162
164
 
163
165
 
166
+ def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
167
+ reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
168
+
169
+ seg_indptr = torch.zeros(
170
+ local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
171
+ )
172
+ src2dst = torch.empty(
173
+ local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
174
+ )
175
+
176
+ BLOCK_SIZE = 512
177
+ grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
178
+ compute_src2dst_triton_kernel[grid](
179
+ reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
180
+ )
181
+
182
+ return reorder_topk_ids, src2dst, seg_indptr
183
+
184
+
185
+ @triton.jit
186
+ def pre_reorder_triton_kernel_for_cutlass_moe(
187
+ input_ptr,
188
+ gateup_input_ptr,
189
+ src2dst_ptr,
190
+ topk_ids_ptr,
191
+ a1_scales_ptr,
192
+ num_experts,
193
+ topk,
194
+ hidden_size,
195
+ BLOCK_SIZE: tl.constexpr,
196
+ ):
197
+ OutDtype = gateup_input_ptr.dtype.element_ty
198
+
199
+ src_idx = tl.program_id(0)
200
+ src2dst_ptr = src2dst_ptr + src_idx * topk
201
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
202
+
203
+ src_ptr = input_ptr + src_idx * hidden_size
204
+ for idx in range(topk):
205
+ expert_id = tl.load(topk_ids_ptr + idx)
206
+ if expert_id != num_experts:
207
+ if a1_scales_ptr is not None:
208
+ scale = 1.0 / tl.load(a1_scales_ptr)
209
+ else:
210
+ scale = 1.0
211
+
212
+ dst_idx = tl.load(src2dst_ptr + idx)
213
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
214
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
215
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
216
+ mask = offset < hidden_size
217
+ in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
218
+ out_data = (in_data * scale).to(OutDtype)
219
+ tl.store(dst_ptr + offset, out_data, mask=mask)
220
+
221
+
164
222
  @triton.jit
165
223
  def pre_reorder_triton_kernel(
166
224
  input_ptr,