sglang 0.4.9__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +12 -1
- sglang/srt/conversation.py +35 -1
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/layers/communicator.py +3 -1
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +58 -0
- sglang/srt/layers/moe/ep_moe/layer.py +140 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +135 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +8 -1
- sglang/srt/managers/mm_utils.py +4 -2
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +17 -5
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +113 -63
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/models/deepseek_v2.py +16 -2
- sglang/srt/models/mllama4.py +360 -79
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +62 -60
- sglang/srt/server_args.py +15 -0
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +37 -17
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +4 -3
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +47 -43
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -814,9 +814,9 @@ def sample_mmmu_requests(
|
|
814
814
|
List of tuples (prompt, prompt_token_len, output_token_len).
|
815
815
|
"""
|
816
816
|
try:
|
817
|
-
import base64
|
818
817
|
import io
|
819
818
|
|
819
|
+
import pybase64
|
820
820
|
from datasets import load_dataset
|
821
821
|
except ImportError:
|
822
822
|
raise ImportError("Please install datasets: pip install datasets")
|
@@ -867,7 +867,7 @@ def sample_mmmu_requests(
|
|
867
867
|
# Encode image to base64
|
868
868
|
buffered = io.BytesIO()
|
869
869
|
image.save(buffered, format="JPEG")
|
870
|
-
img_str =
|
870
|
+
img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
|
871
871
|
image_data = f"data:image/jpeg;base64,{img_str}"
|
872
872
|
else:
|
873
873
|
continue
|
@@ -359,7 +359,17 @@ class ModelConfig:
|
|
359
359
|
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
360
360
|
quant_cfg = modelopt_quant_config
|
361
361
|
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
362
|
-
|
362
|
+
quant_config_file = os.path.join(
|
363
|
+
self.model_path, "hf_quant_config.json"
|
364
|
+
)
|
365
|
+
with open(quant_config_file) as f:
|
366
|
+
quant_config_dict = json.load(f)
|
367
|
+
json_quant_configs = quant_config_dict["quantization"]
|
368
|
+
quant_algo = json_quant_configs.get("quant_algo", None)
|
369
|
+
if quant_algo == "MIXED_PRECISION":
|
370
|
+
quant_cfg = {"quant_method": "w4afp8"}
|
371
|
+
else:
|
372
|
+
quant_cfg = modelopt_quant_config
|
363
373
|
return quant_cfg
|
364
374
|
|
365
375
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
@@ -389,6 +399,7 @@ class ModelConfig:
|
|
389
399
|
"w8a8_fp8",
|
390
400
|
"moe_wna16",
|
391
401
|
"qoq",
|
402
|
+
"w4afp8",
|
392
403
|
]
|
393
404
|
compatible_quantization_methods = {
|
394
405
|
"modelopt_fp4": ["modelopt"],
|
sglang/srt/conversation.py
CHANGED
@@ -921,6 +921,19 @@ register_conv_template(
|
|
921
921
|
)
|
922
922
|
)
|
923
923
|
|
924
|
+
register_conv_template(
|
925
|
+
Conversation(
|
926
|
+
name="mimo-vl",
|
927
|
+
system_message="You are MiMo, an AI assistant developed by Xiaomi.",
|
928
|
+
system_template="<|im_start|>system\n{system_message}",
|
929
|
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
930
|
+
sep="<|im_end|>\n",
|
931
|
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
932
|
+
stop_str=["<|im_end|>"],
|
933
|
+
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
934
|
+
)
|
935
|
+
)
|
936
|
+
|
924
937
|
|
925
938
|
register_conv_template(
|
926
939
|
Conversation(
|
@@ -935,6 +948,19 @@ register_conv_template(
|
|
935
948
|
)
|
936
949
|
)
|
937
950
|
|
951
|
+
register_conv_template(
|
952
|
+
Conversation(
|
953
|
+
name="llama_4_vision",
|
954
|
+
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
955
|
+
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
956
|
+
roles=("user", "assistant"),
|
957
|
+
sep_style=SeparatorStyle.LLAMA4,
|
958
|
+
sep="",
|
959
|
+
stop_str="<|eot|>",
|
960
|
+
image_token="<|image|>",
|
961
|
+
)
|
962
|
+
)
|
963
|
+
|
938
964
|
|
939
965
|
@register_conv_template_matching_function
|
940
966
|
def match_internvl(model_path: str):
|
@@ -943,9 +969,11 @@ def match_internvl(model_path: str):
|
|
943
969
|
|
944
970
|
|
945
971
|
@register_conv_template_matching_function
|
946
|
-
def
|
972
|
+
def match_llama_vision(model_path: str):
|
947
973
|
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
948
974
|
return "llama_3_vision"
|
975
|
+
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
976
|
+
return "llama_4_vision"
|
949
977
|
|
950
978
|
|
951
979
|
@register_conv_template_matching_function
|
@@ -1034,3 +1062,9 @@ def match_phi_4_mm(model_path: str):
|
|
1034
1062
|
def match_vila(model_path: str):
|
1035
1063
|
if re.search(r"vila", model_path, re.IGNORECASE):
|
1036
1064
|
return "chatml"
|
1065
|
+
|
1066
|
+
|
1067
|
+
@register_conv_template_matching_function
|
1068
|
+
def match_mimo_vl(model_path: str):
|
1069
|
+
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
1070
|
+
return "mimo-vl"
|
@@ -185,9 +185,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
185
185
|
threading.Thread(
|
186
186
|
target=self.transfer_worker, args=(queue, executor), daemon=True
|
187
187
|
).start()
|
188
|
-
|
189
|
-
|
190
|
-
|
188
|
+
# If a timeout happens on the prefill side, it means prefill instances
|
189
|
+
# fail to receive the KV indices from the decode instance of this request.
|
190
|
+
# These timeout requests should be aborted to release the tree cache.
|
191
|
+
self.bootstrap_timeout = get_int_env_var(
|
192
|
+
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
191
193
|
)
|
192
194
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
193
195
|
self.heartbeat_failures = {}
|
@@ -209,6 +211,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
209
211
|
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
210
212
|
self.prefill_tp_size_table: Dict[str, int] = {}
|
211
213
|
self.prefill_dp_size_table: Dict[str, int] = {}
|
214
|
+
# If a timeout happens on the decode side, it means decode instances
|
215
|
+
# fail to receive the KV Cache transfer done signal after bootstrapping.
|
216
|
+
# These timeout requests should be aborted to release the tree cache.
|
217
|
+
self.waiting_timeout = get_int_env_var(
|
218
|
+
"SGLANG_DISAGGREGATION_WAITING_TIMEOUT", 300
|
219
|
+
)
|
212
220
|
else:
|
213
221
|
raise ValueError(
|
214
222
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
@@ -938,7 +946,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
938
946
|
if self.init_time is not None:
|
939
947
|
now = time.time()
|
940
948
|
elapsed = now - self.init_time
|
941
|
-
if elapsed >= self.kv_mgr.
|
949
|
+
if elapsed >= self.kv_mgr.bootstrap_timeout:
|
950
|
+
logger.warning_once(
|
951
|
+
"Some requests timed out when bootstrapping, "
|
952
|
+
"which means prefill instances fail to receive the KV indices from the decode instance of this request. "
|
953
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
954
|
+
)
|
942
955
|
self.kv_mgr.record_failure(
|
943
956
|
self.bootstrap_room,
|
944
957
|
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.Bootstrapping",
|
@@ -987,6 +1000,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
987
1000
|
self.session_id = self.kv_mgr.get_session_id()
|
988
1001
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
989
1002
|
self.conclude_state = None
|
1003
|
+
self.init_time = None
|
990
1004
|
self.data_parallel_rank = data_parallel_rank
|
991
1005
|
|
992
1006
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
@@ -1222,14 +1236,31 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1222
1236
|
str(self.required_dst_info_num).encode("ascii"),
|
1223
1237
|
]
|
1224
1238
|
)
|
1239
|
+
self.init_time = time.time()
|
1225
1240
|
|
1226
1241
|
def poll(self) -> KVPoll:
|
1227
1242
|
if self.conclude_state is None:
|
1228
1243
|
status = self.kv_mgr.check_status(self.bootstrap_room)
|
1229
1244
|
if status in (KVPoll.Success, KVPoll.Failed):
|
1230
1245
|
self.conclude_state = status
|
1246
|
+
elif status == KVPoll.WaitingForInput:
|
1247
|
+
if self.init_time is not None:
|
1248
|
+
now = time.time()
|
1249
|
+
elapsed = now - self.init_time
|
1250
|
+
if elapsed >= self.kv_mgr.waiting_timeout:
|
1251
|
+
logger.warning_once(
|
1252
|
+
"Some requests fail to receive KV Cache transfer done signal after bootstrapping. "
|
1253
|
+
"If a greater mean TTFT is acceptable, you can 'export SGLANG_DISAGGREGATION_WAITING_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
|
1254
|
+
)
|
1255
|
+
self.kv_mgr.record_failure(
|
1256
|
+
self.bootstrap_room,
|
1257
|
+
f"Request {self.bootstrap_room} timed out after {elapsed:.1f}s in KVPoll.WaitingForInput",
|
1258
|
+
)
|
1259
|
+
self.conclude_state = KVPoll.Failed
|
1260
|
+
return KVPoll.Failed
|
1231
1261
|
|
1232
1262
|
return status
|
1263
|
+
|
1233
1264
|
else:
|
1234
1265
|
return self.conclude_state
|
1235
1266
|
|
@@ -1,4 +1,3 @@
|
|
1
|
-
import base64
|
2
1
|
import copy
|
3
2
|
import dataclasses
|
4
3
|
import multiprocessing
|
@@ -7,6 +6,7 @@ import threading
|
|
7
6
|
import time
|
8
7
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
9
8
|
|
9
|
+
import pybase64
|
10
10
|
import requests
|
11
11
|
import torch
|
12
12
|
import torch.distributed as dist
|
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
402
402
|
if hidden_states.shape[0] != 0:
|
403
403
|
hidden_states = layernorm(hidden_states)
|
404
404
|
else:
|
405
|
+
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
406
|
+
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
405
407
|
if (
|
406
408
|
_is_sm100_supported
|
407
409
|
and _is_flashinfer_available
|
408
410
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
409
411
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
410
|
-
and hidden_states.shape[0] <=
|
412
|
+
and hidden_states.shape[0] <= 128
|
411
413
|
):
|
412
414
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
413
415
|
hidden_states, residual
|
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|
92
92
|
|
93
93
|
|
94
94
|
def ensure_workspace_initialized(
|
95
|
-
max_token_num: int =
|
95
|
+
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
96
96
|
):
|
97
97
|
"""Ensure workspace is initialized"""
|
98
98
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
|
|
119
119
|
return _workspace_manager.initialized
|
120
120
|
|
121
121
|
|
122
|
-
def
|
122
|
+
def flashinfer_allreduce_residual_rmsnorm(
|
123
123
|
input_tensor: torch.Tensor,
|
124
124
|
residual: torch.Tensor,
|
125
125
|
weight: torch.Tensor,
|
126
126
|
eps: float = 1e-6,
|
127
|
-
max_token_num: int =
|
127
|
+
max_token_num: int = 128,
|
128
128
|
use_oneshot: bool = True,
|
129
129
|
trigger_completion_at_end: bool = False,
|
130
130
|
fp32_acc: bool = False,
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
|
|
174
174
|
if residual is not None:
|
175
175
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
176
176
|
from sglang.srt.layers.flashinfer_comm_fusion import (
|
177
|
-
|
177
|
+
flashinfer_allreduce_residual_rmsnorm,
|
178
178
|
)
|
179
179
|
|
180
180
|
if get_tensor_model_parallel_world_size() > 1:
|
181
|
-
fused_result =
|
181
|
+
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
182
182
|
input_tensor=x,
|
183
183
|
residual=residual,
|
184
184
|
weight=self.weight,
|
@@ -0,0 +1,215 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
"""Cutlass W4A8 MoE kernel."""
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from sgl_kernel import (
|
7
|
+
cutlass_w4a8_moe_mm,
|
8
|
+
get_cutlass_w4a8_moe_mm_data,
|
9
|
+
sgl_per_tensor_quant_fp8,
|
10
|
+
silu_and_mul,
|
11
|
+
)
|
12
|
+
|
13
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
+
post_reorder_triton_kernel,
|
15
|
+
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
|
+
run_cutlass_moe_ep_preproess,
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def cutlass_w4a8_moe(
|
21
|
+
start_expert_id: int,
|
22
|
+
end_expert_id: int,
|
23
|
+
total_num_experts: int,
|
24
|
+
a: torch.Tensor,
|
25
|
+
w1_q: torch.Tensor,
|
26
|
+
w2_q: torch.Tensor,
|
27
|
+
w1_scale: torch.Tensor,
|
28
|
+
w2_scale: torch.Tensor,
|
29
|
+
topk_weights: torch.Tensor,
|
30
|
+
topk_ids_: torch.Tensor,
|
31
|
+
local_topk_ids: torch.Tensor,
|
32
|
+
a_strides1: torch.Tensor,
|
33
|
+
b_strides1: torch.Tensor,
|
34
|
+
c_strides1: torch.Tensor,
|
35
|
+
a_strides2: torch.Tensor,
|
36
|
+
b_strides2: torch.Tensor,
|
37
|
+
c_strides2: torch.Tensor,
|
38
|
+
s_strides13: torch.Tensor,
|
39
|
+
s_strides2: torch.Tensor,
|
40
|
+
expert_offsets: torch.Tensor,
|
41
|
+
problem_sizes1: torch.Tensor,
|
42
|
+
problem_sizes2: torch.Tensor,
|
43
|
+
a1_scale: Optional[torch.Tensor] = None,
|
44
|
+
a2_scale: Optional[torch.Tensor] = None,
|
45
|
+
apply_router_weight_on_input: bool = False,
|
46
|
+
) -> torch.Tensor:
|
47
|
+
"""
|
48
|
+
This function computes a w4a8-quantized Mixture of Experts (MoE) layer
|
49
|
+
using two sets of quantized weights, w1_q and w2_q, and top-k gating
|
50
|
+
mechanism. The matrix multiplications are implemented with CUTLASS
|
51
|
+
grouped gemm.
|
52
|
+
|
53
|
+
Parameters:
|
54
|
+
- a (torch.Tensor): The input tensor to the MoE layer.
|
55
|
+
Shape: [M, K]
|
56
|
+
- w1_q (torch.Tensor): The first set of int4-quantized expert weights.
|
57
|
+
Shape: [num_experts, N * 2, K // 2]
|
58
|
+
(the weights are passed transposed and int4-packed)
|
59
|
+
- w2_q (torch.Tensor): The second set of int4-quantized expert weights.
|
60
|
+
Shape: [num_experts, K, N // 2]
|
61
|
+
(the weights are passed transposed and int4-packed)
|
62
|
+
- w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
|
63
|
+
Shape: [num_experts, K // 512, N * 8]
|
64
|
+
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
65
|
+
Shape: [num_experts, N // 512, K * 4]
|
66
|
+
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
67
|
+
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
68
|
+
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
69
|
+
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
70
|
+
- a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
|
71
|
+
- b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
|
72
|
+
- c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
|
73
|
+
- s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
|
74
|
+
- s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
|
75
|
+
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
|
76
|
+
Shape: scalar or [1, K]
|
77
|
+
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
|
78
|
+
quantize the intermediate result between the gemms.
|
79
|
+
Shape: scalar or [1, N]
|
80
|
+
- apply_router_weight_on_input (bool): When true, the topk weights are
|
81
|
+
applied directly on the inputs. This is only applicable when topk is 1.
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
85
|
+
"""
|
86
|
+
assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
|
87
|
+
assert w1_q.dtype == torch.int8
|
88
|
+
assert w2_q.dtype == torch.int8
|
89
|
+
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
90
|
+
assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
|
91
|
+
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
|
+
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
|
+
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
+
assert (
|
95
|
+
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
+
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
+
), "W1 scale shape mismatch"
|
98
|
+
assert (
|
99
|
+
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
+
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
+
), "W2 scale shape mismatch"
|
102
|
+
|
103
|
+
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
|
+
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
|
+
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
|
+
num_experts = w1_q.size(0)
|
108
|
+
m = a.size(0)
|
109
|
+
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
110
|
+
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
111
|
+
topk = topk_ids_.size(1)
|
112
|
+
|
113
|
+
if apply_router_weight_on_input:
|
114
|
+
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
115
|
+
|
116
|
+
device = a.device
|
117
|
+
|
118
|
+
_, src2dst, _ = run_cutlass_moe_ep_preproess(
|
119
|
+
local_topk_ids,
|
120
|
+
num_experts,
|
121
|
+
)
|
122
|
+
|
123
|
+
gateup_input = torch.empty(
|
124
|
+
(m * topk, k),
|
125
|
+
device=device,
|
126
|
+
dtype=torch.float8_e4m3fn,
|
127
|
+
)
|
128
|
+
|
129
|
+
pre_reorder_triton_kernel_for_cutlass_moe[(m,)](
|
130
|
+
a,
|
131
|
+
gateup_input,
|
132
|
+
src2dst,
|
133
|
+
local_topk_ids,
|
134
|
+
a1_scale,
|
135
|
+
total_num_experts,
|
136
|
+
topk,
|
137
|
+
k,
|
138
|
+
BLOCK_SIZE=512,
|
139
|
+
)
|
140
|
+
|
141
|
+
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
142
|
+
# they are kept to allow for a quick switch of the permutation logic
|
143
|
+
# from the current triton kernel implementation to the cutlass-based one if needed.
|
144
|
+
a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
145
|
+
c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
|
146
|
+
get_cutlass_w4a8_moe_mm_data(
|
147
|
+
local_topk_ids,
|
148
|
+
expert_offsets,
|
149
|
+
problem_sizes1,
|
150
|
+
problem_sizes2,
|
151
|
+
a_map,
|
152
|
+
c_map,
|
153
|
+
num_experts,
|
154
|
+
n,
|
155
|
+
k,
|
156
|
+
)
|
157
|
+
|
158
|
+
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
159
|
+
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
160
|
+
|
161
|
+
cutlass_w4a8_moe_mm(
|
162
|
+
c1,
|
163
|
+
gateup_input,
|
164
|
+
w1_q,
|
165
|
+
a1_scale.float(),
|
166
|
+
w1_scale,
|
167
|
+
expert_offsets[:-1],
|
168
|
+
problem_sizes1,
|
169
|
+
a_strides1,
|
170
|
+
b_strides1,
|
171
|
+
c_strides1,
|
172
|
+
s_strides13,
|
173
|
+
128,
|
174
|
+
topk,
|
175
|
+
)
|
176
|
+
|
177
|
+
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
178
|
+
silu_and_mul(c1, intermediate)
|
179
|
+
|
180
|
+
intermediate_q = torch.empty(
|
181
|
+
intermediate.shape, dtype=torch.float8_e4m3fn, device=device
|
182
|
+
)
|
183
|
+
sgl_per_tensor_quant_fp8(intermediate, intermediate_q, a2_scale.float(), True)
|
184
|
+
|
185
|
+
cutlass_w4a8_moe_mm(
|
186
|
+
c2,
|
187
|
+
intermediate_q,
|
188
|
+
w2_q,
|
189
|
+
a2_scale.float(),
|
190
|
+
w2_scale,
|
191
|
+
expert_offsets[:-1],
|
192
|
+
problem_sizes2,
|
193
|
+
a_strides2,
|
194
|
+
b_strides2,
|
195
|
+
c_strides2,
|
196
|
+
s_strides2,
|
197
|
+
128,
|
198
|
+
topk,
|
199
|
+
)
|
200
|
+
|
201
|
+
output = torch.empty_like(a)
|
202
|
+
post_reorder_triton_kernel[(m,)](
|
203
|
+
c2,
|
204
|
+
output,
|
205
|
+
src2dst,
|
206
|
+
topk_ids_,
|
207
|
+
topk_weights,
|
208
|
+
start_expert_id,
|
209
|
+
end_expert_id,
|
210
|
+
topk,
|
211
|
+
k,
|
212
|
+
0,
|
213
|
+
BLOCK_SIZE=512,
|
214
|
+
)
|
215
|
+
return output
|
@@ -146,6 +146,7 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
|
146
146
|
|
147
147
|
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
148
148
|
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
149
|
+
|
149
150
|
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
150
151
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
151
152
|
|
@@ -158,9 +159,66 @@ def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
|
158
159
|
compute_src2dst_triton_kernel[grid](
|
159
160
|
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
160
161
|
)
|
162
|
+
|
161
163
|
return reorder_topk_ids, src2dst, seg_indptr
|
162
164
|
|
163
165
|
|
166
|
+
def run_cutlass_moe_ep_preproess(local_topk_ids: torch.Tensor, local_num_experts: int):
|
167
|
+
reorder_topk_ids, reorder_ids = torch.sort(local_topk_ids.view(-1), stable=True)
|
168
|
+
|
169
|
+
seg_indptr = torch.zeros(
|
170
|
+
local_num_experts + 1, device=local_topk_ids.device, dtype=torch.int64
|
171
|
+
)
|
172
|
+
src2dst = torch.empty(
|
173
|
+
local_topk_ids.numel(), device=local_topk_ids.device, dtype=torch.int32
|
174
|
+
)
|
175
|
+
|
176
|
+
BLOCK_SIZE = 512
|
177
|
+
grid = (triton.cdiv(local_topk_ids.numel(), BLOCK_SIZE),)
|
178
|
+
compute_src2dst_triton_kernel[grid](
|
179
|
+
reorder_ids, src2dst, local_topk_ids.numel(), BLOCK_SIZE
|
180
|
+
)
|
181
|
+
|
182
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
183
|
+
|
184
|
+
|
185
|
+
@triton.jit
|
186
|
+
def pre_reorder_triton_kernel_for_cutlass_moe(
|
187
|
+
input_ptr,
|
188
|
+
gateup_input_ptr,
|
189
|
+
src2dst_ptr,
|
190
|
+
topk_ids_ptr,
|
191
|
+
a1_scales_ptr,
|
192
|
+
num_experts,
|
193
|
+
topk,
|
194
|
+
hidden_size,
|
195
|
+
BLOCK_SIZE: tl.constexpr,
|
196
|
+
):
|
197
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
198
|
+
|
199
|
+
src_idx = tl.program_id(0)
|
200
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
201
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
202
|
+
|
203
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
204
|
+
for idx in range(topk):
|
205
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
206
|
+
if expert_id != num_experts:
|
207
|
+
if a1_scales_ptr is not None:
|
208
|
+
scale = 1.0 / tl.load(a1_scales_ptr)
|
209
|
+
else:
|
210
|
+
scale = 1.0
|
211
|
+
|
212
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
213
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
214
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
215
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
216
|
+
mask = offset < hidden_size
|
217
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
218
|
+
out_data = (in_data * scale).to(OutDtype)
|
219
|
+
tl.store(dst_ptr + offset, out_data, mask=mask)
|
220
|
+
|
221
|
+
|
164
222
|
@triton.jit
|
165
223
|
def pre_reorder_triton_kernel(
|
166
224
|
input_ptr,
|