sglang 0.5.0rc2__py3-none-any.whl → 0.5.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -0
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +375 -51
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,31 @@
|
|
1
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
2
|
+
from sglang.srt.layers.moe.utils import (
|
3
|
+
DeepEPMode,
|
4
|
+
MoeA2ABackend,
|
5
|
+
MoeRunnerBackend,
|
6
|
+
get_deepep_config,
|
7
|
+
get_deepep_mode,
|
8
|
+
get_moe_a2a_backend,
|
9
|
+
get_moe_runner_backend,
|
10
|
+
get_tbo_token_distribution_threshold,
|
11
|
+
initialize_moe_config,
|
12
|
+
is_tbo_enabled,
|
13
|
+
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
14
|
+
should_use_flashinfer_trtllm_moe,
|
15
|
+
)
|
16
|
+
|
17
|
+
__all__ = [
|
18
|
+
"DeepEPMode",
|
19
|
+
"MoeA2ABackend",
|
20
|
+
"MoeRunnerConfig",
|
21
|
+
"MoeRunnerBackend",
|
22
|
+
"initialize_moe_config",
|
23
|
+
"get_moe_a2a_backend",
|
24
|
+
"get_moe_runner_backend",
|
25
|
+
"get_deepep_mode",
|
26
|
+
"should_use_flashinfer_trtllm_moe",
|
27
|
+
"should_use_flashinfer_cutlass_moe_fp4_allgather",
|
28
|
+
"is_tbo_enabled",
|
29
|
+
"get_tbo_token_distribution_threshold",
|
30
|
+
"get_deepep_config",
|
31
|
+
]
|
@@ -1,11 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING, Optional
|
4
|
+
from typing import TYPE_CHECKING, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
9
|
+
from sglang.srt.layers.moe import (
|
10
|
+
get_deepep_mode,
|
11
|
+
get_moe_a2a_backend,
|
12
|
+
get_moe_runner_backend,
|
13
|
+
should_use_flashinfer_trtllm_moe,
|
14
|
+
)
|
9
15
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
10
16
|
ep_gather,
|
11
17
|
ep_scatter,
|
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
16
22
|
)
|
17
23
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
|
18
24
|
from sglang.srt.layers.moe.topk import TopKOutput
|
19
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
|
20
25
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
21
26
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
22
|
-
from sglang.srt.layers.quantization.fp8 import
|
23
|
-
Fp8Config,
|
24
|
-
Fp8MoEMethod,
|
25
|
-
get_tile_tokens_dim,
|
26
|
-
)
|
27
|
+
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
27
28
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
28
29
|
is_fp8_fnuz,
|
29
30
|
sglang_per_token_group_quant_fp8,
|
@@ -51,7 +52,6 @@ if not (_is_npu or _is_hip):
|
|
51
52
|
if _use_aiter:
|
52
53
|
from aiter import ActivationType, QuantType
|
53
54
|
from aiter.fused_moe import fused_moe
|
54
|
-
from aiter.ops.shuffle import shuffle_weight
|
55
55
|
|
56
56
|
logger = logging.getLogger(__name__)
|
57
57
|
|
@@ -89,12 +89,11 @@ class EPMoE(FusedMoE):
|
|
89
89
|
num_fused_shared_experts: int = 0,
|
90
90
|
params_dtype: Optional[torch.dtype] = None,
|
91
91
|
quant_config: Optional[QuantizationConfig] = None,
|
92
|
-
tp_size: Optional[int] = None,
|
93
92
|
prefix: str = "",
|
94
93
|
activation: str = "silu",
|
95
94
|
routed_scaling_factor: Optional[float] = None,
|
96
|
-
|
97
|
-
|
95
|
+
gemm1_alpha: Optional[float] = None,
|
96
|
+
gemm1_clamp_limit: Optional[float] = None,
|
98
97
|
with_bias: bool = False,
|
99
98
|
):
|
100
99
|
super().__init__(
|
@@ -106,13 +105,12 @@ class EPMoE(FusedMoE):
|
|
106
105
|
top_k=top_k,
|
107
106
|
params_dtype=params_dtype,
|
108
107
|
quant_config=quant_config,
|
109
|
-
tp_size=tp_size,
|
110
108
|
prefix=prefix,
|
111
109
|
activation=activation,
|
112
110
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
113
111
|
routed_scaling_factor=routed_scaling_factor,
|
114
|
-
|
115
|
-
|
112
|
+
gemm1_alpha=gemm1_alpha,
|
113
|
+
gemm1_clamp_limit=gemm1_clamp_limit,
|
116
114
|
with_bias=with_bias,
|
117
115
|
)
|
118
116
|
|
@@ -163,7 +161,8 @@ class EPMoE(FusedMoE):
|
|
163
161
|
)
|
164
162
|
|
165
163
|
assert self.quant_method is not None
|
166
|
-
assert self.activation == "silu"
|
164
|
+
assert self.moe_runner_config.activation == "silu"
|
165
|
+
|
167
166
|
hidden_states_shape = hidden_states.shape
|
168
167
|
hidden_states_dtype = hidden_states.dtype
|
169
168
|
hidden_states_device = hidden_states.device
|
@@ -327,8 +326,8 @@ class EPMoE(FusedMoE):
|
|
327
326
|
m_max * self.start_expert_id,
|
328
327
|
BLOCK_SIZE=512,
|
329
328
|
)
|
330
|
-
if self.routed_scaling_factor is not None:
|
331
|
-
output *= self.routed_scaling_factor
|
329
|
+
if self.moe_runner_config.routed_scaling_factor is not None:
|
330
|
+
output *= self.moe_runner_config.routed_scaling_factor
|
332
331
|
return output
|
333
332
|
|
334
333
|
|
@@ -349,11 +348,9 @@ class DeepEPMoE(EPMoE):
|
|
349
348
|
num_fused_shared_experts: int = 0,
|
350
349
|
params_dtype: Optional[torch.dtype] = None,
|
351
350
|
quant_config: Optional[QuantizationConfig] = None,
|
352
|
-
tp_size: Optional[int] = None,
|
353
351
|
prefix: str = "",
|
354
352
|
activation: str = "silu",
|
355
353
|
routed_scaling_factor: Optional[float] = None,
|
356
|
-
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
357
354
|
):
|
358
355
|
super().__init__(
|
359
356
|
num_experts=num_experts,
|
@@ -364,12 +361,11 @@ class DeepEPMoE(EPMoE):
|
|
364
361
|
num_fused_shared_experts=num_fused_shared_experts,
|
365
362
|
params_dtype=params_dtype,
|
366
363
|
quant_config=quant_config,
|
367
|
-
tp_size=tp_size,
|
368
364
|
prefix=prefix,
|
369
365
|
activation=activation,
|
370
366
|
routed_scaling_factor=routed_scaling_factor,
|
371
367
|
)
|
372
|
-
self.deepep_mode =
|
368
|
+
self.deepep_mode = get_deepep_mode()
|
373
369
|
|
374
370
|
# TODO: move to the beginning of the file
|
375
371
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
@@ -383,7 +379,7 @@ class DeepEPMoE(EPMoE):
|
|
383
379
|
num_local_experts=self.num_local_experts,
|
384
380
|
hidden_size=hidden_size,
|
385
381
|
params_dtype=params_dtype,
|
386
|
-
deepep_mode=deepep_mode,
|
382
|
+
deepep_mode=self.deepep_mode,
|
387
383
|
async_finish=True, # TODO
|
388
384
|
return_recv_hook=True,
|
389
385
|
)
|
@@ -458,15 +454,19 @@ class DeepEPMoE(EPMoE):
|
|
458
454
|
)
|
459
455
|
|
460
456
|
def moe_impl(self, dispatch_output: DispatchOutput):
|
457
|
+
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
|
458
|
+
|
461
459
|
if _use_aiter:
|
460
|
+
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
462
461
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
463
462
|
return self.forward_aiter(dispatch_output)
|
464
463
|
if _is_npu:
|
464
|
+
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
|
465
465
|
return self.forward_npu(dispatch_output)
|
466
|
-
if
|
466
|
+
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
467
467
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
468
468
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
469
|
-
elif
|
469
|
+
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
470
470
|
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
471
471
|
return self.forward_deepgemm_masked(dispatch_output)
|
472
472
|
else:
|
@@ -490,7 +490,7 @@ class DeepEPMoE(EPMoE):
|
|
490
490
|
|
491
491
|
def forward_aiter(
|
492
492
|
self,
|
493
|
-
dispatch_output: DeepEPNormalOutput,
|
493
|
+
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
|
494
494
|
):
|
495
495
|
hidden_states, topk_idx, topk_weights = (
|
496
496
|
dispatch_output.hidden_states,
|
@@ -516,7 +516,7 @@ class DeepEPMoE(EPMoE):
|
|
516
516
|
quant_type=QuantType.per_128x128,
|
517
517
|
activation=(
|
518
518
|
ActivationType.Silu
|
519
|
-
if self.activation == "silu"
|
519
|
+
if self.moe_runner_config.activation == "silu"
|
520
520
|
else ActivationType.Gelu
|
521
521
|
),
|
522
522
|
expert_mask=self.expert_mask,
|
@@ -531,7 +531,7 @@ class DeepEPMoE(EPMoE):
|
|
531
531
|
)
|
532
532
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
533
533
|
assert self.quant_method is not None
|
534
|
-
assert self.activation == "silu"
|
534
|
+
assert self.moe_runner_config.activation == "silu"
|
535
535
|
if num_recv_tokens_per_expert is None:
|
536
536
|
return hidden_states_fp8.bfloat16()
|
537
537
|
all_tokens = sum(num_recv_tokens_per_expert)
|
@@ -652,7 +652,7 @@ class DeepEPMoE(EPMoE):
|
|
652
652
|
):
|
653
653
|
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
654
654
|
assert self.quant_method is not None
|
655
|
-
assert self.activation == "silu"
|
655
|
+
assert self.moe_runner_config.activation == "silu"
|
656
656
|
|
657
657
|
# GroupGemm-0
|
658
658
|
num_groups, m, k = hidden_states_fp8[0].size()
|
@@ -735,7 +735,7 @@ class DeepEPMoE(EPMoE):
|
|
735
735
|
assert isinstance(dispatch_output, AscendDeepEPLLOutput)
|
736
736
|
hidden_states, topk_idx, topk_weights, _, seg_indptr, _ = dispatch_output
|
737
737
|
assert self.quant_method is not None
|
738
|
-
assert self.activation == "silu"
|
738
|
+
assert self.moe_runner_config.activation == "silu"
|
739
739
|
|
740
740
|
# NOTE: Ascend's Dispatch & Combine does not support FP16
|
741
741
|
output_dtype = torch.bfloat16
|
@@ -782,13 +782,17 @@ class DeepEPMoE(EPMoE):
|
|
782
782
|
return hidden_states
|
783
783
|
|
784
784
|
|
785
|
-
def get_moe_impl_class():
|
786
|
-
if
|
785
|
+
def get_moe_impl_class(quant_config: Optional[QuantizationConfig] = None):
|
786
|
+
if get_moe_a2a_backend().is_deepep():
|
787
787
|
return DeepEPMoE
|
788
788
|
|
789
789
|
# NEW: Direct FP4 detection (bypasses EP requirements)
|
790
790
|
# Check for FP4 quantization with TRTLLM flag, regardless of EP
|
791
|
-
if
|
791
|
+
if get_moe_runner_backend().is_flashinfer_trtllm():
|
792
|
+
# FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod.
|
793
|
+
# If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead.
|
794
|
+
if quant_config is None:
|
795
|
+
return FusedMoE
|
792
796
|
try:
|
793
797
|
# Check the quantization argument directly
|
794
798
|
quantization = global_server_args_dict.get("quantization")
|
@@ -803,7 +807,7 @@ def get_moe_impl_class():
|
|
803
807
|
|
804
808
|
if should_use_flashinfer_trtllm_moe():
|
805
809
|
return FlashInferFusedMoE
|
806
|
-
if
|
810
|
+
if get_moe_runner_backend().is_flashinfer_cutlass():
|
807
811
|
return FusedMoE
|
808
812
|
if get_moe_expert_parallel_world_size() > 1:
|
809
813
|
return EPMoE
|
@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|
3
3
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
4
4
|
"""
|
5
5
|
|
6
|
-
from typing import Callable, Optional
|
7
|
-
|
8
6
|
import torch
|
9
7
|
from torch.nn import functional as F
|
10
8
|
|
11
9
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
12
|
-
from sglang.srt.layers.moe.
|
10
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
11
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
13
12
|
|
14
13
|
|
15
14
|
def fused_moe_forward_native(
|
16
15
|
layer: torch.nn.Module,
|
17
16
|
x: torch.Tensor,
|
18
|
-
topk_output:
|
19
|
-
|
20
|
-
activation: str = "silu",
|
21
|
-
apply_router_weight_on_input: bool = False,
|
22
|
-
inplace: bool = True,
|
23
|
-
no_combine: bool = False,
|
24
|
-
routed_scaling_factor: Optional[float] = None,
|
17
|
+
topk_output: StandardTopKOutput,
|
18
|
+
moe_runner_config: MoeRunnerConfig,
|
25
19
|
) -> torch.Tensor:
|
26
20
|
|
27
|
-
if apply_router_weight_on_input:
|
21
|
+
if moe_runner_config.apply_router_weight_on_input:
|
28
22
|
raise NotImplementedError()
|
29
23
|
|
30
24
|
topk_weights, topk_ids, _ = topk_output
|
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
|
|
33
27
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
34
28
|
w2_weights = layer.w2_weight[topk_ids]
|
35
29
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
36
|
-
if activation == "silu":
|
30
|
+
if moe_runner_config.activation == "silu":
|
37
31
|
x1 = F.silu(x1)
|
38
|
-
elif activation == "gelu":
|
32
|
+
elif moe_runner_config.activation == "gelu":
|
39
33
|
x1 = F.gelu(x1)
|
40
34
|
else:
|
41
|
-
raise ValueError(f"Unsupported activation: {activation=}")
|
35
|
+
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
42
36
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
43
37
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
44
38
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
|
|
47
41
|
def moe_forward_native(
|
48
42
|
layer: torch.nn.Module,
|
49
43
|
x: torch.Tensor,
|
50
|
-
topk_output:
|
51
|
-
|
52
|
-
activation: str = "silu",
|
53
|
-
apply_router_weight_on_input: bool = False,
|
54
|
-
inplace: bool = True,
|
55
|
-
no_combine: bool = False,
|
56
|
-
routed_scaling_factor: Optional[float] = None,
|
44
|
+
topk_output: StandardTopKOutput,
|
45
|
+
moe_runner_config: MoeRunnerConfig,
|
57
46
|
) -> torch.Tensor:
|
58
47
|
|
59
|
-
if apply_router_weight_on_input:
|
48
|
+
if moe_runner_config.apply_router_weight_on_input:
|
60
49
|
raise NotImplementedError()
|
61
50
|
|
62
51
|
topk_weights, topk_ids, _ = topk_output
|
@@ -72,12 +61,12 @@ def moe_forward_native(
|
|
72
61
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
73
62
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
74
63
|
|
75
|
-
if activation == "silu":
|
64
|
+
if moe_runner_config.activation == "silu":
|
76
65
|
act = SiluAndMul()
|
77
|
-
elif activation == "gelu":
|
66
|
+
elif moe_runner_config.activation == "gelu":
|
78
67
|
act = GeluAndMul()
|
79
68
|
else:
|
80
|
-
raise ValueError(f"Unsupported activation: {activation=}")
|
69
|
+
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
|
81
70
|
|
82
71
|
outputs = []
|
83
72
|
start_idx = 0
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 64,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 256,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 64,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 256,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 256,
|
85
|
+
"BLOCK_SIZE_K": 64,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 16,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 4
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 4
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 64,
|
116
|
+
"BLOCK_SIZE_N": 128,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 128,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 32,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 4
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 128,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 4
|
145
|
+
}
|
146
|
+
}
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 8,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 8,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 256,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|