sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ from typing import Callable, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import einops
|
5
5
|
import torch
|
6
|
-
from sgl_kernel import silu_and_mul
|
7
6
|
from torch.nn import Module
|
8
7
|
|
9
8
|
from sglang.srt.custom_op import CustomOp
|
@@ -11,6 +10,8 @@ from sglang.srt.distributed import (
|
|
11
10
|
get_tensor_model_parallel_rank,
|
12
11
|
get_tensor_model_parallel_world_size,
|
13
12
|
)
|
13
|
+
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
14
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
14
15
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
15
16
|
ep_gather,
|
16
17
|
ep_scatter,
|
@@ -40,22 +41,26 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
40
41
|
sglang_per_token_quant_fp8,
|
41
42
|
)
|
42
43
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
43
|
-
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
44
|
-
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
45
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
46
|
from sglang.srt.utils import (
|
48
47
|
DeepEPMode,
|
48
|
+
ceil_div,
|
49
49
|
dispose_tensor,
|
50
50
|
get_bool_env_var,
|
51
51
|
is_hip,
|
52
|
+
is_npu,
|
52
53
|
set_weight_attrs,
|
53
54
|
)
|
54
55
|
|
55
56
|
_is_hip = is_hip()
|
57
|
+
_is_npu = is_npu()
|
56
58
|
_is_fp8_fnuz = is_fp8_fnuz()
|
57
59
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
58
60
|
|
61
|
+
if not _is_npu:
|
62
|
+
from sgl_kernel import silu_and_mul
|
63
|
+
|
59
64
|
if _is_hip:
|
60
65
|
from vllm._custom_ops import scaled_fp8_quant
|
61
66
|
|
@@ -1173,12 +1178,14 @@ class DeepEPMoE(EPMoE):
|
|
1173
1178
|
masked_m: torch.Tensor,
|
1174
1179
|
expected_m: int,
|
1175
1180
|
num_recv_tokens_per_expert: List[int],
|
1176
|
-
|
1181
|
+
forward_batch: ForwardBatch,
|
1177
1182
|
):
|
1178
1183
|
if _use_aiter:
|
1179
1184
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
1180
1185
|
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
1181
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
1186
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
1187
|
+
forward_batch.is_extend_in_batch
|
1188
|
+
)
|
1182
1189
|
if resolved_deepep_mode == DeepEPMode.normal:
|
1183
1190
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
1184
1191
|
return self.forward_deepgemm_contiguous(
|
@@ -1370,10 +1377,19 @@ class DeepEPMoE(EPMoE):
|
|
1370
1377
|
device=hidden_states_fp8.device,
|
1371
1378
|
dtype=hidden_states_fp8.dtype,
|
1372
1379
|
),
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1380
|
+
(
|
1381
|
+
# TODO check whether need `zeros`
|
1382
|
+
torch.zeros(
|
1383
|
+
(ceil_div(K // 128, 4), all_tokens),
|
1384
|
+
device=hidden_states_fp8.device,
|
1385
|
+
dtype=torch.int,
|
1386
|
+
).transpose(0, 1)
|
1387
|
+
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1388
|
+
else torch.empty(
|
1389
|
+
(all_tokens, K // 128),
|
1390
|
+
device=hidden_states_fp8.device,
|
1391
|
+
dtype=torch.float32,
|
1392
|
+
)
|
1377
1393
|
),
|
1378
1394
|
]
|
1379
1395
|
m_indices = torch.empty(
|
@@ -1399,6 +1415,7 @@ class DeepEPMoE(EPMoE):
|
|
1399
1415
|
input_tensor[1],
|
1400
1416
|
m_indices,
|
1401
1417
|
output_index,
|
1418
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1402
1419
|
)
|
1403
1420
|
dispose_tensor(hidden_states_fp8)
|
1404
1421
|
|
@@ -1407,7 +1424,8 @@ class DeepEPMoE(EPMoE):
|
|
1407
1424
|
device=hidden_states_fp8_device,
|
1408
1425
|
dtype=torch.bfloat16,
|
1409
1426
|
)
|
1410
|
-
|
1427
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1428
|
+
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1411
1429
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1412
1430
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1413
1431
|
)
|
@@ -1428,10 +1446,15 @@ class DeepEPMoE(EPMoE):
|
|
1428
1446
|
dtype=torch.bfloat16,
|
1429
1447
|
)
|
1430
1448
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1431
|
-
down_input,
|
1449
|
+
down_input,
|
1450
|
+
scale_block_size,
|
1451
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1452
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1453
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
1432
1454
|
)
|
1433
1455
|
del down_input
|
1434
|
-
|
1456
|
+
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
1457
|
+
down_input_scale = tma_align_input_scale(down_input_scale)
|
1435
1458
|
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
1436
1459
|
(down_input_fp8, down_input_scale),
|
1437
1460
|
self.w2_weight_fp8,
|
@@ -1,10 +1,8 @@
|
|
1
1
|
import logging
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
4
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
5
|
-
from sglang.srt.managers.expert_distribution import (
|
6
|
-
get_global_expert_distribution_recorder,
|
7
|
-
)
|
8
6
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
7
|
from sglang.srt.utils import (
|
10
8
|
DeepEPMode,
|
@@ -36,7 +34,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
36
34
|
deepep_post_reorder_triton_kernel,
|
37
35
|
deepep_run_moe_deep_preprocess,
|
38
36
|
)
|
39
|
-
from sglang.srt.model_executor.forward_batch_info import
|
37
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
38
|
|
41
39
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
42
40
|
|
@@ -246,7 +244,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
246
244
|
topk_idx = topk_idx.to(torch.int64)
|
247
245
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
248
246
|
# TODO hard code 128 block quant,use fp8 communication
|
249
|
-
hidden_states = sglang_per_token_group_quant_fp8(
|
247
|
+
hidden_states = sglang_per_token_group_quant_fp8(
|
248
|
+
hidden_states,
|
249
|
+
128,
|
250
|
+
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
251
|
+
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
252
|
+
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
253
|
+
)
|
250
254
|
previous_event = Buffer.capture() if self.async_finish else None
|
251
255
|
return hidden_states, topk_idx, topk_weights, previous_event
|
252
256
|
|
@@ -682,21 +686,21 @@ class DeepEPDispatcher:
|
|
682
686
|
hidden_states: torch.Tensor,
|
683
687
|
topk_idx: torch.Tensor,
|
684
688
|
topk_weights: torch.Tensor,
|
685
|
-
|
689
|
+
forward_batch: ForwardBatch,
|
686
690
|
):
|
687
691
|
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
688
|
-
inner_state = self._get_impl(
|
692
|
+
inner_state = self._get_impl(forward_batch).dispatch_a(
|
689
693
|
hidden_states=hidden_states,
|
690
694
|
topk_idx=topk_idx,
|
691
695
|
topk_weights=topk_weights,
|
692
696
|
)
|
693
|
-
self._dispatch_intermediate_state =
|
697
|
+
self._dispatch_intermediate_state = forward_batch, inner_state
|
694
698
|
|
695
699
|
def dispatch_b(self):
|
696
700
|
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
697
|
-
|
701
|
+
forward_batch, inner_state = self._dispatch_intermediate_state
|
698
702
|
del self._dispatch_intermediate_state
|
699
|
-
return self._get_impl(
|
703
|
+
return self._get_impl(forward_batch).dispatch_b(*inner_state)
|
700
704
|
|
701
705
|
def combine(self, *args, **kwargs) -> Tuple:
|
702
706
|
self.combine_a(*args, **kwargs)
|
@@ -708,24 +712,26 @@ class DeepEPDispatcher:
|
|
708
712
|
hidden_states: torch.Tensor,
|
709
713
|
topk_idx: torch.Tensor,
|
710
714
|
topk_weights: torch.Tensor,
|
711
|
-
|
715
|
+
forward_batch: ForwardBatch,
|
712
716
|
):
|
713
717
|
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
714
|
-
inner_state = self._get_impl(
|
718
|
+
inner_state = self._get_impl(forward_batch).combine_a(
|
715
719
|
hidden_states=hidden_states,
|
716
720
|
topk_idx=topk_idx,
|
717
721
|
topk_weights=topk_weights,
|
718
722
|
)
|
719
|
-
self._combine_intermediate_state =
|
723
|
+
self._combine_intermediate_state = forward_batch, inner_state
|
720
724
|
|
721
725
|
def combine_b(self):
|
722
726
|
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
723
|
-
|
727
|
+
forward_batch, inner_state = self._combine_intermediate_state
|
724
728
|
del self._combine_intermediate_state
|
725
|
-
return self._get_impl(
|
729
|
+
return self._get_impl(forward_batch).combine_b(*inner_state)
|
726
730
|
|
727
|
-
def _get_impl(self,
|
728
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
731
|
+
def _get_impl(self, forward_batch: ForwardBatch) -> _DeepEPDispatcherImplBase:
|
732
|
+
resolved_deepep_mode = self.deepep_mode.resolve(
|
733
|
+
forward_batch.is_extend_in_batch
|
734
|
+
)
|
729
735
|
if resolved_deepep_mode == DeepEPMode.normal:
|
730
736
|
return self._normal_dispatcher
|
731
737
|
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
@@ -12,7 +12,6 @@ import torch
|
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
14
|
|
15
|
-
from sglang.math_utils import ceil_div
|
16
15
|
from sglang.srt.layers.moe.topk import select_experts
|
17
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
18
17
|
per_token_group_quant_fp8,
|
@@ -25,6 +24,7 @@ from sglang.srt.layers.quantization.int8_kernel import (
|
|
25
24
|
sglang_per_token_group_quant_int8,
|
26
25
|
)
|
27
26
|
from sglang.srt.utils import (
|
27
|
+
ceil_div,
|
28
28
|
cpu_has_amx_support,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_bool_env_var,
|
@@ -32,7 +32,6 @@ from sglang.srt.utils import (
|
|
32
32
|
is_cpu,
|
33
33
|
is_cuda,
|
34
34
|
is_hip,
|
35
|
-
log_info_on_rank0,
|
36
35
|
next_power_of_2,
|
37
36
|
)
|
38
37
|
|
@@ -12,19 +12,21 @@ from sglang.srt.distributed import (
|
|
12
12
|
get_tensor_model_parallel_world_size,
|
13
13
|
tensor_model_parallel_all_reduce,
|
14
14
|
)
|
15
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
15
16
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
16
17
|
from sglang.srt.layers.moe.topk import select_experts
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
19
20
|
QuantizeMethodBase,
|
20
21
|
)
|
22
|
+
from sglang.srt.model_loader.weight_utils import narrow_padded_param_and_loaded_weight
|
21
23
|
from sglang.srt.utils import (
|
22
|
-
_process_weight_after_loading,
|
23
24
|
cpu_has_amx_support,
|
24
25
|
get_bool_env_var,
|
25
26
|
is_cpu,
|
26
27
|
is_hip,
|
27
28
|
set_weight_attrs,
|
29
|
+
use_intel_amx_backend,
|
28
30
|
)
|
29
31
|
|
30
32
|
if torch.cuda.is_available():
|
@@ -129,7 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
129
131
|
|
130
132
|
# Pack weight for get better performance on CPU
|
131
133
|
if _is_cpu and _is_cpu_amx_available:
|
132
|
-
|
134
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
133
135
|
|
134
136
|
return
|
135
137
|
|
@@ -264,10 +266,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
264
266
|
) -> torch.Tensor:
|
265
267
|
assert activation == "silu", f"activation = {activation} is not supported."
|
266
268
|
|
267
|
-
if (
|
268
|
-
getattr(layer, "use_intel_amx_backend", False)
|
269
|
-
and not apply_router_weight_on_input
|
270
|
-
):
|
269
|
+
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
271
270
|
topk_weights, topk_ids = select_experts(
|
272
271
|
hidden_states=x,
|
273
272
|
router_logits=router_logits,
|
@@ -291,7 +290,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
291
290
|
torch.float
|
292
291
|
), # TODO: the topk_weights of llama4 is computed via Llama4MoE:custom_routing_function and is bfloat16 while the kernel requires it to be float32
|
293
292
|
topk_ids,
|
294
|
-
|
293
|
+
False, # inplace # See [Note] inplace should be False in fused_experts.
|
295
294
|
False, # use_int8_w8a8
|
296
295
|
False, # use_fp8_w8a16
|
297
296
|
None, # w1_scale
|
@@ -321,6 +320,44 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
321
320
|
routed_scaling_factor,
|
322
321
|
)
|
323
322
|
|
323
|
+
def forward_npu(
|
324
|
+
self,
|
325
|
+
layer: torch.nn.Module,
|
326
|
+
x: torch.Tensor,
|
327
|
+
use_grouped_topk: bool,
|
328
|
+
top_k: int,
|
329
|
+
router_logits: torch.Tensor,
|
330
|
+
renormalize: bool,
|
331
|
+
topk_group: Optional[int] = None,
|
332
|
+
num_expert_group: Optional[int] = None,
|
333
|
+
num_fused_shared_experts: int = 0,
|
334
|
+
custom_routing_function: Optional[Callable] = None,
|
335
|
+
correction_bias: Optional[torch.Tensor] = None,
|
336
|
+
activation: str = "silu",
|
337
|
+
apply_router_weight_on_input: bool = False,
|
338
|
+
inplace: bool = True,
|
339
|
+
no_combine: bool = False,
|
340
|
+
routed_scaling_factor: Optional[float] = None,
|
341
|
+
) -> torch.Tensor:
|
342
|
+
return moe_forward_native(
|
343
|
+
layer,
|
344
|
+
x,
|
345
|
+
use_grouped_topk,
|
346
|
+
top_k,
|
347
|
+
router_logits,
|
348
|
+
renormalize,
|
349
|
+
topk_group,
|
350
|
+
num_expert_group,
|
351
|
+
num_fused_shared_experts,
|
352
|
+
custom_routing_function,
|
353
|
+
correction_bias,
|
354
|
+
activation,
|
355
|
+
apply_router_weight_on_input,
|
356
|
+
inplace,
|
357
|
+
no_combine,
|
358
|
+
routed_scaling_factor,
|
359
|
+
)
|
360
|
+
|
324
361
|
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
325
362
|
raise NotImplementedError("The TPU backend currently does not support MoE.")
|
326
363
|
|
@@ -537,11 +574,6 @@ class FusedMoE(torch.nn.Module):
|
|
537
574
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
538
575
|
shard_size = expert_data.shape[shard_dim] // 2
|
539
576
|
|
540
|
-
if not self.use_presharded_weights:
|
541
|
-
loaded_weight = loaded_weight.narrow(
|
542
|
-
shard_dim, shard_size * tp_rank, shard_size
|
543
|
-
)
|
544
|
-
|
545
577
|
# Narrow parameter and load.
|
546
578
|
# w1, gate_proj: Load into first logical weight of w13.
|
547
579
|
# w3, up_proj: Load into second logical weight of w13.
|
@@ -552,7 +584,24 @@ class FusedMoE(torch.nn.Module):
|
|
552
584
|
start = shard_size
|
553
585
|
else:
|
554
586
|
start = 0
|
555
|
-
|
587
|
+
|
588
|
+
if _is_cpu:
|
589
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
590
|
+
expert_data,
|
591
|
+
loaded_weight,
|
592
|
+
start,
|
593
|
+
shard_size * tp_rank,
|
594
|
+
shard_dim,
|
595
|
+
shard_size,
|
596
|
+
not self.use_presharded_weights,
|
597
|
+
)
|
598
|
+
else:
|
599
|
+
if not self.use_presharded_weights:
|
600
|
+
loaded_weight = loaded_weight.narrow(
|
601
|
+
shard_dim, shard_size * tp_rank, shard_size
|
602
|
+
)
|
603
|
+
|
604
|
+
expert_data = expert_data.narrow(shard_dim, start, shard_size)
|
556
605
|
expert_data.copy_(loaded_weight)
|
557
606
|
|
558
607
|
def _load_w2(
|
@@ -569,10 +618,21 @@ class FusedMoE(torch.nn.Module):
|
|
569
618
|
# Narrow parameter and load.
|
570
619
|
shard_size = expert_data.shape[shard_dim]
|
571
620
|
|
572
|
-
if
|
573
|
-
loaded_weight =
|
574
|
-
|
621
|
+
if _is_cpu:
|
622
|
+
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
623
|
+
expert_data,
|
624
|
+
loaded_weight,
|
625
|
+
0, # param_data_start
|
626
|
+
shard_size * tp_rank,
|
627
|
+
shard_dim,
|
628
|
+
shard_size,
|
629
|
+
not self.use_presharded_weights,
|
575
630
|
)
|
631
|
+
else:
|
632
|
+
if not self.use_presharded_weights:
|
633
|
+
loaded_weight = loaded_weight.narrow(
|
634
|
+
shard_dim, shard_size * tp_rank, shard_size
|
635
|
+
)
|
576
636
|
|
577
637
|
# w2, down_proj: Load into only logical weight of w2.
|
578
638
|
expert_data.copy_(loaded_weight)
|
sglang/srt/layers/moe/router.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Tuple
|
1
|
+
from typing import Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import triton
|
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
|
|
16
16
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
17
17
|
topk_weights_ptr, # output (bs, topk)
|
18
18
|
topk_ids_ptr, # output (bs, topk)
|
19
|
+
correction_bias_ptr,
|
20
|
+
is_correction_bias: tl.constexpr,
|
19
21
|
num_experts: tl.constexpr,
|
20
22
|
topk: tl.constexpr,
|
21
23
|
moe_softcapping: tl.constexpr,
|
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
|
|
49
51
|
bottom = exped + 1
|
50
52
|
logits_softcapped = top / bottom * moe_softcapping
|
51
53
|
|
54
|
+
# Add bias after softcapping
|
55
|
+
if is_correction_bias:
|
56
|
+
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
|
57
|
+
logits_softcapped = logits_softcapped + bias
|
58
|
+
|
52
59
|
# topk
|
53
60
|
# assert 1 <= topk <= num_experts
|
54
61
|
|
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
|
|
109
116
|
router_weight: torch.Tensor,
|
110
117
|
topk: int,
|
111
118
|
moe_softcapping: float,
|
119
|
+
correction_bias: Optional[torch.Tensor] = None,
|
112
120
|
):
|
113
121
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
114
122
|
bs, hidden_dim = x.shape
|
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
|
|
117
125
|
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
118
126
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
119
127
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
128
|
+
is_correction_bias = correction_bias is not None
|
120
129
|
|
121
|
-
|
122
|
-
|
123
|
-
min_num_warps = 16 if _is_hip else 32
|
124
|
-
|
130
|
+
max_warps = 16 if _is_hip else 32
|
125
131
|
config = {
|
126
132
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
127
133
|
"num_warps": max(
|
128
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
134
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
129
135
|
),
|
130
136
|
}
|
131
137
|
|
132
|
-
fused_moe_router_kernel[
|
138
|
+
fused_moe_router_kernel[(bs,)](
|
133
139
|
x,
|
134
140
|
router_weight,
|
135
141
|
topk_weights,
|
136
142
|
topk_ids,
|
143
|
+
correction_bias,
|
144
|
+
is_correction_bias=is_correction_bias,
|
137
145
|
num_experts=num_experts,
|
138
146
|
topk=topk,
|
139
147
|
moe_softcapping=moe_softcapping,
|
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
|
|
153
161
|
topk_ids_ptr, # output (bs, topk)
|
154
162
|
bs,
|
155
163
|
num_experts: tl.constexpr,
|
156
|
-
topk: tl.constexpr, # only support topk
|
164
|
+
topk: tl.constexpr, # only support topk <= 2
|
157
165
|
moe_softcapping: tl.constexpr,
|
158
166
|
moe_renormalize: tl.constexpr, # not supported
|
159
167
|
K: tl.constexpr,
|
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
|
|
204
212
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
205
213
|
|
206
214
|
# 5. top1
|
207
|
-
|
208
|
-
|
215
|
+
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
216
|
+
cond_top1 = arange_block_size_n < num_experts
|
217
|
+
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
|
209
218
|
top1_v = tl.max(
|
210
|
-
tl.where(
|
219
|
+
tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
211
220
|
)
|
212
|
-
|
213
|
-
tl.where(
|
221
|
+
top1_invsumexp = 1.0 / tl.sum(
|
222
|
+
tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
214
223
|
)
|
215
224
|
|
216
|
-
# 6. store to output
|
217
|
-
|
218
|
-
|
219
|
-
tl.store(topk_ids_ptr +
|
225
|
+
# 6. store top1 to output
|
226
|
+
offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
|
227
|
+
top1_mask = offs_top1 < bs * topk
|
228
|
+
tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
|
220
229
|
tl.store(
|
221
|
-
topk_weights_ptr +
|
222
|
-
|
223
|
-
mask=
|
230
|
+
topk_weights_ptr + offs_top1,
|
231
|
+
top1_invsumexp,
|
232
|
+
mask=top1_mask,
|
224
233
|
)
|
225
234
|
|
235
|
+
# 7. handle topk == 2
|
236
|
+
if topk == 2:
|
237
|
+
cond_top2 = (arange_block_size_n < num_experts) and (
|
238
|
+
arange_block_size_n != top1[:, None]
|
239
|
+
)
|
240
|
+
top2 = tl.argmax(
|
241
|
+
tl.where(cond_top2, logits_softcapped, float("-inf")),
|
242
|
+
axis=1,
|
243
|
+
keep_dims=True,
|
244
|
+
)
|
245
|
+
top2_v = tl.sum(
|
246
|
+
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
|
247
|
+
)
|
248
|
+
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
|
249
|
+
|
250
|
+
# store top2
|
251
|
+
offs_top2 = (
|
252
|
+
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
|
253
|
+
)
|
254
|
+
top2_mask = offs_top2 < bs * topk
|
255
|
+
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
|
256
|
+
tl.store(
|
257
|
+
topk_weights_ptr + offs_top2,
|
258
|
+
top2_invsumexp,
|
259
|
+
mask=top2_mask,
|
260
|
+
)
|
261
|
+
|
226
262
|
|
227
263
|
def fused_moe_router_large_bs_impl(
|
228
264
|
x: torch.Tensor,
|
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
|
|
239
275
|
|
240
276
|
assert num_experts <= BLOCK_SIZE_N
|
241
277
|
assert hidden_dim % BLOCK_SIZE_K == 0
|
242
|
-
assert topk
|
278
|
+
assert topk <= 2
|
243
279
|
|
244
280
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
245
281
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
|
|
273
309
|
gating_output,
|
274
310
|
topk,
|
275
311
|
renormalize,
|
312
|
+
correction_bias: Optional[torch.Tensor] = None,
|
276
313
|
):
|
277
314
|
assert not renormalize
|
278
315
|
assert (
|
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
|
|
286
323
|
BLOCK_SIZE_K = 256
|
287
324
|
if (
|
288
325
|
bs >= 512
|
289
|
-
and topk
|
326
|
+
and topk <= 2
|
290
327
|
and num_experts <= BLOCK_SIZE_N
|
291
328
|
and hidden_dim % BLOCK_SIZE_K == 0
|
292
329
|
):
|
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
|
|
305
342
|
router_weight=gating_output,
|
306
343
|
topk=topk,
|
307
344
|
moe_softcapping=moe_softcapping,
|
345
|
+
correction_bias=correction_bias,
|
308
346
|
)
|
309
347
|
|
310
348
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -18,12 +18,12 @@ from typing import Callable, Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
|
-
from sglang.srt.
|
22
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.eplb import expert_location_dispatch
|
22
|
+
from sglang.srt.eplb.expert_distribution import (
|
23
23
|
ExpertDistributionRecorder,
|
24
24
|
get_global_expert_distribution_recorder,
|
25
25
|
)
|
26
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.eplb.expert_location_dispatch import (
|
27
27
|
ExpertLocationDispatchInfo,
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
|
35
35
|
is_cpu,
|
36
36
|
is_cuda,
|
37
37
|
is_hip,
|
38
|
+
is_npu,
|
38
39
|
)
|
39
40
|
|
40
41
|
_is_cuda = is_cuda()
|
@@ -42,6 +43,7 @@ _is_hip = is_hip()
|
|
42
43
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
44
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_npu = is_npu()
|
45
47
|
|
46
48
|
if _is_cuda:
|
47
49
|
from sgl_kernel import moe_fused_gate
|
@@ -106,37 +108,14 @@ def fused_topk(
|
|
106
108
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
107
109
|
)
|
108
110
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
109
|
-
token_expert_indicies = torch.empty(
|
110
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
111
|
-
)
|
112
111
|
|
113
112
|
topk_softmax(
|
114
113
|
topk_weights,
|
115
114
|
topk_ids,
|
116
|
-
|
117
|
-
|
118
|
-
)
|
119
|
-
del token_expert_indicies
|
120
|
-
|
121
|
-
return _fused_topk_postprocess(
|
122
|
-
topk_weights=topk_weights,
|
123
|
-
topk_ids=topk_ids,
|
124
|
-
renormalize=renormalize,
|
125
|
-
expert_location_dispatch_info=expert_location_dispatch_info,
|
126
|
-
num_token_non_padded=num_token_non_padded,
|
115
|
+
gating_output,
|
116
|
+
renormalize,
|
127
117
|
)
|
128
118
|
|
129
|
-
|
130
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
131
|
-
def _fused_topk_postprocess(
|
132
|
-
topk_weights,
|
133
|
-
topk_ids,
|
134
|
-
renormalize,
|
135
|
-
expert_location_dispatch_info,
|
136
|
-
num_token_non_padded,
|
137
|
-
):
|
138
|
-
if renormalize:
|
139
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
140
119
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
141
120
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
142
121
|
return topk_weights, topk_ids
|
@@ -159,6 +138,9 @@ def grouped_topk_gpu(
|
|
159
138
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
160
139
|
|
161
140
|
scores = torch.softmax(gating_output, dim=-1)
|
141
|
+
# NPU compiler limitation
|
142
|
+
if _is_npu and scores.dtype == torch.bfloat16:
|
143
|
+
scores = scores.to(torch.float16)
|
162
144
|
num_token = scores.shape[0]
|
163
145
|
num_experts = scores.shape[1]
|
164
146
|
group_scores = (
|