sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import einops
|
5
5
|
import torch
|
6
|
-
from sgl_kernel import silu_and_mul
|
7
6
|
from torch.nn import Module
|
8
7
|
|
9
8
|
from sglang.srt.custom_op import CustomOp
|
@@ -11,6 +10,8 @@ from sglang.srt.distributed import (
|
|
11
10
|
get_tensor_model_parallel_rank,
|
12
11
|
get_tensor_model_parallel_world_size,
|
13
12
|
)
|
13
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
14
15
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
15
16
|
ep_gather,
|
16
17
|
ep_scatter,
|
@@ -40,24 +41,34 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
40
41
|
sglang_per_token_quant_fp8,
|
41
42
|
)
|
42
43
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
43
|
-
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
44
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
45
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
46
|
from sglang.srt.utils import (
|
48
47
|
DeepEPMode,
|
48
|
+
ceil_div,
|
49
49
|
dispose_tensor,
|
50
50
|
get_bool_env_var,
|
51
51
|
is_hip,
|
52
|
+
is_npu,
|
52
53
|
set_weight_attrs,
|
53
54
|
)
|
54
55
|
|
55
56
|
_is_hip = is_hip()
|
57
|
+
_is_npu = is_npu()
|
56
58
|
_is_fp8_fnuz = is_fp8_fnuz()
|
59
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
60
|
+
|
61
|
+
if not _is_npu:
|
62
|
+
from sgl_kernel import silu_and_mul
|
57
63
|
|
58
64
|
if _is_hip:
|
59
65
|
from vllm._custom_ops import scaled_fp8_quant
|
60
66
|
|
67
|
+
if _use_aiter:
|
68
|
+
from aiter import ActivationType, QuantType
|
69
|
+
from aiter.fused_moe import fused_moe
|
70
|
+
from aiter.ops.shuffle import shuffle_weight
|
71
|
+
|
61
72
|
logger = logging.getLogger(__name__)
|
62
73
|
|
63
74
|
|
@@ -1046,6 +1057,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
1046
1057
|
w2_weight_scale, requires_grad=False
|
1047
1058
|
)
|
1048
1059
|
layer.w2_input_scale = None
|
1060
|
+
if _use_aiter:
|
1061
|
+
layer.w13_weight = torch.nn.Parameter(
|
1062
|
+
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
1063
|
+
requires_grad=False,
|
1064
|
+
)
|
1065
|
+
layer.w2_weight = torch.nn.Parameter(
|
1066
|
+
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
1067
|
+
requires_grad=False,
|
1068
|
+
)
|
1049
1069
|
return
|
1050
1070
|
|
1051
1071
|
def apply(
|
@@ -1117,18 +1137,36 @@ class DeepEPMoE(EPMoE):
|
|
1117
1137
|
assert (
|
1118
1138
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1119
1139
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
1120
|
-
|
1121
|
-
self.
|
1122
|
-
(
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1140
|
+
if _use_aiter:
|
1141
|
+
# expert_mask is of size (self.num_experts_per_partition + 1),
|
1142
|
+
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
1143
|
+
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
1144
|
+
# self.expert_mask = [1, 1, 1, 1, 0]
|
1145
|
+
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
1146
|
+
self.expert_mask = torch.zeros(
|
1147
|
+
(self.num_experts_per_partition + 1),
|
1148
|
+
device=torch.cuda.current_device(),
|
1149
|
+
dtype=torch.int,
|
1150
|
+
)
|
1151
|
+
# the last one is invalid rank_id
|
1152
|
+
self.expert_mask[:-1] = 1
|
1153
|
+
else:
|
1154
|
+
self.w13_weight_fp8 = (
|
1155
|
+
self.w13_weight,
|
1156
|
+
(
|
1157
|
+
self.w13_weight_scale_inv
|
1158
|
+
if self.use_block_quant
|
1159
|
+
else self.w13_weight_scale
|
1160
|
+
),
|
1161
|
+
)
|
1162
|
+
self.w2_weight_fp8 = (
|
1163
|
+
self.w2_weight,
|
1164
|
+
(
|
1165
|
+
self.w2_weight_scale_inv
|
1166
|
+
if self.use_block_quant
|
1167
|
+
else self.w2_weight_scale
|
1168
|
+
),
|
1169
|
+
)
|
1132
1170
|
|
1133
1171
|
def forward(
|
1134
1172
|
self,
|
@@ -1140,9 +1178,14 @@ class DeepEPMoE(EPMoE):
|
|
1140
1178
|
masked_m: torch.Tensor,
|
1141
1179
|
expected_m: int,
|
1142
1180
|
num_recv_tokens_per_expert: List[int],
|
1143
|
-
|
1181
|
+
forward_batch: ForwardBatch,
|
1144
1182
|
):
|
1145
|
-
|
1183
|
+
if _use_aiter:
|
1184
|
+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
1185
|
+
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
1186
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
1187
|
+
forward_batch.is_extend_in_batch
|
1188
|
+
)
|
1146
1189
|
if resolved_deepep_mode == DeepEPMode.normal:
|
1147
1190
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
1148
1191
|
return self.forward_deepgemm_contiguous(
|
@@ -1274,6 +1317,37 @@ class DeepEPMoE(EPMoE):
|
|
1274
1317
|
)
|
1275
1318
|
return down_output
|
1276
1319
|
|
1320
|
+
def forward_aiter(
|
1321
|
+
self,
|
1322
|
+
hidden_states: torch.Tensor,
|
1323
|
+
topk_idx: torch.Tensor,
|
1324
|
+
topk_weights: torch.Tensor,
|
1325
|
+
):
|
1326
|
+
if hidden_states.shape[0] == 0:
|
1327
|
+
return hidden_states
|
1328
|
+
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
1329
|
+
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
1330
|
+
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
1331
|
+
topk_idx_copy = topk_idx.to(torch.int32)
|
1332
|
+
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
1333
|
+
|
1334
|
+
return fused_moe(
|
1335
|
+
hidden_states,
|
1336
|
+
self.w13_weight,
|
1337
|
+
self.w2_weight,
|
1338
|
+
topk_weights,
|
1339
|
+
topk_idx_copy,
|
1340
|
+
w1_scale=self.w13_weight_scale_inv,
|
1341
|
+
w2_scale=self.w2_weight_scale_inv,
|
1342
|
+
quant_type=QuantType.per_128x128,
|
1343
|
+
activation=(
|
1344
|
+
ActivationType.Silu
|
1345
|
+
if self.activation == "silu"
|
1346
|
+
else ActivationType.Gelu
|
1347
|
+
),
|
1348
|
+
expert_mask=self.expert_mask,
|
1349
|
+
)
|
1350
|
+
|
1277
1351
|
def forward_deepgemm_contiguous(
|
1278
1352
|
self,
|
1279
1353
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
@@ -1303,10 +1377,19 @@ class DeepEPMoE(EPMoE):
|
|
1303
1377
|
device=hidden_states_fp8.device,
|
1304
1378
|
dtype=hidden_states_fp8.dtype,
|
1305
1379
|
),
|
1306
|
-
|
1307
|
-
|
1308
|
-
|
1309
|
-
|
1380
|
+
(
|
1381
|
+
# TODO check whether need `zeros`
|
1382
|
+
torch.zeros(
|
1383
|
+
(ceil_div(K // 128, 4), all_tokens),
|
1384
|
+
device=hidden_states_fp8.device,
|
1385
|
+
dtype=torch.int,
|
1386
|
+
).transpose(0, 1)
|
1387
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1388
|
+
else torch.empty(
|
1389
|
+
(all_tokens, K // 128),
|
1390
|
+
device=hidden_states_fp8.device,
|
1391
|
+
dtype=torch.float32,
|
1392
|
+
)
|
1310
1393
|
),
|
1311
1394
|
]
|
1312
1395
|
m_indices = torch.empty(
|
@@ -1332,6 +1415,7 @@ class DeepEPMoE(EPMoE):
|
|
1332
1415
|
input_tensor[1],
|
1333
1416
|
m_indices,
|
1334
1417
|
output_index,
|
1418
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1335
1419
|
)
|
1336
1420
|
dispose_tensor(hidden_states_fp8)
|
1337
1421
|
|
@@ -1340,7 +1424,8 @@ class DeepEPMoE(EPMoE):
|
|
1340
1424
|
device=hidden_states_fp8_device,
|
1341
1425
|
dtype=torch.bfloat16,
|
1342
1426
|
)
|
1343
|
-
|
1427
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1428
|
+
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1344
1429
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1345
1430
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1346
1431
|
)
|
@@ -1361,10 +1446,15 @@ class DeepEPMoE(EPMoE):
|
|
1361
1446
|
dtype=torch.bfloat16,
|
1362
1447
|
)
|
1363
1448
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1364
|
-
down_input,
|
1449
|
+
down_input,
|
1450
|
+
scale_block_size,
|
1451
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1452
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1453
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1365
1454
|
)
|
1366
1455
|
del down_input
|
1367
|
-
|
1456
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1457
|
+
down_input_scale = tma_align_input_scale(down_input_scale)
|
1368
1458
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1369
1459
|
(down_input_fp8, down_input_scale),
|
1370
1460
|
self.w2_weight_fp8,
|
@@ -1,12 +1,16 @@
|
|
1
1
|
import logging
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
4
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
5
|
-
from sglang.srt.managers.expert_distribution import (
|
6
|
-
get_global_expert_distribution_recorder,
|
7
|
-
)
|
8
6
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
-
from sglang.srt.utils import
|
7
|
+
from sglang.srt.utils import (
|
8
|
+
DeepEPMode,
|
9
|
+
get_bool_env_var,
|
10
|
+
get_int_env_var,
|
11
|
+
is_hip,
|
12
|
+
load_json_config,
|
13
|
+
)
|
10
14
|
|
11
15
|
try:
|
12
16
|
from deep_ep import Buffer, Config
|
@@ -30,7 +34,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
30
34
|
deepep_post_reorder_triton_kernel,
|
31
35
|
deepep_run_moe_deep_preprocess,
|
32
36
|
)
|
33
|
-
from sglang.srt.model_executor.forward_batch_info import
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
38
|
+
|
39
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
34
40
|
|
35
41
|
logger = logging.getLogger(__name__)
|
36
42
|
|
@@ -238,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
238
244
|
topk_idx = topk_idx.to(torch.int64)
|
239
245
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
240
246
|
# TODO hard code 128 block quant,use fp8 communication
|
241
|
-
hidden_states = sglang_per_token_group_quant_fp8(
|
247
|
+
hidden_states = sglang_per_token_group_quant_fp8(
|
248
|
+
hidden_states,
|
249
|
+
128,
|
250
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
251
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
252
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
253
|
+
)
|
242
254
|
previous_event = Buffer.capture() if self.async_finish else None
|
243
255
|
return hidden_states, topk_idx, topk_weights, previous_event
|
244
256
|
|
@@ -376,6 +388,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
376
388
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
377
389
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
378
390
|
"""
|
391
|
+
if _use_aiter:
|
392
|
+
# skip permutation here as aiter fused_moe has fused inside
|
393
|
+
reorder_topk_ids = torch.empty(
|
394
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
395
|
+
)
|
396
|
+
seg_indptr = torch.zeros(
|
397
|
+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
398
|
+
)
|
399
|
+
return reorder_topk_ids, seg_indptr, hidden_states
|
379
400
|
|
380
401
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
381
402
|
topk_idx, self.num_experts
|
@@ -409,7 +430,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
409
430
|
topk_idx: torch.Tensor,
|
410
431
|
topk_weights: torch.Tensor,
|
411
432
|
):
|
412
|
-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
433
|
+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
413
434
|
output = hidden_states
|
414
435
|
else:
|
415
436
|
if hidden_states.shape[0] > 0:
|
@@ -665,21 +686,21 @@ class DeepEPDispatcher:
|
|
665
686
|
hidden_states: torch.Tensor,
|
666
687
|
topk_idx: torch.Tensor,
|
667
688
|
topk_weights: torch.Tensor,
|
668
|
-
|
689
|
+
forward_batch: ForwardBatch,
|
669
690
|
):
|
670
691
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
671
|
-
inner_state = self._get_impl(
|
692
|
+
inner_state = self._get_impl(forward_batch).dispatch_a(
|
672
693
|
hidden_states=hidden_states,
|
673
694
|
topk_idx=topk_idx,
|
674
695
|
topk_weights=topk_weights,
|
675
696
|
)
|
676
|
-
self._dispatch_intermediate_state =
|
697
|
+
self._dispatch_intermediate_state = forward_batch, inner_state
|
677
698
|
|
678
699
|
def dispatch_b(self):
|
679
700
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
680
|
-
|
701
|
+
forward_batch, inner_state = self._dispatch_intermediate_state
|
681
702
|
del self._dispatch_intermediate_state
|
682
|
-
return self._get_impl(
|
703
|
+
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
683
704
|
|
684
705
|
def combine(self, *args, **kwargs) -> Tuple:
|
685
706
|
self.combine_a(*args, **kwargs)
|
@@ -691,24 +712,26 @@ class DeepEPDispatcher:
|
|
691
712
|
hidden_states: torch.Tensor,
|
692
713
|
topk_idx: torch.Tensor,
|
693
714
|
topk_weights: torch.Tensor,
|
694
|
-
|
715
|
+
forward_batch: ForwardBatch,
|
695
716
|
):
|
696
717
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
697
|
-
inner_state = self._get_impl(
|
718
|
+
inner_state = self._get_impl(forward_batch).combine_a(
|
698
719
|
hidden_states=hidden_states,
|
699
720
|
topk_idx=topk_idx,
|
700
721
|
topk_weights=topk_weights,
|
701
722
|
)
|
702
|
-
self._combine_intermediate_state =
|
723
|
+
self._combine_intermediate_state = forward_batch, inner_state
|
703
724
|
|
704
725
|
def combine_b(self):
|
705
726
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
706
|
-
|
727
|
+
forward_batch, inner_state = self._combine_intermediate_state
|
707
728
|
del self._combine_intermediate_state
|
708
|
-
return self._get_impl(
|
729
|
+
return self._get_impl(forward_batch).combine_b(*inner_state)
|
709
730
|
|
710
|
-
def _get_impl(self,
|
711
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
731
|
+
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
732
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
733
|
+
forward_batch.is_extend_in_batch
|
734
|
+
)
|
712
735
|
if resolved_deepep_mode == DeepEPMode.normal:
|
713
736
|
return self._normal_dispatcher
|
714
737
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
@@ -77,8 +77,15 @@ def moe_forward_native(
|
|
77
77
|
custom_routing_function: Optional[Callable] = None,
|
78
78
|
correction_bias: Optional[torch.Tensor] = None,
|
79
79
|
activation: str = "silu",
|
80
|
+
apply_router_weight_on_input: bool = False,
|
81
|
+
inplace: bool = True,
|
82
|
+
no_combine: bool = False,
|
80
83
|
routed_scaling_factor: Optional[float] = None,
|
81
84
|
) -> torch.Tensor:
|
85
|
+
|
86
|
+
if apply_router_weight_on_input:
|
87
|
+
raise NotImplementedError()
|
88
|
+
|
82
89
|
topk_weights, topk_ids = select_experts(
|
83
90
|
hidden_states=x,
|
84
91
|
router_logits=router_logits,
|
@@ -12,7 +12,6 @@ import torch
|
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
-
from sglang.math_utils import ceil_div
|
16
15
|
from sglang.srt.layers.moe.topk import select_experts
|
17
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
18
17
|
per_token_group_quant_fp8,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
25
24
|
sglang_per_token_group_quant_int8,
|
26
25
|
)
|
27
26
|
from sglang.srt.utils import (
|
27
|
+
ceil_div,
|
28
28
|
cpu_has_amx_support,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_bool_env_var,
|
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
|
|
32
32
|
is_cpu,
|
33
33
|
is_cuda,
|
34
34
|
is_hip,
|
35
|
-
log_info_on_rank0,
|
36
35
|
next_power_of_2,
|
37
36
|
)
|
38
37
|
|
@@ -750,9 +749,11 @@ def moe_align_block_size(
|
|
750
749
|
by block_size for proper block matrix operations.
|
751
750
|
"""
|
752
751
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
753
|
-
sorted_ids
|
754
|
-
max_num_tokens_padded,
|
752
|
+
sorted_ids = torch.empty(
|
753
|
+
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
755
754
|
)
|
755
|
+
sorted_ids.fill_(topk_ids.numel())
|
756
|
+
|
756
757
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
757
758
|
expert_ids = torch.empty(
|
758
759
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
@@ -768,6 +769,9 @@ def moe_align_block_size(
|
|
768
769
|
num_tokens_post_pad,
|
769
770
|
)
|
770
771
|
else:
|
772
|
+
cumsum_buffer = torch.empty(
|
773
|
+
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
774
|
+
)
|
771
775
|
token_cnts_buffer = torch.empty(
|
772
776
|
(num_experts + 1) * num_experts,
|
773
777
|
dtype=torch.int32,
|
@@ -12,13 +12,22 @@ from sglang.srt.distributed import (
|
|
12
12
|
get_tensor_model_parallel_world_size,
|
13
13
|
tensor_model_parallel_all_reduce,
|
14
14
|
)
|
15
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
15
16
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
16
17
|
from sglang.srt.layers.moe.topk import select_experts
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
19
20
|
QuantizeMethodBase,
|
20
21
|
)
|
21
|
-
from sglang.srt.
|
22
|
+
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
cpu_has_amx_support,
|
25
|
+
get_bool_env_var,
|
26
|
+
is_cpu,
|
27
|
+
is_hip,
|
28
|
+
set_weight_attrs,
|
29
|
+
use_intel_amx_backend,
|
30
|
+
)
|
22
31
|
|
23
32
|
if torch.cuda.is_available():
|
24
33
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
@@ -28,6 +37,8 @@ else:
|
|
28
37
|
import logging
|
29
38
|
|
30
39
|
_is_hip = is_hip()
|
40
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
41
|
+
_is_cpu = is_cpu()
|
31
42
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
32
43
|
|
33
44
|
if _use_aiter:
|
@@ -117,6 +128,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
117
128
|
requires_grad=False,
|
118
129
|
)
|
119
130
|
torch.cuda.empty_cache()
|
131
|
+
|
132
|
+
# Pack weight for get better performance on CPU
|
133
|
+
if _is_cpu and _is_cpu_amx_available:
|
134
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
135
|
+
|
120
136
|
return
|
121
137
|
|
122
138
|
def apply(
|
@@ -247,6 +263,81 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
247
263
|
inplace: bool = True,
|
248
264
|
no_combine: bool = False,
|
249
265
|
routed_scaling_factor: Optional[float] = None,
|
266
|
+
) -> torch.Tensor:
|
267
|
+
assert activation == "silu", f"activation = {activation} is not supported."
|
268
|
+
|
269
|
+
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
270
|
+
topk_weights, topk_ids = select_experts(
|
271
|
+
hidden_states=x,
|
272
|
+
router_logits=router_logits,
|
273
|
+
use_grouped_topk=use_grouped_topk,
|
274
|
+
top_k=top_k,
|
275
|
+
renormalize=renormalize,
|
276
|
+
topk_group=topk_group,
|
277
|
+
num_expert_group=num_expert_group,
|
278
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
279
|
+
custom_routing_function=custom_routing_function,
|
280
|
+
correction_bias=correction_bias,
|
281
|
+
routed_scaling_factor=routed_scaling_factor,
|
282
|
+
)
|
283
|
+
|
284
|
+
# TODO: support apply_router_weight_on_input in the fused_experts_cpu kernel
|
285
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
286
|
+
x,
|
287
|
+
layer.w13_weight,
|
288
|
+
layer.w2_weight,
|
289
|
+
topk_weights.to(
|
290
|
+
torch.float
|
291
|
+
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
292
|
+
topk_ids,
|
293
|
+
False, # inplace # See [Note] inplace should be False in fused_experts.
|
294
|
+
False, # use_int8_w8a8
|
295
|
+
False, # use_fp8_w8a16
|
296
|
+
None, # w1_scale
|
297
|
+
None, # w2_scale
|
298
|
+
None, # block_size
|
299
|
+
None, # a1_scale
|
300
|
+
None, # a2_scale
|
301
|
+
True, # is_vnni
|
302
|
+
)
|
303
|
+
else:
|
304
|
+
return moe_forward_native(
|
305
|
+
layer,
|
306
|
+
x,
|
307
|
+
use_grouped_topk,
|
308
|
+
top_k,
|
309
|
+
router_logits,
|
310
|
+
renormalize,
|
311
|
+
topk_group,
|
312
|
+
num_expert_group,
|
313
|
+
num_fused_shared_experts,
|
314
|
+
custom_routing_function,
|
315
|
+
correction_bias,
|
316
|
+
activation,
|
317
|
+
apply_router_weight_on_input,
|
318
|
+
inplace,
|
319
|
+
no_combine,
|
320
|
+
routed_scaling_factor,
|
321
|
+
)
|
322
|
+
|
323
|
+
def forward_npu(
|
324
|
+
self,
|
325
|
+
layer: torch.nn.Module,
|
326
|
+
x: torch.Tensor,
|
327
|
+
use_grouped_topk: bool,
|
328
|
+
top_k: int,
|
329
|
+
router_logits: torch.Tensor,
|
330
|
+
renormalize: bool,
|
331
|
+
topk_group: Optional[int] = None,
|
332
|
+
num_expert_group: Optional[int] = None,
|
333
|
+
num_fused_shared_experts: int = 0,
|
334
|
+
custom_routing_function: Optional[Callable] = None,
|
335
|
+
correction_bias: Optional[torch.Tensor] = None,
|
336
|
+
activation: str = "silu",
|
337
|
+
apply_router_weight_on_input: bool = False,
|
338
|
+
inplace: bool = True,
|
339
|
+
no_combine: bool = False,
|
340
|
+
routed_scaling_factor: Optional[float] = None,
|
250
341
|
) -> torch.Tensor:
|
251
342
|
return moe_forward_native(
|
252
343
|
layer,
|
@@ -260,6 +351,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
260
351
|
num_fused_shared_experts,
|
261
352
|
custom_routing_function,
|
262
353
|
correction_bias,
|
354
|
+
activation,
|
355
|
+
apply_router_weight_on_input,
|
356
|
+
inplace,
|
357
|
+
no_combine,
|
358
|
+
routed_scaling_factor,
|
263
359
|
)
|
264
360
|
|
265
361
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
@@ -478,11 +574,6 @@ class FusedMoE(torch.nn.Module):
|
|
478
574
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
479
575
|
shard_size = expert_data.shape[shard_dim] // 2
|
480
576
|
|
481
|
-
if not self.use_presharded_weights:
|
482
|
-
loaded_weight = loaded_weight.narrow(
|
483
|
-
shard_dim, shard_size * tp_rank, shard_size
|
484
|
-
)
|
485
|
-
|
486
577
|
# Narrow parameter and load.
|
487
578
|
# w1, gate_proj: Load into first logical weight of w13.
|
488
579
|
# w3, up_proj: Load into second logical weight of w13.
|
@@ -493,7 +584,24 @@ class FusedMoE(torch.nn.Module):
|
|
493
584
|
start = shard_size
|
494
585
|
else:
|
495
586
|
start = 0
|
496
|
-
|
587
|
+
|
588
|
+
if _is_cpu:
|
589
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
590
|
+
expert_data,
|
591
|
+
loaded_weight,
|
592
|
+
start,
|
593
|
+
shard_size * tp_rank,
|
594
|
+
shard_dim,
|
595
|
+
shard_size,
|
596
|
+
not self.use_presharded_weights,
|
597
|
+
)
|
598
|
+
else:
|
599
|
+
if not self.use_presharded_weights:
|
600
|
+
loaded_weight = loaded_weight.narrow(
|
601
|
+
shard_dim, shard_size * tp_rank, shard_size
|
602
|
+
)
|
603
|
+
|
604
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
497
605
|
expert_data.copy_(loaded_weight)
|
498
606
|
|
499
607
|
def _load_w2(
|
@@ -510,10 +618,21 @@ class FusedMoE(torch.nn.Module):
|
|
510
618
|
# Narrow parameter and load.
|
511
619
|
shard_size = expert_data.shape[shard_dim]
|
512
620
|
|
513
|
-
if
|
514
|
-
loaded_weight =
|
515
|
-
|
621
|
+
if _is_cpu:
|
622
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
623
|
+
expert_data,
|
624
|
+
loaded_weight,
|
625
|
+
0, # param_data_start
|
626
|
+
shard_size * tp_rank,
|
627
|
+
shard_dim,
|
628
|
+
shard_size,
|
629
|
+
not self.use_presharded_weights,
|
516
630
|
)
|
631
|
+
else:
|
632
|
+
if not self.use_presharded_weights:
|
633
|
+
loaded_weight = loaded_weight.narrow(
|
634
|
+
shard_dim, shard_size * tp_rank, shard_size
|
635
|
+
)
|
517
636
|
|
518
637
|
# w2, down_proj: Load into only logical weight of w2.
|
519
638
|
expert_data.copy_(loaded_weight)
|