sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +8 -6
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +14 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -28
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/conn.py +301 -64
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +20 -15
- sglang/srt/disaggregation/utils.py +47 -18
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +27 -31
- sglang/srt/entrypoints/http_server.py +149 -79
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +897 -0
- sglang/srt/entrypoints/openai/serving_completions.py +425 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +28 -3
- sglang/srt/layers/attention/aiter_backend.py +5 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
- sglang/srt/layers/attention/flashattention_backend.py +43 -23
- sglang/srt/layers/attention/flashinfer_backend.py +9 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
- sglang/srt/layers/attention/flashmla_backend.py +5 -2
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +19 -11
- sglang/srt/layers/communicator.py +5 -5
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +44 -2
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +14 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
- sglang/srt/layers/moe/ep_moe/layer.py +286 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
- sglang/srt/layers/moe/topk.py +117 -4
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/rotary_embedding.py +144 -12
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/lora/lora_manager.py +173 -74
- sglang/srt/lora/mem_pool.py +49 -45
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -15
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +19 -14
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +49 -32
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +189 -68
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +11 -8
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -16
- sglang/srt/mem_cache/hiradix_cache.py +34 -23
- sglang/srt/mem_cache/memory_pool.py +118 -114
- sglang/srt/mem_cache/radix_cache.py +20 -16
- sglang/srt/model_executor/cuda_graph_runner.py +77 -46
- sglang/srt/model_executor/forward_batch_info.py +18 -5
- sglang/srt/model_executor/model_runner.py +27 -8
- sglang/srt/model_loader/loader.py +50 -8
- sglang/srt/model_loader/weight_utils.py +100 -2
- sglang/srt/models/deepseek_nextn.py +35 -30
- sglang/srt/models/deepseek_v2.py +255 -30
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/server_args.py +51 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
- sglang/srt/speculative/eagle_utils.py +80 -8
- sglang/srt/speculative/eagle_worker.py +124 -41
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +248 -11
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -2148
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -478,11 +478,13 @@ def post_reorder_triton_kernel(
|
|
478
478
|
end_expert_id,
|
479
479
|
topk,
|
480
480
|
hidden_size,
|
481
|
+
dst_start,
|
481
482
|
BLOCK_SIZE: tl.constexpr,
|
482
483
|
):
|
483
484
|
InDtype = down_output_ptr.dtype.element_ty
|
484
485
|
|
485
|
-
|
486
|
+
src_idx_int32 = tl.program_id(0)
|
487
|
+
src_idx = src_idx_int32.to(tl.int64)
|
486
488
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
487
489
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
488
490
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
@@ -501,7 +503,9 @@ def post_reorder_triton_kernel(
|
|
501
503
|
expert_id = tl.load(topk_ids_ptr + idx)
|
502
504
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
503
505
|
computed = True
|
504
|
-
|
506
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
507
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
508
|
+
dst_idx = dst_idx - dst_start
|
505
509
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
506
510
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
507
511
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
@@ -1086,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
|
|
1086
1090
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1087
1091
|
)
|
1088
1092
|
return output.t()[:m]
|
1093
|
+
|
1094
|
+
|
1095
|
+
@triton.jit
|
1096
|
+
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
|
1097
|
+
expert_id = tl.program_id(0)
|
1098
|
+
start = tl.load(seg_indptr + expert_id)
|
1099
|
+
end = tl.load(seg_indptr + expert_id + 1)
|
1100
|
+
tl.store(masked_m + expert_id, (end - start))
|
1101
|
+
|
1102
|
+
|
1103
|
+
@triton.jit
|
1104
|
+
def deepgemm_compute_src2dst_triton_kernel(
|
1105
|
+
topk_ids,
|
1106
|
+
reorder_ids,
|
1107
|
+
seg_indptr,
|
1108
|
+
src2dst,
|
1109
|
+
m_max,
|
1110
|
+
num_toks,
|
1111
|
+
BLOCK_SIZE: tl.constexpr,
|
1112
|
+
):
|
1113
|
+
pid = tl.program_id(axis=0)
|
1114
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
1115
|
+
mask = dst_id < num_toks
|
1116
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
1117
|
+
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
1118
|
+
expert_dst_start = tl.load(seg_indptr + expert_id)
|
1119
|
+
expert_dst_offset = dst_id - expert_dst_start
|
1120
|
+
dst_id = expert_id * m_max + expert_dst_offset
|
1121
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
1122
|
+
|
1123
|
+
|
1124
|
+
@triton.jit
|
1125
|
+
def fill_gateup_input_triton_kernel(
|
1126
|
+
input_ptr,
|
1127
|
+
scale_ptr,
|
1128
|
+
gateup_input_ptr,
|
1129
|
+
gateup_input_scale_ptr,
|
1130
|
+
src2dst_ptr,
|
1131
|
+
topk_ids_ptr,
|
1132
|
+
start_expert_id,
|
1133
|
+
end_expert_id,
|
1134
|
+
topk,
|
1135
|
+
m_max,
|
1136
|
+
hidden_size,
|
1137
|
+
scale_size,
|
1138
|
+
BLOCK_SIZE: tl.constexpr,
|
1139
|
+
):
|
1140
|
+
|
1141
|
+
src_idx_int32 = tl.program_id(0)
|
1142
|
+
src_idx = src_idx_int32.to(tl.int64)
|
1143
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
1144
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
1145
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
1146
|
+
scale_src_ptr = scale_ptr + src_idx * scale_size
|
1147
|
+
|
1148
|
+
vec = tl.arange(0, BLOCK_SIZE)
|
1149
|
+
for idx in range(topk):
|
1150
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
1151
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
1152
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
1153
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
1154
|
+
dst_idx = dst_idx - start_expert_id * m_max
|
1155
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
1156
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
1157
|
+
offset = start_offset + vec
|
1158
|
+
mask = offset < hidden_size
|
1159
|
+
in_data = tl.load(src_ptr + offset, mask=mask)
|
1160
|
+
tl.store(dst_ptr + offset, in_data, mask=mask)
|
1161
|
+
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
|
1162
|
+
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
|
1163
|
+
offset = start_offset + vec
|
1164
|
+
mask = offset < scale_size
|
1165
|
+
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
|
1166
|
+
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
|
1167
|
+
|
1168
|
+
|
1169
|
+
def moe_ep_deepgemm_preprocess(
|
1170
|
+
topk_ids: torch.Tensor,
|
1171
|
+
num_experts: int,
|
1172
|
+
hidden_states: torch.Tensor,
|
1173
|
+
top_k: int,
|
1174
|
+
start_expert_id,
|
1175
|
+
end_expert_id,
|
1176
|
+
block_shape,
|
1177
|
+
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
1178
|
+
):
|
1179
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
1180
|
+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
1181
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
1182
|
+
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
1183
|
+
|
1184
|
+
compute_seg_indptr_triton_kernel[(num_experts,)](
|
1185
|
+
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
1189
|
+
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
1190
|
+
|
1191
|
+
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
1192
|
+
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
1193
|
+
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
1194
|
+
gateup_input = torch.empty(
|
1195
|
+
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
1196
|
+
device=hidden_states.device,
|
1197
|
+
dtype=output_dtype,
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
deepgemm_compute_src2dst_triton_kernel[grid](
|
1201
|
+
topk_ids,
|
1202
|
+
reorder_ids,
|
1203
|
+
seg_indptr,
|
1204
|
+
src2dst,
|
1205
|
+
m_max,
|
1206
|
+
topk_ids.numel(),
|
1207
|
+
BLOCK_SIZE=256,
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
if block_shape is None:
|
1211
|
+
block_shape = [128, 128]
|
1212
|
+
assert len(block_shape) == 2
|
1213
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
1214
|
+
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
1215
|
+
|
1216
|
+
gateup_input_scale = torch.empty(
|
1217
|
+
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
|
1218
|
+
device=hidden_states.device,
|
1219
|
+
dtype=scale.dtype,
|
1220
|
+
)
|
1221
|
+
|
1222
|
+
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
|
1223
|
+
hidden_states,
|
1224
|
+
scale,
|
1225
|
+
gateup_input,
|
1226
|
+
gateup_input_scale,
|
1227
|
+
src2dst,
|
1228
|
+
topk_ids,
|
1229
|
+
start_expert_id,
|
1230
|
+
end_expert_id,
|
1231
|
+
top_k,
|
1232
|
+
m_max,
|
1233
|
+
hidden_states.size(1),
|
1234
|
+
scale.size(1),
|
1235
|
+
BLOCK_SIZE=1024,
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
return (
|
1239
|
+
m_max,
|
1240
|
+
masked_m[start_expert_id : (end_expert_id + 1)],
|
1241
|
+
expected_m,
|
1242
|
+
src2dst,
|
1243
|
+
gateup_input,
|
1244
|
+
gateup_input_scale,
|
1245
|
+
)
|
@@ -16,6 +16,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
16
16
|
ep_scatter,
|
17
17
|
gelu_and_mul_triton_kernel,
|
18
18
|
grouped_gemm_triton,
|
19
|
+
moe_ep_deepgemm_preprocess,
|
19
20
|
post_reorder_triton_kernel,
|
20
21
|
pre_reorder_triton_kernel,
|
21
22
|
run_moe_ep_preproess,
|
@@ -33,10 +34,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|
33
34
|
)
|
34
35
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
35
36
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
|
+
is_fp8_fnuz,
|
36
38
|
scaled_fp8_quant,
|
37
39
|
sglang_per_token_group_quant_fp8,
|
38
40
|
sglang_per_token_quant_fp8,
|
39
41
|
)
|
42
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
40
43
|
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
41
44
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
42
45
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -50,10 +53,17 @@ from sglang.srt.utils import (
|
|
50
53
|
)
|
51
54
|
|
52
55
|
_is_hip = is_hip()
|
56
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
57
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
53
58
|
|
54
59
|
if _is_hip:
|
55
60
|
from vllm._custom_ops import scaled_fp8_quant
|
56
61
|
|
62
|
+
if _use_aiter:
|
63
|
+
from aiter import ActivationType, QuantType
|
64
|
+
from aiter.fused_moe import fused_moe
|
65
|
+
from aiter.ops.shuffle import shuffle_weight
|
66
|
+
|
57
67
|
logger = logging.getLogger(__name__)
|
58
68
|
|
59
69
|
|
@@ -175,6 +185,7 @@ class EPMoE(torch.nn.Module):
|
|
175
185
|
assert (
|
176
186
|
num_fused_shared_experts == 0
|
177
187
|
), "num_fused_shared_experts is not supported in EP"
|
188
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
178
189
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
179
190
|
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
180
191
|
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
@@ -224,13 +235,182 @@ class EPMoE(torch.nn.Module):
|
|
224
235
|
|
225
236
|
self.grouped_gemm_runner = None
|
226
237
|
|
238
|
+
self.w13_weight_fp8 = (
|
239
|
+
self.w13_weight,
|
240
|
+
(
|
241
|
+
self.w13_weight_scale_inv
|
242
|
+
if self.use_block_quant
|
243
|
+
else self.w13_weight_scale
|
244
|
+
),
|
245
|
+
)
|
246
|
+
self.w2_weight_fp8 = (
|
247
|
+
self.w2_weight,
|
248
|
+
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
249
|
+
)
|
250
|
+
|
227
251
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
252
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
253
|
+
return self.forward_deepgemm(hidden_states, router_logits)
|
254
|
+
else:
|
255
|
+
return self.forward_normal(hidden_states, router_logits)
|
256
|
+
|
257
|
+
def forward_deepgemm(
|
258
|
+
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
259
|
+
):
|
260
|
+
assert self.quant_method is not None
|
261
|
+
assert self.activation == "silu"
|
228
262
|
hidden_states_shape = hidden_states.shape
|
229
263
|
hidden_states_dtype = hidden_states.dtype
|
230
264
|
hidden_states_device = hidden_states.device
|
265
|
+
topk_weights, topk_ids = select_experts(
|
266
|
+
hidden_states=hidden_states,
|
267
|
+
router_logits=router_logits,
|
268
|
+
top_k=self.top_k,
|
269
|
+
use_grouped_topk=self.use_grouped_topk,
|
270
|
+
renormalize=self.renormalize,
|
271
|
+
topk_group=self.topk_group,
|
272
|
+
num_expert_group=self.num_expert_group,
|
273
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
274
|
+
correction_bias=self.correction_bias,
|
275
|
+
custom_routing_function=self.custom_routing_function,
|
276
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
277
|
+
)
|
231
278
|
|
232
|
-
|
279
|
+
if not self.use_block_quant:
|
280
|
+
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
281
|
+
scale_block_size = 128
|
282
|
+
w13_weight_scale_n = 2 * (
|
283
|
+
(self.intermediate_size + scale_block_size - 1) // scale_block_size
|
284
|
+
)
|
285
|
+
w13_weight_scale_k = (
|
286
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
287
|
+
) // scale_block_size
|
288
|
+
w13_weight_scale = (
|
289
|
+
self.w13_weight_scale.unsqueeze(1)
|
290
|
+
.repeat_interleave(w13_weight_scale_n, dim=1)
|
291
|
+
.unsqueeze(2)
|
292
|
+
.repeat_interleave(w13_weight_scale_k, dim=2)
|
293
|
+
)
|
294
|
+
self.w13_weight_fp8 = (
|
295
|
+
self.w13_weight,
|
296
|
+
w13_weight_scale,
|
297
|
+
)
|
298
|
+
w2_weight_scale_n = (
|
299
|
+
hidden_states_shape[-1] + scale_block_size - 1
|
300
|
+
) // scale_block_size
|
301
|
+
w2_weight_scale_k = (
|
302
|
+
self.intermediate_size + scale_block_size - 1
|
303
|
+
) // scale_block_size
|
304
|
+
w2_weight_scale = (
|
305
|
+
self.w2_weight_scale.unsqueeze(1)
|
306
|
+
.repeat_interleave(w2_weight_scale_n, dim=1)
|
307
|
+
.unsqueeze(2)
|
308
|
+
.repeat_interleave(w2_weight_scale_k, dim=2)
|
309
|
+
)
|
310
|
+
self.w2_weight_fp8 = (
|
311
|
+
self.w2_weight,
|
312
|
+
w2_weight_scale,
|
313
|
+
)
|
314
|
+
|
315
|
+
# PreReorder
|
316
|
+
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
|
317
|
+
moe_ep_deepgemm_preprocess(
|
318
|
+
topk_ids,
|
319
|
+
self.num_experts,
|
320
|
+
hidden_states,
|
321
|
+
self.top_k,
|
322
|
+
self.start_expert_id,
|
323
|
+
self.end_expert_id,
|
324
|
+
self.block_shape,
|
325
|
+
)
|
326
|
+
)
|
327
|
+
|
328
|
+
dispose_tensor(hidden_states)
|
233
329
|
|
330
|
+
# GroupGemm-0
|
331
|
+
gateup_input_fp8 = (
|
332
|
+
gateup_input,
|
333
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(gateup_input_scale),
|
334
|
+
)
|
335
|
+
num_groups, m, k = gateup_input_fp8[0].size()
|
336
|
+
n = self.w13_weight.size(1)
|
337
|
+
gateup_output = torch.empty(
|
338
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
339
|
+
)
|
340
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
341
|
+
gateup_input_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
342
|
+
)
|
343
|
+
del gateup_input
|
344
|
+
del gateup_input_fp8
|
345
|
+
|
346
|
+
# Act
|
347
|
+
down_input = torch.empty(
|
348
|
+
(
|
349
|
+
gateup_output.shape[0],
|
350
|
+
gateup_output.shape[1],
|
351
|
+
gateup_output.shape[2] // 2,
|
352
|
+
),
|
353
|
+
device=hidden_states_device,
|
354
|
+
dtype=self.fp8_dtype,
|
355
|
+
)
|
356
|
+
scale_block_size = 128
|
357
|
+
down_input_scale = torch.empty(
|
358
|
+
(
|
359
|
+
gateup_output.shape[0],
|
360
|
+
gateup_output.shape[1],
|
361
|
+
gateup_output.shape[2] // 2 // scale_block_size,
|
362
|
+
),
|
363
|
+
device=hidden_states_device,
|
364
|
+
dtype=torch.float32,
|
365
|
+
)
|
366
|
+
silu_and_mul_masked_post_quant_fwd(
|
367
|
+
gateup_output,
|
368
|
+
down_input,
|
369
|
+
down_input_scale,
|
370
|
+
scale_block_size,
|
371
|
+
masked_m,
|
372
|
+
)
|
373
|
+
del gateup_output
|
374
|
+
|
375
|
+
# GroupGemm-1
|
376
|
+
n = self.w2_weight.size(1)
|
377
|
+
down_input_fp8 = (
|
378
|
+
down_input,
|
379
|
+
deep_gemm_wrapper.get_col_major_tma_aligned_tensor(down_input_scale),
|
380
|
+
)
|
381
|
+
down_output = torch.empty(
|
382
|
+
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
|
383
|
+
)
|
384
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
385
|
+
down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m
|
386
|
+
)
|
387
|
+
del down_input
|
388
|
+
del down_input_fp8
|
389
|
+
|
390
|
+
# PostReorder
|
391
|
+
output = torch.empty(
|
392
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
393
|
+
)
|
394
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
395
|
+
down_output,
|
396
|
+
output,
|
397
|
+
src2dst,
|
398
|
+
topk_ids,
|
399
|
+
topk_weights,
|
400
|
+
self.start_expert_id,
|
401
|
+
self.end_expert_id,
|
402
|
+
self.top_k,
|
403
|
+
hidden_states_shape[1],
|
404
|
+
m_max * self.start_expert_id,
|
405
|
+
BLOCK_SIZE=512,
|
406
|
+
)
|
407
|
+
return output
|
408
|
+
|
409
|
+
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
410
|
+
assert self.quant_method is not None
|
411
|
+
hidden_states_shape = hidden_states.shape
|
412
|
+
hidden_states_dtype = hidden_states.dtype
|
413
|
+
hidden_states_device = hidden_states.device
|
234
414
|
if self.grouped_gemm_runner is None:
|
235
415
|
self.grouped_gemm_runner = GroupedGemmRunner(
|
236
416
|
hidden_states.device,
|
@@ -246,6 +426,7 @@ class EPMoE(torch.nn.Module):
|
|
246
426
|
renormalize=self.renormalize,
|
247
427
|
topk_group=self.topk_group,
|
248
428
|
num_expert_group=self.num_expert_group,
|
429
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
249
430
|
correction_bias=self.correction_bias,
|
250
431
|
custom_routing_function=self.custom_routing_function,
|
251
432
|
routed_scaling_factor=self.routed_scaling_factor,
|
@@ -437,6 +618,7 @@ class EPMoE(torch.nn.Module):
|
|
437
618
|
self.end_expert_id,
|
438
619
|
self.top_k,
|
439
620
|
hidden_states_shape[1],
|
621
|
+
0,
|
440
622
|
BLOCK_SIZE=512,
|
441
623
|
)
|
442
624
|
return output
|
@@ -843,6 +1025,42 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
843
1025
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
844
1026
|
requires_grad=False,
|
845
1027
|
)
|
1028
|
+
if self.block_quant:
|
1029
|
+
# If ROCm, normalize the weights and scales to e4m3fnuz
|
1030
|
+
if _is_fp8_fnuz:
|
1031
|
+
# activation_scheme: dynamic
|
1032
|
+
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1033
|
+
weight=layer.w13_weight,
|
1034
|
+
weight_scale=layer.w13_weight_scale_inv,
|
1035
|
+
input_scale=None,
|
1036
|
+
)
|
1037
|
+
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1038
|
+
weight=layer.w2_weight,
|
1039
|
+
weight_scale=layer.w2_weight_scale_inv,
|
1040
|
+
input_scale=None,
|
1041
|
+
)
|
1042
|
+
# Reset the parameter
|
1043
|
+
layer.w13_weight = torch.nn.Parameter(
|
1044
|
+
w13_weight, requires_grad=False
|
1045
|
+
)
|
1046
|
+
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
1047
|
+
w13_weight_scale, requires_grad=False
|
1048
|
+
)
|
1049
|
+
layer.w13_input_scale = None
|
1050
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
1051
|
+
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
1052
|
+
w2_weight_scale, requires_grad=False
|
1053
|
+
)
|
1054
|
+
layer.w2_input_scale = None
|
1055
|
+
if _use_aiter:
|
1056
|
+
layer.w13_weight = torch.nn.Parameter(
|
1057
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1058
|
+
requires_grad=False,
|
1059
|
+
)
|
1060
|
+
layer.w2_weight = torch.nn.Parameter(
|
1061
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1062
|
+
requires_grad=False,
|
1063
|
+
)
|
846
1064
|
return
|
847
1065
|
|
848
1066
|
def apply(
|
@@ -914,18 +1132,36 @@ class DeepEPMoE(EPMoE):
|
|
914
1132
|
assert (
|
915
1133
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
916
1134
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
917
|
-
|
918
|
-
self.
|
919
|
-
(
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
1135
|
+
if _use_aiter:
|
1136
|
+
# expert_mask is of size (self.num_experts_per_partition + 1),
|
1137
|
+
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
1138
|
+
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
1139
|
+
# self.expert_mask = [1, 1, 1, 1, 0]
|
1140
|
+
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
1141
|
+
self.expert_mask = torch.zeros(
|
1142
|
+
(self.num_experts_per_partition + 1),
|
1143
|
+
device=torch.cuda.current_device(),
|
1144
|
+
dtype=torch.int,
|
1145
|
+
)
|
1146
|
+
# the last one is invalid rank_id
|
1147
|
+
self.expert_mask[:-1] = 1
|
1148
|
+
else:
|
1149
|
+
self.w13_weight_fp8 = (
|
1150
|
+
self.w13_weight,
|
1151
|
+
(
|
1152
|
+
self.w13_weight_scale_inv
|
1153
|
+
if self.use_block_quant
|
1154
|
+
else self.w13_weight_scale
|
1155
|
+
),
|
1156
|
+
)
|
1157
|
+
self.w2_weight_fp8 = (
|
1158
|
+
self.w2_weight,
|
1159
|
+
(
|
1160
|
+
self.w2_weight_scale_inv
|
1161
|
+
if self.use_block_quant
|
1162
|
+
else self.w2_weight_scale
|
1163
|
+
),
|
1164
|
+
)
|
929
1165
|
|
930
1166
|
def forward(
|
931
1167
|
self,
|
@@ -939,6 +1175,9 @@ class DeepEPMoE(EPMoE):
|
|
939
1175
|
num_recv_tokens_per_expert: List[int],
|
940
1176
|
forward_mode: ForwardMode,
|
941
1177
|
):
|
1178
|
+
if _use_aiter:
|
1179
|
+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
1180
|
+
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
942
1181
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
943
1182
|
if resolved_deepep_mode == DeepEPMode.normal:
|
944
1183
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
@@ -1071,6 +1310,37 @@ class DeepEPMoE(EPMoE):
|
|
1071
1310
|
)
|
1072
1311
|
return down_output
|
1073
1312
|
|
1313
|
+
def forward_aiter(
|
1314
|
+
self,
|
1315
|
+
hidden_states: torch.Tensor,
|
1316
|
+
topk_idx: torch.Tensor,
|
1317
|
+
topk_weights: torch.Tensor,
|
1318
|
+
):
|
1319
|
+
if hidden_states.shape[0] == 0:
|
1320
|
+
return hidden_states
|
1321
|
+
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
1322
|
+
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
1323
|
+
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
1324
|
+
topk_idx_copy = topk_idx.to(torch.int32)
|
1325
|
+
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
1326
|
+
|
1327
|
+
return fused_moe(
|
1328
|
+
hidden_states,
|
1329
|
+
self.w13_weight,
|
1330
|
+
self.w2_weight,
|
1331
|
+
topk_weights,
|
1332
|
+
topk_idx_copy,
|
1333
|
+
w1_scale=self.w13_weight_scale_inv,
|
1334
|
+
w2_scale=self.w2_weight_scale_inv,
|
1335
|
+
quant_type=QuantType.per_128x128,
|
1336
|
+
activation=(
|
1337
|
+
ActivationType.Silu
|
1338
|
+
if self.activation == "silu"
|
1339
|
+
else ActivationType.Gelu
|
1340
|
+
),
|
1341
|
+
expert_mask=self.expert_mask,
|
1342
|
+
)
|
1343
|
+
|
1074
1344
|
def forward_deepgemm_contiguous(
|
1075
1345
|
self,
|
1076
1346
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
@@ -1265,6 +1535,9 @@ class DeepEPMoE(EPMoE):
|
|
1265
1535
|
def get_moe_impl_class():
|
1266
1536
|
if global_server_args_dict["enable_deepep_moe"]:
|
1267
1537
|
return DeepEPMoE
|
1538
|
+
if global_server_args_dict["enable_flashinfer_moe"]:
|
1539
|
+
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1540
|
+
return FusedMoE
|
1268
1541
|
if global_server_args_dict["enable_ep_moe"]:
|
1269
1542
|
return EPMoE
|
1270
1543
|
return FusedMoE
|
@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
|
|
6
6
|
get_global_expert_distribution_recorder,
|
7
7
|
)
|
8
8
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
-
from sglang.srt.utils import
|
9
|
+
from sglang.srt.utils import (
|
10
|
+
DeepEPMode,
|
11
|
+
get_bool_env_var,
|
12
|
+
get_int_env_var,
|
13
|
+
is_hip,
|
14
|
+
load_json_config,
|
15
|
+
)
|
10
16
|
|
11
17
|
try:
|
12
18
|
from deep_ep import Buffer, Config
|
@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
32
38
|
)
|
33
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
34
40
|
|
41
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
42
|
+
|
35
43
|
logger = logging.getLogger(__name__)
|
36
44
|
|
37
45
|
|
@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
376
384
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
377
385
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
378
386
|
"""
|
387
|
+
if _use_aiter:
|
388
|
+
# skip permutation here as aiter fused_moe has fused inside
|
389
|
+
reorder_topk_ids = torch.empty(
|
390
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
391
|
+
)
|
392
|
+
seg_indptr = torch.zeros(
|
393
|
+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
394
|
+
)
|
395
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
379
396
|
|
380
397
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
381
398
|
topk_idx, self.num_experts
|
@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
409
426
|
topk_idx: torch.Tensor,
|
410
427
|
topk_weights: torch.Tensor,
|
411
428
|
):
|
412
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
429
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
413
430
|
output = hidden_states
|
414
431
|
else:
|
415
432
|
if hidden_states.shape[0] > 0:
|
@@ -77,8 +77,15 @@ def moe_forward_native(
|
|
77
77
|
custom_routing_function: Optional[Callable] = None,
|
78
78
|
correction_bias: Optional[torch.Tensor] = None,
|
79
79
|
activation: str = "silu",
|
80
|
+
apply_router_weight_on_input: bool = False,
|
81
|
+
inplace: bool = True,
|
82
|
+
no_combine: bool = False,
|
80
83
|
routed_scaling_factor: Optional[float] = None,
|
81
84
|
) -> torch.Tensor:
|
85
|
+
|
86
|
+
if apply_router_weight_on_input:
|
87
|
+
raise NotImplementedError()
|
88
|
+
|
82
89
|
topk_weights, topk_ids = select_experts(
|
83
90
|
hidden_states=x,
|
84
91
|
router_logits=router_logits,
|